mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
Compare commits
1 Commits
feat/robom
...
feat/dynam
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a5be3a3b6f |
@@ -12,11 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Python virtual environments — never copy into Docker images
|
||||
.venv
|
||||
venv
|
||||
env/
|
||||
|
||||
# Misc
|
||||
.git
|
||||
tmp
|
||||
|
||||
9
.github/workflows/fast_tests.yml
vendored
9
.github/workflows/fast_tests.yml
vendored
@@ -44,7 +44,7 @@ permissions:
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.12"
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
|
||||
concurrency:
|
||||
@@ -61,7 +61,6 @@ jobs:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
@@ -90,11 +89,5 @@ jobs:
|
||||
- name: Install lerobot with test extras
|
||||
run: uv sync --extra "test"
|
||||
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
|
||||
- name: Run pytest
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
|
||||
17
.github/workflows/full_tests.yml
vendored
17
.github/workflows/full_tests.yml
vendored
@@ -37,7 +37,7 @@ permissions:
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.12"
|
||||
PYTHON_VERSION: "3.10"
|
||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
|
||||
|
||||
# Ensures that only the latest action is built, canceling older runs.
|
||||
@@ -60,7 +60,6 @@ jobs:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
@@ -88,12 +87,6 @@ jobs:
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
|
||||
@@ -169,7 +162,6 @@ jobs:
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
@@ -181,13 +173,8 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Fix ptxas permissions
|
||||
run: chmod +x /lerobot/.venv/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
|
||||
24
.github/workflows/nightly.yml
vendored
24
.github/workflows/nightly.yml
vendored
@@ -28,7 +28,7 @@ on:
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.12"
|
||||
PYTHON_VERSION: "3.10"
|
||||
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest
|
||||
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest
|
||||
|
||||
@@ -119,7 +119,6 @@ jobs:
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --shm-size "16gb"
|
||||
@@ -131,11 +130,6 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Run pytest on CPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
@@ -152,7 +146,6 @@ jobs:
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
@@ -164,11 +157,6 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
@@ -186,7 +174,6 @@ jobs:
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
CUDA_VISIBLE_DEVICES: "0,1,2,3"
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
@@ -198,15 +185,12 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Verify GPU availability
|
||||
run: |
|
||||
nvidia-smi
|
||||
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
||||
|
||||
- name: Run multi-GPU training tests
|
||||
run: pytest -vv tests/training/
|
||||
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
|
||||
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
|
||||
timeout-minutes: 10
|
||||
|
||||
2
.github/workflows/quality.yml
vendored
2
.github/workflows/quality.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Run pre-commit hooks
|
||||
uses: pre-commit/action@v3.0.1 # zizmor: ignore[unpinned-uses]
|
||||
|
||||
12
.github/workflows/release.yml
vendored
12
.github/workflows/release.yml
vendored
@@ -22,7 +22,7 @@ on:
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.12"
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
jobs:
|
||||
# This job builds the Python package and publishes it to PyPI
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Extract Version
|
||||
id: extract_info
|
||||
@@ -83,6 +83,14 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Remove Tags with Git dependencies
|
||||
# TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies.
|
||||
run: |
|
||||
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
|
||||
grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
|
||||
sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml
|
||||
echo "::info:: Git dependencies removed. Proceeding with build."
|
||||
|
||||
- name: Install build dependencies
|
||||
run: python -m pip install build
|
||||
|
||||
|
||||
15
.github/workflows/unbound_deps_tests.yml
vendored
15
.github/workflows/unbound_deps_tests.yml
vendored
@@ -29,7 +29,7 @@ permissions:
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.12"
|
||||
PYTHON_VERSION: "3.10"
|
||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound
|
||||
|
||||
# Ensures that only the latest action is built, canceling older runs.
|
||||
@@ -48,7 +48,6 @@ jobs:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
@@ -80,11 +79,7 @@ jobs:
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv
|
||||
|
||||
@@ -142,7 +137,6 @@ jobs:
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
@@ -154,11 +148,6 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv
|
||||
- name: Run end-to-end tests
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
default_language_version:
|
||||
python: python3.12
|
||||
python: python3.10
|
||||
|
||||
exclude: "tests/artifacts/.*\\.safetensors$"
|
||||
|
||||
@@ -55,7 +55,7 @@ repos:
|
||||
rev: v3.21.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py312-plus]
|
||||
args: [--py310-plus]
|
||||
|
||||
##### Markdown Quality #####
|
||||
- repo: https://github.com/rbubley/mirrors-prettier
|
||||
|
||||
19
README.md
19
README.md
@@ -135,7 +135,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
||||
|
||||
## Citation
|
||||
|
||||
If you use LeRobot in your project, please cite the GitHub repository to acknowledge the ongoing development and contributors:
|
||||
If you use LeRobot in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
@@ -146,23 +146,6 @@ If you use LeRobot in your project, please cite the GitHub repository to acknowl
|
||||
}
|
||||
```
|
||||
|
||||
If you are referencing our research or the academic paper, please also cite our ICLR publication:
|
||||
|
||||
<details>
|
||||
<summary><b>ICLR 2026 Paper</b></summary>
|
||||
|
||||
```bibtex
|
||||
@inproceedings{cadenelerobot,
|
||||
title={LeRobot: An Open-Source Library for End-to-End Robot Learning},
|
||||
author={Cadene, Remi and Alibert, Simon and Capuano, Francesco and Aractingi, Michel and Zouitine, Adil and Kooijmans, Pepijn and Choghari, Jade and Russi, Martino and Pascal, Caroline and Palma, Steven and Shukor, Mustafa and Moss, Jess and Soare, Alexander and Aubakirova, Dana and Lhoest, Quentin and Gallou\'edec, Quentin and Wolf, Thomas},
|
||||
booktitle={The Fourteenth International Conference on Learning Representations},
|
||||
year={2026},
|
||||
url={https://arxiv.org/abs/2602.22818}
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Contribute
|
||||
|
||||
We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](./CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!
|
||||
|
||||
@@ -1,200 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Benchmark evaluation container — one image per benchmark, built via BENCHMARK arg.
|
||||
#
|
||||
# Supported values for BENCHMARK:
|
||||
# libero — LIBERO suite (spatial / object / goal / 10 / 90)
|
||||
# libero_plus — LIBERO-plus extended benchmark (requires robosuite, bddl, robomimic)
|
||||
# robomme — RoboMME memory-augmented manipulation benchmark
|
||||
# robocasa — RoboCasa kitchen composite-task benchmark
|
||||
#
|
||||
# Build:
|
||||
# docker build --build-arg BENCHMARK=libero -f docker/Dockerfile.benchmark \
|
||||
# -t lerobot-benchmark-libero .
|
||||
#
|
||||
# Run (interactive):
|
||||
# docker run --gpus all --rm -it lerobot-benchmark-libero
|
||||
# Run eval:
|
||||
# docker run --gpus all --rm lerobot-benchmark-libero lerobot-eval --help
|
||||
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG OS_VERSION=22.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG BENCHMARK=libero
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
MUJOCO_GL=egl \
|
||||
PYOPENGL_PLATFORM=egl \
|
||||
EGL_PLATFORM=device \
|
||||
NVIDIA_DRIVER_CAPABILITIES=all \
|
||||
NVIDIA_VISIBLE_DEVICES=all \
|
||||
PATH=/lerobot/.venv/bin:$PATH \
|
||||
CMAKE_POLICY_VERSION_MINIMUM=3.5 \
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
DEVICE=cuda \
|
||||
BENCHMARK=${BENCHMARK}
|
||||
|
||||
# ── Base system deps (shared across all benchmarks) ───────────────────────────
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
software-properties-common build-essential git curl \
|
||||
libglib2.0-0 libgl1 libgl1-mesa-glx libgles2 \
|
||||
libegl1 libegl1-mesa libegl1-mesa-dev \
|
||||
libglew-dev libglfw3 libglfw3-dev libgl1-mesa-dri \
|
||||
libglvnd-dev libosmesa6 libosmesa6-dev \
|
||||
libvulkan1 mesa-vulkan-drivers \
|
||||
libsm6 libxext6 libxrender-dev \
|
||||
ffmpeg libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
|
||||
cmake pkg-config ninja-build \
|
||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
python${PYTHON_VERSION} \
|
||||
python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-dev \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||
&& useradd --create-home --shell /bin/bash user_lerobot \
|
||||
&& usermod -aG sudo user_lerobot \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# ── NVIDIA EGL + Vulkan vendor ICDs (lets GLVND find the GPU driver) ──────────
|
||||
RUN mkdir -p /usr/share/vulkan/icd.d /usr/share/glvnd/egl_vendor.d \
|
||||
&& printf '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.2.155"}}\n' \
|
||||
> /usr/share/vulkan/icd.d/nvidia_icd.json \
|
||||
&& printf '{"file_format_version":"1.0.0","ICD":{"library_path":"libEGL_nvidia.so.0"}}\n' \
|
||||
> /usr/share/glvnd/egl_vendor.d/10_nvidia.json
|
||||
|
||||
# ── Benchmark-specific system deps ────────────────────────────────────────────
|
||||
# libero_plus: the `wand` Python package requires ImageMagick headers.
|
||||
RUN case "${BENCHMARK}" in \
|
||||
libero_plus) \
|
||||
apt-get update && apt-get install -y --no-install-recommends \
|
||||
libmagickwand-dev \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* ;; \
|
||||
esac
|
||||
|
||||
WORKDIR /lerobot
|
||||
RUN chown -R user_lerobot:user_lerobot /lerobot
|
||||
|
||||
USER user_lerobot
|
||||
|
||||
ENV HOME=/home/user_lerobot \
|
||||
HF_HOME=/home/user_lerobot/.cache/huggingface \
|
||||
HF_LEROBOT_HOME=/home/user_lerobot/.cache/huggingface/lerobot \
|
||||
TORCH_HOME=/home/user_lerobot/.cache/torch \
|
||||
TRITON_CACHE_DIR=/home/user_lerobot/.cache/triton
|
||||
|
||||
RUN uv venv --seed --python python${PYTHON_VERSION}
|
||||
|
||||
# Copy only the dependency manifests first so Docker can cache this layer
|
||||
# independently of source-code changes.
|
||||
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
|
||||
ARG UNBOUND_DEPS=false
|
||||
RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml; \
|
||||
echo "Dependencies unbound:" && cat pyproject.toml; \
|
||||
fi
|
||||
|
||||
# Install lerobot core + the selected benchmark extra.
|
||||
# LIBERO-plus needs a dedicated install path because the upstream package is
|
||||
# import-broken when installed via the extras chain alone.
|
||||
RUN case "${BENCHMARK}" in \
|
||||
libero_plus) \
|
||||
PATH=/usr/bin:/bin:/lerobot/.venv/bin:$PATH /lerobot/.venv/bin/python -m pip install --no-cache-dir \
|
||||
"hf-libero>=0.1.3,<0.2.0" \
|
||||
"hf-egl-probe>=1.0.1" \
|
||||
"transformers>=5.3.0,<6.0.0" \
|
||||
"scipy>=1.14.0,<2.0.0" \
|
||||
"bddl>=1.0.1,<2.0.0" \
|
||||
"future" \
|
||||
"easydict>=1.9" \
|
||||
"wand" \
|
||||
"scikit-image>=0.20.0" \
|
||||
"gym>=0.25.0,<0.27.0" \
|
||||
&& git clone --depth 1 https://github.com/sylvestf/LIBERO-plus.git /tmp/LIBERO-plus \
|
||||
&& PATH=/usr/bin:/bin:/lerobot/.venv/bin:$PATH /lerobot/.venv/bin/python -m pip install --no-cache-dir --no-deps /tmp/LIBERO-plus \
|
||||
&& /lerobot/.venv/bin/python -c "import pathlib, site; pathlib.Path(site.getsitepackages()[0], 'libero_plus_repo.pth').write_text('/tmp/LIBERO-plus\n')" \
|
||||
&& /lerobot/.venv/bin/python -m pip install --no-cache-dir . \
|
||||
&& /lerobot/.venv/bin/python -c "\
|
||||
import os, yaml, importlib.util; \
|
||||
root = os.path.dirname(importlib.util.find_spec('libero.libero').origin); \
|
||||
d = dict(benchmark_root=root, bddl_files=os.path.join(root,'bddl_files'), \
|
||||
init_states=os.path.join(root,'init_files'), datasets=os.path.join(root,'..','datasets'), \
|
||||
assets=os.path.join(root,'assets')); \
|
||||
cfg_dir = os.path.expanduser('~/.libero'); os.makedirs(cfg_dir, exist_ok=True); \
|
||||
yaml.dump(d, open(os.path.join(cfg_dir,'config.yaml'),'w')); print('libero config created')" \
|
||||
&& /lerobot/.venv/bin/python -c "from libero.libero import benchmark, get_libero_path; print('libero OK')" ;; \
|
||||
libero) \
|
||||
uv pip install --no-cache ".[libero]" \
|
||||
&& /lerobot/.venv/bin/python -c "\
|
||||
import os, yaml, importlib.util; \
|
||||
root = os.path.dirname(importlib.util.find_spec('libero.libero').origin); \
|
||||
d = dict(benchmark_root=root, bddl_files=os.path.join(root,'bddl_files'), \
|
||||
init_states=os.path.join(root,'init_files'), datasets=os.path.join(root,'..','datasets'), \
|
||||
assets=os.path.join(root,'assets')); \
|
||||
cfg_dir = os.path.expanduser('~/.libero'); os.makedirs(cfg_dir, exist_ok=True); \
|
||||
yaml.dump(d, open(os.path.join(cfg_dir,'config.yaml'),'w')); print('libero config created')" \
|
||||
&& /lerobot/.venv/bin/python -c "from libero.libero import benchmark, get_libero_path; print('libero OK')" ;; \
|
||||
*) \
|
||||
uv pip install --no-cache ".[${BENCHMARK}]" ;; \
|
||||
esac
|
||||
|
||||
# LIBERO-plus requires ~6 GB of scene/texture/object assets from HuggingFace.
|
||||
# Download at build time so containers don't need network access at runtime.
|
||||
USER root
|
||||
COPY <<'FETCH_ASSETS' /tmp/fetch_assets.py
|
||||
from huggingface_hub import hf_hub_download
|
||||
hf_hub_download("Sylvest/LIBERO-plus", "assets.zip",
|
||||
repo_type="dataset", local_dir="/tmp/libero-plus-assets")
|
||||
FETCH_ASSETS
|
||||
COPY <<'VERIFY_ASSETS' /tmp/verify_assets.py
|
||||
from pathlib import Path
|
||||
from libero.libero import get_libero_path
|
||||
d = Path(get_libero_path("benchmark_root")) / "assets" / "scenes"
|
||||
assert d.is_dir(), f"assets missing at {d}"
|
||||
print("assets OK:", d)
|
||||
VERIFY_ASSETS
|
||||
RUN if [ "${BENCHMARK}" = "libero_plus" ]; then \
|
||||
apt-get update && apt-get install -y --no-install-recommends unzip \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||
&& /lerobot/.venv/bin/python /tmp/fetch_assets.py \
|
||||
&& unzip -q /tmp/libero-plus-assets/assets.zip -d /tmp/libero-plus-unzipped \
|
||||
&& ASSETS_DIR=$(/lerobot/.venv/bin/python -c "from libero.libero import get_libero_path; print(get_libero_path('benchmark_root'))") \
|
||||
&& SRC=$(find /tmp/libero-plus-unzipped -type d -name assets | head -1) \
|
||||
&& mv "$SRC" "$ASSETS_DIR/assets" \
|
||||
&& chown -R user_lerobot:user_lerobot "$ASSETS_DIR/assets" \
|
||||
&& rm -rf /tmp/libero-plus-assets /tmp/libero-plus-unzipped /tmp/fetch_assets.py \
|
||||
&& /lerobot/.venv/bin/python /tmp/verify_assets.py \
|
||||
&& rm /tmp/verify_assets.py; \
|
||||
fi
|
||||
USER user_lerobot
|
||||
|
||||
# Triton requires its ptxas binary to be executable (NVIDIA-specific).
|
||||
RUN if [ -f /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas ]; then \
|
||||
chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas; \
|
||||
fi
|
||||
|
||||
# Verify EGL probe is importable (runtime GPU check requires NVIDIA drivers at container start).
|
||||
RUN /lerobot/.venv/bin/python -c "import egl_probe; print('egl_probe OK')" \
|
||||
2>/dev/null || echo 'NOTE: egl_probe not installed (non-libero build), skipping'
|
||||
|
||||
# Copy full source (tests, examples, configs, etc.)
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -1,78 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG OS_VERSION=22.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
MUJOCO_GL=egl \
|
||||
PYOPENGL_PLATFORM=egl \
|
||||
EGL_PLATFORM=device \
|
||||
NVIDIA_DRIVER_CAPABILITIES=all \
|
||||
NVIDIA_VISIBLE_DEVICES=all \
|
||||
PATH=/lerobot/.venv/bin:$PATH \
|
||||
# cmake 4.x removed backward compat with cmake_minimum_required < 3.5.
|
||||
# This env var re-enables it so packages like egl-probe can compile.
|
||||
CMAKE_POLICY_VERSION_MINIMUM=3.5
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
software-properties-common build-essential git curl \
|
||||
libglib2.0-0 libgl1 libgl1-mesa-glx libgles2 \
|
||||
libegl1 libegl1-mesa libegl1-mesa-dev \
|
||||
libglew-dev libglfw3 libglvnd-dev \
|
||||
libosmesa6 libosmesa6-dev \
|
||||
libvulkan1 mesa-vulkan-drivers \
|
||||
libsm6 libxext6 libxrender-dev \
|
||||
ffmpeg libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
|
||||
cmake pkg-config ninja-build \
|
||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
python${PYTHON_VERSION} \
|
||||
python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-dev \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||
&& useradd --create-home --shell /bin/bash user_lerobot \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# NVIDIA EGL + Vulkan vendor ICDs (lets GLVND find the GPU driver)
|
||||
RUN mkdir -p /usr/share/vulkan/icd.d /usr/share/glvnd/egl_vendor.d \
|
||||
&& printf '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.2.155"}}\n' \
|
||||
> /usr/share/vulkan/icd.d/nvidia_icd.json \
|
||||
&& printf '{"file_format_version":"1.0.0","ICD":{"library_path":"libEGL_nvidia.so.0"}}\n' \
|
||||
> /usr/share/glvnd/egl_vendor.d/10_nvidia.json
|
||||
|
||||
WORKDIR /lerobot
|
||||
RUN chown -R user_lerobot:user_lerobot /lerobot
|
||||
USER user_lerobot
|
||||
|
||||
ENV HOME=/home/user_lerobot \
|
||||
HF_HOME=/home/user_lerobot/.cache/huggingface \
|
||||
HF_LEROBOT_HOME=/home/user_lerobot/.cache/huggingface/lerobot \
|
||||
TORCH_HOME=/home/user_lerobot/.cache/torch \
|
||||
TRITON_CACHE_DIR=/home/user_lerobot/.cache/triton
|
||||
|
||||
RUN uv venv --seed --python python${PYTHON_VERSION}
|
||||
|
||||
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
RUN uv pip install --no-cache .
|
||||
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -1,20 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM lerobot-eval-base:latest
|
||||
|
||||
RUN uv pip install --no-cache ".[libero]" \
|
||||
&& python -c "import libero"
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -1,47 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM lerobot-eval-base:latest
|
||||
|
||||
# Install libero_plus deps explicitly rather than via ".[libero_plus]" extras chain.
|
||||
# uv has a bug where it considers packages "already resolved" when coming through
|
||||
# a nested lerobot[libero] → lerobot[libero_plus] extras chain, silently skipping them.
|
||||
RUN uv pip install --no-cache \
|
||||
"hf-libero>=0.1.3,<0.2.0" \
|
||||
"hf-egl-probe>=1.0.1" \
|
||||
"transformers>=5.3.0,<6.0.0" \
|
||||
"scipy>=1.14.0,<2.0.0" \
|
||||
"bddl>=1.0.1,<2.0.0" \
|
||||
"future" \
|
||||
"easydict>=1.9" \
|
||||
"wand" \
|
||||
"scikit-image>=0.20.0" \
|
||||
"gym>=0.25.0,<0.27.0"
|
||||
|
||||
# Clone LIBERO-plus; install with --no-deps (runtime deps declared above via hf-libero).
|
||||
# Add .pth so the libero module can locate its data files at runtime.
|
||||
RUN git clone --depth 1 https://github.com/sylvestf/LIBERO-plus.git /tmp/LIBERO-plus \
|
||||
&& uv pip install --no-cache --no-deps /tmp/LIBERO-plus \
|
||||
&& python -c "import pathlib, site; pathlib.Path(site.getsitepackages()[0], 'libero_plus_repo.pth').write_text('/tmp/LIBERO-plus\n')" \
|
||||
&& python -c "\
|
||||
import os, yaml, importlib.util; \
|
||||
root = os.path.dirname(importlib.util.find_spec('libero.libero').origin); \
|
||||
d = dict(benchmark_root=root, bddl_files=os.path.join(root,'bddl_files'), \
|
||||
init_states=os.path.join(root,'init_files'), datasets=os.path.join(root,'..','datasets'), \
|
||||
assets=os.path.join(root,'assets')); \
|
||||
cfg_dir = os.path.expanduser('~/.libero'); os.makedirs(cfg_dir, exist_ok=True); \
|
||||
yaml.dump(d, open(os.path.join(cfg_dir,'config.yaml'),'w')); print('libero config created')" \
|
||||
&& python -c "from libero.libero import benchmark, get_libero_path; print('libero OK')"
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -1,20 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM lerobot-eval-base:latest
|
||||
|
||||
RUN uv pip install --no-cache ".[metaworld]" \
|
||||
&& python -c "import metaworld"
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -1,40 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM lerobot-eval-base:latest
|
||||
|
||||
# robocasa README says to use master branch of ARISE-Initiative/robosuite.
|
||||
# Install it with deps (robosuite from master has modern dep declarations).
|
||||
RUN git clone --depth 1 https://github.com/ARISE-Initiative/robosuite.git /tmp/robosuite \
|
||||
&& uv pip install --no-cache /tmp/robosuite
|
||||
|
||||
# Clone robocasa and install with --no-deps to skip its lerobot==0.3.3 pin.
|
||||
# Install robocasa's actual runtime deps explicitly instead.
|
||||
RUN git clone --depth 1 https://github.com/robocasa/robocasa.git /tmp/robocasa \
|
||||
&& uv pip install --no-cache --no-deps /tmp/robocasa \
|
||||
&& uv pip install --no-cache \
|
||||
"scikit-image>=0.20.0" \
|
||||
"numba>=0.61.0,<0.62.0" \
|
||||
"mujoco==3.3.1" \
|
||||
"h5py" \
|
||||
"lxml" \
|
||||
"tianshou==0.4.10" \
|
||||
"easydict>=1.9"
|
||||
|
||||
# robocasa/__init__.py asserts numpy.__version__ in ["2.2.5"] — pin it last
|
||||
# so no subsequent package can bump it away.
|
||||
RUN uv pip install --no-cache "numpy==2.2.5" \
|
||||
&& python -c "import robocasa"
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -1,26 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM lerobot-eval-base:latest
|
||||
|
||||
# mani-skill==3.0.0b21 (robomme dep) pins gymnasium==0.29.1 and numpy<2.0.0,
|
||||
# conflicting with lerobot's gymnasium>=1.1.1 and numpy>=2.0.0.
|
||||
# Both overrides are safe at runtime:
|
||||
# - gymnasium 0.29.x has the same 5-tuple step() API as 1.x (since gym 0.26)
|
||||
# - numpy 1.26.4 is API-compatible with lerobot's actual usage (no 2.x-only APIs used)
|
||||
RUN printf 'gymnasium==0.29.1\nnumpy==1.26.4\n' > /tmp/robomme_override.txt \
|
||||
&& uv pip install --no-cache --override /tmp/robomme_override.txt ".[robomme]" \
|
||||
&& python -c "import robomme"
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -24,7 +24,7 @@ ARG OS_VERSION=22.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
# Define Python version argument
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG PYTHON_VERSION=3.10
|
||||
|
||||
# Configure environment variables
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
# docker run -it --rm lerobot-user
|
||||
|
||||
# Configure the base image
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG PYTHON_VERSION=3.10
|
||||
FROM python:${PYTHON_VERSION}-slim
|
||||
|
||||
# Configure environment variables
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# Build (and optionally push) all lerobot benchmark eval images.
|
||||
#
|
||||
# Usage:
|
||||
# # Build locally only (for testing on this machine)
|
||||
# bash docker/build_benchmark_images.sh
|
||||
#
|
||||
# # Build and push to Docker Hub under your org
|
||||
# bash docker/build_benchmark_images.sh --push --hub_org=pepijn223
|
||||
#
|
||||
# # Force-rebuild base image (e.g. after Dockerfile.eval-base changes)
|
||||
# bash docker/build_benchmark_images.sh --no-cache-base --push --hub_org=pepijn223
|
||||
#
|
||||
# # Build only specific benchmarks
|
||||
# bash docker/build_benchmark_images.sh --benchmarks="libero_plus robomme"
|
||||
#
|
||||
# After building, run eval with:
|
||||
# lerobot-eval --eval.runtime=docker --eval.docker.pull=false \
|
||||
# --eval.docker.image=<hub_org>/lerobot-benchmark-<benchmark>:latest ...
|
||||
# OR (if run locally with the default tag):
|
||||
# lerobot-eval --eval.runtime=docker --eval.docker.pull=false \
|
||||
# --env.type=<benchmark> ... # auto-resolves to lerobot-benchmark-<benchmark>
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
PUSH=false
|
||||
HUB_ORG=""
|
||||
BENCHMARKS="libero libero_plus robomme robocasa metaworld"
|
||||
NO_CACHE_BASE=false
|
||||
PROGRESS="auto"
|
||||
|
||||
for arg in "$@"; do
|
||||
case "$arg" in
|
||||
--push) PUSH=true ;;
|
||||
--hub_org=*) HUB_ORG="${arg#*=}" ;;
|
||||
--benchmarks=*) BENCHMARKS="${arg#*=}" ;;
|
||||
--no-cache-base) NO_CACHE_BASE=true ;;
|
||||
--plain) PROGRESS="plain" ;;
|
||||
*) echo "Unknown arg: $arg"; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
||||
|
||||
if [[ "$PUSH" == "true" && -z "$HUB_ORG" ]]; then
|
||||
echo "ERROR: --push requires --hub_org=<your-dockerhub-org>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ok() { echo "[OK] $*"; }
|
||||
fail() { echo "[FAIL] $*"; exit 1; }
|
||||
|
||||
BASE_CACHE_FLAG=""
|
||||
if [[ "$NO_CACHE_BASE" == "true" ]]; then
|
||||
BASE_CACHE_FLAG="--no-cache"
|
||||
fi
|
||||
|
||||
echo "=== Building lerobot-eval-base ==="
|
||||
docker build \
|
||||
${BASE_CACHE_FLAG} \
|
||||
--progress="${PROGRESS}" \
|
||||
-f "${SCRIPT_DIR}/Dockerfile.eval-base" \
|
||||
-t lerobot-eval-base:latest \
|
||||
"${REPO_ROOT}" || fail "lerobot-eval-base build failed"
|
||||
ok "lerobot-eval-base"
|
||||
|
||||
for BENCHMARK in $BENCHMARKS; do
|
||||
LOCAL_TAG="lerobot-benchmark-${BENCHMARK}:latest"
|
||||
DOCKERFILE="${SCRIPT_DIR}/Dockerfile.eval-${BENCHMARK//_/-}"
|
||||
|
||||
# Handle underscore → hyphen mapping for filename lookup
|
||||
DOCKERFILE_HYPHEN="${SCRIPT_DIR}/Dockerfile.eval-${BENCHMARK//_/-}"
|
||||
DOCKERFILE_UNDERSCORE="${SCRIPT_DIR}/Dockerfile.eval-${BENCHMARK}"
|
||||
if [[ -f "$DOCKERFILE_HYPHEN" ]]; then
|
||||
DOCKERFILE="$DOCKERFILE_HYPHEN"
|
||||
elif [[ -f "$DOCKERFILE_UNDERSCORE" ]]; then
|
||||
DOCKERFILE="$DOCKERFILE_UNDERSCORE"
|
||||
else
|
||||
fail "No Dockerfile found for benchmark '${BENCHMARK}' (tried ${DOCKERFILE_HYPHEN} and ${DOCKERFILE_UNDERSCORE})"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== Building ${LOCAL_TAG} from $(basename ${DOCKERFILE}) ==="
|
||||
docker build \
|
||||
--progress="${PROGRESS}" \
|
||||
-f "${DOCKERFILE}" \
|
||||
-t "${LOCAL_TAG}" \
|
||||
"${REPO_ROOT}" || fail "${LOCAL_TAG} build failed"
|
||||
ok "${LOCAL_TAG}"
|
||||
|
||||
if [[ "$PUSH" == "true" ]]; then
|
||||
HUB_TAG="${HUB_ORG}/lerobot-benchmark-${BENCHMARK}:latest"
|
||||
docker tag "${LOCAL_TAG}" "${HUB_TAG}"
|
||||
docker push "${HUB_TAG}" || fail "push ${HUB_TAG} failed"
|
||||
ok "Pushed ${HUB_TAG}"
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "=== Smoke-testing images ==="
|
||||
for BENCHMARK in $BENCHMARKS; do
|
||||
LOCAL_TAG="lerobot-benchmark-${BENCHMARK}:latest"
|
||||
echo " Smoke test: ${LOCAL_TAG}"
|
||||
docker run --rm -e BENCHMARK="${BENCHMARK}" \
|
||||
"${LOCAL_TAG}" bash docker/smoke_test_benchmark.sh \
|
||||
&& ok "smoke test ${BENCHMARK}" \
|
||||
|| echo "[WARN] smoke test failed for ${BENCHMARK} (may need GPU)"
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "All benchmark images built successfully."
|
||||
if [[ "$PUSH" == "true" ]]; then
|
||||
echo "Pushed to Docker Hub under: ${HUB_ORG}/"
|
||||
echo ""
|
||||
echo "To use Hub images in eval, pass:"
|
||||
for BENCHMARK in $BENCHMARKS; do
|
||||
echo " --eval.docker.image=${HUB_ORG}/lerobot-benchmark-${BENCHMARK}:latest"
|
||||
done
|
||||
fi
|
||||
@@ -1,115 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# Smoke-test a benchmark container: verifies imports and CLI entry-points.
|
||||
#
|
||||
# Build and run for a specific benchmark:
|
||||
# docker build --build-arg BENCHMARK=libero -f docker/Dockerfile.benchmark -t lerobot-benchmark-libero .
|
||||
# docker run --gpus all --rm -e BENCHMARK=libero lerobot-benchmark-libero bash docker/smoke_test_benchmark.sh
|
||||
#
|
||||
# Test all benchmarks individually:
|
||||
# for b in libero libero_plus robomme robocasa; do
|
||||
# docker build --build-arg BENCHMARK=$b -f docker/Dockerfile.benchmark -t lerobot-benchmark-$b .
|
||||
# docker run --gpus all --rm -e BENCHMARK=$b lerobot-benchmark-$b bash docker/smoke_test_benchmark.sh
|
||||
# done
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
BENCHMARK="${BENCHMARK:-libero}"
|
||||
PASS=0
|
||||
FAIL=0
|
||||
|
||||
ok() { echo "[PASS] $*"; PASS=$((PASS + 1)); }
|
||||
fail() { echo "[FAIL] $*"; FAIL=$((FAIL + 1)); }
|
||||
|
||||
python_import() {
|
||||
local module="$1"
|
||||
if python -c "import ${module}" 2>/dev/null; then
|
||||
ok "import ${module}"
|
||||
else
|
||||
fail "import ${module}"
|
||||
fi
|
||||
}
|
||||
|
||||
cli_help() {
|
||||
local cmd="$1"
|
||||
if "${cmd}" --help > /dev/null 2>&1; then
|
||||
ok "${cmd} --help"
|
||||
else
|
||||
fail "${cmd} --help"
|
||||
fi
|
||||
}
|
||||
|
||||
echo "=== Smoke test: benchmark=${BENCHMARK} ==="
|
||||
|
||||
# ── lerobot core ──────────────────────────────────────────────────────────────
|
||||
python_import "lerobot"
|
||||
python_import "lerobot.envs"
|
||||
python_import "lerobot.configs.eval"
|
||||
cli_help "lerobot-eval"
|
||||
|
||||
# ── Benchmark-specific env import ─────────────────────────────────────────────
|
||||
case "${BENCHMARK}" in
|
||||
libero)
|
||||
python_import "lerobot.envs.libero"
|
||||
python -c "
|
||||
from lerobot.envs.configs import LiberoEnv
|
||||
cfg = LiberoEnv(task='libero_spatial/KITCHEN_SCENE1_open_the_bottom_drawer_of_the_cabinet')
|
||||
print(' LiberoEnv config OK:', cfg.type)
|
||||
" && ok "LiberoEnv config instantiation" || fail "LiberoEnv config instantiation"
|
||||
;;
|
||||
|
||||
libero_plus)
|
||||
python_import "lerobot.envs.libero"
|
||||
python -c "
|
||||
from lerobot.envs.configs import LiberoPlusEnv
|
||||
cfg = LiberoPlusEnv()
|
||||
print(' LiberoPlusEnv config OK:', cfg.type)
|
||||
" && ok "LiberoPlusEnv config instantiation" || fail "LiberoPlusEnv config instantiation"
|
||||
# Verify the LIBERO-plus package itself is importable
|
||||
python_import "libero"
|
||||
python_import "robosuite"
|
||||
;;
|
||||
|
||||
robomme)
|
||||
python_import "lerobot.envs.robomme"
|
||||
python -c "
|
||||
from lerobot.envs.robomme import ROBOMME_TASKS, RoboMMEGymEnv
|
||||
assert len(ROBOMME_TASKS) == 16, f'Expected 16 tasks, got {len(ROBOMME_TASKS)}'
|
||||
print(' ROBOMME_TASKS OK:', ROBOMME_TASKS[:3], '...')
|
||||
" && ok "RoboMME task list" || fail "RoboMME task list"
|
||||
python -c "
|
||||
from lerobot.envs.configs import RoboMMEEnv
|
||||
cfg = RoboMMEEnv(task='PickXtimes')
|
||||
print(' RoboMMEEnv config OK:', cfg.type)
|
||||
" && ok "RoboMMEEnv config instantiation" || fail "RoboMMEEnv config instantiation"
|
||||
python_import "robomme"
|
||||
;;
|
||||
|
||||
robocasa)
|
||||
python_import "lerobot.envs.robocasa"
|
||||
python -c "
|
||||
from lerobot.envs.robocasa import ACTION_DIM, STATE_DIM
|
||||
assert ACTION_DIM == 12, f'Expected ACTION_DIM=12, got {ACTION_DIM}'
|
||||
assert STATE_DIM == 16, f'Expected STATE_DIM=16, got {STATE_DIM}'
|
||||
print(' ACTION_DIM:', ACTION_DIM, ' STATE_DIM:', STATE_DIM)
|
||||
" && ok "RoboCasa constants" || fail "RoboCasa constants"
|
||||
python -c "
|
||||
from lerobot.envs.configs import RoboCasaEnv
|
||||
cfg = RoboCasaEnv(task='PickPlaceCounterToCabinet')
|
||||
print(' RoboCasaEnv config OK:', cfg.type)
|
||||
" && ok "RoboCasaEnv config instantiation" || fail "RoboCasaEnv config instantiation"
|
||||
python_import "robocasa"
|
||||
python_import "robosuite"
|
||||
;;
|
||||
|
||||
*)
|
||||
echo "Unknown BENCHMARK='${BENCHMARK}'. Valid values: libero, libero_plus, robomme, robocasa"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
# ── Summary ───────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "=== Results: ${PASS} passed, ${FAIL} failed ==="
|
||||
if [ "${FAIL}" -gt 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
@@ -19,8 +19,6 @@
|
||||
title: Multi GPU training
|
||||
- local: peft_training
|
||||
title: Training with PEFT (e.g., LoRA)
|
||||
- local: benchmark_training
|
||||
title: Benchmark Training & Evaluation
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
|
||||
@@ -48,7 +48,7 @@ python -m lerobot.async_inference.robot_client \
|
||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server (cuda, mps, xpu, cpu)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
|
||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||
|
||||
@@ -1,398 +0,0 @@
|
||||
# Benchmark Training & Evaluation
|
||||
|
||||
This guide explains how to train and evaluate policies on the simulation benchmarks
|
||||
integrated in LeRobot: **LIBERO**, **LIBERO-plus**, **MetaWorld**, **RoboCasa**, and **RoboMME**.
|
||||
|
||||
The workflow is:
|
||||
|
||||
1. Pick one or more benchmarks.
|
||||
2. For each benchmark, train a policy on its combined dataset (multi-GPU).
|
||||
3. Upload the trained policy to the Hugging Face Hub.
|
||||
4. Evaluate the policy on every task suite within that benchmark.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Install the benchmark-specific dependencies for the environments you want to evaluate on:
|
||||
|
||||
```bash
|
||||
# LIBERO (original)
|
||||
pip install -e ".[libero]"
|
||||
|
||||
# LIBERO-plus
|
||||
pip install -e ".[libero_plus]"
|
||||
|
||||
# MetaWorld
|
||||
pip install -e ".[metaworld]"
|
||||
|
||||
# RoboCasa
|
||||
pip install -e ".[robocasa]"
|
||||
|
||||
# RoboMME
|
||||
pip install -e ".[robomme]"
|
||||
```
|
||||
|
||||
`libero_plus` includes the same EGL probe dependencies as `libero` so headless
|
||||
renderer setup is consistent between both installs.
|
||||
|
||||
If your environment has CMake build-isolation issues, use the same fallback as
|
||||
standard LIBERO installs:
|
||||
|
||||
```bash
|
||||
PATH=/usr/bin:/bin:$PATH pip install --no-build-isolation -e ".[libero-plus]"
|
||||
```
|
||||
|
||||
For multi-GPU training you also need [Accelerate](https://huggingface.co/docs/accelerate):
|
||||
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
## Docker-isolated evaluation (EnvHub)
|
||||
|
||||
LeRobot eval now supports running the full eval worker in a Docker container
|
||||
while keeping policy loading compatible with local checkpoints and local code changes.
|
||||
|
||||
Use `lerobot-eval` with `--eval.runtime=docker`:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=outputs/train/my_policy/checkpoints/050000/pretrained_model \
|
||||
--env.type=libero_plus \
|
||||
--eval.runtime=docker \
|
||||
--eval.docker.envhub_ref=envhub://lerobot/libero_plus@v1 \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=10
|
||||
```
|
||||
|
||||
`eval.docker.envhub_ref` is optional. If omitted, LeRobot resolves a default
|
||||
image from `env.type`. You can also override the image directly:
|
||||
|
||||
```bash
|
||||
--eval.docker.image=docker://ghcr.io/huggingface/lerobot-eval-libero-plus:latest
|
||||
```
|
||||
|
||||
By default (`eval.docker.use_local_code=true`), the local repository is mounted
|
||||
in the container and added to `PYTHONPATH`, so edited policy/env code and local
|
||||
checkpoints continue to work without rebuilding the image for each change.
|
||||
|
||||
Common Docker runtime options:
|
||||
|
||||
```bash
|
||||
--eval.docker.pull=true \
|
||||
--eval.docker.gpus=all \
|
||||
--eval.docker.shm_size=8g \
|
||||
--eval.docker.use_local_code=true
|
||||
```
|
||||
|
||||
The benchmark runner supports the same Docker eval path (extra args are
|
||||
forwarded to each generated `lerobot-eval` call):
|
||||
|
||||
```bash
|
||||
lerobot-benchmark eval \
|
||||
--benchmarks libero_plus,robocasa \
|
||||
--hub-user $HF_USER \
|
||||
--n-episodes 50 \
|
||||
--eval.runtime=docker \
|
||||
--eval.docker.pull=true
|
||||
```
|
||||
|
||||
Build benchmark images locally:
|
||||
|
||||
```bash
|
||||
make build-eval-images
|
||||
```
|
||||
|
||||
## Fast single-machine eval tuning
|
||||
|
||||
`lerobot-eval` now has two orthogonal throughput knobs:
|
||||
|
||||
- `eval.batch_size`: number of sub-envs per task (inside one vector env).
|
||||
- `env.max_parallel_tasks`: number of tasks scheduled concurrently.
|
||||
- `eval.instance_count`: number of full eval instances (process-level sharding).
|
||||
|
||||
Use them in this order:
|
||||
|
||||
1. Increase `eval.batch_size` first for per-task throughput.
|
||||
2. Then increase `env.max_parallel_tasks` to overlap tasks, while monitoring RAM/VRAM.
|
||||
3. Optionally increase `eval.instance_count` for process-level parallelism (best with enough CPU/RAM and small models).
|
||||
|
||||
The eval logs print the active scheduler mode (`sequential`, `threaded`, or `batched_lazy`) so you can verify the effective concurrency path.
|
||||
|
||||
### Suggested starting points
|
||||
|
||||
| Benchmark | Conservative | Faster (single GPU) | Notes |
|
||||
|---|---|---|---|
|
||||
| `libero` / `libero_plus` | `eval.batch_size=1`, `env.max_parallel_tasks=4` | `eval.batch_size=1`, `env.max_parallel_tasks=16` | For large suite sweeps, increase `max_parallel_tasks` before `batch_size` to avoid MuJoCo memory spikes. |
|
||||
| `metaworld` | `eval.batch_size=8`, `env.max_parallel_tasks=1` | `eval.batch_size=16`, `env.max_parallel_tasks=2` | Prefer larger per-task vectorization first. |
|
||||
| `robocasa` | `eval.batch_size=4`, `env.max_parallel_tasks=1` | `eval.batch_size=8`, `env.max_parallel_tasks=2` | Rendering/memory can dominate at high image resolution. |
|
||||
| `robomme` | `eval.batch_size=4`, `env.max_parallel_tasks=1` | `eval.batch_size=8`, `env.max_parallel_tasks=2` | Start small and scale gradually with task count. |
|
||||
|
||||
### Local fast eval recipe
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=$HF_USER/smolvla_libero_plus \
|
||||
--env.type=libero_plus \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--env.max_parallel_tasks=16 \
|
||||
--eval.instance_count=2 \
|
||||
--rename_map='{"observation.images.image":"observation.images.camera1","observation.images.image2":"observation.images.camera2"}' \
|
||||
--output_dir=outputs/eval/smolvla_libero_plus \
|
||||
--push_to_hub=true
|
||||
```
|
||||
|
||||
### Docker fast eval recipe
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=$HF_USER/smolvla_libero_plus \
|
||||
--env.type=libero_plus \
|
||||
--eval.runtime=docker \
|
||||
--eval.docker.envhub_ref=envhub://lerobot/libero_plus@v1 \
|
||||
--eval.docker.gpus=all \
|
||||
--eval.docker.shm_size=16g \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--env.max_parallel_tasks=16
|
||||
```
|
||||
|
||||
## Quick start — single benchmark
|
||||
|
||||
Train SmolVLA on LIBERO-plus with 4 GPUs for 50 000 steps:
|
||||
|
||||
```bash
|
||||
lerobot-benchmark train \
|
||||
--benchmarks libero_plus \
|
||||
--policy-path lerobot/smolvla_base \
|
||||
--hub-user $HF_USER \
|
||||
--num-gpus 4 \
|
||||
--steps 50000 \
|
||||
--batch-size 32 \
|
||||
--wandb
|
||||
```
|
||||
|
||||
This trains on the combined LIBERO-plus dataset and pushes the checkpoint to
|
||||
`$HF_USER/smolvla_libero_plus` on the Hub.
|
||||
|
||||
Then evaluate on **all four** LIBERO suites (spatial, object, goal, 10):
|
||||
|
||||
```bash
|
||||
lerobot-benchmark eval \
|
||||
--benchmarks libero_plus \
|
||||
--hub-user $HF_USER \
|
||||
--n-episodes 50
|
||||
```
|
||||
|
||||
This automatically runs a separate `lerobot-eval` for each suite.
|
||||
|
||||
## Full sweep — multiple benchmarks
|
||||
|
||||
Run training **and** evaluation across all benchmarks:
|
||||
|
||||
```bash
|
||||
lerobot-benchmark all \
|
||||
--benchmarks libero,libero_plus,metaworld,robocasa,robomme \
|
||||
--policy-path lerobot/smolvla_base \
|
||||
--hub-user $HF_USER \
|
||||
--num-gpus 4 \
|
||||
--steps 50000 \
|
||||
--batch-size 32 \
|
||||
--wandb \
|
||||
--push-eval-to-hub
|
||||
```
|
||||
|
||||
For each benchmark the runner:
|
||||
1. Trains a policy on its dataset.
|
||||
2. Evaluates on every eval task in the benchmark (e.g. 4 suites for LIBERO).
|
||||
3. Pushes HF-native `.eval_results` rows (and optional artifacts) to the Hub.
|
||||
|
||||
<Tip>
|
||||
|
||||
Use `--dry-run` to print the exact `lerobot-train` / `lerobot-eval` commands without executing them, so you can inspect or modify them before running.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Using the CLI directly (without the benchmark runner)
|
||||
|
||||
You can also compose the commands yourself. The benchmark runner is a thin wrapper; here is what it does under the hood.
|
||||
|
||||
### Training
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
$(which lerobot-train) \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=$HF_USER/libero_plus \
|
||||
--policy.repo_id=$HF_USER/smolvla_libero_plus \
|
||||
--env.type=libero_plus \
|
||||
--env.task=libero_spatial \
|
||||
--steps=50000 \
|
||||
--batch_size=32 \
|
||||
--eval_freq=10000 \
|
||||
--save_freq=10000 \
|
||||
--output_dir=outputs/train/smolvla_libero_plus \
|
||||
--job_name=smolvla_libero_plus \
|
||||
--policy.push_to_hub=true \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
### Evaluation (run once per suite)
|
||||
|
||||
```bash
|
||||
for SUITE in libero_spatial libero_object libero_goal libero_10; do
|
||||
lerobot-eval \
|
||||
--policy.path=$HF_USER/smolvla_libero_plus \
|
||||
--env.type=libero_plus \
|
||||
--env.task=$SUITE \
|
||||
--eval.n_episodes=50 \
|
||||
--eval.batch_size=10 \
|
||||
--output_dir=outputs/eval/smolvla_libero_plus/$SUITE \
|
||||
--policy.device=cuda \
|
||||
--push_to_hub=true \
|
||||
--benchmark_dataset_id=lerobot/sim-benchmarks
|
||||
done
|
||||
```
|
||||
|
||||
## Available benchmarks
|
||||
|
||||
| Benchmark | Env type | Dataset | Eval tasks | Action dim |
|
||||
|---|---|---|---|---|
|
||||
| `libero` | `libero` | `{hub_user}/libero` | spatial, object, goal, 10 | 7 |
|
||||
| `libero_plus` | `libero_plus` | `{hub_user}/libero_plus` | spatial, object, goal, 10 | 7 |
|
||||
| `metaworld` | `metaworld` | `{hub_user}/metaworld` | push-v2 | 4 |
|
||||
| `robocasa` | `robocasa` | `{hub_user}/robocasa` | PickPlaceCounterToCabinet | 12 |
|
||||
| `robomme` | `robomme` | `{hub_user}/robomme` | PickXtimes | 8 |
|
||||
|
||||
Run `lerobot-benchmark list` to see the full registry with all eval tasks.
|
||||
|
||||
## Policy naming convention
|
||||
|
||||
The benchmark runner stores trained policies under:
|
||||
|
||||
```
|
||||
{hub_user}/{policy_name}_{benchmark}
|
||||
```
|
||||
|
||||
The default `--policy-name` is `smolvla`. So training on `libero_plus` as user `alice` produces `alice/smolvla_libero_plus`.
|
||||
|
||||
You can override this, e.g. `--policy-name pi05` if training π₀.₅ instead.
|
||||
|
||||
## Multi-GPU considerations
|
||||
|
||||
The effective batch size is `batch_size × num_gpus`. With `--batch-size=32` and
|
||||
`--num-gpus=4`, you train with an effective batch of 128 per step. LeRobot does **not**
|
||||
auto-scale the learning rate; see the [Multi-GPU Training guide](./multi_gpu_training) for
|
||||
details on when and how to adjust it.
|
||||
|
||||
## Custom benchmarks
|
||||
|
||||
To add a new benchmark, edit the `BENCHMARK_REGISTRY` in
|
||||
`src/lerobot/scripts/lerobot_benchmark.py`:
|
||||
|
||||
```python
|
||||
from lerobot.scripts.lerobot_benchmark import BenchmarkEntry, BENCHMARK_REGISTRY
|
||||
|
||||
BENCHMARK_REGISTRY["my_benchmark"] = BenchmarkEntry(
|
||||
dataset_repo_id="{hub_user}/my_dataset",
|
||||
env_type="my_env",
|
||||
env_task="MyDefaultTask",
|
||||
eval_tasks=["TaskA", "TaskB", "TaskC"],
|
||||
)
|
||||
```
|
||||
|
||||
Then use `--benchmarks my_benchmark` as usual. The runner will train once and
|
||||
evaluate separately on TaskA, TaskB, and TaskC.
|
||||
|
||||
## Outputs
|
||||
|
||||
After training and evaluation, your outputs directory looks like:
|
||||
|
||||
```
|
||||
outputs/
|
||||
├── train/
|
||||
│ ├── smolvla_libero/
|
||||
│ │ ├── checkpoints/
|
||||
│ │ └── ...
|
||||
│ ├── smolvla_libero_plus/
|
||||
│ ├── smolvla_robocasa/
|
||||
│ └── smolvla_robomme/
|
||||
└── eval/
|
||||
├── smolvla_libero/
|
||||
│ ├── libero_spatial/
|
||||
│ │ ├── eval_info.json
|
||||
│ │ └── videos/
|
||||
│ ├── libero_object/
|
||||
│ ├── libero_goal/
|
||||
│ └── libero_10/
|
||||
├── smolvla_libero_plus/
|
||||
│ ├── libero_spatial/
|
||||
│ ├── libero_object/
|
||||
│ ├── libero_goal/
|
||||
│ └── libero_10/
|
||||
├── smolvla_robocasa/
|
||||
└── smolvla_robomme/
|
||||
```
|
||||
|
||||
Each `eval_info.json` contains per-episode rewards, success rates, and aggregate metrics.
|
||||
|
||||
## HF Eval Results + Leaderboard
|
||||
|
||||
LeRobot publishes benchmark scores using Hugging Face's native
|
||||
`/.eval_results/*.yaml` format, which powers model-page eval cards and
|
||||
benchmark leaderboards.
|
||||
|
||||
Add `--push-eval-to-hub` to push results after each eval run:
|
||||
|
||||
```bash
|
||||
lerobot-benchmark eval \
|
||||
--benchmarks libero_plus,robocasa \
|
||||
--hub-user $HF_USER \
|
||||
--benchmark-dataset-id lerobot/sim-benchmarks \
|
||||
--push-eval-to-hub
|
||||
```
|
||||
|
||||
This writes one or more files under `.eval_results/` in the model repo, for example:
|
||||
|
||||
```yaml
|
||||
- dataset:
|
||||
id: lerobot/sim-benchmarks
|
||||
task_id: libero_plus/spatial
|
||||
value: 82.4
|
||||
notes: lerobot-eval
|
||||
```
|
||||
|
||||
Notes:
|
||||
- `--benchmark-dataset-id` points to your consolidated benchmark dataset repo.
|
||||
- `task_id` values are derived from `env.type` and evaluated suite/task names.
|
||||
- Eval artifacts (`eval_info.json`, `eval_config.json`, videos) are still uploaded
|
||||
for provenance, but leaderboard ranking comes from `.eval_results`.
|
||||
|
||||
## Passing extra arguments
|
||||
|
||||
Any arguments after the recognized flags are forwarded to `lerobot-train` or
|
||||
`lerobot-eval`.
|
||||
|
||||
Example (training): use PEFT/LoRA during training.
|
||||
|
||||
```bash
|
||||
lerobot-benchmark train \
|
||||
--benchmarks libero_plus \
|
||||
--policy-path lerobot/smolvla_base \
|
||||
--hub-user $HF_USER \
|
||||
--num-gpus 4 \
|
||||
--steps 50000 \
|
||||
--peft.method_type=LORA --peft.r=16
|
||||
```
|
||||
|
||||
Example (evaluation): forward Docker runtime flags to each `lerobot-eval` call.
|
||||
|
||||
```bash
|
||||
lerobot-benchmark eval \
|
||||
--benchmarks libero_plus \
|
||||
--hub-user $HF_USER \
|
||||
--eval.runtime=docker \
|
||||
--eval.docker.envhub_ref=envhub://lerobot/libero_plus@v1
|
||||
```
|
||||
@@ -32,7 +32,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
# your policy-specific dependencies
|
||||
]
|
||||
requires-python = ">= 3.12"
|
||||
requires-python = ">= 3.11"
|
||||
|
||||
[build-system]
|
||||
build-backend = # your-build-backend
|
||||
@@ -82,7 +82,7 @@ Create your policy implementation by inheriting from LeRobot's base `PreTrainedP
|
||||
# modeling_my_custom_policy.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Any
|
||||
from typing import Dict, Any
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
@@ -91,7 +91,7 @@ class MyCustomPolicy(PreTrainedPolicy):
|
||||
config_class = MyCustomPolicyConfig
|
||||
name = "my_custom_policy"
|
||||
|
||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None):
|
||||
super().__init__(config, dataset_stats)
|
||||
...
|
||||
```
|
||||
@@ -102,7 +102,7 @@ Create processor functions:
|
||||
|
||||
```python
|
||||
# processor_my_custom_policy.py
|
||||
from typing import Any
|
||||
from typing import Dict, Any
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ The EarthRover Mini Plus is a fully open source mobile robot that connects throu
|
||||
### Hardware
|
||||
|
||||
- EarthRover Mini robot
|
||||
- Computer with Python 3.12 or newer
|
||||
- Computer with Python 3.10 or newer
|
||||
- Internet connection
|
||||
|
||||
### Setting Up the Frodobots SDK
|
||||
@@ -170,13 +170,13 @@ Once you can drive the robot well, you can start recording data to train AI mode
|
||||
We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
|
||||
```bash
|
||||
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face username:
|
||||
|
||||
```bash
|
||||
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
|
||||
@@ -155,10 +155,10 @@ Upload your repository to Hugging Face:
|
||||
pip install huggingface_hub
|
||||
|
||||
# Login to Hugging Face
|
||||
hf auth login
|
||||
huggingface-cli login
|
||||
|
||||
# Create a new repository
|
||||
hf repo create my-org/my-custom-env
|
||||
huggingface-cli repo create my-custom-env --type space --org my-org
|
||||
|
||||
# Initialize git and push
|
||||
git init
|
||||
|
||||
@@ -159,7 +159,7 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
|
||||
Add your token to the CLI by running this command:
|
||||
|
||||
```bash
|
||||
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
@@ -327,7 +327,7 @@ You can look for other LeRobot datasets on the hub by searching for `LeRobot` [t
|
||||
You can also push your local dataset to the Hub manually, running:
|
||||
|
||||
```bash
|
||||
hf upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
|
||||
huggingface-cli upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
|
||||
```
|
||||
|
||||
#### Record function
|
||||
@@ -491,7 +491,7 @@ If your local computer doesn't have a powerful GPU you could utilize Google Cola
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
|
||||
```bash
|
||||
hf upload ${HF_USER}/act_so101_test \
|
||||
huggingface-cli upload ${HF_USER}/act_so101_test \
|
||||
outputs/train/act_so101_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
@@ -499,7 +499,7 @@ You can also upload intermediate checkpoints with:
|
||||
|
||||
```bash
|
||||
CKPT=010000
|
||||
hf upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
outputs/train/act_so101_test/checkpoints/${CKPT}/pretrained_model
|
||||
```
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Installation
|
||||
|
||||
This guide uses `conda` (via miniforge) to manage environments (recommended). If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and `ffmpeg` installed with the `libsvtav1` encoder, then skip ahead to [Environment Setup](#step-2-environment-setup).
|
||||
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
|
||||
|
||||
## Step 1 (`conda` only): Install [`miniforge`](https://conda-forge.org/download/)
|
||||
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
|
||||
|
||||
```bash
|
||||
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
|
||||
@@ -11,47 +11,22 @@ bash Miniforge3-$(uname)-$(uname -m).sh
|
||||
|
||||
## Step 2: Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.12:
|
||||
Create a virtual environment with Python 3.10, using conda:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
<hfoptions id="create_venv">
|
||||
<hfoption id="conda">
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.12
|
||||
conda create -y -n lerobot python=3.10
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="uv">
|
||||
|
||||
Then activate your conda environment, you have to do this each time you open a shell to use lerobot:
|
||||
|
||||
```bash
|
||||
uv python install 3.12
|
||||
uv venv --python 3.12
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
Then activate your virtual environment, you have to do this each time you open a shell to use lerobot:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
<hfoptions id="activate_venv">
|
||||
<hfoption id="conda">```bash
|
||||
conda activate lerobot
|
||||
```</hfoption>
|
||||
<hfoption id="uv">
|
||||
```bash
|
||||
# Linux/macOSsource
|
||||
source .venv/bin/activate
|
||||
# Windows PowerShell
|
||||
source .venv\Scripts\Activate.ps1
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
When using `conda`, install `ffmpeg` in your environment:
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
ffmpeg -version # ffmpeg 8.X is not yet supported !
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
@@ -72,9 +47,6 @@ ffmpeg -version # ffmpeg 8.X is not yet supported !
|
||||
> conda install evdev -c conda-forge
|
||||
> ```
|
||||
|
||||
> [!IMPORTANT]
|
||||
> If you are using `uv` you will have to install `ffmpeg` system-wide (outside of the virtual environment). You rely on `uv` and `torchcodec` ability to dynamically link to the system `ffmpeg`.
|
||||
|
||||
## Step 3: Install LeRobot 🤗
|
||||
|
||||
### From Source
|
||||
@@ -88,45 +60,23 @@ cd lerobot
|
||||
|
||||
Then, install the library in editable mode. This is useful if you plan to contribute to the code.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
<hfoptions id="install_lerobot_src">
|
||||
<hfoption id="conda">
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="uv">
|
||||
```bash
|
||||
uv pip install -e .
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Installation from PyPI
|
||||
|
||||
**Core Library:**
|
||||
Install the base package with:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
<hfoptions id="install_lerobot_pypi">
|
||||
<hfoption id="conda">
|
||||
```bash
|
||||
pip install lerobot
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="uv">
|
||||
```bash
|
||||
uv pip install lerobot
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
_This installs only the default dependencies._
|
||||
|
||||
**Extra Features:**
|
||||
To install additional functionality, use one of the following (If you are using `uv`, replace `pip install` with `uv pip install` in the commands below.):
|
||||
To install additional functionality, use one of the following:
|
||||
|
||||
```bash
|
||||
pip install 'lerobot[all]' # All available features
|
||||
@@ -140,10 +90,13 @@ _Replace `[...]` with your desired features._
|
||||
For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||
To install these for Linux run:
|
||||
To install these for linux run:
|
||||
|
||||
```bash
|
||||
sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev
|
||||
@@ -153,7 +106,7 @@ For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/
|
||||
|
||||
## Optional dependencies
|
||||
|
||||
LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`. If you are using `uv`, replace `pip install` with `uv pip install` in the commands below.
|
||||
LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`.
|
||||
|
||||
### Simulations
|
||||
|
||||
|
||||
@@ -279,13 +279,13 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
|
||||
Add your token to the CLI by running this command:
|
||||
|
||||
```bash
|
||||
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
|
||||
```bash
|
||||
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
|
||||
@@ -34,6 +34,11 @@ As described by Physical Intelligence, while AI has achieved remarkable success
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Training Data and Capabilities
|
||||
|
||||
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
|
||||
|
||||
@@ -36,6 +36,11 @@ This diverse training mixture creates a "curriculum" that enables generalization
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Usage
|
||||
|
||||
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
@@ -43,11 +43,16 @@ This approach can transform **any existing VLM** into a VLA by training it to pr
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install the pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Training a Custom FAST Tokenizer
|
||||
|
||||
You have two options for the FAST tokenizer:
|
||||
|
||||
1. **Use the pre-trained tokenizer**: The `lerobot/fast-action-tokenizer` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
|
||||
1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
|
||||
|
||||
2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data.
|
||||
|
||||
@@ -109,15 +114,15 @@ lerobot-train \
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
| -------------------------------------- | -------------------------------------------------- | ------------------------------- |
|
||||
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
|
||||
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
|
||||
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
|
||||
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
|
||||
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
|
||||
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `lerobot/fast-action-tokenizer` |
|
||||
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
|
||||
| Parameter | Description | Default |
|
||||
| -------------------------------------- | -------------------------------------------------- | ---------------------------- |
|
||||
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
|
||||
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
|
||||
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
|
||||
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
|
||||
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
|
||||
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` |
|
||||
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
|
||||
|
||||
## Inference
|
||||
|
||||
|
||||
@@ -1,49 +1,23 @@
|
||||
# Unitree G1
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/unitree_thumbnail.jpg"
|
||||
alt="Unitree G1 locomanipulation demo"
|
||||
style={{ width: "100%" }}
|
||||
/>
|
||||
This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion.
|
||||
|
||||
The Unitree G1 humanoid is now supported in LeRobot! You can teleoperate, train locomanipulation policies, test in sim, and more. Both 29 and 23 DoF variants are supported.
|
||||
## About
|
||||
|
||||
We support both 29 and 23 DOF G1 EDU version. We introduce:
|
||||
|
||||
- **`unitree g1` robot class, handling low level read/write from/to the humanoid**
|
||||
- **ZMQ socket bridge** for remote communication and camera streaming, allowing for remote policy deployment over wlan, eth or directly on the robot
|
||||
- **Locomotion policies** from NVIDIA gr00t and Amazon FAR Holosoma
|
||||
- **Simulation mode** for testing policies without the physical robot in mujoco
|
||||
|
||||
---
|
||||
|
||||
## Part 1: Getting Started
|
||||
## Connection guide
|
||||
|
||||
### Install LeRobot on Your Machine
|
||||
### Step 1: Configure Ethernet Interface
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.12
|
||||
conda activate lerobot
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
```
|
||||
|
||||
### Test the Installation (Simulation)
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.id=wbc_unitree \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
This will launch a [MuJoCo sim instance](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main) for the G1.
|
||||
|
||||
- Press `9` to release the robot
|
||||
- Press `7` / `8` to increase / decrease waist height
|
||||
|
||||
### Connect to the Robot
|
||||
|
||||
The G1's Ethernet IP is fixed at `192.168.123.164`. Your machine must have a static IP on the same subnet: `192.168.123.x` where `x ≠ 164`.
|
||||
Set a static IP on the same subnet as the robot:
|
||||
|
||||
```bash
|
||||
# Replace 'enp131s0' with your ethernet interface name (check with `ip a`)
|
||||
@@ -52,200 +26,272 @@ sudo ip addr add 192.168.123.200/24 dev enp131s0
|
||||
sudo ip link set enp131s0 up
|
||||
```
|
||||
|
||||
### SSH into the Robot
|
||||
**Note**: The G1's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` with x ≠ 164.
|
||||
|
||||
### Step 2: SSH into the Robot
|
||||
|
||||
```bash
|
||||
ssh unitree@192.168.123.164
|
||||
# Password: 123
|
||||
```
|
||||
|
||||
### Install LeRobot on the G1
|
||||
|
||||
From the robot:
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.12
|
||||
conda activate lerobot
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
```
|
||||
|
||||
> **Note:** The Unitree SDK requires CycloneDDS v0.10.2. See the [Unitree SDK docs](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
|
||||
You should now be connected to the G1's Orin.
|
||||
|
||||
---
|
||||
|
||||
## Part 2: Enable WiFi on the Robot
|
||||
|
||||
Wi-Fi connectivity is blocked by default on the G1. To activate:
|
||||
Wlan0 is disabled by default on the G1. To enable it:
|
||||
|
||||
### Step 1: Enable WiFi Hardware
|
||||
|
||||
```bash
|
||||
sudo rfkill unblock wifi
|
||||
sudo rfkill unblock all
|
||||
|
||||
# Bring up wlan0
|
||||
sudo ip link set wlan0 up
|
||||
|
||||
# Enable NetworkManager control of wlan0
|
||||
sudo nmcli radio wifi on
|
||||
sudo nmcli device set wlan0 managed yes
|
||||
sudo systemctl restart NetworkManager
|
||||
```
|
||||
|
||||
**On your laptop** (share internet via Ethernet):
|
||||
### Step 2: Enable Internet Forwarding
|
||||
|
||||
**On your laptop:**
|
||||
|
||||
```bash
|
||||
# Enable IP forwarding
|
||||
sudo sysctl -w net.ipv4.ip_forward=1
|
||||
|
||||
# Replace wlp132s0f0 with your WiFi interface name
|
||||
# Set up NAT (replace wlp132s0f0 with your WiFi interface)
|
||||
sudo iptables -t nat -A POSTROUTING -o wlp132s0f0 -s 192.168.123.0/24 -j MASQUERADE
|
||||
sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTABLISHED -j ACCEPT
|
||||
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
|
||||
```
|
||||
|
||||
**On the G1** (set default route through your laptop):
|
||||
**On the G1:**
|
||||
|
||||
```bash
|
||||
# Add laptop as default gateway
|
||||
sudo ip route del default 2>/dev/null || true
|
||||
sudo ip route add default via 192.168.123.200 dev eth0
|
||||
echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
|
||||
|
||||
# Verify
|
||||
# Test connection
|
||||
ping -c 3 8.8.8.8
|
||||
```
|
||||
|
||||
**Connect to a WiFi network:**
|
||||
### Step 3: Connect to WiFi Network
|
||||
|
||||
```bash
|
||||
# List available networks
|
||||
nmcli device wifi list
|
||||
|
||||
# Connect to your WiFi (example)
|
||||
sudo nmcli connection add type wifi ifname wlan0 con-name "YourNetwork" ssid "YourNetwork"
|
||||
sudo nmcli connection modify "YourNetwork" wifi-sec.key-mgmt wpa-psk
|
||||
sudo nmcli connection modify "YourNetwork" wifi-sec.psk "YourPassword"
|
||||
sudo nmcli connection modify "YourNetwork" connection.autoconnect yes
|
||||
sudo nmcli connection up "YourNetwork"
|
||||
|
||||
# Check WiFi IP address
|
||||
ip a show wlan0
|
||||
```
|
||||
|
||||
You can now SSH over WiFi:
|
||||
### Step 4: SSH Over WiFi
|
||||
|
||||
Once connected to WiFi, note the robot's IP address and disconnect the Ethernet cable. You can now SSH over WiFi:
|
||||
|
||||
```bash
|
||||
ssh unitree@<ROBOT_WIFI_IP>
|
||||
ssh unitree@<YOUR_ROBOT_IP>
|
||||
# Password: 123
|
||||
```
|
||||
|
||||
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address.
|
||||
|
||||
---
|
||||
|
||||
## Part 3: Teleoperation & Locomotion
|
||||
## Part 3: Robot Server Setup
|
||||
|
||||
### Run the Robot Server
|
||||
### Step 1: Install LeRobot on the Orin
|
||||
|
||||
SSH into the robot and install LeRobot:
|
||||
|
||||
```bash
|
||||
ssh unitree@<YOUR_ROBOT_IP>
|
||||
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
```
|
||||
|
||||
**Note**: The Unitree SDK requires CycloneDDS v0.10.2 to be installed. See the [Unitree SDK documentation](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
|
||||
|
||||
### Step 2: Run the Robot Server
|
||||
|
||||
On the robot:
|
||||
|
||||
```bash
|
||||
python src/lerobot/robots/unitree_g1/run_g1_server.py --camera
|
||||
python src/lerobot/robots/unitree_g1/run_g1_server.py
|
||||
```
|
||||
|
||||
### Run the Locomotion Policy
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.robot_ip=<ROBOT_IP> \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.id=wbc_unitree \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "<ROBOT_IP>", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--display_data=true \
|
||||
--robot.controller=HolosomaLocomotionController
|
||||
```
|
||||
|
||||
We support both [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) and [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl).
|
||||
**Important**: Keep this terminal running. The server must be active for remote control.
|
||||
|
||||
---
|
||||
|
||||
## Part 4: Loco-Manipulation with the Homunculus Exoskeleton
|
||||
## Part 4: Controlling the robot
|
||||
|
||||
We provide a loco-manipulation solution via the Homunculus Exoskeleton — an open-source 7 DoF exoskeleton for whole-body control. Assembly instructions [here](https://github.com/nepyope/hmc_exo).
|
||||
With the robot server running, you can now control the robot remotely. Let's launch a locomotion policy
|
||||
|
||||
### Calibrate
|
||||
### Step 1: Install LeRobot on your machine
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
```
|
||||
|
||||
### Step 2: Update Robot IP in Config
|
||||
|
||||
Edit the config file to match your robot's WiFi IP:
|
||||
|
||||
```python
|
||||
# In src/lerobot/robots/unitree_g1/config_unitree_g1.py
|
||||
robot_ip: str = "<YOUR_ROBOT_IP>" # Replace with your robot's WiFi IP.
|
||||
```
|
||||
|
||||
### Step 3: Run the Locomotion Policy
|
||||
|
||||
```bash
|
||||
# Run GR00T locomotion controller
|
||||
python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1"
|
||||
|
||||
# Run Holosoma locomotion controller
|
||||
python examples/unitree_g1/holosoma_locomotion.py
|
||||
|
||||
```
|
||||
|
||||
Press `Ctrl+C` to stop the policy.
|
||||
|
||||
---
|
||||
|
||||
## Running in Simulation Mode (MuJoCo)
|
||||
|
||||
You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI.
|
||||
|
||||
### Calibrate Exoskeleton Teleoperator
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo
|
||||
```
|
||||
|
||||
During calibration move each joint through its entire range. After fitting, move the joint in a neutral position and press `n` to advance.
|
||||
### Teleoperate in Simulation
|
||||
|
||||
### Record a Dataset
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--fps=100
|
||||
```
|
||||
|
||||
### Record Dataset in Simulation
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
> **Note:** Omit `--teleop.left_arm_config.port` and `--teleop.right_arm_config.port` if you're only using the joystick.
|
||||
|
||||
Example dataset: [nepyope/unitree_box_move_blue_full](https://huggingface.co/datasets/nepyope/unitree_box_move_blue_full)
|
||||
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
|
||||
|
||||
---
|
||||
|
||||
## Part 5: Training & Inference
|
||||
## Running on Real Robot
|
||||
|
||||
### Train
|
||||
Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot.
|
||||
|
||||
### Start the Camera Server
|
||||
|
||||
On the robot, start the ZMQ image server:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--policy.type=pi05 \
|
||||
--output_dir=./outputs/pi05_training \
|
||||
--job_name=pi05_training \
|
||||
--policy.repo_id=your-username/your-repo-id \
|
||||
--policy.pretrained_path=lerobot/pi05_base \
|
||||
--policy.compile_model=true \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--wandb.enable=true \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.train_expert_only=false \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--batch_size=32
|
||||
python src/lerobot/cameras/zmq/image_server.py
|
||||
```
|
||||
|
||||
### Inference with RTC
|
||||
Keep this running in a separate terminal for camera streaming during recording.
|
||||
|
||||
Once trained, we recommend deploying policies using inference-time RTC:
|
||||
### Teleoperate Real Robot
|
||||
|
||||
```bash
|
||||
python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=your-username/your-repo-id \
|
||||
--policy.device=cuda \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.controller=HolosomaLocomotionController \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "<ROBOT_IP>", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--task="task_description" \
|
||||
--duration=1000 \
|
||||
--fps=30 \
|
||||
--rtc.enabled=true
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--fps=100
|
||||
```
|
||||
|
||||
### Record Dataset on Real Robot
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
**Note**: Update `server_address` to match your robot's camera server IP.
|
||||
|
||||
Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real)
|
||||
|
||||
---
|
||||
|
||||
## Additional Resources
|
||||
@@ -254,8 +300,8 @@ python examples/rtc/eval_with_real_robot.py \
|
||||
- [GR00T-WholeBodyControl](https://github.com/NVlabs/GR00T-WholeBodyControl)
|
||||
- [Holosoma](https://github.com/amazon-far/holosoma)
|
||||
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||
- [Unitree IL LeRobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
|
||||
- [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
|
||||
|
||||
---
|
||||
|
||||
_Last updated: March 2026_
|
||||
_Last updated: December 2025_
|
||||
|
||||
@@ -1,490 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SLURM-distributed SARM RA-BC annotation pipeline.
|
||||
|
||||
Computes SARM progress values for all frames in a dataset, distributed across
|
||||
SLURM workers, then merges the shards into a single sarm_progress.parquet.
|
||||
|
||||
Two subcommands, each a separate SLURM submission:
|
||||
|
||||
compute – N workers, each computes progress for a subset of episodes
|
||||
aggregate – 1 worker, merges N shards into sarm_progress.parquet, pushes to hub
|
||||
|
||||
Usage:
|
||||
python slurm_compute_rabc.py compute \\
|
||||
--repo-id user/dataset --reward-model-path user/sarm_model \\
|
||||
--stride 10 --device cpu --workers 50 --partition cpu
|
||||
|
||||
python slurm_compute_rabc.py aggregate \\
|
||||
--repo-id user/dataset --reward-model-path user/sarm_model \\
|
||||
--partition cpu --push-to-hub
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
|
||||
class ComputeProgressShards(PipelineStep):
|
||||
"""Each worker computes SARM progress for its assigned episodes."""
|
||||
|
||||
def __init__(
|
||||
self, repo_id, reward_model_path, stride=1, head_mode="sparse", device="cpu", shard_dir="rabc_shards"
|
||||
):
|
||||
super().__init__()
|
||||
if stride < 1:
|
||||
raise ValueError(f"stride must be >= 1, got {stride}")
|
||||
self.repo_id = repo_id
|
||||
self.reward_model_path = reward_model_path
|
||||
self.stride = stride
|
||||
self.head_mode = head_mode
|
||||
self.device = device
|
||||
self.shard_dir = shard_dir
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.policies.sarm.compute_rabc_weights import (
|
||||
generate_all_frame_indices,
|
||||
interpolate_progress,
|
||||
load_sarm_resources,
|
||||
)
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
|
||||
dataset, reward_model, preprocess = load_sarm_resources(
|
||||
self.repo_id,
|
||||
self.reward_model_path,
|
||||
self.device,
|
||||
)
|
||||
|
||||
if hasattr(preprocess, "eval"):
|
||||
preprocess.eval()
|
||||
for step in preprocess.steps:
|
||||
if hasattr(step, "eval"):
|
||||
step.eval()
|
||||
|
||||
image_key = reward_model.config.image_key
|
||||
state_key = reward_model.config.state_key
|
||||
frame_gap = reward_model.config.frame_gap
|
||||
center_idx = reward_model.config.n_obs_steps // 2
|
||||
|
||||
dual_mode = reward_model.config.uses_dual_heads
|
||||
compute_sparse = self.head_mode in ("sparse", "both") or not dual_mode
|
||||
compute_dense = self.head_mode in ("dense", "both") and dual_mode
|
||||
|
||||
my_episodes = list(range(dataset.num_episodes))[rank::world_size]
|
||||
if not my_episodes:
|
||||
logging.info(f"Rank {rank}: no episodes assigned")
|
||||
return
|
||||
logging.info(f"Rank {rank}: {len(my_episodes)} / {dataset.num_episodes} episodes")
|
||||
|
||||
all_rows = []
|
||||
|
||||
for ep_idx in tqdm(my_episodes, desc=f"Rank {rank}"):
|
||||
ep = dataset.meta.episodes[ep_idx]
|
||||
ep_start, ep_end = ep["dataset_from_index"], ep["dataset_to_index"]
|
||||
task = dataset[ep_start].get("task", "perform the task")
|
||||
|
||||
all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap)
|
||||
if self.stride > 1:
|
||||
compute_indices = [i for i in all_ep_indices if (i - ep_start) % self.stride == 0]
|
||||
if (ep_end - 1) not in compute_indices:
|
||||
compute_indices.append(ep_end - 1)
|
||||
compute_indices = sorted(set(compute_indices))
|
||||
else:
|
||||
compute_indices = all_ep_indices
|
||||
|
||||
frame_results = {}
|
||||
for qi in tqdm(compute_indices, desc=f" Ep {ep_idx}", leave=False):
|
||||
try:
|
||||
sample = dataset[qi]
|
||||
batch = {
|
||||
image_key: sample[image_key],
|
||||
"task": task,
|
||||
"index": qi,
|
||||
"episode_index": ep_idx,
|
||||
}
|
||||
if state_key in sample:
|
||||
batch[state_key] = sample[state_key]
|
||||
|
||||
with torch.no_grad():
|
||||
processed = preprocess(batch)
|
||||
vf = processed["video_features"].to(self.device)
|
||||
tf = processed["text_features"].to(self.device)
|
||||
sf = processed.get("state_features")
|
||||
if sf is not None:
|
||||
sf = sf.to(self.device)
|
||||
lengths = processed.get("lengths")
|
||||
|
||||
sparse_val = dense_val = np.nan
|
||||
if compute_sparse:
|
||||
r = reward_model.calculate_rewards(
|
||||
text_embeddings=tf,
|
||||
video_embeddings=vf,
|
||||
state_features=sf,
|
||||
lengths=lengths,
|
||||
return_all_frames=True,
|
||||
head_mode="sparse",
|
||||
)
|
||||
sparse_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx])
|
||||
if compute_dense:
|
||||
r = reward_model.calculate_rewards(
|
||||
text_embeddings=tf,
|
||||
video_embeddings=vf,
|
||||
state_features=sf,
|
||||
lengths=lengths,
|
||||
return_all_frames=True,
|
||||
head_mode="dense",
|
||||
)
|
||||
dense_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx])
|
||||
|
||||
frame_results[qi] = (sparse_val, dense_val)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed frame {qi}: {e}")
|
||||
|
||||
if not frame_results:
|
||||
logging.warning(f"Episode {ep_idx}: all frames failed, skipping")
|
||||
continue
|
||||
|
||||
# Interpolate to all frames in this episode
|
||||
computed_idx = np.array(sorted(frame_results.keys()))
|
||||
all_frame_arr = np.arange(ep_start, ep_end)
|
||||
|
||||
sparse_vals = np.array([frame_results[i][0] for i in computed_idx]) if compute_sparse else None
|
||||
dense_vals = np.array([frame_results[i][1] for i in computed_idx]) if compute_dense else None
|
||||
|
||||
if self.stride > 1 and len(computed_idx) > 1:
|
||||
if compute_sparse:
|
||||
sparse_vals = interpolate_progress(computed_idx, sparse_vals, all_frame_arr)
|
||||
if compute_dense:
|
||||
dense_vals = interpolate_progress(computed_idx, dense_vals, all_frame_arr)
|
||||
output_frames = all_frame_arr
|
||||
else:
|
||||
# Use only successfully computed frames to avoid indexing mismatch on failures
|
||||
output_frames = computed_idx
|
||||
|
||||
for i, fi in enumerate(output_frames):
|
||||
row = {"index": int(fi), "episode_index": ep_idx, "frame_index": int(fi - ep_start)}
|
||||
if compute_sparse:
|
||||
row["progress_sparse"] = float(sparse_vals[i])
|
||||
if compute_dense:
|
||||
row["progress_dense"] = float(dense_vals[i])
|
||||
all_rows.append(row)
|
||||
|
||||
if all_rows:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame(all_rows).sort_values("index").reset_index(drop=True)
|
||||
table = pa.Table.from_pandas(df, preserve_index=False)
|
||||
table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()})
|
||||
shard_dir = Path(self.shard_dir)
|
||||
shard_dir.mkdir(parents=True, exist_ok=True)
|
||||
out = shard_dir / f"shard_{rank:05d}.parquet"
|
||||
pq.write_table(table, out)
|
||||
logging.info(f"Rank {rank}: saved {len(df)} rows to {out}")
|
||||
|
||||
|
||||
class AggregateProgress(PipelineStep):
|
||||
"""Merge all shard parquets into final sarm_progress.parquet."""
|
||||
|
||||
def __init__(self, repo_id, reward_model_path, shard_dir="rabc_shards", push_to_hub=False):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.reward_model_path = reward_model_path
|
||||
self.shard_dir = shard_dir
|
||||
self.push_to_hub = push_to_hub
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
if rank != 0:
|
||||
return
|
||||
|
||||
shard_dir = Path(self.shard_dir)
|
||||
shards = sorted(shard_dir.glob("shard_*.parquet"))
|
||||
if not shards:
|
||||
raise FileNotFoundError(f"No shards found in {shard_dir}")
|
||||
|
||||
# Log shard modification time range to help detect stale files
|
||||
mtimes = [os.path.getmtime(s) for s in shards]
|
||||
oldest = datetime.datetime.fromtimestamp(min(mtimes)).isoformat(timespec="seconds")
|
||||
newest = datetime.datetime.fromtimestamp(max(mtimes)).isoformat(timespec="seconds")
|
||||
logging.info(f"Aggregating {len(shards)} shards (oldest: {oldest}, newest: {newest})")
|
||||
|
||||
df = pd.concat([pd.read_parquet(s) for s in shards], ignore_index=True)
|
||||
df = df.sort_values("index").reset_index(drop=True)
|
||||
|
||||
table = pa.Table.from_pandas(df, preserve_index=False)
|
||||
table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()})
|
||||
|
||||
temp_ds = LeRobotDataset(self.repo_id, download_videos=False)
|
||||
out_path = Path(temp_ds.root) / "sarm_progress.parquet"
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(table, out_path)
|
||||
logging.info(f"Saved {len(df)} rows to {out_path}")
|
||||
|
||||
for col in ["progress_sparse", "progress_dense"]:
|
||||
if col in df.columns:
|
||||
v = df[col].dropna()
|
||||
logging.info(
|
||||
f"{col}: mean={v.mean():.4f} std={v.std():.4f} min={v.min():.4f} max={v.max():.4f}"
|
||||
)
|
||||
|
||||
if self.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
hub_path = "sarm_progress.parquet"
|
||||
logging.info(f"Uploading to {self.repo_id}/{hub_path}")
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(out_path),
|
||||
path_in_repo=hub_path,
|
||||
repo_id=self.repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
logging.info(f"Uploaded: https://huggingface.co/datasets/{self.repo_id}/blob/main/{hub_path}")
|
||||
|
||||
|
||||
def make_compute_executor(
|
||||
repo_id,
|
||||
reward_model_path,
|
||||
stride,
|
||||
head_mode,
|
||||
device,
|
||||
shard_dir,
|
||||
logs_dir,
|
||||
job_name,
|
||||
slurm,
|
||||
workers,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
ComputeProgressShards(repo_id, reward_model_path, stride, head_mode, device, str(shard_dir)),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": workers,
|
||||
"workers": workers,
|
||||
"time": "24:00:00",
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
return SlurmPipelineExecutor(**kwargs)
|
||||
|
||||
kwargs.update({"tasks": workers, "workers": 1})
|
||||
return LocalPipelineExecutor(**kwargs)
|
||||
|
||||
|
||||
def make_aggregate_executor(
|
||||
repo_id,
|
||||
reward_model_path,
|
||||
shard_dir,
|
||||
logs_dir,
|
||||
job_name,
|
||||
slurm,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
push_to_hub,
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
AggregateProgress(repo_id, reward_model_path, str(shard_dir), push_to_hub),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
"time": "02:00:00",
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
return SlurmPipelineExecutor(**kwargs)
|
||||
|
||||
kwargs.update({"tasks": 1, "workers": 1})
|
||||
return LocalPipelineExecutor(**kwargs)
|
||||
|
||||
|
||||
def _add_shared_args(p):
|
||||
p.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Hugging Face repository identifier, e.g. 'user/dataset'.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--shard-dir",
|
||||
type=Path,
|
||||
default=Path("rabc_shards"),
|
||||
help="Directory to read/write per-rank parquet shards.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
default=Path("logs"),
|
||||
help="Directory for datatrove logs.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="SLURM job name (defaults to rabc_<subcommand>).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="1 = submit via SLURM; 0 = run locally (useful for debugging).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
default=None,
|
||||
help="SLURM partition to submit to.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of CPUs per SLURM task.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="4G",
|
||||
help="Memory per CPU, e.g. '4G' or '1950M'.",
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="SLURM-distributed SARM RA-BC annotation pipeline",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# compute subcommand
|
||||
cp = sub.add_parser(
|
||||
"compute",
|
||||
help="Distribute progress computation across SLURM workers.",
|
||||
)
|
||||
_add_shared_args(cp)
|
||||
cp.add_argument(
|
||||
"--reward-model-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path or HF repo id of the SARM reward model.",
|
||||
)
|
||||
cp.add_argument(
|
||||
"--stride",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Compute every Nth frame; intermediate frames are interpolated (must be >= 1).",
|
||||
)
|
||||
cp.add_argument(
|
||||
"--head-mode",
|
||||
type=str,
|
||||
default="sparse",
|
||||
choices=["sparse", "dense", "both"],
|
||||
help="Which reward head(s) to compute.",
|
||||
)
|
||||
cp.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Device for reward model inference, e.g. 'cpu' or 'cuda'.",
|
||||
)
|
||||
cp.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of parallel SLURM tasks (one shard per worker).",
|
||||
)
|
||||
|
||||
# aggregate subcommand
|
||||
ap = sub.add_parser(
|
||||
"aggregate",
|
||||
help="Merge per-rank shards into a single sarm_progress.parquet.",
|
||||
)
|
||||
_add_shared_args(ap)
|
||||
ap.add_argument(
|
||||
"--reward-model-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path or HF repo id of the SARM reward model (stored in parquet metadata).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Upload sarm_progress.parquet to the Hugging Face Hub after aggregation.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
job_name = args.job_name or f"rabc_{args.command}"
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
kwargs["job_name"] = job_name
|
||||
command = kwargs.pop("command")
|
||||
|
||||
executor = make_compute_executor(**kwargs) if command == "compute" else make_aggregate_executor(**kwargs)
|
||||
|
||||
executor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -78,7 +78,6 @@ from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
@@ -98,7 +97,6 @@ from lerobot.robots import ( # noqa: F401
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
so_follower,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
@@ -14,20 +14,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.g1_utils import (
|
||||
REMOTE_AXES,
|
||||
REMOTE_BUTTONS,
|
||||
G1_29_JointIndex,
|
||||
get_gravity_orientation,
|
||||
)
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -36,13 +36,18 @@ GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # Hip pitch
|
||||
GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # Knee
|
||||
GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # Ankle pitch
|
||||
|
||||
MISSING_JOINTS = []
|
||||
G1_MODEL = "g1_23" # Or "g1_29"
|
||||
if G1_MODEL == "g1_23":
|
||||
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # Waist yaw/pitch, wrist pitch/yaw
|
||||
|
||||
# Control parameters
|
||||
ACTION_SCALE = 0.25
|
||||
CONTROL_DT = 0.02 # 50Hz
|
||||
ANG_VEL_SCALE: float = 0.25
|
||||
DOF_POS_SCALE: float = 1.0
|
||||
DOF_VEL_SCALE: float = 0.05
|
||||
CMD_SCALE: list[float] = [2.0, 2.0, 0.25]
|
||||
CMD_SCALE: list = [2.0, 2.0, 0.25]
|
||||
|
||||
|
||||
DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
|
||||
@@ -80,11 +85,11 @@ def load_groot_policies(
|
||||
class GrootLocomotionController:
|
||||
"""GR00T lower-body locomotion controller for the Unitree G1."""
|
||||
|
||||
control_dt = CONTROL_DT # Expose for unitree_g1.py
|
||||
|
||||
def __init__(self):
|
||||
# Load policies
|
||||
self.policy_balance, self.policy_walk = load_groot_policies()
|
||||
def __init__(self, policy_balance, policy_walk, robot, config):
|
||||
self.policy_balance = policy_balance
|
||||
self.policy_walk = policy_walk
|
||||
self.robot = robot
|
||||
self.config = config
|
||||
|
||||
self.cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot
|
||||
|
||||
@@ -104,60 +109,45 @@ class GrootLocomotionController:
|
||||
|
||||
logger.info("GrootLocomotionController initialized")
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset internal state for a new episode."""
|
||||
self.cmd[:] = 0.0
|
||||
self.groot_qj_all[:] = 0.0
|
||||
self.groot_dqj_all[:] = 0.0
|
||||
self.groot_action[:] = 0.0
|
||||
self.groot_obs_single[:] = 0.0
|
||||
self.groot_obs_stacked[:] = 0.0
|
||||
self.groot_height_cmd = 0.74
|
||||
self.groot_orientation_cmd[:] = 0.0
|
||||
self.groot_obs_history.clear()
|
||||
for _ in range(6):
|
||||
self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
|
||||
def run_step(self):
|
||||
# Get current observation
|
||||
obs = self.robot.get_observation()
|
||||
|
||||
def run_step(self, action: dict, lowstate) -> dict:
|
||||
"""Run one step of the locomotion controller.
|
||||
if not obs:
|
||||
return
|
||||
|
||||
Args:
|
||||
action: Action dict containing remote.lx/ly/rx/ry and buttons
|
||||
lowstate: Robot lowstate containing motor positions/velocities and IMU
|
||||
|
||||
Returns:
|
||||
Action dict for lower body joints (0-14)
|
||||
"""
|
||||
if lowstate is None:
|
||||
return {}
|
||||
|
||||
buttons = [int(action.get(k, 0)) for k in REMOTE_BUTTONS]
|
||||
if buttons[0]: # R1 - raise waist
|
||||
# Get command from remote controller
|
||||
if obs["remote.buttons"][0]: # R1 - raise waist
|
||||
self.groot_height_cmd += 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
if buttons[4]: # R2 - lower waist
|
||||
if obs["remote.buttons"][4]: # R2 - lower waist
|
||||
self.groot_height_cmd -= 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
|
||||
lx, ly, rx, _ry = (action.get(k, 0.0) for k in REMOTE_AXES)
|
||||
self.cmd[0] = ly # Forward/backward
|
||||
self.cmd[1] = -lx # Left/right (negated)
|
||||
self.cmd[2] = -rx # Rotation rate (negated)
|
||||
self.cmd[0] = obs["remote.ly"] # Forward/backward
|
||||
self.cmd[1] = obs["remote.lx"] * -1 # Left/right
|
||||
self.cmd[2] = obs["remote.rx"] * -1 # Rotation rate
|
||||
|
||||
# Get joint positions and velocities from lowstate
|
||||
# Get joint positions and velocities from flat dict
|
||||
for motor in G1_29_JointIndex:
|
||||
name = motor.name
|
||||
idx = motor.value
|
||||
self.groot_qj_all[idx] = lowstate.motor_state[idx].q
|
||||
self.groot_dqj_all[idx] = lowstate.motor_state[idx].dq
|
||||
self.groot_qj_all[idx] = obs[f"{name}.q"]
|
||||
self.groot_dqj_all[idx] = obs[f"{name}.dq"]
|
||||
|
||||
# Adapt observation for g1_23dof
|
||||
for idx in MISSING_JOINTS:
|
||||
self.groot_qj_all[idx] = 0.0
|
||||
self.groot_dqj_all[idx] = 0.0
|
||||
|
||||
# Scale joint positions and velocities
|
||||
qj_obs = self.groot_qj_all.copy()
|
||||
dqj_obs = self.groot_dqj_all.copy()
|
||||
|
||||
# Express IMU data in gravity frame of reference
|
||||
quat = lowstate.imu_state.quaternion
|
||||
ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32)
|
||||
gravity_orientation = get_gravity_orientation(quat)
|
||||
quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
|
||||
ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
|
||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
# Scale joint positions and velocities before policy inference
|
||||
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
|
||||
@@ -196,10 +186,73 @@ class GrootLocomotionController:
|
||||
# Transform action back to target joint positions
|
||||
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * ACTION_SCALE
|
||||
|
||||
# Build action dict
|
||||
# Build action dict (only first 15 joints for GR00T)
|
||||
action_dict = {}
|
||||
for i in range(15):
|
||||
motor_name = G1_29_JointIndex(i).name
|
||||
action_dict[f"{motor_name}.q"] = float(target_dof_pos_15[i])
|
||||
|
||||
return action_dict
|
||||
# Zero out missing joints for g1_23dof
|
||||
for joint_idx in MISSING_JOINTS:
|
||||
motor_name = G1_29_JointIndex(joint_idx).name
|
||||
action_dict[f"{motor_name}.q"] = 0.0
|
||||
|
||||
# Send action to robot
|
||||
self.robot.send_action(action_dict)
|
||||
|
||||
|
||||
def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None:
|
||||
"""Main function to run the GR00T locomotion controller.
|
||||
|
||||
Args:
|
||||
repo_id: Hugging Face Hub repository ID for GR00T policies.
|
||||
"""
|
||||
# Load policies
|
||||
policy_balance, policy_walk = load_groot_policies(repo_id=repo_id)
|
||||
|
||||
# Initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
robot.connect()
|
||||
|
||||
# Initialize gr00T locomotion controller
|
||||
groot_controller = GrootLocomotionController(
|
||||
policy_balance=policy_balance,
|
||||
policy_walk=policy_walk,
|
||||
robot=robot,
|
||||
config=config,
|
||||
)
|
||||
|
||||
try:
|
||||
robot.reset(CONTROL_DT, GROOT_DEFAULT_ANGLES)
|
||||
|
||||
logger.info("Use joystick: LY=fwd/back, LX=left/right, RX=rotate, R1=raise waist, R2=lower waist")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# Run step
|
||||
while not robot._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
groot_controller.run_step()
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Stopping locomotion...")
|
||||
finally:
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
logger.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=DEFAULT_GROOT_REPO_ID,
|
||||
help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run(repo_id=args.repo_id)
|
||||
@@ -14,21 +14,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.g1_utils import (
|
||||
REMOTE_AXES,
|
||||
G1_29_JointArmIndex,
|
||||
G1_29_JointIndex,
|
||||
get_gravity_orientation,
|
||||
)
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_ANGLES = np.zeros(29, dtype=np.float32)
|
||||
@@ -40,13 +40,18 @@ DEFAULT_ANGLES[16] = 0.2 # Left shoulder roll
|
||||
DEFAULT_ANGLES[23] = -0.2 # Right shoulder roll
|
||||
DEFAULT_ANGLES[[18, 25]] = 0.6 # Elbow
|
||||
|
||||
MISSING_JOINTS = []
|
||||
G1_MODEL = "g1_23" # Or "g1_29"
|
||||
if G1_MODEL == "g1_23":
|
||||
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # Waist yaw/pitch, wrist pitch/yaw
|
||||
|
||||
# Control parameters
|
||||
ACTION_SCALE = 0.25
|
||||
CONTROL_DT = 0.005 # 200Hz
|
||||
CONTROL_DT = 0.02 # 50Hz
|
||||
ANG_VEL_SCALE = 0.25
|
||||
DOF_POS_SCALE = 1.0
|
||||
DOF_VEL_SCALE = 0.05
|
||||
GAIT_PERIOD = 0.5
|
||||
GAIT_PERIOD = 1.0
|
||||
|
||||
|
||||
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
|
||||
@@ -82,7 +87,7 @@ def load_policy(
|
||||
logger.info(f"Policy loaded: {policy.get_inputs()[0].shape} → {policy.get_outputs()[0].shape}")
|
||||
|
||||
# Extract KP/KD from ONNX metadata
|
||||
model = onnx.load(policy_path, load_external_data=False)
|
||||
model = onnx.load(policy_path)
|
||||
metadata = {prop.key: prop.value for prop in model.metadata_props}
|
||||
|
||||
if "kp" not in metadata or "kd" not in metadata:
|
||||
@@ -96,13 +101,15 @@ def load_policy(
|
||||
|
||||
|
||||
class HolosomaLocomotionController:
|
||||
"""Holosoma lower-body locomotion controller for Unitree G1."""
|
||||
"""Holosoma whole-body locomotion controller for Unitree G1."""
|
||||
|
||||
control_dt = CONTROL_DT # Expose for unitree_g1.py
|
||||
def __init__(self, policy, robot, kp: np.ndarray, kd: np.ndarray):
|
||||
self.policy = policy
|
||||
self.robot = robot
|
||||
|
||||
def __init__(self):
|
||||
# Load policy and gains
|
||||
self.policy, self.kp, self.kd = load_policy()
|
||||
# Override robot's PD gains with policy gains
|
||||
self.robot.kp = kp
|
||||
self.robot.kd = kd
|
||||
|
||||
self.cmd = np.zeros(3, dtype=np.float32)
|
||||
|
||||
@@ -117,55 +124,35 @@ class HolosomaLocomotionController:
|
||||
self.phase_dt = 2 * np.pi / ((1.0 / CONTROL_DT) * GAIT_PERIOD)
|
||||
self.is_standing = True
|
||||
|
||||
logger.info("HolosomaLocomotionController initialized")
|
||||
def run_step(self):
|
||||
# Get current observation
|
||||
obs = self.robot.get_observation()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset internal state for a new episode."""
|
||||
self.cmd[:] = 0.0
|
||||
self.qj[:] = 0.0
|
||||
self.dqj[:] = 0.0
|
||||
self.obs[:] = 0.0
|
||||
self.last_action[:] = 0.0
|
||||
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||
self.is_standing = True
|
||||
if not obs:
|
||||
return
|
||||
|
||||
def run_step(self, action: dict, lowstate) -> dict:
|
||||
"""Run one step of the locomotion controller.
|
||||
|
||||
Args:
|
||||
action: Action dict containing remote.lx/ly/rx/ry
|
||||
lowstate: Robot lowstate containing motor positions/velocities and IMU
|
||||
|
||||
Returns:
|
||||
Action dict for lower body joints (0-14)
|
||||
"""
|
||||
if lowstate is None:
|
||||
return {}
|
||||
|
||||
lx, ly, rx, _ry = (action.get(k, 0.0) for k in REMOTE_AXES)
|
||||
ly = ly if abs(ly) > 0.1 else 0.0
|
||||
lx = lx if abs(lx) > 0.1 else 0.0
|
||||
rx = rx if abs(rx) > 0.1 else 0.0
|
||||
ly = np.clip(ly, -0.3, 0.3)
|
||||
lx = np.clip(lx, -0.3, 0.3)
|
||||
# Get command from remote controller
|
||||
ly = obs["remote.ly"] if abs(obs["remote.ly"]) > 0.1 else 0.0
|
||||
lx = obs["remote.lx"] if abs(obs["remote.lx"]) > 0.1 else 0.0
|
||||
rx = obs["remote.rx"] if abs(obs["remote.rx"]) > 0.1 else 0.0
|
||||
self.cmd[:] = [ly, -lx, -rx]
|
||||
|
||||
# Get joint positions and velocities from lowstate
|
||||
# Get joint positions and velocities
|
||||
for motor in G1_29_JointIndex:
|
||||
name = motor.name
|
||||
idx = motor.value
|
||||
self.qj[idx] = lowstate.motor_state[idx].q
|
||||
self.dqj[idx] = lowstate.motor_state[idx].dq
|
||||
self.qj[idx] = obs[f"{name}.q"]
|
||||
self.dqj[idx] = obs[f"{name}.dq"]
|
||||
|
||||
# Hide arm positions from policy (show DEFAULT_ANGLES instead)
|
||||
# This prevents policy from reacting to teleop arm movements
|
||||
for arm_joint in G1_29_JointArmIndex:
|
||||
self.qj[arm_joint.value] = DEFAULT_ANGLES[arm_joint.value]
|
||||
self.dqj[arm_joint.value] = 0.0
|
||||
# Adapt observation for g1_23dof
|
||||
for idx in MISSING_JOINTS:
|
||||
self.qj[idx] = 0.0
|
||||
self.dqj[idx] = 0.0
|
||||
|
||||
# Express IMU data in gravity frame of reference
|
||||
quat = lowstate.imu_state.quaternion
|
||||
ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32)
|
||||
gravity = get_gravity_orientation(quat)
|
||||
quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
|
||||
ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
|
||||
gravity = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
# Scale joint positions and velocities before policy inference
|
||||
qj_obs = (self.qj - DEFAULT_ANGLES) * DOF_POS_SCALE
|
||||
@@ -199,16 +186,79 @@ class HolosomaLocomotionController:
|
||||
# Run policy inference
|
||||
ort_in = {self.policy.get_inputs()[0].name: self.obs.reshape(1, -1).astype(np.float32)}
|
||||
raw_action = self.policy.run(None, ort_in)[0].squeeze()
|
||||
policy_action = np.clip(raw_action, -100.0, 100.0)
|
||||
self.last_action = policy_action.copy()
|
||||
action = np.clip(raw_action, -100.0, 100.0)
|
||||
self.last_action = action.copy()
|
||||
|
||||
# Transform action back to target joint positions
|
||||
target = DEFAULT_ANGLES + policy_action * ACTION_SCALE
|
||||
target = DEFAULT_ANGLES + action * ACTION_SCALE
|
||||
|
||||
# Build action dict (first 15 joints only)
|
||||
# Build action dict
|
||||
action_dict = {}
|
||||
for i in range(15):
|
||||
motor_name = G1_29_JointIndex(i).name
|
||||
action_dict[f"{motor_name}.q"] = float(target[i])
|
||||
for motor in G1_29_JointIndex:
|
||||
action_dict[f"{motor.name}.q"] = float(target[motor.value])
|
||||
|
||||
return action_dict
|
||||
# Zero out missing joints for g1_23dof
|
||||
for joint_idx in MISSING_JOINTS:
|
||||
motor_name = G1_29_JointIndex(joint_idx).name
|
||||
action_dict[f"{motor_name}.q"] = 0.0
|
||||
|
||||
# Send action to robot
|
||||
self.robot.send_action(action_dict)
|
||||
|
||||
|
||||
def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") -> None:
|
||||
"""Main function to run the Holosoma locomotion controller.
|
||||
|
||||
Args:
|
||||
repo_id: Hugging Face Hub repository ID for Holosoma policies.
|
||||
policy_type: Policy type to use ('fastsac' or 'ppo').
|
||||
"""
|
||||
# Load policy and gains
|
||||
policy, kp, kd = load_policy(repo_id=repo_id, policy_type=policy_type)
|
||||
|
||||
# Initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
robot.connect()
|
||||
|
||||
holosoma_controller = HolosomaLocomotionController(policy, robot, kp, kd)
|
||||
|
||||
try:
|
||||
robot.reset(CONTROL_DT, DEFAULT_ANGLES)
|
||||
|
||||
logger.info("Use joystick: LY=fwd/back, LX=left/right, RX=rotate")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# Run step
|
||||
while not robot._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
holosoma_controller.run_step()
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Stopping locomotion...")
|
||||
finally:
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
logger.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Holosoma Locomotion Controller for Unitree G1")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=DEFAULT_HOLOSOMA_REPO_ID,
|
||||
help=f"Hugging Face Hub repo ID for Holosoma policies (default: {DEFAULT_HOLOSOMA_REPO_ID})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy",
|
||||
type=str,
|
||||
choices=["fastsac", "ppo"],
|
||||
default="fastsac",
|
||||
help="Policy type to use: 'fastsac' (default) or 'ppo'",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run(repo_id=args.repo_id, policy_type=args.policy)
|
||||
205
pyproject.toml
205
pyproject.toml
@@ -25,11 +25,11 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.5.1"
|
||||
version = "0.4.5"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
dynamic = ["readme"]
|
||||
license = { text = "Apache-2.0" }
|
||||
requires-python = ">=3.12"
|
||||
requires-python = ">=3.10"
|
||||
authors = [
|
||||
{ name = "Rémi Cadène", email = "re.cadene@gmail.com" },
|
||||
{ name = "Simon Alibert", email = "alibert.sim@gmail.com" },
|
||||
@@ -50,8 +50,7 @@ classifiers = [
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Topic :: Software Development :: Build Tools",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
]
|
||||
@@ -62,28 +61,26 @@ dependencies = [
|
||||
# Hugging Face dependencies
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"diffusers>=0.27.2,<0.36.0",
|
||||
"huggingface-hub>=1.0.0,<2.0.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
|
||||
# Core dependencies
|
||||
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
|
||||
"setuptools>=71.0.0,<81.0.0",
|
||||
"cmake>=3.29.0.1,<4.2.0",
|
||||
"packaging>=24.2,<26.0",
|
||||
|
||||
"torch>=2.2.1,<2.11.0",
|
||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||
"torchvision>=0.21.0,<0.26.0",
|
||||
|
||||
"einops>=0.8.0,<0.9.0",
|
||||
"opencv-python-headless>=4.9.0,<4.13.0",
|
||||
"av>=15.0.0,<16.0.0",
|
||||
"jsonlines>=4.0.0,<5.0.0",
|
||||
"pynput>=1.7.8,<1.9.0",
|
||||
"packaging>=24.2,<26.0",
|
||||
"pynput>=1.7.7,<1.9.0",
|
||||
"pyserial>=3.5,<4.0",
|
||||
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
"draccus==0.10.0", # TODO: Relax version constraint
|
||||
|
||||
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
|
||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
|
||||
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
|
||||
|
||||
"draccus==0.10.0", # TODO: Remove ==
|
||||
"gymnasium>=1.1.1,<2.0.0",
|
||||
"rerun-sdk>=0.24.0,<0.27.0",
|
||||
|
||||
@@ -98,14 +95,10 @@ dependencies = [
|
||||
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||
transformers-dep = ["transformers>=5.3.0,<6.0.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
|
||||
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
@@ -119,36 +112,34 @@ gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
unitree_g1 = [
|
||||
"unitree-sdk2==1.0.1",
|
||||
"pyzmq>=26.2.1,<28.0.0",
|
||||
"onnxruntime>=1.16.0,<2.0.0",
|
||||
"pin>=3.0.0,<4.0.0",
|
||||
"meshcat>=0.3.0,<0.4.0",
|
||||
"lerobot[matplotlib-dep]",
|
||||
"lerobot[pygame-dep]",
|
||||
"matplotlib>=3.9.0,<4.0.0",
|
||||
"casadi>=3.6.0,<4.0.0",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54,<2.57.0 ; sys_platform == 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
|
||||
]
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0", "lerobot[scipy-dep]"]
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
||||
|
||||
# Policies
|
||||
wallx = [
|
||||
"lerobot[transformers-dep]",
|
||||
"lerobot[peft]",
|
||||
"lerobot[scipy-dep]",
|
||||
"torchdiffeq>=0.2.4,<0.3.0",
|
||||
"lerobot[qwen-vl-utils-dep]",
|
||||
"transformers==4.49.0",
|
||||
"peft==0.17.1",
|
||||
"scipy==1.15.3",
|
||||
"torchdiffeq==0.2.5",
|
||||
"qwen_vl_utils==0.0.11"
|
||||
]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
"lerobot[peft]",
|
||||
"peft>=0.13.0,<1.0.0",
|
||||
"dm-tree>=0.1.8,<1.0.0",
|
||||
"timm>=1.0.0,<1.1.0",
|
||||
"safetensors>=0.4.3,<1.0.0",
|
||||
@@ -157,13 +148,13 @@ groot = [
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14,<0.1.0"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
peft = ["lerobot[transformers-dep]", "peft>=0.18.0,<1.0.0"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
|
||||
@@ -171,53 +162,13 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
|
||||
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = [
|
||||
"lerobot[transformers-dep]",
|
||||
"hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'",
|
||||
# hf-egl-probe is the fixed fork of egl-probe (robomimic transitive dep).
|
||||
# egl-probe's CMakeLists.txt requires cmake_minimum_required < 3.5 which
|
||||
# modern cmake rejects. Installing hf-egl-probe first satisfies the egl_probe
|
||||
# import without source compilation.
|
||||
"hf-egl-probe>=1.0.1; sys_platform == 'linux'",
|
||||
"lerobot[scipy-dep]",
|
||||
]
|
||||
libero_plus = [
|
||||
# Inherit all of libero's deps (hf-libero → robosuite/robomimic/egl-probe/scipy/transformers).
|
||||
# LIBERO-plus extends LIBERO with extra task suites; its Python module is installed
|
||||
# from the git clone in Dockerfile.eval-libero-plus (overrides hf-libero via .pth).
|
||||
"lerobot[libero]",
|
||||
# Additional runtime deps declared by LIBERO-plus but absent from its setup.py:
|
||||
"bddl>=1.0.1,<2.0.0; sys_platform == 'linux'",
|
||||
"future; sys_platform == 'linux'", # bddl transitive dep not declared in its metadata
|
||||
"easydict>=1.9; sys_platform == 'linux'",
|
||||
"wand; sys_platform == 'linux'",
|
||||
"scikit-image>=0.20.0; sys_platform == 'linux'",
|
||||
"gym>=0.25.0,<0.27.0; sys_platform == 'linux'",
|
||||
]
|
||||
libero-plus = ["lerobot[libero_plus]"]
|
||||
robomme = [
|
||||
"robomme @ git+https://github.com/RoboMME/robomme_benchmark.git@main ; sys_platform == 'linux'",
|
||||
]
|
||||
robocasa = [
|
||||
# robocasa and its robosuite fork are not on PyPI; both installed from source
|
||||
# in Dockerfile.eval-robocasa (requires ARISE-Initiative/robosuite@robocasa_v1.4.1
|
||||
# for PandaOmron and other robocasa-specific robots).
|
||||
"easydict>=1.9; sys_platform == 'linux'",
|
||||
"scikit-image>=0.20.0; sys_platform == 'linux'",
|
||||
"lerobot[scipy-dep]",
|
||||
]
|
||||
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0"]
|
||||
metaworld = ["metaworld==3.0.0"]
|
||||
|
||||
# All
|
||||
all = [
|
||||
# NOTE(resolver hint): scipy is pulled in transitively via lerobot[scipy-dep] through
|
||||
# multiple extras (aloha, metaworld, pi, wallx, phone). Listing it explicitly
|
||||
# helps pip's resolver converge by constraining scipy early, before it encounters
|
||||
# the loose scipy requirements from transitive deps like dm-control and metaworld.
|
||||
"scipy>=1.14.0,<2.0.0",
|
||||
"lerobot[dynamixel]",
|
||||
"lerobot[gamepad]",
|
||||
"lerobot[hopejr]",
|
||||
@@ -225,8 +176,8 @@ all = [
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[wallx]",
|
||||
"lerobot[pi]",
|
||||
# "lerobot[wallx]",
|
||||
# "lerobot[pi]", TODO(Pepijn): Update pi to transformers v5
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
@@ -238,11 +189,10 @@ all = [
|
||||
"lerobot[aloha]",
|
||||
"lerobot[pusht]",
|
||||
"lerobot[phone]",
|
||||
"lerobot[libero]; sys_platform == 'linux'",
|
||||
"lerobot[libero]",
|
||||
"lerobot[metaworld]",
|
||||
"lerobot[sarm]",
|
||||
"lerobot[peft]",
|
||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -254,7 +204,6 @@ lerobot-replay="lerobot.scripts.lerobot_replay:main"
|
||||
lerobot-setup-motors="lerobot.scripts.lerobot_setup_motors:main"
|
||||
lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main"
|
||||
lerobot-eval="lerobot.scripts.lerobot_eval:main"
|
||||
lerobot-eval-worker="lerobot.scripts.lerobot_eval_worker:main"
|
||||
lerobot-train="lerobot.scripts.lerobot_train:main"
|
||||
lerobot-train-tokenizer="lerobot.scripts.lerobot_train_tokenizer:main"
|
||||
lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
|
||||
@@ -262,9 +211,7 @@ lerobot-info="lerobot.scripts.lerobot_info:main"
|
||||
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-leaderboard="lerobot.scripts.lerobot_leaderboard:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-benchmark="lerobot.scripts.lerobot_benchmark:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
@@ -274,7 +221,7 @@ lerobot = ["envs/*.json"]
|
||||
where = ["src"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
target-version = "py310"
|
||||
line-length = 110
|
||||
exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
|
||||
|
||||
@@ -366,7 +313,7 @@ default.extend-ignore-identifiers-re = [
|
||||
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.12"
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
follow_imports = "skip"
|
||||
# warn_return_any = true
|
||||
@@ -450,3 +397,85 @@ ignore_errors = false
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.scripts.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[tool.uv]
|
||||
# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "transformers-dep" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "pi" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "smolvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "groot" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "xvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "sarm" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "hilserl" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "libero" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "peft" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "all" },
|
||||
],
|
||||
# pi uses custom branch which conflicts with transformers-dep
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "transformers-dep" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "smolvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "groot" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "xvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "sarm" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "hilserl" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "libero" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "peft" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "all" },
|
||||
],
|
||||
]
|
||||
|
||||
@@ -1,73 +1,76 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with Python 3.12
|
||||
# This file is autogenerated by pip-compile with Python 3.10
|
||||
# by the following command:
|
||||
#
|
||||
# pip-compile --output-file=requirements-macos.txt requirements.in
|
||||
#
|
||||
-e .[all]
|
||||
# via -[all]
|
||||
absl-py==2.4.0
|
||||
absl-py==2.3.1
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# labmaze
|
||||
# mujoco
|
||||
accelerate==1.13.0
|
||||
# tensorboard
|
||||
accelerate==1.11.0
|
||||
# via
|
||||
# lerobot
|
||||
# peft
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.3
|
||||
aiohttp==3.13.1
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-doc==0.0.4
|
||||
# via
|
||||
# fastapi
|
||||
# typer
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.12.1
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via
|
||||
# hydra-core
|
||||
# omegaconf
|
||||
anyio==4.11.0
|
||||
# via
|
||||
# httpx
|
||||
# starlette
|
||||
# watchfiles
|
||||
asttokens==3.0.1
|
||||
asttokens==3.0.0
|
||||
# via stack-data
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
# dm-tree
|
||||
# jsonlines
|
||||
# jsonschema
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
av==15.1.0
|
||||
# via lerobot
|
||||
bddl==1.0.1
|
||||
# via libero
|
||||
certifi==2025.10.5
|
||||
# via
|
||||
# lerobot
|
||||
# qwen-vl-utils
|
||||
certifi==2026.2.25
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==2.0.0
|
||||
# via pymunk
|
||||
cfgv==3.5.0
|
||||
cfgv==3.4.0
|
||||
# via pre-commit
|
||||
charset-normalizer==3.4.5
|
||||
charset-normalizer==3.4.4
|
||||
# via requests
|
||||
click==8.3.1
|
||||
click==8.3.0
|
||||
# via
|
||||
# typer
|
||||
# uvicorn
|
||||
# wandb
|
||||
cloudpickle==3.1.2
|
||||
# via gymnasium
|
||||
cmake==4.1.3
|
||||
cloudpickle==3.1.1
|
||||
# via
|
||||
# gymnasium
|
||||
# libero
|
||||
cmake==4.1.0
|
||||
# via lerobot
|
||||
cmeel==0.59.0
|
||||
cmeel==0.57.3
|
||||
# via
|
||||
# cmeel-assimp
|
||||
# cmeel-boost
|
||||
@@ -105,17 +108,15 @@ cmeel-zlib==1.3.1
|
||||
# via cmeel-assimp
|
||||
coal-library==3.0.1
|
||||
# via pin
|
||||
contourpy==1.3.3
|
||||
# via
|
||||
# lerobot
|
||||
# matplotlib
|
||||
coverage[toml]==7.13.4
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
coverage[toml]==7.11.0
|
||||
# via pytest-cov
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
datasets==4.6.1
|
||||
datasets==4.1.1
|
||||
# via lerobot
|
||||
debugpy==1.8.20
|
||||
debugpy==1.8.17
|
||||
# via lerobot
|
||||
decorator==5.2.1
|
||||
# via ipython
|
||||
@@ -129,7 +130,7 @@ dill==0.4.0
|
||||
# multiprocess
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
dm-control==1.0.37
|
||||
dm-control==1.0.34
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
@@ -137,55 +138,69 @@ dm-tree==0.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# lerobot
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
draccus==0.10.0
|
||||
# via lerobot
|
||||
dynamixel-sdk==3.8.4
|
||||
# via lerobot
|
||||
easydict==1.13
|
||||
# via libero
|
||||
egl-probe @ git+https://github.com/huggingface/egl_probe.git
|
||||
# via
|
||||
# libero
|
||||
# robomimic
|
||||
eigenpy==3.10.3
|
||||
# via coal-library
|
||||
einops==0.8.2
|
||||
# via lerobot
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
etils[epath,epy]==1.14.0
|
||||
# via mujoco
|
||||
executing==2.2.1
|
||||
# via stack-data
|
||||
faker==34.0.2
|
||||
# via lerobot
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fastapi==0.135.1
|
||||
einops==0.8.1
|
||||
# via
|
||||
# lerobot
|
||||
# teleop
|
||||
# libero
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
etils[epath,epy]==1.13.0
|
||||
# via mujoco
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# ipython
|
||||
# pytest
|
||||
executing==2.2.1
|
||||
# via stack-data
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fastapi==0.119.1
|
||||
# via teleop
|
||||
fastjsonschema==2.21.2
|
||||
# via nbformat
|
||||
feetech-servo-sdk==1.0.0
|
||||
# via lerobot
|
||||
filelock==3.25.0
|
||||
filelock==3.20.0
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# huggingface-hub
|
||||
# python-discovery
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
fonttools==4.61.1
|
||||
fonttools==4.60.1
|
||||
# via matplotlib
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2026.2.0
|
||||
fsspec[http]==2025.9.0
|
||||
# via
|
||||
# datasets
|
||||
# etils
|
||||
# huggingface-hub
|
||||
# torch
|
||||
future==1.0.0
|
||||
# via libero
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.46
|
||||
gitpython==3.1.45
|
||||
# via wandb
|
||||
glfw==2.10.0
|
||||
# via
|
||||
@@ -197,6 +212,7 @@ grpcio==1.73.1
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
grpcio-tools==1.73.1
|
||||
# via
|
||||
# lerobot
|
||||
@@ -207,67 +223,71 @@ gym-hil==0.1.13
|
||||
# via lerobot
|
||||
gym-pusht==0.1.6
|
||||
# via lerobot
|
||||
gymnasium==1.2.3
|
||||
gymnasium==1.2.1
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
# libero
|
||||
# metaworld
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
# via uvicorn
|
||||
h5py==3.15.1
|
||||
# via robomimic
|
||||
hebi-py==2.11.0
|
||||
# via lerobot
|
||||
hf-xet==1.3.2
|
||||
hf-transfer==0.1.9
|
||||
# via huggingface-hub
|
||||
hf-xet==1.1.10
|
||||
# via huggingface-hub
|
||||
hidapi==0.14.0.post4
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httptools==0.7.1
|
||||
# via uvicorn
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
huggingface-hub==1.6.0
|
||||
huggingface-hub[cli,hf-transfer]==0.35.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
identify==2.6.17
|
||||
hydra-core==1.3.2
|
||||
# via libero
|
||||
identify==2.6.15
|
||||
# via pre-commit
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
imageio[ffmpeg]==2.37.2
|
||||
imageio[ffmpeg]==2.37.0
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robomimic
|
||||
# scikit-image
|
||||
imageio-ffmpeg==0.6.0
|
||||
# via imageio
|
||||
importlib-metadata==8.7.1
|
||||
# via
|
||||
# imageio
|
||||
# robomimic
|
||||
importlib-metadata==8.7.0
|
||||
# via diffusers
|
||||
importlib-resources==6.5.2
|
||||
# via etils
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
ipython==9.11.0
|
||||
inquirerpy==0.3.4
|
||||
# via huggingface-hub
|
||||
ipython==8.37.0
|
||||
# via meshcat
|
||||
ipython-pygments-lexers==1.1.1
|
||||
# via ipython
|
||||
ischedule==1.2.7
|
||||
# via placo
|
||||
jedi==0.19.2
|
||||
@@ -276,24 +296,44 @@ jinja2==3.1.6
|
||||
# via torch
|
||||
jsonlines==4.0.0
|
||||
# via lerobot
|
||||
jsonschema==4.25.1
|
||||
# via nbformat
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
jupyter-core==5.9.1
|
||||
# via nbformat
|
||||
jupytext==1.18.1
|
||||
# via bddl
|
||||
kiwisolver==1.4.9
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lazy-loader==0.5
|
||||
lazy-loader==0.4
|
||||
# via scikit-image
|
||||
librt==0.8.1
|
||||
# via mypy
|
||||
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
|
||||
# via lerobot
|
||||
llvmlite==0.45.1
|
||||
# via numba
|
||||
lxml==6.0.2
|
||||
# via dm-control
|
||||
markdown==3.9
|
||||
# via tensorboard
|
||||
markdown-it-py==4.0.0
|
||||
# via rich
|
||||
# via
|
||||
# jupytext
|
||||
# mdit-py-plugins
|
||||
markupsafe==3.0.3
|
||||
# via jinja2
|
||||
matplotlib==3.10.8
|
||||
# via lerobot
|
||||
# via
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.10.7
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
matplotlib-inline==0.2.1
|
||||
# via ipython
|
||||
mdit-py-plugins==0.5.0
|
||||
# via jupytext
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mergedeep==1.3.4
|
||||
@@ -306,35 +346,41 @@ mock-serial==0.0.1
|
||||
# via lerobot
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mujoco==3.5.0
|
||||
mujoco==3.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# libero
|
||||
# metaworld
|
||||
multidict==6.7.1
|
||||
# robosuite
|
||||
multidict==6.7.0
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
multiprocess==0.70.18
|
||||
multiprocess==0.70.16
|
||||
# via datasets
|
||||
mypy==1.19.1
|
||||
# via lerobot
|
||||
mypy-extensions==1.1.0
|
||||
# via typing-inspect
|
||||
nbformat==5.10.4
|
||||
# via jupytext
|
||||
networkx==3.4.2
|
||||
# via
|
||||
# mypy
|
||||
# typing-inspect
|
||||
networkx==3.6.1
|
||||
# via
|
||||
# bddl
|
||||
# scikit-image
|
||||
# torch
|
||||
nodeenv==1.10.0
|
||||
ninja==1.13.0
|
||||
# via lerobot
|
||||
nodeenv==1.9.1
|
||||
# via pre-commit
|
||||
num2words==0.5.14
|
||||
# via lerobot
|
||||
numba==0.62.1
|
||||
# via robosuite
|
||||
numpy==2.2.6
|
||||
# via
|
||||
# accelerate
|
||||
# bddl
|
||||
# cmeel-boost
|
||||
# contourpy
|
||||
# datasets
|
||||
@@ -343,14 +389,16 @@ numpy==2.2.6
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# gymnasium
|
||||
# h5py
|
||||
# hebi-py
|
||||
# imageio
|
||||
# labmaze
|
||||
# lerobot
|
||||
# libero
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# metaworld
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# opencv-python-headless
|
||||
# pandas
|
||||
@@ -358,18 +406,26 @@ numpy==2.2.6
|
||||
# pyquaternion
|
||||
# reachy2-sdk
|
||||
# rerun-sdk
|
||||
# robomimic
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# scipy
|
||||
# shapely
|
||||
# teleop
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# tifffile
|
||||
# torchvision
|
||||
# transformers
|
||||
# transforms3d
|
||||
opencv-python==4.13.0.92
|
||||
omegaconf==2.3.0
|
||||
# via hydra-core
|
||||
opencv-python==4.12.0.88
|
||||
# via
|
||||
# gym-pusht
|
||||
# libero
|
||||
# reachy2-sdk
|
||||
# robosuite
|
||||
opencv-python-headless==4.12.0.88
|
||||
# via lerobot
|
||||
orderly-set==5.5.0
|
||||
@@ -379,87 +435,97 @@ packaging==25.0
|
||||
# accelerate
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# hydra-core
|
||||
# jupytext
|
||||
# lazy-loader
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# peft
|
||||
# pytest
|
||||
# qwen-vl-utils
|
||||
# reachy2-sdk
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# transformers
|
||||
# wandb
|
||||
pandas==2.3.3
|
||||
# via
|
||||
# datasets
|
||||
# lerobot
|
||||
parso==0.8.6
|
||||
parso==0.8.5
|
||||
# via jedi
|
||||
pathspec==1.0.4
|
||||
# via mypy
|
||||
peft==0.18.1
|
||||
peft==0.17.1
|
||||
# via lerobot
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pillow==12.1.1
|
||||
pfzy==0.3.4
|
||||
# via inquirerpy
|
||||
pillow==12.0.0
|
||||
# via
|
||||
# diffusers
|
||||
# imageio
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# qwen-vl-utils
|
||||
# rerun-sdk
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# torchvision
|
||||
pin==3.4.0
|
||||
# via placo
|
||||
placo==0.9.16
|
||||
placo==0.9.14
|
||||
# via lerobot
|
||||
platformdirs==4.9.4
|
||||
platformdirs==4.5.0
|
||||
# via
|
||||
# python-discovery
|
||||
# jupyter-core
|
||||
# virtualenv
|
||||
# wandb
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
pre-commit==4.5.1
|
||||
pre-commit==4.3.0
|
||||
# via lerobot
|
||||
prompt-toolkit==3.0.52
|
||||
# via ipython
|
||||
# via
|
||||
# inquirerpy
|
||||
# ipython
|
||||
propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
protobuf==6.31.1
|
||||
protobuf==6.31.0
|
||||
# via
|
||||
# dm-control
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# wandb
|
||||
psutil==7.2.2
|
||||
psutil==7.1.1
|
||||
# via
|
||||
# accelerate
|
||||
# imageio
|
||||
# peft
|
||||
# robomimic
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
# via stack-data
|
||||
pyarrow==23.0.1
|
||||
pyarrow==21.0.0
|
||||
# via
|
||||
# datasets
|
||||
# rerun-sdk
|
||||
pycparser==3.0
|
||||
pycparser==2.23
|
||||
# via cffi
|
||||
pydantic==2.12.5
|
||||
pydantic==2.12.3
|
||||
# via
|
||||
# fastapi
|
||||
# wandb
|
||||
pydantic-core==2.41.5
|
||||
pydantic-core==2.41.4
|
||||
# via pydantic
|
||||
pygame==2.6.1
|
||||
# via
|
||||
@@ -469,35 +535,33 @@ pygame==2.6.1
|
||||
pygments==2.19.2
|
||||
# via
|
||||
# ipython
|
||||
# ipython-pygments-lexers
|
||||
# pytest
|
||||
# rich
|
||||
pymunk==6.11.1
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pyngrok==7.5.1
|
||||
pyngrok==7.4.1
|
||||
# via meshcat
|
||||
pynput==1.8.1
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
pyobjc-core==12.1
|
||||
pyobjc-core==12.0
|
||||
# via
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-cocoa
|
||||
# pyobjc-framework-coretext
|
||||
# pyobjc-framework-quartz
|
||||
pyobjc-framework-applicationservices==12.1
|
||||
pyobjc-framework-applicationservices==12.0
|
||||
# via pynput
|
||||
pyobjc-framework-cocoa==12.1
|
||||
pyobjc-framework-cocoa==12.0
|
||||
# via
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-coretext
|
||||
# pyobjc-framework-quartz
|
||||
pyobjc-framework-coretext==12.1
|
||||
pyobjc-framework-coretext==12.0
|
||||
# via pyobjc-framework-applicationservices
|
||||
pyobjc-framework-quartz==12.1
|
||||
pyobjc-framework-quartz==12.0
|
||||
# via
|
||||
# pynput
|
||||
# pyobjc-framework-applicationservices
|
||||
@@ -506,13 +570,13 @@ pyopengl==3.1.10
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.3.2
|
||||
pyparsing==3.2.5
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via reachy2-sdk
|
||||
pyrealsense2-macosx==2.56.5
|
||||
pyrealsense2-macosx==2.54.2
|
||||
# via lerobot
|
||||
pyserial==3.5
|
||||
# via
|
||||
@@ -521,6 +585,7 @@ pyserial==3.5
|
||||
# lerobot
|
||||
pytest==8.4.2
|
||||
# via
|
||||
# bddl
|
||||
# lerobot
|
||||
# pytest-cov
|
||||
# pytest-timeout
|
||||
@@ -531,14 +596,11 @@ pytest-timeout==2.4.0
|
||||
# via lerobot
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# faker
|
||||
# matplotlib
|
||||
# pandas
|
||||
python-discovery==1.1.1
|
||||
# via virtualenv
|
||||
python-dotenv==1.2.2
|
||||
python-dotenv==1.1.1
|
||||
# via uvicorn
|
||||
pytz==2026.1.post1
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
@@ -547,10 +609,13 @@ pyyaml==6.0.3
|
||||
# draccus
|
||||
# hebi-py
|
||||
# huggingface-hub
|
||||
# jupytext
|
||||
# omegaconf
|
||||
# peft
|
||||
# pre-commit
|
||||
# pyngrok
|
||||
# pyyaml-include
|
||||
# timm
|
||||
# transformers
|
||||
# uvicorn
|
||||
# wandb
|
||||
@@ -560,13 +625,15 @@ pyzmq==27.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# meshcat
|
||||
qwen-vl-utils==0.0.14
|
||||
# via lerobot
|
||||
reachy2-sdk==1.0.15
|
||||
reachy2-sdk==1.0.14
|
||||
# via lerobot
|
||||
reachy2-sdk-api==1.0.21
|
||||
# via reachy2-sdk
|
||||
regex==2026.2.28
|
||||
referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2025.10.23
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
@@ -575,150 +642,184 @@ requests==2.32.5
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# qwen-vl-utils
|
||||
# huggingface-hub
|
||||
# teleop
|
||||
# transformers
|
||||
# wandb
|
||||
rerun-sdk==0.26.2
|
||||
rerun-sdk==0.26.1
|
||||
# via lerobot
|
||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||
# via placo
|
||||
rich==14.3.3
|
||||
# via typer
|
||||
safetensors==0.7.0
|
||||
robomimic==0.2.0
|
||||
# via libero
|
||||
robosuite==1.4.0
|
||||
# via libero
|
||||
rpds-py==0.28.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.6.2
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# transformers
|
||||
scikit-image==0.25.2
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
scipy==1.17.1
|
||||
scipy==1.15.3
|
||||
# via
|
||||
# dm-control
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# torchdiffeq
|
||||
sentry-sdk==2.54.0
|
||||
sentry-sdk==2.42.1
|
||||
# via wandb
|
||||
shapely==2.1.2
|
||||
# via gym-pusht
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.17.0
|
||||
# via
|
||||
# pynput
|
||||
# python-dateutil
|
||||
smmap==5.0.3
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
starlette==0.52.1
|
||||
starlette==0.48.0
|
||||
# via fastapi
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
teleop==0.1.4
|
||||
teleop==0.1.2
|
||||
# via lerobot
|
||||
termcolor==3.3.0
|
||||
# via lerobot
|
||||
tifffile==2026.3.3
|
||||
tensorboard==2.20.0
|
||||
# via robomimic
|
||||
tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
tensorboardx==2.6.4
|
||||
# via robomimic
|
||||
termcolor==3.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
thop==0.1.1.post2209072238
|
||||
# via libero
|
||||
tifffile==2025.5.10
|
||||
# via scikit-image
|
||||
tokenizers==0.22.2
|
||||
timm==1.0.20
|
||||
# via lerobot
|
||||
tokenizers==0.22.1
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via draccus
|
||||
torch==2.10.0
|
||||
tomli==2.3.0
|
||||
# via
|
||||
# cmeel
|
||||
# coverage
|
||||
# jupytext
|
||||
# pytest
|
||||
torch==2.7.1
|
||||
# via
|
||||
# accelerate
|
||||
# lerobot
|
||||
# peft
|
||||
# torchdiffeq
|
||||
# robomimic
|
||||
# thop
|
||||
# timm
|
||||
# torchvision
|
||||
torchcodec==0.10.0
|
||||
torchcodec==0.5
|
||||
# via lerobot
|
||||
torchdiffeq==0.2.5
|
||||
# via lerobot
|
||||
torchvision==0.25.0
|
||||
# via lerobot
|
||||
tornado==6.5.4
|
||||
torchvision==0.22.1
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
# timm
|
||||
tornado==6.5.2
|
||||
# via meshcat
|
||||
tqdm==4.67.3
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# datasets
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# peft
|
||||
# robomimic
|
||||
# transformers
|
||||
traitlets==5.14.3
|
||||
# via
|
||||
# ipython
|
||||
# jupyter-core
|
||||
# matplotlib-inline
|
||||
transformers==5.3.0
|
||||
# nbformat
|
||||
transformers==4.57.1
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
# peft
|
||||
transforms3d==0.4.2
|
||||
# via teleop
|
||||
typer==0.24.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# etils
|
||||
# faker
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
# gymnasium
|
||||
# huggingface-hub
|
||||
# mypy
|
||||
# ipython
|
||||
# multidict
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
# starlette
|
||||
# torch
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# uvicorn
|
||||
# virtualenv
|
||||
# wandb
|
||||
typing-inspect==0.9.0
|
||||
# via draccus
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# pydantic
|
||||
tzdata==2025.3
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via pandas
|
||||
u-msgpack-python==2.8.0
|
||||
# via meshcat
|
||||
urllib3==2.6.3
|
||||
urllib3==2.5.0
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
uvicorn[standard]==0.41.0
|
||||
uvicorn[standard]==0.38.0
|
||||
# via teleop
|
||||
uvloop==0.22.1
|
||||
# via uvicorn
|
||||
virtualenv==21.1.0
|
||||
virtualenv==20.35.3
|
||||
# via pre-commit
|
||||
wandb==0.24.2
|
||||
# via lerobot
|
||||
wandb==0.21.4
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
watchfiles==1.1.1
|
||||
# via uvicorn
|
||||
wcwidth==0.6.0
|
||||
wcwidth==0.2.14
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.9.0
|
||||
# via teleop
|
||||
websockets==16.0
|
||||
websockets==15.0.1
|
||||
# via uvicorn
|
||||
wrapt==2.1.2
|
||||
werkzeug==3.1.3
|
||||
# via tensorboard
|
||||
wrapt==2.0.0
|
||||
# via dm-tree
|
||||
xxhash==3.6.0
|
||||
# via datasets
|
||||
yarl==1.23.0
|
||||
yarl==1.22.0
|
||||
# via aiohttp
|
||||
zipp==3.23.0
|
||||
# via
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with Python 3.12
|
||||
# This file is autogenerated by pip-compile with Python 3.10
|
||||
# by the following command:
|
||||
#
|
||||
# pip-compile --output-file=requirements-ubuntu.txt requirements.in
|
||||
#
|
||||
-e .[all]
|
||||
# via -[all]
|
||||
absl-py==2.4.0
|
||||
absl-py==2.3.1
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
@@ -14,33 +14,30 @@ absl-py==2.4.0
|
||||
# labmaze
|
||||
# mujoco
|
||||
# tensorboard
|
||||
accelerate==1.13.0
|
||||
accelerate==1.11.0
|
||||
# via
|
||||
# lerobot
|
||||
# peft
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.3
|
||||
aiohttp==3.13.1
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-doc==0.0.4
|
||||
# via
|
||||
# fastapi
|
||||
# typer
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via
|
||||
# hydra-core
|
||||
# omegaconf
|
||||
anyio==4.12.1
|
||||
anyio==4.11.0
|
||||
# via
|
||||
# httpx
|
||||
# starlette
|
||||
# watchfiles
|
||||
asttokens==3.0.1
|
||||
asttokens==3.0.0
|
||||
# via stack-data
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
@@ -50,35 +47,30 @@ attrs==25.4.0
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
av==15.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# qwen-vl-utils
|
||||
# via lerobot
|
||||
bddl==1.0.1
|
||||
# via hf-libero
|
||||
certifi==2026.2.25
|
||||
# via libero
|
||||
certifi==2025.10.5
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==2.0.0
|
||||
# via pymunk
|
||||
cfgv==3.5.0
|
||||
cfgv==3.4.0
|
||||
# via pre-commit
|
||||
charset-normalizer==3.4.5
|
||||
charset-normalizer==3.4.4
|
||||
# via requests
|
||||
click==8.3.1
|
||||
click==8.3.0
|
||||
# via
|
||||
# typer
|
||||
# uvicorn
|
||||
# wandb
|
||||
cloudpickle==3.1.2
|
||||
cloudpickle==3.1.1
|
||||
# via
|
||||
# gymnasium
|
||||
# hf-libero
|
||||
cmake==4.1.3
|
||||
# libero
|
||||
cmake==4.1.0
|
||||
# via lerobot
|
||||
cmeel==0.59.0
|
||||
cmeel==0.57.3
|
||||
# via
|
||||
# cmeel-assimp
|
||||
# cmeel-boost
|
||||
@@ -116,24 +108,20 @@ cmeel-zlib==1.3.1
|
||||
# via cmeel-assimp
|
||||
coal-library==3.0.1
|
||||
# via pin
|
||||
contourpy==1.3.3
|
||||
# via
|
||||
# lerobot
|
||||
# matplotlib
|
||||
coverage[toml]==7.13.4
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
coverage[toml]==7.11.0
|
||||
# via pytest-cov
|
||||
cuda-bindings==12.9.4
|
||||
# via torch
|
||||
cuda-pathfinder==1.4.1
|
||||
# via cuda-bindings
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
datasets==4.6.1
|
||||
datasets==4.1.1
|
||||
# via lerobot
|
||||
debugpy==1.8.20
|
||||
debugpy==1.8.17
|
||||
# via lerobot
|
||||
decorator==5.2.1
|
||||
# via ipython
|
||||
decord==0.6.0
|
||||
# via lerobot
|
||||
deepdiff==8.6.1
|
||||
# via lerobot
|
||||
diffusers==0.35.2
|
||||
@@ -144,7 +132,7 @@ dill==0.4.0
|
||||
# multiprocess
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
dm-control==1.0.37
|
||||
dm-control==1.0.34
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
@@ -152,6 +140,7 @@ dm-tree==0.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# lerobot
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
draccus==0.10.0
|
||||
@@ -159,60 +148,66 @@ draccus==0.10.0
|
||||
dynamixel-sdk==3.8.4
|
||||
# via lerobot
|
||||
easydict==1.13
|
||||
# via hf-libero
|
||||
egl-probe==1.0.2
|
||||
# via robomimic
|
||||
# via libero
|
||||
egl-probe @ git+https://github.com/huggingface/egl_probe.git
|
||||
# via
|
||||
# libero
|
||||
# robomimic
|
||||
eigenpy==3.10.3
|
||||
# via coal-library
|
||||
einops==0.8.2
|
||||
einops==0.8.1
|
||||
# via
|
||||
# hf-libero
|
||||
# flash-attn
|
||||
# lerobot
|
||||
# libero
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
etils[epath,epy]==1.14.0
|
||||
etils[epath,epy]==1.13.0
|
||||
# via mujoco
|
||||
evdev==1.9.3
|
||||
evdev==1.9.2
|
||||
# via pynput
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# ipython
|
||||
# pytest
|
||||
executing==2.2.1
|
||||
# via stack-data
|
||||
faker==34.0.2
|
||||
# via lerobot
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fastapi==0.135.1
|
||||
# via
|
||||
# lerobot
|
||||
# teleop
|
||||
fastapi==0.119.1
|
||||
# via teleop
|
||||
fastjsonschema==2.21.2
|
||||
# via nbformat
|
||||
feetech-servo-sdk==1.0.0
|
||||
# via lerobot
|
||||
filelock==3.25.0
|
||||
filelock==3.20.0
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# huggingface-hub
|
||||
# python-discovery
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
fonttools==4.61.1
|
||||
flash-attn==2.8.3
|
||||
# via lerobot
|
||||
fonttools==4.60.1
|
||||
# via matplotlib
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2026.2.0
|
||||
fsspec[http]==2025.9.0
|
||||
# via
|
||||
# datasets
|
||||
# etils
|
||||
# huggingface-hub
|
||||
# torch
|
||||
future==1.0.0
|
||||
# via hf-libero
|
||||
# via libero
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.46
|
||||
gitpython==3.1.45
|
||||
# via wandb
|
||||
glfw==2.10.0
|
||||
# via
|
||||
@@ -235,60 +230,50 @@ gym-hil==0.1.13
|
||||
# via lerobot
|
||||
gym-pusht==0.1.6
|
||||
# via lerobot
|
||||
gymnasium==1.2.3
|
||||
gymnasium==1.2.1
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# hf-libero
|
||||
# lerobot
|
||||
# libero
|
||||
# metaworld
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
h5py==3.16.0
|
||||
# via uvicorn
|
||||
h5py==3.15.1
|
||||
# via robomimic
|
||||
hebi-py==2.11.0
|
||||
# via lerobot
|
||||
hf-egl-probe==1.0.2
|
||||
# via hf-libero
|
||||
hf-libero==0.1.3
|
||||
# via lerobot
|
||||
hf-xet==1.3.2
|
||||
hf-transfer==0.1.9
|
||||
# via huggingface-hub
|
||||
hf-xet==1.1.10
|
||||
# via huggingface-hub
|
||||
hidapi==0.14.0.post4
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httptools==0.7.1
|
||||
# via uvicorn
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
huggingface-hub==1.6.0
|
||||
huggingface-hub[cli,hf-transfer]==0.35.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
hydra-core==1.3.2
|
||||
# via hf-libero
|
||||
identify==2.6.17
|
||||
# via libero
|
||||
identify==2.6.15
|
||||
# via pre-commit
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
imageio[ffmpeg]==2.37.2
|
||||
imageio[ffmpeg]==2.37.0
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
@@ -300,14 +285,16 @@ imageio-ffmpeg==0.6.0
|
||||
# via
|
||||
# imageio
|
||||
# robomimic
|
||||
importlib-metadata==8.7.1
|
||||
importlib-metadata==8.7.0
|
||||
# via diffusers
|
||||
importlib-resources==6.5.2
|
||||
# via etils
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
ipython==9.11.0
|
||||
inquirerpy==0.3.4
|
||||
# via huggingface-hub
|
||||
ipython==8.37.0
|
||||
# via meshcat
|
||||
ipython-pygments-lexers==1.1.1
|
||||
# via ipython
|
||||
ischedule==1.2.7
|
||||
# via placo
|
||||
jedi==0.19.2
|
||||
@@ -316,41 +303,40 @@ jinja2==3.1.6
|
||||
# via torch
|
||||
jsonlines==4.0.0
|
||||
# via lerobot
|
||||
jsonschema==4.26.0
|
||||
jsonschema==4.25.1
|
||||
# via nbformat
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
jupyter-core==5.9.1
|
||||
# via nbformat
|
||||
jupytext==1.19.1
|
||||
jupytext==1.18.1
|
||||
# via bddl
|
||||
kiwisolver==1.4.9
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lazy-loader==0.5
|
||||
lazy-loader==0.4
|
||||
# via scikit-image
|
||||
librt==0.8.1
|
||||
# via mypy
|
||||
llvmlite==0.46.0
|
||||
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
|
||||
# via lerobot
|
||||
llvmlite==0.45.1
|
||||
# via numba
|
||||
lxml==6.0.2
|
||||
# via dm-control
|
||||
markdown==3.10.2
|
||||
markdown==3.9
|
||||
# via tensorboard
|
||||
markdown-it-py==4.0.0
|
||||
# via
|
||||
# jupytext
|
||||
# mdit-py-plugins
|
||||
# rich
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.10.8
|
||||
matplotlib==3.10.7
|
||||
# via
|
||||
# hf-libero
|
||||
# lerobot
|
||||
# libero
|
||||
matplotlib-inline==0.2.1
|
||||
# via ipython
|
||||
mdit-py-plugins==0.5.0
|
||||
@@ -367,38 +353,36 @@ mock-serial==0.0.1
|
||||
# via lerobot
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mujoco==3.5.0
|
||||
mujoco==3.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# hf-libero
|
||||
# libero
|
||||
# metaworld
|
||||
# robosuite
|
||||
multidict==6.7.1
|
||||
multidict==6.7.0
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
multiprocess==0.70.18
|
||||
multiprocess==0.70.16
|
||||
# via datasets
|
||||
mypy==1.19.1
|
||||
# via lerobot
|
||||
mypy-extensions==1.1.0
|
||||
# via
|
||||
# mypy
|
||||
# typing-inspect
|
||||
# via typing-inspect
|
||||
nbformat==5.10.4
|
||||
# via jupytext
|
||||
networkx==3.6.1
|
||||
networkx==3.4.2
|
||||
# via
|
||||
# bddl
|
||||
# scikit-image
|
||||
# torch
|
||||
nodeenv==1.10.0
|
||||
ninja==1.13.0
|
||||
# via lerobot
|
||||
nodeenv==1.9.1
|
||||
# via pre-commit
|
||||
num2words==0.5.14
|
||||
# via lerobot
|
||||
numba==0.64.0
|
||||
numba==0.62.1
|
||||
# via robosuite
|
||||
numpy==2.2.6
|
||||
# via
|
||||
@@ -407,6 +391,7 @@ numpy==2.2.6
|
||||
# cmeel-boost
|
||||
# contourpy
|
||||
# datasets
|
||||
# decord
|
||||
# diffusers
|
||||
# dm-control
|
||||
# dm-env
|
||||
@@ -414,10 +399,9 @@ numpy==2.2.6
|
||||
# gymnasium
|
||||
# h5py
|
||||
# hebi-py
|
||||
# hf-libero
|
||||
# imageio
|
||||
# labmaze
|
||||
# lerobot
|
||||
# libero
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# metaworld
|
||||
@@ -442,51 +426,49 @@ numpy==2.2.6
|
||||
# torchvision
|
||||
# transformers
|
||||
# transforms3d
|
||||
nvidia-cublas-cu12==12.8.4.1
|
||||
nvidia-cublas-cu12==12.6.4.1
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# torch
|
||||
nvidia-cuda-cupti-cu12==12.8.90
|
||||
nvidia-cuda-cupti-cu12==12.6.80
|
||||
# via torch
|
||||
nvidia-cuda-nvrtc-cu12==12.8.93
|
||||
nvidia-cuda-nvrtc-cu12==12.6.77
|
||||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.8.90
|
||||
nvidia-cuda-runtime-cu12==12.6.77
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==9.10.2.21
|
||||
nvidia-cudnn-cu12==9.5.1.17
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.3.3.83
|
||||
nvidia-cufft-cu12==11.3.0.4
|
||||
# via torch
|
||||
nvidia-cufile-cu12==1.13.1.3
|
||||
nvidia-cufile-cu12==1.11.1.6
|
||||
# via torch
|
||||
nvidia-curand-cu12==10.3.9.90
|
||||
nvidia-curand-cu12==10.3.7.77
|
||||
# via torch
|
||||
nvidia-cusolver-cu12==11.7.3.90
|
||||
nvidia-cusolver-cu12==11.7.1.2
|
||||
# via torch
|
||||
nvidia-cusparse-cu12==12.5.8.93
|
||||
nvidia-cusparse-cu12==12.5.4.2
|
||||
# via
|
||||
# nvidia-cusolver-cu12
|
||||
# torch
|
||||
nvidia-cusparselt-cu12==0.7.1
|
||||
nvidia-cusparselt-cu12==0.6.3
|
||||
# via torch
|
||||
nvidia-nccl-cu12==2.27.5
|
||||
nvidia-nccl-cu12==2.26.2
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.8.93
|
||||
nvidia-nvjitlink-cu12==12.6.85
|
||||
# via
|
||||
# nvidia-cufft-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# nvidia-cusparse-cu12
|
||||
# torch
|
||||
nvidia-nvshmem-cu12==3.4.5
|
||||
# via torch
|
||||
nvidia-nvtx-cu12==12.8.90
|
||||
nvidia-nvtx-cu12==12.6.77
|
||||
# via torch
|
||||
omegaconf==2.3.0
|
||||
# via hydra-core
|
||||
opencv-python==4.13.0.92
|
||||
opencv-python==4.12.0.88
|
||||
# via
|
||||
# gym-pusht
|
||||
# hf-libero
|
||||
# libero
|
||||
# reachy2-sdk
|
||||
# robosuite
|
||||
opencv-python-headless==4.12.0.88
|
||||
@@ -505,7 +487,6 @@ packaging==25.0
|
||||
# matplotlib
|
||||
# peft
|
||||
# pytest
|
||||
# qwen-vl-utils
|
||||
# reachy2-sdk
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
@@ -516,21 +497,21 @@ pandas==2.3.3
|
||||
# via
|
||||
# datasets
|
||||
# lerobot
|
||||
parso==0.8.6
|
||||
parso==0.8.5
|
||||
# via jedi
|
||||
pathspec==1.0.4
|
||||
# via mypy
|
||||
peft==0.18.1
|
||||
peft==0.17.1
|
||||
# via lerobot
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pillow==12.1.1
|
||||
pfzy==0.3.4
|
||||
# via inquirerpy
|
||||
pillow==12.0.0
|
||||
# via
|
||||
# diffusers
|
||||
# imageio
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# qwen-vl-utils
|
||||
# rerun-sdk
|
||||
# robosuite
|
||||
# scikit-image
|
||||
@@ -538,27 +519,28 @@ pillow==12.1.1
|
||||
# torchvision
|
||||
pin==3.4.0
|
||||
# via placo
|
||||
placo==0.9.16
|
||||
placo==0.9.14
|
||||
# via lerobot
|
||||
platformdirs==4.9.4
|
||||
platformdirs==4.5.0
|
||||
# via
|
||||
# jupyter-core
|
||||
# python-discovery
|
||||
# virtualenv
|
||||
# wandb
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
pre-commit==4.5.1
|
||||
pre-commit==4.3.0
|
||||
# via lerobot
|
||||
prompt-toolkit==3.0.52
|
||||
# via ipython
|
||||
# via
|
||||
# inquirerpy
|
||||
# ipython
|
||||
propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
protobuf==6.31.1
|
||||
protobuf==6.31.0
|
||||
# via
|
||||
# dm-control
|
||||
# grpcio-tools
|
||||
@@ -568,7 +550,7 @@ protobuf==6.31.1
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# wandb
|
||||
psutil==7.2.2
|
||||
psutil==7.1.1
|
||||
# via
|
||||
# accelerate
|
||||
# imageio
|
||||
@@ -578,17 +560,17 @@ ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
# via stack-data
|
||||
pyarrow==23.0.1
|
||||
pyarrow==21.0.0
|
||||
# via
|
||||
# datasets
|
||||
# rerun-sdk
|
||||
pycparser==3.0
|
||||
pycparser==2.23
|
||||
# via cffi
|
||||
pydantic==2.12.5
|
||||
pydantic==2.12.3
|
||||
# via
|
||||
# fastapi
|
||||
# wandb
|
||||
pydantic-core==2.41.5
|
||||
pydantic-core==2.41.4
|
||||
# via pydantic
|
||||
pygame==2.6.1
|
||||
# via
|
||||
@@ -598,14 +580,12 @@ pygame==2.6.1
|
||||
pygments==2.19.2
|
||||
# via
|
||||
# ipython
|
||||
# ipython-pygments-lexers
|
||||
# pytest
|
||||
# rich
|
||||
pymunk==6.11.1
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pyngrok==7.5.1
|
||||
pyngrok==7.4.1
|
||||
# via meshcat
|
||||
pynput==1.8.1
|
||||
# via
|
||||
@@ -615,7 +595,7 @@ pyopengl==3.1.10
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.3.2
|
||||
pyparsing==3.2.5
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
@@ -641,16 +621,13 @@ pytest-timeout==2.4.0
|
||||
# via lerobot
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# faker
|
||||
# matplotlib
|
||||
# pandas
|
||||
python-discovery==1.1.1
|
||||
# via virtualenv
|
||||
python-dotenv==1.2.2
|
||||
python-dotenv==1.1.1
|
||||
# via uvicorn
|
||||
python-xlib==0.33
|
||||
# via pynput
|
||||
pytz==2026.1.post1
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
@@ -665,6 +642,7 @@ pyyaml==6.0.3
|
||||
# pre-commit
|
||||
# pyngrok
|
||||
# pyyaml-include
|
||||
# timm
|
||||
# transformers
|
||||
# uvicorn
|
||||
# wandb
|
||||
@@ -674,9 +652,7 @@ pyzmq==27.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# meshcat
|
||||
qwen-vl-utils==0.0.14
|
||||
# via lerobot
|
||||
reachy2-sdk==1.0.15
|
||||
reachy2-sdk==1.0.14
|
||||
# via lerobot
|
||||
reachy2-sdk-api==1.0.21
|
||||
# via reachy2-sdk
|
||||
@@ -684,7 +660,7 @@ referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2026.2.28
|
||||
regex==2025.10.23
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
@@ -693,62 +669,60 @@ requests==2.32.5
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# qwen-vl-utils
|
||||
# huggingface-hub
|
||||
# teleop
|
||||
# transformers
|
||||
# wandb
|
||||
rerun-sdk==0.26.2
|
||||
rerun-sdk==0.26.1
|
||||
# via lerobot
|
||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||
# via placo
|
||||
rich==14.3.3
|
||||
# via typer
|
||||
robomimic==0.2.0
|
||||
# via hf-libero
|
||||
# via libero
|
||||
robosuite==1.4.0
|
||||
# via hf-libero
|
||||
rpds-py==0.30.0
|
||||
# via libero
|
||||
rpds-py==0.28.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.7.0
|
||||
safetensors==0.6.2
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# transformers
|
||||
scikit-image==0.25.2
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
scipy==1.17.1
|
||||
scipy==1.15.3
|
||||
# via
|
||||
# dm-control
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# torchdiffeq
|
||||
sentry-sdk==2.54.0
|
||||
sentry-sdk==2.42.1
|
||||
# via wandb
|
||||
shapely==2.1.2
|
||||
# via gym-pusht
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.17.0
|
||||
# via
|
||||
# pynput
|
||||
# python-dateutil
|
||||
# python-xlib
|
||||
smmap==5.0.3
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
starlette==0.52.1
|
||||
starlette==0.48.0
|
||||
# via fastapi
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
teleop==0.1.4
|
||||
teleop==0.1.2
|
||||
# via lerobot
|
||||
tensorboard==2.20.0
|
||||
# via robomimic
|
||||
@@ -756,38 +730,46 @@ tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
tensorboardx==2.6.4
|
||||
# via robomimic
|
||||
termcolor==3.3.0
|
||||
termcolor==3.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
thop==0.1.1.post2209072238
|
||||
# via hf-libero
|
||||
tifffile==2026.3.3
|
||||
# via libero
|
||||
tifffile==2025.5.10
|
||||
# via scikit-image
|
||||
tokenizers==0.22.2
|
||||
timm==1.0.20
|
||||
# via lerobot
|
||||
tokenizers==0.22.1
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via draccus
|
||||
torch==2.10.0
|
||||
tomli==2.3.0
|
||||
# via
|
||||
# cmeel
|
||||
# coverage
|
||||
# jupytext
|
||||
# pytest
|
||||
torch==2.7.1
|
||||
# via
|
||||
# accelerate
|
||||
# flash-attn
|
||||
# lerobot
|
||||
# peft
|
||||
# robomimic
|
||||
# thop
|
||||
# torchdiffeq
|
||||
# timm
|
||||
# torchvision
|
||||
torchcodec==0.10.0
|
||||
torchcodec==0.5
|
||||
# via lerobot
|
||||
torchdiffeq==0.2.5
|
||||
# via lerobot
|
||||
torchvision==0.25.0
|
||||
torchvision==0.22.1
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
tornado==6.5.4
|
||||
# timm
|
||||
tornado==6.5.2
|
||||
# via meshcat
|
||||
tqdm==4.67.3
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# datasets
|
||||
# dm-control
|
||||
@@ -801,29 +783,26 @@ traitlets==5.14.3
|
||||
# jupyter-core
|
||||
# matplotlib-inline
|
||||
# nbformat
|
||||
transformers==5.3.0
|
||||
transformers==4.57.1
|
||||
# via
|
||||
# hf-libero
|
||||
# lerobot
|
||||
# libero
|
||||
# peft
|
||||
transforms3d==0.4.2
|
||||
# via teleop
|
||||
triton==3.6.0
|
||||
triton==3.3.1
|
||||
# via torch
|
||||
typer==0.24.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# etils
|
||||
# faker
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
# gymnasium
|
||||
# huggingface-hub
|
||||
# mypy
|
||||
# ipython
|
||||
# multidict
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
@@ -832,46 +811,46 @@ typing-extensions==4.15.0
|
||||
# torch
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# uvicorn
|
||||
# virtualenv
|
||||
# wandb
|
||||
typing-inspect==0.9.0
|
||||
# via draccus
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# pydantic
|
||||
tzdata==2025.3
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via pandas
|
||||
u-msgpack-python==2.8.0
|
||||
# via meshcat
|
||||
urllib3==2.6.3
|
||||
urllib3==2.5.0
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
uvicorn[standard]==0.41.0
|
||||
uvicorn[standard]==0.38.0
|
||||
# via teleop
|
||||
uvloop==0.22.1
|
||||
# via uvicorn
|
||||
virtualenv==21.1.0
|
||||
virtualenv==20.35.3
|
||||
# via pre-commit
|
||||
wandb==0.24.2
|
||||
wandb==0.21.4
|
||||
# via
|
||||
# hf-libero
|
||||
# lerobot
|
||||
# libero
|
||||
watchfiles==1.1.1
|
||||
# via uvicorn
|
||||
wcwidth==0.6.0
|
||||
wcwidth==0.2.14
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.9.0
|
||||
# via teleop
|
||||
websockets==16.0
|
||||
websockets==15.0.1
|
||||
# via uvicorn
|
||||
werkzeug==3.1.6
|
||||
werkzeug==3.1.3
|
||||
# via tensorboard
|
||||
wrapt==2.1.2
|
||||
wrapt==2.0.0
|
||||
# via dm-tree
|
||||
xxhash==3.6.0
|
||||
# via datasets
|
||||
yarl==1.23.0
|
||||
yarl==1.22.0
|
||||
# via aiohttp
|
||||
zipp==3.23.0
|
||||
# via
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# requirements.in
|
||||
|
||||
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.3.1 25D2128 arm64).
|
||||
# Darwin MacBook-Pro.local 25.3.0 Darwin Kernel Version 25.3.0: Wed Jan 28 20:54:55 PST 2026; root:xnu-12377.91.3~2/RELEASE_ARM64_T8132 arm64
|
||||
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.0.1 25A362 arm64).
|
||||
# Darwin MacBook-Pro.local 25.0.0 Darwin Kernel Version 25.0.0: Wed Sep 17 21:42:08 PDT 2025; root:xnu-12377.1.9~141/RELEASE_ARM64_T8132 arm64
|
||||
|
||||
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.4 LTS x86_64).
|
||||
# Linux lerobot-linux 6.17.0-14-generic #14~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Jan 15 15:52:10 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
||||
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.3 LTS x86_64).
|
||||
# Linux mlerobot-linux 6.14.0-33-generic #33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
||||
|
||||
-e .[all]
|
||||
|
||||
@@ -1,689 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Chunk-level multi-modality analysis for comparing full/mixed vs curated datasets.
|
||||
|
||||
Treats each action chunk (sliding window of CHUNK_SIZE consecutive frames) as the
|
||||
atomic unit, tagged by the SARM progress score at its start frame. For each
|
||||
progress band, compares the full vs HQ dataset on:
|
||||
|
||||
1. Intra-band action variance
|
||||
2. Progress delta per chunk
|
||||
3. GMM + BIC optimal K (number of distinct strategies)
|
||||
4. PCA embedding (visual cluster inspection)
|
||||
|
||||
Usage:
|
||||
python chunk_multimodality_analysis.py \\
|
||||
--full-dataset lerobot-data-collection/level12_rac_2_2026-02-08_1 \\
|
||||
--hq-dataset lerobot-data-collection/level2_final_quality3 \\
|
||||
--output-dir ./chunk_analysis
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import hf_hub_download
|
||||
from scipy.stats import gaussian_kde
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.mixture import GaussianMixture
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Visual style ──────────────────────────────────────────────────────────
|
||||
|
||||
BG = "#0e1117"
|
||||
CARD = "#1a1d27"
|
||||
BORDER = "#2a2d3a"
|
||||
SUB = "#8b8fa8"
|
||||
TEXT = "#e8eaf0"
|
||||
C_FULL = "#f7934f"
|
||||
C_HQ = "#4dc98a"
|
||||
|
||||
|
||||
def _style_ax(ax: plt.Axes) -> None:
|
||||
ax.set_facecolor(CARD)
|
||||
ax.tick_params(colors=SUB, labelsize=8)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color(BORDER)
|
||||
|
||||
|
||||
def _save(fig: plt.Figure, path: Path) -> None:
|
||||
fig.savefig(path, dpi=150, bbox_inches="tight", facecolor=BG)
|
||||
plt.close(fig)
|
||||
logger.info("Saved %s", path)
|
||||
|
||||
|
||||
# ── Step 0: Load episodes ────────────────────────────────────────────────
|
||||
|
||||
def _load_sarm_progress(repo_id: str) -> pd.DataFrame | None:
|
||||
"""Try to download sarm_progress.parquet from the Hub."""
|
||||
try:
|
||||
path = hf_hub_download(
|
||||
repo_id=repo_id, filename="sarm_progress.parquet",
|
||||
repo_type="dataset",
|
||||
)
|
||||
df = pd.read_parquet(path)
|
||||
col = "progress_sparse" if "progress_sparse" in df.columns else "progress_dense"
|
||||
if col not in df.columns:
|
||||
logger.warning("sarm_progress.parquet has no progress columns — ignoring")
|
||||
return None
|
||||
logger.info("Loaded SARM progress (%s) for %s (%d rows)", col, repo_id, len(df))
|
||||
return df.rename(columns={col: "progress"})[["episode_index", "frame_index", "progress"]]
|
||||
except Exception as exc:
|
||||
logger.warning("Could not load sarm_progress.parquet for %s: %s", repo_id, exc)
|
||||
return None
|
||||
|
||||
|
||||
def load_episodes(
|
||||
repo_id: str,
|
||||
n_joints: int = 16,
|
||||
max_episodes: int | None = None,
|
||||
) -> list[dict]:
|
||||
dataset = LeRobotDataset(repo_id, download_videos=False)
|
||||
raw = dataset.hf_dataset
|
||||
|
||||
sarm_df = _load_sarm_progress(repo_id)
|
||||
# Build per-episode progress arrays from SARM parquet (indexed by frame_index)
|
||||
sarm_by_ep: dict[int, dict[int, float]] = {}
|
||||
if sarm_df is not None:
|
||||
if max_episodes is not None:
|
||||
sarm_df = sarm_df[sarm_df["episode_index"] < max_episodes]
|
||||
for ep_id, grp in sarm_df.groupby("episode_index"):
|
||||
sarm_by_ep[int(ep_id)] = dict(
|
||||
zip(grp["frame_index"].astype(int), grp["progress"].astype(float))
|
||||
)
|
||||
|
||||
episodes: dict[int, dict] = defaultdict(lambda: {"actions": [], "progress": []})
|
||||
for row in raw:
|
||||
ep = int(row["episode_index"])
|
||||
if max_episodes is not None and ep >= max_episodes:
|
||||
continue
|
||||
action = np.array(row["action"], dtype=np.float32)[:n_joints]
|
||||
episodes[ep]["actions"].append(action)
|
||||
fi = int(row["frame_index"])
|
||||
ep_prog = sarm_by_ep.get(ep, {})
|
||||
episodes[ep]["progress"].append(ep_prog.get(fi, float("nan")))
|
||||
|
||||
has_sarm = len(sarm_lookup) > 0
|
||||
result = []
|
||||
for ep_id, d in sorted(episodes.items()):
|
||||
actions = np.stack(d["actions"])
|
||||
T = len(actions)
|
||||
if has_sarm:
|
||||
prog = np.array(d["progress"], dtype=np.float32)
|
||||
prog = np.clip(np.nan_to_num(prog, nan=0.0), 0.0, 1.0)
|
||||
prog = np.maximum.accumulate(prog)
|
||||
else:
|
||||
prog = np.linspace(0.0, 1.0, T, dtype=np.float32)
|
||||
result.append({"episode": ep_id, "actions": actions, "progress": prog})
|
||||
|
||||
src = "SARM" if has_sarm else "time-based"
|
||||
logger.info("Progress source: %s", src)
|
||||
return result
|
||||
|
||||
|
||||
# ── Step 1: Filter short episodes ────────────────────────────────────────
|
||||
|
||||
def auto_length_threshold(
|
||||
episodes_full: list[dict], episodes_hq: list[dict]
|
||||
) -> int:
|
||||
all_lengths = np.array(
|
||||
[e["actions"].shape[0] for e in episodes_full + episodes_hq]
|
||||
)
|
||||
kde = gaussian_kde(all_lengths, bw_method=0.25)
|
||||
xs = np.linspace(all_lengths.min(), np.percentile(all_lengths, 40), 300)
|
||||
return int(xs[np.argmin(kde(xs))])
|
||||
|
||||
|
||||
def plot_length_distribution(
|
||||
episodes_full: list[dict],
|
||||
episodes_hq: list[dict],
|
||||
threshold: int,
|
||||
out_path: Path,
|
||||
) -> None:
|
||||
lens_full = np.array([e["actions"].shape[0] for e in episodes_full])
|
||||
lens_hq = np.array([e["actions"].shape[0] for e in episodes_hq])
|
||||
all_lens = np.concatenate([lens_full, lens_hq])
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 5))
|
||||
fig.patch.set_facecolor(BG)
|
||||
_style_ax(ax)
|
||||
|
||||
bins = np.linspace(all_lens.min(), all_lens.max(), 50)
|
||||
ax.hist(lens_full, bins=bins, alpha=0.5, color=C_FULL, label="Full/Mixed")
|
||||
ax.hist(lens_hq, bins=bins, alpha=0.5, color=C_HQ, label="HQ")
|
||||
|
||||
xs = np.linspace(all_lens.min(), all_lens.max(), 300)
|
||||
kde = gaussian_kde(all_lens, bw_method=0.25)
|
||||
ax.plot(xs, kde(xs) * len(all_lens) * (bins[1] - bins[0]), color=TEXT, lw=1.5, label="KDE (combined)")
|
||||
|
||||
ax.axvline(threshold, color="#ff4b4b", ls="--", lw=1.5, label=f"Threshold = {threshold}")
|
||||
ax.set_xlabel("Episode length (frames)", color=SUB)
|
||||
ax.set_ylabel("Count", color=SUB)
|
||||
ax.set_title("Episode Length Distribution", color=TEXT, fontsize=13)
|
||||
ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8)
|
||||
_save(fig, out_path)
|
||||
|
||||
|
||||
def filter_episodes(episodes: list[dict], min_length: int) -> list[dict]:
|
||||
kept = [e for e in episodes if e["actions"].shape[0] >= min_length]
|
||||
logger.info("Kept %d / %d episodes (min_length=%d)", len(kept), len(episodes), min_length)
|
||||
return kept
|
||||
|
||||
|
||||
# ── Step 2: Extract chunks ───────────────────────────────────────────────
|
||||
|
||||
def extract_chunks(
|
||||
episodes: list[dict],
|
||||
chunk_size: int = 30,
|
||||
chunk_stride: int = 15,
|
||||
) -> list[dict]:
|
||||
chunks = []
|
||||
for ep in episodes:
|
||||
actions = ep["actions"]
|
||||
T = len(actions)
|
||||
prog = ep["progress"]
|
||||
|
||||
for t in range(0, T - chunk_size, chunk_stride):
|
||||
chunk = actions[t : t + chunk_size]
|
||||
p_start = float(prog[t])
|
||||
p_end = float(prog[min(t + chunk_size, T - 1)])
|
||||
|
||||
chunks.append({
|
||||
"action_mean": chunk.mean(axis=0).astype(np.float32),
|
||||
"action_flat": chunk.flatten().astype(np.float32),
|
||||
"progress_start": p_start,
|
||||
"progress_delta": p_end - p_start,
|
||||
"episode": ep["episode"],
|
||||
})
|
||||
return chunks
|
||||
|
||||
|
||||
# ── Step 3: Adaptive progress bands ─────────────────────────────────────
|
||||
|
||||
def make_bands(n_bands: int = 5) -> list[tuple[float, float]]:
|
||||
edges = np.linspace(0.0, 1.0, n_bands + 1)
|
||||
return [(float(edges[i]), float(edges[i + 1])) for i in range(n_bands)]
|
||||
|
||||
|
||||
def assign_bands(
|
||||
chunks: list[dict], band_edges: list[tuple[float, float]]
|
||||
) -> list[dict]:
|
||||
n = len(band_edges)
|
||||
for c in chunks:
|
||||
p = c["progress_start"]
|
||||
c["band"] = next(
|
||||
(bi for bi, (lo, hi) in enumerate(band_edges) if p < hi),
|
||||
n - 1,
|
||||
)
|
||||
return chunks
|
||||
|
||||
|
||||
def split_by_band(chunks: list[dict], n_bands: int) -> dict[int, list[dict]]:
|
||||
out: dict[int, list[dict]] = {b: [] for b in range(n_bands)}
|
||||
for c in chunks:
|
||||
out[c["band"]].append(c)
|
||||
return out
|
||||
|
||||
|
||||
# ── Step 4: Intra-band action variance ──────────────────────────────────
|
||||
|
||||
def band_variance_matrix(
|
||||
bands: dict[int, list[dict]], n_bands: int, n_joints: int
|
||||
) -> np.ndarray:
|
||||
var_mat = np.full((n_bands, n_joints), np.nan)
|
||||
for b, clist in bands.items():
|
||||
if len(clist) < 3:
|
||||
continue
|
||||
means = np.stack([c["action_mean"] for c in clist])
|
||||
var_mat[b] = np.var(means, axis=0)
|
||||
return var_mat
|
||||
|
||||
|
||||
def plot_variance_heatmap(
|
||||
var_full: np.ndarray,
|
||||
var_hq: np.ndarray,
|
||||
band_edges: list[tuple[float, float]],
|
||||
out_path: Path,
|
||||
) -> None:
|
||||
n_bands = var_full.shape[0]
|
||||
vmin = 0.0
|
||||
vmax = max(np.nanmax(var_full), np.nanmax(var_hq))
|
||||
|
||||
band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]
|
||||
joint_labels = [f"J{j}" for j in range(var_full.shape[1])]
|
||||
|
||||
fig, axes = plt.subplots(3, 1, figsize=(12, 10), gridspec_kw={"height_ratios": [3, 3, 2]})
|
||||
fig.patch.set_facecolor(BG)
|
||||
fig.suptitle("Intra-Band Action Variance", color=TEXT, fontsize=14, y=0.98)
|
||||
|
||||
for ax_idx, (mat, label) in enumerate([(var_full, "Full/Mixed"), (var_hq, "HQ")]):
|
||||
ax = axes[ax_idx]
|
||||
_style_ax(ax)
|
||||
im = ax.imshow(mat, aspect="auto", cmap="YlOrRd", vmin=vmin, vmax=vmax)
|
||||
ax.set_yticks(range(n_bands))
|
||||
ax.set_yticklabels(band_labels, fontsize=7, color=SUB)
|
||||
ax.set_xticks(range(var_full.shape[1]))
|
||||
ax.set_xticklabels(joint_labels, fontsize=7, color=SUB)
|
||||
ax.set_title(f"Panel {'A' if ax_idx == 0 else 'B'}: {label}", color=TEXT, fontsize=11)
|
||||
fig.colorbar(im, ax=ax, fraction=0.02, pad=0.02)
|
||||
|
||||
with np.errstate(invalid="ignore"):
|
||||
mean_full = np.nanmean(var_full, axis=1)
|
||||
mean_hq = np.nanmean(var_hq, axis=1)
|
||||
ratio = np.where(np.isnan(mean_full) | np.isnan(mean_hq), np.nan,
|
||||
mean_full / (mean_hq + 1e-8))
|
||||
ax_bar = axes[2]
|
||||
_style_ax(ax_bar)
|
||||
colors = [
|
||||
"#ff4b4b" if r > 2.0 else "#ffaa33" if r > 1.2 else C_HQ
|
||||
for r in ratio
|
||||
]
|
||||
ax_bar.bar(range(n_bands), ratio, color=colors, edgecolor=BORDER)
|
||||
ax_bar.axhline(1.0, color=SUB, ls="--", lw=0.8)
|
||||
ax_bar.set_xticks(range(n_bands))
|
||||
ax_bar.set_xticklabels(band_labels, fontsize=7, color=SUB)
|
||||
ax_bar.set_ylabel("Variance ratio\n(Full / HQ)", color=SUB, fontsize=9)
|
||||
ax_bar.set_title("Panel C: Variance Ratio per Band", color=TEXT, fontsize=11)
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.96])
|
||||
_save(fig, out_path)
|
||||
|
||||
|
||||
# ── Step 5: Progress delta per band ──────────────────────────────────────
|
||||
|
||||
def plot_progress_delta(
|
||||
bands_full: dict[int, list[dict]],
|
||||
bands_hq: dict[int, list[dict]],
|
||||
band_edges: list[tuple[float, float]],
|
||||
out_path: Path,
|
||||
) -> None:
|
||||
n_bands = len(band_edges)
|
||||
band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]
|
||||
x = np.arange(n_bands)
|
||||
w = 0.35
|
||||
|
||||
means_full, stds_full = [], []
|
||||
means_hq, stds_hq = [], []
|
||||
all_deltas_full, all_deltas_hq = [], []
|
||||
|
||||
for b in range(n_bands):
|
||||
df = np.array([c["progress_delta"] for c in bands_full.get(b, [])])
|
||||
dh = np.array([c["progress_delta"] for c in bands_hq.get(b, [])])
|
||||
means_full.append(np.mean(df) if len(df) > 0 else 0)
|
||||
stds_full.append(np.std(df) if len(df) > 0 else 0)
|
||||
means_hq.append(np.mean(dh) if len(dh) > 0 else 0)
|
||||
stds_hq.append(np.std(dh) if len(dh) > 0 else 0)
|
||||
all_deltas_full.extend(df.tolist())
|
||||
all_deltas_hq.extend(dh.tolist())
|
||||
|
||||
fig, (ax_bar, ax_viol) = plt.subplots(1, 2, figsize=(14, 5), gridspec_kw={"width_ratios": [3, 1]})
|
||||
fig.patch.set_facecolor(BG)
|
||||
fig.suptitle("Progress Delta per Chunk", color=TEXT, fontsize=14)
|
||||
|
||||
_style_ax(ax_bar)
|
||||
ax_bar.bar(x - w / 2, means_full, w, yerr=stds_full, color=C_FULL, edgecolor=BORDER,
|
||||
capsize=3, label="Full/Mixed", error_kw={"ecolor": SUB})
|
||||
ax_bar.bar(x + w / 2, means_hq, w, yerr=stds_hq, color=C_HQ, edgecolor=BORDER,
|
||||
capsize=3, label="HQ", error_kw={"ecolor": SUB})
|
||||
ax_bar.set_xticks(x)
|
||||
ax_bar.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30)
|
||||
ax_bar.set_ylabel("Mean progress Δ", color=SUB)
|
||||
ax_bar.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8)
|
||||
|
||||
_style_ax(ax_viol)
|
||||
data_viol = [np.array(all_deltas_full), np.array(all_deltas_hq)]
|
||||
if all(len(d) > 0 for d in data_viol):
|
||||
parts = ax_viol.violinplot(data_viol, positions=[0, 1], showmeans=True, showmedians=True)
|
||||
for pc, c in zip(parts["bodies"], [C_FULL, C_HQ]):
|
||||
pc.set_facecolor(c)
|
||||
pc.set_alpha(0.7)
|
||||
for key in ("cmeans", "cmedians", "cbars", "cmins", "cmaxes"):
|
||||
if key in parts:
|
||||
parts[key].set_color(SUB)
|
||||
ax_viol.set_xticks([0, 1])
|
||||
ax_viol.set_xticklabels(["Full", "HQ"], color=SUB)
|
||||
ax_viol.set_ylabel("Progress Δ", color=SUB)
|
||||
ax_viol.set_title("Overall Distribution", color=TEXT, fontsize=10)
|
||||
|
||||
fig.tight_layout()
|
||||
_save(fig, out_path)
|
||||
|
||||
|
||||
# ── Step 6: GMM + BIC per band ──────────────────────────────────────────
|
||||
|
||||
def gmm_optimal_k(
|
||||
band_chunks: list[dict],
|
||||
pca_components: int = 15,
|
||||
max_k: int = 12,
|
||||
seed: int = 42,
|
||||
) -> int | None:
|
||||
if len(band_chunks) < 20:
|
||||
return None
|
||||
X = np.stack([c["action_flat"] for c in band_chunks])
|
||||
X = StandardScaler().fit_transform(X)
|
||||
n = min(pca_components, X.shape[1], X.shape[0] - 1)
|
||||
X_r = PCA(n_components=n, random_state=seed).fit_transform(X)
|
||||
bics = []
|
||||
for k in range(1, min(max_k + 1, len(X_r) // 6)):
|
||||
gmm = GaussianMixture(
|
||||
n_components=k, covariance_type="full",
|
||||
n_init=5, max_iter=300, random_state=seed,
|
||||
)
|
||||
gmm.fit(X_r)
|
||||
bics.append((k, gmm.bic(X_r)))
|
||||
if not bics:
|
||||
return None
|
||||
return min(bics, key=lambda x: x[1])[0]
|
||||
|
||||
|
||||
def plot_gmm_bic(
|
||||
bands_full: dict[int, list[dict]],
|
||||
bands_hq: dict[int, list[dict]],
|
||||
band_edges: list[tuple[float, float]],
|
||||
seed: int,
|
||||
out_path: Path,
|
||||
) -> tuple[list[int | None], list[int | None]]:
|
||||
n_bands = len(band_edges)
|
||||
ks_full = [gmm_optimal_k(bands_full.get(b, []), seed=seed) for b in range(n_bands)]
|
||||
ks_hq = [gmm_optimal_k(bands_hq.get(b, []), seed=seed) for b in range(n_bands)]
|
||||
|
||||
band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 5))
|
||||
fig.patch.set_facecolor(BG)
|
||||
_style_ax(ax)
|
||||
|
||||
xs = np.arange(n_bands)
|
||||
valid_full = [(i, k) for i, k in enumerate(ks_full) if k is not None]
|
||||
valid_hq = [(i, k) for i, k in enumerate(ks_hq) if k is not None]
|
||||
|
||||
if valid_full:
|
||||
xi, yi = zip(*valid_full)
|
||||
ax.plot(xi, yi, "o-", color=C_FULL, label="Full/Mixed", lw=2, markersize=7)
|
||||
if valid_hq:
|
||||
xi, yi = zip(*valid_hq)
|
||||
ax.plot(xi, yi, "o-", color=C_HQ, label="HQ", lw=2, markersize=7)
|
||||
|
||||
if valid_full and valid_hq:
|
||||
all_x = sorted(set([i for i, _ in valid_full]) & set([i for i, _ in valid_hq]))
|
||||
if len(all_x) >= 2:
|
||||
kf_interp = {i: k for i, k in valid_full}
|
||||
kh_interp = {i: k for i, k in valid_hq}
|
||||
shared_x = [i for i in all_x if i in kf_interp and i in kh_interp]
|
||||
yf = [kf_interp[i] for i in shared_x]
|
||||
yh = [kh_interp[i] for i in shared_x]
|
||||
ax.fill_between(shared_x, yf, yh, alpha=0.15, color=TEXT)
|
||||
|
||||
ax.set_xticks(xs)
|
||||
ax.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30)
|
||||
ax.set_ylabel("Optimal K (GMM-BIC)", color=SUB)
|
||||
ax.set_title("Number of Distinct Strategies per Band", color=TEXT, fontsize=13)
|
||||
ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=9)
|
||||
ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
|
||||
fig.tight_layout()
|
||||
_save(fig, out_path)
|
||||
return ks_full, ks_hq
|
||||
|
||||
|
||||
# ── Step 7: PCA scatter per band ────────────────────────────────────────
|
||||
|
||||
def plot_pca_scatter(
|
||||
bands_full: dict[int, list[dict]],
|
||||
bands_hq: dict[int, list[dict]],
|
||||
band_edges: list[tuple[float, float]],
|
||||
out_path: Path,
|
||||
) -> None:
|
||||
n_plot = min(4, len(band_edges))
|
||||
fig, axes = plt.subplots(2, n_plot, figsize=(4 * n_plot, 7))
|
||||
fig.patch.set_facecolor(BG)
|
||||
fig.suptitle("PCA of Action Chunks per Band", color=TEXT, fontsize=14)
|
||||
|
||||
if n_plot == 1:
|
||||
axes = axes.reshape(2, 1)
|
||||
|
||||
for col, b in enumerate(range(n_plot)):
|
||||
cf = bands_full.get(b, [])
|
||||
ch = bands_hq.get(b, [])
|
||||
lo, hi = band_edges[b]
|
||||
|
||||
for row, (clist, color, label) in enumerate([
|
||||
(cf, C_FULL, "Full/Mixed"), (ch, C_HQ, "HQ")
|
||||
]):
|
||||
ax = axes[row, col]
|
||||
_style_ax(ax)
|
||||
if row == 0:
|
||||
ax.set_title(f"{lo:.0%}–{hi:.0%}", color=TEXT, fontsize=10)
|
||||
if col == 0:
|
||||
ax.set_ylabel(label, color=SUB, fontsize=9)
|
||||
|
||||
if len(cf) < 3 or len(ch) < 3:
|
||||
ax.text(0.5, 0.5, "Too few\nchunks", transform=ax.transAxes,
|
||||
ha="center", va="center", color=SUB, fontsize=9)
|
||||
continue
|
||||
|
||||
X_full_b = np.stack([c["action_flat"] for c in cf])
|
||||
X_hq_b = np.stack([c["action_flat"] for c in ch])
|
||||
X_all = np.vstack([X_full_b, X_hq_b])
|
||||
X_all = StandardScaler().fit_transform(X_all)
|
||||
X_2d = PCA(n_components=2, random_state=42).fit_transform(X_all)
|
||||
|
||||
X_2d_full = X_2d[: len(cf)]
|
||||
X_2d_hq = X_2d[len(cf) :]
|
||||
|
||||
pts = X_2d_full if row == 0 else X_2d_hq
|
||||
ax.scatter(pts[:, 0], pts[:, 1], s=8, alpha=0.5, color=color, edgecolors="none")
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.95])
|
||||
_save(fig, out_path)
|
||||
|
||||
|
||||
# ── Plot 1: Chunk counts per band ───────────────────────────────────────
|
||||
|
||||
def plot_chunk_counts(
|
||||
bands_full: dict[int, list[dict]],
|
||||
bands_hq: dict[int, list[dict]],
|
||||
band_edges: list[tuple[float, float]],
|
||||
out_path: Path,
|
||||
) -> None:
|
||||
n_bands = len(band_edges)
|
||||
band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]
|
||||
x = np.arange(n_bands)
|
||||
w = 0.35
|
||||
|
||||
counts_full = [len(bands_full.get(b, [])) for b in range(n_bands)]
|
||||
counts_hq = [len(bands_hq.get(b, [])) for b in range(n_bands)]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 5))
|
||||
fig.patch.set_facecolor(BG)
|
||||
_style_ax(ax)
|
||||
|
||||
ax.bar(x - w / 2, counts_full, w, color=C_FULL, edgecolor=BORDER, label="Full/Mixed")
|
||||
ax.bar(x + w / 2, counts_hq, w, color=C_HQ, edgecolor=BORDER, label="HQ")
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30)
|
||||
ax.set_ylabel("Chunk count", color=SUB)
|
||||
ax.set_title("Chunk Counts per Progress Band", color=TEXT, fontsize=13)
|
||||
ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8)
|
||||
fig.tight_layout()
|
||||
_save(fig, out_path)
|
||||
|
||||
|
||||
# ── Summary figure ───────────────────────────────────────────────────────
|
||||
|
||||
def plot_summary(
|
||||
var_full: np.ndarray,
|
||||
var_hq: np.ndarray,
|
||||
band_edges: list[tuple[float, float]],
|
||||
ks_full: list[int | None],
|
||||
ks_hq: list[int | None],
|
||||
bands_full: dict[int, list[dict]],
|
||||
bands_hq: dict[int, list[dict]],
|
||||
out_path: Path,
|
||||
) -> None:
|
||||
with np.errstate(invalid="ignore"):
|
||||
mean_full = np.nanmean(var_full, axis=1)
|
||||
mean_hq = np.nanmean(var_hq, axis=1)
|
||||
ratio = np.where(np.isnan(mean_full) | np.isnan(mean_hq), np.nan,
|
||||
mean_full / (mean_hq + 1e-8))
|
||||
valid_ratio = ratio[~np.isnan(ratio)]
|
||||
mean_ratio = float(np.mean(valid_ratio)) if len(valid_ratio) > 0 else float("nan")
|
||||
peak_idx = int(np.argmax(valid_ratio)) if len(valid_ratio) > 0 else 0
|
||||
peak_ratio = float(valid_ratio[peak_idx]) if len(valid_ratio) > 0 else float("nan")
|
||||
lo, hi = band_edges[peak_idx]
|
||||
peak_band = f"{lo:.0%}–{hi:.0%}"
|
||||
|
||||
valid_kf = [k for k in ks_full if k is not None]
|
||||
valid_kh = [k for k in ks_hq if k is not None]
|
||||
mean_k_full = np.mean(valid_kf) if valid_kf else float("nan")
|
||||
mean_k_hq = np.mean(valid_kh) if valid_kh else float("nan")
|
||||
|
||||
n_bands = len(band_edges)
|
||||
deltas_full = [c["progress_delta"] for b in range(n_bands) for c in bands_full.get(b, [])]
|
||||
deltas_hq = [c["progress_delta"] for b in range(n_bands) for c in bands_hq.get(b, [])]
|
||||
mean_delta_full = float(np.mean(deltas_full)) if deltas_full else float("nan")
|
||||
mean_delta_hq = float(np.mean(deltas_hq)) if deltas_hq else float("nan")
|
||||
|
||||
rows = [
|
||||
("Mean variance ratio (Full / HQ)", f"{mean_ratio:.2f}x"),
|
||||
("Peak variance ratio", f"{peak_ratio:.2f}x at {peak_band}"),
|
||||
("Mean GMM K — Full", f"{mean_k_full:.1f}"),
|
||||
("Mean GMM K — HQ", f"{mean_k_hq:.1f}"),
|
||||
("Mean progress Δ — Full", f"{mean_delta_full:.4f}"),
|
||||
("Mean progress Δ — HQ", f"{mean_delta_hq:.4f}"),
|
||||
]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 3))
|
||||
fig.patch.set_facecolor(BG)
|
||||
ax.set_facecolor(CARD)
|
||||
ax.axis("off")
|
||||
|
||||
table = ax.table(
|
||||
cellText=[[m, v] for m, v in rows],
|
||||
colLabels=["Metric", "Value"],
|
||||
loc="center",
|
||||
cellLoc="left",
|
||||
)
|
||||
table.auto_set_font_size(False)
|
||||
table.set_fontsize(10)
|
||||
for key, cell in table.get_celld().items():
|
||||
cell.set_edgecolor(BORDER)
|
||||
cell.set_facecolor(CARD)
|
||||
cell.set_text_props(color=TEXT)
|
||||
if key[0] == 0:
|
||||
cell.set_text_props(color=TEXT, fontweight="bold")
|
||||
table.scale(1, 1.6)
|
||||
ax.set_title("Summary Statistics", color=TEXT, fontsize=13, pad=15)
|
||||
fig.tight_layout()
|
||||
_save(fig, out_path)
|
||||
|
||||
for metric, value in rows:
|
||||
logger.info(" %s: %s", metric, value)
|
||||
|
||||
|
||||
# ── Main ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
out = Path(args.output_dir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("Loading FULL dataset: %s", args.full_dataset)
|
||||
episodes_full = load_episodes(args.full_dataset, args.n_joints, args.max_episodes)
|
||||
logger.info("Loading HQ dataset: %s", args.hq_dataset)
|
||||
episodes_hq = load_episodes(args.hq_dataset, args.n_joints, args.max_episodes)
|
||||
logger.info("Loaded %d full episodes, %d HQ episodes", len(episodes_full), len(episodes_hq))
|
||||
|
||||
# Step 1: length threshold + filter
|
||||
if args.min_episode_length is not None:
|
||||
threshold = args.min_episode_length
|
||||
else:
|
||||
threshold = auto_length_threshold(episodes_full, episodes_hq)
|
||||
logger.info("Episode length threshold: %d", threshold)
|
||||
|
||||
plot_length_distribution(episodes_full, episodes_hq, threshold, out / "0_length_distribution.png")
|
||||
episodes_full = filter_episodes(episodes_full, threshold)
|
||||
episodes_hq = filter_episodes(episodes_hq, threshold)
|
||||
|
||||
# Step 2: extract chunks
|
||||
chunks_full = extract_chunks(episodes_full, args.chunk_size, args.chunk_stride)
|
||||
chunks_hq = extract_chunks(episodes_hq, args.chunk_size, args.chunk_stride)
|
||||
logger.info("Extracted %d full chunks, %d HQ chunks", len(chunks_full), len(chunks_hq))
|
||||
|
||||
# Step 3: fixed equal-width bands over episode-relative progress
|
||||
band_edges = make_bands(args.n_bands)
|
||||
n_bands = len(band_edges)
|
||||
logger.info("Progress bands (%d): %s", n_bands,
|
||||
[f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges])
|
||||
|
||||
chunks_full = assign_bands(chunks_full, band_edges)
|
||||
chunks_hq = assign_bands(chunks_hq, band_edges)
|
||||
bands_full = split_by_band(chunks_full, n_bands)
|
||||
bands_hq = split_by_band(chunks_hq, n_bands)
|
||||
|
||||
# Plot 1: chunk counts
|
||||
plot_chunk_counts(bands_full, bands_hq, band_edges, out / "1_chunk_counts_per_band.png")
|
||||
|
||||
# Step 4: variance heatmap
|
||||
var_full = band_variance_matrix(bands_full, n_bands, args.n_joints)
|
||||
var_hq = band_variance_matrix(bands_hq, n_bands, args.n_joints)
|
||||
plot_variance_heatmap(var_full, var_hq, band_edges, out / "2_variance_heatmap.png")
|
||||
|
||||
# Step 5: progress delta
|
||||
plot_progress_delta(bands_full, bands_hq, band_edges, out / "3_progress_delta_per_band.png")
|
||||
|
||||
# Step 6: GMM BIC
|
||||
ks_full, ks_hq = plot_gmm_bic(bands_full, bands_hq, band_edges, args.seed, out / "4_gmm_bic_per_band.png")
|
||||
|
||||
# Step 7: PCA scatter
|
||||
plot_pca_scatter(bands_full, bands_hq, band_edges, out / "5_pca_per_band.png")
|
||||
|
||||
# Summary
|
||||
plot_summary(var_full, var_hq, band_edges, ks_full, ks_hq,
|
||||
bands_full, bands_hq, out / "6_summary.png")
|
||||
|
||||
logger.info("All figures saved to %s", out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser(
|
||||
description="Chunk-level multi-modality analysis: Full/Mixed vs HQ dataset.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
p.add_argument("--full-dataset", default="lerobot-data-collection/level12_rac_2_2026-02-08_1")
|
||||
p.add_argument("--hq-dataset", default="lerobot-data-collection/level2_final_quality3_trim_0_hil_data")
|
||||
p.add_argument("--output-dir", default="./chunk_analysis")
|
||||
p.add_argument("--chunk-size", type=int, default=30)
|
||||
p.add_argument("--chunk-stride", type=int, default=15)
|
||||
p.add_argument("--n-bands", type=int, default=5, help="Number of equal-width progress bands")
|
||||
p.add_argument("--max-episodes", type=int, default=500)
|
||||
p.add_argument("--n-joints", type=int, default=16)
|
||||
p.add_argument("--min-episode-length", type=int, default=None,
|
||||
help="Override auto-detected length filter threshold")
|
||||
p.add_argument("--seed", type=int, default=42)
|
||||
args = p.parse_args()
|
||||
main(args)
|
||||
@@ -1,29 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=smolvla_libero_plus
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --gpus-per-node=4
|
||||
#SBATCH --cpus-per-task=48
|
||||
#SBATCH --mem=200G
|
||||
#SBATCH --time=12:00:00
|
||||
#SBATCH --output=logs/smolvla_libero_plus_%j.out
|
||||
#SBATCH --error=logs/smolvla_libero_plus_%j.err
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
eval "$(conda shell.bash hook 2>/dev/null)"
|
||||
conda activate lerobot312
|
||||
|
||||
cd /admin/home/pepijn/lerobot_wt_robocasa
|
||||
|
||||
lerobot-benchmark train \
|
||||
--benchmarks libero_plus \
|
||||
--policy-path lerobot/smolvla_base \
|
||||
--hub-user pepijn223 \
|
||||
--num-gpus 4 \
|
||||
--steps 30000 \
|
||||
--batch-size 32 \
|
||||
--eval-freq 0 \
|
||||
--wandb \
|
||||
--dataset.repo_id=pepijn223/libero_plus_lerobot
|
||||
@@ -49,14 +49,9 @@ import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
from lerobot.robots import (
|
||||
RobotConfig, # noqa: F401
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
|
||||
@@ -181,7 +181,7 @@ class ZMQCamera(Camera):
|
||||
try:
|
||||
message = self.socket.recv_string()
|
||||
except Exception as e:
|
||||
# zmq is lazy-imported in connect(), so check by name to avoid a top-level import
|
||||
# Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import
|
||||
if type(e).__name__ == "Again":
|
||||
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
|
||||
raise
|
||||
|
||||
@@ -23,7 +23,6 @@ import base64
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
@@ -43,57 +42,10 @@ def encode_image(image: np.ndarray, quality: int = 80) -> str:
|
||||
return base64.b64encode(buffer).decode("utf-8")
|
||||
|
||||
|
||||
class CameraCaptureThread:
|
||||
"""Background thread that continuously captures and encodes frames from a camera."""
|
||||
|
||||
def __init__(self, camera: OpenCVCamera, name: str):
|
||||
self.camera = camera
|
||||
self.name = name
|
||||
self.latest_encoded: str | None = None # Pre-encoded JPEG as base64
|
||||
self.latest_timestamp: float = 0.0
|
||||
self.frame_lock = threading.Lock()
|
||||
self.running = False
|
||||
self.thread: threading.Thread | None = None
|
||||
|
||||
def start(self):
|
||||
"""Start the capture thread."""
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._capture_loop, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop the capture thread."""
|
||||
self.running = False
|
||||
if self.thread:
|
||||
self.thread.join(timeout=1.0)
|
||||
|
||||
def _capture_loop(self):
|
||||
"""Continuously capture and encode frames at the camera's native rate."""
|
||||
while self.running:
|
||||
try:
|
||||
frame = self.camera.read() # Blocks at camera's native rate
|
||||
timestamp = time.time()
|
||||
# Encode immediately in capture thread (this is the slow part)
|
||||
encoded = encode_image(frame)
|
||||
with self.frame_lock:
|
||||
self.latest_encoded = encoded
|
||||
self.latest_timestamp = timestamp
|
||||
except Exception as e:
|
||||
logger.warning(f"Camera {self.name} capture error: {e}")
|
||||
time.sleep(0.01)
|
||||
|
||||
def get_latest(self) -> tuple[str | None, float]:
|
||||
"""Get the latest encoded frame and its timestamp."""
|
||||
with self.frame_lock:
|
||||
return self.latest_encoded, self.latest_timestamp
|
||||
|
||||
|
||||
class ImageServer:
|
||||
def __init__(self, config: dict, port: int = 5555):
|
||||
# fps controls the publish loop rate (how often frames are sent over ZMQ), not the camera capture rate
|
||||
self.fps = config.get("fps", 30)
|
||||
self.cameras: dict[str, OpenCVCamera] = {}
|
||||
self.capture_threads: dict[str, CameraCaptureThread] = {}
|
||||
|
||||
for name, cfg in config.get("cameras", {}).items():
|
||||
shape = cfg.get("shape", [480, 640])
|
||||
@@ -109,10 +61,6 @@ class ImageServer:
|
||||
self.cameras[name] = camera
|
||||
logger.info(f"Camera {name}: {shape[1]}x{shape[0]}")
|
||||
|
||||
# Create capture thread for this camera
|
||||
capture_thread = CameraCaptureThread(camera, name)
|
||||
self.capture_threads[name] = capture_thread
|
||||
|
||||
# ZMQ PUB socket
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.PUB)
|
||||
@@ -125,18 +73,6 @@ class ImageServer:
|
||||
def run(self):
|
||||
frame_count = 0
|
||||
frame_times = deque(maxlen=60)
|
||||
last_published_ts: dict[str, float] = {}
|
||||
|
||||
# Start all capture threads
|
||||
for capture_thread in self.capture_threads.values():
|
||||
capture_thread.start()
|
||||
|
||||
# Wait for first frames to be captured and encoded
|
||||
logger.info("Waiting for cameras to start capturing...")
|
||||
for name, capture_thread in self.capture_threads.items():
|
||||
while capture_thread.get_latest()[0] is None:
|
||||
time.sleep(0.01)
|
||||
logger.info(f"Camera {name} ready (capture + encode in background)")
|
||||
|
||||
try:
|
||||
while True:
|
||||
@@ -144,12 +80,10 @@ class ImageServer:
|
||||
|
||||
# Build message
|
||||
message = {"timestamps": {}, "images": {}}
|
||||
for name, capture_thread in self.capture_threads.items():
|
||||
encoded, timestamp = capture_thread.get_latest()
|
||||
if encoded is not None and timestamp > last_published_ts.get(name, 0.0):
|
||||
message["timestamps"][name] = timestamp
|
||||
message["images"][name] = encoded
|
||||
last_published_ts[name] = timestamp
|
||||
for name, cam in self.cameras.items():
|
||||
frame = cam.read() # Returns RGB
|
||||
message["timestamps"][name] = time.time()
|
||||
message["images"][name] = encode_image(frame)
|
||||
|
||||
# Send as JSON string (suppress if buffer full)
|
||||
with contextlib.suppress(zmq.Again):
|
||||
@@ -168,8 +102,6 @@ class ImageServer:
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
for capture_thread in self.capture_threads.values():
|
||||
capture_thread.stop()
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
self.socket.close()
|
||||
|
||||
@@ -49,64 +49,15 @@ class WandBConfig:
|
||||
mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalDockerConfig:
|
||||
# Docker image to use for evaluation (e.g., "ghcr.io/org/lerobot-eval-libero:latest").
|
||||
# Takes precedence over eval.envhub_ref.
|
||||
image: str | None = None
|
||||
# Optional EnvHub reference to resolve an image, e.g. "envhub://lerobot/libero_plus@v1".
|
||||
envhub_ref: str | None = None
|
||||
# If true, mount the local repository and prefer local source code in the container.
|
||||
use_local_code: bool = True
|
||||
# Pull the image before running.
|
||||
pull: bool = True
|
||||
# Docker --gpus value. Set to None to disable GPU flags and run CPU-only.
|
||||
gpus: str | None = "all"
|
||||
# Docker --shm-size value (increase when using larger eval.batch_size values).
|
||||
shm_size: str = "8g"
|
||||
# Port on which the host HTTP policy inference server listens.
|
||||
port: int = 50051
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalConfig:
|
||||
n_episodes: int = 50
|
||||
# Number of sub-envs per task inside one VectorEnv. Increase to improve per-task
|
||||
# inference throughput until GPU or simulator memory saturates.
|
||||
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
|
||||
batch_size: int = 50
|
||||
# Use AsyncVectorEnv (multiprocessing). Prefer SyncVectorEnv unless your environment
|
||||
# spends significant time in Python-side stepping and can benefit from process parallelism.
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
use_async_envs: bool = False
|
||||
# Runtime where evaluation executes: "local", "docker", or "multiprocess".
|
||||
# "multiprocess" spawns local worker processes + policy servers.
|
||||
runtime: str = "local"
|
||||
docker: EvalDockerConfig = field(default_factory=EvalDockerConfig)
|
||||
# Number of parallel eval script instances to launch for one run.
|
||||
# instance_count > 1 enables multi-instance task sharding.
|
||||
instance_count: int = 1
|
||||
# 0-indexed shard id for this process. Users usually leave this at 0.
|
||||
# Additional shards are launched automatically by `lerobot-eval` when instance_count > 1.
|
||||
instance_id: int = 0
|
||||
# Number of policy inference servers to run in parallel (docker/multiprocess runtimes).
|
||||
# Each server loads a copy of the model and listens on consecutive ports
|
||||
# starting from eval.docker.port. Workers are round-robin assigned.
|
||||
policy_servers: int = 1
|
||||
# Base port for policy servers in multiprocess mode.
|
||||
port: int = 50051
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.runtime not in {"local", "docker", "multiprocess"}:
|
||||
raise ValueError(
|
||||
f"Unsupported eval.runtime '{self.runtime}'. Expected one of: local, docker, multiprocess."
|
||||
)
|
||||
if self.instance_count < 1:
|
||||
raise ValueError("eval.instance_count must be >= 1.")
|
||||
if self.instance_id < 0 or self.instance_id >= self.instance_count:
|
||||
raise ValueError(
|
||||
f"eval.instance_id must be in [0, {self.instance_count - 1}] (got {self.instance_id})."
|
||||
)
|
||||
if self.policy_servers < 1:
|
||||
raise ValueError("eval.policy_servers must be >= 1.")
|
||||
if self.batch_size > self.n_episodes:
|
||||
raise ValueError(
|
||||
"The eval batch size is greater than the number of eval episodes "
|
||||
|
||||
@@ -40,8 +40,6 @@ class EvalPipelineConfig:
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
# Explicit consent to execute remote code from the Hub (required for hub environments).
|
||||
trust_remote_code: bool = False
|
||||
# Push eval results (metrics JSON, rollout videos, model card update) to the model's Hub repo.
|
||||
push_to_hub: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
|
||||
@@ -50,9 +50,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||
# AND for the evaluation environments.
|
||||
seed: int | None = 1000
|
||||
# Set to True to use deterministic cuDNN algorithms for reproducibility.
|
||||
# This disables cudnn.benchmark and may reduce training speed by ~10-20%.
|
||||
cudnn_deterministic: bool = False
|
||||
# Number of workers for the dataloader.
|
||||
num_workers: int = 4
|
||||
batch_size: int = 8
|
||||
|
||||
@@ -289,9 +289,7 @@ def aggregate_datasets(
|
||||
|
||||
logging.info("Find all tasks")
|
||||
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
|
||||
dst_meta.tasks = pd.DataFrame(
|
||||
{"task_index": range(len(unique_tasks))}, index=pd.Index(unique_tasks, name="task")
|
||||
)
|
||||
dst_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
|
||||
|
||||
meta_idx = {"chunk": 0, "file": 0}
|
||||
data_idx = {"chunk": 0, "file": 0}
|
||||
|
||||
@@ -89,8 +89,8 @@ def delete_episodes(
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
episode_indices: List of episode indices to delete.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
"""
|
||||
if not episode_indices:
|
||||
raise ValueError("No episodes to delete")
|
||||
@@ -152,7 +152,7 @@ def split_dataset(
|
||||
dataset: The source LeRobotDataset to split.
|
||||
splits: Either a dict mapping split names to episode indices, or a dict mapping
|
||||
split names to fractions (must sum to <= 1.0).
|
||||
output_dir: Root directory where the split datasets will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
output_dir: Base directory for output datasets. If None, uses default location.
|
||||
|
||||
Examples:
|
||||
Split by specific episodes
|
||||
@@ -243,8 +243,8 @@ def merge_datasets(
|
||||
|
||||
Args:
|
||||
datasets: List of LeRobotDatasets to merge.
|
||||
output_repo_id: Merged dataset identifier.
|
||||
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
|
||||
output_repo_id: Repository ID for the merged dataset.
|
||||
output_dir: Directory to save the merged dataset. If None, uses default location.
|
||||
"""
|
||||
if not datasets:
|
||||
raise ValueError("No datasets to merge")
|
||||
@@ -288,8 +288,8 @@ def modify_features(
|
||||
dataset: The source LeRobotDataset.
|
||||
add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples.
|
||||
remove_features: Optional feature name(s) to remove. Can be a single string or list.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
|
||||
Returns:
|
||||
New dataset with features modified.
|
||||
@@ -390,8 +390,8 @@ def add_features(
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
features: Dictionary mapping feature names to (feature_values, feature_info) tuples.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
|
||||
Returns:
|
||||
New dataset with all features added.
|
||||
@@ -427,8 +427,8 @@ def remove_feature(
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
feature_names: Name(s) of features to remove. Can be a single string or list.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
|
||||
Returns:
|
||||
New dataset with features removed.
|
||||
@@ -1475,9 +1475,7 @@ def modify_tasks(
|
||||
|
||||
# Collect all unique tasks and create new task mapping
|
||||
unique_tasks = sorted(set(episode_to_task.values()))
|
||||
new_task_df = pd.DataFrame(
|
||||
{"task_index": list(range(len(unique_tasks)))}, index=pd.Index(unique_tasks, name="task")
|
||||
)
|
||||
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
|
||||
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
|
||||
|
||||
logging.info(f"Modifying tasks in {dataset.repo_id}")
|
||||
@@ -1531,7 +1529,7 @@ def modify_tasks(
|
||||
|
||||
def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path | None = None,
|
||||
output_dir: Path,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
@@ -1550,8 +1548,8 @@ def convert_image_to_video_dataset(
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
output_dir: Directory to save the new video dataset
|
||||
repo_id: Repository ID for the new dataset (default: original_id + "_video")
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
@@ -1602,7 +1600,6 @@ def convert_image_to_video_dataset(
|
||||
# Video info will be updated after episodes are encoded
|
||||
|
||||
# Create new metadata for video dataset
|
||||
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
|
||||
@@ -126,11 +126,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
)
|
||||
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
if dataset.meta.stats is None:
|
||||
dataset.meta.stats = {}
|
||||
for key in dataset.meta.camera_keys:
|
||||
if key not in dataset.meta.stats:
|
||||
dataset.meta.stats[key] = {}
|
||||
for stats_type, stats in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
|
||||
@@ -314,7 +314,7 @@ class LeRobotDatasetMetadata:
|
||||
if self.tasks is None:
|
||||
new_tasks = tasks
|
||||
task_indices = range(len(tasks))
|
||||
self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task"))
|
||||
self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks)
|
||||
else:
|
||||
new_tasks = [task for task in tasks if task not in self.tasks.index]
|
||||
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
|
||||
|
||||
@@ -21,7 +21,7 @@ from collections import deque
|
||||
from collections.abc import Iterable, Iterator
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@@ -78,6 +78,8 @@ DEFAULT_FEATURES = {
|
||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
@@ -339,7 +341,6 @@ def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None:
|
||||
|
||||
def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||
tasks.index.name = "task"
|
||||
return tasks
|
||||
|
||||
|
||||
@@ -1232,7 +1233,7 @@ class LookAheadError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Backtrackable[T]:
|
||||
class Backtrackable(Generic[T]):
|
||||
"""
|
||||
Wrap any iterator/iterable so you can step back up to `history` items
|
||||
and look ahead up to `lookahead` items.
|
||||
|
||||
@@ -36,11 +36,8 @@ Convert a local dataset (works in place):
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
--repo-id=lerobot/pusht \
|
||||
--root=/path/to/local/dataset/directory \
|
||||
--root=/path/to/local/dataset/directory
|
||||
--push-to-hub=false
|
||||
|
||||
N.B. Path semantics (v2): --root is the exact dataset folder containing
|
||||
meta/, data/, videos/. When omitted, defaults to $HF_LEROBOT_HOME/{repo_id}.
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -108,7 +105,7 @@ episodes.jsonl
|
||||
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
|
||||
|
||||
NEW
|
||||
meta/episodes/chunk-000/file_000.parquet
|
||||
meta/episodes/chunk-000/episodes_000.parquet
|
||||
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
|
||||
-------------------------
|
||||
OLD
|
||||
@@ -116,16 +113,15 @@ tasks.jsonl
|
||||
{"task_index": 1, "task": "Put the blue block in the green bowl"}
|
||||
|
||||
NEW
|
||||
meta/tasks.parquet
|
||||
meta/tasks/chunk-000/file_000.parquet
|
||||
task_index | task
|
||||
-------------------------
|
||||
OLD
|
||||
episodes_stats.jsonl
|
||||
{"episode_index": 1, "stats": {"feature_name": {"min": ..., "max": ..., "mean": ..., "std": ..., "count": ...}}}
|
||||
|
||||
NEW
|
||||
meta/episodes/chunk-000/file_000.parquet
|
||||
episode_index | feature_name/min | feature_name/max | feature_name/mean | feature_name/std | feature_name/count
|
||||
meta/episodes_stats/chunk-000/file_000.parquet
|
||||
episode_index | mean | std | min | max
|
||||
-------------------------
|
||||
UPDATE
|
||||
meta/info.json
|
||||
@@ -174,7 +170,7 @@ def convert_tasks(root, new_root):
|
||||
tasks, _ = legacy_load_tasks(root)
|
||||
task_indices = tasks.keys()
|
||||
task_strings = tasks.values()
|
||||
df_tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(task_strings, name="task"))
|
||||
df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings)
|
||||
write_tasks(df_tasks, new_root)
|
||||
|
||||
|
||||
@@ -205,6 +201,7 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
|
||||
image_keys = get_image_keys(root)
|
||||
|
||||
ep_idx = 0
|
||||
chunk_idx = 0
|
||||
file_idx = 0
|
||||
size_in_mb = 0
|
||||
@@ -214,23 +211,9 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
|
||||
logging.info(f"Converting data files from {len(ep_paths)} episodes")
|
||||
|
||||
for ep_idx, ep_path in enumerate(tqdm.tqdm(ep_paths, desc="convert data files")):
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc="convert data files"):
|
||||
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
|
||||
ep_num_frames = get_parquet_num_frames(ep_path)
|
||||
|
||||
# Check if we need to start a new file BEFORE creating metadata
|
||||
if size_in_mb + ep_size_in_mb >= data_file_size_in_mb and len(paths_to_cat) > 0:
|
||||
# Write the accumulated data files
|
||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||
|
||||
# Move to next file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
|
||||
# Reset for the next file
|
||||
size_in_mb = 0
|
||||
paths_to_cat = []
|
||||
|
||||
# Now create metadata with correct chunk/file indices
|
||||
ep_metadata = {
|
||||
"episode_index": ep_idx,
|
||||
"data/chunk_index": chunk_idx,
|
||||
@@ -241,7 +224,20 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
size_in_mb += ep_size_in_mb
|
||||
num_frames += ep_num_frames
|
||||
episodes_metadata.append(ep_metadata)
|
||||
paths_to_cat.append(ep_path)
|
||||
ep_idx += 1
|
||||
|
||||
if size_in_mb < data_file_size_in_mb:
|
||||
paths_to_cat.append(ep_path)
|
||||
continue
|
||||
|
||||
if paths_to_cat:
|
||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||
|
||||
# Reset for the next file
|
||||
size_in_mb = ep_size_in_mb
|
||||
paths_to_cat = [ep_path]
|
||||
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
|
||||
# Write remaining data if any
|
||||
if paths_to_cat:
|
||||
@@ -473,7 +469,7 @@ def convert_dataset(
|
||||
|
||||
# Set root based on whether local dataset path is provided
|
||||
use_local_dataset = False
|
||||
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root)
|
||||
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root) / repo_id
|
||||
if root.exists():
|
||||
validate_local_dataset_version(root)
|
||||
use_local_dataset = True
|
||||
@@ -557,7 +553,7 @@ if __name__ == "__main__":
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Local directory to use for downloading/writing the dataset. Defaults to $HF_LEROBOT_HOME/repo_id.",
|
||||
help="Local directory to use for downloading/writing the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
|
||||
@@ -45,10 +45,6 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
fps: int = 30
|
||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
features_map: dict[str, str] = field(default_factory=dict)
|
||||
# Upper bound on concurrent task evaluation in `lerobot-eval`.
|
||||
# - For lazy wrappers (e.g. LIBERO/LIBERO-plus), values >1 can enable chunked
|
||||
# task batching with one policy forward pass over multiple tasks.
|
||||
# - For other envs, values >1 use a threaded task scheduler fallback.
|
||||
max_parallel_tasks: int = 1
|
||||
disable_env_checker: bool = True
|
||||
|
||||
@@ -350,105 +346,6 @@ class LiberoEnv(EnvConfig):
|
||||
return kwargs
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("libero_plus")
|
||||
@dataclass
|
||||
class LiberoPlusEnv(LiberoEnv):
|
||||
"""Alias config for LIBERO-plus benchmarks.
|
||||
|
||||
LIBERO-plus keeps the same Python package/module names as LIBERO, so this
|
||||
config reuses the existing LIBERO env implementation while making intent explicit
|
||||
in experiment configs (`env.type=libero_plus`).
|
||||
"""
|
||||
|
||||
task: str = "libero_spatial,libero_object,libero_goal,libero_10"
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("robocasa")
|
||||
@dataclass
|
||||
class RoboCasaEnv(EnvConfig):
|
||||
"""RoboCasa kitchen composite-task environments.
|
||||
|
||||
Wraps ``robocasa.wrappers.gym_wrapper.RoboCasaGymEnv`` with a flat 12-D Box
|
||||
action space and a structured pixel + state observation dict.
|
||||
|
||||
Selected benchmark tasks (3 short + 2 long):
|
||||
Short: PickPlaceCounterToCabinet, PrepareToast, CoffeeSetupMug
|
||||
Long: PrepareCoffee, RestockPantry
|
||||
"""
|
||||
|
||||
task: str = "PickPlaceCounterToCabinet"
|
||||
tasks: list[str] | None = None # multi-task: list of task names (without robocasa/ prefix)
|
||||
fps: int = 20
|
||||
episode_length: int = 500
|
||||
image_size: int = 128
|
||||
split: str = "target" # "pretrain" or "target"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(12,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"agentview_left": f"{OBS_IMAGES}.agentview_left",
|
||||
"agentview_right": f"{OBS_IMAGES}.agentview_right",
|
||||
"eye_in_hand": f"{OBS_IMAGES}.eye_in_hand",
|
||||
"robot_state": OBS_STATE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
|
||||
self.features[cam] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.image_size, self.image_size, 3)
|
||||
)
|
||||
self.features["robot_state"] = PolicyFeature(type=FeatureType.STATE, shape=(16,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {"split": self.split}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("robomme")
|
||||
@dataclass
|
||||
class RoboMMEEnv(EnvConfig):
|
||||
"""RoboMME memory-augmented manipulation benchmark (ManiSkill/SAPIEN).
|
||||
|
||||
16 tasks across 4 suites: Counting, Permanence, Reference, Imitation.
|
||||
Uses BenchmarkEnvBuilder from the robomme package.
|
||||
"""
|
||||
|
||||
task: str = "PickXtimes"
|
||||
fps: int = 10
|
||||
episode_length: int = 300
|
||||
action_space: str = "joint_angle"
|
||||
dataset_split: str = "test"
|
||||
task_ids: list[int] | None = None
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(8,)),
|
||||
"front_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
||||
"wrist_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"front_rgb": f"{OBS_IMAGES}.front",
|
||||
"wrist_rgb": f"{OBS_IMAGES}.wrist",
|
||||
OBS_STATE: OBS_STATE,
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"action_space": self.action_space,
|
||||
"dataset": self.dataset_split,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("metaworld")
|
||||
@dataclass
|
||||
class MetaworldEnv(EnvConfig):
|
||||
|
||||
@@ -1,442 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Docker runtime for lerobot-eval.
|
||||
|
||||
The policy stays on the host GPU; gym environments run inside Docker containers.
|
||||
Each container runs `lerobot-eval-worker`, which calls back to a host HTTP inference
|
||||
server for action chunks.
|
||||
|
||||
Architecture:
|
||||
host (GPU):
|
||||
1. Load policy + preprocessors from EvalPipelineConfig.
|
||||
2. Start ``policy_servers`` HTTP inference servers on consecutive ports.
|
||||
3. Spawn ``instance_count`` Docker containers, round-robin assigned to servers.
|
||||
4. Wait; collect per-task JSON written to the mounted output volume.
|
||||
5. Merge shards → aggregate → write eval_info.json.
|
||||
|
||||
container (CPU only):
|
||||
1. make_env(cfg.env) → shard tasks by (instance_id, instance_count).
|
||||
2. For each task: run n_episodes, POST obs to /predict_chunk, step env.
|
||||
3. Write per-task JSON to /results/worker_{instance_id}.json.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import pickle # nosec B403 — internal serialisation only
|
||||
import platform
|
||||
import subprocess # nosec B404
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.envs.factory import make_env_pre_post_processors
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.utils.utils import get_safe_torch_device
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP inference server (host side)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _PolicyInferenceHandler(BaseHTTPRequestHandler):
|
||||
"""POST /predict_chunk → pickled numpy action chunk."""
|
||||
|
||||
server: _InferenceServer
|
||||
|
||||
def do_POST(self) -> None:
|
||||
if self.path != "/predict_chunk":
|
||||
self.send_error(404)
|
||||
return
|
||||
length = int(self.headers["Content-Length"])
|
||||
body = self.rfile.read(length)
|
||||
payload: dict = pickle.loads(body) # nosec B301
|
||||
obs_t: dict = payload["obs_t"]
|
||||
|
||||
with self.server._lock:
|
||||
chunk_np = self.server._predict(obs_t)
|
||||
|
||||
resp = pickle.dumps(chunk_np) # nosec B301
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "application/octet-stream")
|
||||
self.send_header("Content-Length", str(len(resp)))
|
||||
self.end_headers()
|
||||
self.wfile.write(resp)
|
||||
|
||||
def log_message(self, fmt: str, *args: Any) -> None: # noqa: ANN401
|
||||
pass # suppress per-request logs
|
||||
|
||||
|
||||
class _InferenceServer(HTTPServer):
|
||||
"""Wraps the loaded policy behind a trivial HTTP interface."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
addr: tuple[str, int],
|
||||
policy: Any,
|
||||
env_preprocessor: Any,
|
||||
preprocessor: Any,
|
||||
postprocessor: Any,
|
||||
) -> None:
|
||||
super().__init__(addr, _PolicyInferenceHandler)
|
||||
self._policy = policy
|
||||
self._env_preprocessor = env_preprocessor
|
||||
self._preprocessor = preprocessor
|
||||
self._postprocessor = postprocessor
|
||||
self._lock = threading.Lock()
|
||||
self._device = torch.device(str(policy.config.device))
|
||||
|
||||
def _predict(self, obs_t: dict) -> np.ndarray:
|
||||
"""Apply full preprocessing pipeline and return (n_action_steps, A) numpy chunk."""
|
||||
obs = self._env_preprocessor(obs_t)
|
||||
obs = self._preprocessor(obs)
|
||||
obs_gpu: dict = {k: v.to(self._device) if isinstance(v, torch.Tensor) else v for k, v in obs.items()}
|
||||
with torch.no_grad():
|
||||
chunk: torch.Tensor = self._policy.predict_action_chunk(obs_gpu) # (B, T, A)
|
||||
|
||||
n_action_steps = getattr(self._policy.config, "n_action_steps", chunk.shape[1])
|
||||
batch, n_steps, action_dim = chunk.shape
|
||||
chunk_2d = chunk.reshape(batch * n_steps, action_dim) # (B*T, A)
|
||||
chunk_2d = self._postprocessor(chunk_2d) # (B*T, A)
|
||||
return chunk_2d[:n_action_steps].cpu().numpy() # (n_action_steps, A)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_host_ip() -> str:
|
||||
"""Return the IP that containers can use to reach the host."""
|
||||
if platform.system() in ("Darwin", "Windows"):
|
||||
return "host.docker.internal"
|
||||
return "172.17.0.1" # Linux Docker bridge default gateway
|
||||
|
||||
|
||||
def _resolve_image(cfg: EvalPipelineConfig) -> str:
|
||||
"""Return the Docker image name to use for the env containers."""
|
||||
if cfg.eval.docker.image:
|
||||
return cfg.eval.docker.image
|
||||
return f"lerobot-benchmark-{cfg.env.type}"
|
||||
|
||||
|
||||
def _env_argv() -> list[str]:
|
||||
"""Extract --env.* args from sys.argv to forward verbatim to the worker."""
|
||||
return [arg for arg in sys.argv[1:] if arg.startswith("--env.")]
|
||||
|
||||
|
||||
def _spawn_container(
|
||||
*,
|
||||
image: str,
|
||||
instance_id: int,
|
||||
instance_count: int,
|
||||
server_address: str,
|
||||
n_episodes: int,
|
||||
seed: int,
|
||||
output_dir: Path,
|
||||
docker_cfg: Any,
|
||||
env_argv: list[str],
|
||||
) -> subprocess.Popen:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
container_results = "/results"
|
||||
|
||||
cmd: list[str] = ["docker", "run", "--rm"]
|
||||
if docker_cfg.gpus:
|
||||
cmd += [f"--gpus={docker_cfg.gpus}"]
|
||||
cmd += [f"--shm-size={docker_cfg.shm_size}"]
|
||||
cmd += ["-v", f"{output_dir.resolve()}:{container_results}"]
|
||||
# Allow containers on Linux to resolve host.docker.internal.
|
||||
cmd += ["--add-host=host.docker.internal:host-gateway"]
|
||||
cmd.append(image)
|
||||
|
||||
cmd += [
|
||||
"lerobot-eval-worker",
|
||||
*env_argv,
|
||||
f"--server_address={server_address}",
|
||||
f"--n_episodes={n_episodes}",
|
||||
f"--seed={seed}",
|
||||
f"--instance_id={instance_id}",
|
||||
f"--instance_count={instance_count}",
|
||||
f"--output_path={container_results}/worker_{instance_id}.json",
|
||||
]
|
||||
|
||||
logger.info(
|
||||
"Spawning container %d/%d: %s",
|
||||
instance_id + 1,
|
||||
instance_count,
|
||||
" ".join(cmd),
|
||||
)
|
||||
return subprocess.Popen(cmd) # nosec B603 B607
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_eval_in_docker(cfg: EvalPipelineConfig) -> None:
|
||||
"""Run eval with env in Docker containers and policy on the host GPU.
|
||||
|
||||
Writes ``eval_info.json`` to ``cfg.output_dir``. Called by
|
||||
``lerobot_eval._run_eval_worker`` when ``eval.runtime == "docker"``.
|
||||
"""
|
||||
# Import here to avoid circular import at module level.
|
||||
from lerobot.scripts.lerobot_eval import _aggregate_eval_from_per_task
|
||||
|
||||
start_t = time.time()
|
||||
output_dir = Path(cfg.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
docker_cfg = cfg.eval.docker
|
||||
|
||||
# Optionally pull the image before starting.
|
||||
image = _resolve_image(cfg)
|
||||
if docker_cfg.pull:
|
||||
logger.info("Pulling Docker image: %s", image)
|
||||
subprocess.run(["docker", "pull", image], check=True) # nosec B603 B607
|
||||
|
||||
# ── Load policy + all preprocessors on the host GPU ──────────────────
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env, rename_map=cfg.rename_map)
|
||||
policy.eval()
|
||||
|
||||
preprocessor_overrides: dict = {
|
||||
"device_processor": {"device": str(device)},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
}
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
)
|
||||
env_preprocessor, _env_postprocessor = make_env_pre_post_processors(
|
||||
env_cfg=cfg.env,
|
||||
policy_cfg=cfg.policy,
|
||||
)
|
||||
|
||||
# ── Start HTTP inference server(s) ────────────────────────────────────
|
||||
n_policy_servers = cfg.eval.policy_servers
|
||||
base_port = docker_cfg.port
|
||||
host_ip = _get_host_ip()
|
||||
instance_count = cfg.eval.instance_count
|
||||
env_argv = _env_argv()
|
||||
|
||||
servers: list[_InferenceServer] = []
|
||||
for s_idx in range(n_policy_servers):
|
||||
port = base_port + s_idx
|
||||
if s_idx > 0:
|
||||
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env, rename_map=cfg.rename_map)
|
||||
policy.eval()
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
)
|
||||
env_preprocessor, _ = make_env_pre_post_processors(
|
||||
env_cfg=cfg.env, policy_cfg=cfg.policy,
|
||||
)
|
||||
srv = _InferenceServer(
|
||||
("0.0.0.0", port), # nosec B104
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
t = threading.Thread(target=srv.serve_forever, daemon=True)
|
||||
t.start()
|
||||
servers.append(srv)
|
||||
logger.info("Policy inference server %d/%d running on port %d", s_idx + 1, n_policy_servers, port)
|
||||
|
||||
# ── Spawn containers (round-robin across policy servers) ──────────────
|
||||
container_dirs: list[Path] = []
|
||||
procs: list[subprocess.Popen] = []
|
||||
try:
|
||||
for i in range(instance_count):
|
||||
assigned_port = base_port + (i % n_policy_servers)
|
||||
server_address = f"{host_ip}:{assigned_port}"
|
||||
shard_dir = output_dir / "shards" / str(i)
|
||||
container_dirs.append(shard_dir)
|
||||
proc = _spawn_container(
|
||||
image=image,
|
||||
instance_id=i,
|
||||
instance_count=instance_count,
|
||||
server_address=server_address,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
seed=cfg.seed,
|
||||
output_dir=shard_dir,
|
||||
docker_cfg=docker_cfg,
|
||||
env_argv=env_argv,
|
||||
)
|
||||
procs.append(proc)
|
||||
|
||||
failed: list[tuple[int, int]] = []
|
||||
for i, proc in enumerate(procs):
|
||||
rc = proc.wait()
|
||||
if rc != 0:
|
||||
failed.append((i, rc))
|
||||
logger.error("Container %d/%d exited with code %d", i + 1, instance_count, rc)
|
||||
if failed:
|
||||
raise RuntimeError(f"Docker eval containers failed (instance_id, exit_code): {failed}")
|
||||
|
||||
finally:
|
||||
for srv in servers:
|
||||
srv.shutdown()
|
||||
|
||||
# ── Collect and merge per-task results ───────────────────────────────
|
||||
per_task: list[dict] = []
|
||||
for i, shard_dir in enumerate(container_dirs):
|
||||
result_file = shard_dir / f"worker_{i}.json"
|
||||
with open(result_file) as f:
|
||||
shard_data: dict = json.load(f)
|
||||
per_task.extend(shard_data.get("per_task", []))
|
||||
|
||||
per_task.sort(key=lambda x: (x["task_group"], x["task_id"]))
|
||||
|
||||
info = _aggregate_eval_from_per_task(per_task, total_eval_s=time.time() - start_t)
|
||||
with open(output_dir / "eval_info.json", "w") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
logger.info("Docker eval complete. Results: %s/eval_info.json", output_dir)
|
||||
|
||||
|
||||
def run_eval_multiprocess(cfg: EvalPipelineConfig) -> None:
|
||||
"""Run eval with multiple local worker processes and policy servers (no Docker).
|
||||
|
||||
Same architecture as Docker runtime but spawns `lerobot-eval-worker` as local
|
||||
subprocesses. Works on SLURM clusters and anywhere Docker is unavailable.
|
||||
"""
|
||||
from lerobot.scripts.lerobot_eval import _aggregate_eval_from_per_task
|
||||
|
||||
start_t = time.time()
|
||||
output_dir = Path(cfg.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env, rename_map=cfg.rename_map)
|
||||
policy.eval()
|
||||
|
||||
preprocessor_overrides: dict = {
|
||||
"device_processor": {"device": str(device)},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
}
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
)
|
||||
env_preprocessor, _env_postprocessor = make_env_pre_post_processors(
|
||||
env_cfg=cfg.env, policy_cfg=cfg.policy,
|
||||
)
|
||||
|
||||
n_policy_servers = cfg.eval.policy_servers
|
||||
base_port = cfg.eval.port
|
||||
instance_count = cfg.eval.instance_count
|
||||
env_argv = _env_argv()
|
||||
|
||||
servers: list[_InferenceServer] = []
|
||||
for s_idx in range(n_policy_servers):
|
||||
port = base_port + s_idx
|
||||
if s_idx > 0:
|
||||
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env, rename_map=cfg.rename_map)
|
||||
policy.eval()
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
)
|
||||
env_preprocessor, _ = make_env_pre_post_processors(
|
||||
env_cfg=cfg.env, policy_cfg=cfg.policy,
|
||||
)
|
||||
srv = _InferenceServer(
|
||||
("0.0.0.0", port), # nosec B104
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
t = threading.Thread(target=srv.serve_forever, daemon=True)
|
||||
t.start()
|
||||
servers.append(srv)
|
||||
logger.info("Policy server %d/%d on port %d", s_idx + 1, n_policy_servers, port)
|
||||
|
||||
worker_dirs: list[Path] = []
|
||||
procs: list[subprocess.Popen] = []
|
||||
try:
|
||||
for i in range(instance_count):
|
||||
assigned_port = base_port + (i % n_policy_servers)
|
||||
shard_dir = output_dir / "shards" / str(i)
|
||||
shard_dir.mkdir(parents=True, exist_ok=True)
|
||||
worker_dirs.append(shard_dir)
|
||||
|
||||
cmd = [
|
||||
sys.executable, "-m", "lerobot.scripts.lerobot_eval_worker",
|
||||
*env_argv,
|
||||
f"--server_address=127.0.0.1:{assigned_port}",
|
||||
f"--n_episodes={cfg.eval.n_episodes}",
|
||||
f"--seed={cfg.seed}",
|
||||
f"--instance_id={i}",
|
||||
f"--instance_count={instance_count}",
|
||||
f"--output_path={shard_dir / f'worker_{i}.json'}",
|
||||
]
|
||||
logger.info("Spawning worker %d/%d → port %d", i + 1, instance_count, assigned_port)
|
||||
procs.append(subprocess.Popen(cmd)) # nosec B603
|
||||
|
||||
failed: list[tuple[int, int]] = []
|
||||
for i, proc in enumerate(procs):
|
||||
rc = proc.wait()
|
||||
if rc != 0:
|
||||
failed.append((i, rc))
|
||||
logger.error("Worker %d/%d exited with code %d", i + 1, instance_count, rc)
|
||||
if failed:
|
||||
raise RuntimeError(f"Multiprocess eval workers failed (id, exit_code): {failed}")
|
||||
|
||||
finally:
|
||||
for srv in servers:
|
||||
srv.shutdown()
|
||||
|
||||
per_task: list[dict] = []
|
||||
for i, shard_dir in enumerate(worker_dirs):
|
||||
result_file = shard_dir / f"worker_{i}.json"
|
||||
with open(result_file) as f:
|
||||
shard_data: dict = json.load(f)
|
||||
per_task.extend(shard_data.get("per_task", []))
|
||||
|
||||
per_task.sort(key=lambda x: (x["task_group"], x["task_id"]))
|
||||
|
||||
info = _aggregate_eval_from_per_task(per_task, total_eval_s=time.time() - start_t)
|
||||
with open(output_dir / "eval_info.json", "w") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
logger.info("Multiprocess eval complete. Results: %s/eval_info.json", output_dir)
|
||||
@@ -20,21 +20,11 @@ import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.configs import (
|
||||
AlohaEnv,
|
||||
EnvConfig,
|
||||
HubEnvConfig,
|
||||
IsaaclabArenaEnv,
|
||||
LiberoEnv,
|
||||
LiberoPlusEnv,
|
||||
PushtEnv,
|
||||
RoboCasaEnv,
|
||||
RoboMMEEnv,
|
||||
)
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
|
||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import ProcessorStep
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep, RoboCasaProcessorStep
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
|
||||
@@ -45,12 +35,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "libero":
|
||||
return LiberoEnv(**kwargs)
|
||||
elif env_type == "libero_plus":
|
||||
return LiberoPlusEnv(**kwargs)
|
||||
elif env_type == "robocasa":
|
||||
return RoboCasaEnv(**kwargs)
|
||||
elif env_type == "robomme":
|
||||
return RoboMMEEnv(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
@@ -86,13 +70,9 @@ def make_env_pre_post_processors(
|
||||
return make_xvla_libero_pre_post_processors()
|
||||
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, (LiberoEnv, LiberoPlusEnv)) or "libero" in env_cfg.type:
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
preprocessor_steps.append(LiberoProcessorStep())
|
||||
|
||||
# For RoboCasa environments, add the RoboCasaProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, RoboCasaEnv) or "robocasa" in env_cfg.type:
|
||||
preprocessor_steps.append(RoboCasaProcessorStep())
|
||||
|
||||
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
|
||||
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
|
||||
# Parse comma-separated keys (handle None for state-based policies)
|
||||
@@ -125,7 +105,7 @@ def make_env(
|
||||
use_async_envs: bool = False,
|
||||
hub_cache_dir: str | None = None,
|
||||
trust_remote_code: bool = False,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
"""Makes a gym vector environment according to the config or Hub reference.
|
||||
|
||||
Args:
|
||||
@@ -143,9 +123,8 @@ def make_env(
|
||||
ModuleNotFoundError: If the requested env package is not installed
|
||||
|
||||
Returns:
|
||||
dict[str, dict[int, Any]]:
|
||||
A mapping from suite name to indexed environments. Values are either
|
||||
materialized vector envs or lazy wrappers that materialize on first use.
|
||||
dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
A mapping from suite name to indexed vectorized environments.
|
||||
- For multi-task benchmarks (e.g., LIBERO): one entry per suite, and one vec env per task_id.
|
||||
- For single-task environments: a single suite entry (cfg.type) with task_id=0.
|
||||
|
||||
@@ -192,11 +171,6 @@ def make_env(
|
||||
if cfg.task is None:
|
||||
raise ValueError("LiberoEnv requires a task to be specified")
|
||||
|
||||
if cfg.type == "libero_plus":
|
||||
from lerobot.envs.libero import _check_libero_plus_assets
|
||||
|
||||
_check_libero_plus_assets()
|
||||
|
||||
return create_libero_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
@@ -207,33 +181,6 @@ def make_env(
|
||||
control_mode=cfg.control_mode,
|
||||
episode_length=cfg.episode_length,
|
||||
)
|
||||
elif "robocasa" in cfg.type:
|
||||
from lerobot.envs.robocasa import create_robocasa_envs
|
||||
|
||||
tasks = cfg.tasks if cfg.tasks else [cfg.task]
|
||||
return create_robocasa_envs(
|
||||
tasks=tasks,
|
||||
n_envs=n_envs,
|
||||
image_size=cfg.image_size,
|
||||
split=cfg.split,
|
||||
episode_length=cfg.episode_length,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
elif "robomme" in cfg.type:
|
||||
from lerobot.envs.robomme import create_robomme_envs
|
||||
|
||||
return create_robomme_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
action_space_type=cfg.action_space,
|
||||
dataset=cfg.dataset_split,
|
||||
episode_length=cfg.episode_length,
|
||||
task_ids=cfg.task_ids,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
elif "metaworld" in cfg.type:
|
||||
from lerobot.envs.metaworld import create_metaworld_envs
|
||||
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LazyVectorEnv:
|
||||
"""Defer vector-env construction until first usage.
|
||||
|
||||
This is useful for benchmarks with many tasks: we can register one env object
|
||||
per task without eagerly allocating all simulator/rendering resources.
|
||||
"""
|
||||
|
||||
def __init__(self, env_cls: Callable[[Sequence[Callable[[], Any]]], Any], factory_fns: list[Callable]):
|
||||
self._env_cls = env_cls
|
||||
self._factory_fns = factory_fns
|
||||
self._env = None
|
||||
|
||||
@property
|
||||
def env_cls(self) -> Callable[[Sequence[Callable[[], Any]]], Any]:
|
||||
return self._env_cls
|
||||
|
||||
@property
|
||||
def factory_fns(self) -> list[Callable]:
|
||||
return self._factory_fns
|
||||
|
||||
@property
|
||||
def num_factory_fns(self) -> int:
|
||||
return len(self._factory_fns)
|
||||
|
||||
def materialize(self):
|
||||
if self._env is None:
|
||||
self._env = self._env_cls(self._factory_fns)
|
||||
return self._env
|
||||
|
||||
def close(self):
|
||||
if self._env is not None:
|
||||
self._env.close()
|
||||
self._env = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.materialize(), name)
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
@@ -27,222 +26,11 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
|
||||
try:
|
||||
import libero as _libero_pkg # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import libero. Install benchmark dependencies with one of:\n"
|
||||
" pip install -e \".[libero]\"\n"
|
||||
" pip install -e \".[libero_plus]\" (alias: \".[libero-plus]\")"
|
||||
)
|
||||
|
||||
# LIBERO's env_wrapper unconditionally imports wand (ImageMagick Python binding)
|
||||
# which requires the system-level libMagickWand library. The wand features are only
|
||||
# used for visual noise perturbations and are not needed for standard evaluation.
|
||||
# Pre-install a stub so the import succeeds even without ImageMagick.
|
||||
import sys
|
||||
import types
|
||||
|
||||
if "wand" not in sys.modules:
|
||||
try:
|
||||
import wand.api # noqa: F401
|
||||
except (ImportError, OSError):
|
||||
|
||||
class _AttrSink:
|
||||
"""Accepts any attribute get/set without error."""
|
||||
|
||||
def __getattr__(self, _name):
|
||||
return self
|
||||
|
||||
def __setattr__(self, _name, _value):
|
||||
pass
|
||||
|
||||
def __call__(self, *a, **kw):
|
||||
pass
|
||||
|
||||
_wand = types.ModuleType("wand")
|
||||
_wand_api = types.ModuleType("wand.api")
|
||||
_wand_api.library = _AttrSink()
|
||||
_wand_image = types.ModuleType("wand.image")
|
||||
_wand_image.Image = type("Image", (), {})
|
||||
_wand.api = _wand_api
|
||||
_wand.image = _wand_image
|
||||
sys.modules["wand"] = _wand
|
||||
sys.modules["wand.api"] = _wand_api
|
||||
sys.modules["wand.image"] = _wand_image
|
||||
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
from lerobot.envs.lazy_vec_env import LazyVectorEnv
|
||||
from lerobot.processor import RobotObservation
|
||||
|
||||
_ASSET_DOWNLOAD_INSTRUCTIONS = """\
|
||||
LIBERO-plus assets not found at: {assets_dir}
|
||||
|
||||
The LIBERO-plus benchmark requires ~6 GB of scene/texture/object assets that
|
||||
are hosted separately on Hugging Face. To download and install them:
|
||||
|
||||
python -c "
|
||||
from huggingface_hub import hf_hub_download
|
||||
hf_hub_download('Sylvest/LIBERO-plus', 'assets.zip',
|
||||
repo_type='dataset', local_dir='/tmp/libero-plus-assets')
|
||||
"
|
||||
unzip /tmp/libero-plus-assets/assets.zip -d /tmp/libero-plus-assets-unzipped
|
||||
# The zip contains a deeply nested path; move the assets directory:
|
||||
mv /tmp/libero-plus-assets-unzipped/inspire/*/assets {assets_dir}
|
||||
rm -rf /tmp/libero-plus-assets /tmp/libero-plus-assets-unzipped
|
||||
|
||||
See https://huggingface.co/datasets/Sylvest/LIBERO-plus for details.
|
||||
"""
|
||||
|
||||
|
||||
def _check_libero_plus_assets() -> None:
|
||||
"""Validate that LIBERO-plus scene assets are present."""
|
||||
assets_dir = Path(get_libero_path("benchmark_root")) / "assets"
|
||||
if not (assets_dir / "scenes").is_dir():
|
||||
raise FileNotFoundError(_ASSET_DOWNLOAD_INSTRUCTIONS.format(assets_dir=assets_dir))
|
||||
|
||||
|
||||
# ---- Perturbation support for LIBERO-Plus -----------------------------------
|
||||
|
||||
PERTURBATION_DIMENSIONS = (
|
||||
"Camera Viewpoints",
|
||||
"Robot Initial States",
|
||||
"Language Instructions",
|
||||
"Light Conditions",
|
||||
"Background Textures",
|
||||
"Sensor Noise",
|
||||
"Objects Layout",
|
||||
)
|
||||
|
||||
PERTURBATION_SHORT_KEYS = {
|
||||
"Camera Viewpoints": "camera",
|
||||
"Robot Initial States": "robot",
|
||||
"Language Instructions": "language",
|
||||
"Light Conditions": "light",
|
||||
"Background Textures": "background",
|
||||
"Sensor Noise": "noise",
|
||||
"Objects Layout": "layout",
|
||||
}
|
||||
|
||||
|
||||
def load_task_classification() -> dict:
|
||||
"""Load task_classification.json shipped with LIBERO-Plus."""
|
||||
import json
|
||||
|
||||
benchmark_root = Path(get_libero_path("benchmark_root"))
|
||||
candidates = [
|
||||
benchmark_root / "benchmark" / "task_classification.json",
|
||||
benchmark_root / "task_classification.json",
|
||||
benchmark_root.parent / "benchmark" / "task_classification.json",
|
||||
]
|
||||
for p in candidates:
|
||||
if p.exists():
|
||||
with open(p) as f:
|
||||
return json.load(f)
|
||||
raise FileNotFoundError(
|
||||
f"task_classification.json not found. Tried: {[str(c) for c in candidates]}"
|
||||
)
|
||||
|
||||
|
||||
def build_perturbation_index(suite_name: str) -> dict[int, str]:
|
||||
"""Return {0-indexed task_id: perturbation_dimension} for *suite_name*."""
|
||||
tc = load_task_classification()
|
||||
suite_data = tc.get(suite_name, {})
|
||||
index: dict[int, str] = {}
|
||||
|
||||
# LIBERO-Plus task_classification.json has appeared in two shapes:
|
||||
# 1) dict[suite][task_id_str] -> meta
|
||||
# 2) dict[suite] -> list[{id, category, ...}]
|
||||
if isinstance(suite_data, list):
|
||||
for item in suite_data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
raw_id = item.get("id")
|
||||
if raw_id is None:
|
||||
continue
|
||||
try:
|
||||
# list-form ids are 1-indexed in current LIBERO-Plus release.
|
||||
tid = int(raw_id) - 1
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if tid < 0:
|
||||
continue
|
||||
dim = item.get("perturbation_type") or item.get("category", "unknown")
|
||||
index[tid] = dim
|
||||
return index
|
||||
|
||||
if isinstance(suite_data, dict):
|
||||
# Handle both 0-indexed and 1-indexed key conventions.
|
||||
numeric_keys: list[int] = []
|
||||
for k in suite_data:
|
||||
try:
|
||||
numeric_keys.append(int(k))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
one_indexed = bool(numeric_keys) and 0 not in numeric_keys and min(numeric_keys) >= 1
|
||||
|
||||
for task_id_str, meta in suite_data.items():
|
||||
try:
|
||||
tid = int(task_id_str)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if one_indexed:
|
||||
tid -= 1
|
||||
if tid < 0:
|
||||
continue
|
||||
if isinstance(meta, dict):
|
||||
dim = meta.get("perturbation_type") or meta.get("category", "unknown")
|
||||
else:
|
||||
dim = "unknown"
|
||||
index[tid] = dim
|
||||
return index
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def aggregate_by_perturbation(
|
||||
per_task: list[dict], suite_indices: dict[str, dict[int, str]]
|
||||
) -> dict[str, dict]:
|
||||
"""Aggregate per-task eval results by perturbation dimension.
|
||||
|
||||
Args:
|
||||
per_task: list of {"task_group": str, "task_id": int, "metrics": {...}}
|
||||
suite_indices: {suite_name: {task_id: dimension_name}} from build_perturbation_index
|
||||
|
||||
Returns:
|
||||
{short_key: {"pc_success": float, "n_episodes": int}} for each perturbation dimension
|
||||
"""
|
||||
dim_successes: dict[str, list] = defaultdict(list)
|
||||
for entry in per_task:
|
||||
suite = entry["task_group"]
|
||||
tid = entry["task_id"]
|
||||
idx = suite_indices.get(suite, {})
|
||||
dim = idx.get(tid, "unknown")
|
||||
short = PERTURBATION_SHORT_KEYS.get(dim, dim.lower().replace(" ", "_"))
|
||||
successes = entry["metrics"].get("successes", [])
|
||||
dim_successes[short].extend(successes)
|
||||
|
||||
results: dict[str, dict] = {}
|
||||
all_successes: list = []
|
||||
for short_key in list(PERTURBATION_SHORT_KEYS.values()) + ["unknown"]:
|
||||
if short_key not in dim_successes:
|
||||
continue
|
||||
s = dim_successes[short_key]
|
||||
all_successes.extend(s)
|
||||
results[short_key] = {
|
||||
"pc_success": float(np.nanmean(s) * 100) if s else float("nan"),
|
||||
"n_episodes": len(s),
|
||||
}
|
||||
if all_successes:
|
||||
results["total"] = {
|
||||
"pc_success": float(np.nanmean(all_successes) * 100),
|
||||
"n_episodes": len(all_successes),
|
||||
}
|
||||
return results
|
||||
|
||||
|
||||
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||
"""Normalize camera_name into a non-empty list of strings."""
|
||||
@@ -280,35 +68,13 @@ def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[i
|
||||
|
||||
|
||||
def get_task_init_states(task_suite: Any, i: int) -> np.ndarray:
|
||||
init_states_dir = Path(get_libero_path("init_states")) / task_suite.tasks[i].problem_folder
|
||||
init_states_file = task_suite.tasks[i].init_states_file
|
||||
|
||||
# 1. Direct match
|
||||
direct = init_states_dir / init_states_file
|
||||
if direct.exists():
|
||||
return torch.load(direct, weights_only=False) # nosec B614
|
||||
|
||||
# 2. LIBERO-Plus perturbation filenames append suffixes like
|
||||
# _view_0_0_100_0_0_initstate_1, _language_19, _noise_45, _table_1, _tb_9, _add_16
|
||||
# to the base task name. Instead of regex-matching every variant, find the
|
||||
# longest existing base file whose stem is a prefix of the perturbation stem.
|
||||
stem, ext = os.path.splitext(init_states_file)
|
||||
best_match: Path | None = None
|
||||
best_len = 0
|
||||
for candidate in init_states_dir.glob(f"*{ext}"):
|
||||
cstem = candidate.stem
|
||||
if stem == cstem or (stem.startswith(cstem) and stem[len(cstem)] == "_"):
|
||||
if len(cstem) > best_len:
|
||||
best_len = len(cstem)
|
||||
best_match = candidate
|
||||
|
||||
if best_match is not None:
|
||||
return torch.load(best_match, weights_only=False) # nosec B614
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Could not find init states for task {i}. "
|
||||
f"Tried '{init_states_file}' and prefix matching in '{init_states_dir}'."
|
||||
init_states_path = (
|
||||
Path(get_libero_path("init_states"))
|
||||
/ task_suite.tasks[i].problem_folder
|
||||
/ task_suite.tasks[i].init_states_file
|
||||
)
|
||||
init_states = torch.load(init_states_path, weights_only=False) # nosec B614
|
||||
return init_states
|
||||
|
||||
|
||||
def get_libero_dummy_action():
|
||||
@@ -328,29 +94,6 @@ TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
||||
}
|
||||
|
||||
|
||||
def _make_offscreen_env_with_renderer_fallback(env_args: dict[str, Any]) -> Any:
|
||||
"""Create OffScreenRenderEnv and fallback to OSMesa if EGL is unavailable."""
|
||||
try:
|
||||
return OffScreenRenderEnv(**env_args)
|
||||
except ImportError as exc:
|
||||
msg = str(exc)
|
||||
if "EGL" not in msg and "PLATFORM_DEVICE" not in msg:
|
||||
raise
|
||||
|
||||
# Headless clusters often miss EGL PLATFORM_DEVICE support. Retry with
|
||||
# software rendering to keep evaluation working.
|
||||
os.environ["MUJOCO_GL"] = "osmesa"
|
||||
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
||||
try:
|
||||
return OffScreenRenderEnv(**env_args)
|
||||
except Exception as fallback_exc:
|
||||
raise ImportError(
|
||||
"Failed to initialize robosuite offscreen renderer with both EGL and "
|
||||
"OSMesa backends. Set up EGL-capable drivers or install OSMesa (e.g. "
|
||||
"`conda install -c conda-forge mesalib`) and retry."
|
||||
) from fallback_exc
|
||||
|
||||
|
||||
class LiberoEnv(gym.Env):
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
|
||||
|
||||
@@ -404,7 +147,6 @@ class LiberoEnv(gym.Env):
|
||||
# Load once and keep
|
||||
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
|
||||
self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`.
|
||||
self._init_state_error_warned = False
|
||||
|
||||
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||
|
||||
@@ -496,7 +238,7 @@ class LiberoEnv(gym.Env):
|
||||
"camera_heights": self.observation_height,
|
||||
"camera_widths": self.observation_width,
|
||||
}
|
||||
env = _make_offscreen_env_with_renderer_fallback(env_args)
|
||||
env = OffScreenRenderEnv(**env_args)
|
||||
env.reset()
|
||||
return env
|
||||
|
||||
@@ -556,21 +298,8 @@ class LiberoEnv(gym.Env):
|
||||
self._env.seed(seed)
|
||||
raw_obs = self._env.reset()
|
||||
if self.init_states and self._init_states is not None:
|
||||
try:
|
||||
raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)])
|
||||
self.init_state_id += self._reset_stride # Change init_state_id when reset
|
||||
except Exception as exc:
|
||||
# Some LIBERO-Plus perturbation tasks (notably object-layout variants)
|
||||
# can have different simulator state dimensions than their base init files.
|
||||
# Fall back to plain env.reset() instead of aborting the whole evaluation.
|
||||
self.init_states = False
|
||||
if not self._init_state_error_warned:
|
||||
print(
|
||||
"WARNING: Failed to apply init state for "
|
||||
f"task_id={self.task_id} ({self.task}). "
|
||||
f"Falling back to plain reset. Error: {exc}"
|
||||
)
|
||||
self._init_state_error_warned = True
|
||||
raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)])
|
||||
self.init_state_id += self._reset_stride # Change init_state_id when reset
|
||||
|
||||
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
|
||||
# Step the simulator with a no-op action for a few frames so everything settles.
|
||||
@@ -596,17 +325,7 @@ class LiberoEnv(gym.Env):
|
||||
f"Expected action to be 1-D (shape (action_dim,)), "
|
||||
f"but got shape {action.shape} with ndim={action.ndim}"
|
||||
)
|
||||
|
||||
try:
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
except ValueError as e:
|
||||
if "terminated episode" not in str(e):
|
||||
raise
|
||||
# Robosuite's internal done flag is stale (e.g. from a previous
|
||||
# termination that wasn't properly cleared by SyncVectorEnv).
|
||||
# Signal termination so the caller resets us.
|
||||
obs, reset_info = self.reset()
|
||||
return obs, 0.0, True, False, {"is_success": False, **reset_info}
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
|
||||
is_success = self._env.check_success()
|
||||
terminated = done or is_success
|
||||
@@ -626,6 +345,7 @@ class LiberoEnv(gym.Env):
|
||||
"done": bool(done),
|
||||
"is_success": bool(is_success),
|
||||
}
|
||||
self.reset()
|
||||
truncated = False
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
@@ -668,9 +388,6 @@ def _make_env_fns(
|
||||
return fns
|
||||
|
||||
|
||||
_LazyVecEnv = LazyVectorEnv
|
||||
|
||||
|
||||
# ---- Main API ----------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -714,23 +431,12 @@ def create_libero_envs(
|
||||
print(f"Restricting to task_ids={task_ids_filter}")
|
||||
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
total_tasks = 0
|
||||
for suite_name in suite_names:
|
||||
suite = _get_suite(suite_name)
|
||||
total = len(suite.tasks)
|
||||
selected = _select_task_ids(total, task_ids_filter)
|
||||
if not selected:
|
||||
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
|
||||
total_tasks += len(selected)
|
||||
|
||||
lazy = total_tasks > 1
|
||||
if lazy:
|
||||
print(f"Using lazy env creation for {total_tasks} tasks (envs created on demand)")
|
||||
|
||||
for suite_name in suite_names:
|
||||
suite = _get_suite(suite_name)
|
||||
total = len(suite.tasks)
|
||||
selected = _select_task_ids(total, task_ids_filter)
|
||||
|
||||
for tid in selected:
|
||||
fns = _make_env_fns(
|
||||
@@ -744,11 +450,8 @@ def create_libero_envs(
|
||||
gym_kwargs=gym_kwargs,
|
||||
control_mode=control_mode,
|
||||
)
|
||||
if lazy:
|
||||
out[suite_name][tid] = LazyVectorEnv(env_cls, fns)
|
||||
else:
|
||||
out[suite_name][tid] = env_cls(fns)
|
||||
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||
out[suite_name][tid] = env_cls(fns)
|
||||
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||
|
||||
# return plain dicts for predictability
|
||||
return {suite: dict(task_map) for suite, task_map in out.items()}
|
||||
|
||||
@@ -25,7 +25,6 @@ import metaworld.policies as policies
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from lerobot.envs.lazy_vec_env import LazyVectorEnv
|
||||
from lerobot.processor import RobotObservation
|
||||
|
||||
# ---- Load configuration data from the external JSON file ----
|
||||
@@ -298,24 +297,19 @@ def create_metaworld_envs(
|
||||
|
||||
print(f"Creating Meta-World envs | task_groups={task_groups} | n_envs(per task)={n_envs}")
|
||||
|
||||
group_to_tasks = {group: DIFFICULTY_TO_TASKS.get(group, [group]) for group in task_groups}
|
||||
total_tasks = sum(len(tasks) for tasks in group_to_tasks.values())
|
||||
lazy = total_tasks > 50
|
||||
if lazy:
|
||||
print(f"Using lazy env creation for {total_tasks} tasks (envs created on demand)")
|
||||
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
|
||||
for group in task_groups:
|
||||
tasks = group_to_tasks[group]
|
||||
# if not in difficulty presets, treat it as a single custom task
|
||||
tasks = DIFFICULTY_TO_TASKS.get(group, [group])
|
||||
|
||||
for tid, task_name in enumerate(tasks):
|
||||
if not lazy:
|
||||
print(f"Building vec env | group={group} | task_id={tid} | task={task_name}")
|
||||
print(f"Building vec env | group={group} | task_id={tid} | task={task_name}")
|
||||
|
||||
# build n_envs factories
|
||||
fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)]
|
||||
out[group][tid] = LazyVectorEnv(env_cls, fns) if lazy else env_cls(fns)
|
||||
|
||||
out[group][tid] = env_cls(fns)
|
||||
|
||||
# return a plain dict for consistency
|
||||
return {group: dict(task_map) for group, task_map in out.items()}
|
||||
|
||||
@@ -1,279 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from lerobot.envs.lazy_vec_env import LazyVectorEnv
|
||||
|
||||
# Action layout (flat 12D, normalized to [-1, 1]):
|
||||
# [0:3] end_effector_position (delta x, y, z)
|
||||
# [3:6] end_effector_rotation (delta roll, pitch, yaw)
|
||||
# [6:7] gripper_close (open=-1, close=+1)
|
||||
# [7:11] base_motion (x, y, theta, torso_height)
|
||||
# [11:12] control_mode (arm=-1, base=+1)
|
||||
ACTION_DIM = 12
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
|
||||
# Proprioceptive state layout (flat 16D):
|
||||
# [0:2] gripper_qpos
|
||||
# [2:5] base_position
|
||||
# [5:9] base_rotation (quaternion)
|
||||
# [9:12] end_effector_position_relative
|
||||
# [12:16] end_effector_rotation_relative (quaternion)
|
||||
STATE_DIM = 16
|
||||
|
||||
# Obs dict keys from RoboCasaGymEnv.get_observation()
|
||||
_CAM_KEYS = (
|
||||
"video.robot0_agentview_left",
|
||||
"video.robot0_agentview_right",
|
||||
"video.robot0_eye_in_hand",
|
||||
)
|
||||
_STATE_KEYS_ORDERED = (
|
||||
"state.gripper_qpos", # (2,)
|
||||
"state.base_position", # (3,)
|
||||
"state.base_rotation", # (4,)
|
||||
"state.end_effector_position_relative", # (3,)
|
||||
"state.end_effector_rotation_relative", # (4,)
|
||||
)
|
||||
|
||||
# Mapping from video.* key → short image name used in features_map
|
||||
CAM_KEY_TO_NAME = {
|
||||
"video.robot0_agentview_left": "agentview_left",
|
||||
"video.robot0_agentview_right": "agentview_right",
|
||||
"video.robot0_eye_in_hand": "eye_in_hand",
|
||||
}
|
||||
|
||||
|
||||
def _flat_to_action_dict(flat: np.ndarray) -> dict[str, np.ndarray]:
|
||||
"""Convert a 12D flat action array to the Dict format expected by RoboCasaGymEnv."""
|
||||
return {
|
||||
"action.end_effector_position": flat[0:3],
|
||||
"action.end_effector_rotation": flat[3:6],
|
||||
"action.gripper_close": flat[6:7],
|
||||
"action.base_motion": flat[7:11],
|
||||
"action.control_mode": flat[11:12],
|
||||
}
|
||||
|
||||
|
||||
class RoboCasaEnv(gym.Env):
|
||||
"""Thin wrapper around RoboCasaGymEnv that provides a flat Box action space
|
||||
and a structured observation dict compatible with LeRobot policies.
|
||||
|
||||
Observations returned by step/reset:
|
||||
{
|
||||
"pixels": {
|
||||
"agentview_left": (H, W, 3) uint8,
|
||||
"agentview_right": (H, W, 3) uint8,
|
||||
"eye_in_hand": (H, W, 3) uint8,
|
||||
},
|
||||
"robot_state": (16,) float32,
|
||||
}
|
||||
|
||||
Actions: flat float32 ndarray of shape (12,), normalized to [-1, 1].
|
||||
"""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
split: str = "target",
|
||||
image_size: int = 128,
|
||||
render_mode: str = "rgb_array",
|
||||
episode_length: int = 500,
|
||||
**gym_kwargs: Any,
|
||||
):
|
||||
super().__init__()
|
||||
# Lazy import — robocasa is optional
|
||||
import robocasa.environments # noqa: F401 — registers all gym envs
|
||||
|
||||
self.task = task
|
||||
self.render_mode = render_mode
|
||||
self.image_size = image_size
|
||||
self._max_episode_steps = episode_length
|
||||
self._step_count = 0
|
||||
|
||||
self._env = gym.make(
|
||||
f"robocasa/{task}",
|
||||
split=split,
|
||||
camera_widths=image_size,
|
||||
camera_heights=image_size,
|
||||
**gym_kwargs,
|
||||
)
|
||||
|
||||
# Flat 12D Box action space
|
||||
self.action_space = spaces.Box(
|
||||
low=ACTION_LOW,
|
||||
high=ACTION_HIGH,
|
||||
shape=(ACTION_DIM,),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
images = {
|
||||
name: spaces.Box(low=0, high=255, shape=(image_size, image_size, 3), dtype=np.uint8)
|
||||
for name in CAM_KEY_TO_NAME.values()
|
||||
}
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(images),
|
||||
"robot_state": spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(STATE_DIM,), dtype=np.float32
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _format_obs(self, raw_obs: dict) -> dict:
|
||||
pixels = {
|
||||
CAM_KEY_TO_NAME[k]: raw_obs[k]
|
||||
for k in _CAM_KEYS
|
||||
if k in raw_obs
|
||||
}
|
||||
state_parts = [
|
||||
np.asarray(raw_obs[k], dtype=np.float32)
|
||||
for k in _STATE_KEYS_ORDERED
|
||||
if k in raw_obs
|
||||
]
|
||||
robot_state = np.concatenate(state_parts) if state_parts else np.zeros(STATE_DIM, dtype=np.float32)
|
||||
return {"pixels": pixels, "robot_state": robot_state}
|
||||
|
||||
def reset(self, seed: int | None = None, **kwargs) -> tuple[dict, dict]:
|
||||
super().reset(seed=seed)
|
||||
self._step_count = 0
|
||||
raw_obs, info = self._env.reset(seed=seed)
|
||||
info.setdefault("is_success", False)
|
||||
info["task"] = self.task
|
||||
return self._format_obs(raw_obs), info
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[dict, float, bool, bool, dict]:
|
||||
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
|
||||
raise ValueError(
|
||||
f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}"
|
||||
)
|
||||
action_dict = _flat_to_action_dict(action)
|
||||
raw_obs, reward, terminated, truncated, info = self._env.step(action_dict)
|
||||
self._step_count += 1
|
||||
|
||||
is_success = bool(info.get("success", False))
|
||||
terminated = terminated or is_success
|
||||
if self._step_count >= self._max_episode_steps:
|
||||
truncated = True
|
||||
|
||||
info.update({"task": self.task, "is_success": is_success})
|
||||
obs = self._format_obs(raw_obs)
|
||||
|
||||
if terminated or truncated:
|
||||
info["final_info"] = {"task": self.task, "is_success": is_success}
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def render(self) -> np.ndarray | None:
|
||||
if self.render_mode == "rgb_array":
|
||||
return self._env.render()
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
self._env.close()
|
||||
|
||||
|
||||
def _make_env_fns(
|
||||
*,
|
||||
task: str,
|
||||
n_envs: int,
|
||||
image_size: int,
|
||||
split: str,
|
||||
episode_length: int,
|
||||
gym_kwargs: dict[str, Any],
|
||||
) -> list[Callable[[], RoboCasaEnv]]:
|
||||
"""Build n_envs factory callables for a single task."""
|
||||
def _make(episode_index: int) -> RoboCasaEnv: # noqa: ARG001
|
||||
return RoboCasaEnv(
|
||||
task=task,
|
||||
split=split,
|
||||
image_size=image_size,
|
||||
episode_length=episode_length,
|
||||
**gym_kwargs,
|
||||
)
|
||||
|
||||
return [partial(_make, i) for i in range(n_envs)]
|
||||
|
||||
|
||||
def create_robocasa_envs(
|
||||
tasks: str | Sequence[str],
|
||||
n_envs: int,
|
||||
image_size: int = 128,
|
||||
split: str = "target",
|
||||
episode_length: int = 500,
|
||||
gym_kwargs: dict[str, Any] | None = None,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""Create vectorized RoboCasa environments.
|
||||
|
||||
Args:
|
||||
tasks: A single task name or list of task names (without "robocasa/" prefix).
|
||||
E.g. "PickPlaceCounterToCabinet" or ["BoilPot", "PrepareCoffee"].
|
||||
n_envs: Number of parallel envs per task.
|
||||
image_size: Square image resolution for all cameras.
|
||||
split: RoboCasa dataset split — "pretrain" or "target".
|
||||
episode_length: Max steps per episode before truncation.
|
||||
gym_kwargs: Extra kwargs forwarded to each RoboCasaEnv.
|
||||
env_cls: Callable to wrap list of factory fns (SyncVectorEnv or AsyncVectorEnv).
|
||||
|
||||
Returns:
|
||||
dict[task_name][task_id=0] -> vec_env
|
||||
"""
|
||||
if env_cls is None or not callable(env_cls):
|
||||
raise ValueError("env_cls must be a callable wrapping a list of env factory callables.")
|
||||
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||
|
||||
if isinstance(tasks, str):
|
||||
task_list = [t.strip() for t in tasks.split(",") if t.strip()]
|
||||
else:
|
||||
task_list = [str(t).strip() for t in tasks if str(t).strip()]
|
||||
if not task_list:
|
||||
raise ValueError("`tasks` must contain at least one task name.")
|
||||
|
||||
gym_kwargs = dict(gym_kwargs or {})
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
total_tasks = len(task_list)
|
||||
lazy = total_tasks > 50
|
||||
|
||||
print(f"Creating RoboCasa envs | tasks={task_list} | n_envs(per task)={n_envs} | split={split}")
|
||||
if lazy:
|
||||
print(f"Using lazy env creation for {total_tasks} tasks (envs created on demand)")
|
||||
for task in task_list:
|
||||
fns = _make_env_fns(
|
||||
task=task,
|
||||
n_envs=n_envs,
|
||||
image_size=image_size,
|
||||
split=split,
|
||||
episode_length=episode_length,
|
||||
gym_kwargs=gym_kwargs,
|
||||
)
|
||||
out["robocasa"][len(out["robocasa"])] = LazyVectorEnv(env_cls, fns) if lazy else env_cls(fns)
|
||||
if not lazy:
|
||||
print(f" Built vec env | task={task} | n_envs={n_envs}")
|
||||
|
||||
return {suite: dict(task_map) for suite, task_map in out.items()}
|
||||
@@ -1,181 +0,0 @@
|
||||
"""RoboMME environment wrapper for LeRobot evaluation.
|
||||
|
||||
Wraps the RoboMME ``BenchmarkEnvBuilder`` into a Gymnasium-compatible
|
||||
``VectorEnv`` suitable for ``lerobot_eval``.
|
||||
|
||||
RoboMME tasks:
|
||||
Counting: BinFill, PickXtimes, SwingXtimes, StopCube
|
||||
Permanence: VideoUnmask, VideoUnmaskSwap, ButtonUnmask, ButtonUnmaskSwap
|
||||
Reference: PickHighlight, VideoRepick, VideoPlaceButton, VideoPlaceOrder
|
||||
Imitation: MoveCube, InsertPeg, PatternLock, RouteStick
|
||||
|
||||
Install: pip install robomme (or from source: https://github.com/RoboMME/robomme_benchmark)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from lerobot.envs.lazy_vec_env import LazyVectorEnv
|
||||
|
||||
ROBOMME_TASKS = [
|
||||
"BinFill", "PickXtimes", "SwingXtimes", "StopCube",
|
||||
"VideoUnmask", "VideoUnmaskSwap", "ButtonUnmask", "ButtonUnmaskSwap",
|
||||
"PickHighlight", "VideoRepick", "VideoPlaceButton", "VideoPlaceOrder",
|
||||
"MoveCube", "InsertPeg", "PatternLock", "RouteStick",
|
||||
]
|
||||
|
||||
|
||||
class RoboMMEGymEnv(gym.Env):
|
||||
"""Thin Gymnasium wrapper around a single RoboMME episode env."""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"]}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: str = "PickXtimes",
|
||||
action_space_type: str = "joint_angle",
|
||||
dataset: str = "test",
|
||||
episode_idx: int = 0,
|
||||
max_steps: int = 300,
|
||||
):
|
||||
super().__init__()
|
||||
from robomme.env_record_wrapper import BenchmarkEnvBuilder
|
||||
|
||||
self._task = task
|
||||
self._action_space_type = action_space_type
|
||||
self._dataset = dataset
|
||||
self._episode_idx = episode_idx
|
||||
self._max_steps = max_steps
|
||||
|
||||
self._builder = BenchmarkEnvBuilder(
|
||||
env_id=task,
|
||||
dataset=dataset,
|
||||
action_space=action_space_type,
|
||||
gui_render=False,
|
||||
max_steps=max_steps,
|
||||
)
|
||||
self._env = None
|
||||
|
||||
action_dim = 8 if action_space_type == "joint_angle" else 7
|
||||
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(action_dim,), dtype=np.float32)
|
||||
self.observation_space = spaces.Dict({
|
||||
"front_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
|
||||
"wrist_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
|
||||
"state": spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32),
|
||||
})
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
super().reset(seed=seed)
|
||||
self._env = self._builder.make_env_for_episode(
|
||||
episode_idx=self._episode_idx, max_steps=self._max_steps,
|
||||
)
|
||||
obs, info = self._env.reset()
|
||||
return self._convert_obs(obs), self._convert_info(info)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self._env.step(action)
|
||||
|
||||
terminated_bool = bool(terminated.item()) if hasattr(terminated, "item") else bool(terminated)
|
||||
truncated_bool = bool(truncated.item()) if hasattr(truncated, "item") else bool(truncated)
|
||||
|
||||
status = info.get("status", "ongoing")
|
||||
is_success = status == "success"
|
||||
conv_info = self._convert_info(info)
|
||||
conv_info["is_success"] = is_success
|
||||
|
||||
return self._convert_obs(obs), float(reward), terminated_bool, truncated_bool, conv_info
|
||||
|
||||
def _convert_obs(self, obs: dict) -> dict:
|
||||
front_rgb = obs["front_rgb_list"][-1] if isinstance(obs["front_rgb_list"], list) else obs["front_rgb_list"]
|
||||
wrist_rgb = obs["wrist_rgb_list"][-1] if isinstance(obs["wrist_rgb_list"], list) else obs["wrist_rgb_list"]
|
||||
joint_state = obs["joint_state_list"][-1] if isinstance(obs["joint_state_list"], list) else obs["joint_state_list"]
|
||||
gripper_state = obs["gripper_state_list"][-1] if isinstance(obs["gripper_state_list"], list) else obs["gripper_state_list"]
|
||||
|
||||
front_rgb = np.asarray(front_rgb, dtype=np.uint8)
|
||||
wrist_rgb = np.asarray(wrist_rgb, dtype=np.uint8)
|
||||
joint = np.asarray(joint_state, dtype=np.float32).flatten()[:7]
|
||||
gripper = np.asarray(gripper_state, dtype=np.float32).flatten()[:1]
|
||||
state = np.concatenate([joint, gripper])
|
||||
|
||||
return {
|
||||
"front_rgb": front_rgb,
|
||||
"wrist_rgb": wrist_rgb,
|
||||
"state": state,
|
||||
}
|
||||
|
||||
def _convert_info(self, info: dict) -> dict:
|
||||
return {
|
||||
"status": info.get("status", "ongoing"),
|
||||
"task_goal": info.get("task_goal", ""),
|
||||
}
|
||||
|
||||
|
||||
def _make_env_fns(
|
||||
*,
|
||||
task: str,
|
||||
n_envs: int,
|
||||
action_space_type: str,
|
||||
dataset: str,
|
||||
episode_length: int,
|
||||
task_id: int,
|
||||
) -> list[Callable[[], RoboMMEGymEnv]]:
|
||||
"""Build n_envs factory callables for one RoboMME task id."""
|
||||
|
||||
def _make_one(episode_index: int) -> RoboMMEGymEnv:
|
||||
return RoboMMEGymEnv(
|
||||
task=task,
|
||||
action_space_type=action_space_type,
|
||||
dataset=dataset,
|
||||
episode_idx=episode_index,
|
||||
max_steps=episode_length,
|
||||
)
|
||||
|
||||
return [partial(_make_one, task_id + i) for i in range(n_envs)]
|
||||
|
||||
|
||||
def create_robomme_envs(
|
||||
task: str,
|
||||
n_envs: int = 1,
|
||||
action_space_type: str = "joint_angle",
|
||||
dataset: str = "test",
|
||||
episode_length: int = 300,
|
||||
task_ids: list[int] | None = None,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
"""Create vectorized RoboMME environments for evaluation.
|
||||
|
||||
Returns {suite_name: {task_id: VectorEnv}} matching lerobot's expected format.
|
||||
"""
|
||||
if env_cls is None or not callable(env_cls):
|
||||
raise ValueError("env_cls must be a callable that wraps a list of env factory callables.")
|
||||
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||
|
||||
if task_ids is None:
|
||||
task_ids = [0]
|
||||
|
||||
suite_name = "robomme"
|
||||
envs_by_task = {}
|
||||
lazy = len(task_ids) > 50
|
||||
if lazy:
|
||||
print(f"Using lazy env creation for {len(task_ids)} tasks (envs created on demand)")
|
||||
|
||||
for task_id in task_ids:
|
||||
fns = _make_env_fns(
|
||||
task=task,
|
||||
n_envs=n_envs,
|
||||
action_space_type=action_space_type,
|
||||
dataset=dataset,
|
||||
episode_length=episode_length,
|
||||
task_id=task_id,
|
||||
)
|
||||
envs_by_task[task_id] = LazyVectorEnv(env_cls, fns) if lazy else env_cls(fns)
|
||||
|
||||
return {suite_name: envs_by_task}
|
||||
@@ -122,12 +122,14 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
port: str,
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
protocol_version: int = PROTOCOL_VERSION,
|
||||
):
|
||||
super().__init__(port, motors, calibration)
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
self.port_handler = dxl.PortHandler(self.port)
|
||||
self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
|
||||
self.packet_handler = dxl.PacketHandler(protocol_version)
|
||||
print(f"Using protocol version {protocol_version}")
|
||||
self.sync_reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0)
|
||||
self.sync_writer = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0)
|
||||
self._comm_success = dxl.COMM_SUCCESS
|
||||
|
||||
@@ -33,6 +33,58 @@
|
||||
# 2. We can change the value of the MyControlTableKey enums without impacting the client code
|
||||
|
||||
|
||||
# {data_name: (address, size_byte)}
|
||||
# https://emanual.robotis.com/docs/en/dxl/ax/{MODEL}/#control-table
|
||||
AX_SERIES_CONTROL_TABLE = {
|
||||
# EEPROM Area
|
||||
"Model_Number": (0, 2),
|
||||
"Firmware_Version": (2, 1),
|
||||
"ID": (3, 1),
|
||||
"Baud_Rate": (4, 1),
|
||||
"Return_Delay_Time": (5, 1),
|
||||
"CW_Angle_Limit": (6, 2),
|
||||
"CCW_Angle_Limit": (8, 2),
|
||||
"Temperature_Limit": (11, 1),
|
||||
"Min_Voltage_Limit": (12, 1),
|
||||
"Max_Voltage_Limit": (13, 1),
|
||||
"Max_Torque": (14, 2),
|
||||
"Status_Return_Level": (16, 1),
|
||||
"Alarm_LED": (17, 1),
|
||||
"Shutdown": (18, 1),
|
||||
# RAM Area
|
||||
"Torque_Enable": (24, 1),
|
||||
"LED": (25, 1),
|
||||
"CW_Compliance_Margin": (26, 1),
|
||||
"CCW_Compliance_Margin": (27, 1),
|
||||
"CW_Compliance_Slope": (28, 1),
|
||||
"CCW_Compliance_Slope": (29, 1),
|
||||
"Goal_Position": (30, 2),
|
||||
"Moving_Speed": (32, 2),
|
||||
"Torque_Limit": (34, 2),
|
||||
"Present_Position": (36, 2),
|
||||
"Present_Speed": (38, 2),
|
||||
"Present_Load": (40, 2),
|
||||
"Present_Voltage": (42, 1),
|
||||
"Present_Temperature": (43, 1),
|
||||
"Registered": (44, 1),
|
||||
"Moving": (46, 1),
|
||||
"Lock": (47, 1),
|
||||
"Punch": (48, 2),
|
||||
}
|
||||
|
||||
# https://emanual.robotis.com/docs/en/dxl/ax/{MODEL}/#baud-rate4
|
||||
AX_SERIES_BAUDRATE_TABLE = {
|
||||
9_600: 207,
|
||||
19_200: 103,
|
||||
57_600: 34,
|
||||
115_200: 16,
|
||||
200_000: 9,
|
||||
250_000: 7,
|
||||
400_000: 4,
|
||||
500_000: 3,
|
||||
1_000_000: 1,
|
||||
}
|
||||
|
||||
# {data_name: (address, size_byte)}
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#control-table
|
||||
X_SERIES_CONTROL_TABLE = {
|
||||
@@ -114,6 +166,14 @@ X_SERIES_ENCODINGS_TABLE = {
|
||||
"Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1],
|
||||
}
|
||||
|
||||
# {data_name: size_byte}
|
||||
AX_SERIES_ENCODINGS_TABLE = {
|
||||
"Goal_Position": AX_SERIES_CONTROL_TABLE["Goal_Position"][1],
|
||||
"Moving_Speed": AX_SERIES_CONTROL_TABLE["Moving_Speed"][1],
|
||||
"Present_Position": AX_SERIES_CONTROL_TABLE["Present_Position"][1],
|
||||
"Present_Speed": AX_SERIES_CONTROL_TABLE["Present_Speed"][1],
|
||||
}
|
||||
|
||||
MODEL_ENCODING_TABLE = {
|
||||
"x_series": X_SERIES_ENCODINGS_TABLE,
|
||||
"xl330-m077": X_SERIES_ENCODINGS_TABLE,
|
||||
@@ -122,6 +182,8 @@ MODEL_ENCODING_TABLE = {
|
||||
"xm430-w350": X_SERIES_ENCODINGS_TABLE,
|
||||
"xm540-w270": X_SERIES_ENCODINGS_TABLE,
|
||||
"xc430-w150": X_SERIES_ENCODINGS_TABLE,
|
||||
"ax_series": AX_SERIES_ENCODINGS_TABLE,
|
||||
"ax-12a": AX_SERIES_ENCODINGS_TABLE,
|
||||
}
|
||||
|
||||
# {model: model_resolution}
|
||||
@@ -134,6 +196,8 @@ MODEL_RESOLUTION = {
|
||||
"xm430-w350": 4096,
|
||||
"xm540-w270": 4096,
|
||||
"xc430-w150": 4096,
|
||||
"ax_series": 1024,
|
||||
"ax-12a": 1024,
|
||||
}
|
||||
|
||||
# {model: model_number}
|
||||
@@ -145,6 +209,7 @@ MODEL_NUMBER_TABLE = {
|
||||
"xm430-w350": 1020,
|
||||
"xm540-w270": 1120,
|
||||
"xc430-w150": 1070,
|
||||
"ax-12a": 12,
|
||||
}
|
||||
|
||||
# {model: available_operating_modes}
|
||||
@@ -166,6 +231,8 @@ MODEL_CONTROL_TABLE = {
|
||||
"xm430-w350": X_SERIES_CONTROL_TABLE,
|
||||
"xm540-w270": X_SERIES_CONTROL_TABLE,
|
||||
"xc430-w150": X_SERIES_CONTROL_TABLE,
|
||||
"ax_series": AX_SERIES_CONTROL_TABLE,
|
||||
"ax-12a": AX_SERIES_CONTROL_TABLE,
|
||||
}
|
||||
|
||||
MODEL_BAUDRATE_TABLE = {
|
||||
@@ -176,6 +243,8 @@ MODEL_BAUDRATE_TABLE = {
|
||||
"xm430-w350": X_SERIES_BAUDRATE_TABLE,
|
||||
"xm540-w270": X_SERIES_BAUDRATE_TABLE,
|
||||
"xc430-w150": X_SERIES_BAUDRATE_TABLE,
|
||||
"ax_series": AX_SERIES_BAUDRATE_TABLE,
|
||||
"ax-12a": AX_SERIES_BAUDRATE_TABLE,
|
||||
}
|
||||
|
||||
AVAILABLE_BAUDRATES = [
|
||||
|
||||
@@ -29,7 +29,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from pprint import pformat
|
||||
from typing import Protocol
|
||||
from typing import Protocol, TypeAlias
|
||||
|
||||
import serial
|
||||
from deepdiff import DeepDiff
|
||||
@@ -38,8 +38,8 @@ from tqdm import tqdm
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
type NameOrID = str | int
|
||||
type Value = int | float
|
||||
NameOrID: TypeAlias = str | int
|
||||
Value: TypeAlias = int | float
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1277,4 +1277,4 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
MotorsBus = SerialMotorsBus
|
||||
MotorsBus: TypeAlias = SerialMotorsBus
|
||||
|
||||
@@ -18,9 +18,10 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Any, TypedDict, Unpack
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import torch
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
@@ -4,16 +4,17 @@
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
|
||||
from typing import Optional
|
||||
|
||||
from transformers.image_processing_utils import (
|
||||
BatchFeature,
|
||||
get_patch_output_size,
|
||||
)
|
||||
from transformers.image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
ImagesKwargs,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
@@ -76,7 +77,7 @@ def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> tor
|
||||
return img[:, top:bottom, left:right]
|
||||
|
||||
|
||||
class Eagle25VLFastImageProcessorKwargs(ImagesKwargs):
|
||||
class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
max_dynamic_tiles: int | None
|
||||
min_dynamic_tiles: int | None
|
||||
use_thumbnail: bool | None
|
||||
@@ -164,11 +165,11 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
def _resize_for_patching(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
image: "torch.Tensor",
|
||||
target_resolution: tuple,
|
||||
interpolation: F.InterpolationMode,
|
||||
interpolation: "F.InterpolationMode",
|
||||
input_data_format: ChannelDimension,
|
||||
) -> torch.Tensor:
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Resizes an image to a target resolution while maintaining aspect ratio.
|
||||
|
||||
@@ -218,8 +219,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
return best_ratio
|
||||
|
||||
def _pad_for_patching(
|
||||
self, image: torch.Tensor, target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> torch.Tensor:
|
||||
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
@@ -235,15 +236,15 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
def _get_image_patches(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
image: "torch.Tensor",
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
size: tuple,
|
||||
tile_size: int,
|
||||
use_thumbnail: bool,
|
||||
interpolation: F.InterpolationMode,
|
||||
interpolation: "F.InterpolationMode",
|
||||
pad_during_tiling: bool,
|
||||
) -> list[torch.Tensor]:
|
||||
) -> list["torch.Tensor"]:
|
||||
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
||||
orig_height, orig_width = image_size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
@@ -304,8 +305,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
def _pad_for_batching(
|
||||
self,
|
||||
pixel_values: list[torch.Tensor],
|
||||
) -> list[torch.Tensor]:
|
||||
pixel_values: list["torch.Tensor"],
|
||||
) -> list["torch.Tensor"]:
|
||||
"""
|
||||
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
||||
|
||||
@@ -326,14 +327,14 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list[torch.Tensor],
|
||||
images: list["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
max_dynamic_tiles: int,
|
||||
min_dynamic_tiles: int,
|
||||
use_thumbnail: bool,
|
||||
pad_during_tiling: bool,
|
||||
interpolation: F.InterpolationMode | None,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
|
||||
@@ -15,16 +15,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -32,21 +32,13 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaForCausalLM,
|
||||
_gated_residual,
|
||||
layernorm_forward,
|
||||
)
|
||||
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
layernorm_forward = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
|
||||
GemmaForCausalLM = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||
@@ -199,7 +191,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(0.0, 1.0)
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
@@ -210,7 +202,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
@@ -229,14 +221,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
models = [paligemma.language_model, gemma_expert.model]
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
@@ -262,10 +254,10 @@ def compute_layer_complete(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
paligemma.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -273,7 +265,7 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
@@ -285,15 +277,15 @@ def compute_layer_complete(
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||
# first residual
|
||||
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
# second residual
|
||||
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
||||
outputs_embeds.append(out_emb)
|
||||
start_pos = end_pos
|
||||
return outputs_embeds
|
||||
@@ -366,7 +358,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
@@ -374,7 +366,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
||||
|
||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||
head_dim=action_expert_config.head_dim,
|
||||
@@ -385,13 +377,13 @@ class PaliGemmaWithExpertModel(
|
||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||
vocab_size=257152,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
dtype="float32",
|
||||
torch_dtype="float32",
|
||||
use_adarms=use_adarms[1],
|
||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||
)
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
||||
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
@@ -406,11 +398,10 @@ class PaliGemmaWithExpertModel(
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||
# "same dtype" error). Align with PI05.
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
@@ -422,8 +413,8 @@ class PaliGemmaWithExpertModel(
|
||||
|
||||
def _set_requires_grad(self):
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
for param in self.paligemma.model.vision_tower.parameters():
|
||||
self.paligemma.vision_tower.eval()
|
||||
for param in self.paligemma.vision_tower.parameters():
|
||||
param.requires_grad = False
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
@@ -433,23 +424,15 @@ class PaliGemmaWithExpertModel(
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
self.paligemma.vision_tower.eval()
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
return self.paligemma.model.get_image_features(image)
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -463,7 +446,7 @@ class PaliGemmaWithExpertModel(
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
prefix_output = self.paligemma.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@@ -487,7 +470,7 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
models = [self.paligemma.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
@@ -527,7 +510,7 @@ class PaliGemmaWithExpertModel(
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -593,19 +576,29 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Also compile the main forward pass used during training
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
@@ -767,7 +760,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
||||
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -841,7 +834,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks_4d,
|
||||
@@ -915,7 +908,6 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
@@ -1005,12 +997,14 @@ class PI0Policy(PreTrainedPolicy):
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# Load state dict (expects keys with "model." prefix)
|
||||
# Now manually load and remap the state dict
|
||||
try:
|
||||
# Try to load the pytorch_model.bin or model.safetensors file
|
||||
print(f"Loading model from: {pretrained_name_or_path}")
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_name_or_path,
|
||||
"model.safetensors",
|
||||
@@ -1018,7 +1012,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
token=kwargs.get("token"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
@@ -1031,7 +1025,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
@@ -1076,7 +1070,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
print("All keys loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load state dict: {e}")
|
||||
print(f"Warning: Could not remap state dict keys: {e}")
|
||||
|
||||
return model
|
||||
|
||||
@@ -1126,14 +1120,6 @@ class PI0Policy(PreTrainedPolicy):
|
||||
# Some checkpoints might have this, but current model expects different structure
|
||||
logging.warning(f"Vision embedding key might need handling: {key}")
|
||||
|
||||
if (
|
||||
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
|
||||
or key == "paligemma_with_expert.paligemma.lm_head.weight"
|
||||
):
|
||||
fixed_state_dict[
|
||||
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
] = value.clone()
|
||||
|
||||
fixed_state_dict[new_key] = value
|
||||
|
||||
return fixed_state_dict
|
||||
|
||||
@@ -15,16 +15,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -32,20 +32,14 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaForCausalLM,
|
||||
_gated_residual,
|
||||
layernorm_forward,
|
||||
)
|
||||
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
layernorm_forward = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
GemmaForCausalLM = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
@@ -98,11 +92,10 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
|
||||
alpha_t = torch.tensor(alpha, dtype=torch.float32)
|
||||
beta_t = torch.tensor(beta, dtype=torch.float32)
|
||||
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
||||
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,)).to(device)
|
||||
return dist.sample((bsize,))
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
|
||||
@@ -196,7 +189,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(0.0, 1.0)
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
@@ -207,7 +200,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
@@ -226,14 +219,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
models = [paligemma.language_model, gemma_expert.model]
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
@@ -259,10 +252,10 @@ def compute_layer_complete(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
paligemma.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -270,7 +263,7 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
@@ -282,15 +275,15 @@ def compute_layer_complete(
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||
# first residual
|
||||
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
# second residual
|
||||
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
||||
outputs_embeds.append(out_emb)
|
||||
start_pos = end_pos
|
||||
return outputs_embeds
|
||||
@@ -363,7 +356,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
@@ -371,7 +364,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
||||
|
||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||
head_dim=action_expert_config.head_dim,
|
||||
@@ -382,13 +375,13 @@ class PaliGemmaWithExpertModel(
|
||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||
vocab_size=257152,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
dtype="float32",
|
||||
torch_dtype="float32",
|
||||
use_adarms=use_adarms[1],
|
||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||
)
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
||||
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
@@ -403,11 +396,10 @@ class PaliGemmaWithExpertModel(
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||
# "same dtype" error). Saves memory vs full float32; more memory than only 3 params.
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
@@ -419,8 +411,8 @@ class PaliGemmaWithExpertModel(
|
||||
|
||||
def _set_requires_grad(self):
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
for param in self.paligemma.model.vision_tower.parameters():
|
||||
self.paligemma.vision_tower.eval()
|
||||
for param in self.paligemma.vision_tower.parameters():
|
||||
param.requires_grad = False
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
@@ -430,23 +422,15 @@ class PaliGemmaWithExpertModel(
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
self.paligemma.vision_tower.eval()
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32).
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
return self.paligemma.model.get_image_features(image)
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -460,7 +444,7 @@ class PaliGemmaWithExpertModel(
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
prefix_output = self.paligemma.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@@ -484,7 +468,7 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
models = [self.paligemma.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
@@ -524,7 +508,7 @@ class PaliGemmaWithExpertModel(
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -589,19 +573,29 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Also compile the main forward pass used during training
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
||||
|
||||
@@ -743,7 +737,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -814,7 +808,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks_4d,
|
||||
@@ -886,7 +880,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
@@ -976,12 +969,14 @@ class PI05Policy(PreTrainedPolicy):
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# Load state dict (expects keys with "model." prefix)
|
||||
# Now manually load and remap the state dict
|
||||
try:
|
||||
# Try to load the pytorch_model.bin or model.safetensors file
|
||||
print(f"Loading model from: {pretrained_name_or_path}")
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_name_or_path,
|
||||
"model.safetensors",
|
||||
@@ -989,7 +984,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
token=kwargs.get("token"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
@@ -1002,7 +997,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
@@ -1014,6 +1009,8 @@ class PI05Policy(PreTrainedPolicy):
|
||||
new_key = f"model.{key}"
|
||||
remapped_state_dict[new_key] = value
|
||||
remap_count += 1
|
||||
if remap_count <= 10: # Only print first 10 to avoid spam
|
||||
print(f"Remapped: {key} -> {new_key}")
|
||||
else:
|
||||
remapped_state_dict[key] = value
|
||||
|
||||
@@ -1047,7 +1044,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
print("All keys loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load state dict: {e}")
|
||||
print(f"Warning: Could not remap state dict keys: {e}")
|
||||
|
||||
return model
|
||||
|
||||
@@ -1101,14 +1098,6 @@ class PI05Policy(PreTrainedPolicy):
|
||||
# Some checkpoints might have this, but current model expects different structure
|
||||
logging.warning(f"Vision embedding key might need handling: {key}")
|
||||
|
||||
if (
|
||||
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
|
||||
or key == "paligemma_with_expert.paligemma.lm_head.weight"
|
||||
):
|
||||
fixed_state_dict[
|
||||
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
] = value.clone()
|
||||
|
||||
fixed_state_dict[new_key] = value
|
||||
|
||||
return fixed_state_dict
|
||||
|
||||
@@ -23,6 +23,7 @@ import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
@@ -67,6 +68,9 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.max_state_dim)
|
||||
|
||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
|
||||
@@ -54,7 +54,7 @@ class PI0FastConfig(PreTrainedConfig):
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
|
||||
action_tokenizer_name: str = "lerobot/fast-action-tokenizer"
|
||||
action_tokenizer_name: str = "physical-intelligence/fast"
|
||||
temperature: float = 0.0
|
||||
max_decoding_steps: int = 256
|
||||
fast_skip_tokens: int = 128
|
||||
|
||||
@@ -19,12 +19,13 @@ import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _scipy_available, _transformers_available
|
||||
|
||||
@@ -37,16 +38,11 @@ else:
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaModel,
|
||||
)
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
AutoTokenizer = None
|
||||
PiGemmaModel = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||
@@ -125,7 +121,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(0.0, 1.0)
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
@@ -136,7 +132,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
@@ -210,22 +206,16 @@ class PI0FastPaliGemma(nn.Module):
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
|
||||
# Use PI Gemma (AdaRMS) as language model when use_adarms[0] is True so that
|
||||
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
|
||||
if use_adarms[0]:
|
||||
text_config = self.paligemma.config.text_config
|
||||
self.paligemma.model.language_model = PiGemmaModel(text_config)
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
|
||||
@@ -238,11 +228,10 @@ class PI0FastPaliGemma(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||
# "same dtype" error). Align with PI05.
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
@@ -253,18 +242,10 @@ class PI0FastPaliGemma(nn.Module):
|
||||
param.data = param.data.to(dtype=torch.float32)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
return self.paligemma.model.get_image_features(image)
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -278,7 +259,7 @@ class PI0FastPaliGemma(nn.Module):
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
prefix_output = self.paligemma.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@@ -325,14 +306,24 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode)
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
# Call the proper gradient_checkpointing_enable() method with use_reentrant=False for better memory efficiency
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_enable(
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_enable(
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
logging.info("Enabled gradient checkpointing for PI0FastPytorch model")
|
||||
@@ -341,8 +332,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
# Call the proper gradient_checkpointing_disable() method
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_disable()
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_disable()
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_disable()
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_disable()
|
||||
logging.info("Disabled gradient checkpointing for PI0FastPytorch model")
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
@@ -532,7 +523,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
# Convert embeddings to bfloat16 if needed
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -625,7 +616,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -723,7 +714,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
# Ensure correct precision (bfloat16/float32)
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -906,12 +897,14 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# Load state dict (expects keys with "model." prefix)
|
||||
# Now manually load and remap the state dict
|
||||
try:
|
||||
# Try to load the pytorch_model.bin or model.safetensors file
|
||||
print(f"Loading model from: {pretrained_name_or_path}")
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_name_or_path,
|
||||
"model.safetensors",
|
||||
@@ -919,7 +912,7 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
token=kwargs.get("token"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
@@ -932,9 +925,8 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
remapped_state_dict = {}
|
||||
remap_count = 0
|
||||
@@ -944,6 +936,8 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
new_key = f"model.{key}"
|
||||
remapped_state_dict[new_key] = value
|
||||
remap_count += 1
|
||||
if remap_count <= 10: # Only print first 10 to avoid spam
|
||||
print(f"Remapped: {key} -> {new_key}")
|
||||
else:
|
||||
remapped_state_dict[key] = value
|
||||
|
||||
@@ -977,7 +971,7 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
print("All keys loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load state dict: {e}")
|
||||
print(f"Warning: Could not remap state dict keys: {e}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||
from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
|
||||
from lerobot.processor import (
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -68,6 +69,9 @@ class Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.max_state_dim)
|
||||
|
||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
|
||||
@@ -1,363 +0,0 @@
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.masking_utils import create_causal_mask
|
||||
from transformers.modeling_layers import GradientCheckpointingLayer
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.gemma.modeling_gemma import (
|
||||
GemmaAttention,
|
||||
GemmaConfig,
|
||||
GemmaForCausalLM,
|
||||
GemmaMLP,
|
||||
GemmaModel,
|
||||
)
|
||||
from transformers.models.paligemma.modeling_paligemma import (
|
||||
PaliGemmaForConditionalGeneration,
|
||||
PaliGemmaModel,
|
||||
)
|
||||
else:
|
||||
GemmaAttention = None
|
||||
GemmaConfig = None
|
||||
GemmaForCausalLM = None
|
||||
GemmaMLP = None
|
||||
GemmaModel = None
|
||||
PaliGemmaModel = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
DynamicCache = None
|
||||
GradientCheckpointingLayer = None
|
||||
BaseModelOutputWithPast = None
|
||||
create_causal_mask = None
|
||||
|
||||
|
||||
def _gated_residual(
|
||||
x: torch.Tensor | None,
|
||||
y: torch.Tensor | None,
|
||||
gate: torch.Tensor | None,
|
||||
) -> torch.Tensor | None:
|
||||
"""Gated residual: x + y when gate is None, else x + y * gate."""
|
||||
if x is None and y is None:
|
||||
return None
|
||||
if x is None or y is None:
|
||||
return x if x is not None else y
|
||||
if gate is None:
|
||||
return x + y
|
||||
return x + y * gate
|
||||
|
||||
|
||||
def layernorm_forward(
|
||||
layernorm: nn.Module,
|
||||
x: torch.Tensor,
|
||||
cond: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
call layernorm and return hidden states and gate
|
||||
if cond is not None, use conditional norm
|
||||
otherwise, use normal gemma norm
|
||||
"""
|
||||
if cond is not None:
|
||||
return layernorm(x, cond=cond)
|
||||
else:
|
||||
return layernorm(x)
|
||||
|
||||
|
||||
class PiGemmaRMSNorm(nn.Module):
|
||||
"""
|
||||
Adaptive RMSNorm for PI Gemma (AdaRMS).
|
||||
When cond_dim is set, uses cond to modulate scale/shift/gate; otherwise behaves like standard GemmaRMSNorm.
|
||||
forward(x, cond=None) returns (output, gate) for use with _gated_residual.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.dim = dim
|
||||
self.cond_dim = cond_dim
|
||||
if cond_dim is not None:
|
||||
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
|
||||
nn.init.zeros_(self.dense.weight)
|
||||
else:
|
||||
self.weight = nn.Parameter(torch.zeros(dim))
|
||||
self.dense = None
|
||||
|
||||
def _norm(self, x):
|
||||
# Compute variance in float32 (like the source implementation)
|
||||
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
|
||||
# Compute normalization in float32
|
||||
normed_inputs = x * torch.rsqrt(var + self.eps)
|
||||
return normed_inputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cond: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
dtype = x.dtype
|
||||
normed = self._norm(x)
|
||||
if cond is None or self.dense is None:
|
||||
normed = normed * (1.0 + self.weight.float())
|
||||
return normed.type_as(x), None
|
||||
if cond.shape[-1] != self.cond_dim:
|
||||
raise ValueError(f"Expected cond dim {self.cond_dim}, got {cond.shape[-1]}")
|
||||
modulation = self.dense(cond)
|
||||
if len(x.shape) == 3:
|
||||
modulation = modulation.unsqueeze(1)
|
||||
scale, shift, gate = modulation.chunk(3, dim=-1)
|
||||
normed = normed * (1 + scale.float()) + shift.float()
|
||||
return normed.to(dtype), gate.to(dtype)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
if self.dense is not None:
|
||||
return f"dim={self.dim}, eps={self.eps}, adaptive=True, cond_dim={self.cond_dim}"
|
||||
return f"dim={self.dim}, eps={self.eps}"
|
||||
|
||||
|
||||
def _get_pi_gemma_decoder_layer_base():
|
||||
"""base for PiGemmaDecoderLayer"""
|
||||
|
||||
class _PiGemmaDecoderLayerBase(GradientCheckpointingLayer):
|
||||
"""Decoder layer that uses PiGemmaRMSNorm and _gated_residual, compatible with v5 Gemma."""
|
||||
|
||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = GemmaMLP(config)
|
||||
cond_dim = (
|
||||
getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
|
||||
)
|
||||
self.input_layernorm = PiGemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
||||
)
|
||||
self.post_attention_layernorm = PiGemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values=None,
|
||||
use_cache: bool = False,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states, gate = self.input_layernorm(hidden_states, cond=adarms_cond)
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states, gate = self.post_attention_layernorm(hidden_states, cond=adarms_cond)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||
return hidden_states
|
||||
|
||||
return _PiGemmaDecoderLayerBase
|
||||
|
||||
|
||||
class PiGemmaModel(GemmaModel): # type: ignore[misc]
|
||||
"""
|
||||
GemmaModel extended with AdaRMS (adaptive RMSNorm) and gated residuals when config.use_adarms is True.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GemmaConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
# if not getattr(config, "use_adarms", False):
|
||||
# return
|
||||
cond_dim = getattr(config, "adarms_cond_dim", None)
|
||||
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
|
||||
self.layers = nn.ModuleList(
|
||||
[pi_gemma_decoder_layer_base(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = PiGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: DynamicCache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> BaseModelOutputWithPast:
|
||||
"""
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
import logging
|
||||
|
||||
logging.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
# Convert to bfloat16 if the first layer uses bfloat16
|
||||
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.bfloat16)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# normalized
|
||||
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
adarms_cond=adarms_cond,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, adarms_cond)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
|
||||
"""
|
||||
Causal LM wrapper using PiGemmaModel as the backbone, for consistency with GemmaForCausalLM
|
||||
and the language model used in pi0_fast. Use this for the action expert in pi0/pi05.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GemmaConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.model = PiGemmaModel(config)
|
||||
|
||||
|
||||
class PaliGemmaModelWithPiGemma(PaliGemmaModel):
|
||||
"""PaliGemmaModel whose language_model is PiGemmaModel (custom decoder with PiGemmaRMSNorm and gated residuals)."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.language_model = PiGemmaModel(config.text_config)
|
||||
|
||||
|
||||
class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGeneration):
|
||||
"""PaliGemmaForConditionalGeneration using PiGemma decoder for the language model."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = PaliGemmaModelWithPiGemma(config)
|
||||
|
||||
# Make modules available through conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
return self.model.language_model
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PiGemmaModel",
|
||||
"PiGemmaForCausalLM",
|
||||
"PiGemmaRMSNorm",
|
||||
"_gated_residual",
|
||||
"layernorm_forward",
|
||||
"PaliGemmaModelWithPiGemma",
|
||||
"PaliGemmaForConditionalGenerationWithPiGemma",
|
||||
]
|
||||
@@ -19,7 +19,7 @@ import os
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import TypedDict, TypeVar, Unpack
|
||||
from typing import TypedDict, TypeVar
|
||||
|
||||
import packaging
|
||||
import safetensors
|
||||
@@ -28,6 +28,7 @@ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
@@ -33,7 +33,7 @@ class RewardClassifierConfig(PreTrainedConfig):
|
||||
latent_dim: int = 256
|
||||
image_embedding_pooling_dim: int = 8
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
|
||||
model_name: str = "helper2424/resnet10"
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
|
||||
@@ -54,11 +54,12 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import TypedDict, Unpack
|
||||
from typing import TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
|
||||
@@ -55,7 +55,7 @@ class WallXConfig(PreTrainedConfig):
|
||||
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
|
||||
|
||||
# Tokenizer settings
|
||||
action_tokenizer_path: str | None = "lerobot/fast-action-tokenizer"
|
||||
action_tokenizer_path: str | None = "physical-intelligence/fast"
|
||||
|
||||
# Action prediction mode: "diffusion" or "fast"
|
||||
prediction_mode: str = "diffusion"
|
||||
|
||||
@@ -261,15 +261,10 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
and optional LoRA fine-tuning support.
|
||||
"""
|
||||
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
config_class = Qwen2_5_VLConfig
|
||||
_no_split_modules = ["Qwen2_5_VLDecoderLayer_with_MoE", "Qwen2_5_VLVisionBlock"]
|
||||
|
||||
def init_weights(self):
|
||||
if getattr(self.model, "language_model", None) is not None:
|
||||
return
|
||||
super().init_weights()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -317,11 +312,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
processor.action_processor = action_tokenizer
|
||||
else:
|
||||
action_tokenizer = None
|
||||
|
||||
# add pad_token_id to config
|
||||
config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
config.text_config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
|
||||
# Initialize model with configuration and processor
|
||||
model = cls(config, processor=processor, action_tokenizer=action_tokenizer, **kwargs)
|
||||
|
||||
@@ -341,7 +331,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
token=kwargs.get("token"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
|
||||
@@ -21,7 +21,6 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
window_size=112,
|
||||
out_hidden_size=3584,
|
||||
fullatt_block_indexes=[7, 15, 23, 31],
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@@ -39,7 +38,6 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
self.window_size = window_size
|
||||
self.fullatt_block_indexes = fullatt_block_indexes
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
|
||||
@@ -11,6 +11,7 @@ from transformers.activations import ACT2FN
|
||||
from transformers.cache_utils import (
|
||||
Cache,
|
||||
DynamicCache,
|
||||
SlidingWindowCache,
|
||||
StaticCache,
|
||||
)
|
||||
from transformers.generation import GenerationMixin
|
||||
@@ -30,15 +31,6 @@ from transformers.utils import (
|
||||
|
||||
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
|
||||
|
||||
|
||||
# TODO(Steven): SlidingWindowCache was removed in transformers v5. Define a placeholder so isinstance checks
|
||||
# always return False (which is the correct behavior when no sliding window cache is in use).
|
||||
class _SlidingWindowCachePlaceholder:
|
||||
pass
|
||||
|
||||
|
||||
SlidingWindowCache = _SlidingWindowCachePlaceholder
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
@@ -602,40 +594,19 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _compute_default_rope_parameters_qwen2_5_vl(config, device=None):
|
||||
"""
|
||||
compute default rope parameters for Qwen2_5_VL
|
||||
"""
|
||||
base = config.text_config.rope_parameters["rope_theta"]
|
||||
dim = config.hidden_size // config.num_attention_heads
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
||||
)
|
||||
return inv_freq, 1.0
|
||||
|
||||
|
||||
class Qwen2_5_VLRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Qwen2_5_VLConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
elif hasattr(config, "rope_parameters") and config.rope_parameters is not None:
|
||||
self.rope_type = config.rope_parameters.get("rope_type", "default")
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
|
||||
if self.rope_type == "default":
|
||||
self.rope_init_fn = _compute_default_rope_parameters_qwen2_5_vl
|
||||
self.rope_kwargs = {}
|
||||
else:
|
||||
rope_type_key = "linear" if self.rope_type == "linear" else self.rope_type
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type_key]
|
||||
self.rope_kwargs = {}
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
@@ -1596,7 +1567,7 @@ QWEN2_5_VL_INPUTS_DOCSTRING = r"""
|
||||
|
||||
|
||||
class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
config_class = Qwen2_5_VLConfig
|
||||
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ def preprocesser_call(
|
||||
"""
|
||||
# Process image inputs
|
||||
if images is not None and len(images) > 0:
|
||||
image_inputs = processor.image_processor(images=images, return_tensors=return_tensors)
|
||||
image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
@@ -152,7 +152,7 @@ def preprocesser_call(
|
||||
|
||||
# Process video inputs
|
||||
if videos is not None:
|
||||
videos_inputs = processor.image_processor(videos=videos, return_tensors=return_tensors)
|
||||
videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
|
||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||
else:
|
||||
videos_inputs = {}
|
||||
|
||||
@@ -276,8 +276,6 @@ class Florence2LanguageConfig(PretrainedConfig):
|
||||
)
|
||||
|
||||
# ensure backward compatibility for BART CNN models
|
||||
if not hasattr(self, "forced_bos_token_id"):
|
||||
self.forced_bos_token_id = None
|
||||
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
||||
self.forced_bos_token_id = self.bos_token_id
|
||||
warnings.warn(
|
||||
|
||||
@@ -1951,10 +1951,7 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
||||
|
||||
|
||||
class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
||||
_tied_weights_keys = {
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: Florence2LanguageConfig):
|
||||
super().__init__(config)
|
||||
@@ -2079,10 +2076,7 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
||||
|
||||
class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
_tied_weights_keys = {
|
||||
"model.encoder.embed_tokens.weight": "model.shared.weight",
|
||||
"model.decoder.embed_tokens.weight": "model.shared.weight",
|
||||
}
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||
|
||||
def __init__(self, config: Florence2LanguageConfig):
|
||||
@@ -2442,10 +2436,11 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
|
||||
FLORENCE2_START_DOCSTRING,
|
||||
)
|
||||
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
||||
_tied_weights_keys = {
|
||||
"language_model.model.encoder.embed_tokens.weight": "language_model.model.shared.weight",
|
||||
"language_model.model.decoder.embed_tokens.weight": "language_model.model.shared.weight",
|
||||
}
|
||||
_tied_weights_keys = [
|
||||
"language_model.encoder.embed_tokens.weight",
|
||||
"language_model.decoder.embed_tokens.weight",
|
||||
"language_model.lm_head.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: Florence2Config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypeAlias, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -36,10 +36,10 @@ class TransitionKey(str, Enum):
|
||||
COMPLEMENTARY_DATA = "complementary_data"
|
||||
|
||||
|
||||
PolicyAction = torch.Tensor
|
||||
RobotAction = dict[str, Any]
|
||||
EnvAction = np.ndarray
|
||||
RobotObservation = dict[str, Any]
|
||||
PolicyAction: TypeAlias = torch.Tensor
|
||||
RobotAction: TypeAlias = dict[str, Any]
|
||||
EnvAction: TypeAlias = np.ndarray
|
||||
RobotObservation: TypeAlias = dict[str, Any]
|
||||
|
||||
|
||||
EnvTransition = TypedDict(
|
||||
|
||||
@@ -153,44 +153,6 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="robocasa_processor")
|
||||
class RoboCasaProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes RoboCasa observations into LeRobot format.
|
||||
|
||||
The RoboCasaEnv wrapper returns:
|
||||
- ``pixels.<cam_name>``: (B, C, H, W) float32 images (already converted by vectorenv)
|
||||
- ``observation.robot_state``: (B, 16) float32 proprioception
|
||||
|
||||
This step remaps them to:
|
||||
- ``observation.images.<cam_name>`` (unchanged tensor)
|
||||
- ``observation.state`` (robot_state renamed)
|
||||
"""
|
||||
|
||||
def _process_observation(self, observation: dict) -> dict:
|
||||
processed = {}
|
||||
obs_prefix = OBS_PREFIX # "observation."
|
||||
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}."):
|
||||
# Already in the right place; pass through
|
||||
processed[key] = value
|
||||
elif key == OBS_STATE or key == f"{obs_prefix}robot_state":
|
||||
# Rename robot_state → observation.state
|
||||
processed[OBS_STATE] = value.float() if hasattr(value, "float") else value
|
||||
|
||||
return processed
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
return self._process_observation(observation)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="isaaclab_arena_processor")
|
||||
class IsaaclabArenaProcessorStep(ObservationProcessorStep):
|
||||
|
||||
@@ -39,7 +39,7 @@ from collections.abc import Callable, Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, TypedDict, TypeVar, cast
|
||||
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -251,7 +251,7 @@ class ProcessorMigrationError(Exception):
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
||||
"""A sequential pipeline for processing data, integrated with the Hugging Face Hub.
|
||||
|
||||
This class chains together multiple `ProcessorStep` instances to form a complete
|
||||
@@ -1432,8 +1432,8 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
|
||||
# Type aliases for semantic clarity.
|
||||
RobotProcessorPipeline = DataProcessorPipeline[TInput, TOutput]
|
||||
PolicyProcessorPipeline = DataProcessorPipeline[TInput, TOutput]
|
||||
RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
|
||||
PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
|
||||
|
||||
|
||||
class ObservationProcessorStep(ProcessorStep, ABC):
|
||||
|
||||
@@ -336,7 +336,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Attributes:
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "lerobot/fast-action-tokenizer").
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
|
||||
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
||||
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypeAlias
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
@@ -49,5 +50,5 @@ class SOFollowerRobotConfig(RobotConfig, SOFollowerConfig):
|
||||
pass
|
||||
|
||||
|
||||
SO100FollowerConfig = SOFollowerRobotConfig
|
||||
SO101FollowerConfig = SOFollowerRobotConfig
|
||||
SO100FollowerConfig: TypeAlias = SOFollowerRobotConfig
|
||||
SO101FollowerConfig: TypeAlias = SOFollowerRobotConfig
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import TypeAlias
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
@@ -229,5 +230,5 @@ class SOFollower(Robot):
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
|
||||
SO100Follower = SOFollower
|
||||
SO101Follower = SOFollower
|
||||
SO100Follower: TypeAlias = SOFollower
|
||||
SO101Follower: TypeAlias = SOFollower
|
||||
|
||||
@@ -16,5 +16,3 @@
|
||||
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
from .unitree_g1 import UnitreeG1
|
||||
|
||||
__all__ = ["UnitreeG1", "UnitreeG1Config"]
|
||||
|
||||
@@ -27,10 +27,11 @@ _GAINS: dict[str, dict[str, list[float]]] = {
|
||||
}, # pitch, roll, yaw, knee, ankle_pitch, ankle_roll
|
||||
"right_leg": {"kp": [150, 150, 150, 300, 40, 40], "kd": [2, 2, 2, 4, 2, 2]},
|
||||
"waist": {"kp": [250, 250, 250], "kd": [5, 5, 5]}, # yaw, roll, pitch
|
||||
"left_arm": {"kp": [50, 50, 80, 80], "kd": [3, 3, 3, 3]}, # shoulder_pitch/roll/yaw, elbow
|
||||
"left_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]}, # shoulder_pitch/roll/yaw, elbow
|
||||
"left_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]}, # roll, pitch, yaw
|
||||
"right_arm": {"kp": [50, 50, 80, 80], "kd": [3, 3, 3, 3]},
|
||||
"right_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]},
|
||||
"right_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]},
|
||||
"other": {"kp": [80, 80, 80, 80, 80, 80], "kd": [3, 3, 3, 3, 3, 3]},
|
||||
}
|
||||
|
||||
|
||||
@@ -67,7 +68,3 @@ class UnitreeG1Config(RobotConfig):
|
||||
|
||||
# Compensates for gravity on the unitree's arms using the arm ik solver
|
||||
gravity_compensation: bool = False
|
||||
|
||||
# Lower-body controller class name, e.g. "GrootLocomotionController" or
|
||||
# "HolosomaLocomotionController". None disables it.
|
||||
controller: str | None = None
|
||||
|
||||
@@ -14,34 +14,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
from enum import IntEnum
|
||||
|
||||
import numpy as np
|
||||
|
||||
# ruff: noqa: N801, N815
|
||||
|
||||
NUM_MOTORS = 29
|
||||
|
||||
REMOTE_AXES = ("remote.lx", "remote.ly", "remote.rx", "remote.ry")
|
||||
REMOTE_BUTTONS = tuple(f"remote.button.{i}" for i in range(16))
|
||||
REMOTE_KEYS = REMOTE_AXES + REMOTE_BUTTONS
|
||||
|
||||
|
||||
def default_remote_input() -> dict[str, float]:
|
||||
"""Return a zeroed-out remote input dict (axes + buttons)."""
|
||||
return dict.fromkeys(REMOTE_KEYS, 0.0)
|
||||
|
||||
|
||||
def get_gravity_orientation(quaternion: list[float] | np.ndarray) -> np.ndarray:
|
||||
"""Get gravity orientation from quaternion [w, x, y, z]."""
|
||||
qw, qx, qy, qz = quaternion
|
||||
gravity_orientation = np.zeros(3, dtype=np.float32)
|
||||
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
|
||||
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
|
||||
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
|
||||
return gravity_orientation
|
||||
|
||||
|
||||
class G1_29_JointArmIndex(IntEnum):
|
||||
# Left arm
|
||||
@@ -51,7 +29,7 @@ class G1_29_JointArmIndex(IntEnum):
|
||||
kLeftElbow = 18
|
||||
kLeftWristRoll = 19
|
||||
kLeftWristPitch = 20
|
||||
kLeftWristYaw = 21
|
||||
kLeftWristyaw = 21
|
||||
|
||||
# Right arm
|
||||
kRightShoulderPitch = 22
|
||||
@@ -63,21 +41,6 @@ class G1_29_JointArmIndex(IntEnum):
|
||||
kRightWristYaw = 28
|
||||
|
||||
|
||||
def make_locomotion_controller(name: str | None):
|
||||
"""Instantiate a locomotion controller by class name. Returns None if name is None."""
|
||||
if name is None:
|
||||
return None
|
||||
controllers = {
|
||||
"GrootLocomotionController": "lerobot.robots.unitree_g1.gr00t_locomotion",
|
||||
"HolosomaLocomotionController": "lerobot.robots.unitree_g1.holosoma_locomotion",
|
||||
}
|
||||
module_path = controllers.get(name)
|
||||
if module_path is None:
|
||||
raise ValueError(f"Unknown controller: {name!r}. Available: {list(controllers)}")
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, name)()
|
||||
|
||||
|
||||
class G1_29_JointIndex(IntEnum):
|
||||
# Left leg
|
||||
kLeftHipPitch = 0
|
||||
@@ -106,7 +69,7 @@ class G1_29_JointIndex(IntEnum):
|
||||
kLeftElbow = 18
|
||||
kLeftWristRoll = 19
|
||||
kLeftWristPitch = 20
|
||||
kLeftWristYaw = 21
|
||||
kLeftWristyaw = 21
|
||||
|
||||
# Right arm
|
||||
kRightShoulderPitch = 22
|
||||
|
||||
@@ -16,11 +16,13 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
parent2_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(parent2_dir)
|
||||
|
||||
|
||||
class WeightedMovingFilter:
|
||||
@@ -29,14 +31,18 @@ class WeightedMovingFilter:
|
||||
self._weights = np.array(weights)
|
||||
self._data_size = data_size
|
||||
self._filtered_data = np.zeros(self._data_size)
|
||||
self._data_queue = deque(maxlen=self._window_size)
|
||||
self._data_queue = []
|
||||
|
||||
def _apply_filter(self):
|
||||
if len(self._data_queue) < self._window_size:
|
||||
return self._data_queue[-1]
|
||||
|
||||
data_array = np.array(self._data_queue)
|
||||
return data_array.T @ self._weights
|
||||
temp_filtered_data = np.zeros(self._data_size)
|
||||
for i in range(self._data_size):
|
||||
temp_filtered_data[i] = np.convolve(data_array[:, i], self._weights, mode="valid")[-1]
|
||||
|
||||
return temp_filtered_data
|
||||
|
||||
def add_data(self, new_data):
|
||||
assert len(new_data) == self._data_size
|
||||
@@ -46,6 +52,9 @@ class WeightedMovingFilter:
|
||||
): # skip duplicate data
|
||||
return
|
||||
|
||||
if len(self._data_queue) >= self._window_size:
|
||||
self._data_queue.pop(0)
|
||||
|
||||
self._data_queue.append(new_data)
|
||||
self._filtered_data = self._apply_filter()
|
||||
|
||||
@@ -62,6 +71,8 @@ class G1_29_ArmIK: # noqa: N801
|
||||
from pinocchio import casadi as cpin
|
||||
|
||||
self._pin = pin
|
||||
np.set_printoptions(precision=5, suppress=True, linewidth=200)
|
||||
|
||||
self.unit_test = unit_test
|
||||
|
||||
self.repo_path = snapshot_download("lerobot/unitree-g1-mujoco")
|
||||
@@ -238,35 +249,50 @@ class G1_29_ArmIK: # noqa: N801
|
||||
self.opti.set_value(self.param_tf_r, right_wrist)
|
||||
self.opti.set_value(self.var_q_last, self.init_data) # for smooth
|
||||
|
||||
converged = True
|
||||
try:
|
||||
self.opti.solve()
|
||||
|
||||
sol_q = self.opti.value(self.var_q)
|
||||
self.smooth_filter.add_data(sol_q)
|
||||
sol_q = self.smooth_filter.filtered_data
|
||||
|
||||
if current_lr_arm_motor_dq is not None:
|
||||
v = current_lr_arm_motor_dq * 0.0
|
||||
else:
|
||||
v = (sol_q - self.init_data) * 0.0
|
||||
|
||||
self.init_data = sol_q
|
||||
|
||||
sol_tauff = self._pin.rnea(
|
||||
self.reduced_robot.model,
|
||||
self.reduced_robot.data,
|
||||
sol_q,
|
||||
v,
|
||||
np.zeros(self.reduced_robot.model.nv),
|
||||
)
|
||||
|
||||
return sol_q, sol_tauff
|
||||
|
||||
except Exception as e:
|
||||
converged = False
|
||||
logger.error(f"IK convergence error: {e}")
|
||||
logger.error(f"ERROR in convergence, plotting debug info.{e}")
|
||||
|
||||
sol_q = self.opti.debug.value(self.var_q)
|
||||
self.smooth_filter.add_data(sol_q)
|
||||
sol_q = self.smooth_filter.filtered_data
|
||||
|
||||
self.smooth_filter.add_data(sol_q)
|
||||
sol_q = self.smooth_filter.filtered_data
|
||||
self.init_data = sol_q
|
||||
if current_lr_arm_motor_dq is not None:
|
||||
v = current_lr_arm_motor_dq * 0.0
|
||||
else:
|
||||
v = (sol_q - self.init_data) * 0.0
|
||||
|
||||
self.init_data = sol_q
|
||||
|
||||
if not converged:
|
||||
logger.error(
|
||||
f"sol_q:{sol_q} \nmotorstate: \n{current_lr_arm_motor_q} \nleft_pose: \n{left_wrist} \nright_pose: \n{right_wrist}"
|
||||
)
|
||||
|
||||
return current_lr_arm_motor_q, np.zeros(self.reduced_robot.model.nv)
|
||||
|
||||
sol_tauff = self._pin.rnea(
|
||||
self.reduced_robot.model,
|
||||
self.reduced_robot.data,
|
||||
sol_q,
|
||||
np.zeros(self.reduced_robot.model.nv),
|
||||
np.zeros(self.reduced_robot.model.nv),
|
||||
)
|
||||
|
||||
return sol_q, sol_tauff
|
||||
|
||||
def solve_tau(self, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None):
|
||||
try:
|
||||
q_g1 = np.array(current_lr_arm_motor_q, dtype=float)
|
||||
@@ -24,7 +24,6 @@ This server runs on the robot and forwards:
|
||||
Uses JSON for secure serialization instead of pickle.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
@@ -39,8 +38,6 @@ from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
|
||||
from lerobot.cameras.zmq.image_server import ImageServer
|
||||
|
||||
# DDS topic names follow Unitree SDK naming conventions
|
||||
# ruff: noqa: N816
|
||||
kTopicLowCommand_Debug = "rt/lowcmd" # action to robot
|
||||
@@ -153,32 +150,6 @@ def cmd_forward_loop(
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point for the robot server bridge."""
|
||||
parser = argparse.ArgumentParser(description="DDS-to-ZMQ bridge server for Unitree G1")
|
||||
parser.add_argument("--camera", action="store_true", help="Also launch camera server")
|
||||
parser.add_argument("--camera-device", type=int, default=4, help="Camera device ID (default: 4)")
|
||||
parser.add_argument("--camera-fps", type=int, default=30, help="Camera FPS (default: 30)")
|
||||
parser.add_argument("--camera-width", type=int, default=640, help="Camera width (default: 640)")
|
||||
parser.add_argument("--camera-height", type=int, default=480, help="Camera height (default: 480)")
|
||||
parser.add_argument("--camera-port", type=int, default=5555, help="Camera ZMQ port (default: 5555)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Optionally start camera server in background thread
|
||||
camera_thread = None
|
||||
if args.camera:
|
||||
camera_config = {
|
||||
"fps": args.camera_fps,
|
||||
"cameras": {
|
||||
"head_camera": {
|
||||
"device_id": args.camera_device,
|
||||
"shape": [args.camera_height, args.camera_width],
|
||||
}
|
||||
},
|
||||
}
|
||||
camera_server = ImageServer(camera_config, port=args.camera_port)
|
||||
camera_thread = threading.Thread(target=camera_server.run, daemon=True)
|
||||
camera_thread.start()
|
||||
print(f"Camera server started on port {args.camera_port} (device {args.camera_device})")
|
||||
|
||||
# initialize DDS
|
||||
ChannelFactoryInitialize(0)
|
||||
|
||||
@@ -235,8 +206,6 @@ def main() -> None:
|
||||
shutdown_event.set()
|
||||
ctx.term() # terminates blocking zmq.recv() calls
|
||||
t_state.join(timeout=2.0)
|
||||
if camera_thread is not None:
|
||||
camera_thread.join(timeout=2.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -14,67 +14,27 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.unitree_g1.g1_kinematics import G1_29_ArmIK
|
||||
from lerobot.robots.unitree_g1.g1_utils import (
|
||||
REMOTE_AXES,
|
||||
REMOTE_KEYS,
|
||||
G1_29_JointArmIndex,
|
||||
G1_29_JointIndex,
|
||||
default_remote_input,
|
||||
make_locomotion_controller,
|
||||
)
|
||||
from lerobot.utils.import_utils import _unitree_sdk_available
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex, G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
|
||||
if TYPE_CHECKING or _unitree_sdk_available:
|
||||
from unitree_sdk2py.core.channel import (
|
||||
ChannelFactoryInitialize as _SDKChannelFactoryInitialize,
|
||||
ChannelPublisher as _SDKChannelPublisher,
|
||||
ChannelSubscriber as _SDKChannelSubscriber,
|
||||
)
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
|
||||
LowCmd_ as hg_LowCmd,
|
||||
LowState_ as hg_LowState,
|
||||
)
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
else:
|
||||
_SDKChannelFactoryInitialize = None
|
||||
_SDKChannelPublisher = None
|
||||
_SDKChannelSubscriber = None
|
||||
unitree_hg_msg_dds__LowCmd_ = None
|
||||
hg_LowCmd = None
|
||||
hg_LowState = None
|
||||
CRC = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LocomotionController(Protocol):
|
||||
control_dt: float
|
||||
|
||||
def run_step(self, action: dict, lowstate) -> dict: ...
|
||||
|
||||
def reset(self) -> None: ...
|
||||
|
||||
|
||||
# DDS topic names follow Unitree SDK naming conventions
|
||||
# ruff: noqa: N816
|
||||
kTopicLowCommand_Debug = "rt/lowcmd"
|
||||
@@ -103,7 +63,7 @@ class IMUState:
|
||||
class G1_29_LowState: # noqa: N801
|
||||
motor_state: list[MotorState] = field(default_factory=lambda: [MotorState() for _ in G1_29_JointIndex])
|
||||
imu_state: IMUState = field(default_factory=IMUState)
|
||||
wireless_remote: bytes | None = None # Raw wireless remote data
|
||||
wireless_remote: Any = None # Raw wireless remote data
|
||||
mode_machine: int = 0 # Robot mode
|
||||
|
||||
|
||||
@@ -111,6 +71,25 @@ class UnitreeG1(Robot):
|
||||
config_class = UnitreeG1Config
|
||||
name = "unitree_g1"
|
||||
|
||||
# unitree remote controller
|
||||
class RemoteController:
|
||||
def __init__(self):
|
||||
self.lx = 0
|
||||
self.ly = 0
|
||||
self.rx = 0
|
||||
self.ry = 0
|
||||
self.button = [0] * 16
|
||||
|
||||
def set(self, data):
|
||||
# wireless_remote
|
||||
keys = struct.unpack("H", data[2:4])[0]
|
||||
for i in range(16):
|
||||
self.button[i] = (keys & (1 << i)) >> i
|
||||
self.lx = struct.unpack("f", data[4:8])[0]
|
||||
self.rx = struct.unpack("f", data[8:12])[0]
|
||||
self.ry = struct.unpack("f", data[12:16])[0]
|
||||
self.ly = struct.unpack("f", data[20:24])[0]
|
||||
|
||||
def __init__(self, config: UnitreeG1Config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -124,9 +103,11 @@ class UnitreeG1(Robot):
|
||||
|
||||
# Import channel classes based on mode
|
||||
if config.is_simulation:
|
||||
self._ChannelFactoryInitialize = _SDKChannelFactoryInitialize
|
||||
self._ChannelPublisher = _SDKChannelPublisher
|
||||
self._ChannelSubscriber = _SDKChannelSubscriber
|
||||
from unitree_sdk2py.core.channel import (
|
||||
ChannelFactoryInitialize,
|
||||
ChannelPublisher,
|
||||
ChannelSubscriber,
|
||||
)
|
||||
else:
|
||||
from lerobot.robots.unitree_g1.unitree_sdk2_socket import (
|
||||
ChannelFactoryInitialize,
|
||||
@@ -134,30 +115,22 @@ class UnitreeG1(Robot):
|
||||
ChannelSubscriber,
|
||||
)
|
||||
|
||||
self._ChannelFactoryInitialize = ChannelFactoryInitialize
|
||||
self._ChannelPublisher = ChannelPublisher
|
||||
self._ChannelSubscriber = ChannelSubscriber
|
||||
# Store for use in connect()
|
||||
self._ChannelFactoryInitialize = ChannelFactoryInitialize
|
||||
self._ChannelPublisher = ChannelPublisher
|
||||
self._ChannelSubscriber = ChannelSubscriber
|
||||
|
||||
# Initialize state variables
|
||||
self.sim_env = None
|
||||
self._env_wrapper = None
|
||||
self._lowstate = None
|
||||
self._lowstate_lock = threading.Lock()
|
||||
self._shutdown_event = threading.Event()
|
||||
self.subscribe_thread = None
|
||||
self.remote_controller = self.RemoteController()
|
||||
|
||||
self.arm_ik = G1_29_ArmIK() if config.gravity_compensation else None
|
||||
self.arm_ik = G1_29_ArmIK()
|
||||
|
||||
# Lower-body controller loaded dynamically
|
||||
self.controller: LocomotionController | None = make_locomotion_controller(config.controller)
|
||||
|
||||
# Controller thread state
|
||||
self._controller_thread = None
|
||||
self._controller_action_lock = threading.Lock()
|
||||
self.controller_input = default_remote_input()
|
||||
self.controller_output = {}
|
||||
|
||||
def _subscribe_lowstate(self): # polls robot state @ 250Hz
|
||||
def _subscribe_motor_state(self): # polls robot state @ 250Hz
|
||||
while not self._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
|
||||
@@ -170,11 +143,11 @@ class UnitreeG1(Robot):
|
||||
lowstate = G1_29_LowState()
|
||||
|
||||
# Capture motor states using jointindex
|
||||
for joint in G1_29_JointIndex:
|
||||
lowstate.motor_state[joint].q = msg.motor_state[joint].q
|
||||
lowstate.motor_state[joint].dq = msg.motor_state[joint].dq
|
||||
lowstate.motor_state[joint].tau_est = msg.motor_state[joint].tau_est
|
||||
lowstate.motor_state[joint].temperature = msg.motor_state[joint].temperature
|
||||
for id in G1_29_JointIndex:
|
||||
lowstate.motor_state[id].q = msg.motor_state[id].q
|
||||
lowstate.motor_state[id].dq = msg.motor_state[id].dq
|
||||
lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est
|
||||
lowstate.motor_state[id].temperature = msg.motor_state[id].temperature
|
||||
|
||||
# Capture IMU state
|
||||
lowstate.imu_state.quaternion = list(msg.imu_state.quaternion)
|
||||
@@ -189,106 +162,31 @@ class UnitreeG1(Robot):
|
||||
# Capture mode_machine
|
||||
lowstate.mode_machine = msg.mode_machine
|
||||
|
||||
with self._lowstate_lock:
|
||||
self._lowstate = lowstate
|
||||
self._lowstate = lowstate
|
||||
|
||||
current_time = time.time()
|
||||
all_t_elapsed = current_time - start_time
|
||||
sleep_time = max(0, (self.control_dt - all_t_elapsed)) # maintain constant control dt
|
||||
time.sleep(sleep_time)
|
||||
|
||||
def publish_lowcmd(
|
||||
self,
|
||||
action: RobotAction,
|
||||
kp: np.ndarray | list[float] | None = None,
|
||||
kd: np.ndarray | list[float] | None = None,
|
||||
tau: np.ndarray | list[float] | None = None,
|
||||
) -> None: # writes robot command whenever requested
|
||||
for motor in G1_29_JointIndex:
|
||||
key = f"{motor.name}.q"
|
||||
if key in action:
|
||||
self.msg.motor_cmd[motor.value].q = action[key]
|
||||
self.msg.motor_cmd[motor.value].qd = 0
|
||||
self.msg.motor_cmd[motor.value].kp = (
|
||||
kp[motor.value] if kp is not None else self.kp[motor.value]
|
||||
)
|
||||
self.msg.motor_cmd[motor.value].kd = (
|
||||
kd[motor.value] if kd is not None else self.kd[motor.value]
|
||||
)
|
||||
self.msg.motor_cmd[motor.value].tau = tau[motor.value] if tau is not None else 0.0
|
||||
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
if self.controller is None:
|
||||
return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
|
||||
return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
|
||||
|
||||
arm_features = {f"{G1_29_JointArmIndex(motor).name}.q": float for motor in G1_29_JointArmIndex}
|
||||
remote_features = dict.fromkeys(REMOTE_AXES, float)
|
||||
return {**arm_features, **remote_features}
|
||||
|
||||
def _controller_loop(self):
|
||||
"""Background thread that runs controller at policy's control_dt."""
|
||||
control_dt = self.controller.control_dt
|
||||
logger.info(f"Controller loop starting with control_dt={control_dt} ({1.0 / control_dt:.1f}Hz)")
|
||||
|
||||
loop_count = 0
|
||||
last_log_time = time.time()
|
||||
|
||||
while not self._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
|
||||
with self._lowstate_lock:
|
||||
lowstate = self._lowstate
|
||||
|
||||
if lowstate is not None and self.controller is not None:
|
||||
loop_count += 1
|
||||
if time.time() - last_log_time >= 5.0: # Log every 5 seconds
|
||||
actual_hz = loop_count / (time.time() - last_log_time)
|
||||
logger.info(
|
||||
f"Controller actual rate: {actual_hz:.1f}Hz (target: {1.0 / control_dt:.1f}Hz)"
|
||||
)
|
||||
loop_count = 0
|
||||
last_log_time = time.time()
|
||||
# Read controller input snapshot
|
||||
with self._controller_action_lock:
|
||||
controller_input = dict(self.controller_input)
|
||||
|
||||
# Run controller step
|
||||
controller_action = self.controller.run_step(controller_input, lowstate)
|
||||
|
||||
# Write controller output snapshot
|
||||
with self._controller_action_lock:
|
||||
self.controller_output = dict(controller_action)
|
||||
|
||||
ctrl_kp = self.controller.kp if hasattr(self.controller, "kp") else None
|
||||
ctrl_kd = self.controller.kd if hasattr(self.controller, "kd") else None
|
||||
self.publish_lowcmd(controller_action, kp=ctrl_kp, kd=ctrl_kd)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, control_dt - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
def calibrate(self) -> None:
|
||||
# TODO: implement g1_29 calibration
|
||||
def calibrate(self) -> None: # robot is already calibrated
|
||||
pass
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None: # connect to DDS
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
|
||||
LowCmd_ as hg_LowCmd,
|
||||
LowState_ as hg_LowState,
|
||||
)
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
|
||||
# Initialize DDS channel and simulation environment
|
||||
if self.config.is_simulation:
|
||||
self._ChannelFactoryInitialize(0, "lo")
|
||||
@@ -296,7 +194,7 @@ class UnitreeG1(Robot):
|
||||
# Extract the actual gym env from the dict structure
|
||||
self.sim_env = self._env_wrapper["hub_env"][0].envs[0]
|
||||
else:
|
||||
self._ChannelFactoryInitialize(0, config=self.config)
|
||||
self._ChannelFactoryInitialize(0)
|
||||
|
||||
# Initialize direct motor control interface
|
||||
self.lowcmd_publisher = self._ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||
@@ -305,7 +203,7 @@ class UnitreeG1(Robot):
|
||||
self.lowstate_subscriber.Init()
|
||||
|
||||
# Start subscribe thread to read robot state
|
||||
self.subscribe_thread = threading.Thread(target=self._subscribe_lowstate)
|
||||
self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state)
|
||||
self.subscribe_thread.start()
|
||||
|
||||
# Connect cameras
|
||||
@@ -322,53 +220,25 @@ class UnitreeG1(Robot):
|
||||
|
||||
# Wait for first state message to arrive
|
||||
lowstate = None
|
||||
deadline = time.time() + 10.0
|
||||
while lowstate is None:
|
||||
with self._lowstate_lock:
|
||||
lowstate = self._lowstate
|
||||
lowstate = self._lowstate
|
||||
if lowstate is None:
|
||||
if time.time() > deadline:
|
||||
raise TimeoutError("Timed out waiting for robot state (10s)")
|
||||
logger.warning("[UnitreeG1] Waiting for robot state...")
|
||||
time.sleep(0.01)
|
||||
logger.info("[UnitreeG1] Connected to robot.")
|
||||
logger.warning("[UnitreeG1] Waiting for robot state...")
|
||||
logger.warning("[UnitreeG1] Connected to robot.")
|
||||
self.msg.mode_machine = lowstate.mode_machine
|
||||
|
||||
# Initialize all motors with unified kp/kd from config
|
||||
self.kp = np.array(self.config.kp, dtype=np.float32)
|
||||
self.kd = np.array(self.config.kd, dtype=np.float32)
|
||||
|
||||
for joint in G1_29_JointIndex:
|
||||
self.msg.motor_cmd[joint].mode = 1
|
||||
self.msg.motor_cmd[joint].kp = self.kp[joint.value]
|
||||
self.msg.motor_cmd[joint].kd = self.kd[joint.value]
|
||||
self.msg.motor_cmd[joint].q = lowstate.motor_state[joint.value].q
|
||||
|
||||
# Start controller thread if enabled
|
||||
if self.controller is not None:
|
||||
self._controller_thread = threading.Thread(target=self._controller_loop, daemon=True)
|
||||
self._controller_thread.start()
|
||||
fps = int(1.0 / self.controller.control_dt)
|
||||
logger.info(f"Controller thread started ({fps}Hz)")
|
||||
|
||||
def _send_zero_torque(self) -> None:
|
||||
"""Send a zero-gain command to make joints passive before shutting down."""
|
||||
try:
|
||||
with self._lowstate_lock:
|
||||
lowstate = self._lowstate
|
||||
if lowstate is None:
|
||||
return
|
||||
action = {f"{motor.name}.q": lowstate.motor_state[motor.value].q for motor in G1_29_JointIndex}
|
||||
zero_gains = np.zeros(29, dtype=np.float32)
|
||||
self.publish_lowcmd(action, kp=zero_gains, kd=zero_gains, tau=zero_gains)
|
||||
logger.info("Sent zero-torque command for safe shutdown")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send zero-torque on disconnect: {e}")
|
||||
for id in G1_29_JointIndex:
|
||||
self.msg.motor_cmd[id].mode = 1
|
||||
self.msg.motor_cmd[id].kp = self.kp[id.value]
|
||||
self.msg.motor_cmd[id].kd = self.kd[id.value]
|
||||
self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q
|
||||
|
||||
def disconnect(self):
|
||||
# Put robot in passive mode before stopping threads
|
||||
if not self.config.is_simulation:
|
||||
self._send_zero_torque()
|
||||
|
||||
# Signal thread to stop and unblock any waits
|
||||
self._shutdown_event.set()
|
||||
|
||||
@@ -378,12 +248,6 @@ class UnitreeG1(Robot):
|
||||
if self.subscribe_thread.is_alive():
|
||||
logger.warning("Subscribe thread did not stop cleanly")
|
||||
|
||||
# Wait for controller thread to finish
|
||||
if self._controller_thread is not None:
|
||||
self._controller_thread.join(timeout=2.0)
|
||||
if self._controller_thread.is_alive():
|
||||
logger.warning("Controller thread did not stop cleanly")
|
||||
|
||||
# Close simulation environment
|
||||
if self.config.is_simulation and self.sim_env is not None:
|
||||
try:
|
||||
@@ -410,8 +274,7 @@ class UnitreeG1(Robot):
|
||||
cam.disconnect()
|
||||
|
||||
def get_observation(self) -> RobotObservation:
|
||||
with self._lowstate_lock:
|
||||
lowstate = self._lowstate
|
||||
lowstate = self._lowstate
|
||||
if lowstate is None:
|
||||
return {}
|
||||
|
||||
@@ -450,9 +313,14 @@ class UnitreeG1(Robot):
|
||||
obs["imu.rpy.pitch"] = lowstate.imu_state.rpy[1]
|
||||
obs["imu.rpy.yaw"] = lowstate.imu_state.rpy[2]
|
||||
|
||||
# Wireless remote (raw bytes for teleoperator)
|
||||
if lowstate.wireless_remote:
|
||||
obs["wireless_remote"] = lowstate.wireless_remote
|
||||
# Controller - parse wireless_remote and add to obs
|
||||
if lowstate.wireless_remote and len(lowstate.wireless_remote) >= 24:
|
||||
self.remote_controller.set(lowstate.wireless_remote)
|
||||
obs["remote.buttons"] = self.remote_controller.button.copy()
|
||||
obs["remote.lx"] = self.remote_controller.lx
|
||||
obs["remote.ly"] = self.remote_controller.ly
|
||||
obs["remote.rx"] = self.remote_controller.rx
|
||||
obs["remote.ry"] = self.remote_controller.ry
|
||||
|
||||
# Cameras - read images from ZMQ cameras
|
||||
for cam_name, cam in self._cameras.items():
|
||||
@@ -460,63 +328,73 @@ class UnitreeG1(Robot):
|
||||
|
||||
return obs
|
||||
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
action_to_publish = action
|
||||
if self.controller is not None:
|
||||
# Controller thread owns legs/waist. Here we only update joystick inputs
|
||||
# and publish arm targets from the teleoperator.
|
||||
self._update_controller_action(action)
|
||||
arm_prefixes = tuple(j.name for j in G1_29_JointArmIndex)
|
||||
action_to_publish = {
|
||||
key: value
|
||||
for key, value in action.items()
|
||||
if key.endswith(".q") and key.startswith(arm_prefixes)
|
||||
}
|
||||
|
||||
tau = None
|
||||
if self.config.gravity_compensation and self.arm_ik is not None:
|
||||
tau = np.zeros(29, dtype=np.float32)
|
||||
action_np = np.array(
|
||||
[
|
||||
action_to_publish.get(f"{joint.name}.q", self.msg.motor_cmd[joint.value].q)
|
||||
for joint in G1_29_JointArmIndex
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
arm_tau = self.arm_ik.solve_tau(action_np)
|
||||
arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value
|
||||
for joint in G1_29_JointArmIndex:
|
||||
local_idx = joint.value - arm_start_idx
|
||||
tau[joint.value] = arm_tau[local_idx]
|
||||
|
||||
self.publish_lowcmd(action_to_publish, tau=tau)
|
||||
return action
|
||||
|
||||
def _update_controller_action(self, action: RobotAction) -> None:
|
||||
"""Update controller input state from incoming teleop action."""
|
||||
with self._controller_action_lock:
|
||||
for key in REMOTE_KEYS:
|
||||
if key in action:
|
||||
self.controller_input[key] = action[key]
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
with self._lowstate_lock:
|
||||
return self._lowstate is not None
|
||||
return self._lowstate is not None
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
"""Joint positions for all 29 joints."""
|
||||
return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
|
||||
|
||||
@property
|
||||
def cameras(self) -> dict:
|
||||
return self._cameras
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
for motor in G1_29_JointIndex:
|
||||
key = f"{motor.name}.q"
|
||||
if key in action:
|
||||
self.msg.motor_cmd[motor.value].q = action[key]
|
||||
self.msg.motor_cmd[motor.value].qd = 0
|
||||
self.msg.motor_cmd[motor.value].kp = self.kp[motor.value]
|
||||
self.msg.motor_cmd[motor.value].kd = self.kd[motor.value]
|
||||
self.msg.motor_cmd[motor.value].tau = 0
|
||||
|
||||
if self.config.gravity_compensation:
|
||||
# Build action_np from motor commands (arm joints are indices 15-28, local indices 0-13)
|
||||
action_np = np.zeros(14)
|
||||
arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value # 15
|
||||
for joint in G1_29_JointArmIndex:
|
||||
local_idx = joint.value - arm_start_idx
|
||||
action_np[local_idx] = self.msg.motor_cmd[joint.value].q
|
||||
tau = self.arm_ik.solve_tau(action_np)
|
||||
|
||||
# Apply tau back to motor commands
|
||||
for joint in G1_29_JointArmIndex:
|
||||
local_idx = joint.value - arm_start_idx
|
||||
self.msg.motor_cmd[joint.value].tau = tau[local_idx]
|
||||
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
return action
|
||||
|
||||
def get_gravity_orientation(self, quaternion): # get gravity orientation from quaternion
|
||||
"""Get gravity orientation from quaternion."""
|
||||
qw = quaternion[0]
|
||||
qx = quaternion[1]
|
||||
qy = quaternion[2]
|
||||
qz = quaternion[3]
|
||||
|
||||
gravity_orientation = np.zeros(3)
|
||||
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
|
||||
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
|
||||
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
|
||||
return gravity_orientation
|
||||
|
||||
def reset(
|
||||
self,
|
||||
control_dt: float | None = None,
|
||||
@@ -529,9 +407,15 @@ class UnitreeG1(Robot):
|
||||
|
||||
if self.config.is_simulation and self.sim_env is not None:
|
||||
self.sim_env.reset()
|
||||
self.publish_lowcmd(
|
||||
{f"{motor.name}.q": float(default_positions[motor.value]) for motor in G1_29_JointIndex}
|
||||
)
|
||||
|
||||
for motor in G1_29_JointIndex:
|
||||
self.msg.motor_cmd[motor.value].q = default_positions[motor.value]
|
||||
self.msg.motor_cmd[motor.value].qd = 0
|
||||
self.msg.motor_cmd[motor.value].kp = self.kp[motor.value]
|
||||
self.msg.motor_cmd[motor.value].kd = self.kd[motor.value]
|
||||
self.msg.motor_cmd[motor.value].tau = 0
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
else:
|
||||
total_time = 3.0
|
||||
num_steps = int(total_time / control_dt)
|
||||
@@ -562,8 +446,4 @@ class UnitreeG1(Robot):
|
||||
sleep_time = max(0, control_dt - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
# Reset controller internal state (gait phase, obs history, etc.)
|
||||
if self.controller is not None and hasattr(self.controller, "reset"):
|
||||
self.controller.reset()
|
||||
|
||||
logger.info("Reached default position")
|
||||
|
||||
@@ -22,8 +22,6 @@ import zmq
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
|
||||
# Module-level ZMQ state mirrors the Unitree SDK's global ChannelFactory Singleton.
|
||||
# Only one robot connection per process is supported.
|
||||
_ctx: zmq.Context | None = None
|
||||
_lowcmd_sock: zmq.Socket | None = None
|
||||
_lowstate_sock: zmq.Socket | None = None
|
||||
@@ -99,22 +97,17 @@ def lowcmd_to_dict(topic: str, msg: Any) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def ChannelFactoryInitialize(domain_id: int = 0, config: Any = None) -> None: # noqa: N802
|
||||
def ChannelFactoryInitialize(*args: Any, **kwargs: Any) -> None: # noqa: N802
|
||||
"""
|
||||
Initialize ZMQ sockets for robot communication.
|
||||
|
||||
This function mimics the Unitree SDK's ChannelFactoryInitialize but uses
|
||||
ZMQ sockets to connect to the robot server bridge instead of DDS.
|
||||
|
||||
Args:
|
||||
domain_id: Ignored (for API compatibility with Unitree SDK)
|
||||
config: UnitreeG1Config instance with robot_ip
|
||||
"""
|
||||
global _ctx, _lowcmd_sock, _lowstate_sock
|
||||
|
||||
# read socket config
|
||||
if config is None:
|
||||
config = UnitreeG1Config()
|
||||
config = UnitreeG1Config()
|
||||
robot_ip = config.robot_ip
|
||||
|
||||
ctx = zmq.Context.instance()
|
||||
|
||||
@@ -1,462 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Benchmark runner: train and evaluate policies across simulation benchmarks.
|
||||
|
||||
Orchestrates per-benchmark training and evaluation using the existing
|
||||
``lerobot-train`` and ``lerobot-eval`` CLI tools.
|
||||
|
||||
Typical usage::
|
||||
|
||||
# Train SmolVLA on LIBERO-plus (4 GPUs, 50k steps):
|
||||
lerobot-benchmark train \\
|
||||
--benchmarks libero_plus \\
|
||||
--policy-path lerobot/smolvla_base \\
|
||||
--hub-user $HF_USER \\
|
||||
--num-gpus 4 --steps 50000
|
||||
|
||||
# Evaluate the trained policies:
|
||||
lerobot-benchmark eval \\
|
||||
--benchmarks libero_plus \\
|
||||
--hub-user $HF_USER
|
||||
|
||||
# Full pipeline (train → upload → eval) for multiple benchmarks:
|
||||
lerobot-benchmark all \\
|
||||
--benchmarks libero_plus,robocasa,robomme \\
|
||||
--policy-path lerobot/smolvla_base \\
|
||||
--hub-user $HF_USER \\
|
||||
--num-gpus 4 --steps 50000
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkEntry:
|
||||
"""Training + evaluation settings for a single benchmark.
|
||||
|
||||
When ``eval_tasks`` is set, evaluation runs once per task in the list
|
||||
(e.g. libero_spatial, libero_object, …). ``env_task`` is still used as
|
||||
the task for mid-training evaluation during ``lerobot-train``.
|
||||
"""
|
||||
|
||||
dataset_repo_id: str
|
||||
env_type: str
|
||||
env_task: str
|
||||
eval_tasks: list[str] | None = None
|
||||
train_overrides: dict[str, str] = field(default_factory=dict)
|
||||
eval_overrides: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
LIBERO_SUITES = ["libero_spatial", "libero_object", "libero_goal", "libero_10"]
|
||||
|
||||
# Each benchmark maps a human-readable name to its dataset and eval env.
|
||||
# ``dataset_repo_id`` can contain ``{hub_user}`` which is interpolated at
|
||||
# runtime from ``--hub-user``.
|
||||
BENCHMARK_REGISTRY: dict[str, BenchmarkEntry] = {
|
||||
"libero": BenchmarkEntry(
|
||||
dataset_repo_id="{hub_user}/libero",
|
||||
env_type="libero",
|
||||
env_task="libero_spatial",
|
||||
eval_tasks=LIBERO_SUITES,
|
||||
),
|
||||
"libero_plus": BenchmarkEntry(
|
||||
dataset_repo_id="{hub_user}/libero_plus",
|
||||
env_type="libero_plus",
|
||||
env_task="libero_spatial",
|
||||
eval_tasks=LIBERO_SUITES,
|
||||
),
|
||||
"metaworld": BenchmarkEntry(
|
||||
dataset_repo_id="{hub_user}/metaworld",
|
||||
env_type="metaworld",
|
||||
env_task="metaworld-push-v2",
|
||||
),
|
||||
"robocasa": BenchmarkEntry(
|
||||
dataset_repo_id="{hub_user}/robocasa",
|
||||
env_type="robocasa",
|
||||
env_task="PickPlaceCounterToCabinet",
|
||||
),
|
||||
"robomme": BenchmarkEntry(
|
||||
dataset_repo_id="{hub_user}/robomme",
|
||||
env_type="robomme",
|
||||
env_task="PickXtimes",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _policy_repo_id(hub_user: str, policy_name: str, benchmark: str) -> str:
|
||||
return f"{hub_user}/{policy_name}_{benchmark}"
|
||||
|
||||
|
||||
def _extra_keys(extra_args: list[str]) -> set[str]:
|
||||
"""Extract ``--key`` prefixes from extra CLI args for override detection."""
|
||||
keys: set[str] = set()
|
||||
for arg in extra_args:
|
||||
if arg.startswith("--") and "=" in arg:
|
||||
keys.add(arg.split("=", 1)[0])
|
||||
return keys
|
||||
|
||||
|
||||
def _build_train_cmd(
|
||||
benchmark: BenchmarkEntry,
|
||||
*,
|
||||
policy_path: str,
|
||||
hub_user: str,
|
||||
policy_name: str,
|
||||
benchmark_name: str,
|
||||
num_gpus: int,
|
||||
steps: int,
|
||||
batch_size: int,
|
||||
eval_freq: int,
|
||||
save_freq: int,
|
||||
wandb: bool,
|
||||
extra_args: list[str],
|
||||
) -> list[str]:
|
||||
"""Build the ``accelerate launch lerobot-train`` command list."""
|
||||
lerobot_train = shutil.which("lerobot-train")
|
||||
if lerobot_train is None:
|
||||
raise RuntimeError("lerobot-train not found on PATH. Is lerobot installed?")
|
||||
|
||||
# Strip bare "--" separators that argparse may pass through
|
||||
cleaned_extra = [a for a in extra_args if a != "--"]
|
||||
overridden = _extra_keys(cleaned_extra)
|
||||
|
||||
repo_id = _policy_repo_id(hub_user, policy_name, benchmark_name)
|
||||
dataset_id = benchmark.dataset_repo_id.format(hub_user=hub_user)
|
||||
|
||||
defaults: list[tuple[str, str]] = [
|
||||
("--policy.path", policy_path),
|
||||
("--dataset.repo_id", dataset_id),
|
||||
("--policy.repo_id", repo_id),
|
||||
("--env.type", benchmark.env_type),
|
||||
("--env.task", benchmark.env_task),
|
||||
("--steps", str(steps)),
|
||||
("--batch_size", str(batch_size)),
|
||||
("--eval_freq", str(eval_freq)),
|
||||
("--save_freq", str(save_freq)),
|
||||
("--output_dir", f"outputs/train/{policy_name}_{benchmark_name}"),
|
||||
("--job_name", f"{policy_name}_{benchmark_name}"),
|
||||
("--policy.push_to_hub", "true"),
|
||||
]
|
||||
if wandb:
|
||||
defaults.append(("--wandb.enable", "true"))
|
||||
for k, v in benchmark.train_overrides.items():
|
||||
defaults.append((f"--{k}", v))
|
||||
|
||||
cmd: list[str] = [
|
||||
"accelerate", "launch",
|
||||
"--multi_gpu",
|
||||
f"--num_processes={num_gpus}",
|
||||
lerobot_train,
|
||||
]
|
||||
for key, val in defaults:
|
||||
if key not in overridden:
|
||||
cmd.append(f"{key}={val}")
|
||||
cmd.extend(cleaned_extra)
|
||||
return cmd
|
||||
|
||||
|
||||
def _build_eval_cmd(
|
||||
benchmark: BenchmarkEntry,
|
||||
*,
|
||||
hub_user: str,
|
||||
policy_name: str,
|
||||
benchmark_name: str,
|
||||
eval_task: str | None = None,
|
||||
n_episodes: int,
|
||||
batch_size_eval: int,
|
||||
extra_args: list[str],
|
||||
) -> list[str]:
|
||||
"""Build the ``lerobot-eval`` command list.
|
||||
|
||||
``eval_task`` overrides the benchmark's ``env_task`` so the same
|
||||
benchmark can be evaluated on multiple suites (e.g. LIBERO).
|
||||
"""
|
||||
lerobot_eval = shutil.which("lerobot-eval")
|
||||
if lerobot_eval is None:
|
||||
raise RuntimeError("lerobot-eval not found on PATH. Is lerobot installed?")
|
||||
|
||||
task = eval_task or benchmark.env_task
|
||||
repo_id = _policy_repo_id(hub_user, policy_name, benchmark_name)
|
||||
out_dir = _eval_output_dir(policy_name, benchmark_name, eval_task=task)
|
||||
|
||||
cleaned_extra = [a for a in extra_args if a != "--"]
|
||||
overridden = _extra_keys(cleaned_extra)
|
||||
|
||||
defaults: list[tuple[str, str]] = [
|
||||
("--policy.path", repo_id),
|
||||
("--env.type", benchmark.env_type),
|
||||
("--env.task", task),
|
||||
("--eval.n_episodes", str(n_episodes)),
|
||||
("--eval.batch_size", str(batch_size_eval)),
|
||||
("--output_dir", out_dir),
|
||||
("--policy.device", "cuda"),
|
||||
]
|
||||
for k, v in benchmark.eval_overrides.items():
|
||||
defaults.append((f"--{k}", v))
|
||||
|
||||
cmd: list[str] = [lerobot_eval]
|
||||
for key, val in defaults:
|
||||
if key not in overridden:
|
||||
cmd.append(f"{key}={val}")
|
||||
cmd.extend(cleaned_extra)
|
||||
return cmd
|
||||
|
||||
|
||||
def _eval_output_dir(policy_name: str, benchmark_name: str, eval_task: str | None = None) -> Path:
|
||||
if eval_task:
|
||||
return Path(f"outputs/eval/{policy_name}_{benchmark_name}/{eval_task}")
|
||||
return Path(f"outputs/eval/{policy_name}_{benchmark_name}")
|
||||
|
||||
|
||||
def _run(cmd: list[str], *, dry_run: bool) -> None:
|
||||
log.info("Command: %s", " \\\n ".join(cmd))
|
||||
if dry_run:
|
||||
log.info("[dry-run] Skipping execution.")
|
||||
return
|
||||
result = subprocess.run(cmd, check=False)
|
||||
if result.returncode != 0:
|
||||
log.error("Command failed with exit code %d", result.returncode)
|
||||
sys.exit(result.returncode)
|
||||
|
||||
|
||||
def _push_eval_to_hub(
|
||||
*,
|
||||
hub_user: str,
|
||||
policy_name: str,
|
||||
benchmark_name: str,
|
||||
eval_task: str | None = None,
|
||||
dry_run: bool,
|
||||
) -> None:
|
||||
"""Upload eval results (metrics + videos) to the policy repo on the Hub."""
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
repo_id = _policy_repo_id(hub_user, policy_name, benchmark_name)
|
||||
local_dir = _eval_output_dir(policy_name, benchmark_name, eval_task=eval_task)
|
||||
hub_path = f"eval/{eval_task}" if eval_task else f"eval/{benchmark_name}"
|
||||
|
||||
if not local_dir.exists():
|
||||
log.warning("Eval output dir %s does not exist, skipping hub upload.", local_dir)
|
||||
return
|
||||
|
||||
log.info("Uploading eval results from %s to %s (path_in_repo=%s)", local_dir, repo_id, hub_path)
|
||||
if dry_run:
|
||||
log.info("[dry-run] Skipping upload.")
|
||||
return
|
||||
|
||||
api = HfApi()
|
||||
api.upload_folder(
|
||||
folder_path=str(local_dir),
|
||||
repo_id=repo_id,
|
||||
path_in_repo=hub_path,
|
||||
repo_type="model",
|
||||
commit_message=f"Upload eval results for {eval_task or benchmark_name}",
|
||||
)
|
||||
|
||||
|
||||
def _resolve_benchmarks(names: str) -> list[tuple[str, BenchmarkEntry]]:
|
||||
out = []
|
||||
for name in names.split(","):
|
||||
name = name.strip()
|
||||
if name not in BENCHMARK_REGISTRY:
|
||||
available = ", ".join(BENCHMARK_REGISTRY)
|
||||
raise ValueError(f"Unknown benchmark '{name}'. Available: {available}")
|
||||
out.append((name, BENCHMARK_REGISTRY[name]))
|
||||
return out
|
||||
|
||||
|
||||
def cmd_train(args: argparse.Namespace) -> None:
|
||||
benchmarks = _resolve_benchmarks(args.benchmarks)
|
||||
for bname, bentry in benchmarks:
|
||||
log.info("=== Training on benchmark: %s ===", bname)
|
||||
cmd = _build_train_cmd(
|
||||
bentry,
|
||||
policy_path=args.policy_path,
|
||||
hub_user=args.hub_user,
|
||||
policy_name=args.policy_name,
|
||||
benchmark_name=bname,
|
||||
num_gpus=args.num_gpus,
|
||||
steps=args.steps,
|
||||
batch_size=args.batch_size,
|
||||
eval_freq=args.eval_freq,
|
||||
save_freq=args.save_freq,
|
||||
wandb=args.wandb,
|
||||
extra_args=args.extra,
|
||||
)
|
||||
_run(cmd, dry_run=args.dry_run)
|
||||
|
||||
|
||||
def _run_eval_for_benchmark(
|
||||
bname: str,
|
||||
bentry: BenchmarkEntry,
|
||||
args: argparse.Namespace,
|
||||
) -> None:
|
||||
"""Run evaluation for a single benchmark, iterating over all its eval_tasks."""
|
||||
tasks = bentry.eval_tasks or [bentry.env_task]
|
||||
for task in tasks:
|
||||
log.info("=== Evaluating %s / %s ===", bname, task)
|
||||
cmd = _build_eval_cmd(
|
||||
bentry,
|
||||
hub_user=args.hub_user,
|
||||
policy_name=args.policy_name,
|
||||
benchmark_name=bname,
|
||||
eval_task=task if bentry.eval_tasks else None,
|
||||
n_episodes=args.n_episodes,
|
||||
batch_size_eval=args.batch_size_eval,
|
||||
extra_args=args.extra,
|
||||
)
|
||||
_run(cmd, dry_run=args.dry_run)
|
||||
if args.push_eval_to_hub:
|
||||
_push_eval_to_hub(
|
||||
hub_user=args.hub_user,
|
||||
policy_name=args.policy_name,
|
||||
benchmark_name=bname,
|
||||
eval_task=task if bentry.eval_tasks else None,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
|
||||
def cmd_eval(args: argparse.Namespace) -> None:
|
||||
benchmarks = _resolve_benchmarks(args.benchmarks)
|
||||
for bname, bentry in benchmarks:
|
||||
_run_eval_for_benchmark(bname, bentry, args)
|
||||
|
||||
|
||||
def cmd_all(args: argparse.Namespace) -> None:
|
||||
"""Train on each benchmark, then evaluate each."""
|
||||
benchmarks = _resolve_benchmarks(args.benchmarks)
|
||||
|
||||
log.info("Phase 1: Training on %d benchmark(s)", len(benchmarks))
|
||||
for bname, bentry in benchmarks:
|
||||
log.info("=== Training on benchmark: %s ===", bname)
|
||||
cmd = _build_train_cmd(
|
||||
bentry,
|
||||
policy_path=args.policy_path,
|
||||
hub_user=args.hub_user,
|
||||
policy_name=args.policy_name,
|
||||
benchmark_name=bname,
|
||||
num_gpus=args.num_gpus,
|
||||
steps=args.steps,
|
||||
batch_size=args.batch_size,
|
||||
eval_freq=args.eval_freq,
|
||||
save_freq=args.save_freq,
|
||||
wandb=args.wandb,
|
||||
extra_args=args.extra,
|
||||
)
|
||||
_run(cmd, dry_run=args.dry_run)
|
||||
|
||||
log.info("Phase 2: Evaluating %d benchmark(s)", len(benchmarks))
|
||||
for bname, bentry in benchmarks:
|
||||
_run_eval_for_benchmark(bname, bentry, args)
|
||||
|
||||
|
||||
def _add_common_args(p: argparse.ArgumentParser) -> None:
|
||||
p.add_argument(
|
||||
"--benchmarks", required=True,
|
||||
help="Comma-separated benchmark names (e.g. libero_plus,robocasa,robomme).",
|
||||
)
|
||||
p.add_argument("--hub-user", required=True, help="HuggingFace Hub username.")
|
||||
p.add_argument(
|
||||
"--policy-name", default="smolvla",
|
||||
help="Short policy name used in repo IDs and output dirs (default: smolvla).",
|
||||
)
|
||||
p.add_argument("--dry-run", action="store_true", help="Print commands without executing.")
|
||||
|
||||
|
||||
def _add_train_args(p: argparse.ArgumentParser) -> None:
|
||||
p.add_argument("--policy-path", default="lerobot/smolvla_base", help="Pretrained policy path.")
|
||||
p.add_argument("--num-gpus", type=int, default=4, help="Number of GPUs.")
|
||||
p.add_argument("--steps", type=int, default=50_000, help="Total training steps.")
|
||||
p.add_argument("--batch-size", type=int, default=32, help="Per-GPU batch size.")
|
||||
p.add_argument("--eval-freq", type=int, default=10_000, help="Eval every N steps (0 to disable).")
|
||||
p.add_argument("--save-freq", type=int, default=10_000, help="Save checkpoint every N steps.")
|
||||
p.add_argument("--wandb", action="store_true", help="Enable Weights & Biases logging.")
|
||||
|
||||
|
||||
def _add_eval_args(p: argparse.ArgumentParser) -> None:
|
||||
p.add_argument("--n-episodes", type=int, default=50, help="Number of eval episodes.")
|
||||
p.add_argument("--batch-size-eval", type=int, default=10, help="Eval batch size (parallel envs).")
|
||||
p.add_argument(
|
||||
"--push-eval-to-hub", action="store_true",
|
||||
help="Upload eval results (metrics + videos) to the policy repo on the Hub.",
|
||||
)
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="lerobot-benchmark",
|
||||
description="Train and evaluate policies across simulation benchmarks.",
|
||||
)
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# train
|
||||
p_train = sub.add_parser("train", help="Train a policy on each selected benchmark.")
|
||||
_add_common_args(p_train)
|
||||
_add_train_args(p_train)
|
||||
p_train.set_defaults(func=cmd_train)
|
||||
|
||||
# eval
|
||||
p_eval = sub.add_parser("eval", help="Evaluate trained policies on each benchmark.")
|
||||
_add_common_args(p_eval)
|
||||
_add_eval_args(p_eval)
|
||||
p_eval.set_defaults(func=cmd_eval)
|
||||
|
||||
# all (train + eval)
|
||||
p_all = sub.add_parser("all", help="Train then evaluate on each benchmark.")
|
||||
_add_common_args(p_all)
|
||||
_add_train_args(p_all)
|
||||
_add_eval_args(p_all)
|
||||
p_all.set_defaults(func=cmd_all)
|
||||
|
||||
# list
|
||||
p_list = sub.add_parser("list", help="List available benchmarks.")
|
||||
p_list.set_defaults(func=lambda _args: _list_benchmarks())
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _list_benchmarks() -> None:
|
||||
print("Available benchmarks:\n")
|
||||
for name, entry in BENCHMARK_REGISTRY.items():
|
||||
print(f" {name}")
|
||||
print(f" dataset: {entry.dataset_repo_id}")
|
||||
print(f" env: {entry.env_type}")
|
||||
if entry.eval_tasks:
|
||||
print(f" eval on: {', '.join(entry.eval_tasks)}")
|
||||
else:
|
||||
print(f" eval on: {entry.env_task}")
|
||||
print()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = build_parser()
|
||||
args, extra = parser.parse_known_args()
|
||||
args.extra = extra
|
||||
args.func(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -132,13 +132,10 @@ def visualize_dataset(
|
||||
|
||||
logging.info("Logging to Rerun")
|
||||
|
||||
first_index = None
|
||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||
if first_index is None:
|
||||
first_index = batch["index"][0].item()
|
||||
# iterate over the batch
|
||||
for i in range(len(batch["index"])):
|
||||
rr.set_time("frame_index", sequence=batch["index"][i].item() - first_index)
|
||||
rr.set_time("frame_index", sequence=batch["frame_index"][i].item())
|
||||
rr.set_time("timestamp", timestamp=batch["timestamp"][i].item())
|
||||
|
||||
# display each camera image
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user