mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
Compare commits
63 Commits
feat/add-m
...
tmp/wip_ds
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e5e1c97a6c | ||
|
|
1594ae60a7 | ||
|
|
7cd710857d | ||
|
|
5c9bfd57ec | ||
|
|
f2ff370459 | ||
|
|
25f60c301b | ||
|
|
0699b46d87 | ||
|
|
b8f7e401d4 | ||
|
|
656fc0f059 | ||
|
|
829d2d1ad9 | ||
|
|
4ccf28437a | ||
|
|
9a49e57c72 | ||
|
|
6c28ef894a | ||
|
|
bf3c8746b7 | ||
|
|
9f32e00f90 | ||
|
|
fcaa0ea5f9 | ||
|
|
5ac9356135 | ||
|
|
b74e2a6113 | ||
|
|
a4bed41132 | ||
|
|
5c8dd883be | ||
|
|
38f6fc816b | ||
|
|
abde7be3b3 | ||
|
|
b6c528a438 | ||
|
|
6d331310ab | ||
|
|
5dfdec9288 | ||
|
|
50977a2c28 | ||
|
|
a0d7627d81 | ||
|
|
1ad2da403d | ||
|
|
2d3a605b3c | ||
|
|
f173265354 | ||
|
|
bbcf66bd82 | ||
|
|
c378a325f0 | ||
|
|
90684a9690 | ||
|
|
f59eb54f5c | ||
|
|
62e9849ffd | ||
|
|
e3b572992e | ||
|
|
5b647e3bcb | ||
|
|
ddfff054bc | ||
|
|
49918efbc1 | ||
|
|
c5b5955c5a | ||
|
|
ec40ccde0d | ||
|
|
d2782cf66b | ||
|
|
9627765ce2 | ||
|
|
43d878a102 | ||
|
|
ddba994d73 | ||
|
|
a87d4c9a74 | ||
|
|
170c09e7f6 | ||
|
|
853cc70194 | ||
|
|
ec63225dc1 | ||
|
|
af1760f175 | ||
|
|
163df97c0c | ||
|
|
cdd2bf1c4e | ||
|
|
1cba47da20 | ||
|
|
7359e18eb6 | ||
|
|
13010647bc | ||
|
|
acbc14f60a | ||
|
|
2b59850f15 | ||
|
|
42e4b3d09e | ||
|
|
98bcda2d8b | ||
|
|
a4178f385b | ||
|
|
bd09b2153f | ||
|
|
1033680a57 | ||
|
|
7cf04a5ec3 |
4
.github/workflows/stale.yml
vendored
4
.github/workflows/stale.yml
vendored
@@ -31,11 +31,11 @@ env:
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
WARN_ISSUE_MESSAGE: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity (1 year). It will be closed if no further activity occurs.
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
Thank you for your contributions.
|
||||
WARN_PR_MESSAGE: >
|
||||
This PR has been automatically marked as stale because it has not had
|
||||
recent activity (1 year). It will be closed if no further activity occurs.
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
Thank you for your contributions.
|
||||
|
||||
jobs:
|
||||
|
||||
183
.github/workflows/unbound_deps_tests.yml
vendored
Normal file
183
.github/workflows/unbound_deps_tests.yml
vendored
Normal file
@@ -0,0 +1,183 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This workflow handles full testing with unboud dependencies versions.
|
||||
name: Unbound Dependency Tests
|
||||
|
||||
on:
|
||||
# Allows running this workflow manually from the Actions tab
|
||||
workflow_dispatch:
|
||||
|
||||
# Run on the 1st and 15th of every month at 09:00 UTC
|
||||
schedule:
|
||||
- cron: '0 2 1,15 * *'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound
|
||||
|
||||
# Ensures that only the latest action is built, canceling older runs.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
|
||||
# This job runs the E2E tests + pytest with all unbound extras
|
||||
full-tests:
|
||||
name: Full Unbound Tests
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Unbound dependencies
|
||||
run: |
|
||||
sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml
|
||||
echo "Dependencies unbound:" && cat pyproject.toml
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --all-extras
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv
|
||||
|
||||
- name: Run end-to-end tests
|
||||
run: uv run make test-end-to-end
|
||||
|
||||
# This job builds a GPU enabled image for testing
|
||||
build-and-push-docker:
|
||||
name: Build and Push Docker
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
outputs:
|
||||
image_tag: ${{ env.DOCKER_IMAGE_NAME }}
|
||||
env:
|
||||
GITHUB_REF: ${{ github.ref }}
|
||||
steps:
|
||||
- name: Install Git LFS
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
cache-binary: false
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/Dockerfile.internal
|
||||
push: true
|
||||
tags: ${{ env.DOCKER_IMAGE_NAME }}
|
||||
build-args: |
|
||||
UNBOUND_DEPS=true
|
||||
|
||||
# This job runs pytest with all unbound extras in a GPU enabled host
|
||||
# It runs everytime a test image is created
|
||||
gpu-tests:
|
||||
name: GPU Unbound Tests
|
||||
needs: [build-and-push-docker]
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
env:
|
||||
HF_HOME: /home/user_lerobot/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
container:
|
||||
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv
|
||||
- name: Run end-to-end tests
|
||||
run: make test-end-to-end
|
||||
|
||||
# This job deletes the test image recently created
|
||||
# It runs everytime after the gpu-tests have finished
|
||||
delete-unbound-image:
|
||||
name: Delete Unbound Image
|
||||
needs: [gpu-tests, build-and-push-docker]
|
||||
if: always() && needs.build-and-push-docker.result == 'success'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Get Docker Hub Token and Delete Image
|
||||
# zizmor: ignore[template-injection]
|
||||
run: |
|
||||
IMAGE_NAME=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f1)
|
||||
IMAGE_TAG=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f2)
|
||||
|
||||
echo "Attempting to delete image: $IMAGE_NAME:$IMAGE_TAG"
|
||||
|
||||
TOKEN=$(curl -s -H "Content-Type: application/json" \
|
||||
-X POST \
|
||||
-d '{"username": "${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}", "password": "${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}"}' \
|
||||
https://hub.docker.com/v2/users/login/ | jq -r .token)
|
||||
|
||||
if [ "$TOKEN" == "null" ] || [ -z "$TOKEN" ]; then
|
||||
echo "::error::Failed to get Docker Hub token."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
HTTP_RESPONSE=$(curl -s -o /dev/null -w "%{http_code}" \
|
||||
-H "Authorization: JWT ${TOKEN}" \
|
||||
-X DELETE \
|
||||
https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/${IMAGE_TAG}/)
|
||||
|
||||
if [ "$HTTP_RESPONSE" -eq 204 ]; then
|
||||
echo "Successfully deleted Docker image tag: $IMAGE_NAME:$IMAGE_TAG"
|
||||
else
|
||||
echo "::error::Failed to delete Docker image. HTTP status: $HTTP_RESPONSE"
|
||||
exit 1
|
||||
fi
|
||||
@@ -86,11 +86,12 @@ repos:
|
||||
|
||||
# TODO(Steven): Uncomment when ready to use
|
||||
##### Static Analysis & Typing #####
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: v1.16.0
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
# args: [--python-version=3.10]
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.16.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: [--config-file=pyproject.toml]
|
||||
exclude: ^(examples|benchmarks|tests)/
|
||||
|
||||
##### Docstring Checks #####
|
||||
# - repo: https://github.com/akaihola/darglint2
|
||||
|
||||
@@ -197,7 +197,7 @@ wandb login
|
||||
|
||||
### Visualize datasets
|
||||
|
||||
Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
|
||||
Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
|
||||
|
||||
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
|
||||
|
||||
|
||||
@@ -35,12 +35,13 @@ import torch
|
||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||
from tqdm import tqdm
|
||||
|
||||
from benchmarks.video.benchmark import TimeBenchmark
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.video_utils import (
|
||||
decode_video_frames_torchvision,
|
||||
encode_video_frames,
|
||||
)
|
||||
from lerobot.utils.benchmark import TimeBenchmark
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
|
||||
BASE_ENCODING = OrderedDict(
|
||||
[
|
||||
@@ -117,7 +118,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# We only save images from the first camera
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
||||
|
||||
for i, item in enumerate(
|
||||
|
||||
@@ -75,6 +75,14 @@ RUN uv venv --python python${PYTHON_VERSION}
|
||||
# Install Python dependencies for caching
|
||||
COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
|
||||
ARG UNBOUND_DEPS=false
|
||||
|
||||
RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml; \
|
||||
echo "Dependencies unbound:" && cat pyproject.toml; \
|
||||
fi
|
||||
|
||||
RUN uv pip install --no-cache ".[all]"
|
||||
|
||||
# Copy the rest of the application source code
|
||||
|
||||
@@ -61,6 +61,14 @@ RUN uv venv
|
||||
# Install Python dependencies for caching
|
||||
COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
|
||||
ARG UNBOUND_DEPS=false
|
||||
|
||||
RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml; \
|
||||
echo "Dependencies unbound:" && cat pyproject.toml; \
|
||||
fi
|
||||
|
||||
RUN uv pip install --no-cache ".[all]"
|
||||
|
||||
# Copy the rest of the application code
|
||||
|
||||
@@ -25,14 +25,21 @@
|
||||
title: Using LeRobotDataset
|
||||
- local: porting_datasets_v3
|
||||
title: Porting Large Datasets
|
||||
- local: using_dataset_tools
|
||||
title: Using the Dataset Tools
|
||||
title: "Datasets"
|
||||
- sections:
|
||||
- local: act
|
||||
title: ACT
|
||||
- local: smolvla
|
||||
title: Finetune SmolVLA
|
||||
title: SmolVLA
|
||||
- local: pi0
|
||||
title: π₀ (Pi0)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: libero
|
||||
title: Using Libero
|
||||
title: "Policies"
|
||||
|
||||
- sections:
|
||||
- local: introduction_processors
|
||||
title: Introduction to Robot Processors
|
||||
|
||||
92
docs/source/act.mdx
Normal file
92
docs/source/act.mdx
Normal file
@@ -0,0 +1,92 @@
|
||||
# ACT (Action Chunking with Transformers)
|
||||
|
||||
ACT is a **lightweight and efficient policy for imitation learning**, especially well-suited for fine-grained manipulation tasks. It's the **first model we recommend when you're starting out** with LeRobot due to its fast training time, low computational requirements, and strong performance.
|
||||
|
||||
<div class="video-container">
|
||||
<iframe
|
||||
width="100%"
|
||||
height="415"
|
||||
src="https://www.youtube.com/embed/ft73x0LfGpM"
|
||||
title="LeRobot ACT Tutorial"
|
||||
frameborder="0"
|
||||
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
|
||||
allowfullscreen
|
||||
></iframe>
|
||||
</div>
|
||||
|
||||
_Watch this tutorial from the LeRobot team to learn how ACT works: [LeRobot ACT Tutorial](https://www.youtube.com/watch?v=ft73x0LfGpM)_
|
||||
|
||||
## Model Overview
|
||||
|
||||
Action Chunking with Transformers (ACT) was introduced in the paper [Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware](https://arxiv.org/abs/2304.13705) by Zhao et al. The policy was designed to enable precise, contact-rich manipulation tasks using affordable hardware and minimal demonstration data.
|
||||
|
||||
### Why ACT is Great for Beginners
|
||||
|
||||
ACT stands out as an excellent starting point for several reasons:
|
||||
|
||||
- **Fast Training**: Trains in a few hours on a single GPU
|
||||
- **Lightweight**: Only ~80M parameters, making it efficient and easy to work with
|
||||
- **Data Efficient**: Often achieves high success rates with just 50 demonstrations
|
||||
|
||||
### Architecture
|
||||
|
||||
ACT uses a transformer-based architecture with three main components:
|
||||
|
||||
1. **Vision Backbone**: ResNet-18 processes images from multiple camera viewpoints
|
||||
2. **Transformer Encoder**: Synthesizes information from camera features, joint positions, and a learned latent variable
|
||||
3. **Transformer Decoder**: Generates coherent action sequences using cross-attention
|
||||
|
||||
The policy takes as input:
|
||||
|
||||
- Multiple RGB images (e.g., from wrist cameras, front/top cameras)
|
||||
- Current robot joint positions
|
||||
- A latent style variable `z` (learned during training, set to zero during inference)
|
||||
|
||||
And outputs a chunk of `k` future action sequences.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. ACT is included in the base LeRobot installation, so no additional dependencies are needed!
|
||||
|
||||
## Training ACT
|
||||
|
||||
ACT works seamlessly with the standard LeRobot training pipeline. Here's a complete example for training ACT on your dataset:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/your_dataset \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_your_dataset \
|
||||
--job_name=act_your_dataset \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=${HF_USER}/act_policy
|
||||
```
|
||||
|
||||
### Training Tips
|
||||
|
||||
1. **Start with defaults**: ACT's default hyperparameters work well for most tasks
|
||||
2. **Training duration**: Expect a few hours for 100k training steps on a single GPU
|
||||
3. **Batch size**: Start with batch size 8 and adjust based on your GPU memory
|
||||
|
||||
### Train using Google Colab
|
||||
|
||||
If your local computer doesn't have a powerful GPU, you can utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||
|
||||
## Evaluating ACT
|
||||
|
||||
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_robot \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Your task description" \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
```
|
||||
@@ -31,15 +31,15 @@ Then, spin up a policy server (in one terminal, or in a separate machine) specif
|
||||
You can spin up a policy server running:
|
||||
|
||||
```shell
|
||||
python src/lerobot/scripts/server/policy_server.py \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080 \
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080
|
||||
```
|
||||
|
||||
This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with:
|
||||
|
||||
```shell
|
||||
python src/lerobot/scripts/server/robot_client.py \
|
||||
python -m lerobot.async_inference.robot_client \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
@@ -113,17 +113,17 @@ As such, spinning up a policy server is as easy as specifying the host address a
|
||||
<hfoptions id="start_policy_server">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.scripts.server.policy_server \
|
||||
--host="localhost" \
|
||||
--port=8080
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.scripts.server.configs import PolicyServerConfig
|
||||
from lerobot.scripts.server.policy_server import serve
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host="localhost",
|
||||
@@ -148,7 +148,7 @@ The `RobotClient` streams observations to the `PolicyServer`, and receives actio
|
||||
<hfoptions id="start_robot_client">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python src/lerobot/scripts/server/robot_client.py \
|
||||
python -m lerobot.async_inference.robot_client \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
@@ -171,9 +171,9 @@ python src/lerobot/scripts/server/robot_client.py \
|
||||
import threading
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.scripts.server.configs import RobotClientConfig
|
||||
from lerobot.scripts.server.robot_client import RobotClient
|
||||
from lerobot.scripts.server.helpers import visualize_action_queue_size
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.async_inference.helpers import visualize_action_queue_size
|
||||
|
||||
# 1. Create the robot instance
|
||||
"""Check out the cameras available in your setup by running `python lerobot/find_cameras.py`"""
|
||||
|
||||
@@ -95,7 +95,6 @@ class HILSerlProcessorConfig:
|
||||
class ObservationConfig:
|
||||
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
|
||||
add_current_to_observation: bool = False # Add motor currents to state
|
||||
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
|
||||
display_cameras: bool = False # Display camera feeds during execution
|
||||
|
||||
class ImagePreprocessingConfig:
|
||||
@@ -105,7 +104,6 @@ class ImagePreprocessingConfig:
|
||||
class GripperConfig:
|
||||
use_gripper: bool = True # Enable gripper control
|
||||
gripper_penalty: float = 0.0 # Penalty for inappropriate gripper usage
|
||||
gripper_penalty_in_reward: bool = False # Include gripper penalty in reward
|
||||
|
||||
class ResetConfig:
|
||||
fixed_reset_joint_positions: Any | None = None # Joint positions for reset
|
||||
@@ -288,7 +286,6 @@ You can enable multiple observation processing features simultaneously:
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true,
|
||||
"add_current_to_observation": true,
|
||||
"add_ee_pose_to_observation": false,
|
||||
"display_cameras": false
|
||||
}
|
||||
}
|
||||
@@ -304,19 +301,19 @@ Before collecting demonstrations, you need to determine the appropriate operatio
|
||||
|
||||
This helps simplify the problem of learning on the real robot in two ways: 1) by limiting the robot's operational space to a specific region that solves the task and avoids unnecessary or unsafe exploration, and 2) by allowing training in end-effector space rather than joint space. Empirically, learning in joint space for reinforcement learning in manipulation is often a harder problem - some tasks are nearly impossible to learn in joint space but become learnable when the action space is transformed to end-effector coordinates.
|
||||
|
||||
**Using find_joint_limits.py**
|
||||
**Using lerobot-find-joint-limits**
|
||||
|
||||
This script helps you find the safe operational bounds for your robot's end-effector. Given that you have a follower and leader arm, you can use the script to find the bounds for the follower arm that will be applied during training.
|
||||
Bounding the action space will reduce the redundant exploration of the agent and guarantees safety.
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.find_joint_limits \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=blue
|
||||
lerobot-find-joint-limits \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=blue
|
||||
```
|
||||
|
||||
**Workflow**
|
||||
|
||||
@@ -200,7 +200,7 @@ from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderCo
|
||||
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.record import record_loop
|
||||
|
||||
NUM_EPISODES = 5
|
||||
@@ -237,7 +237,7 @@ dataset = LeRobotDataset.create(
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording")
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
@@ -517,7 +517,7 @@ from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerCon
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.policies.factory import make_processor
|
||||
|
||||
@@ -557,7 +557,7 @@ dataset = LeRobotDataset.create(
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording")
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
@@ -8,7 +8,7 @@ To that end, we provide the [`Robot`](https://github.com/huggingface/lerobot/blo
|
||||
|
||||
- Your own robot which exposes a communication interface (e.g. serial, CAN, TCP)
|
||||
- A way to read sensor data and send motor commands programmatically, e.g. manufacturer's SDK or API, or your own protocol implementation.
|
||||
- LeRobot installed in your environment. Follow our [Installation Guide](./installation.mdx).
|
||||
- LeRobot installed in your environment. Follow our [Installation Guide](./installation).
|
||||
|
||||
## Choose your motors
|
||||
|
||||
@@ -65,7 +65,7 @@ class MyCoolRobotConfig(RobotConfig):
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
[Cameras tutorial](./cameras.mdx) to understand how to detect and add your camera.
|
||||
[Cameras tutorial](./cameras) to understand how to detect and add your camera.
|
||||
|
||||
Next, we'll create our actual robot class which inherits from `Robot`. This abstract class defines a contract you must follow for your robot to be usable with the rest of the LeRobot tools.
|
||||
|
||||
@@ -335,6 +335,134 @@ For implementing teleoperation devices, we also provide a [`Teleoperator`](https
|
||||
|
||||
The main differences are in the I/O functions: a teleoperator allows you to produce action via `get_action` and can receive feedback actions via `send_feedback`. Feedback could be anything controllable on the teleoperation device that could help the person controlling it understand the consequences of the actions sent. Think motion/force feedback on a leader arm, vibrations on a gamepad controller for example. To implement a teleoperator, you can follow this same tutorial and adapt it for these two methods.
|
||||
|
||||
## Using Your Own `LeRobot` Devices 🔌
|
||||
|
||||
You can easily extend `lerobot` with your own custom hardware—be it a camera, robot, or teleoperation device—by creating a separate, installable Python package. If you follow a few simple conventions, the `lerobot` command-line tools (like `lerobot-teleop` and `lerobot-record`) will **automatically discover and integrate your creations** without requiring any changes to the `lerobot` source code.
|
||||
|
||||
This guide outlines the conventions your plugin must follow.
|
||||
|
||||
### The 4 Core Conventions
|
||||
|
||||
To ensure your custom device is discoverable, you must adhere to the following four rules.
|
||||
|
||||
#### 1\. Create an Installable Package with a Specific Prefix
|
||||
|
||||
Your project must be a standard, installable Python package. Crucially, the name of your package (as defined in `pyproject.toml` or `setup.py`) must begin with one of these prefixes:
|
||||
|
||||
- `lerobot_robot_` for a robot.
|
||||
- `lerobot_camera_` for a camera.
|
||||
- `lerobot_teleoperator_` for a teleoperation device.
|
||||
|
||||
This prefix system is how `lerobot` automatically finds your plugin in the Python environment.
|
||||
|
||||
#### 2\. Follow the `SomethingConfig`/`Something` Naming Pattern
|
||||
|
||||
Your device's implementation class must be named after its configuration class, simply by removing the `Config` suffix.
|
||||
|
||||
- **Config Class:** `MyAwesomeTeleopConfig`
|
||||
- **Device Class:** `MyAwesomeTeleop`
|
||||
|
||||
#### 3\. Place Your Files in a Predictable Structure
|
||||
|
||||
The device class (`MyAwesomeTeleop`) must be located in a predictable module relative to its configuration class (`MyAwesomeTeleopConfig`). `lerobot` will automatically search in these locations:
|
||||
|
||||
- In the **same module** as the config class.
|
||||
- In a **submodule named after the device** (e.g., `my_awesome_teleop.py`).
|
||||
|
||||
The recommended and simplest structure is to place them in separate, clearly named files within the same directory.
|
||||
|
||||
#### 4\. Expose Classes in `__init__.py`
|
||||
|
||||
Your package's `__init__.py` file should import and expose both the configuration and the device classes, making them easily accessible.
|
||||
|
||||
### Putting It All Together: A Complete Example
|
||||
|
||||
Let's create a new teleoperator called `my_awesome_teleop`.
|
||||
|
||||
#### Directory Structure
|
||||
|
||||
Here is what the project folder should look like. The package name, `lerobot_teleoperator_my_awesome_teleop`, follows **Convention \#1**.
|
||||
|
||||
```
|
||||
lerobot_teleoperator_my_awesome_teleop/
|
||||
├── pyproject.toml # (or setup.py) lists lerobot as a dependency
|
||||
└── lerobot_teleoperator_my_awesome_teleop/
|
||||
├── __init__.py
|
||||
├── config_my_awesome_teleop.py
|
||||
└── my_awesome_teleop.py
|
||||
```
|
||||
|
||||
#### File Contents
|
||||
|
||||
- **`config_my_awesome_teleop.py`**: Defines the configuration class. Note the `Config` suffix (**Convention \#2**).
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.teleoperators.config import TeleoperatorConfig
|
||||
|
||||
@TeleoperatorConfig.register_subclass("my_awesome_teleop")
|
||||
@dataclass
|
||||
class MyAwesomeTeleopConfig(TeleoperatorConfig):
|
||||
# Your configuration fields go here
|
||||
port: str = "192.168.1.1"
|
||||
```
|
||||
|
||||
- **`my_awesome_teleop.py`**: Implements the device. The class name `MyAwesomeTeleop` matches its config class name (**Convention \#2**). This file structure adheres to **Convention \#3**.
|
||||
|
||||
```python
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
|
||||
from .config_my_awesome_teleop import MyAwesomeTeleopConfig
|
||||
|
||||
class MyAwesomeTeleop(Teleoperator):
|
||||
config_class = MyAwesomeTeleopConfig
|
||||
name = "my_awesome_teleop"
|
||||
|
||||
def __init__(self, config: MyAwesomeTeleopConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Your device logic (e.g., connect) goes here
|
||||
```
|
||||
|
||||
- **`__init__.py`**: Exposes the key classes (**Convention \#4**).
|
||||
|
||||
```python
|
||||
from .config_my_awesome_teleop import MyAwesomeTeleopConfig
|
||||
from .my_awesome_teleop import MyAwesomeTeleop
|
||||
```
|
||||
|
||||
### Installation and Usage
|
||||
|
||||
1. **Install your new plugin in your Python environment.** You can install your local plugin package using `pip`'s editable mode or from PyPi.
|
||||
|
||||
```bash
|
||||
# Locally
|
||||
# Navigate to your plugin's root directory and install it
|
||||
cd lerobot_teleoperator_my_awesome_teleop
|
||||
pip install -e .
|
||||
|
||||
# From PyPi
|
||||
pip install lerobot_teleoperator_my_awesome_teleop
|
||||
```
|
||||
|
||||
2. **Use it directly from the command line.** Now, you can use your custom device by referencing its type.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate --teleop.type=my_awesome_teleop \
|
||||
# other arguments
|
||||
```
|
||||
|
||||
And that's it\! Your custom device is now fully integrated.
|
||||
|
||||
### Looking for an example ?
|
||||
|
||||
Check out these two packages from the community:
|
||||
|
||||
- https://github.com/SpesRobotics/lerobot-robot-xarm
|
||||
- https://github.com/SpesRobotics/lerobot-teleoperator-teleop
|
||||
|
||||
## Wrapping Up
|
||||
|
||||
Once your robot class is complete, you can leverage the LeRobot ecosystem:
|
||||
|
||||
@@ -297,9 +297,9 @@ LeRobot provides many registered processor steps. Here are the most commonly use
|
||||
|
||||
### Next Steps
|
||||
|
||||
- **[Implement Your Own Processor](implement_your_own_processor.mdx)** - Create custom processor steps
|
||||
- **[Debug Your Pipeline](debug_processor_pipeline.mdx)** - Troubleshoot and optimize pipelines
|
||||
- **[Processors for Robots and Teleoperators](processors_robots_teleop.mdx)** - Real-world integration patterns
|
||||
- **[Implement Your Own Processor](./implement_your_own_processor)** - Create custom processor steps
|
||||
- **[Debug Your Pipeline](./debug_processor_pipeline)** - Troubleshoot and optimize pipelines
|
||||
- **[Processors for Robots and Teleoperators](./processors_robots_teleop)** - Real-world integration patterns
|
||||
|
||||
## Summary
|
||||
|
||||
|
||||
@@ -277,7 +277,7 @@ leader.disconnect()
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./il_robots)
|
||||
|
||||
> [!TIP]
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
|
||||
|
||||
@@ -323,7 +323,7 @@ To replay an episode run the API example below, make sure to change `remote_ip`,
|
||||
python examples/lekiwi/replay.py
|
||||
```
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by the training part of this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by the training part of this tutorial: [Getting started with real-world robots](./il_robots)
|
||||
|
||||
## Evaluate your policy
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ To Install LIBERO, after following LeRobot official instructions, just do:
|
||||
Evaluate a policy on one LIBERO suite:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/eval.py \
|
||||
lerobot-eval \
|
||||
--policy.path="your-policy-id" \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
@@ -52,7 +52,7 @@ python src/lerobot/scripts/eval.py \
|
||||
Benchmark a policy across multiple suites at once:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/eval.py \
|
||||
lerobot-eval \
|
||||
--policy.path="your-policy-id" \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object,libero_spatial \
|
||||
@@ -103,10 +103,11 @@ For reference, here is the **original dataset** published by Physical Intelligen
|
||||
### Example training command
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/train.py \
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--policy.repo_id=${HF_USER}/libero-test \
|
||||
--dataset.repo_id=jadechoghari/smol-libero3 \
|
||||
--policy.load_vlm_weights=true \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 \
|
||||
--output_dir=./outputs/ \
|
||||
@@ -124,3 +125,42 @@ python src/lerobot/scripts/train.py \
|
||||
LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation:
|
||||
|
||||
- `export MUJOCO_GL=egl` → for headless servers (e.g. HPC, cloud)
|
||||
|
||||
## Reproducing π₀.₅ results
|
||||
|
||||
We reproduce the results of π₀.₅ on the LIBERO benchmark using the LeRobot implementation. We take the Physical Intelligence LIBERO base model (`pi05_libero`) and finetune for an additional 6k steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero).
|
||||
|
||||
The finetuned model can be found here:
|
||||
|
||||
- **π₀.₅ LIBERO**: [lerobot/pi05_libero_finetuned](https://huggingface.co/lerobot/pi05_libero_finetuned)
|
||||
|
||||
We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/eval.py \
|
||||
--output_dir=/logs/ \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10 \
|
||||
--policy.path=pi05_libero_finetuned \
|
||||
--policy.n_action_steps=10 \
|
||||
--output_dir=./eval_logs/ \
|
||||
--env.max_parallel_tasks=1
|
||||
```
|
||||
|
||||
**Note:** We set `n_action_steps=10`, similar to the original OpenPI implementation.
|
||||
|
||||
### Results
|
||||
|
||||
We obtain the following results on the LIBERO benchmark:
|
||||
|
||||
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
|
||||
| -------- | -------------- | ------------- | ----------- | --------- | -------- |
|
||||
| **π₀.₅** | 97.0 | 99.0 | 98.0 | 96.0 | **97.5** |
|
||||
|
||||
These results are consistent with the original [results](https://github.com/Physical-Intelligence/openpi/tree/main/examples/libero#results) reported by Physical Intelligence:
|
||||
|
||||
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
|
||||
| -------- | -------------- | ------------- | ----------- | --------- | --------- |
|
||||
| **π₀.₅** | 98.8 | 98.2 | 98.0 | 92.4 | **96.85** |
|
||||
|
||||
@@ -79,7 +79,7 @@ After running the example:
|
||||
- Android: after starting the script, open the printed local URL on your phone, tap Start, then press and hold Move.
|
||||
- iOS: open HEBI Mobile I/O first; B1 enables motion. A3 controls the gripper.
|
||||
|
||||
Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop.mdx) guide.
|
||||
Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop) guide.
|
||||
|
||||
- Run this example to record a dataset, which saves absolute end effector observations and actions:
|
||||
|
||||
@@ -136,13 +136,12 @@ Additionally you can customize mapping or safety limits by editing the processor
|
||||
),
|
||||
```
|
||||
|
||||
- The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` and `max_ee_twist_step_rad` are the step limits for the EE pose and can be modified to change the safety limits.
|
||||
- The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` are the step limits for the EE pose and can be modified to change the safety limits.
|
||||
|
||||
```examples/phone_to_so100/teleoperate.py
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
79
docs/source/pi0.mdx
Normal file
79
docs/source/pi0.mdx
Normal file
@@ -0,0 +1,79 @@
|
||||
# π₀ (Pi0)
|
||||
|
||||
π₀ is a **Vision-Language-Action model for general robot control**, from Physical Intelligence. The LeRobot implementation is adapted from their open source [OpenPI](https://github.com/Physical-Intelligence/openpi) repository.
|
||||
|
||||
## Model Overview
|
||||
|
||||
π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi0). Unlike traditional robot programs that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks.
|
||||
|
||||
### The Vision for Physical Intelligence
|
||||
|
||||
As described by Physical Intelligence, while AI has achieved remarkable success in digital domains, from chess-playing to drug discovery, human intelligence still dramatically outpaces AI in the physical world. To paraphrase Moravec's paradox, winning a game of chess represents an "easy" problem for AI, but folding a shirt or cleaning up a table requires solving some of the most difficult engineering problems ever conceived. π₀ represents a first step toward developing artificial physical intelligence that enables users to simply ask robots to perform any task they want, just like they can with large language models.
|
||||
|
||||
### Architecture and Approach
|
||||
|
||||
π₀ combines several key innovations:
|
||||
|
||||
- **Flow Matching**: Uses a novel method to augment pre-trained VLMs with continuous action outputs via flow matching (a variant of diffusion models)
|
||||
- **Cross-Embodiment Training**: Trained on data from 8 distinct robot platforms including UR5e, Bimanual UR5e, Franka, Bimanual Trossen, Bimanual ARX, Mobile Trossen, and Mobile Fibocom
|
||||
- **Internet-Scale Pre-training**: Inherits semantic knowledge from a pre-trained 3B parameter Vision-Language Model
|
||||
- **High-Frequency Control**: Outputs motor commands at up to 50 Hz for real-time dexterous manipulation
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. Install Pi0 dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
## Training Data and Capabilities
|
||||
|
||||
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
|
||||
|
||||
1. **Internet-Scale Pre-training**: Vision-language data from the web for semantic understanding
|
||||
2. **Open X-Embodiment Dataset**: Open-source robot manipulation datasets
|
||||
3. **Physical Intelligence Dataset**: Large and diverse dataset of dexterous tasks across 8 distinct robots
|
||||
|
||||
## Usage
|
||||
|
||||
To use π₀ in LeRobot, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=pi0
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi0 \
|
||||
--output_dir=./outputs/pi0_training \
|
||||
--job_name=pi0_training \
|
||||
--policy.pretrained_path=lerobot/pi0_base \
|
||||
--policy.repo_id=your_repo_id \
|
||||
--policy.compile_model=true \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--batch_size=32
|
||||
```
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
- **`--policy.compile_model=true`**: Enables model compilation for faster training
|
||||
- **`--policy.gradient_checkpointing=true`**: Reduces memory usage significantly during training
|
||||
- **`--policy.dtype=bfloat16`**: Use mixed precision training for efficiency
|
||||
- **`--batch_size=32`**: Batch size for training, adapt this based on your GPU memory
|
||||
- **`--policy.pretrained_path=lerobot/pi0_base`**: The base π₀ model you want to finetune, options are:
|
||||
- [lerobot/pi0_base](https://huggingface.co/lerobot/pi0_base)
|
||||
- [lerobot/pi0_libero](https://huggingface.co/lerobot/pi0_libero) (specifically trained on the Libero dataset)
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
107
docs/source/pi05.mdx
Normal file
107
docs/source/pi05.mdx
Normal file
@@ -0,0 +1,107 @@
|
||||
# π₀.₅ (Pi05) Policy
|
||||
|
||||
π₀.₅ is a **Vision-Language-Action model with open-world generalization**, from Physical Intelligence. The LeRobot implementation is adapted from their open source [OpenPI](https://github.com/Physical-Intelligence/openpi) repository.
|
||||
|
||||
## Model Overview
|
||||
|
||||
π₀.₅ represents a significant evolution from π₀, developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi05) to address a big challenge in robotics: **open-world generalization**. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
|
||||
|
||||
### The Generalization Challenge
|
||||
|
||||
As Physical Intelligence explains, the fundamental challenge isn't performing tasks of agility or dexterity, but generalization, the ability to correctly perform tasks in new settings with new objects. Consider a robot cleaning different homes: each home has different objects in different places. Generalization must occur at multiple levels:
|
||||
|
||||
- **Physical Level**: Understanding how to pick up a spoon (by the handle) or plate (by the edge), even with unseen objects in cluttered environments
|
||||
- **Semantic Level**: Understanding task semantics, where to put clothes and shoes (laundry hamper, not on the bed), and what tools are appropriate for cleaning spills
|
||||
- **Environmental Level**: Adapting to "messy" real-world environments like homes, grocery stores, offices, and hospitals
|
||||
|
||||
### Co-Training on Heterogeneous Data
|
||||
|
||||
The breakthrough innovation in π₀.₅ is **co-training on heterogeneous data sources**. The model learns from:
|
||||
|
||||
1. **Multimodal Web Data**: Image captioning, visual question answering, object detection
|
||||
2. **Verbal Instructions**: Humans coaching robots through complex tasks step-by-step
|
||||
3. **Subtask Commands**: High-level semantic behavior labels (e.g., "pick up the pillow" for an unmade bed)
|
||||
4. **Cross-Embodiment Robot Data**: Data from various robot platforms with different capabilities
|
||||
5. **Multi-Environment Data**: Static robots deployed across many different homes
|
||||
6. **Mobile Manipulation Data**: ~400 hours of mobile robot demonstrations
|
||||
|
||||
This diverse training mixture creates a "curriculum" that enables generalization across physical, visual, and semantic levels simultaneously.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. Install Pi0.5 dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=pi05
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Training Command Example
|
||||
|
||||
Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py\
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi05 \
|
||||
--output_dir=./outputs/pi05_training \
|
||||
--job_name=pi05_training \
|
||||
--policy.repo_id=your_repo_id \
|
||||
--policy.pretrained_path=lerobot/pi05_base \
|
||||
--policy.compile_model=true \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--wandb.enable=true \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--batch_size=32
|
||||
```
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
- **`--policy.compile_model=true`**: Enables model compilation for faster training
|
||||
- **`--policy.gradient_checkpointing=true`**: Reduces memory usage significantly during training
|
||||
- **`--policy.dtype=bfloat16`**: Use mixed precision training for efficiency
|
||||
- **`--batch_size=32`**: Batch size for training, adapt this based on your GPU memory
|
||||
- **`--policy.pretrained_path=lerobot/pi05_base`**: The base π₀.₅ model you want to finetune, options are:
|
||||
- [lerobot/pi05_base](https://huggingface.co/lerobot/pi05_base)
|
||||
- [lerobot/pi05_libero](https://huggingface.co/lerobot/pi05_libero) (specifically trained on the Libero dataset)
|
||||
|
||||
If your dataset is not converted with `quantiles`, you can convert it with the following command:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
--repo-id=your_dataset \
|
||||
```
|
||||
|
||||
Or train pi05 with this normalization mapping: `--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'`
|
||||
|
||||
## Performance Results
|
||||
|
||||
### Libero Benchmark Results
|
||||
|
||||
π₀.₅ has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the libero base model for an additional 6k steps on the Libero dataset and compared the results to the OpenPI reference results.
|
||||
|
||||
| Benchmark | LeRobot Implementation | OpenPI Reference |
|
||||
| ------------------ | ---------------------- | ---------------- |
|
||||
| **Libero Spatial** | 97.0% | 98.8% |
|
||||
| **Libero Object** | 99.0% | 98.2% |
|
||||
| **Libero Goal** | 98.0% | 98.0% |
|
||||
| **Libero 10** | 96.0% | 92.4% |
|
||||
| **Average** | 97.5% | 96.85% |
|
||||
|
||||
These results demonstrate π₀.₅'s strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
@@ -38,7 +38,7 @@ phone_to_robot_ee_pose_processor = RobotProcessorPipeline[RobotAction, RobotActi
|
||||
kinematics=kinematics_solver, end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.20, max_ee_twist_step_rad=0.50,
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.20,
|
||||
),
|
||||
GripperVelocityToJoint(),
|
||||
],
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Finetune SmolVLA
|
||||
# SmolVLA
|
||||
|
||||
SmolVLA is Hugging Face’s lightweight foundation model for robotics. Designed for easy fine-tuning on LeRobot datasets, it helps accelerate your development!
|
||||
|
||||
@@ -29,7 +29,7 @@ SmolVLA is Hugging Face’s lightweight foundation model for robotics. Designed
|
||||
## Collect a dataset
|
||||
|
||||
SmolVLA is a base model, so fine-tuning on your own data is required for optimal performance in your setup.
|
||||
We recommend recording ~50 episodes of your task as a starting point. Follow our guide to get started: [Recording a Dataset](https://huggingface.co/docs/lerobot/getting_started_real_world_robot#record-a-dataset)
|
||||
We recommend recording ~50 episodes of your task as a starting point. Follow our guide to get started: [Recording a Dataset](./il_robots)
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -93,7 +93,7 @@ lerobot-train --help
|
||||
|
||||
## Evaluate the finetuned model and run it in real-time
|
||||
|
||||
Similarly for when recording an episode, it is recommended that you are logged in to the HuggingFace Hub. You can follow the corresponding steps: [Record a dataset](./getting_started_real_world_robot#record-a-dataset).
|
||||
Similarly for when recording an episode, it is recommended that you are logged in to the HuggingFace Hub. You can follow the corresponding steps: [Record a dataset](./il_robots).
|
||||
Once you are logged in, you can run inference in your setup by doing:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -634,7 +634,7 @@ leader.disconnect()
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./il_robots)
|
||||
|
||||
> [!TIP]
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
|
||||
|
||||
@@ -430,7 +430,7 @@ leader.disconnect()
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./il_robots)
|
||||
|
||||
> [!TIP]
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
|
||||
|
||||
102
docs/source/using_dataset_tools.mdx
Normal file
102
docs/source/using_dataset_tools.mdx
Normal file
@@ -0,0 +1,102 @@
|
||||
# Using Dataset Tools
|
||||
|
||||
This guide covers the dataset tools utilities available in LeRobot for modifying and editing existing datasets.
|
||||
|
||||
## Overview
|
||||
|
||||
LeRobot provides several utilities for manipulating datasets:
|
||||
|
||||
1. **Delete Episodes** - Remove specific episodes from a dataset
|
||||
2. **Split Dataset** - Divide a dataset into multiple smaller datasets
|
||||
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
|
||||
## Command-Line Tool: lerobot-edit-dataset
|
||||
|
||||
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features.
|
||||
|
||||
Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
|
||||
|
||||
### Usage Examples
|
||||
|
||||
#### Delete Episodes
|
||||
|
||||
Remove specific episodes from a dataset. This is useful for filtering out undesired data.
|
||||
|
||||
```bash
|
||||
# Delete episodes 0, 2, and 5 (modifies original dataset)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
# Delete episodes and save to a new dataset (preserves original dataset)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_after_deletion \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
```
|
||||
|
||||
#### Split Dataset
|
||||
|
||||
Divide a dataset into multiple subsets.
|
||||
|
||||
```bash
|
||||
# Split by fractions (e.g. 80% train, 20% test, 20% val)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.8, "test": 0.2, "val": 0.2}'
|
||||
|
||||
# Split by specific episode indices
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"task1": [0, 1, 2, 3], "task2": [4, 5]}'
|
||||
```
|
||||
|
||||
There are no constraints on the split names, they can be determined by the user. Resulting datasets are saved under the repo id with the split name appended, e.g. `lerobot/pusht_train`, `lerobot/pusht_task1`, `lerobot/pusht_task2`.
|
||||
|
||||
#### Merge Datasets
|
||||
|
||||
Combine multiple datasets into a single dataset.
|
||||
|
||||
```bash
|
||||
# Merge train and validation splits back into one dataset
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_merged \
|
||||
--operation.type merge \
|
||||
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
||||
```
|
||||
|
||||
#### Remove Features
|
||||
|
||||
Remove features from a dataset.
|
||||
|
||||
```bash
|
||||
# Remove a camera feature
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
```
|
||||
|
||||
### Push to Hub
|
||||
|
||||
Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
|
||||
```bash
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_after_deletion \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.
|
||||
@@ -44,6 +44,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
@@ -78,16 +79,16 @@ def replay(cfg: ReplayConfig):
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
actions = dataset.hf_dataset.select_columns(ACTION)
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action_array = actions[idx]["action"]
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features["action"]["names"]):
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
key = f"{name.removeprefix('main_')}.pos"
|
||||
action[key] = action_array[i].item()
|
||||
|
||||
|
||||
117
examples/dataset/use_dataset_tools.py
Normal file
117
examples/dataset/use_dataset_tools.py
Normal file
@@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example script demonstrating dataset tools utilities.
|
||||
|
||||
This script shows how to:
|
||||
1. Delete episodes from a dataset
|
||||
2. Split a dataset into train/val sets
|
||||
3. Add/remove features
|
||||
4. Merge datasets
|
||||
|
||||
Usage:
|
||||
python examples/dataset/use_dataset_tools.py
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
add_feature,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def main():
|
||||
dataset = LeRobotDataset("lerobot/pusht")
|
||||
|
||||
print(f"Original dataset: {dataset.meta.total_episodes} episodes, {dataset.meta.total_frames} frames")
|
||||
print(f"Features: {list(dataset.meta.features.keys())}")
|
||||
|
||||
print("\n1. Deleting episodes 0 and 2...")
|
||||
filtered_dataset = delete_episodes(dataset, episode_indices=[0, 2], repo_id="lerobot/pusht_filtered")
|
||||
print(f"Filtered dataset: {filtered_dataset.meta.total_episodes} episodes")
|
||||
|
||||
print("\n2. Splitting dataset into train/val...")
|
||||
splits = split_dataset(
|
||||
dataset,
|
||||
splits={"train": 0.8, "val": 0.2},
|
||||
)
|
||||
print(f"Train split: {splits['train'].meta.total_episodes} episodes")
|
||||
print(f"Val split: {splits['val'].meta.total_episodes} episodes")
|
||||
|
||||
print("\n3. Adding a reward feature...")
|
||||
|
||||
reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32)
|
||||
dataset_with_reward = add_feature(
|
||||
dataset,
|
||||
feature_name="reward",
|
||||
feature_values=reward_values,
|
||||
feature_info={
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
repo_id="lerobot/pusht_with_reward",
|
||||
)
|
||||
|
||||
def compute_success(row_dict, episode_index, frame_index):
|
||||
episode_length = 10
|
||||
return float(frame_index >= episode_length - 10)
|
||||
|
||||
dataset_with_success = add_feature(
|
||||
dataset_with_reward,
|
||||
feature_name="success",
|
||||
feature_values=compute_success,
|
||||
feature_info={
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
repo_id="lerobot/pusht_with_reward_and_success",
|
||||
)
|
||||
|
||||
print(f"New features: {list(dataset_with_success.meta.features.keys())}")
|
||||
|
||||
print("\n4. Removing the success feature...")
|
||||
dataset_cleaned = remove_feature(
|
||||
dataset_with_success, feature_names="success", repo_id="lerobot/pusht_cleaned"
|
||||
)
|
||||
print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}")
|
||||
|
||||
print("\n5. Merging train and val splits back together...")
|
||||
merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged")
|
||||
print(f"Merged dataset: {merged.meta.total_episodes} episodes")
|
||||
|
||||
print("\n6. Complex workflow example...")
|
||||
|
||||
if len(dataset.meta.camera_keys) > 1:
|
||||
camera_to_remove = dataset.meta.camera_keys[0]
|
||||
print(f"Removing camera: {camera_to_remove}")
|
||||
dataset_no_cam = remove_feature(
|
||||
dataset, feature_names=camera_to_remove, repo_id="pusht_no_first_camera"
|
||||
)
|
||||
print(f"Remaining cameras: {dataset_no_cam.meta.camera_keys}")
|
||||
|
||||
print("\nDone! Check ~/.cache/huggingface/lerobot/ for the created datasets.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -19,11 +19,12 @@ from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
@@ -41,8 +42,8 @@ robot = LeKiwiClient(robot_config)
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
@@ -73,7 +74,7 @@ teleop_action_processor, robot_action_processor, robot_observation_processor = m
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="lekiwi_evaluate")
|
||||
init_rerun(session_name="lekiwi_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
@@ -17,14 +17,15 @@
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
@@ -47,8 +48,8 @@ keyboard = KeyboardTeleop(keyboard_config)
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
@@ -69,7 +70,7 @@ keyboard.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="lekiwi_record")
|
||||
init_rerun(session_name="lekiwi_record")
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
@@ -19,6 +19,7 @@ import time
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -34,7 +35,7 @@ robot = LeKiwiClient(robot_config)
|
||||
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns("action")
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
@@ -49,7 +50,7 @@ for idx in range(len(episode_frames)):
|
||||
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Send action to robot
|
||||
|
||||
@@ -20,7 +20,7 @@ from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
@@ -41,7 +41,7 @@ leader_arm.connect()
|
||||
keyboard.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
_init_rerun(session_name="lekiwi_teleop")
|
||||
init_rerun(session_name="lekiwi_teleop")
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
@@ -34,16 +34,16 @@ from lerobot.processor.converters import (
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -137,7 +137,7 @@ robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="phone_so100_evaluate")
|
||||
init_rerun(session_name="phone_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
@@ -26,7 +26,6 @@ from lerobot.processor.converters import (
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
@@ -36,12 +35,13 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
@@ -84,7 +84,6 @@ phone_to_robot_ee_pose_processor = RobotProcessorPipeline[tuple[RobotAction, Rob
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
GripperVelocityToJoint(speed_factor=20.0),
|
||||
],
|
||||
@@ -143,7 +142,7 @@ phone.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="phone_so100_record")
|
||||
init_rerun(session_name="phone_so100_record")
|
||||
|
||||
if not robot.is_connected or not phone.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
@@ -28,6 +28,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -66,7 +67,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotOb
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns("action")
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
@@ -81,7 +82,7 @@ for idx in range(len(episode_frames)):
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get robot observation
|
||||
|
||||
@@ -33,7 +33,7 @@ from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
@@ -67,7 +67,6 @@ phone_to_robot_joints_processor = RobotProcessorPipeline[tuple[RobotAction, Robo
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
speed_factor=20.0,
|
||||
@@ -87,7 +86,7 @@ robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
_init_rerun(session_name="phone_so100_teleop")
|
||||
init_rerun(session_name="phone_so100_teleop")
|
||||
|
||||
if not robot.is_connected or not teleop_device.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
@@ -34,16 +34,16 @@ from lerobot.processor.converters import (
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -138,7 +138,7 @@ robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="so100_so100_evaluate")
|
||||
init_rerun(session_name="so100_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
@@ -27,7 +27,6 @@ from lerobot.processor.converters import (
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
@@ -35,11 +34,12 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
|
||||
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
@@ -101,7 +101,6 @@ ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservati
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
@@ -143,7 +142,7 @@ follower.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording_phone")
|
||||
init_rerun(session_name="recording_phone")
|
||||
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
@@ -29,6 +29,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -67,7 +68,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotOb
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns("action")
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
@@ -82,7 +83,7 @@ for idx in range(len(episode_frames)):
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get robot observation
|
||||
|
||||
@@ -33,7 +33,7 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
|
||||
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
@@ -78,7 +78,6 @@ ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservati
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
@@ -95,7 +94,7 @@ follower.connect()
|
||||
leader.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
_init_rerun(session_name="so100_so100_EE_teleop")
|
||||
init_rerun(session_name="so100_so100_EE_teleop")
|
||||
|
||||
print("Starting teleop loop...")
|
||||
while True:
|
||||
|
||||
@@ -20,13 +20,13 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.constants import ACTION
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
175
pyproject.toml
175
pyproject.toml
@@ -59,20 +59,20 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
||||
dependencies = [
|
||||
|
||||
# Hugging Face dependencies
|
||||
"datasets>=4.0.0",
|
||||
"diffusers>=0.27.2",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2",
|
||||
"datasets>=4.0.0,<4.2.0",
|
||||
"diffusers>=0.27.2,<0.36.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||
|
||||
# Core dependencies
|
||||
"cmake>=3.29.0.1",
|
||||
"einops>=0.8.0",
|
||||
"opencv-python-headless>=4.9.0",
|
||||
"av>=14.2.0",
|
||||
"jsonlines>=4.0.0",
|
||||
"packaging>=24.2",
|
||||
"pynput>=1.7.7",
|
||||
"pyserial>=3.5",
|
||||
"wandb>=0.20.0",
|
||||
"cmake>=3.29.0.1,<4.2.0",
|
||||
"einops>=0.8.0,<0.9.0",
|
||||
"opencv-python-headless>=4.9.0,<4.13.0",
|
||||
"av>=15.0.0,<16.0.0",
|
||||
"jsonlines>=4.0.0,<5.0.0",
|
||||
"packaging>=24.2,<26.0",
|
||||
"pynput>=1.7.7,<1.9.0",
|
||||
"pyserial>=3.5,<4.0",
|
||||
"wandb>=0.20.0,<0.23.0",
|
||||
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
@@ -92,26 +92,26 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1"]
|
||||
placo-dep = ["placo>=0.9.6"]
|
||||
transformers-dep = ["transformers>=4.52.0"]
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31"]
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
|
||||
# Robots
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1"]
|
||||
reachy2 = ["reachy2_sdk>=1.0.14"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
|
||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
|
||||
]
|
||||
phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"]
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0"]
|
||||
# stretch = [
|
||||
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
|
||||
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
|
||||
@@ -119,22 +119,22 @@ phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"]
|
||||
# ] # TODO: Currently not supported
|
||||
|
||||
# Policies
|
||||
pi0 = ["lerobot[transformers-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"]
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||
test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
aloha = ["gym-aloha>=0.1.1"]
|
||||
pusht = ["gym-pusht>=0.1.5", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
xarm = ["gym-xarm>=0.1.1"]
|
||||
aloha = ["gym-aloha>=0.1.1,<0.2.0"]
|
||||
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
xarm = ["gym-xarm>=0.1.1,<0.2.0"]
|
||||
libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"]
|
||||
|
||||
|
||||
@@ -147,7 +147,7 @@ all = [
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[pi0]",
|
||||
"lerobot[pi]",
|
||||
"lerobot[smolvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[async]",
|
||||
@@ -162,18 +162,20 @@ all = [
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
lerobot-calibrate="lerobot.calibrate:main"
|
||||
lerobot-find-cameras="lerobot.find_cameras:main"
|
||||
lerobot-find-port="lerobot.find_port:main"
|
||||
lerobot-record="lerobot.record:main"
|
||||
lerobot-replay="lerobot.replay:main"
|
||||
lerobot-setup-motors="lerobot.setup_motors:main"
|
||||
lerobot-teleoperate="lerobot.teleoperate:main"
|
||||
lerobot-eval="lerobot.scripts.eval:main"
|
||||
lerobot-train="lerobot.scripts.train:main"
|
||||
lerobot-calibrate="lerobot.scripts.lerobot_calibrate:main"
|
||||
lerobot-find-cameras="lerobot.scripts.lerobot_find_cameras:main"
|
||||
lerobot-find-port="lerobot.scripts.lerobot_find_port:main"
|
||||
lerobot-record="lerobot.scripts.lerobot_record:main"
|
||||
lerobot-replay="lerobot.scripts.lerobot_replay:main"
|
||||
lerobot-setup-motors="lerobot.scripts.lerobot_setup_motors:main"
|
||||
lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main"
|
||||
lerobot-eval="lerobot.scripts.lerobot_eval:main"
|
||||
lerobot-train="lerobot.scripts.lerobot_train:main"
|
||||
lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
|
||||
lerobot-info="lerobot.scripts.lerobot_info:main"
|
||||
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.packages.find]
|
||||
@@ -200,7 +202,7 @@ exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
|
||||
# N: pep8-naming
|
||||
# TODO: Uncomment rules when ready to use
|
||||
select = [
|
||||
"E", "W", "F", "I", "B", "C4", "T20", "N" # "SIM", "A", "S", "D", "RUF", "UP"
|
||||
"E", "W", "F", "I", "B", "C4", "T20", "N", "UP", "SIM" #, "A", "S", "D", "RUF"
|
||||
]
|
||||
ignore = [
|
||||
"E501", # Line too long
|
||||
@@ -266,8 +268,87 @@ default.extend-ignore-identifiers-re = [
|
||||
# color = true
|
||||
# paths = ["src/lerobot"]
|
||||
|
||||
# [tool.mypy]
|
||||
# python_version = "3.10"
|
||||
# TODO: Enable mypy gradually module by module across multiple PRs
|
||||
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
follow_imports = "skip"
|
||||
# warn_return_any = true
|
||||
# warn_unused_configs = true
|
||||
# ignore_missing_imports = false
|
||||
# strict = true
|
||||
# disallow_untyped_defs = true
|
||||
# disallow_incomplete_defs = true
|
||||
# check_untyped_defs = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.*"
|
||||
ignore_errors = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.envs.*"
|
||||
# Enable type checking only for the envs module
|
||||
ignore_errors = false
|
||||
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.utils.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.configs.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.optim.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.model.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.processor.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.datasets.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.cameras.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.motors.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.robots.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.teleoperators.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.policies.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.rl.*"
|
||||
# ignore_errors = false
|
||||
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.async_inference.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.transport.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.scripts.*"
|
||||
# ignore_errors = false
|
||||
|
||||
@@ -18,7 +18,8 @@ from dataclasses import dataclass, field
|
||||
import torch
|
||||
|
||||
from lerobot.robots.config import RobotConfig
|
||||
from lerobot.scripts.server.constants import (
|
||||
|
||||
from .constants import (
|
||||
DEFAULT_FPS,
|
||||
DEFAULT_INFERENCE_LATENCY,
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT,
|
||||
@@ -141,11 +142,6 @@ class RobotClientConfig:
|
||||
default=False, metadata={"help": "Visualize the action queue size"}
|
||||
)
|
||||
|
||||
# Verification configuration
|
||||
verify_robot_cameras: bool = field(
|
||||
default=True, metadata={"help": "Verify that the robot cameras match the policy cameras"}
|
||||
)
|
||||
|
||||
@property
|
||||
def environment_dt(self) -> float:
|
||||
"""Environment time step, in seconds"""
|
||||
@@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
||||
|
||||
# All action chunking policies
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"]
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"]
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower"]
|
||||
@@ -22,16 +22,22 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
|
||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
|
||||
from lerobot.policies import ( # noqa: F401
|
||||
ACTConfig,
|
||||
DiffusionConfig,
|
||||
PI0Config,
|
||||
PI05Config,
|
||||
SmolVLAConfig,
|
||||
VQBeTConfig,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
Action = torch.Tensor
|
||||
ActionChunk = torch.Tensor
|
||||
|
||||
# observation as received from the robot
|
||||
RawObservation = dict[str, torch.Tensor]
|
||||
@@ -46,7 +52,7 @@ Observation = dict[str, torch.Tensor]
|
||||
def visualize_action_queue_size(action_queue_size: list[int]) -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
_, ax = plt.subplots()
|
||||
ax.set_title("Action Queue Size Over Time")
|
||||
ax.set_xlabel("Environment steps")
|
||||
ax.set_ylabel("Action Queue Size")
|
||||
@@ -56,17 +62,8 @@ def visualize_action_queue_size(action_queue_size: list[int]) -> None:
|
||||
plt.show()
|
||||
|
||||
|
||||
def validate_robot_cameras_for_policy(
|
||||
lerobot_observation_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature]
|
||||
) -> None:
|
||||
image_keys = list(filter(is_image_key, lerobot_observation_features))
|
||||
assert set(image_keys) == set(policy_image_features.keys()), (
|
||||
f"Policy image features must match robot cameras! Received {list(policy_image_features.keys())} != {image_keys}"
|
||||
)
|
||||
|
||||
|
||||
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
|
||||
return hw_to_dataset_features(robot.observation_features, "observation", use_video=False)
|
||||
return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False)
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
@@ -86,11 +83,11 @@ def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int,
|
||||
return resized.squeeze(0)
|
||||
|
||||
|
||||
# TODO(Steven): Consider implementing a pipeline step for this
|
||||
def raw_observation_to_observation(
|
||||
raw_observation: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
policy_image_features: dict[str, PolicyFeature],
|
||||
device: str,
|
||||
) -> Observation:
|
||||
observation = {}
|
||||
|
||||
@@ -99,9 +96,7 @@ def raw_observation_to_observation(
|
||||
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations
|
||||
if "image" in k:
|
||||
# Policy expects images in shape (B, C, H, W)
|
||||
observation[k] = prepare_image(v).unsqueeze(0).to(device)
|
||||
else:
|
||||
observation[k] = v.to(device)
|
||||
observation[k] = prepare_image(v).unsqueeze(0)
|
||||
else:
|
||||
observation[k] = v
|
||||
|
||||
@@ -141,7 +136,7 @@ def make_lerobot_observation(
|
||||
lerobot_features: dict[str, dict],
|
||||
) -> LeRobotObservation:
|
||||
"""Make a lerobot observation from a raw observation."""
|
||||
return build_dataset_frame(lerobot_features, robot_obs, prefix="observation")
|
||||
return build_dataset_frame(lerobot_features, robot_obs, prefix=OBS_STR)
|
||||
|
||||
|
||||
def prepare_raw_observation(
|
||||
@@ -15,7 +15,7 @@
|
||||
"""
|
||||
Example:
|
||||
```shell
|
||||
python src/lerobot/scripts/server/policy_server.py \
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080 \
|
||||
--fps=30 \
|
||||
@@ -32,15 +32,26 @@ from concurrent import futures
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
from lerobot.scripts.server.configs import PolicyServerConfig
|
||||
from lerobot.scripts.server.constants import SUPPORTED_POLICIES
|
||||
from lerobot.scripts.server.helpers import (
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks
|
||||
|
||||
from .configs import PolicyServerConfig
|
||||
from .constants import SUPPORTED_POLICIES
|
||||
from .helpers import (
|
||||
FPSTracker,
|
||||
Observation,
|
||||
RemotePolicyConfig,
|
||||
@@ -50,11 +61,6 @@ from lerobot.scripts.server.helpers import (
|
||||
observations_similar,
|
||||
raw_observation_to_observation,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks
|
||||
|
||||
|
||||
class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
@@ -81,6 +87,8 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
self.lerobot_features = None
|
||||
self.actions_per_chunk = None
|
||||
self.policy = None
|
||||
self.preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None
|
||||
self.postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
@@ -145,6 +153,16 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
start = time.perf_counter()
|
||||
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
|
||||
self.policy.to(self.device)
|
||||
|
||||
# Load preprocessor and postprocessor, overriding device to match requested device
|
||||
device_override = {"device": self.device}
|
||||
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
||||
self.policy.config,
|
||||
pretrained_path=policy_specs.pretrained_name_or_path,
|
||||
preprocessor_overrides={"device_processor": device_override},
|
||||
postprocessor_overrides={"device_processor": device_override},
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
|
||||
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
|
||||
@@ -172,7 +190,7 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
# Calculate FPS metrics
|
||||
fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp)
|
||||
|
||||
self.logger.info(
|
||||
self.logger.debug(
|
||||
f"Received observation #{obs_timestep} | "
|
||||
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client
|
||||
f"Target: {fps_metrics['target_fps']:.2f} | "
|
||||
@@ -188,7 +206,7 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
if not self._enqueue_observation(
|
||||
timed_observation # wrapping a RawObservation
|
||||
):
|
||||
self.logger.info(f"Observation #{obs_timestep} has been filtered out")
|
||||
self.logger.debug(f"Observation #{obs_timestep} has been filtered out")
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
@@ -300,23 +318,6 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
for i, action in enumerate(action_chunk)
|
||||
]
|
||||
|
||||
def _prepare_observation(self, observation_t: TimedObservation) -> Observation:
|
||||
"""
|
||||
Prepare observation, ready for policy inference.
|
||||
E.g.: To keep observation sampling rate high (and network packet tiny) we send int8 [0,255] images from the
|
||||
client and then convert them to float32 [0,1] images here, before running inference.
|
||||
"""
|
||||
# RawObservation from robot.get_observation() - wrong keys, wrong dtype, wrong image shape
|
||||
observation: Observation = raw_observation_to_observation(
|
||||
observation_t.get_observation(),
|
||||
self.lerobot_features,
|
||||
self.policy_image_features,
|
||||
self.device,
|
||||
)
|
||||
# processed Observation - right keys, right dtype, right image shape
|
||||
|
||||
return observation
|
||||
|
||||
def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Get an action chunk from the policy. The chunk contains only"""
|
||||
chunk = self.policy.predict_action_chunk(observation)
|
||||
@@ -326,44 +327,76 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
return chunk[:, : self.actions_per_chunk, :]
|
||||
|
||||
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
|
||||
"""Predict an action chunk based on an observation"""
|
||||
inference_starts = time.perf_counter()
|
||||
"""Predict an action chunk based on an observation.
|
||||
|
||||
Pipeline:
|
||||
1. Convert raw observation to LeRobot format
|
||||
2. Apply preprocessor (tokenization, normalization, batching, device placement)
|
||||
3. Run policy inference to get action chunk
|
||||
4. Apply postprocessor (unnormalization, device movement)
|
||||
5. Convert to TimedAction list
|
||||
"""
|
||||
"""1. Prepare observation"""
|
||||
start_time = time.perf_counter()
|
||||
observation = self._prepare_observation(observation_t)
|
||||
preprocessing_time = time.perf_counter() - start_time
|
||||
start_prepare = time.perf_counter()
|
||||
observation: Observation = raw_observation_to_observation(
|
||||
observation_t.get_observation(),
|
||||
self.lerobot_features,
|
||||
self.policy_image_features,
|
||||
)
|
||||
prepare_time = time.perf_counter() - start_prepare
|
||||
|
||||
"""2. Apply preprocessor"""
|
||||
start_preprocess = time.perf_counter()
|
||||
observation = self.preprocessor(observation)
|
||||
self.last_processed_obs: TimedObservation = observation_t
|
||||
preprocessing_time = time.perf_counter() - start_preprocess
|
||||
|
||||
"""2. Get action chunk"""
|
||||
start_time = time.perf_counter()
|
||||
"""3. Get action chunk"""
|
||||
start_inference = time.perf_counter()
|
||||
action_tensor = self._get_action_chunk(observation)
|
||||
inference_time = time.perf_counter() - start_time
|
||||
inference_time = time.perf_counter() - start_inference
|
||||
self.logger.info(
|
||||
f"Preprocessing and inference took {inference_time:.4f}s, action shape: {action_tensor.shape}"
|
||||
)
|
||||
|
||||
"""3. Post-inference processing"""
|
||||
start_time = time.perf_counter()
|
||||
# Move to CPU before serializing
|
||||
action_tensor = action_tensor.cpu().squeeze(0)
|
||||
"""4. Apply postprocessor"""
|
||||
# Apply postprocessor (handles unnormalization and device movement)
|
||||
# Postprocessor expects (B, action_dim) per action, but we have (B, chunk_size, action_dim)
|
||||
# So we process each action in the chunk individually
|
||||
start_postprocess = time.perf_counter()
|
||||
_, chunk_size, _ = action_tensor.shape
|
||||
|
||||
# Process each action in the chunk
|
||||
processed_actions = []
|
||||
for i in range(chunk_size):
|
||||
# Extract action at timestep i: (B, action_dim)
|
||||
single_action = action_tensor[:, i, :]
|
||||
processed_action = self.postprocessor(single_action)
|
||||
processed_actions.append(processed_action)
|
||||
|
||||
# Stack back to (B, chunk_size, action_dim), then remove batch dim
|
||||
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
|
||||
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
|
||||
|
||||
"""5. Convert to TimedAction list"""
|
||||
action_chunk = self._time_action_chunk(
|
||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
||||
)
|
||||
postprocessing_time = time.perf_counter() - start_time
|
||||
inference_stops = time.perf_counter()
|
||||
postprocess_stops = time.perf_counter()
|
||||
postprocessing_time = postprocess_stops - start_postprocess
|
||||
|
||||
self.logger.info(
|
||||
f"Observation {observation_t.get_timestep()} |"
|
||||
f"Inference time: {1000 * (inference_stops - inference_starts):.2f}ms"
|
||||
f"Observation {observation_t.get_timestep()} | "
|
||||
f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
|
||||
)
|
||||
|
||||
# full-process latency breakdown for debugging purposes
|
||||
self.logger.debug(
|
||||
f"Observation {observation_t.get_timestep()} | "
|
||||
f"Preprocessing time: {1000 * (preprocessing_time - inference_starts):.2f}ms | "
|
||||
f"Inference time: {1000 * (inference_time - preprocessing_time):.2f}ms | "
|
||||
f"Postprocessing time: {1000 * (postprocessing_time - inference_time):.2f}ms | "
|
||||
f"Total time: {1000 * (postprocessing_time - inference_starts):.2f}ms"
|
||||
f"Prepare time: {1000 * prepare_time:.2f}ms | "
|
||||
f"Preprocessing time: {1000 * preprocessing_time:.2f}ms | "
|
||||
f"Inference time: {1000 * inference_time:.2f}ms | "
|
||||
f"Postprocessing time: {1000 * postprocessing_time:.2f}ms | "
|
||||
f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
|
||||
)
|
||||
|
||||
return action_chunk
|
||||
@@ -15,7 +15,7 @@
|
||||
"""
|
||||
Example command:
|
||||
```shell
|
||||
python src/lerobot/scripts/server/robot_client.py \
|
||||
python src/lerobot/async_inference/robot_client.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
@@ -48,18 +48,24 @@ import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so100_follower,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.scripts.server.configs import RobotClientConfig
|
||||
from lerobot.scripts.server.constants import SUPPORTED_ROBOTS
|
||||
from lerobot.scripts.server.helpers import (
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
||||
|
||||
from .configs import RobotClientConfig
|
||||
from .constants import SUPPORTED_ROBOTS
|
||||
from .helpers import (
|
||||
Action,
|
||||
FPSTracker,
|
||||
Observation,
|
||||
@@ -69,14 +75,8 @@ from lerobot.scripts.server.helpers import (
|
||||
TimedObservation,
|
||||
get_logger,
|
||||
map_robot_keys_to_lerobot_features,
|
||||
validate_robot_cameras_for_policy,
|
||||
visualize_action_queue_size,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
||||
|
||||
|
||||
class RobotClient:
|
||||
@@ -96,14 +96,6 @@ class RobotClient:
|
||||
|
||||
lerobot_features = map_robot_keys_to_lerobot_features(self.robot)
|
||||
|
||||
if config.verify_robot_cameras:
|
||||
# Load policy config for validation
|
||||
policy_config = PreTrainedConfig.from_pretrained(config.pretrained_name_or_path)
|
||||
policy_image_features = policy_config.image_features
|
||||
|
||||
# The cameras specified for inference must match the one supported by the policy chosen
|
||||
validate_robot_cameras_for_policy(lerobot_features, policy_image_features)
|
||||
|
||||
# Use environment variable if server_address is not provided in config
|
||||
self.server_address = config.server_address
|
||||
|
||||
@@ -213,7 +205,7 @@ class RobotClient:
|
||||
)
|
||||
_ = self.stub.SendObservations(observation_iterator)
|
||||
obs_timestep = obs.get_timestep()
|
||||
self.logger.info(f"Sent observation #{obs_timestep} | ")
|
||||
self.logger.debug(f"Sent observation #{obs_timestep} | ")
|
||||
|
||||
return True
|
||||
|
||||
@@ -466,7 +458,7 @@ class RobotClient:
|
||||
if self._ready_to_send_observation():
|
||||
_captured_observation = self.control_loop_observation(task, verbose)
|
||||
|
||||
self.logger.info(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
|
||||
self.logger.debug(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
|
||||
# Dynamically adjust sleep time to maintain the desired control frequency
|
||||
time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start)))
|
||||
|
||||
@@ -31,7 +31,7 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from ..utils import get_cv2_backend, get_cv2_rotation
|
||||
|
||||
@@ -31,7 +31,7 @@ import numpy as np
|
||||
from reachy2_sdk.media.camera import CameraView
|
||||
from reachy2_sdk.media.camera_manager import CameraManager
|
||||
|
||||
from lerobot.errors import DeviceNotConnectedError
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from .configuration_reachy2_camera import ColorMode, Reachy2CameraConfig
|
||||
|
||||
@@ -29,7 +29,7 @@ try:
|
||||
except Exception as e:
|
||||
logging.info(f"Could not import realsense: {e}")
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from ..configs import ColorMode
|
||||
|
||||
@@ -15,19 +15,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import TypeAlias
|
||||
from typing import cast
|
||||
|
||||
from lerobot.utils.import_utils import make_device_from_device_class
|
||||
|
||||
from .camera import Camera
|
||||
from .configs import CameraConfig, Cv2Rotation
|
||||
|
||||
IndexOrPath: TypeAlias = int | Path
|
||||
|
||||
|
||||
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[str, Camera]:
|
||||
cameras = {}
|
||||
cameras: dict[str, Camera] = {}
|
||||
|
||||
for key, cfg in camera_configs.items():
|
||||
# TODO(Steven): Consider just using the make_device_from_device_class for all types
|
||||
if cfg.type == "opencv":
|
||||
from .opencv import OpenCVCamera
|
||||
|
||||
@@ -44,7 +44,10 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
|
||||
cameras[key] = Reachy2Camera(cfg)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
|
||||
try:
|
||||
cameras[key] = cast(Camera, make_device_from_device_class(cfg))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error creating camera {key} with config {cfg}: {e}") from e
|
||||
|
||||
return cameras
|
||||
|
||||
|
||||
@@ -16,9 +16,6 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot import (
|
||||
policies, # noqa: F401
|
||||
)
|
||||
from lerobot.datasets.transforms import ImageTransformsConfig
|
||||
from lerobot.datasets.video_utils import get_safe_default_codec
|
||||
|
||||
|
||||
@@ -27,9 +27,9 @@ from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
|
||||
@@ -71,9 +71,11 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
tags: list[str] | None = None
|
||||
# Add tags to your policy on the hub.
|
||||
license: str | None = None
|
||||
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
||||
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
|
||||
pretrained_path: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.pretrained_path = None
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
class FeatureType(str, Enum):
|
||||
@@ -36,10 +35,8 @@ class NormalizationMode(str, Enum):
|
||||
MIN_MAX = "MIN_MAX"
|
||||
MEAN_STD = "MEAN_STD"
|
||||
IDENTITY = "IDENTITY"
|
||||
|
||||
|
||||
class DictLike(Protocol):
|
||||
def __getitem__(self, key: Any) -> Any: ...
|
||||
QUANTILES = "QUANTILES"
|
||||
QUANTILE10 = "QUANTILE10"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -31,15 +31,15 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
get_file_size_in_mb,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_video_size_in_mb,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
|
||||
|
||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
@@ -93,14 +93,13 @@ def update_data_df(df, src_meta, dst_meta):
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices.
|
||||
"""
|
||||
|
||||
def _update(row):
|
||||
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
|
||||
row["index"] = row["index"] + dst_meta.info["total_frames"]
|
||||
task = src_meta.tasks.iloc[row["task_index"]].name
|
||||
row["task_index"] = dst_meta.tasks.loc[task].task_index.item()
|
||||
return row
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
df["index"] = df["index"] + dst_meta.info["total_frames"]
|
||||
|
||||
return df.apply(_update, axis=1)
|
||||
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
|
||||
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def update_meta_data(
|
||||
@@ -126,27 +125,45 @@ def update_meta_data(
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
|
||||
"""
|
||||
|
||||
def _update(row):
|
||||
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
||||
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"]
|
||||
row["data/chunk_index"] = row["data/chunk_index"] + data_idx["chunk"]
|
||||
row["data/file_index"] = row["data/file_index"] + data_idx["file"]
|
||||
for key, video_idx in videos_idx.items():
|
||||
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + video_idx["chunk"]
|
||||
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + video_idx["file"]
|
||||
row[f"videos/{key}/from_timestamp"] = (
|
||||
row[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||
)
|
||||
row[f"videos/{key}/to_timestamp"] = (
|
||||
row[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
|
||||
)
|
||||
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
||||
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
|
||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||
for key, video_idx in videos_idx.items():
|
||||
# Store original video file indices before updating
|
||||
orig_chunk_col = f"videos/{key}/chunk_index"
|
||||
orig_file_col = f"videos/{key}/file_index"
|
||||
df["_orig_chunk"] = df[orig_chunk_col].copy()
|
||||
df["_orig_file"] = df[orig_file_col].copy()
|
||||
|
||||
row["dataset_from_index"] = row["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
row["dataset_to_index"] = row["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
|
||||
return row
|
||||
# Update chunk and file indices to point to destination
|
||||
df[orig_chunk_col] = video_idx["chunk"]
|
||||
df[orig_file_col] = video_idx["file"]
|
||||
|
||||
return df.apply(_update, axis=1)
|
||||
# Apply per-source-file timestamp offsets
|
||||
src_to_offset = video_idx.get("src_to_offset", {})
|
||||
if src_to_offset:
|
||||
# Apply offset based on original source file
|
||||
for idx in df.index:
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||
else:
|
||||
# Fallback to simple offset (for backward compatibility)
|
||||
df[f"videos/{key}/from_timestamp"] = (
|
||||
df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||
)
|
||||
df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
|
||||
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def aggregate_datasets(
|
||||
@@ -200,6 +217,10 @@ def aggregate_datasets(
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
root=aggr_root,
|
||||
use_videos=len(video_keys) > 0,
|
||||
chunks_size=chunk_size,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
video_files_size_in_mb=video_files_size_in_mb,
|
||||
)
|
||||
|
||||
logging.info("Find all tasks")
|
||||
@@ -243,6 +264,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
Returns:
|
||||
dict: Updated videos_idx with current chunk and file indices.
|
||||
"""
|
||||
for key in videos_idx:
|
||||
videos_idx[key]["episode_duration"] = 0
|
||||
# Track offset for each source (chunk, file) pair
|
||||
videos_idx[key]["src_to_offset"] = {}
|
||||
|
||||
for key, video_idx in videos_idx.items():
|
||||
unique_chunk_file_pairs = {
|
||||
(chunk, file)
|
||||
@@ -256,6 +282,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
|
||||
chunk_idx = video_idx["chunk"]
|
||||
file_idx = video_idx["file"]
|
||||
current_offset = video_idx["latest_duration"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
@@ -270,21 +297,25 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
# If a new file is created, we don't want to increment the latest_duration
|
||||
update_latest_duration = False
|
||||
src_duration = get_video_duration_in_s(src_path)
|
||||
|
||||
if not dst_path.exists():
|
||||
# First write to this destination file
|
||||
# Store offset before incrementing
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
continue # not accumulating further, already copied the file in place
|
||||
videos_idx[key]["episode_duration"] += src_duration
|
||||
current_offset += src_duration
|
||||
continue
|
||||
|
||||
# Check file sizes before appending
|
||||
src_size = get_video_size_in_mb(src_path)
|
||||
dst_size = get_video_size_in_mb(dst_path)
|
||||
src_size = get_file_size_in_mb(src_path)
|
||||
dst_size = get_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= video_files_size_in_mb:
|
||||
# Rotate to a new chunk/file
|
||||
# Rotate to a new file, this source becomes start of new destination
|
||||
# So its offset should be 0
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
@@ -293,25 +324,22 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
# Reset offset for next file
|
||||
current_offset = src_duration
|
||||
else:
|
||||
# Get the timestamps shift for this video
|
||||
timestamps_shift_s = dst_meta.info["total_frames"] / dst_meta.info["fps"]
|
||||
|
||||
# Append to existing video file
|
||||
# Append to existing video file - use current accumulated offset
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
concatenate_video_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
)
|
||||
# Update the latest_duration when appending (shifts timestamps!)
|
||||
update_latest_duration = not update_latest_duration
|
||||
current_offset += src_duration
|
||||
|
||||
videos_idx[key]["episode_duration"] += src_duration
|
||||
|
||||
# Update the videos_idx with the final chunk and file indices for this key
|
||||
videos_idx[key]["chunk"] = chunk_idx
|
||||
videos_idx[key]["file"] = file_idx
|
||||
|
||||
if update_latest_duration:
|
||||
videos_idx[key]["latest_duration"] += timestamps_shift_s
|
||||
|
||||
return videos_idx
|
||||
|
||||
|
||||
@@ -396,9 +424,6 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
videos_idx,
|
||||
)
|
||||
|
||||
for k in videos_idx:
|
||||
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||
|
||||
meta_idx = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
@@ -410,6 +435,10 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
aggr_root=dst_meta.root,
|
||||
)
|
||||
|
||||
# Increment latest_duration by the total duration added from this source dataset
|
||||
for k in videos_idx:
|
||||
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||
|
||||
return meta_idx
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,9 @@ Please, update your dataset to the new format using this command:
|
||||
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id={repo_id}
|
||||
```
|
||||
|
||||
If you already have a converted version uploaded to the hub, then this error might be because of
|
||||
an older version in your local cache. Consider deleting the cached version and retrying.
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
@@ -17,6 +17,179 @@ import numpy as np
|
||||
|
||||
from lerobot.datasets.utils import load_image_as_numpy
|
||||
|
||||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||||
|
||||
|
||||
class RunningQuantileStats:
|
||||
"""
|
||||
Maintains running statistics for batches of vectors, including mean,
|
||||
standard deviation, min, max, and approximate quantiles.
|
||||
|
||||
Statistics are computed per feature dimension and updated incrementally
|
||||
as new batches are observed. Quantiles are estimated using histograms,
|
||||
which adapt dynamically if the observed data range expands.
|
||||
"""
|
||||
|
||||
def __init__(self, quantile_list: list[float] | None = None, num_quantile_bins: int = 5000):
|
||||
self._count = 0
|
||||
self._mean = None
|
||||
self._mean_of_squares = None
|
||||
self._min = None
|
||||
self._max = None
|
||||
self._histograms = None
|
||||
self._bin_edges = None
|
||||
self._num_quantile_bins = num_quantile_bins
|
||||
|
||||
self._quantile_list = quantile_list
|
||||
if self._quantile_list is None:
|
||||
self._quantile_list = DEFAULT_QUANTILES
|
||||
self._quantile_keys = [f"q{int(q * 100):02d}" for q in self._quantile_list]
|
||||
|
||||
def update(self, batch: np.ndarray) -> None:
|
||||
"""Update the running statistics with a batch of vectors.
|
||||
|
||||
Args:
|
||||
batch: An array where all dimensions except the last are batch dimensions.
|
||||
"""
|
||||
batch = batch.reshape(-1, batch.shape[-1])
|
||||
num_elements, vector_length = batch.shape
|
||||
|
||||
if self._count == 0:
|
||||
self._mean = np.mean(batch, axis=0)
|
||||
self._mean_of_squares = np.mean(batch**2, axis=0)
|
||||
self._min = np.min(batch, axis=0)
|
||||
self._max = np.max(batch, axis=0)
|
||||
self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)]
|
||||
self._bin_edges = [
|
||||
np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1)
|
||||
for i in range(vector_length)
|
||||
]
|
||||
else:
|
||||
if vector_length != self._mean.size:
|
||||
raise ValueError("The length of new vectors does not match the initialized vector length.")
|
||||
|
||||
new_max = np.max(batch, axis=0)
|
||||
new_min = np.min(batch, axis=0)
|
||||
max_changed = np.any(new_max > self._max)
|
||||
min_changed = np.any(new_min < self._min)
|
||||
self._max = np.maximum(self._max, new_max)
|
||||
self._min = np.minimum(self._min, new_min)
|
||||
|
||||
if max_changed or min_changed:
|
||||
self._adjust_histograms()
|
||||
|
||||
self._count += num_elements
|
||||
|
||||
batch_mean = np.mean(batch, axis=0)
|
||||
batch_mean_of_squares = np.mean(batch**2, axis=0)
|
||||
|
||||
# Update running mean and mean of squares
|
||||
self._mean += (batch_mean - self._mean) * (num_elements / self._count)
|
||||
self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (
|
||||
num_elements / self._count
|
||||
)
|
||||
|
||||
self._update_histograms(batch)
|
||||
|
||||
def get_statistics(self) -> dict[str, np.ndarray]:
|
||||
"""Compute and return the statistics of the vectors processed so far.
|
||||
|
||||
Args:
|
||||
quantiles: List of quantiles to compute (e.g., [0.01, 0.10, 0.50, 0.90, 0.99]). If None, no quantiles computed.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the computed statistics.
|
||||
"""
|
||||
if self._count < 2:
|
||||
raise ValueError("Cannot compute statistics for less than 2 vectors.")
|
||||
|
||||
variance = self._mean_of_squares - self._mean**2
|
||||
|
||||
stddev = np.sqrt(np.maximum(0, variance))
|
||||
|
||||
stats = {
|
||||
"min": self._min.copy(),
|
||||
"max": self._max.copy(),
|
||||
"mean": self._mean.copy(),
|
||||
"std": stddev,
|
||||
"count": np.array([self._count]),
|
||||
}
|
||||
|
||||
quantile_results = self._compute_quantiles()
|
||||
for i, q in enumerate(self._quantile_keys):
|
||||
stats[q] = quantile_results[i]
|
||||
|
||||
return stats
|
||||
|
||||
def _adjust_histograms(self):
|
||||
"""Adjust histograms when min or max changes."""
|
||||
for i in range(len(self._histograms)):
|
||||
old_edges = self._bin_edges[i]
|
||||
old_hist = self._histograms[i]
|
||||
|
||||
# Create new edges with small padding to ensure range coverage
|
||||
padding = (self._max[i] - self._min[i]) * 1e-10
|
||||
new_edges = np.linspace(
|
||||
self._min[i] - padding, self._max[i] + padding, self._num_quantile_bins + 1
|
||||
)
|
||||
|
||||
# Redistribute existing histogram counts to new bins
|
||||
# We need to map each old bin center to the new bins
|
||||
old_centers = (old_edges[:-1] + old_edges[1:]) / 2
|
||||
new_hist = np.zeros(self._num_quantile_bins)
|
||||
|
||||
for old_center, count in zip(old_centers, old_hist, strict=False):
|
||||
if count > 0:
|
||||
# Find which new bin this old center belongs to
|
||||
bin_idx = np.searchsorted(new_edges, old_center) - 1
|
||||
bin_idx = max(0, min(bin_idx, self._num_quantile_bins - 1))
|
||||
new_hist[bin_idx] += count
|
||||
|
||||
self._histograms[i] = new_hist
|
||||
self._bin_edges[i] = new_edges
|
||||
|
||||
def _update_histograms(self, batch: np.ndarray) -> None:
|
||||
"""Update histograms with new vectors."""
|
||||
for i in range(batch.shape[1]):
|
||||
hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i])
|
||||
self._histograms[i] += hist
|
||||
|
||||
def _compute_quantiles(self) -> list[np.ndarray]:
|
||||
"""Compute quantiles based on histograms."""
|
||||
results = []
|
||||
for q in self._quantile_list:
|
||||
target_count = q * self._count
|
||||
q_values = []
|
||||
|
||||
for hist, edges in zip(self._histograms, self._bin_edges, strict=True):
|
||||
q_value = self._compute_single_quantile(hist, edges, target_count)
|
||||
q_values.append(q_value)
|
||||
|
||||
results.append(np.array(q_values))
|
||||
return results
|
||||
|
||||
def _compute_single_quantile(self, hist: np.ndarray, edges: np.ndarray, target_count: float) -> float:
|
||||
"""Compute a single quantile value from histogram and bin edges."""
|
||||
cumsum = np.cumsum(hist)
|
||||
idx = np.searchsorted(cumsum, target_count)
|
||||
|
||||
if idx == 0:
|
||||
return edges[0]
|
||||
if idx >= len(cumsum):
|
||||
return edges[-1]
|
||||
|
||||
# If not edge case, interpolate within the bin
|
||||
count_before = cumsum[idx - 1]
|
||||
count_in_bin = cumsum[idx] - count_before
|
||||
|
||||
# If no samples in this bin, use the bin edge
|
||||
if count_in_bin == 0:
|
||||
return edges[idx]
|
||||
|
||||
# Linear interpolation within the bin
|
||||
fraction = (target_count - count_before) / count_in_bin
|
||||
return edges[idx] + fraction * (edges[idx + 1] - edges[idx])
|
||||
|
||||
|
||||
def estimate_num_samples(
|
||||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||
@@ -72,33 +245,282 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
return images
|
||||
|
||||
|
||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||
return {
|
||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||
"max": np.max(array, axis=axis, keepdims=keepdims),
|
||||
"mean": np.mean(array, axis=axis, keepdims=keepdims),
|
||||
"std": np.std(array, axis=axis, keepdims=keepdims),
|
||||
"count": np.array([len(array)]),
|
||||
def _reshape_stats_by_axis(
|
||||
stats: dict[str, np.ndarray],
|
||||
axis: int | tuple[int, ...] | None,
|
||||
keepdims: bool,
|
||||
original_shape: tuple[int, ...],
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Reshape all statistics to match NumPy's output conventions.
|
||||
|
||||
Applies consistent reshaping to all statistics (except 'count') based on the
|
||||
axis and keepdims parameters. This ensures statistics have the correct shape
|
||||
for broadcasting with the original data.
|
||||
|
||||
Args:
|
||||
stats: Dictionary of computed statistics
|
||||
axis: Axis or axes along which statistics were computed
|
||||
keepdims: Whether to keep reduced dimensions as size-1 dimensions
|
||||
original_shape: Shape of the original array
|
||||
|
||||
Returns:
|
||||
Dictionary with reshaped statistics
|
||||
|
||||
Note:
|
||||
The 'count' statistic is never reshaped as it represents metadata
|
||||
rather than per-feature statistics.
|
||||
"""
|
||||
if axis == (1,) and not keepdims:
|
||||
return stats
|
||||
|
||||
result = {}
|
||||
for key, value in stats.items():
|
||||
if key == "count":
|
||||
result[key] = value
|
||||
else:
|
||||
result[key] = _reshape_single_stat(value, axis, keepdims, original_shape)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _reshape_for_image_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
|
||||
"""Reshape statistics for image data (axis=(0,2,3))."""
|
||||
if keepdims and value.ndim == 1:
|
||||
return value.reshape(1, -1, 1, 1)
|
||||
return value
|
||||
|
||||
|
||||
def _reshape_for_vector_stats(
|
||||
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
|
||||
) -> np.ndarray:
|
||||
"""Reshape statistics for vector data (axis=0 or axis=(0,))."""
|
||||
if not keepdims:
|
||||
return value
|
||||
|
||||
if len(original_shape) == 1 and value.ndim > 0:
|
||||
return value.reshape(1)
|
||||
elif len(original_shape) >= 2 and value.ndim == 1:
|
||||
return value.reshape(1, -1)
|
||||
return value
|
||||
|
||||
|
||||
def _reshape_for_feature_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
|
||||
"""Reshape statistics for feature-wise computation (axis=(1,))."""
|
||||
if not keepdims:
|
||||
return value
|
||||
|
||||
if value.ndim == 0:
|
||||
return value.reshape(1, 1)
|
||||
elif value.ndim == 1:
|
||||
return value.reshape(-1, 1)
|
||||
return value
|
||||
|
||||
|
||||
def _reshape_for_global_stats(
|
||||
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
|
||||
) -> np.ndarray | float:
|
||||
"""Reshape statistics for global reduction (axis=None)."""
|
||||
if keepdims:
|
||||
target_shape = tuple(1 for _ in original_shape)
|
||||
return value.reshape(target_shape)
|
||||
# Keep at least 1-D arrays to satisfy validator
|
||||
return np.atleast_1d(value)
|
||||
|
||||
|
||||
def _reshape_single_stat(
|
||||
value: np.ndarray, axis: int | tuple[int, ...] | None, keepdims: bool, original_shape: tuple[int, ...]
|
||||
) -> np.ndarray | float:
|
||||
"""Apply appropriate reshaping to a single statistic array.
|
||||
|
||||
This function transforms statistic arrays to match expected output shapes
|
||||
based on the axis configuration and keepdims parameter.
|
||||
|
||||
Args:
|
||||
value: The statistic array to reshape
|
||||
axis: Axis or axes that were reduced during computation
|
||||
keepdims: Whether to maintain reduced dimensions as size-1 dimensions
|
||||
original_shape: Shape of the original data before reduction
|
||||
|
||||
Returns:
|
||||
Reshaped array following NumPy broadcasting conventions
|
||||
|
||||
"""
|
||||
if axis == (0, 2, 3):
|
||||
return _reshape_for_image_stats(value, keepdims)
|
||||
|
||||
if axis in [0, (0,)]:
|
||||
return _reshape_for_vector_stats(value, keepdims, original_shape)
|
||||
|
||||
if axis == (1,):
|
||||
return _reshape_for_feature_stats(value, keepdims)
|
||||
|
||||
if axis is None:
|
||||
return _reshape_for_global_stats(value, keepdims, original_shape)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _prepare_array_for_stats(array: np.ndarray, axis: int | tuple[int, ...] | None) -> tuple[np.ndarray, int]:
|
||||
"""Prepare array for statistics computation by reshaping according to axis.
|
||||
|
||||
Args:
|
||||
array: Input data array
|
||||
axis: Axis or axes along which to compute statistics
|
||||
|
||||
Returns:
|
||||
Tuple of (reshaped_array, sample_count)
|
||||
"""
|
||||
if axis == (0, 2, 3): # Image data
|
||||
batch_size, channels, height, width = array.shape
|
||||
reshaped = array.transpose(0, 2, 3, 1).reshape(-1, channels)
|
||||
return reshaped, batch_size
|
||||
|
||||
if axis == 0 or axis == (0,): # Vector data
|
||||
reshaped = array
|
||||
if array.ndim == 1:
|
||||
reshaped = array.reshape(-1, 1)
|
||||
return reshaped, array.shape[0]
|
||||
|
||||
if axis == (1,): # Feature-wise statistics
|
||||
return array.T, array.shape[1]
|
||||
|
||||
if axis is None: # Global statistics
|
||||
reshaped = array.reshape(-1, 1)
|
||||
# For backward compatibility, count represents the first dimension size
|
||||
return reshaped, array.shape[0] if array.ndim > 0 else 1
|
||||
|
||||
raise ValueError(f"Unsupported axis configuration: {axis}")
|
||||
|
||||
|
||||
def _compute_basic_stats(
|
||||
array: np.ndarray, sample_count: int, quantile_list: list[float] | None = None
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Compute basic statistics for arrays with insufficient samples for quantiles.
|
||||
|
||||
Args:
|
||||
array: Reshaped array ready for statistics computation
|
||||
sample_count: Number of samples represented in the data
|
||||
|
||||
Returns:
|
||||
Dictionary with basic statistics and quantiles set to mean values
|
||||
"""
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
quantile_list_keys = [f"q{int(q * 100):02d}" for q in quantile_list]
|
||||
|
||||
stats = {
|
||||
"min": np.min(array, axis=0),
|
||||
"max": np.max(array, axis=0),
|
||||
"mean": np.mean(array, axis=0),
|
||||
"std": np.std(array, axis=0),
|
||||
"count": np.array([sample_count]),
|
||||
}
|
||||
|
||||
for q in quantile_list_keys:
|
||||
stats[q] = stats["mean"].copy()
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def get_feature_stats(
|
||||
array: np.ndarray,
|
||||
axis: int | tuple[int, ...] | None,
|
||||
keepdims: bool,
|
||||
quantile_list: list[float] | None = None,
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Compute comprehensive statistics for array features along specified axes.
|
||||
|
||||
This function calculates min, max, mean, std, and quantiles (1%, 10%, 50%, 90%, 99%)
|
||||
for the input array along the specified axes. It handles different data layouts:
|
||||
- Image data: axis=(0,2,3) computes per-channel statistics
|
||||
- Vector data: axis=0 computes per-feature statistics
|
||||
- Feature-wise: axis=1 computes statistics across features
|
||||
- Global: axis=None computes statistics over entire array
|
||||
|
||||
Args:
|
||||
array: Input data array with shape appropriate for the specified axis
|
||||
axis: Axis or axes along which to compute statistics
|
||||
- (0, 2, 3): For image data (batch, channels, height, width)
|
||||
- 0 or (0,): For vector/tabular data (samples, features)
|
||||
- (1,): For computing across features
|
||||
- None: For global statistics over entire array
|
||||
keepdims: If True, reduced axes are kept as dimensions with size 1
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- 'min': Minimum values
|
||||
- 'max': Maximum values
|
||||
- 'mean': Mean values
|
||||
- 'std': Standard deviation
|
||||
- 'count': Number of samples (always shape (1,))
|
||||
- 'q01', 'q10', 'q50', 'q90', 'q99': Quantile values
|
||||
|
||||
"""
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
|
||||
original_shape = array.shape
|
||||
reshaped, sample_count = _prepare_array_for_stats(array, axis)
|
||||
|
||||
if reshaped.shape[0] < 2:
|
||||
stats = _compute_basic_stats(reshaped, sample_count, quantile_list)
|
||||
else:
|
||||
running_stats = RunningQuantileStats()
|
||||
running_stats.update(reshaped)
|
||||
stats = running_stats.get_statistics()
|
||||
stats["count"] = np.array([sample_count])
|
||||
|
||||
stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape)
|
||||
return stats
|
||||
|
||||
|
||||
def compute_episode_stats(
|
||||
episode_data: dict[str, list[str] | np.ndarray],
|
||||
features: dict,
|
||||
quantile_list: list[float] | None = None,
|
||||
) -> dict:
|
||||
"""Compute comprehensive statistics for all features in an episode.
|
||||
|
||||
Processes different data types appropriately:
|
||||
- Images/videos: Samples from paths, computes per-channel stats, normalizes to [0,1]
|
||||
- Numerical arrays: Computes per-feature statistics
|
||||
- Strings: Skipped (no statistics computed)
|
||||
|
||||
Args:
|
||||
episode_data: Dictionary mapping feature names to data
|
||||
- For images/videos: list of file paths
|
||||
- For numerical data: numpy arrays
|
||||
features: Dictionary describing each feature's dtype and shape
|
||||
|
||||
Returns:
|
||||
Dictionary mapping feature names to their statistics dictionaries.
|
||||
Each statistics dictionary contains min, max, mean, std, count, and quantiles.
|
||||
|
||||
Note:
|
||||
Image statistics are normalized to [0,1] range and have shape (3,1,1) for
|
||||
per-channel values when dtype is 'image' or 'video'.
|
||||
"""
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
|
||||
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
continue # HACK: we should receive np.arrays of strings
|
||||
elif features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data) # data is a list of image paths
|
||||
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||
continue
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data)
|
||||
axes_to_reduce = (0, 2, 3)
|
||||
keepdims = True
|
||||
else:
|
||||
ep_ft_array = data # data is already a np.ndarray
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
keepdims = data.ndim == 1 # keep as np.array
|
||||
ep_ft_array = data
|
||||
axes_to_reduce = 0
|
||||
keepdims = data.ndim == 1
|
||||
|
||||
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||
ep_stats[key] = get_feature_stats(
|
||||
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list
|
||||
)
|
||||
|
||||
# finally, we normalize and remove batch dim for images
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
@@ -107,20 +529,37 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
|
||||
return ep_stats
|
||||
|
||||
|
||||
def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
|
||||
"""Validate a single statistic value."""
|
||||
if not isinstance(value, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Stats must be composed of numpy array, but key '{key}' of feature '{feature_key}' "
|
||||
f"is of type '{type(value)}' instead."
|
||||
)
|
||||
|
||||
if value.ndim == 0:
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
|
||||
if key == "count" and value.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
|
||||
|
||||
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
|
||||
|
||||
|
||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
for i in range(len(stats_list)):
|
||||
for fkey in stats_list[i]:
|
||||
for k, v in stats_list[i][fkey].items():
|
||||
if not isinstance(v, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
|
||||
)
|
||||
if v.ndim == 0:
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
if k == "count" and v.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
|
||||
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
||||
"""Validate that all statistics have correct types and shapes.
|
||||
|
||||
Args:
|
||||
stats_list: List of statistics dictionaries to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If any statistic has incorrect type or shape
|
||||
"""
|
||||
for stats in stats_list:
|
||||
for feature_key, feature_stats in stats.items():
|
||||
for stat_key, stat_value in feature_stats.items():
|
||||
_validate_stat_value(stat_value, stat_key, feature_key)
|
||||
|
||||
|
||||
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
@@ -143,7 +582,7 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
|
||||
weighted_variances = (variances + delta_means**2) * counts
|
||||
total_variance = weighted_variances.sum(axis=0) / total_count
|
||||
|
||||
return {
|
||||
aggregated = {
|
||||
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
|
||||
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
|
||||
"mean": total_mean,
|
||||
@@ -151,6 +590,17 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
|
||||
"count": total_count,
|
||||
}
|
||||
|
||||
if stats_ft_list:
|
||||
quantile_keys = [k for k in stats_ft_list[0] if k.startswith("q") and k[1:].isdigit()]
|
||||
|
||||
for q_key in quantile_keys:
|
||||
if all(q_key in s for s in stats_ft_list):
|
||||
quantile_values = np.stack([s[q_key] for s in stats_ft_list])
|
||||
weighted_quantiles = quantile_values * counts
|
||||
aggregated[q_key] = weighted_quantiles.sum(axis=0) / total_count
|
||||
|
||||
return aggregated
|
||||
|
||||
|
||||
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||
|
||||
1015
src/lerobot/datasets/dataset_tools.py
Normal file
1015
src/lerobot/datasets/dataset_tools.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -27,6 +27,7 @@ from lerobot.datasets.lerobot_dataset import (
|
||||
)
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
|
||||
|
||||
IMAGENET_STATS = {
|
||||
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
||||
@@ -54,11 +55,11 @@ def resolve_delta_timestamps(
|
||||
"""
|
||||
delta_timestamps = {}
|
||||
for key in ds_meta.features:
|
||||
if key == "next.reward" and cfg.reward_delta_indices is not None:
|
||||
if key == REWARD and cfg.reward_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
||||
if key == "action" and cfg.action_delta_indices is not None:
|
||||
if key == ACTION and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
|
||||
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
|
||||
if len(delta_timestamps) == 0:
|
||||
|
||||
@@ -68,7 +68,30 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
|
||||
"""
|
||||
Saves a NumPy array or PIL Image to a file.
|
||||
|
||||
This function handles both NumPy arrays and PIL Image objects, converting
|
||||
the former to a PIL Image before saving. It includes error handling for
|
||||
the save operation.
|
||||
|
||||
Args:
|
||||
image (np.ndarray | PIL.Image.Image): The image data to save.
|
||||
fpath (Path): The destination file path for the image.
|
||||
compress_level (int, optional): The compression level for the saved
|
||||
image, as used by PIL.Image.save(). Defaults to 1.
|
||||
Refer to: https://github.com/huggingface/lerobot/pull/2135
|
||||
for more details on the default value rationale.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input 'image' is not a NumPy array or a
|
||||
PIL.Image.Image object.
|
||||
|
||||
Side Effects:
|
||||
Prints an error message to the console if the image writing process
|
||||
fails for any reason.
|
||||
"""
|
||||
try:
|
||||
if isinstance(image, np.ndarray):
|
||||
img = image_array_to_pil_image(image)
|
||||
@@ -76,7 +99,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
img = image
|
||||
else:
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath)
|
||||
img.save(fpath, compress_level=compress_level)
|
||||
except Exception as e:
|
||||
print(f"Error writing image {fpath}: {e}")
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import gc
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
@@ -26,12 +25,13 @@ import numpy as np
|
||||
import packaging.version
|
||||
import pandas as pd
|
||||
import PIL.Image
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.datasets.utils import (
|
||||
@@ -47,13 +47,9 @@ from lerobot.datasets.utils import (
|
||||
embed_images,
|
||||
flatten_dict,
|
||||
get_delta_indices,
|
||||
get_hf_dataset_cache_dir,
|
||||
get_hf_dataset_size_in_mb,
|
||||
get_file_size_in_mb,
|
||||
get_hf_features_from_features,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_parquet_num_frames,
|
||||
get_safe_version,
|
||||
get_video_size_in_mb,
|
||||
hf_transform_to_torch,
|
||||
is_valid_version,
|
||||
load_episodes,
|
||||
@@ -61,7 +57,6 @@ from lerobot.datasets.utils import (
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
@@ -79,6 +74,7 @@ from lerobot.datasets.video_utils import (
|
||||
get_video_duration_in_s,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
@@ -90,10 +86,15 @@ class LeRobotDatasetMetadata:
|
||||
root: str | Path | None = None,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
metadata_buffer_size: int = 10,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
self.writer = None
|
||||
self.latest_episode = None
|
||||
self.metadata_buffer: list[dict] = []
|
||||
self.metadata_buffer_size = metadata_buffer_size
|
||||
|
||||
try:
|
||||
if force_cache_sync:
|
||||
@@ -107,6 +108,54 @@ class LeRobotDatasetMetadata:
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.load_metadata()
|
||||
|
||||
def _flush_metadata_buffer(self) -> None:
|
||||
"""Write all buffered episode metadata to parquet file."""
|
||||
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
|
||||
return
|
||||
|
||||
combined_dict = {}
|
||||
for episode_dict in self.metadata_buffer:
|
||||
for key, value in episode_dict.items():
|
||||
if key not in combined_dict:
|
||||
combined_dict[key] = []
|
||||
# Extract value and serialize numpy arrays
|
||||
# because PyArrow's from_pydict function doesn't support numpy arrays
|
||||
val = value[0] if isinstance(value, list) else value
|
||||
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
|
||||
|
||||
first_ep = self.metadata_buffer[0]
|
||||
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
|
||||
file_idx = first_ep["meta/episodes/file_index"][0]
|
||||
|
||||
table = pa.Table.from_pydict(combined_dict)
|
||||
|
||||
if not self.writer:
|
||||
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True
|
||||
)
|
||||
|
||||
self.writer.write_table(table)
|
||||
|
||||
self.latest_episode = self.metadata_buffer[-1]
|
||||
self.metadata_buffer.clear()
|
||||
|
||||
def _close_writer(self) -> None:
|
||||
"""Close and cleanup the parquet writer if it exists."""
|
||||
self._flush_metadata_buffer()
|
||||
|
||||
writer = getattr(self, "writer", None)
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
self.writer = None
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
|
||||
"""
|
||||
self._close_writer()
|
||||
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
@@ -138,6 +187,12 @@ class LeRobotDatasetMetadata:
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
raise IndexError(
|
||||
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||
)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep["data/chunk_index"]
|
||||
file_idx = ep["data/file_index"]
|
||||
@@ -145,6 +200,12 @@ class LeRobotDatasetMetadata:
|
||||
return Path(fpath)
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
raise IndexError(
|
||||
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||
)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep[f"videos/{vid_key}/chunk_index"]
|
||||
file_idx = ep[f"videos/{vid_key}/file_index"]
|
||||
@@ -260,72 +321,75 @@ class LeRobotDatasetMetadata:
|
||||
write_tasks(self.tasks, self.root)
|
||||
|
||||
def _save_episode_metadata(self, episode_dict: dict) -> None:
|
||||
"""Save episode metadata to a parquet file and update the Hugging Face dataset of episodes metadata.
|
||||
"""Buffer episode metadata and write to parquet in batches for efficiency.
|
||||
|
||||
This function processes episodes metadata from a dictionary, converts it into a Hugging Face dataset,
|
||||
and saves it as a parquet file. It handles both the creation of new parquet files and the
|
||||
updating of existing ones based on size constraints. After saving the metadata, it reloads
|
||||
the Hugging Face dataset to ensure it is up-to-date.
|
||||
This function accumulates episode metadata in a buffer and flushes it when the buffer
|
||||
reaches the configured size. This reduces I/O overhead by writing multiple episodes
|
||||
at once instead of one row at a time.
|
||||
|
||||
Notes: We both need to update parquet files and HF dataset:
|
||||
- `pandas` loads parquet file in RAM
|
||||
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
|
||||
or loads directly from pyarrow cache.
|
||||
"""
|
||||
# Convert buffer into HF Dataset
|
||||
# Convert to list format for each value
|
||||
episode_dict = {key: [value] for key, value in episode_dict.items()}
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict)
|
||||
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||
df = pd.DataFrame(ep_dataset)
|
||||
num_frames = episode_dict["length"][0]
|
||||
|
||||
if self.episodes is None:
|
||||
if self.latest_episode is None:
|
||||
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
df["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
df["meta/episodes/file_index"] = [file_idx]
|
||||
df["dataset_from_index"] = [0]
|
||||
df["dataset_to_index"] = [num_frames]
|
||||
else:
|
||||
# Retrieve information from the latest parquet file
|
||||
latest_ep = self.episodes[-1]
|
||||
chunk_idx = latest_ep["meta/episodes/chunk_index"]
|
||||
file_idx = latest_ep["meta/episodes/file_index"]
|
||||
if self.episodes is not None and len(self.episodes) > 0:
|
||||
# It means we are resuming recording, so we need to load the latest episode
|
||||
# Update the indices to avoid overwriting the latest episode
|
||||
chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"]
|
||||
file_idx = self.episodes[-1]["meta/episodes/file_index"]
|
||||
latest_num_frames = self.episodes[-1]["dataset_to_index"]
|
||||
episode_dict["dataset_from_index"] = [latest_num_frames]
|
||||
episode_dict["dataset_to_index"] = [latest_num_frames + num_frames]
|
||||
|
||||
latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.data_files_size_in_mb:
|
||||
# Size limit is reached, prepare new parquet file
|
||||
# When resuming, move to the next file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
||||
else:
|
||||
episode_dict["dataset_from_index"] = [0]
|
||||
episode_dict["dataset_to_index"] = [num_frames]
|
||||
|
||||
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
episode_dict["meta/episodes/file_index"] = [file_idx]
|
||||
else:
|
||||
chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0]
|
||||
file_idx = self.latest_episode["meta/episodes/file_index"][0]
|
||||
|
||||
latest_path = (
|
||||
self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
if self.writer is None
|
||||
else self.writer.where
|
||||
)
|
||||
|
||||
if Path(latest_path).exists():
|
||||
latest_size_in_mb = get_file_size_in_mb(Path(latest_path))
|
||||
latest_num_frames = self.latest_episode["episode_index"][0]
|
||||
|
||||
av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0
|
||||
|
||||
if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb:
|
||||
# Size limit is reached, flush buffer and prepare new parquet file
|
||||
self._flush_metadata_buffer()
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
||||
self._close_writer()
|
||||
|
||||
# Update the existing pandas dataframe with new row
|
||||
df["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
df["meta/episodes/file_index"] = [file_idx]
|
||||
df["dataset_from_index"] = [latest_ep["dataset_to_index"]]
|
||||
df["dataset_to_index"] = [latest_ep["dataset_to_index"] + num_frames]
|
||||
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
episode_dict["meta/episodes/file_index"] = [file_idx]
|
||||
episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]]
|
||||
episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames]
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb < self.data_files_size_in_mb:
|
||||
# Size limit wasnt reached, concatenate latest dataframe with new one
|
||||
latest_df = pd.read_parquet(latest_path)
|
||||
df = pd.concat([latest_df, df], ignore_index=True)
|
||||
# Add to buffer
|
||||
self.metadata_buffer.append(episode_dict)
|
||||
self.latest_episode = episode_dict
|
||||
|
||||
# Memort optimization
|
||||
del latest_df
|
||||
gc.collect()
|
||||
|
||||
# Write the resulting dataframe from RAM to disk
|
||||
path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(path, index=False)
|
||||
|
||||
if self.episodes is not None:
|
||||
# Remove the episodes cache directory, necessary to avoid cache bloat
|
||||
cached_dir = get_hf_dataset_cache_dir(self.episodes)
|
||||
if cached_dir is not None:
|
||||
shutil.rmtree(cached_dir)
|
||||
|
||||
self.episodes = load_episodes(self.root)
|
||||
if len(self.metadata_buffer) >= self.metadata_buffer_size:
|
||||
self._flush_metadata_buffer()
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
@@ -438,6 +502,10 @@ class LeRobotDatasetMetadata:
|
||||
robot_type: str | None = None,
|
||||
root: str | Path | None = None,
|
||||
use_videos: bool = True,
|
||||
metadata_buffer_size: int = 10,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> "LeRobotDatasetMetadata":
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
@@ -452,11 +520,24 @@ class LeRobotDatasetMetadata:
|
||||
obj.tasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type)
|
||||
obj.info = create_empty_dataset_info(
|
||||
CODEBASE_VERSION,
|
||||
fps,
|
||||
features,
|
||||
use_videos,
|
||||
robot_type,
|
||||
chunks_size,
|
||||
data_files_size_in_mb,
|
||||
video_files_size_in_mb,
|
||||
)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
obj.revision = None
|
||||
obj.writer = None
|
||||
obj.latest_episode = None
|
||||
obj.metadata_buffer = []
|
||||
obj.metadata_buffer_size = metadata_buffer_size
|
||||
return obj
|
||||
|
||||
|
||||
@@ -603,6 +684,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
self.episode_buffer = None
|
||||
self.writer = None
|
||||
self.latest_episode = None
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@@ -611,6 +694,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
|
||||
# Track dataset state for efficient incremental writing
|
||||
self._lazy_loading = False
|
||||
self._recorded_frames = self.meta.total_frames
|
||||
self._writer_closed_for_reading = False
|
||||
|
||||
# Load actual data
|
||||
try:
|
||||
if force_cache_sync:
|
||||
@@ -629,6 +717,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
def _close_writer(self) -> None:
|
||||
"""Close and cleanup the parquet writer if it exists."""
|
||||
writer = getattr(self, "writer", None)
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
self.writer = None
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
|
||||
"""
|
||||
self._close_writer()
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
branch: str | None = None,
|
||||
@@ -769,8 +870,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Number of frames in selected episodes."""
|
||||
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
|
||||
"""Number of frames in selected episodes.
|
||||
|
||||
Note: When episodes a subset of the full dataset is requested, we must return the
|
||||
actual loaded data length (len(self.hf_dataset)) rather than metadata total_frames.
|
||||
self.meta.total_frames is the total number of frames in the full dataset.
|
||||
"""
|
||||
if self.episodes is not None and self.hf_dataset is not None:
|
||||
return len(self.hf_dataset)
|
||||
return self.meta.total_frames
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
@@ -848,15 +956,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return item
|
||||
|
||||
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
|
||||
for key, val in padding.items():
|
||||
item[key] = torch.BoolTensor(val)
|
||||
return item
|
||||
def _ensure_hf_dataset_loaded(self):
|
||||
"""Lazy load the HF dataset only when needed for reading."""
|
||||
if self._lazy_loading or self.hf_dataset is None:
|
||||
# Close the writer before loading to ensure parquet file is properly finalized
|
||||
if self.writer is not None:
|
||||
self._close_writer()
|
||||
self._writer_closed_for_reading = True
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self._lazy_loading = False
|
||||
|
||||
def __len__(self):
|
||||
return self.num_frames
|
||||
|
||||
def __getitem__(self, idx) -> dict:
|
||||
# Ensure dataset is loaded when we actually need to read from it
|
||||
self._ensure_hf_dataset_loaded()
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
@@ -895,6 +1010,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"})',\n"
|
||||
)
|
||||
|
||||
def finalize(self):
|
||||
"""
|
||||
Close the parquet writers. This function needs to be called after data collection/conversion, else footer metadata won't be written to the parquet files.
|
||||
The dataset won't be valid and can't be loaded as ds = LeRobotDataset(repo_id=repo, root=HF_LEROBOT_HOME.joinpath(repo))
|
||||
"""
|
||||
self._close_writer()
|
||||
self.meta._close_writer()
|
||||
|
||||
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
ep_buffer = {}
|
||||
@@ -1032,7 +1155,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Reset episode buffer and clean up temporary images (if not already deleted during video encoding)
|
||||
self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0)
|
||||
|
||||
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None):
|
||||
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
|
||||
"""
|
||||
Batch save videos for multiple episodes.
|
||||
|
||||
@@ -1102,74 +1225,101 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ep_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train")
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||
ep_num_frames = len(ep_dataset)
|
||||
df = pd.DataFrame(ep_dataset)
|
||||
|
||||
if self.meta.episodes is None:
|
||||
if self.latest_episode is None:
|
||||
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
latest_num_frames = 0
|
||||
global_frame_index = 0
|
||||
# However, if the episodes already exists
|
||||
# It means we are resuming recording, so we need to load the latest episode
|
||||
# Update the indices to avoid overwriting the latest episode
|
||||
if self.meta.episodes is not None and len(self.meta.episodes) > 0:
|
||||
latest_ep = self.meta.episodes[-1]
|
||||
global_frame_index = latest_ep["dataset_to_index"]
|
||||
chunk_idx = latest_ep["data/chunk_index"]
|
||||
file_idx = latest_ep["data/file_index"]
|
||||
|
||||
# When resuming, move to the next file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
else:
|
||||
# Retrieve information from the latest parquet file
|
||||
latest_ep = self.meta.episodes[-1]
|
||||
latest_ep = self.latest_episode
|
||||
chunk_idx = latest_ep["data/chunk_index"]
|
||||
file_idx = latest_ep["data/file_index"]
|
||||
global_frame_index = latest_ep["index"][-1] + 1
|
||||
|
||||
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||
latest_num_frames = get_parquet_num_frames(latest_path)
|
||||
latest_size_in_mb = get_file_size_in_mb(latest_path)
|
||||
|
||||
frames_in_current_file = global_frame_index - latest_ep["dataset_from_index"]
|
||||
av_size_per_frame = (
|
||||
latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0
|
||||
)
|
||||
|
||||
# Determine if a new parquet file is needed
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.data_files_size_in_mb:
|
||||
# Size limit is reached, prepare new parquet file
|
||||
if (
|
||||
latest_size_in_mb + av_size_per_frame * ep_num_frames >= self.meta.data_files_size_in_mb
|
||||
or self._writer_closed_for_reading
|
||||
):
|
||||
# Size limit is reached or writer was closed for reading, prepare new parquet file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
latest_num_frames = 0
|
||||
else:
|
||||
# Update the existing parquet file with new rows
|
||||
latest_df = pd.read_parquet(latest_path)
|
||||
df = pd.concat([latest_df, df], ignore_index=True)
|
||||
self._close_writer()
|
||||
self._writer_closed_for_reading = False
|
||||
|
||||
# Memort optimization
|
||||
del latest_df
|
||||
gc.collect()
|
||||
ep_dict["data/chunk_index"] = chunk_idx
|
||||
ep_dict["data/file_index"] = file_idx
|
||||
|
||||
# Write the resulting dataframe from RAM to disk
|
||||
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if len(self.meta.image_keys) > 0:
|
||||
to_parquet_with_hf_images(df, path)
|
||||
else:
|
||||
df.to_parquet(path)
|
||||
|
||||
if self.hf_dataset is not None:
|
||||
# Remove hf dataset cache directory, necessary to avoid cache bloat
|
||||
cached_dir = get_hf_dataset_cache_dir(self.hf_dataset)
|
||||
if cached_dir is not None:
|
||||
shutil.rmtree(cached_dir)
|
||||
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
table = ep_dataset.with_format("arrow")[:]
|
||||
if not self.writer:
|
||||
self.writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True
|
||||
)
|
||||
self.writer.write_table(table)
|
||||
|
||||
metadata = {
|
||||
"data/chunk_index": chunk_idx,
|
||||
"data/file_index": file_idx,
|
||||
"dataset_from_index": latest_num_frames,
|
||||
"dataset_to_index": latest_num_frames + ep_num_frames,
|
||||
"dataset_from_index": global_frame_index,
|
||||
"dataset_to_index": global_frame_index + ep_num_frames,
|
||||
}
|
||||
|
||||
# Store metadata with episode data for next episode
|
||||
self.latest_episode = {**ep_dict, **metadata}
|
||||
|
||||
# Mark that the HF dataset needs reloading (lazy loading approach)
|
||||
# This avoids expensive reloading during sequential recording
|
||||
self._lazy_loading = True
|
||||
# Update recorded frames count for efficient length tracking
|
||||
self._recorded_frames += ep_num_frames
|
||||
|
||||
return metadata
|
||||
|
||||
def _save_episode_video(self, video_key: str, episode_index: int):
|
||||
def _save_episode_video(self, video_key: str, episode_index: int) -> dict:
|
||||
# Encode episode frames into a temporary video
|
||||
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
|
||||
if self.meta.episodes is None or (
|
||||
f"videos/{video_key}/chunk_index" not in self.meta.episodes.column_names
|
||||
or f"videos/{video_key}/file_index" not in self.meta.episodes.column_names
|
||||
if (
|
||||
episode_index == 0
|
||||
or self.meta.latest_episode is None
|
||||
or f"videos/{video_key}/chunk_index" not in self.meta.latest_episode
|
||||
):
|
||||
# Initialize indices for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
if self.meta.episodes is not None and len(self.meta.episodes) > 0:
|
||||
# It means we are resuming recording, so we need to load the latest episode
|
||||
# Update the indices to avoid overwriting the latest episode
|
||||
old_chunk_idx = self.meta.episodes[-1][f"videos/{video_key}/chunk_index"]
|
||||
old_file_idx = self.meta.episodes[-1][f"videos/{video_key}/file_index"]
|
||||
chunk_idx, file_idx = update_chunk_file_indices(
|
||||
old_chunk_idx, old_file_idx, self.meta.chunks_size
|
||||
)
|
||||
latest_duration_in_s = 0.0
|
||||
new_path = self.root / self.meta.video_path.format(
|
||||
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
@@ -1177,16 +1327,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(ep_path), str(new_path))
|
||||
else:
|
||||
# Retrieve information from the latest updated video file (possibly several episodes ago)
|
||||
latest_ep = self.meta.episodes[episode_index - 1]
|
||||
chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"]
|
||||
file_idx = latest_ep[f"videos/{video_key}/file_index"]
|
||||
# Retrieve information from the latest updated video file using latest_episode
|
||||
latest_ep = self.meta.latest_episode
|
||||
chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"][0]
|
||||
file_idx = latest_ep[f"videos/{video_key}/file_index"][0]
|
||||
|
||||
latest_path = self.root / self.meta.video_path.format(
|
||||
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
latest_size_in_mb = get_video_size_in_mb(latest_path)
|
||||
latest_duration_in_s = get_video_duration_in_s(latest_path)
|
||||
latest_size_in_mb = get_file_size_in_mb(latest_path)
|
||||
latest_duration_in_s = latest_ep[f"videos/{video_key}/to_timestamp"][0]
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.video_files_size_in_mb:
|
||||
# Move temporary episode video to a new video file in the dataset
|
||||
@@ -1263,7 +1413,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.wait_until_done()
|
||||
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> dict:
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
"""
|
||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
@@ -1320,6 +1470,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj.writer = None
|
||||
obj.latest_episode = None
|
||||
# Initialize tracking for incremental recording
|
||||
obj._lazy_loading = False
|
||||
obj._recorded_frames = 0
|
||||
obj._writer_closed_for_reading = False
|
||||
return obj
|
||||
|
||||
|
||||
@@ -1396,11 +1552,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
||||
|
||||
@property
|
||||
def repo_index_to_id(self):
|
||||
"""Return the inverse mapping if repo_id_to_index."""
|
||||
return {v: k for k, v in self.repo_id_to_index}
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection.
|
||||
@@ -1431,7 +1582,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""Keys to access image and video stream from cameras."""
|
||||
keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, (datasets.Image, VideoFrame)):
|
||||
if isinstance(feats, (datasets.Image | VideoFrame)):
|
||||
keys.append(key)
|
||||
return keys
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType
|
||||
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import DataProcessorPipeline
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
|
||||
|
||||
def create_initial_features(
|
||||
@@ -92,8 +92,8 @@ def aggregate_pipeline_dataset_features(
|
||||
|
||||
# Intermediate storage for categorized and filtered features.
|
||||
processed_features: dict[str, dict[str, Any]] = {
|
||||
"action": {},
|
||||
"observation": {},
|
||||
ACTION: {},
|
||||
OBS_STR: {},
|
||||
}
|
||||
images_token = OBS_IMAGES.split(".")[-1]
|
||||
|
||||
@@ -125,17 +125,15 @@ def aggregate_pipeline_dataset_features(
|
||||
# 3. Add the feature to the appropriate group with a clean name.
|
||||
name = strip_prefix(key, PREFIXES_TO_STRIP)
|
||||
if is_action:
|
||||
processed_features["action"][name] = value
|
||||
processed_features[ACTION][name] = value
|
||||
else:
|
||||
processed_features["observation"][name] = value
|
||||
processed_features[OBS_STR][name] = value
|
||||
|
||||
# Convert the processed features into the final dataset format.
|
||||
dataset_features = {}
|
||||
if processed_features["action"]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos))
|
||||
if processed_features["observation"]:
|
||||
dataset_features.update(
|
||||
hw_to_dataset_features(processed_features["observation"], "observation", use_videos)
|
||||
)
|
||||
if processed_features[ACTION]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features[ACTION], ACTION, use_videos))
|
||||
if processed_features[OBS_STR]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features[OBS_STR], OBS_STR, use_videos))
|
||||
|
||||
return dataset_features
|
||||
|
||||
@@ -13,67 +13,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.video_utils import encode_video_frames
|
||||
|
||||
|
||||
def concatenate_episodes(ep_dicts):
|
||||
data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for x in ep_dict[key]:
|
||||
data_dict[key].append(x)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def save_image(img_array, i, out_dir):
|
||||
img = PIL.Image.fromarray(img_array)
|
||||
img.save(str(out_dir / f"frame_{i:06d}.png"), quality=100)
|
||||
|
||||
num_images = len(imgs_array)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
|
||||
|
||||
def get_default_encoding() -> dict:
|
||||
"""Returns the default ffmpeg encoding parameters used by `encode_video_frames`."""
|
||||
signature = inspect.signature(encode_video_frames)
|
||||
return {
|
||||
k: v.default
|
||||
for k, v in signature.parameters.items()
|
||||
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||
}
|
||||
|
||||
|
||||
def check_repo_id(repo_id: str) -> None:
|
||||
if len(repo_id.split("/")) != 2:
|
||||
raise ValueError(
|
||||
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
|
||||
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
|
||||
)
|
||||
|
||||
|
||||
# TODO(aliberts): remove
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
|
||||
|
||||
@@ -21,7 +21,6 @@ import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from lerobot.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
Backtrackable,
|
||||
@@ -38,6 +37,7 @@ from lerobot.datasets.video_utils import (
|
||||
VideoDecoderCache,
|
||||
decode_video_frames_torchcodec,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
||||
|
||||
|
||||
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
@@ -298,9 +298,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
return padding_mask
|
||||
|
||||
def make_frame(
|
||||
self, dataset_iterator: Backtrackable, previous_dataset_iterator: Backtrackable | None = None
|
||||
) -> Generator:
|
||||
def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
|
||||
"""Makes a frame starting from a dataset iterator"""
|
||||
item = next(dataset_iterator)
|
||||
item = item_to_torch(item)
|
||||
|
||||
@@ -120,7 +120,7 @@ class SharpnessJitter(Transform):
|
||||
self.sharpness = self._check_input(sharpness)
|
||||
|
||||
def _check_input(self, sharpness):
|
||||
if isinstance(sharpness, (int, float)):
|
||||
if isinstance(sharpness, (int | float)):
|
||||
if sharpness < 0:
|
||||
raise ValueError("If sharpness is a single number, it must be non negative.")
|
||||
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
||||
|
||||
@@ -21,7 +21,7 @@ from collections import deque
|
||||
from collections.abc import Iterable, Iterator
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any, Deque, Generic, TypeVar
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@@ -30,7 +30,7 @@ import pandas
|
||||
import pandas as pd
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from datasets import Dataset, concatenate_datasets
|
||||
from datasets import Dataset
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
@@ -43,7 +43,8 @@ from lerobot.datasets.backward_compatibility import (
|
||||
BackwardCompatibilityError,
|
||||
ForwardCompatibilityError,
|
||||
)
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
|
||||
from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_string
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
@@ -66,18 +67,6 @@ DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{fram
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
|
||||
LEGACY_DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
LEGACY_DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
# Metadata will go there
|
||||
---
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
## {}
|
||||
|
||||
"""
|
||||
|
||||
DEFAULT_FEATURES = {
|
||||
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
@@ -105,12 +94,6 @@ def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
|
||||
return hf_ds.data.nbytes // (1024**2)
|
||||
|
||||
|
||||
def get_hf_dataset_cache_dir(hf_ds: Dataset) -> Path | None:
|
||||
if hf_ds.cache_files is None or len(hf_ds.cache_files) == 0:
|
||||
return None
|
||||
return Path(hf_ds.cache_files[0]["filename"]).parents[2]
|
||||
|
||||
|
||||
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]:
|
||||
if file_idx == chunks_size - 1:
|
||||
file_idx = 0
|
||||
@@ -134,8 +117,9 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None)
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
||||
datasets = [Dataset.from_parquet(str(path), features=features) for path in paths]
|
||||
return concatenate_datasets(datasets)
|
||||
with SuppressProgressBars():
|
||||
datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
|
||||
return datasets
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
@@ -143,10 +127,14 @@ def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
return metadata.num_rows
|
||||
|
||||
|
||||
def get_video_size_in_mb(mp4_path: Path) -> float:
|
||||
file_size_bytes = mp4_path.stat().st_size
|
||||
file_size_mb = file_size_bytes / (1024**2)
|
||||
return file_size_mb
|
||||
def get_file_size_in_mb(file_path: Path) -> float:
|
||||
"""Get file size on disk in megabytes.
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to the file.
|
||||
"""
|
||||
file_size_bytes = file_path.stat().st_size
|
||||
return file_size_bytes / (1024**2)
|
||||
|
||||
|
||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
@@ -218,13 +206,13 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
"""
|
||||
serialized_dict = {}
|
||||
for key, value in flatten_dict(stats).items():
|
||||
if isinstance(value, (torch.Tensor, np.ndarray)):
|
||||
if isinstance(value, (torch.Tensor | np.ndarray)):
|
||||
serialized_dict[key] = value.tolist()
|
||||
elif isinstance(value, list) and isinstance(value[0], (int, float, list)):
|
||||
elif isinstance(value, list) and isinstance(value[0], (int | float | list)):
|
||||
serialized_dict[key] = value
|
||||
elif isinstance(value, np.generic):
|
||||
serialized_dict[key] = value.item()
|
||||
elif isinstance(value, (int, float)):
|
||||
elif isinstance(value, (int | float)):
|
||||
serialized_dict[key] = value
|
||||
else:
|
||||
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
|
||||
@@ -382,12 +370,6 @@ def load_episodes(local_dir: Path) -> datasets.Dataset:
|
||||
return episodes
|
||||
|
||||
|
||||
def backward_compatible_episodes_stats(
|
||||
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||
) -> dict[int, dict[str, dict[str, np.ndarray]]]:
|
||||
return dict.fromkeys(episodes, stats)
|
||||
|
||||
|
||||
def load_image_as_numpy(
|
||||
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
|
||||
) -> np.ndarray:
|
||||
@@ -645,14 +627,14 @@ def hw_to_dataset_features(
|
||||
}
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
|
||||
if joint_fts and prefix == "action":
|
||||
if joint_fts and prefix == ACTION:
|
||||
features[prefix] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(joint_fts),),
|
||||
"names": list(joint_fts),
|
||||
}
|
||||
|
||||
if joint_fts and prefix == "observation":
|
||||
if joint_fts and prefix == OBS_STR:
|
||||
features[f"{prefix}.state"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(joint_fts),),
|
||||
@@ -728,11 +710,11 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == "observation.environment_state":
|
||||
elif key == OBS_ENV_STATE:
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith("observation"):
|
||||
elif key.startswith(OBS_STR):
|
||||
type = FeatureType.STATE
|
||||
elif key.startswith("action"):
|
||||
elif key.startswith(ACTION):
|
||||
type = FeatureType.ACTION
|
||||
else:
|
||||
continue
|
||||
@@ -1196,7 +1178,7 @@ def item_to_torch(item: dict) -> dict:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray, list)) and key not in ["task"]:
|
||||
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
@@ -1270,8 +1252,8 @@ class Backtrackable(Generic[T]):
|
||||
raise ValueError("lookahead must be > 0")
|
||||
|
||||
self._source: Iterator[T] = iter(iterable)
|
||||
self._back_buf: Deque[T] = deque(maxlen=history)
|
||||
self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
||||
self._back_buf: deque[T] = deque(maxlen=history)
|
||||
self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
||||
self._cursor: int = 0
|
||||
self._history = history
|
||||
self._lookahead = lookahead
|
||||
@@ -1345,12 +1327,6 @@ class Backtrackable(Generic[T]):
|
||||
# When cursor<0, slice so the order remains chronological
|
||||
return list(self._back_buf)[: self._cursor or None]
|
||||
|
||||
def lookahead_buffer(self) -> list[T]:
|
||||
"""
|
||||
Return a copy of the current lookahead buffer.
|
||||
"""
|
||||
return list(self._ahead_buf)
|
||||
|
||||
def can_peek_back(self, steps: int = 1) -> bool:
|
||||
"""
|
||||
Check if we can go back `steps` items without raising an IndexError.
|
||||
@@ -1376,31 +1352,6 @@ class Backtrackable(Generic[T]):
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
def reset_cursor(self) -> None:
|
||||
"""
|
||||
Reset cursor to the most recent position (equivalent to calling next()
|
||||
until you're back to the latest item).
|
||||
"""
|
||||
self._cursor = 0
|
||||
|
||||
def clear_ahead_buffer(self) -> None:
|
||||
"""
|
||||
Clear the ahead buffer, discarding any pre-fetched items.
|
||||
"""
|
||||
self._ahead_buf.clear()
|
||||
|
||||
def switch_source_iterable(self, new_source: Iterable[T]) -> None:
|
||||
"""
|
||||
Switch the source of the backtrackable to a new iterable, keeping the history.
|
||||
|
||||
This is useful when iterating over a sequence of datasets. The history from the
|
||||
previous source is kept, but the lookahead buffer is cleared. The cursor is reset
|
||||
to the present.
|
||||
"""
|
||||
self._source = iter(new_source)
|
||||
self.clear_ahead_buffer()
|
||||
self.reset_cursor()
|
||||
|
||||
|
||||
def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) -> datasets.Dataset:
|
||||
"""
|
||||
|
||||
260
src/lerobot/datasets/v30/augment_dataset_quantile_stats.py
Normal file
260
src/lerobot/datasets/v30/augment_dataset_quantile_stats.py
Normal file
@@ -0,0 +1,260 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script augments existing LeRobot datasets with quantile statistics.
|
||||
|
||||
Most datasets created before the quantile feature was added do not contain
|
||||
quantile statistics (q01, q10, q50, q90, q99) in their metadata. This script:
|
||||
|
||||
1. Loads an existing LeRobot dataset in v3.0 format
|
||||
2. Checks if it already contains quantile statistics
|
||||
3. If missing, computes quantile statistics for all features
|
||||
4. Updates the dataset metadata with the new quantile statistics
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
--repo-id=lerobot/pusht \
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.datasets.utils import write_stats
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
def has_quantile_stats(stats: dict[str, dict] | None, quantile_list_keys: list[str] | None = None) -> bool:
|
||||
"""Check if dataset statistics already contain quantile information.
|
||||
|
||||
Args:
|
||||
stats: Dataset statistics dictionary
|
||||
|
||||
Returns:
|
||||
True if quantile statistics are present, False otherwise
|
||||
"""
|
||||
if quantile_list_keys is None:
|
||||
quantile_list_keys = [f"q{int(q * 100):02d}" for q in DEFAULT_QUANTILES]
|
||||
|
||||
if stats is None:
|
||||
return False
|
||||
|
||||
for feature_stats in stats.values():
|
||||
if any(q_key in feature_stats for q_key in quantile_list_keys):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||
"""Process a single episode and return its statistics.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset
|
||||
episode_idx: Index of the episode to process
|
||||
|
||||
Returns:
|
||||
Dictionary containing episode statistics
|
||||
"""
|
||||
logging.info(f"Computing stats for episode {episode_idx}")
|
||||
|
||||
start_idx = dataset.meta.episodes[episode_idx]["dataset_from_index"]
|
||||
end_idx = dataset.meta.episodes[episode_idx]["dataset_to_index"]
|
||||
|
||||
collected_data: dict[str, list] = {}
|
||||
for idx in range(start_idx, end_idx):
|
||||
item = dataset[idx]
|
||||
for key, value in item.items():
|
||||
if key not in dataset.features:
|
||||
continue
|
||||
|
||||
if key not in collected_data:
|
||||
collected_data[key] = []
|
||||
collected_data[key].append(value)
|
||||
|
||||
ep_stats = {}
|
||||
for key, data_list in collected_data.items():
|
||||
if dataset.features[key]["dtype"] == "string":
|
||||
continue
|
||||
|
||||
data = torch.stack(data_list).cpu().numpy()
|
||||
if dataset.features[key]["dtype"] in ["image", "video"]:
|
||||
if data.dtype == np.uint8:
|
||||
data = data.astype(np.float32) / 255.0
|
||||
|
||||
axes_to_reduce = (0, 2, 3)
|
||||
keepdims = True
|
||||
else:
|
||||
axes_to_reduce = 0
|
||||
keepdims = data.ndim == 1
|
||||
|
||||
ep_stats[key] = get_feature_stats(
|
||||
data, axis=axes_to_reduce, keepdims=keepdims, quantile_list=DEFAULT_QUANTILES
|
||||
)
|
||||
|
||||
if dataset.features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
|
||||
|
||||
def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dict]:
|
||||
"""Compute quantile statistics for all episodes in the dataset.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset to compute statistics for
|
||||
|
||||
Returns:
|
||||
Dictionary containing aggregated statistics with quantiles
|
||||
|
||||
Note:
|
||||
Video decoding operations are not thread-safe, so we process episodes sequentially
|
||||
when video keys are present. For datasets without videos, we use parallel processing
|
||||
with ThreadPoolExecutor for better performance.
|
||||
"""
|
||||
logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes")
|
||||
|
||||
episode_stats_list = []
|
||||
has_videos = len(dataset.meta.video_keys) > 0
|
||||
|
||||
if has_videos:
|
||||
logging.info("Dataset contains video keys - using sequential processing for thread safety")
|
||||
for episode_idx in tqdm(range(dataset.num_episodes), desc="Processing episodes"):
|
||||
ep_stats = process_single_episode(dataset, episode_idx)
|
||||
episode_stats_list.append(ep_stats)
|
||||
else:
|
||||
logging.info("Dataset has no video keys - using parallel processing for better performance")
|
||||
max_workers = min(dataset.num_episodes, 16)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_episode = {
|
||||
executor.submit(process_single_episode, dataset, episode_idx): episode_idx
|
||||
for episode_idx in range(dataset.num_episodes)
|
||||
}
|
||||
|
||||
episode_results = {}
|
||||
with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar:
|
||||
for future in concurrent.futures.as_completed(future_to_episode):
|
||||
episode_idx = future_to_episode[future]
|
||||
ep_stats = future.result()
|
||||
episode_results[episode_idx] = ep_stats
|
||||
pbar.update(1)
|
||||
|
||||
for episode_idx in range(dataset.num_episodes):
|
||||
if episode_idx in episode_results:
|
||||
episode_stats_list.append(episode_results[episode_idx])
|
||||
|
||||
if not episode_stats_list:
|
||||
raise ValueError("No episode data found for computing statistics")
|
||||
|
||||
logging.info(f"Aggregating statistics from {len(episode_stats_list)} episodes")
|
||||
return aggregate_stats(episode_stats_list)
|
||||
|
||||
|
||||
def augment_dataset_with_quantile_stats(
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""Augment a dataset with quantile statistics if they are missing.
|
||||
|
||||
Args:
|
||||
repo_id: Repository ID of the dataset
|
||||
root: Local root directory for the dataset
|
||||
overwrite: Overwrite existing quantile statistics if they already exist
|
||||
"""
|
||||
logging.info(f"Loading dataset: {repo_id}")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
)
|
||||
|
||||
if not overwrite and has_quantile_stats(dataset.meta.stats):
|
||||
logging.info("Dataset already contains quantile statistics. No action needed.")
|
||||
return
|
||||
|
||||
logging.info("Dataset does not contain quantile statistics. Computing them now...")
|
||||
|
||||
new_stats = compute_quantile_stats_for_dataset(dataset)
|
||||
|
||||
logging.info("Updating dataset metadata with new quantile statistics")
|
||||
dataset.meta.stats = new_stats
|
||||
|
||||
write_stats(new_stats, dataset.meta.root)
|
||||
|
||||
logging.info("Successfully updated dataset with quantile statistics")
|
||||
dataset.push_to_hub()
|
||||
|
||||
hub_api = HfApi()
|
||||
try:
|
||||
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
except HTTPError as e:
|
||||
logging.info(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
|
||||
pass
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=None, repo_type="dataset")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run the augmentation script."""
|
||||
parser = argparse.ArgumentParser(description="Augment LeRobot dataset with quantile statistics")
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository ID of the dataset (e.g., 'lerobot/pusht')",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
help="Local root directory for the dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
action="store_true",
|
||||
help="Overwrite existing quantile statistics if they already exist",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
root = Path(args.root) if args.root else None
|
||||
|
||||
init_logging()
|
||||
|
||||
augment_dataset_with_quantile_stats(
|
||||
repo_id=args.repo_id,
|
||||
root=root,
|
||||
overwrite=args.overwrite,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -26,14 +26,24 @@ This script will help you convert any LeRobot dataset already pushed to the hub
|
||||
|
||||
Usage:
|
||||
|
||||
Convert a dataset from the hub:
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
--repo-id=lerobot/pusht
|
||||
```
|
||||
|
||||
Convert a local dataset (works in place):
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
--repo-id=lerobot/pusht \
|
||||
--root=/path/to/local/dataset/directory
|
||||
--push-to-hub=false
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -46,7 +56,6 @@ from datasets import Dataset, Features, Image
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from requests import HTTPError
|
||||
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
@@ -71,9 +80,11 @@ from lerobot.datasets.utils import (
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
V21 = "v2.1"
|
||||
|
||||
V30 = "v3.0"
|
||||
|
||||
"""
|
||||
-------------------------
|
||||
@@ -143,7 +154,19 @@ def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
return tasks, task_to_task_index
|
||||
|
||||
|
||||
def validate_local_dataset_version(local_path: Path) -> None:
|
||||
"""Validate that the local dataset has the expected v2.1 version."""
|
||||
info = load_info(local_path)
|
||||
dataset_version = info.get("codebase_version", "unknown")
|
||||
if dataset_version != V21:
|
||||
raise ValueError(
|
||||
f"Local dataset has codebase version '{dataset_version}', expected '{V21}'. "
|
||||
f"This script is specifically for converting v2.1 datasets to v3.0."
|
||||
)
|
||||
|
||||
|
||||
def convert_tasks(root, new_root):
|
||||
logging.info(f"Converting tasks from {root} to {new_root}")
|
||||
tasks, _ = legacy_load_tasks(root)
|
||||
task_indices = tasks.keys()
|
||||
task_strings = tasks.values()
|
||||
@@ -185,7 +208,10 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
num_frames = 0
|
||||
paths_to_cat = []
|
||||
episodes_metadata = []
|
||||
for ep_path in ep_paths:
|
||||
|
||||
logging.info(f"Converting data files from {len(ep_paths)} episodes")
|
||||
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc="convert data files"):
|
||||
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
|
||||
ep_num_frames = get_parquet_num_frames(ep_path)
|
||||
ep_metadata = {
|
||||
@@ -209,7 +235,6 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
|
||||
# Reset for the next file
|
||||
size_in_mb = ep_size_in_mb
|
||||
num_frames = ep_num_frames
|
||||
paths_to_cat = [ep_path]
|
||||
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
@@ -236,6 +261,8 @@ def get_image_keys(root):
|
||||
|
||||
|
||||
def convert_videos(root: Path, new_root: Path, video_file_size_in_mb: int):
|
||||
logging.info(f"Converting videos from {root} to {new_root}")
|
||||
|
||||
video_keys = get_video_keys(root)
|
||||
if len(video_keys) == 0:
|
||||
return None
|
||||
@@ -254,7 +281,7 @@ def convert_videos(root: Path, new_root: Path, video_file_size_in_mb: int):
|
||||
episods_metadata = []
|
||||
num_cameras = len(video_keys)
|
||||
num_episodes = num_eps_per_cam[0]
|
||||
for ep_idx in range(num_episodes):
|
||||
for ep_idx in tqdm.tqdm(range(num_episodes), desc="convert videos"):
|
||||
# Sanity check
|
||||
ep_ids = [eps_metadata_per_cam[cam_idx][ep_idx]["episode_index"] for cam_idx in range(num_cameras)]
|
||||
ep_ids += [ep_idx]
|
||||
@@ -281,6 +308,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
||||
duration_in_s = 0.0
|
||||
paths_to_cat = []
|
||||
episodes_metadata = []
|
||||
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
|
||||
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
@@ -374,6 +402,8 @@ def generate_episode_metadata_dict(
|
||||
|
||||
|
||||
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
|
||||
logging.info(f"Converting episodes metadata from {root} to {new_root}")
|
||||
|
||||
episodes_legacy_metadata = legacy_load_episodes(root)
|
||||
episodes_stats = legacy_load_episodes_stats(root)
|
||||
|
||||
@@ -397,14 +427,15 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
|
||||
|
||||
def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb):
|
||||
info = load_info(root)
|
||||
info["codebase_version"] = "v3.0"
|
||||
info["codebase_version"] = V30
|
||||
del info["total_chunks"]
|
||||
del info["total_videos"]
|
||||
info["data_files_size_in_mb"] = data_file_size_in_mb
|
||||
info["video_files_size_in_mb"] = video_file_size_in_mb
|
||||
info["data_path"] = DEFAULT_DATA_PATH
|
||||
info["video_path"] = DEFAULT_VIDEO_PATH
|
||||
info["video_path"] = DEFAULT_VIDEO_PATH if info["video_path"] is not None else None
|
||||
info["fps"] = int(info["fps"])
|
||||
logging.info(f"Converting info from {root} to {new_root}")
|
||||
for key in info["features"]:
|
||||
if info["features"][key]["dtype"] == "video":
|
||||
# already has fps in video_info
|
||||
@@ -418,16 +449,36 @@ def convert_dataset(
|
||||
branch: str | None = None,
|
||||
data_file_size_in_mb: int | None = None,
|
||||
video_file_size_in_mb: int | None = None,
|
||||
root: str | Path | None = None,
|
||||
push_to_hub: bool = True,
|
||||
force_conversion: bool = False,
|
||||
):
|
||||
root = HF_LEROBOT_HOME / repo_id
|
||||
old_root = HF_LEROBOT_HOME / f"{repo_id}_old"
|
||||
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
|
||||
|
||||
if data_file_size_in_mb is None:
|
||||
data_file_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
if video_file_size_in_mb is None:
|
||||
video_file_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
|
||||
# First check if the dataset already has a v3.0 version
|
||||
if root is None and not force_conversion:
|
||||
try:
|
||||
print("Trying to download v3.0 version of the dataset from the hub...")
|
||||
snapshot_download(repo_id, repo_type="dataset", revision=V30, local_dir=HF_LEROBOT_HOME / repo_id)
|
||||
return
|
||||
except Exception:
|
||||
print("Dataset does not have an uploaded v3.0 version. Continuing with conversion.")
|
||||
|
||||
# Set root based on whether local dataset path is provided
|
||||
use_local_dataset = False
|
||||
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root) / repo_id
|
||||
if root.exists():
|
||||
validate_local_dataset_version(root)
|
||||
use_local_dataset = True
|
||||
print(f"Using local dataset at {root}")
|
||||
|
||||
old_root = root.parent / f"{root.name}_old"
|
||||
new_root = root.parent / f"{root.name}_v30"
|
||||
|
||||
# Handle old_root cleanup if both old_root and root exist
|
||||
if old_root.is_dir() and root.is_dir():
|
||||
shutil.rmtree(str(root))
|
||||
shutil.move(str(old_root), str(root))
|
||||
@@ -435,12 +486,13 @@ def convert_dataset(
|
||||
if new_root.is_dir():
|
||||
shutil.rmtree(new_root)
|
||||
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
repo_type="dataset",
|
||||
revision=V21,
|
||||
local_dir=root,
|
||||
)
|
||||
if not use_local_dataset:
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
repo_type="dataset",
|
||||
revision=V21,
|
||||
local_dir=root,
|
||||
)
|
||||
|
||||
convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb)
|
||||
convert_tasks(root, new_root)
|
||||
@@ -451,24 +503,26 @@ def convert_dataset(
|
||||
shutil.move(str(root), str(old_root))
|
||||
shutil.move(str(new_root), str(root))
|
||||
|
||||
hub_api = HfApi()
|
||||
try:
|
||||
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
except HTTPError as e:
|
||||
print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
|
||||
pass
|
||||
hub_api.delete_files(
|
||||
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"],
|
||||
repo_id=repo_id,
|
||||
revision=branch,
|
||||
repo_type="dataset",
|
||||
)
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
if push_to_hub:
|
||||
hub_api = HfApi()
|
||||
try:
|
||||
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
except HTTPError as e:
|
||||
print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
|
||||
pass
|
||||
hub_api.delete_files(
|
||||
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"],
|
||||
repo_id=repo_id,
|
||||
revision=branch,
|
||||
repo_type="dataset",
|
||||
)
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
LeRobotDataset(repo_id).push_to_hub()
|
||||
LeRobotDataset(repo_id).push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
@@ -495,6 +549,23 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="File size in MB. Defaults to 100 for data and 500 for videos.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Local directory to use for downloading/writing the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
type=lambda input: input.lower() == "true",
|
||||
default=True,
|
||||
help="Push the converted dataset to the hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-conversion",
|
||||
action="store_true",
|
||||
help="Force conversion even if the dataset already has a v3.0 version.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_dataset(**vars(args))
|
||||
|
||||
@@ -428,7 +428,7 @@ def concatenate_video_files(
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
||||
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
||||
for input_path in input_video_paths:
|
||||
tmp_concatenate_file.write(f"file '{str(input_path)}'\n")
|
||||
tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
|
||||
tmp_concatenate_file.flush()
|
||||
tmp_concatenate_path = tmp_concatenate_file.name
|
||||
|
||||
@@ -437,7 +437,9 @@ def concatenate_video_files(
|
||||
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
|
||||
) # safe = 0 allows absolute paths as well as relative paths
|
||||
|
||||
tmp_output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
||||
tmp_output_video_path = tmp_named_file.name
|
||||
|
||||
output_container = av.open(
|
||||
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
|
||||
) # faststart is to move the metadata to the beginning of the file to speed up loading
|
||||
@@ -449,11 +451,9 @@ def concatenate_video_files(
|
||||
stream_map[input_stream.index] = output_container.add_stream_from_template(
|
||||
template=input_stream, opaque=True
|
||||
)
|
||||
stream_map[
|
||||
input_stream.index
|
||||
].time_base = (
|
||||
input_stream.time_base
|
||||
) # set the time base to the input stream time base (missing in the codec context)
|
||||
|
||||
# set the time base to the input stream time base (missing in the codec context)
|
||||
stream_map[input_stream.index].time_base = input_stream.time_base
|
||||
|
||||
# Demux + remux packets (no re-encode)
|
||||
for packet in input_container.demux():
|
||||
@@ -585,19 +585,6 @@ def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_image_pixel_channels(image: Image):
|
||||
if image.mode == "L":
|
||||
return 1 # Grayscale
|
||||
elif image.mode == "LA":
|
||||
return 2 # Grayscale + Alpha
|
||||
elif image.mode == "RGB":
|
||||
return 3 # RGB
|
||||
elif image.mode == "RGBA":
|
||||
return 4 # RGBA
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||
"""
|
||||
Get the duration of a video file in seconds using PyAV.
|
||||
@@ -655,6 +642,9 @@ class VideoEncodingManager:
|
||||
)
|
||||
self.dataset._batch_save_episode_video(start_ep, end_ep)
|
||||
|
||||
# Finalize the dataset to properly close all writers
|
||||
self.dataset.finalize()
|
||||
|
||||
# Clean up episode images if recording was interrupted
|
||||
if exc_type is not None:
|
||||
interrupted_episode_index = self.dataset.num_episodes
|
||||
|
||||
@@ -19,9 +19,9 @@ from typing import Any
|
||||
import draccus
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.robots import RobotConfig
|
||||
from lerobot.teleoperators.config import TeleoperatorConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -50,15 +50,17 @@ class AlohaEnv(EnvConfig):
|
||||
fps: int = 50
|
||||
episode_length: int = 400
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
observation_height: int = 480
|
||||
observation_width: int = 640
|
||||
render_mode: str = "rgb_array"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"top": f"{OBS_IMAGE}.top",
|
||||
"pixels/top": f"{OBS_IMAGES}.top",
|
||||
@@ -67,10 +69,14 @@ class AlohaEnv(EnvConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
self.features["top"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
|
||||
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
self.features["pixels/top"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
@@ -91,15 +97,17 @@ class PushtEnv(EnvConfig):
|
||||
render_mode: str = "rgb_array"
|
||||
visualization_width: int = 384
|
||||
visualization_height: int = 384
|
||||
observation_height: int = 384
|
||||
observation_width: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"environment_state": OBS_ENV_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
@@ -108,7 +116,9 @@ class PushtEnv(EnvConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
|
||||
self.features["pixels"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
elif self.obs_type == "environment_state_agent_pos":
|
||||
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
|
||||
|
||||
@@ -135,13 +145,13 @@ class XarmEnv(EnvConfig):
|
||||
visualization_height: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
@@ -193,7 +203,6 @@ class ObservationConfig:
|
||||
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
display_cameras: bool = False
|
||||
|
||||
|
||||
@@ -203,7 +212,6 @@ class GripperConfig:
|
||||
|
||||
use_gripper: bool = True
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -256,15 +264,17 @@ class LiberoEnv(EnvConfig):
|
||||
render_mode: str = "rgb_array"
|
||||
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
|
||||
init_states: bool = True
|
||||
camera_name_mapping: dict[str, str] | None = (None,)
|
||||
camera_name_mapping: dict[str, str] | None = None
|
||||
observation_height: int = 360
|
||||
observation_width: int = 360
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels/agentview_image": f"{OBS_IMAGES}.image",
|
||||
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
|
||||
@@ -274,18 +284,18 @@ class LiberoEnv(EnvConfig):
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["pixels/agentview_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
|
||||
self.features["pixels/agentview_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
||||
|
||||
@@ -63,6 +63,9 @@ def make_env(
|
||||
if "libero" in cfg.type:
|
||||
from lerobot.envs.libero import create_libero_envs
|
||||
|
||||
if cfg.task is None:
|
||||
raise ValueError("LiberoEnv requires a task to be specified")
|
||||
|
||||
return create_libero_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
|
||||
@@ -35,7 +35,7 @@ def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||
"""Normalize camera_name into a non-empty list of strings."""
|
||||
if isinstance(camera_name, str):
|
||||
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
|
||||
elif isinstance(camera_name, (list, tuple)):
|
||||
elif isinstance(camera_name, (list | tuple)):
|
||||
cams = [str(c).strip() for c in camera_name if str(c).strip()]
|
||||
else:
|
||||
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
|
||||
|
||||
@@ -26,6 +26,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.utils import get_channel_first_image_shape
|
||||
|
||||
|
||||
@@ -41,44 +42,44 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
return_observations = {}
|
||||
if "pixels" in observations:
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
imgs = {OBS_IMAGE: observations["pixels"]}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
img = torch.from_numpy(img)
|
||||
img_tensor = torch.from_numpy(img)
|
||||
|
||||
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
|
||||
# This is the case for human-in-the-loop RL where there is only one environment.
|
||||
if img.ndim == 3:
|
||||
img = img.unsqueeze(0)
|
||||
if img_tensor.ndim == 3:
|
||||
img_tensor = img_tensor.unsqueeze(0)
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
_, h, w, c = img_tensor.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img_tensor.shape=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
assert img_tensor.dtype == torch.uint8, f"expect torch.uint8, but instead {img_tensor.dtype=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous()
|
||||
img_tensor = img_tensor.type(torch.float32)
|
||||
img_tensor /= 255
|
||||
|
||||
return_observations[imgkey] = img
|
||||
return_observations[imgkey] = img_tensor
|
||||
|
||||
if "environment_state" in observations:
|
||||
env_state = torch.from_numpy(observations["environment_state"]).float()
|
||||
if env_state.dim() == 1:
|
||||
env_state = env_state.unsqueeze(0)
|
||||
|
||||
return_observations["observation.environment_state"] = env_state
|
||||
return_observations[OBS_ENV_STATE] = env_state
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
return_observations["observation.state"] = agent_pos
|
||||
return_observations[OBS_STATE] = agent_pos
|
||||
|
||||
return return_observations
|
||||
|
||||
@@ -182,10 +183,10 @@ def _(env: Mapping) -> None:
|
||||
|
||||
@close_envs.register
|
||||
def _(envs: Sequence) -> None:
|
||||
if isinstance(envs, (str, bytes)):
|
||||
if isinstance(envs, (str | bytes)):
|
||||
return
|
||||
for v in envs:
|
||||
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str, bytes)):
|
||||
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str | bytes)):
|
||||
close_envs(v)
|
||||
elif hasattr(v, "close"):
|
||||
_close_single_env(v)
|
||||
|
||||
@@ -1 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus
|
||||
|
||||
@@ -22,7 +22,7 @@ import logging
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
|
||||
from lerobot.utils.encoding_utils import decode_twos_complement, encode_twos_complement
|
||||
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
|
||||
from .tables import (
|
||||
|
||||
@@ -17,7 +17,7 @@ from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pprint import pformat
|
||||
|
||||
from lerobot.utils.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
|
||||
from .tables import (
|
||||
|
||||
@@ -32,7 +32,7 @@ import serial
|
||||
from deepdiff import DeepDiff
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
NameOrID: TypeAlias = str | int
|
||||
@@ -99,12 +99,6 @@ class Motor:
|
||||
norm_mode: MotorNormMode
|
||||
|
||||
|
||||
class JointOutOfRangeError(Exception):
|
||||
def __init__(self, message="Joint is out of range"):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class PortHandler(Protocol):
|
||||
def __init__(self, port_name):
|
||||
self.is_open: bool
|
||||
@@ -348,7 +342,7 @@ class MotorsBus(abc.ABC):
|
||||
raise TypeError(motors)
|
||||
|
||||
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
|
||||
if isinstance(values, (int, float)):
|
||||
if isinstance(values, (int | float)):
|
||||
return dict.fromkeys(self.ids, values)
|
||||
elif isinstance(values, dict):
|
||||
return {self.motors[motor].id: val for motor, val in values.items()}
|
||||
@@ -675,7 +669,7 @@ class MotorsBus(abc.ABC):
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str, int)):
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
@@ -703,7 +697,7 @@ class MotorsBus(abc.ABC):
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str, int)):
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
@@ -739,7 +733,7 @@ class MotorsBus(abc.ABC):
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str, int)):
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
|
||||
@@ -22,11 +22,11 @@ import draccus
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.constants import (
|
||||
from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json
|
||||
from lerobot.utils.constants import (
|
||||
OPTIMIZER_PARAM_GROUPS,
|
||||
OPTIMIZER_STATE,
|
||||
)
|
||||
from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json
|
||||
from lerobot.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
|
||||
@@ -22,8 +22,8 @@ import draccus
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from lerobot.constants import SCHEDULER_STATE
|
||||
from lerobot.datasets.utils import write_json
|
||||
from lerobot.utils.constants import SCHEDULER_STATE
|
||||
from lerobot.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0.processor_pi0 import Pi0NewLineProcessor
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
@@ -25,6 +25,7 @@ __all__ = [
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"PI0Config",
|
||||
"PI05Config",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
|
||||
@@ -33,9 +33,9 @@ from torch import Tensor, nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class ACTPolicy(PreTrainedPolicy):
|
||||
@@ -394,25 +394,22 @@ class ACT(nn.Module):
|
||||
latent dimension.
|
||||
"""
|
||||
if self.config.use_vae and self.training:
|
||||
assert "action" in batch, (
|
||||
assert ACTION in batch, (
|
||||
"actions must be provided when using the variational objective in training mode."
|
||||
)
|
||||
|
||||
if "observation.images" in batch:
|
||||
batch_size = batch["observation.images"][0].shape[0]
|
||||
else:
|
||||
batch_size = batch["observation.environment_state"].shape[0]
|
||||
batch_size = batch[OBS_IMAGES][0].shape[0] if OBS_IMAGES in batch else batch[OBS_ENV_STATE].shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.config.use_vae and "action" in batch and self.training:
|
||||
if self.config.use_vae and ACTION in batch and self.training:
|
||||
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
|
||||
cls_embed = einops.repeat(
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
if self.config.robot_state_feature:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch[OBS_STATE])
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch[ACTION]) # (B, S, D)
|
||||
|
||||
if self.config.robot_state_feature:
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
@@ -430,7 +427,7 @@ class ACT(nn.Module):
|
||||
cls_joint_is_pad = torch.full(
|
||||
(batch_size, 2 if self.config.robot_state_feature else 1),
|
||||
False,
|
||||
device=batch["observation.state"].device,
|
||||
device=batch[OBS_STATE].device,
|
||||
)
|
||||
key_padding_mask = torch.cat(
|
||||
[cls_joint_is_pad, batch["action_is_pad"]], axis=1
|
||||
@@ -454,7 +451,7 @@ class ACT(nn.Module):
|
||||
mu = log_sigma_x2 = None
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
||||
batch["observation.state"].device
|
||||
batch[OBS_STATE].device
|
||||
)
|
||||
|
||||
# Prepare transformer encoder inputs.
|
||||
@@ -462,18 +459,16 @@ class ACT(nn.Module):
|
||||
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
|
||||
# Robot state token.
|
||||
if self.config.robot_state_feature:
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch[OBS_STATE]))
|
||||
# Environment state token.
|
||||
if self.config.env_state_feature:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
||||
)
|
||||
encoder_in_tokens.append(self.encoder_env_state_input_proj(batch[OBS_ENV_STATE]))
|
||||
|
||||
if self.config.image_features:
|
||||
# For a list of images, the H and W may vary but H*W is constant.
|
||||
# NOTE: If modifying this section, verify on MPS devices that
|
||||
# gradients remain stable (no explosions or NaNs).
|
||||
for img in batch["observation.images"]:
|
||||
for img in batch[OBS_IMAGES]:
|
||||
cam_features = self.backbone(img)["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features)
|
||||
|
||||
@@ -17,7 +17,6 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -29,6 +28,7 @@ from lerobot.processor import (
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_act_pre_post_processors(
|
||||
|
||||
@@ -45,7 +45,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
chunk_size: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
@@ -105,7 +105,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 2
|
||||
horizon: int = 16
|
||||
chunk_size: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
@@ -118,7 +118,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
# The original implementation doesn't sample frames for the last 7 steps,
|
||||
# which avoids excessive padding and leads to improved training results.
|
||||
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
|
||||
drop_n_last_frames: int = 7 # chunk_size - n_action_steps - n_obs_steps + 1
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
@@ -180,13 +180,13 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
f"Got {self.noise_scheduler_type}."
|
||||
)
|
||||
|
||||
# Check that the horizon size and U-Net downsampling is compatible.
|
||||
# Check that the chunk size and U-Net downsampling is compatible.
|
||||
# U-Net downsamples by 2 with each stage.
|
||||
downsampling_factor = 2 ** len(self.down_dims)
|
||||
if self.horizon % downsampling_factor != 0:
|
||||
if self.chunk_size % downsampling_factor != 0:
|
||||
raise ValueError(
|
||||
"The horizon should be an integer multiple of the downsampling factor (which is determined "
|
||||
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
|
||||
"The chunk_size should be an integer multiple of the downsampling factor (which is determined "
|
||||
f"by `len(down_dims)`). Got {self.chunk_size=} and {self.down_dims=}"
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
@@ -231,7 +231,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
||||
@@ -33,7 +33,6 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import (
|
||||
@@ -42,6 +41,7 @@ from lerobot.policies.utils import (
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class DiffusionPolicy(PreTrainedPolicy):
|
||||
@@ -81,43 +81,43 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
if self.config.image_features:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
# stack n latest observations from the queue
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
actions = self.diffusion.generate_actions(batch, noise=noise)
|
||||
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method handles caching a history of observations and an action trajectory generated by the
|
||||
underlying diffusion model. Here's how it works:
|
||||
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
||||
copied `n_obs_steps` times to fill the cache).
|
||||
- The diffusion model generates `horizon` steps worth of actions.
|
||||
- The diffusion model generates `chunk_size` steps worth of actions.
|
||||
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
||||
Schematically this looks like:
|
||||
----------------------------------------------------------------------------------------------
|
||||
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
||||
(legend: o = n_obs_steps, c = chunk_size, a = n_action_steps)
|
||||
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
||||
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
||||
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
||||
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
||||
----------------------------------------------------------------------------------------------
|
||||
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
|
||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||
Note that this means we require: `n_action_steps <= chunk_size - n_obs_steps + 1`. Also, note that
|
||||
this period is
|
||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
"""
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
@@ -131,7 +131,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
actions = self.predict_action_chunk(batch, noise=noise)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues[ACTION].popleft()
|
||||
@@ -199,17 +199,25 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
# ========= inference ============
|
||||
def conditional_sample(
|
||||
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
|
||||
self,
|
||||
batch_size: int,
|
||||
global_cond: Tensor | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
noise: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
dtype = get_dtype_from_parameters(self)
|
||||
|
||||
# Sample prior.
|
||||
sample = torch.randn(
|
||||
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
sample = (
|
||||
noise
|
||||
if noise is not None
|
||||
else torch.randn(
|
||||
size=(batch_size, self.config.chunk_size, self.config.action_feature.shape[0]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
)
|
||||
)
|
||||
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_steps)
|
||||
@@ -234,7 +242,7 @@ class DiffusionModel(nn.Module):
|
||||
if self.config.image_features:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||
images_per_camera = einops.rearrange(batch[OBS_IMAGES], "b s n ... -> n (b s) ...")
|
||||
img_features_list = torch.cat(
|
||||
[
|
||||
encoder(images)
|
||||
@@ -249,7 +257,7 @@ class DiffusionModel(nn.Module):
|
||||
else:
|
||||
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
@@ -264,7 +272,7 @@ class DiffusionModel(nn.Module):
|
||||
# Concatenate features then flatten to (B, global_cond_dim).
|
||||
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||
|
||||
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def generate_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""
|
||||
This function expects `batch` to have:
|
||||
{
|
||||
@@ -275,14 +283,14 @@ class DiffusionModel(nn.Module):
|
||||
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||
}
|
||||
"""
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Encode image features and concatenate them all together along with the state vector.
|
||||
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||
|
||||
# run sampling
|
||||
actions = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||
actions = self.conditional_sample(batch_size, global_cond=global_cond, noise=noise)
|
||||
|
||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||
start = n_obs_steps - 1
|
||||
@@ -301,23 +309,23 @@ class DiffusionModel(nn.Module):
|
||||
AND/OR
|
||||
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||
|
||||
"action": (B, horizon, action_dim)
|
||||
"action_is_pad": (B, horizon)
|
||||
"action": (B, chunk_size, action_dim)
|
||||
"action_is_pad": (B, chunk_size)
|
||||
}
|
||||
"""
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
|
||||
assert "observation.images" in batch or "observation.environment_state" in batch
|
||||
n_obs_steps = batch["observation.state"].shape[1]
|
||||
horizon = batch["action"].shape[1]
|
||||
assert horizon == self.config.horizon
|
||||
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
|
||||
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
|
||||
n_obs_steps = batch[OBS_STATE].shape[1]
|
||||
chunk_size = batch[ACTION].shape[1]
|
||||
assert chunk_size == self.config.chunk_size
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Encode image features and concatenate them all together along with the state vector.
|
||||
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||
|
||||
# Forward diffusion.
|
||||
trajectory = batch["action"]
|
||||
trajectory = batch[ACTION]
|
||||
# Sample noise to add to the trajectory.
|
||||
eps = torch.randn(trajectory.shape, device=trajectory.device)
|
||||
# Sample a random noising timestep for each item in the batch.
|
||||
@@ -338,7 +346,7 @@ class DiffusionModel(nn.Module):
|
||||
if self.config.prediction_type == "epsilon":
|
||||
target = eps
|
||||
elif self.config.prediction_type == "sample":
|
||||
target = batch["action"]
|
||||
target = batch[ACTION]
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -30,6 +29,7 @@ from lerobot.processor import (
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_diffusion_pre_post_processors(
|
||||
|
||||
244
src/lerobot/policies/dsrl/configuration_dsrl.py
Normal file
244
src/lerobot/policies/dsrl/configuration_dsrl.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import MultiAdamConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
def is_image_feature(key: str) -> bool:
|
||||
"""Check if a feature key represents an image feature.
|
||||
|
||||
Args:
|
||||
key: The feature key to check
|
||||
|
||||
Returns:
|
||||
True if the key represents an image feature, False otherwise
|
||||
"""
|
||||
return key.startswith(OBS_IMAGE)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConcurrencyConfig:
|
||||
"""Configuration for the concurrency of the actor and learner.
|
||||
Possible values are:
|
||||
- "threads": Use threads for the actor and learner.
|
||||
- "processes": Use processes for the actor and learner.
|
||||
"""
|
||||
|
||||
actor: str = "threads"
|
||||
learner: str = "threads"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorLearnerConfig:
|
||||
learner_host: str = "127.0.0.1"
|
||||
learner_port: int = 50051
|
||||
policy_parameters_push_frequency: int = 4
|
||||
queue_get_timeout: float = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class CriticNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
final_activation: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
use_layer_norm: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class NoiseActorConfig:
|
||||
"""Configuration for the noise actor in DSRL.
|
||||
The noise actor outputs noise that gets fed to the diffusion policy.
|
||||
"""
|
||||
|
||||
use_tanh_squash: bool = False # Whether to bound the noise output
|
||||
std_min: float = 1e-5
|
||||
std_max: float = 2.0
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("dsrl")
|
||||
@dataclass
|
||||
class DSRLConfig(PreTrainedConfig):
|
||||
"""Diffusion Steering via Reinforcement Learning (DSRL) configuration."""
|
||||
|
||||
# Mapping of feature types to normalization modes
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ENV": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# Statistics for normalizing different types of inputs
|
||||
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
||||
default_factory=lambda: {
|
||||
OBS_IMAGE: {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
OBS_STATE: {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
ACTION: {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture specifics
|
||||
# Device to run the model on (e.g., "cuda", "cpu")
|
||||
device: str = "cpu"
|
||||
# Device to store the model on
|
||||
storage_device: str = "cpu"
|
||||
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
|
||||
vision_encoder_name: str | None = None
|
||||
# Whether to freeze the vision encoder during training
|
||||
freeze_vision_encoder: bool = True
|
||||
# Hidden dimension size for the image encoder
|
||||
image_encoder_hidden_dim: int = 32
|
||||
# Whether to use a shared encoder for actor and critic
|
||||
shared_encoder: bool = True
|
||||
# Number of discrete actions, eg for gripper actions
|
||||
num_discrete_actions: int | None = None
|
||||
# Dimension of the image embedding pooling
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Name of the action policy
|
||||
action_policy_name: str = "pi0"
|
||||
action_policy_weights: str | None = "lerobot/pi0_base"
|
||||
|
||||
# Training parameter
|
||||
# Number of steps for online training
|
||||
online_steps: int = 1000000
|
||||
# Number of steps for offline training
|
||||
offline_steps: int = 100000
|
||||
# Capacity of the online replay buffer
|
||||
online_buffer_capacity: int = 100000
|
||||
# Capacity of the offline replay buffer
|
||||
offline_buffer_capacity: int = 100000
|
||||
# Whether to use asynchronous prefetching for the buffers
|
||||
async_prefetch: bool = False
|
||||
# Number of steps before learning starts
|
||||
online_step_before_learning: int = 100
|
||||
# Frequency of policy updates
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
discount: float = 0.99
|
||||
# Initial temperature value
|
||||
temperature_init: float = 1.0
|
||||
# Number of critics in the ensemble
|
||||
num_critics: int = 2
|
||||
# Number of subsampled critics for training
|
||||
num_subsample_critics: int | None = None
|
||||
# Learning rate for the critic network
|
||||
critic_lr: float = 3e-4
|
||||
# Learning rate for the actor network
|
||||
actor_lr: float = 3e-4
|
||||
# Learning rate for the temperature parameter
|
||||
temperature_lr: float = 3e-4
|
||||
# Weight for the critic target update
|
||||
critic_target_update_weight: float = 0.005
|
||||
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
|
||||
utd_ratio: int = 1
|
||||
# Hidden dimension size for the state encoder
|
||||
state_encoder_hidden_dim: int = 256
|
||||
# Dimension of the latent space
|
||||
latent_dim: int = 256
|
||||
# Target entropy for the SAC algorithm
|
||||
target_entropy: float | None = None
|
||||
# Whether to use backup entropy for the SAC algorithm
|
||||
use_backup_entropy: bool = True
|
||||
# Gradient clipping norm for the SAC algorithm
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Network configuration
|
||||
# Configuration for the critic network architecture
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for the noise critic network architecture
|
||||
noise_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for the noise actor network architecture
|
||||
noise_actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
# Configuration for the noise actor specific parameters
|
||||
noise_actor_kwargs: NoiseActorConfig = field(default_factory=NoiseActorConfig)
|
||||
# Configuration for actor-learner architecture
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
# Optimizations
|
||||
use_torch_compile: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||
return MultiAdamConfig(
|
||||
weight_decay=0.0,
|
||||
optimizer_groups={
|
||||
"critic_action": {"lr": self.critic_lr},
|
||||
"critic_noise": {"lr": self.critic_lr},
|
||||
"noise_actor": {"lr": self.actor_lr},
|
||||
"temperature": {"lr": self.temperature_lr},
|
||||
},
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
has_image = any(is_image_feature(key) for key in self.input_features)
|
||||
has_state = OBS_STATE in self.input_features
|
||||
|
||||
if not (has_state or has_image):
|
||||
raise ValueError(
|
||||
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
|
||||
)
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
raise ValueError("You must provide 'action' in the output features")
|
||||
|
||||
@property
|
||||
def image_features(self) -> list[str]:
|
||||
return [key for key in self.input_features if is_image_feature(key)]
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
1197
src/lerobot/policies/dsrl/modeling_dsrl.py
Normal file
1197
src/lerobot/policies/dsrl/modeling_dsrl.py
Normal file
File diff suppressed because it is too large
Load Diff
89
src/lerobot/policies/dsrl/processor_dsrl.py
Normal file
89
src/lerobot/policies/dsrl/processor_dsrl.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor for DSRL policy.
|
||||
|
||||
DSRL uses a similar processing pipeline as SAC since it operates on
|
||||
state-action transitions. The main difference is that internally it
|
||||
also works with noise, but that's handled within the policy itself.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_dsrl_pre_post_processors(
|
||||
config: DSRLConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict, dict],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Create preprocessor and postprocessor pipelines for DSRL policy.
|
||||
|
||||
Args:
|
||||
config: DSRL policy configuration
|
||||
dataset_stats: Optional dataset statistics for normalization
|
||||
|
||||
Returns:
|
||||
Tuple of (preprocessor, postprocessor) pipelines
|
||||
"""
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -24,15 +24,16 @@ from typing_extensions import Unpack
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.envs.utils import env_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
@@ -46,6 +47,7 @@ from lerobot.processor.converters import (
|
||||
transition_to_batch,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
@@ -57,7 +59,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla".
|
||||
"vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla", "dsrl".
|
||||
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
@@ -81,14 +83,18 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
|
||||
return VQBeTPolicy
|
||||
elif name == "pi0":
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
return PI0Policy
|
||||
elif name == "pi0fast":
|
||||
from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
return PI0FASTPolicy
|
||||
elif name == "pi0":
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
return PI0Policy
|
||||
elif name == "pi05":
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
return PI05Policy
|
||||
elif name == "sac":
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
@@ -101,6 +107,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "dsrl":
|
||||
from lerobot.policies.dsrl.modeling_dsrl import DSRLPolicy
|
||||
|
||||
return DSRLPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -115,7 +125,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla",
|
||||
"reward_classifier".
|
||||
"reward_classifier", "dsrl".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -132,16 +142,20 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return ACTConfig(**kwargs)
|
||||
elif policy_type == "vqbet":
|
||||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05":
|
||||
return PI05Config(**kwargs)
|
||||
elif policy_type == "sac":
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "dsrl":
|
||||
return DSRLConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
@@ -253,6 +267,14 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI0FASTConfig):
|
||||
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
|
||||
|
||||
processors = make_pi0fast_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI0Config):
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
|
||||
|
||||
@@ -261,10 +283,10 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI0FASTConfig):
|
||||
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
|
||||
elif isinstance(policy_cfg, PI05Config):
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors
|
||||
|
||||
processors = make_pi0fast_pre_post_processors(
|
||||
processors = make_pi05_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
@@ -292,6 +314,13 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, DSRLConfig):
|
||||
from lerobot.policies.dsrl.processor_dsrl import make_dsrl_pre_post_processors
|
||||
|
||||
processors = make_dsrl_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||
|
||||
49
src/lerobot/policies/pi0/README.md
Normal file
49
src/lerobot/policies/pi0/README.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# π₀ (pi0)
|
||||
|
||||
This repository contains the Hugging Face port of **π₀**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
|
||||
It is designed as a **Vision-Language-Action model for general robot control**.
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | π₀ | π₀.₅ |
|
||||
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
|
||||
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
|
||||
| AdaRMS | Not used | Used in action expert |
|
||||
| Tokenizer Length | 48 tokens | 200 tokens |
|
||||
| Discrete State Input | False (Uses `state_proj` layer) | True |
|
||||
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite both **OpenPI** and the π₀ paper:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{black2024pi0visionlanguageactionflowmodel,
|
||||
title = {π₀: A Vision-Language-Action Flow Model for General Robot Control},
|
||||
author = {Kevin Black and Noah Brown and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Lucy Xiaoyang Shi and James Tanner and Quan Vuong and Anna Walling and Haohuan Wang and Ury Zhilinsky},
|
||||
year = {2024},
|
||||
eprint = {2410.24164},
|
||||
archivePrefix= {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2410.24164},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
21
src/lerobot/policies/pi0/__init__.py
Normal file
21
src/lerobot/policies/pi0/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_pi0 import PI0Config
|
||||
from .modeling_pi0 import PI0Policy
|
||||
from .processor_pi0 import make_pi0_pre_post_processors
|
||||
|
||||
__all__ = ["PI0Config", "PI0Policy", "make_pi0_pre_post_processors"]
|
||||
@@ -1,4 +1,6 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -17,19 +19,40 @@ from dataclasses import dataclass, field
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0")
|
||||
@dataclass
|
||||
class PI0Config(PreTrainedConfig):
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
|
||||
n_action_steps: int = 50 # Number of action steps to execute
|
||||
|
||||
# Shorter state and action vectors will be padded to these dimensions
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Flow matching parameters: see openpi `PI0Pytorch`
|
||||
num_inference_steps: int = 10 # Number of denoising steps during inference
|
||||
time_sampling_beta_alpha: float = 1.5
|
||||
time_sampling_beta_beta: float = 1.0
|
||||
time_sampling_scale: float = 0.999
|
||||
time_sampling_offset: float = 0.001
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Normalization
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
@@ -38,94 +61,75 @@ class PI0Config(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
# Shorter state and action vectors will be padded
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
# Training settings
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
||||
|
||||
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
||||
# left and right wrist cameras in addition to the top camera.
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Converts the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi_aloha: bool = False
|
||||
|
||||
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
||||
# Gripper dimensions will remain in absolute values.
|
||||
use_delta_joint_actions_aloha: bool = False
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 48
|
||||
|
||||
# Projector
|
||||
proj_width: int = 1024
|
||||
|
||||
# Decoding
|
||||
num_steps: int = 10
|
||||
|
||||
# Attention utils
|
||||
use_cache: bool = True
|
||||
attention_implementation: str = "eager" # or fa2, flex
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = True
|
||||
train_expert_only: bool = False
|
||||
train_state_proj: bool = True
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 2.5e-5
|
||||
# Optimizer settings: see openpi `AdamW``
|
||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_weight_decay: float = 0.01
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
# TODO: Add EMA
|
||||
tokenizer_max_length: int = 48 # see openpi `__post_init__`
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# TODO(Steven): Validate device and amp? in all policy configs?
|
||||
"""Input validation (not exhaustive)."""
|
||||
# Validate configuration
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.n_obs_steps != 1:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
if self.use_delta_joint_actions_aloha:
|
||||
raise NotImplementedError(
|
||||
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
|
||||
)
|
||||
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
|
||||
|
||||
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
|
||||
|
||||
if self.dtype not in ["bfloat16", "float32"]:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# TODO: implement value error
|
||||
# if not self.image_features and not self.env_state_feature:
|
||||
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
"""Validate and set up input/output features."""
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
key = f"{OBS_IMAGES}.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
shape=(3, *self.image_resolution), # Use configured image resolution
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
if "observation.state" not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,), # Padded to max_state_dim
|
||||
)
|
||||
self.input_features["observation.state"] = state_feature
|
||||
|
||||
if "action" not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,), # Padded to max_action_dim
|
||||
)
|
||||
self.output_features["action"] = action_feature
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def main():
|
||||
device = "cuda"
|
||||
dataset_repo_id = "danaaubakirova/koch_test"
|
||||
# model_name = "pi0_base"
|
||||
# ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
ckpt_torch_dir = "lerobot/pi0"
|
||||
|
||||
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
# To device
|
||||
for k in batch:
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].to(device=device, dtype=torch.float32)
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||
cfg.pretrained_path = ckpt_torch_dir
|
||||
policy = make_policy(cfg, ds_meta=dataset.meta)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
warmup_iters = 10
|
||||
benchmark_iters = 30
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iters):
|
||||
torch.cuda.synchronize()
|
||||
policy.select_action(batch)
|
||||
policy.reset()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
for _ in range(benchmark_iters):
|
||||
policy.select_action(batch)
|
||||
policy.reset()
|
||||
end_event.record()
|
||||
|
||||
# Synchronize and measure time
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time_ms = start_event.elapsed_time(end_event)
|
||||
|
||||
avg_time_per_iter = elapsed_time_ms / benchmark_iters
|
||||
print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with torch.inference_mode():
|
||||
main()
|
||||
@@ -1,131 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.factory import make_policy
|
||||
|
||||
|
||||
def display(tensor: torch.Tensor):
|
||||
if tensor.dtype == torch.bool:
|
||||
tensor = tensor.float()
|
||||
print(f"Shape: {tensor.shape}")
|
||||
print(f"Mean: {tensor.mean().item()}")
|
||||
print(f"Std: {tensor.std().item()}")
|
||||
print(f"Min: {tensor.min().item()}")
|
||||
print(f"Max: {tensor.max().item()}")
|
||||
|
||||
|
||||
def main():
|
||||
num_motors = 14
|
||||
device = "cuda"
|
||||
# model_name = "pi0_aloha_towel"
|
||||
model_name = "pi0_aloha_sim"
|
||||
|
||||
if model_name == "pi0_aloha_towel":
|
||||
dataset_repo_id = "lerobot/aloha_static_towel"
|
||||
else:
|
||||
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
|
||||
|
||||
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
||||
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
|
||||
save_dir = Path(f"../openpi/data/{model_name}/save")
|
||||
|
||||
with open(save_dir / "example.pkl", "rb") as f:
|
||||
example = pickle.load(f)
|
||||
with open(save_dir / "outputs.pkl", "rb") as f:
|
||||
outputs = pickle.load(f)
|
||||
with open(save_dir / "noise.pkl", "rb") as f:
|
||||
noise = pickle.load(f)
|
||||
|
||||
with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
|
||||
norm_stats = json.load(f)
|
||||
|
||||
# Override stats
|
||||
dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
|
||||
dataset_meta.stats["observation.state"]["mean"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Create LeRobot batch from Jax
|
||||
batch = {}
|
||||
for cam_key, uint_chw_array in example["images"].items():
|
||||
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
||||
batch["observation.state"] = torch.from_numpy(example["state"])
|
||||
batch["action"] = torch.from_numpy(outputs["actions"])
|
||||
batch["task"] = example["prompt"]
|
||||
|
||||
if model_name == "pi0_aloha_towel":
|
||||
del batch["observation.images.cam_low"]
|
||||
elif model_name == "pi0_aloha_sim":
|
||||
batch["observation.images.top"] = batch["observation.images.cam_high"]
|
||||
del batch["observation.images.cam_high"]
|
||||
|
||||
# Batchify
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].unsqueeze(0)
|
||||
elif isinstance(batch[key], str):
|
||||
batch[key] = [batch[key]]
|
||||
else:
|
||||
raise ValueError(f"{key}, {batch[key]}")
|
||||
|
||||
# To device
|
||||
for k in batch:
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].to(device=device, dtype=torch.float32)
|
||||
|
||||
noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32)
|
||||
|
||||
from lerobot import policies # noqa
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||
cfg.pretrained_path = ckpt_torch_dir
|
||||
policy = make_policy(cfg, dataset_meta)
|
||||
|
||||
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
|
||||
# loss_dict["loss"].backward()
|
||||
# print("losses")
|
||||
# display(loss_dict["losses_after_forward"])
|
||||
# print("pi_losses")
|
||||
# display(pi_losses)
|
||||
|
||||
actions = []
|
||||
for _ in range(50):
|
||||
action = policy.select_action(batch, noise=noise)
|
||||
actions.append(action)
|
||||
|
||||
actions = torch.stack(actions, dim=1)
|
||||
pi_actions = batch["action"]
|
||||
print("actions")
|
||||
display(actions)
|
||||
print()
|
||||
print("pi_actions")
|
||||
display(pi_actions)
|
||||
print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2))
|
||||
print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2))
|
||||
print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,84 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import GemmaConfig, PaliGemmaConfig
|
||||
|
||||
|
||||
def get_paligemma_config(precision: str):
|
||||
config = {
|
||||
"image_token_index": None,
|
||||
"pad_token_id": 0,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 1,
|
||||
}
|
||||
|
||||
# image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896}
|
||||
|
||||
image_size = 224 # image_sizes[variant]
|
||||
patch_size = 14
|
||||
num_image_tokens = (image_size**2) // (patch_size**2)
|
||||
|
||||
config["image_token_index"] = 257152
|
||||
text_config = {
|
||||
"vocab_size": 257152,
|
||||
"num_hidden_layers": 18,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 256,
|
||||
"torch_dtype": precision,
|
||||
"hidden_size": 2048,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"num_attention_heads": 8,
|
||||
"intermediate_size": 16384,
|
||||
"is_encoder_decoder": False,
|
||||
}
|
||||
vision_config = {
|
||||
"torch_dtype": precision,
|
||||
"image_size": image_size,
|
||||
"patch_size": patch_size,
|
||||
"num_image_tokens": num_image_tokens,
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"num_hidden_layers": 27,
|
||||
"num_attention_heads": 16,
|
||||
"projector_hidden_act": "gelu_fast",
|
||||
"vision_use_head": False,
|
||||
}
|
||||
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
|
||||
return final_config
|
||||
|
||||
|
||||
def get_gemma_config(precision: str):
|
||||
config = {
|
||||
"image_token_index": None,
|
||||
"pad_token_id": 0,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 1,
|
||||
}
|
||||
|
||||
config["image_token_index"] = 257152
|
||||
text_config = {
|
||||
"vocab_size": 257152,
|
||||
"num_hidden_layers": 18,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 256,
|
||||
"torch_dtype": precision,
|
||||
"hidden_size": 1024,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"num_attention_heads": 8,
|
||||
"intermediate_size": 4096,
|
||||
"is_encoder_decoder": False,
|
||||
}
|
||||
final_config = GemmaConfig()
|
||||
final_config.update(text_config)
|
||||
return final_config
|
||||
@@ -1,437 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Convert pi0 parameters from Jax to Pytorch
|
||||
|
||||
Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
|
||||
and install the required libraries.
|
||||
|
||||
```bash
|
||||
cd ~/code/openpi
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
Example downloading parameters:
|
||||
```bash
|
||||
python
|
||||
>>> import openpi.shared.download as download
|
||||
>>> path='s3://openpi-assets/checkpoints/pi0_base/params'
|
||||
>>> download.maybe_download(path)
|
||||
```
|
||||
|
||||
Converting pi0_base:
|
||||
```python
|
||||
python -m lerobot.policies.pi0.conversion_scripts.convert_pi0_to_hf_lerobot \
|
||||
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \
|
||||
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch
|
||||
```
|
||||
|
||||
```python
|
||||
python -m lerobot.policies.pi0.conversion_scripts.convert_pi0_to_hf_lerobot \
|
||||
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \
|
||||
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import orbax.checkpoint as ocp
|
||||
import torch
|
||||
from jax.sharding import SingleDeviceSharding
|
||||
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.conversion_scripts.conversion_utils import (
|
||||
get_gemma_config,
|
||||
get_paligemma_config,
|
||||
)
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
||||
|
||||
|
||||
def slice_paligemma_state_dict(state_dict, config):
|
||||
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
|
||||
|
||||
# fmt: off
|
||||
# patch embeddings
|
||||
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose(
|
||||
3, 2, 0, 1
|
||||
)
|
||||
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}")
|
||||
# positional embeddings
|
||||
state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape(
|
||||
-1, config.vision_config.hidden_size
|
||||
)
|
||||
|
||||
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
|
||||
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
|
||||
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
|
||||
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
|
||||
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
|
||||
|
||||
encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
|
||||
encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
|
||||
encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
|
||||
encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
|
||||
|
||||
encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}")
|
||||
encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}")
|
||||
encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}")
|
||||
encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}")
|
||||
encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}")
|
||||
encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}")
|
||||
encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}")
|
||||
encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}")
|
||||
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i]
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i]
|
||||
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i]
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i]
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
|
||||
state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose()
|
||||
state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}")
|
||||
|
||||
# multimodal projector
|
||||
|
||||
state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose()
|
||||
state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}")
|
||||
|
||||
# text decoder (gemma)
|
||||
embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}")
|
||||
state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector
|
||||
|
||||
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
|
||||
|
||||
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
|
||||
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
|
||||
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
|
||||
|
||||
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
|
||||
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
|
||||
# TODO verify correctness of layer norm loading
|
||||
|
||||
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
|
||||
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
|
||||
|
||||
for i in range(config.text_config.num_hidden_layers):
|
||||
# llm_attention_q_einsum[i].shape = (8, 2048, 256)
|
||||
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
|
||||
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
|
||||
|
||||
# llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256)
|
||||
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
|
||||
# llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256)
|
||||
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
|
||||
|
||||
# output projection.
|
||||
|
||||
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048)
|
||||
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
|
||||
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
|
||||
# mlp layers
|
||||
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
|
||||
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
|
||||
state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
|
||||
|
||||
state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}")
|
||||
state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied.
|
||||
|
||||
# fmt: on
|
||||
expert_dict = {}
|
||||
final_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key not in [
|
||||
f"llm/final_norm_1/scale{suffix}",
|
||||
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
|
||||
f"llm/layers/attn/kv_einsum_1/w{suffix}",
|
||||
f"llm/layers/attn/q_einsum_1/w{suffix}",
|
||||
f"llm/layers/mlp_1/gating_einsum{suffix}",
|
||||
f"llm/layers/mlp_1/linear{suffix}",
|
||||
f"llm/layers/pre_attention_norm_1/scale{suffix}",
|
||||
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
|
||||
]:
|
||||
final_state_dict[key] = torch.from_numpy(value)
|
||||
else:
|
||||
expert_dict[key] = value
|
||||
|
||||
return final_state_dict, expert_dict
|
||||
|
||||
|
||||
def slice_gemma_state_dict(state_dict, config, num_expert=1):
|
||||
# fmt: off
|
||||
# text decoder (gemma)
|
||||
# no embedding vector, the expert just has the decoder layers
|
||||
|
||||
embedding_vector = torch.zeros([config.vocab_size, config.hidden_size])
|
||||
state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector
|
||||
|
||||
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
|
||||
|
||||
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
|
||||
|
||||
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
|
||||
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
|
||||
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
|
||||
|
||||
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
|
||||
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
|
||||
# TODO verify correctness of layer norm loading
|
||||
|
||||
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
|
||||
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
|
||||
|
||||
for i in range(config.num_hidden_layers):
|
||||
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
|
||||
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
|
||||
|
||||
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
|
||||
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
|
||||
|
||||
# output projection.
|
||||
|
||||
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024)
|
||||
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0)
|
||||
|
||||
state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
|
||||
# mlp layers
|
||||
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
||||
state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
|
||||
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
||||
state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
|
||||
state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
|
||||
state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
|
||||
|
||||
state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}")
|
||||
state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here)
|
||||
|
||||
# fmt: on
|
||||
final_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if not isinstance(value, torch.Tensor):
|
||||
final_state_dict[key] = torch.from_numpy(value)
|
||||
else:
|
||||
final_state_dict[key] = value
|
||||
return final_state_dict
|
||||
|
||||
|
||||
def flatten_for_memory(tree, parent_key=""):
|
||||
out = {}
|
||||
for k, v in tree.items():
|
||||
new_key = f"{parent_key}/{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
out.update(flatten_for_memory(v, new_key))
|
||||
else:
|
||||
out[new_key] = np.array(v) # Ensure conversion to np.array for consistency
|
||||
return out
|
||||
|
||||
|
||||
def flatten_for_npz(tree, parent_key=""):
|
||||
out = {}
|
||||
for k, v in tree.items():
|
||||
new_key = f"{parent_key}/{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
out.update(flatten_for_npz(v, new_key))
|
||||
else:
|
||||
# bf16/f32 here?
|
||||
out[new_key] = np.array(v)
|
||||
return out
|
||||
|
||||
|
||||
def slice_initial_orbax_checkpoint(checkpoint_dir: str):
|
||||
params_path = pathlib.Path(checkpoint_dir).resolve()
|
||||
checkpointer = ocp.PyTreeCheckpointer()
|
||||
|
||||
metadata = checkpointer.metadata(params_path)
|
||||
print("Metadata keys:", list(metadata.keys()))
|
||||
|
||||
params_name = "params"
|
||||
|
||||
item = {params_name: metadata[params_name]}
|
||||
device = jax.local_devices()[0] # Use the first local device
|
||||
sharding = SingleDeviceSharding(device)
|
||||
restored = checkpointer.restore(
|
||||
params_path,
|
||||
ocp.args.PyTreeRestore(
|
||||
item=item,
|
||||
restore_args=jax.tree_util.tree_map(
|
||||
lambda _: ocp.ArrayRestoreArgs(
|
||||
restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it
|
||||
sharding=sharding,
|
||||
),
|
||||
item,
|
||||
),
|
||||
transforms={},
|
||||
),
|
||||
)
|
||||
params = restored[params_name]
|
||||
|
||||
# get params for PaliGemma
|
||||
pali_params = params["PaliGemma"]
|
||||
del params["PaliGemma"]
|
||||
pali_params_flat = flatten_for_npz(pali_params)
|
||||
return {"paligemma_params": pali_params_flat, "projection_params": params}
|
||||
|
||||
|
||||
def update_keys_with_prefix(d: dict, prefix: str) -> dict:
|
||||
"""Update dictionary keys by adding a prefix."""
|
||||
return {f"{prefix}{key}": value for key, value in d.items()}
|
||||
|
||||
|
||||
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
|
||||
# Break down orbax ckpts - they are in OCDBT
|
||||
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
||||
# process projection params
|
||||
keys = [
|
||||
"state_proj",
|
||||
"action_in_proj",
|
||||
"action_out_proj",
|
||||
"action_time_mlp_in",
|
||||
"action_time_mlp_out",
|
||||
]
|
||||
|
||||
projection_params = {}
|
||||
for key in keys:
|
||||
kernel_params = initial_params["projection_params"][key]["kernel"]
|
||||
bias_params = initial_params["projection_params"][key]["bias"]
|
||||
if isinstance(kernel_params, dict):
|
||||
weight = kernel_params["value"]
|
||||
bias = bias_params["value"]
|
||||
else:
|
||||
weight = kernel_params
|
||||
bias = bias_params
|
||||
projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T
|
||||
projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias))
|
||||
|
||||
# Process PaliGemma weights
|
||||
paligemma_config = get_paligemma_config(precision)
|
||||
paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict(
|
||||
initial_params["paligemma_params"], paligemma_config
|
||||
)
|
||||
|
||||
# Process Gemma weights (at this stage they are unused)
|
||||
gemma_config = get_gemma_config(precision)
|
||||
gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config)
|
||||
|
||||
# Instantiate model from configs
|
||||
|
||||
if "pi0_aloha_sim" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
empty_cameras=2,
|
||||
adapt_to_pi_aloha=True,
|
||||
use_delta_joint_actions_aloha=False,
|
||||
)
|
||||
elif "pi0_aloha_towel" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
adapt_to_pi_aloha=True,
|
||||
use_delta_joint_actions_aloha=True,
|
||||
)
|
||||
elif "pi0_base" in checkpoint_dir:
|
||||
pi0_config = PI0Config(
|
||||
empty_cameras=0,
|
||||
adapt_to_pi_aloha=False,
|
||||
use_delta_joint_actions_aloha=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
# gemma_config=gemma_config, paligemma_config=paligemma_config)
|
||||
pi0_model = PI0Policy(pi0_config)
|
||||
|
||||
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
|
||||
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
|
||||
projection_params = update_keys_with_prefix(projection_params, "model.")
|
||||
|
||||
# load state dict
|
||||
torch_dtype = PRECISIONS[precision]
|
||||
pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params})
|
||||
pi0_model = pi0_model.to(torch_dtype)
|
||||
# pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||
|
||||
pi0_model.save_pretrained(output_path, safe_serialization=True)
|
||||
# pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype)
|
||||
|
||||
# assert that model loads properly
|
||||
del pi0_model
|
||||
PI0Policy.from_pretrained(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--checkpoint_dir",
|
||||
default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params",
|
||||
type=str,
|
||||
help="Path to the ocdbt checkpoint",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
choices=["float32", "bfloat16", "float16"],
|
||||
default="float32",
|
||||
type=str,
|
||||
help="Precision identifier for model conversion - should match the base checkpoint precision.",
|
||||
)
|
||||
# tokenizer is identical to paligemma, it appears
|
||||
|
||||
parser.add_argument(
|
||||
"--tokenizer_hub_id",
|
||||
default="google/paligemma-3b-pt-224",
|
||||
type=str,
|
||||
help="Hub path to the tokenizer to save",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to save converted weights to",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_pi0_checkpoint(
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
precision=args.precision,
|
||||
tokenizer_id=args.tokenizer_hub_id,
|
||||
output_path=args.output_path,
|
||||
)
|
||||
@@ -1,141 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from packaging.version import Version
|
||||
|
||||
if Version(torch.__version__) > Version("2.5.0"):
|
||||
# Ffex attention is only available from torch 2.5 onwards
|
||||
from torch.nn.attention.flex_attention import (
|
||||
_mask_mod_signature,
|
||||
_round_up_to_multiple,
|
||||
create_block_mask,
|
||||
create_mask,
|
||||
flex_attention,
|
||||
)
|
||||
|
||||
|
||||
# @torch.compile(dynamic=False)
|
||||
def flex_attention_forward(
|
||||
attention_mask: torch.Tensor,
|
||||
batch_size: int,
|
||||
head_dim: int,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
scaling=None,
|
||||
):
|
||||
"""
|
||||
This is defined out of classes to make compile happy.
|
||||
"""
|
||||
|
||||
original_dtype = query_states.dtype
|
||||
num_att_heads = 8
|
||||
num_key_value_heads = 1
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
key_states = key_states[:, :, :, None, :]
|
||||
key_states = key_states.expand(
|
||||
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :]
|
||||
value_states = value_states.expand(
|
||||
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
query_states = query_states.to(torch.float32)
|
||||
key_states = key_states.to(torch.float32)
|
||||
value_states = value_states.to(torch.float32)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if causal_mask is not None:
|
||||
causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
|
||||
|
||||
if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
|
||||
causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
|
||||
|
||||
def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
|
||||
def mask_mod(b, h, q_idx, kv_idx):
|
||||
# Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
|
||||
return precomputed_mask[b][h][q_idx][kv_idx]
|
||||
|
||||
return mask_mod
|
||||
|
||||
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
|
||||
|
||||
block_size = 128
|
||||
q_len_rounded = _round_up_to_multiple(q_len, block_size)
|
||||
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
|
||||
|
||||
# *CRITICAL* we do need to expand here, else we get a CUDA index error
|
||||
|
||||
pad_q = q_len_rounded - q_len
|
||||
pad_k = kv_len_rounded - kv_len
|
||||
|
||||
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
|
||||
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
|
||||
|
||||
mask_4d = create_mask(
|
||||
mod_fn=mask_mod_fn_orig,
|
||||
B=b_mask,
|
||||
H=h_mask,
|
||||
Q_LEN=q_len_rounded,
|
||||
KV_LEN=kv_len_rounded,
|
||||
device=causal_mask.device,
|
||||
_compile=False,
|
||||
)
|
||||
|
||||
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
|
||||
block_mask = create_block_mask(
|
||||
mask_mod=mask_mod_fn_padded,
|
||||
B=b_mask,
|
||||
H=h_mask,
|
||||
Q_LEN=q_len_rounded,
|
||||
KV_LEN=kv_len_rounded,
|
||||
BLOCK_SIZE=block_size,
|
||||
device=causal_mask.device,
|
||||
_compile=False,
|
||||
)
|
||||
|
||||
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
||||
attn_output, attention_weights = flex_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
block_mask=block_mask,
|
||||
enable_gqa=True, # because we shaped query/key states for GQA
|
||||
scale=head_dim**-0.5 if scaling is None else scaling,
|
||||
return_lse=True,
|
||||
)
|
||||
|
||||
attn_output = attn_output.to(dtype=original_dtype)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
|
||||
attn_output = attn_output.reshape(
|
||||
batch_size,
|
||||
-1,
|
||||
attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
|
||||
)
|
||||
return attn_output
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,420 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.version
|
||||
from pytest import Cache
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
GemmaForCausalLM,
|
||||
PaliGemmaForConditionalGeneration,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
)
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.policies.pi0.flex_attention import flex_attention_forward
|
||||
|
||||
|
||||
def apply_rope(x, positions, max_wavelength=10_000):
|
||||
"""
|
||||
Applies RoPE positions [B, L] to x [B, L, H, D].
|
||||
"""
|
||||
d_half = x.shape[-1] // 2
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
||||
|
||||
radians = radians[..., None, :]
|
||||
|
||||
sin = torch.sin(radians) # .to(dtype=dtype)
|
||||
cos = torch.cos(radians) # .to(dtype=dtype)
|
||||
|
||||
x1, x2 = x.split(d_half, dim=-1)
|
||||
res = torch.empty_like(x)
|
||||
res[..., :d_half] = x1 * cos - x2 * sin
|
||||
res[..., d_half:] = x2 * cos + x1 * sin
|
||||
|
||||
return res.to(dtype)
|
||||
|
||||
|
||||
class PaliGemmaWithExpertConfig(PretrainedConfig):
|
||||
model_type = "PaliGemmaWithExpertModel"
|
||||
sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
paligemma_config: dict | None = None,
|
||||
gemma_expert_config: dict | None = None,
|
||||
freeze_vision_encoder: bool = True,
|
||||
train_expert_only: bool = True,
|
||||
attention_implementation: str = "eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.freeze_vision_encoder = freeze_vision_encoder
|
||||
self.train_expert_only = train_expert_only
|
||||
self.attention_implementation = attention_implementation
|
||||
|
||||
if paligemma_config is None:
|
||||
# Default config from Pi0
|
||||
self.paligemma_config = CONFIG_MAPPING["paligemma"](
|
||||
transformers_version="4.48.1",
|
||||
_vocab_size=257152,
|
||||
bos_token_id=2,
|
||||
eos_token_id=1,
|
||||
hidden_size=2048,
|
||||
image_token_index=257152,
|
||||
model_type="paligemma",
|
||||
pad_token_id=0,
|
||||
projection_dim=2048,
|
||||
text_config={
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 16384,
|
||||
"model_type": "gemma",
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 18,
|
||||
"num_image_tokens": 256,
|
||||
"num_key_value_heads": 1,
|
||||
"torch_dtype": "float32",
|
||||
"vocab_size": 257152,
|
||||
},
|
||||
vision_config={
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"num_image_tokens": 256,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 2048,
|
||||
"projector_hidden_act": "gelu_fast",
|
||||
"torch_dtype": "float32",
|
||||
"vision_use_head": False,
|
||||
},
|
||||
)
|
||||
elif isinstance(self.paligemma_config, dict):
|
||||
# Override Pi0 default config for PaliGemma
|
||||
if "model_type" not in gemma_expert_config:
|
||||
paligemma_config["model_type"] = "paligemma"
|
||||
|
||||
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
||||
self.paligemma_config = cfg_cls(**paligemma_config)
|
||||
|
||||
if gemma_expert_config is None:
|
||||
# Default config from Pi0
|
||||
self.gemma_expert_config = CONFIG_MAPPING["gemma"](
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
bos_token_id=2,
|
||||
eos_token_id=1,
|
||||
head_dim=256,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
hidden_size=1024,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=4096,
|
||||
max_position_embeddings=8192,
|
||||
model_type="gemma",
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=18,
|
||||
num_key_value_heads=1,
|
||||
pad_token_id=0,
|
||||
rms_norm_eps=1e-06,
|
||||
rope_theta=10000.0,
|
||||
torch_dtype="float32",
|
||||
transformers_version="4.48.1",
|
||||
use_cache=True,
|
||||
vocab_size=257152,
|
||||
)
|
||||
elif isinstance(self.gemma_expert_config, dict):
|
||||
# Override Pi0 default config for Gemma Expert
|
||||
if "model_type" not in gemma_expert_config:
|
||||
gemma_expert_config["model_type"] = "gemma"
|
||||
|
||||
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
||||
self.gemma_expert_config = cfg_cls(**gemma_expert_config)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.train_expert_only and not self.freeze_vision_encoder:
|
||||
raise ValueError(
|
||||
"You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
|
||||
)
|
||||
|
||||
if self.attention_implementation not in ["eager", "fa2", "flex"]:
|
||||
raise ValueError(
|
||||
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
|
||||
)
|
||||
|
||||
|
||||
class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
config_class = PaliGemmaWithExpertConfig
|
||||
|
||||
def __init__(self, config: PaliGemmaWithExpertConfig):
|
||||
super().__init__(config=config)
|
||||
self.config = config
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
|
||||
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
|
||||
# Remove unused embed_tokens
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_like_physical_intelligence()
|
||||
self.set_requires_grad()
|
||||
|
||||
def set_requires_grad(self):
|
||||
if self.config.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
for params in self.paligemma.vision_tower.parameters():
|
||||
params.requires_grad = False
|
||||
|
||||
if self.config.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
for params in self.paligemma.parameters():
|
||||
params.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
|
||||
if self.config.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
|
||||
if self.config.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
|
||||
def to_bfloat16_like_physical_intelligence(self):
|
||||
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
|
||||
|
||||
params_to_change_dtype = [
|
||||
"language_model.model.layers",
|
||||
"gemma_expert.model.layers",
|
||||
"vision_tower",
|
||||
"multi_modal",
|
||||
]
|
||||
for name, param in self.named_parameters():
|
||||
if any(selector in name for selector in params_to_change_dtype):
|
||||
param.data = param.data.to(dtype=torch.bfloat16)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
# Handle different transformers versions
|
||||
if hasattr(self.paligemma, "get_image_features"):
|
||||
return self.paligemma.get_image_features(image)
|
||||
else:
|
||||
return self.paligemma.model.get_image_features(image)
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.language_model.embed_tokens(tokens)
|
||||
|
||||
# TODO: break down this huge forward into modules or functions
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | Cache | None = None,
|
||||
inputs_embeds: list[torch.FloatTensor] = None,
|
||||
use_cache: bool | None = None,
|
||||
fill_kv_cache: bool | None = None,
|
||||
):
|
||||
models = [self.paligemma.language_model, self.gemma_expert.model]
|
||||
|
||||
for hidden_states in inputs_embeds:
|
||||
# TODO this is very inefficient
|
||||
# dtype is always the same, batch size too (if > 1 len)
|
||||
# device could be trickier in multi gpu edge cases but that's it
|
||||
if hidden_states is None:
|
||||
continue
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# RMSNorm
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
head_dim = self.paligemma.config.text_config.head_dim
|
||||
for layer_idx in range(num_layers):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
if hidden_states is None:
|
||||
continue
|
||||
layer = models[i].layers[layer_idx]
|
||||
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
# hidden_states = hidden_states * normalizer
|
||||
hidden_states = layer.input_layernorm(hidden_states)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=torch.bfloat16)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
# concatenate on the number of embeddings/tokens
|
||||
query_states = torch.cat(query_states, dim=1)
|
||||
key_states = torch.cat(key_states, dim=1)
|
||||
value_states = torch.cat(value_states, dim=1)
|
||||
|
||||
query_states = apply_rope(query_states, position_ids)
|
||||
key_states = apply_rope(key_states, position_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
value_states = torch.cat(
|
||||
[past_key_values[layer_idx]["value_states"], value_states], dim=1
|
||||
)
|
||||
|
||||
attention_interface = self.get_attention_interface()
|
||||
att_output = attention_interface(
|
||||
attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
att_output = att_output.to(dtype=torch.bfloat16)
|
||||
|
||||
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
|
||||
outputs_embeds = []
|
||||
start = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
|
||||
if hidden_states is not None:
|
||||
end = start + hidden_states.shape[1]
|
||||
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start:end])
|
||||
|
||||
# TODO: first dropout (by default 0.0)
|
||||
|
||||
# first residual
|
||||
out_emb += hidden_states
|
||||
after_first_residual = out_emb.clone()
|
||||
|
||||
out_emb = layer.post_attention_layernorm(out_emb)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
|
||||
# TODO: second dropout (by default 0.0)
|
||||
|
||||
# second residual
|
||||
out_emb += after_first_residual
|
||||
|
||||
outputs_embeds.append(out_emb)
|
||||
|
||||
start = end
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
|
||||
inputs_embeds = outputs_embeds
|
||||
|
||||
# final norm
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
if hidden_states is not None:
|
||||
out_emb = models[i].norm(hidden_states)
|
||||
outputs_embeds.append(out_emb)
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
|
||||
return outputs_embeds, past_key_values
|
||||
|
||||
def get_attention_interface(self):
|
||||
if self.config.attention_implementation == "fa2":
|
||||
attention_interface = self.flash_attention_forward
|
||||
elif self.config.attention_implementation == "flex":
|
||||
attention_interface = flex_attention_forward
|
||||
else:
|
||||
attention_interface = self.eager_attention_forward
|
||||
return attention_interface
|
||||
|
||||
def flash_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
raise NotImplementedError("FA2 is not implemented (yet)")
|
||||
|
||||
def eager_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
|
||||
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
# query_states: batch_size, sequence_length, num_att_head, head_dim
|
||||
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
|
||||
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
|
||||
sequence_length = key_states.shape[1]
|
||||
|
||||
key_states = key_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
# Attention here is upcasted to float32 to match the original eager implementation.
|
||||
|
||||
query_states = query_states.to(dtype=torch.float32)
|
||||
key_states = key_states.to(dtype=torch.float32)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
|
||||
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
att_weights *= head_dim**-0.5
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
|
||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||
|
||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||
probs = probs.to(dtype=value_states.dtype)
|
||||
|
||||
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
|
||||
# value_states: batch_size, sequence_length, num_att_heads, head_dim
|
||||
|
||||
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
|
||||
|
||||
att_output = att_output.permute(0, 2, 1, 3)
|
||||
# we use -1 because sequence length can change
|
||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||
|
||||
return att_output
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user