mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
Compare commits
122 Commits
revert/rai
...
feat/compa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
93bc43771b | ||
|
|
f816092993 | ||
|
|
1753235a61 | ||
|
|
739aaa8edd | ||
|
|
15678bd51a | ||
|
|
d72b4fe056 | ||
|
|
f9fd0fb841 | ||
|
|
6cf4555081 | ||
|
|
5ec2615b21 | ||
|
|
65c11eb5e6 | ||
|
|
7621acf776 | ||
|
|
3b33f9e34c | ||
|
|
7157794f58 | ||
|
|
88bc763033 | ||
|
|
64172756a7 | ||
|
|
3cd10d3560 | ||
|
|
dc69ae3fc0 | ||
|
|
bb0175e05e | ||
|
|
cff530a17a | ||
|
|
746336f9c8 | ||
|
|
e48d8babe0 | ||
|
|
da71b233be | ||
|
|
485aa2332c | ||
|
|
0bd16432bc | ||
|
|
5ab6505ea8 | ||
|
|
5170862d23 | ||
|
|
101fb02697 | ||
|
|
0664addec1 | ||
|
|
a7391e82c7 | ||
|
|
3521dd93c1 | ||
|
|
6288439d48 | ||
|
|
1cf768e17a | ||
|
|
d11ec6b5ef | ||
|
|
c75455a6de | ||
|
|
f25ac02e6c | ||
|
|
23cb668cac | ||
|
|
2ea3043b1b | ||
|
|
0f61e2415f | ||
|
|
76a425c600 | ||
|
|
df71f3ce24 | ||
|
|
326aca0a48 | ||
|
|
be46bdea8f | ||
|
|
306429a85b | ||
|
|
12f2f35760 | ||
|
|
a024d33750 | ||
|
|
63cd2111ad | ||
|
|
abe9e79825 | ||
|
|
503fc4e9f4 | ||
|
|
92b479f9ac | ||
|
|
b954337ac7 | ||
|
|
5f6f476f32 | ||
|
|
502fdc0630 | ||
|
|
9db6213895 | ||
|
|
aa1d906802 | ||
|
|
eff8a6fd12 | ||
|
|
c54cd529a2 | ||
|
|
a5ca206c49 | ||
|
|
88100943ef | ||
|
|
a95b15ccc0 | ||
|
|
a97d078d95 | ||
|
|
98662e5f24 | ||
|
|
4d8f242af9 | ||
|
|
1ff8986c77 | ||
|
|
f0aeded142 | ||
|
|
da5d2f3e91 | ||
|
|
d6ea3bbce0 | ||
|
|
7aedbbf81a | ||
|
|
1ee8d824f5 | ||
|
|
f7c4f99545 | ||
|
|
92b6254473 | ||
|
|
79137f58d1 | ||
|
|
da9c2e66f4 | ||
|
|
45730cc71e | ||
|
|
5d4af4b0b1 | ||
|
|
0050d7c61c | ||
|
|
cf2897f545 | ||
|
|
2c18210d02 | ||
|
|
44bf283701 | ||
|
|
a51682b266 | ||
|
|
ed49c9935a | ||
|
|
52455d03a7 | ||
|
|
4afb253825 | ||
|
|
96c664e09f | ||
|
|
8bd0aec618 | ||
|
|
e82e7a02e9 | ||
|
|
845b359d39 | ||
|
|
a6ff3cfebb | ||
|
|
271d92dcaa | ||
|
|
8e940bf361 | ||
|
|
6e8be57eb2 | ||
|
|
723013c71b | ||
|
|
bf6ac5e110 | ||
|
|
3ce5bcf24d | ||
|
|
6f5bb4d4a4 | ||
|
|
f29311ccb0 | ||
|
|
0c79cf8f4e | ||
|
|
f2ff370459 | ||
|
|
25f60c301b | ||
|
|
0699b46d87 | ||
|
|
b8f7e401d4 | ||
|
|
656fc0f059 | ||
|
|
829d2d1ad9 | ||
|
|
4ccf28437a | ||
|
|
9a49e57c72 | ||
|
|
6c28ef894a | ||
|
|
bf3c8746b7 | ||
|
|
9f32e00f90 | ||
|
|
fcaa0ea5f9 | ||
|
|
5ac9356135 | ||
|
|
b74e2a6113 | ||
|
|
a4bed41132 | ||
|
|
5c8dd883be | ||
|
|
38f6fc816b | ||
|
|
abde7be3b3 | ||
|
|
b6c528a438 | ||
|
|
6d331310ab | ||
|
|
5dfdec9288 | ||
|
|
50977a2c28 | ||
|
|
a0d7627d81 | ||
|
|
1ad2da403d | ||
|
|
2d3a605b3c | ||
|
|
f173265354 |
2
.github/workflows/full_tests.yml
vendored
2
.github/workflows/full_tests.yml
vendored
@@ -78,7 +78,7 @@ jobs:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --all-extras
|
||||
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
|
||||
34
.github/workflows/nightly.yml
vendored
34
.github/workflows/nightly.yml
vendored
@@ -119,6 +119,7 @@ jobs:
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
container:
|
||||
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
@@ -158,3 +159,36 @@ jobs:
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
run: make test-end-to-end
|
||||
|
||||
# This job runs multi-GPU training tests with 4 GPUs
|
||||
nightly-multi-gpu-tests:
|
||||
name: Nightly Multi-GPU Tests
|
||||
needs: [build-docker-gpu-nightly]
|
||||
runs-on:
|
||||
group: aws-g4dn-12xlarge # Instance with 4 GPUs
|
||||
env:
|
||||
HF_HOME: /home/user_lerobot/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
CUDA_VISIBLE_DEVICES: "0,1,2,3"
|
||||
container:
|
||||
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Verify GPU availability
|
||||
run: |
|
||||
nvidia-smi
|
||||
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
||||
|
||||
- name: Run multi-GPU training tests
|
||||
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
|
||||
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
|
||||
timeout-minutes: 10
|
||||
|
||||
14
.github/workflows/release.yml
vendored
14
.github/workflows/release.yml
vendored
@@ -82,6 +82,14 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Remove Tags with Git dependencies
|
||||
# TODO(Steven): Temporary patch to remove libero and pi from PyPi 0.4.0 release due to its reliance on git dependencies.
|
||||
run: |
|
||||
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
|
||||
grep -E '@ git\+https|lerobot\[pi\]|lerobot\[libero\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
|
||||
sed -E -i '/@ git\+https|lerobot\[pi\]|lerobot\[libero\]/d' pyproject.toml
|
||||
echo "::info:: Git dependencies removed. Proceeding with build."
|
||||
|
||||
- name: Install build dependencies
|
||||
run: python -m pip install build
|
||||
|
||||
@@ -103,7 +111,7 @@ jobs:
|
||||
- name: Publish to TestPyPI for pre-releases
|
||||
# True for tags like 'v0.2.0-rc1'
|
||||
if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-')
|
||||
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
verbose: true
|
||||
@@ -111,7 +119,7 @@ jobs:
|
||||
|
||||
- name: Publish to PyPI
|
||||
if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
|
||||
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
with:
|
||||
verbose: true
|
||||
print-hash: true
|
||||
@@ -138,7 +146,7 @@ jobs:
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
enable-cache: true
|
||||
enable-cache: true # zizmor: ignore[cache-poisoning]
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
- name: Create uv virtual environment
|
||||
|
||||
12
.github/workflows/stale.yml
vendored
12
.github/workflows/stale.yml
vendored
@@ -27,15 +27,17 @@ env:
|
||||
This issue was closed because it has been stalled for 14 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
CLOSE_PR_MESSAGE: >
|
||||
This PR was closed because it has been stalled for 14 days with no activity.
|
||||
This PR was closed because it has been stalled for 21 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
WARN_ISSUE_MESSAGE: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
Any change, comment or update to this issue will reset this count.
|
||||
Thank you for your contributions.
|
||||
WARN_PR_MESSAGE: >
|
||||
This PR has been automatically marked as stale because it has not had
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
recent activity (1 year). It will be closed if no further activity occurs.
|
||||
Any change, comment or update to this PR will reset this count.
|
||||
Thank you for your contributions.
|
||||
|
||||
jobs:
|
||||
@@ -56,10 +58,10 @@ jobs:
|
||||
stale-pr-label: stale
|
||||
exempt-issue-labels: never-stale
|
||||
exempt-pr-labels: never-stale
|
||||
days-before-issue-stale: 180 # TODO(Steven): Will modify this to 90 after initial cleanup
|
||||
days-before-issue-stale: 180
|
||||
days-before-issue-close: 14
|
||||
days-before-pr-stale: 180
|
||||
days-before-pr-close: 14
|
||||
days-before-pr-stale: 365
|
||||
days-before-pr-close: 21
|
||||
delete-branch: true
|
||||
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
|
||||
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}
|
||||
|
||||
183
.github/workflows/unbound_deps_tests.yml
vendored
Normal file
183
.github/workflows/unbound_deps_tests.yml
vendored
Normal file
@@ -0,0 +1,183 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This workflow handles full testing with unboud dependencies versions.
|
||||
name: Unbound Dependency Tests
|
||||
|
||||
on:
|
||||
# Allows running this workflow manually from the Actions tab
|
||||
workflow_dispatch:
|
||||
|
||||
# Run on the 1st and 15th of every month at 09:00 UTC
|
||||
schedule:
|
||||
- cron: '0 2 1,15 * *'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound
|
||||
|
||||
# Ensures that only the latest action is built, canceling older runs.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
|
||||
# This job runs the E2E tests + pytest with all unbound extras
|
||||
full-tests:
|
||||
name: Full Unbound Tests
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Unbound dependencies
|
||||
run: |
|
||||
sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml
|
||||
echo "Dependencies unbound:" && cat pyproject.toml
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --all-extras
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv
|
||||
|
||||
- name: Run end-to-end tests
|
||||
run: uv run make test-end-to-end
|
||||
|
||||
# This job builds a GPU enabled image for testing
|
||||
build-and-push-docker:
|
||||
name: Build and Push Docker
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
outputs:
|
||||
image_tag: ${{ env.DOCKER_IMAGE_NAME }}
|
||||
env:
|
||||
GITHUB_REF: ${{ github.ref }}
|
||||
steps:
|
||||
- name: Install Git LFS
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
cache-binary: false
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/Dockerfile.internal
|
||||
push: true
|
||||
tags: ${{ env.DOCKER_IMAGE_NAME }}
|
||||
build-args: |
|
||||
UNBOUND_DEPS=true
|
||||
|
||||
# This job runs pytest with all unbound extras in a GPU enabled host
|
||||
# It runs everytime a test image is created
|
||||
gpu-tests:
|
||||
name: GPU Unbound Tests
|
||||
needs: [build-and-push-docker]
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
env:
|
||||
HF_HOME: /home/user_lerobot/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
container:
|
||||
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv
|
||||
- name: Run end-to-end tests
|
||||
run: make test-end-to-end
|
||||
|
||||
# This job deletes the test image recently created
|
||||
# It runs everytime after the gpu-tests have finished
|
||||
delete-unbound-image:
|
||||
name: Delete Unbound Image
|
||||
needs: [gpu-tests, build-and-push-docker]
|
||||
if: always() && needs.build-and-push-docker.result == 'success'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Get Docker Hub Token and Delete Image
|
||||
# zizmor: ignore[template-injection]
|
||||
run: |
|
||||
IMAGE_NAME=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f1)
|
||||
IMAGE_TAG=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f2)
|
||||
|
||||
echo "Attempting to delete image: $IMAGE_NAME:$IMAGE_TAG"
|
||||
|
||||
TOKEN=$(curl -s -H "Content-Type: application/json" \
|
||||
-X POST \
|
||||
-d '{"username": "${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}", "password": "${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}"}' \
|
||||
https://hub.docker.com/v2/users/login/ | jq -r .token)
|
||||
|
||||
if [ "$TOKEN" == "null" ] || [ -z "$TOKEN" ]; then
|
||||
echo "::error::Failed to get Docker Hub token."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
HTTP_RESPONSE=$(curl -s -o /dev/null -w "%{http_code}" \
|
||||
-H "Authorization: JWT ${TOKEN}" \
|
||||
-X DELETE \
|
||||
https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/${IMAGE_TAG}/)
|
||||
|
||||
if [ "$HTTP_RESPONSE" -eq 204 ]; then
|
||||
echo "Successfully deleted Docker image tag: $IMAGE_NAME:$IMAGE_TAG"
|
||||
else
|
||||
echo "::error::Failed to delete Docker image. HTTP status: $HTTP_RESPONSE"
|
||||
exit 1
|
||||
fi
|
||||
@@ -26,7 +26,7 @@ repos:
|
||||
|
||||
##### General Code Quality & Formatting #####
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
rev: v6.0.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=1024']
|
||||
@@ -39,20 +39,20 @@ repos:
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.4
|
||||
rev: v0.14.1
|
||||
hooks:
|
||||
- id: ruff-format
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
|
||||
- repo: https://github.com/adhtruong/mirrors-typos
|
||||
rev: v1.34.0
|
||||
rev: v1.38.1
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--force-exclude]
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.20.0
|
||||
rev: v3.21.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py310-plus]
|
||||
@@ -68,12 +68,12 @@ repos:
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.27.2
|
||||
rev: v8.28.0
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||
rev: v1.11.0
|
||||
rev: v1.15.2
|
||||
hooks:
|
||||
- id: zizmor
|
||||
|
||||
@@ -86,11 +86,12 @@ repos:
|
||||
|
||||
# TODO(Steven): Uncomment when ready to use
|
||||
##### Static Analysis & Typing #####
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: v1.16.0
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
# args: [--python-version=3.10]
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.18.2
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: [--config-file=pyproject.toml]
|
||||
exclude: ^(examples|benchmarks|tests)/
|
||||
|
||||
##### Docstring Checks #####
|
||||
# - repo: https://github.com/akaihola/darglint2
|
||||
|
||||
@@ -72,7 +72,6 @@ post it.
|
||||
|
||||
Look at our implementations for [datasets](./src/lerobot/datasets/), [policies](./src/lerobot/policies/),
|
||||
environments ([aloha](https://github.com/huggingface/gym-aloha),
|
||||
[xarm](https://github.com/huggingface/gym-xarm),
|
||||
[pusht](https://github.com/huggingface/gym-pusht))
|
||||
and follow the same api design.
|
||||
|
||||
@@ -138,7 +137,7 @@ Follow these steps to start contributing:
|
||||
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
|
||||
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
|
||||
|
||||
Set up a development environment with conda or miniconda:
|
||||
Set up a development environment with conda:
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
|
||||
|
||||
10
Makefile
10
Makefile
@@ -119,10 +119,9 @@ test-tdmpc-ete-train:
|
||||
--policy.type=tdmpc \
|
||||
--policy.device=$(DEVICE) \
|
||||
--policy.push_to_hub=false \
|
||||
--env.type=xarm \
|
||||
--env.task=XarmLift-v0 \
|
||||
--env.type=pusht \
|
||||
--env.episode_length=5 \
|
||||
--dataset.repo_id=lerobot/xarm_lift_medium \
|
||||
--dataset.repo_id=lerobot/pusht_image \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
@@ -140,9 +139,10 @@ test-tdmpc-ete-eval:
|
||||
lerobot-eval \
|
||||
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=xarm \
|
||||
--env.type=pusht \
|
||||
--env.episode_length=5 \
|
||||
--env.task=XarmLift-v0 \
|
||||
--env.observation_height=96 \
|
||||
--env.observation_width=96 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
|
||||
21
README.md
21
README.md
@@ -104,14 +104,14 @@ LeRobot works with Python 3.10+ and PyTorch 2.2+.
|
||||
|
||||
### Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniforge`](https://conda-forge.org/download/):
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
When using `conda`, install `ffmpeg` in your environment:
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
@@ -185,6 +185,11 @@ _Replace `[...]` with your desired features._
|
||||
For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install libero or pi tags, you will have to do: `pip install "lerobot[pi,libero]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
### Weights & Biases
|
||||
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
||||
@@ -197,7 +202,7 @@ wandb login
|
||||
|
||||
### Visualize datasets
|
||||
|
||||
Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
|
||||
Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
|
||||
|
||||
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
|
||||
|
||||
@@ -207,13 +212,13 @@ lerobot-dataset-viz \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
or from a dataset in a local folder with the `root` option and the `--mode local` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--root ./my_local_data_dir \
|
||||
--local-files-only 1 \
|
||||
--mode local \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
@@ -310,7 +315,7 @@ To upload these to the hub, run the following:
|
||||
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
|
||||
```
|
||||
|
||||
See [eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/eval.py) for an example of how other people may use your policy.
|
||||
See [lerobot_eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py) for an example of how other people may use your policy.
|
||||
|
||||
### Acknowledgment
|
||||
|
||||
@@ -337,7 +342,3 @@ If you want, you can cite this work with:
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#huggingface/lerobot&Timeline)
|
||||
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@@ -75,6 +75,14 @@ RUN uv venv --python python${PYTHON_VERSION}
|
||||
# Install Python dependencies for caching
|
||||
COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
|
||||
ARG UNBOUND_DEPS=false
|
||||
|
||||
RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml; \
|
||||
echo "Dependencies unbound:" && cat pyproject.toml; \
|
||||
fi
|
||||
|
||||
RUN uv pip install --no-cache ".[all]"
|
||||
|
||||
# Copy the rest of the application source code
|
||||
|
||||
@@ -61,6 +61,14 @@ RUN uv venv
|
||||
# Install Python dependencies for caching
|
||||
COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
|
||||
ARG UNBOUND_DEPS=false
|
||||
|
||||
RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml; \
|
||||
echo "Dependencies unbound:" && cat pyproject.toml; \
|
||||
fi
|
||||
|
||||
RUN uv pip install --no-cache ".[all]"
|
||||
|
||||
# Copy the rest of the application code
|
||||
|
||||
@@ -7,8 +7,6 @@
|
||||
- sections:
|
||||
- local: il_robots
|
||||
title: Imitation Learning for Robots
|
||||
- local: il_sim
|
||||
title: Imitation Learning in Sim
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
- local: integrate_hardware
|
||||
@@ -19,20 +17,37 @@
|
||||
title: Train RL in Simulation
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
- local: multi_gpu_training
|
||||
title: Multi GPU training
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
title: Using LeRobotDataset
|
||||
- local: porting_datasets_v3
|
||||
title: Porting Large Datasets
|
||||
- local: using_dataset_tools
|
||||
title: Using the Dataset Tools
|
||||
title: "Datasets"
|
||||
- sections:
|
||||
- local: act
|
||||
title: ACT
|
||||
- local: smolvla
|
||||
title: Finetune SmolVLA
|
||||
title: SmolVLA
|
||||
- local: pi0
|
||||
title: π₀ (Pi0)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
title: "Policies"
|
||||
- sections:
|
||||
- local: il_sim
|
||||
title: Imitation Learning in Sim
|
||||
- local: libero
|
||||
title: Using Libero
|
||||
title: "Policies"
|
||||
|
||||
- local: metaworld
|
||||
title: Using MetaWorld
|
||||
title: "Simulation"
|
||||
- sections:
|
||||
- local: introduction_processors
|
||||
title: Introduction to Robot Processors
|
||||
|
||||
92
docs/source/act.mdx
Normal file
92
docs/source/act.mdx
Normal file
@@ -0,0 +1,92 @@
|
||||
# ACT (Action Chunking with Transformers)
|
||||
|
||||
ACT is a **lightweight and efficient policy for imitation learning**, especially well-suited for fine-grained manipulation tasks. It's the **first model we recommend when you're starting out** with LeRobot due to its fast training time, low computational requirements, and strong performance.
|
||||
|
||||
<div class="video-container">
|
||||
<iframe
|
||||
width="100%"
|
||||
height="415"
|
||||
src="https://www.youtube.com/embed/ft73x0LfGpM"
|
||||
title="LeRobot ACT Tutorial"
|
||||
frameborder="0"
|
||||
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
|
||||
allowfullscreen
|
||||
></iframe>
|
||||
</div>
|
||||
|
||||
_Watch this tutorial from the LeRobot team to learn how ACT works: [LeRobot ACT Tutorial](https://www.youtube.com/watch?v=ft73x0LfGpM)_
|
||||
|
||||
## Model Overview
|
||||
|
||||
Action Chunking with Transformers (ACT) was introduced in the paper [Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware](https://arxiv.org/abs/2304.13705) by Zhao et al. The policy was designed to enable precise, contact-rich manipulation tasks using affordable hardware and minimal demonstration data.
|
||||
|
||||
### Why ACT is Great for Beginners
|
||||
|
||||
ACT stands out as an excellent starting point for several reasons:
|
||||
|
||||
- **Fast Training**: Trains in a few hours on a single GPU
|
||||
- **Lightweight**: Only ~80M parameters, making it efficient and easy to work with
|
||||
- **Data Efficient**: Often achieves high success rates with just 50 demonstrations
|
||||
|
||||
### Architecture
|
||||
|
||||
ACT uses a transformer-based architecture with three main components:
|
||||
|
||||
1. **Vision Backbone**: ResNet-18 processes images from multiple camera viewpoints
|
||||
2. **Transformer Encoder**: Synthesizes information from camera features, joint positions, and a learned latent variable
|
||||
3. **Transformer Decoder**: Generates coherent action sequences using cross-attention
|
||||
|
||||
The policy takes as input:
|
||||
|
||||
- Multiple RGB images (e.g., from wrist cameras, front/top cameras)
|
||||
- Current robot joint positions
|
||||
- A latent style variable `z` (learned during training, set to zero during inference)
|
||||
|
||||
And outputs a chunk of `k` future action sequences.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. ACT is included in the base LeRobot installation, so no additional dependencies are needed!
|
||||
|
||||
## Training ACT
|
||||
|
||||
ACT works seamlessly with the standard LeRobot training pipeline. Here's a complete example for training ACT on your dataset:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/your_dataset \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_your_dataset \
|
||||
--job_name=act_your_dataset \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=${HF_USER}/act_policy
|
||||
```
|
||||
|
||||
### Training Tips
|
||||
|
||||
1. **Start with defaults**: ACT's default hyperparameters work well for most tasks
|
||||
2. **Training duration**: Expect a few hours for 100k training steps on a single GPU
|
||||
3. **Batch size**: Start with batch size 8 and adjust based on your GPU memory
|
||||
|
||||
### Train using Google Colab
|
||||
|
||||
If your local computer doesn't have a powerful GPU, you can utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||
|
||||
## Evaluating ACT
|
||||
|
||||
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_robot \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Your task description" \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
```
|
||||
@@ -31,15 +31,15 @@ Then, spin up a policy server (in one terminal, or in a separate machine) specif
|
||||
You can spin up a policy server running:
|
||||
|
||||
```shell
|
||||
python src/lerobot/async_inference/policy_server.py \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080 \
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080
|
||||
```
|
||||
|
||||
This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with:
|
||||
|
||||
```shell
|
||||
python src/lerobot/async_inference/robot_client.py \
|
||||
python -m lerobot.async_inference.robot_client \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
@@ -113,9 +113,9 @@ As such, spinning up a policy server is as easy as specifying the host address a
|
||||
<hfoptions id="start_policy_server">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.scripts.server.policy_server \
|
||||
--host="localhost" \
|
||||
--port=8080
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
@@ -148,7 +148,7 @@ The `RobotClient` streams observations to the `PolicyServer`, and receives actio
|
||||
<hfoptions id="start_robot_client">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python src/lerobot/async_inference/robot_client.py \
|
||||
python -m lerobot.async_inference.robot_client \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
|
||||
122
docs/source/groot.mdx
Normal file
122
docs/source/groot.mdx
Normal file
@@ -0,0 +1,122 @@
|
||||
# GR00T N1.5 Policy
|
||||
|
||||
GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
|
||||
|
||||
This document outlines the specifics of its integration and usage within the LeRobot framework.
|
||||
|
||||
## Model Overview
|
||||
|
||||
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
|
||||
|
||||
Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
|
||||
|
||||
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
|
||||
|
||||
Its strong performance comes from being trained on an expansive and diverse humanoid dataset, which includes:
|
||||
|
||||
- Real captured data from robots.
|
||||
- Synthetic data generated using NVIDIA Isaac GR00T Blueprint.
|
||||
- Internet-scale video data.
|
||||
|
||||
This approach allows the model to be highly adaptable through post-training for specific embodiments, tasks, and environments.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
As of today, GR00T N1.5 requires flash attention for it's internal working.
|
||||
|
||||
We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
|
||||
|
||||
1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
|
||||
2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
|
||||
|
||||
```bash
|
||||
# Check https://pytorch.org/get-started/locally/ for your system
|
||||
pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
|
||||
pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
|
||||
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
|
||||
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
|
||||
```
|
||||
|
||||
3. Install LeRobot by running:
|
||||
|
||||
```bash
|
||||
pip install lerobot[groot] # consider also installing libero,dev and test tags
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To use GR00T in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=groot
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Training Command Example
|
||||
|
||||
Here's a complete training command for finetuning the base GR00T model on your own dataset:
|
||||
|
||||
```bash
|
||||
# Using a multi-GPU setup
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=$NUM_GPUS \
|
||||
$(which lerobot-train) \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--save_checkpoint=true \
|
||||
--batch_size=$BATCH_SIZE \
|
||||
--steps=$NUM_STEPS \
|
||||
--save_freq=$SAVE_FREQ \
|
||||
--log_freq=$LOG_FREQ \
|
||||
--policy.push_to_hub=true \
|
||||
--policy.type=groot \
|
||||
--policy.repo_id=$REPO_ID \
|
||||
--policy.tune_diffusion_model=false \
|
||||
--dataset.repo_id=$DATASET_ID \
|
||||
--wandb.enable=true \
|
||||
--wandb.disable_artifact=true \
|
||||
--job_name=$JOB_NAME
|
||||
```
|
||||
|
||||
## Performance Results
|
||||
|
||||
### Libero Benchmark Results
|
||||
|
||||
GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
|
||||
|
||||
| Benchmark | LeRobot Implementation | GR00T Reference |
|
||||
| ------------------ | ---------------------- | --------------- |
|
||||
| **Libero Spatial** | 82.0% | 92.0% |
|
||||
| **Libero Object** | 99.0% | 92.0% |
|
||||
| **Libero Long** | 82.0% | 76.0% |
|
||||
| **Average** | 87.0% | 87.0% |
|
||||
|
||||
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
|
||||
|
||||
### Evaluate in your hardware setup
|
||||
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=bi_so100_follower \
|
||||
--robot.left_arm_port=/dev/ttyACM1 \
|
||||
--robot.right_arm_port=/dev/ttyACM0 \
|
||||
--robot.id=bimanual_follower \
|
||||
--robot.cameras='{ right: {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30},
|
||||
left: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30},
|
||||
top: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30},
|
||||
}' \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm"
|
||||
--policy.path=<user>/groot-bimanual # your trained model
|
||||
--dataset.episode_time_s=30
|
||||
--dataset.reset_time_s=10
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T).
|
||||
@@ -165,7 +165,7 @@ huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
HF_USER=$(hf auth whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
@@ -513,13 +513,14 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.policies.factory import make_processor
|
||||
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -562,7 +563,7 @@ init_rerun(session_name="recording")
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
preprocessor, postprocessor = make_processor(
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
# Installation
|
||||
|
||||
## Install [`miniforge`](https://conda-forge.org/download/)
|
||||
|
||||
```bash
|
||||
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
|
||||
bash Miniforge3-$(uname)-$(uname -m).sh
|
||||
```
|
||||
|
||||
## Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install)
|
||||
Create a virtual environment with Python 3.10, using conda:
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
@@ -14,7 +21,7 @@ Then activate your conda environment, you have to do this each time you open a s
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
When using `conda`, install `ffmpeg` in your environment:
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
@@ -74,6 +81,9 @@ _Replace `[...]` with your desired features._
|
||||
For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install libero or pi, you will have to do: `pip install "lerobot[pi,libero]@git+https://github.com/huggingface/lerobot.git"`
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||
@@ -91,7 +101,7 @@ LeRobot provides optional extras for specific functionalities. Multiple extras c
|
||||
|
||||
### Simulations
|
||||
|
||||
Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), `xarm` ([gym-xarm](https://github.com/huggingface/gym-xarm)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht))
|
||||
Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht))
|
||||
Example:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -8,7 +8,7 @@ To that end, we provide the [`Robot`](https://github.com/huggingface/lerobot/blo
|
||||
|
||||
- Your own robot which exposes a communication interface (e.g. serial, CAN, TCP)
|
||||
- A way to read sensor data and send motor commands programmatically, e.g. manufacturer's SDK or API, or your own protocol implementation.
|
||||
- LeRobot installed in your environment. Follow our [Installation Guide](./installation.mdx).
|
||||
- LeRobot installed in your environment. Follow our [Installation Guide](./installation).
|
||||
|
||||
## Choose your motors
|
||||
|
||||
@@ -65,7 +65,7 @@ class MyCoolRobotConfig(RobotConfig):
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
[Cameras tutorial](./cameras.mdx) to understand how to detect and add your camera.
|
||||
[Cameras tutorial](./cameras) to understand how to detect and add your camera.
|
||||
|
||||
Next, we'll create our actual robot class which inherits from `Robot`. This abstract class defines a contract you must follow for your robot to be usable with the rest of the LeRobot tools.
|
||||
|
||||
@@ -208,34 +208,36 @@ LeRobot supports saving and loading calibration data automatically. This is usef
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
> @property
|
||||
> def is_calibrated(self) -> bool:
|
||||
> return True
|
||||
>
|
||||
> def calibrate(self) -> None:
|
||||
> pass
|
||||
> ```
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
def calibrate(self) -> None:
|
||||
pass
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### `is_calibrated`
|
||||
|
||||
This should reflect whether your robot has the required calibration loaded.
|
||||
|
||||
```
|
||||
<!-- prettier-ignore-end -->python
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.bus.is_calibrated
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### `calibrate()`
|
||||
|
||||
The goal of the calibration is twofold:
|
||||
- Know the physical range of motion of each motors in order to only send commands within this range.
|
||||
- Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
|
||||
|
||||
- Know the physical range of motion of each motors in order to only send commands within this range.
|
||||
- Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
|
||||
|
||||
It should implement the logic for calibration (if relevant) and update the `self.calibration` dictionary. If you are using Feetech or Dynamixel motors, our bus interfaces already include methods to help with this.
|
||||
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
def calibrate(self) -> None:
|
||||
@@ -335,6 +337,134 @@ For implementing teleoperation devices, we also provide a [`Teleoperator`](https
|
||||
|
||||
The main differences are in the I/O functions: a teleoperator allows you to produce action via `get_action` and can receive feedback actions via `send_feedback`. Feedback could be anything controllable on the teleoperation device that could help the person controlling it understand the consequences of the actions sent. Think motion/force feedback on a leader arm, vibrations on a gamepad controller for example. To implement a teleoperator, you can follow this same tutorial and adapt it for these two methods.
|
||||
|
||||
## Using Your Own `LeRobot` Devices 🔌
|
||||
|
||||
You can easily extend `lerobot` with your own custom hardware—be it a camera, robot, or teleoperation device—by creating a separate, installable Python package. If you follow a few simple conventions, the `lerobot` command-line tools (like `lerobot-teleop` and `lerobot-record`) will **automatically discover and integrate your creations** without requiring any changes to the `lerobot` source code.
|
||||
|
||||
This guide outlines the conventions your plugin must follow.
|
||||
|
||||
### The 4 Core Conventions
|
||||
|
||||
To ensure your custom device is discoverable, you must adhere to the following four rules.
|
||||
|
||||
#### 1\. Create an Installable Package with a Specific Prefix
|
||||
|
||||
Your project must be a standard, installable Python package. Crucially, the name of your package (as defined in `pyproject.toml` or `setup.py`) must begin with one of these prefixes:
|
||||
|
||||
- `lerobot_robot_` for a robot.
|
||||
- `lerobot_camera_` for a camera.
|
||||
- `lerobot_teleoperator_` for a teleoperation device.
|
||||
|
||||
This prefix system is how `lerobot` automatically finds your plugin in the Python environment.
|
||||
|
||||
#### 2\. Follow the `SomethingConfig`/`Something` Naming Pattern
|
||||
|
||||
Your device's implementation class must be named after its configuration class, simply by removing the `Config` suffix.
|
||||
|
||||
- **Config Class:** `MyAwesomeTeleopConfig`
|
||||
- **Device Class:** `MyAwesomeTeleop`
|
||||
|
||||
#### 3\. Place Your Files in a Predictable Structure
|
||||
|
||||
The device class (`MyAwesomeTeleop`) must be located in a predictable module relative to its configuration class (`MyAwesomeTeleopConfig`). `lerobot` will automatically search in these locations:
|
||||
|
||||
- In the **same module** as the config class.
|
||||
- In a **submodule named after the device** (e.g., `my_awesome_teleop.py`).
|
||||
|
||||
The recommended and simplest structure is to place them in separate, clearly named files within the same directory.
|
||||
|
||||
#### 4\. Expose Classes in `__init__.py`
|
||||
|
||||
Your package's `__init__.py` file should import and expose both the configuration and the device classes, making them easily accessible.
|
||||
|
||||
### Putting It All Together: A Complete Example
|
||||
|
||||
Let's create a new teleoperator called `my_awesome_teleop`.
|
||||
|
||||
#### Directory Structure
|
||||
|
||||
Here is what the project folder should look like. The package name, `lerobot_teleoperator_my_awesome_teleop`, follows **Convention \#1**.
|
||||
|
||||
```
|
||||
lerobot_teleoperator_my_awesome_teleop/
|
||||
├── pyproject.toml # (or setup.py) lists lerobot as a dependency
|
||||
└── lerobot_teleoperator_my_awesome_teleop/
|
||||
├── __init__.py
|
||||
├── config_my_awesome_teleop.py
|
||||
└── my_awesome_teleop.py
|
||||
```
|
||||
|
||||
#### File Contents
|
||||
|
||||
- **`config_my_awesome_teleop.py`**: Defines the configuration class. Note the `Config` suffix (**Convention \#2**).
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.teleoperators.config import TeleoperatorConfig
|
||||
|
||||
@TeleoperatorConfig.register_subclass("my_awesome_teleop")
|
||||
@dataclass
|
||||
class MyAwesomeTeleopConfig(TeleoperatorConfig):
|
||||
# Your configuration fields go here
|
||||
port: str = "192.168.1.1"
|
||||
```
|
||||
|
||||
- **`my_awesome_teleop.py`**: Implements the device. The class name `MyAwesomeTeleop` matches its config class name (**Convention \#2**). This file structure adheres to **Convention \#3**.
|
||||
|
||||
```python
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
|
||||
from .config_my_awesome_teleop import MyAwesomeTeleopConfig
|
||||
|
||||
class MyAwesomeTeleop(Teleoperator):
|
||||
config_class = MyAwesomeTeleopConfig
|
||||
name = "my_awesome_teleop"
|
||||
|
||||
def __init__(self, config: MyAwesomeTeleopConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Your device logic (e.g., connect) goes here
|
||||
```
|
||||
|
||||
- **`__init__.py`**: Exposes the key classes (**Convention \#4**).
|
||||
|
||||
```python
|
||||
from .config_my_awesome_teleop import MyAwesomeTeleopConfig
|
||||
from .my_awesome_teleop import MyAwesomeTeleop
|
||||
```
|
||||
|
||||
### Installation and Usage
|
||||
|
||||
1. **Install your new plugin in your Python environment.** You can install your local plugin package using `pip`'s editable mode or from PyPi.
|
||||
|
||||
```bash
|
||||
# Locally
|
||||
# Navigate to your plugin's root directory and install it
|
||||
cd lerobot_teleoperator_my_awesome_teleop
|
||||
pip install -e .
|
||||
|
||||
# From PyPi
|
||||
pip install lerobot_teleoperator_my_awesome_teleop
|
||||
```
|
||||
|
||||
2. **Use it directly from the command line.** Now, you can use your custom device by referencing its type.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate --teleop.type=my_awesome_teleop \
|
||||
# other arguments
|
||||
```
|
||||
|
||||
And that's it\! Your custom device is now fully integrated.
|
||||
|
||||
### Looking for an example ?
|
||||
|
||||
Check out these two packages from the community:
|
||||
|
||||
- https://github.com/SpesRobotics/lerobot-robot-xarm
|
||||
- https://github.com/SpesRobotics/lerobot-teleoperator-teleop
|
||||
|
||||
## Wrapping Up
|
||||
|
||||
Once your robot class is complete, you can leverage the LeRobot ecosystem:
|
||||
|
||||
@@ -297,9 +297,9 @@ LeRobot provides many registered processor steps. Here are the most commonly use
|
||||
|
||||
### Next Steps
|
||||
|
||||
- **[Implement Your Own Processor](implement_your_own_processor.mdx)** - Create custom processor steps
|
||||
- **[Debug Your Pipeline](debug_processor_pipeline.mdx)** - Troubleshoot and optimize pipelines
|
||||
- **[Processors for Robots and Teleoperators](processors_robots_teleop.mdx)** - Real-world integration patterns
|
||||
- **[Implement Your Own Processor](./implement_your_own_processor)** - Create custom processor steps
|
||||
- **[Debug Your Pipeline](./debug_processor_pipeline)** - Troubleshoot and optimize pipelines
|
||||
- **[Processors for Robots and Teleoperators](./processors_robots_teleop)** - Real-world integration patterns
|
||||
|
||||
## Summary
|
||||
|
||||
|
||||
@@ -279,3 +279,36 @@ python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DAT
|
||||
- Aggregates parquet files: `episode-0000.parquet`, `episode-0001.parquet`, … → **`file-0000.parquet`**, …
|
||||
- Aggregates mp4 files: `episode-0000.mp4`, `episode-0001.mp4`, … → **`file-0000.mp4`**, …
|
||||
- Updates `meta/episodes/*` (chunked Parquet) with per‑episode lengths, tasks, and byte/frame offsets.
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Always call `finalize()` before pushing
|
||||
|
||||
When creating or recording datasets, you **must** call `dataset.finalize()` to properly close parquet writers. See the [PR #1903](https://github.com/huggingface/lerobot/pull/1903) for more details.
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Create dataset and record episodes
|
||||
dataset = LeRobotDataset.create(...)
|
||||
|
||||
for episode in range(num_episodes):
|
||||
# Record frames
|
||||
for frame in episode_data:
|
||||
dataset.add_frame(frame)
|
||||
dataset.save_episode()
|
||||
|
||||
# Call finalize() when done recording and before push_to_hub()
|
||||
dataset.finalize() # Closes parquet writers, writes metadata footers
|
||||
dataset.push_to_hub()
|
||||
```
|
||||
|
||||
**Why is this necessary?**
|
||||
|
||||
Dataset v3.0 uses incremental parquet writing with buffered metadata for efficiency. The `finalize()` method:
|
||||
|
||||
- Flushes any buffered episode metadata to disk
|
||||
- Closes parquet writers to write footer metadata, otherwise the parquet files will be corrupt
|
||||
- Ensures the dataset is valid for loading
|
||||
|
||||
Without calling `finalize()`, your parquet files will be incomplete and the dataset won't load properly.
|
||||
|
||||
@@ -125,3 +125,42 @@ lerobot-train \
|
||||
LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation:
|
||||
|
||||
- `export MUJOCO_GL=egl` → for headless servers (e.g. HPC, cloud)
|
||||
|
||||
## Reproducing π₀.₅ results
|
||||
|
||||
We reproduce the results of π₀.₅ on the LIBERO benchmark using the LeRobot implementation. We take the Physical Intelligence LIBERO base model (`pi05_libero`) and finetune for an additional 6k steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero).
|
||||
|
||||
The finetuned model can be found here:
|
||||
|
||||
- **π₀.₅ LIBERO**: [lerobot/pi05_libero_finetuned](https://huggingface.co/lerobot/pi05_libero_finetuned)
|
||||
|
||||
We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--output_dir=/logs/ \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10 \
|
||||
--policy.path=pi05_libero_finetuned \
|
||||
--policy.n_action_steps=10 \
|
||||
--output_dir=./eval_logs/ \
|
||||
--env.max_parallel_tasks=1
|
||||
```
|
||||
|
||||
**Note:** We set `n_action_steps=10`, similar to the original OpenPI implementation.
|
||||
|
||||
### Results
|
||||
|
||||
We obtain the following results on the LIBERO benchmark:
|
||||
|
||||
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
|
||||
| -------- | -------------- | ------------- | ----------- | --------- | -------- |
|
||||
| **π₀.₅** | 97.0 | 99.0 | 98.0 | 96.0 | **97.5** |
|
||||
|
||||
These results are consistent with the original [results](https://github.com/Physical-Intelligence/openpi/tree/main/examples/libero#results) reported by Physical Intelligence:
|
||||
|
||||
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
|
||||
| -------- | -------------- | ------------- | ----------- | --------- | --------- |
|
||||
| **π₀.₅** | 98.8 | 98.2 | 98.0 | 92.4 | **96.85** |
|
||||
|
||||
80
docs/source/metaworld.mdx
Normal file
80
docs/source/metaworld.mdx
Normal file
@@ -0,0 +1,80 @@
|
||||
# Meta-World
|
||||
|
||||
Meta-World is a well-designed, open-source simulation benchmark for multi-task and meta reinforcement learning in continuous-control robotic manipulation. It gives researchers a shared, realistic playground to test whether algorithms can _learn many different tasks_ and _generalize quickly to new ones_ — two central challenges for real-world robotics.
|
||||
|
||||
- 📄 [MetaWorld paper](https://arxiv.org/pdf/1910.10897)
|
||||
- 💻 [Original MetaWorld repo](https://github.com/Farama-Foundation/Metaworld)
|
||||
|
||||

|
||||
|
||||
## Why Meta-World matters
|
||||
|
||||
- **Diverse, realistic tasks.** Meta-World bundles a large suite of simulated manipulation tasks (50 in the MT50 suite) using everyday objects and a common tabletop Sawyer arm. This diversity exposes algorithms to a wide variety of dynamics, contacts and goal specifications while keeping a consistent control and observation structure.
|
||||
- **Focus on generalization and multi-task learning.** By evaluating across task distributions that share structure but differ in goals and objects, Meta-World reveals whether an agent truly learns transferable skills rather than overfitting to a narrow task.
|
||||
- **Standardized evaluation protocol.** It provides clear evaluation modes and difficulty splits, so different methods can be compared fairly across easy, medium, hard and very-hard regimes.
|
||||
- **Empirical insight.** Past evaluations on Meta-World show impressive progress on some fronts, but also highlight that current multi-task and meta-RL methods still struggle with large, diverse task sets. That gap points to important research directions.
|
||||
|
||||
## What it enables in LeRobot
|
||||
|
||||
In LeRobot, you can evaluate any policy or vision-language-action (VLA) model on Meta-World tasks and get a clear success-rate measure. The integration is designed to be straightforward:
|
||||
|
||||
- We provide a LeRobot-ready dataset for Meta-World (MT50) on the HF Hub: `https://huggingface.co/datasets/lerobot/metaworld_mt50`.
|
||||
- This dataset is formatted for the MT50 evaluation that uses all 50 tasks (the most challenging multi-task setting).
|
||||
- MT50 gives the policy a one-hot task vector and uses fixed object/goal positions for consistency.
|
||||
|
||||
- Task descriptions and the exact keys required for evaluation are available in the repo/dataset — use these to ensure your policy outputs the right success signals.
|
||||
|
||||
## Quick start, train a SmolVLA policy on Meta-World
|
||||
|
||||
Example command to train a SmolVLA policy on a subset of tasks:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--policy.repo_id=${HF_USER}/metaworld-test \
|
||||
--policy.load_vlm_weights=true \
|
||||
--dataset.repo_id=lerobot/metaworld_mt50 \
|
||||
--env.type=metaworld \
|
||||
--env.task=assembly-v3,dial-turn-v3,handle-press-side-v3 \
|
||||
--output_dir=./outputs/ \
|
||||
--steps=100000 \
|
||||
--batch_size=4 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval_freq=1000
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- `--env.task` accepts explicit task lists (comma separated) or difficulty groups (e.g., `env.task="hard"`).
|
||||
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget.
|
||||
- **Gymnasium Assertion Error**: if you encounter an error like
|
||||
`AssertionError: ['human', 'rgb_array', 'depth_array']` when running MetaWorld environments, this comes from a mismatch between MetaWorld and your Gymnasium version.
|
||||
We recommend using:
|
||||
|
||||
```bash
|
||||
pip install "gymnasium==1.1.0"
|
||||
```
|
||||
|
||||
to ensure proper compatibility.
|
||||
|
||||
## Quick start — evaluate a trained policy
|
||||
|
||||
To evaluate a trained policy on the Meta-World medium difficulty split:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="your-policy-id" \
|
||||
--env.type=metaworld \
|
||||
--env.task=medium \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=2
|
||||
```
|
||||
|
||||
This will run episodes and return per-task success rates using the standard Meta-World evaluation keys.
|
||||
|
||||
## Practical tips
|
||||
|
||||
- If you care about generalization, run on the full MT50 suite — it’s intentionally challenging and reveals strengths/weaknesses better than a few narrow tasks.
|
||||
- Use the one-hot task conditioning for multi-task training (MT10 / MT50 conventions) so policies have explicit task context.
|
||||
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.
|
||||
125
docs/source/multi_gpu_training.mdx
Normal file
125
docs/source/multi_gpu_training.mdx
Normal file
@@ -0,0 +1,125 @@
|
||||
# Multi-GPU Training
|
||||
|
||||
This guide shows you how to train policies on multiple GPUs using [Hugging Face Accelerate](https://huggingface.co/docs/accelerate).
|
||||
|
||||
## Installation
|
||||
|
||||
First, ensure you have accelerate installed:
|
||||
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
## Training with Multiple GPUs
|
||||
|
||||
You can launch training in two ways:
|
||||
|
||||
### Option 1: Without config (specify parameters directly)
|
||||
|
||||
You can specify all parameters directly in the command without running `accelerate config`:
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=2 \
|
||||
$(which lerobot-train) \
|
||||
--dataset.repo_id=${HF_USER}/my_dataset \
|
||||
--policy.type=act \
|
||||
--policy.repo_id=${HF_USER}/my_trained_policy \
|
||||
--output_dir=outputs/train/act_multi_gpu \
|
||||
--job_name=act_multi_gpu \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
**Key accelerate parameters:**
|
||||
|
||||
- `--multi_gpu`: Enable multi-GPU training
|
||||
- `--num_processes=2`: Number of GPUs to use
|
||||
- `--mixed_precision=fp16`: Use fp16 mixed precision (or `bf16` if supported)
|
||||
|
||||
### Option 2: Using accelerate config
|
||||
|
||||
If you prefer to save your configuration, you can optionally configure accelerate for your hardware setup by running:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
This interactive setup will ask you questions about your training environment (number of GPUs, mixed precision settings, etc.) and saves the configuration for future use. For a simple multi-GPU setup on a single machine, you can use these recommended settings:
|
||||
|
||||
- Compute environment: This machine
|
||||
- Number of machines: 1
|
||||
- Number of processes: (number of GPUs you want to use)
|
||||
- GPU ids to use: (leave empty to use all)
|
||||
- Mixed precision: fp16 or bf16 (recommended for faster training)
|
||||
|
||||
Then launch training with:
|
||||
|
||||
```bash
|
||||
accelerate launch $(which lerobot-train) \
|
||||
--dataset.repo_id=${HF_USER}/my_dataset \
|
||||
--policy.type=act \
|
||||
--policy.repo_id=${HF_USER}/my_trained_policy \
|
||||
--output_dir=outputs/train/act_multi_gpu \
|
||||
--job_name=act_multi_gpu \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
When you launch training with accelerate:
|
||||
|
||||
1. **Automatic detection**: LeRobot automatically detects if it's running under accelerate
|
||||
2. **Data distribution**: Your batch is automatically split across GPUs
|
||||
3. **Gradient synchronization**: Gradients are synchronized across GPUs during backpropagation
|
||||
4. **Single process logging**: Only the main process logs to wandb and saves checkpoints
|
||||
|
||||
## Learning Rate and Training Steps Scaling
|
||||
|
||||
**Important:** LeRobot does **NOT** automatically scale learning rates or training steps based on the number of GPUs. This gives you full control over your training hyperparameters.
|
||||
|
||||
### Why No Automatic Scaling?
|
||||
|
||||
Many distributed training frameworks automatically scale the learning rate by the number of GPUs (e.g., `lr = base_lr × num_gpus`).
|
||||
However, LeRobot keeps the learning rate exactly as you specify it.
|
||||
|
||||
### When and How to Scale
|
||||
|
||||
If you want to scale your hyperparameters when using multiple GPUs, you should do it manually:
|
||||
|
||||
**Learning Rate Scaling:**
|
||||
|
||||
```bash
|
||||
# Example: 2 GPUs with linear LR scaling
|
||||
# Base LR: 1e-4, with 2 GPUs -> 2e-4
|
||||
accelerate launch --num_processes=2 $(which lerobot-train) \
|
||||
--optimizer.lr=2e-4 \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--policy=act
|
||||
```
|
||||
|
||||
**Training Steps Scaling:**
|
||||
|
||||
Since the effective batch size `bs` increases with multiple GPUs (batch_size × num_gpus), you may want to reduce the number of training steps proportionally:
|
||||
|
||||
```bash
|
||||
# Example: 2 GPUs with effective batch size 2x larger
|
||||
# Original: batch_size=8, steps=100000
|
||||
# With 2 GPUs: batch_size=8 (16 in total), steps=50000
|
||||
accelerate launch --num_processes=2 $(which lerobot-train) \
|
||||
--batch_size=8 \
|
||||
--steps=50000 \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--policy=act
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
|
||||
- Training logs, checkpoints, and hub uploads are only done by the main process to avoid conflicts. Non-main processes have console logging disabled to prevent duplicate output.
|
||||
- The effective batch size is `batch_size × num_gpus`. If you use 4 GPUs with `--batch_size=8`, your effective batch size is 32.
|
||||
- Learning rate scheduling is handled correctly across multiple processes—LeRobot sets `step_scheduler_with_optimizer=False` to prevent accelerate from adjusting scheduler steps based on the number of processes.
|
||||
- When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility.
|
||||
- WandB integration automatically initializes only on the main process, preventing multiple runs from being created.
|
||||
|
||||
For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate). If you want to learn more about how to train on a large number of GPUs, checkout this awesome guide: [Ultrascale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook).
|
||||
328
docs/source/openarms.mdx
Normal file
328
docs/source/openarms.mdx
Normal file
@@ -0,0 +1,328 @@
|
||||
# OpenArms Robot
|
||||
|
||||
OpenArms is a 7 DOF robotic arm with a gripper, designed by [Enactic, Inc.](https://www.enactic.com/) It uses Damiao motors controlled via CAN bus communication and MIT control mode for smooth, precise motion.
|
||||
|
||||
## Hardware Overview
|
||||
|
||||
- **7 DOF per arm** (14 DOF total for dual arm setup)
|
||||
- **1 gripper per arm** (2 grippers total)
|
||||
- **Damiao motors** with 4 different types:
|
||||
- **DM8009** (DM-J8009P-2EC) for shoulders (J1, J2) - high torque
|
||||
- **DM4340** for shoulder rotation and elbow (J3, J4)
|
||||
- **DM4310** (DM-J4310-2EC V1.1) for wrist (J5, J6, J7) and gripper (J8)
|
||||
- **24V power supply** required
|
||||
- **CAN interface device**:
|
||||
- **Linux**: Any SocketCAN-compatible adapter
|
||||
- **macOS**: CANable, PEAK PCAN-USB, or Kvaser USBcan
|
||||
- Proper CAN wiring (CANH, CANL, 120Ω termination)
|
||||
|
||||
|
||||
## Motor Configuration
|
||||
|
||||
Each arm has the following motor configuration based on the [OpenArm setup guide](https://docs.openarm.dev/software/setup/):
|
||||
|
||||
| Joint | Motor | Motor Type | Sender CAN ID | Receiver ID | Description |
|
||||
|-------|-------|------------|---------------|-------------|-------------|
|
||||
| J1 | joint_1 | DM8009 | 0x01 | 0x11 | Shoulder pan |
|
||||
| J2 | joint_2 | DM8009 | 0x02 | 0x12 | Shoulder lift |
|
||||
| J3 | joint_3 | DM4340 | 0x03 | 0x13 | Shoulder rotation |
|
||||
| J4 | joint_4 | DM4340 | 0x04 | 0x14 | Elbow flex |
|
||||
| J5 | joint_5 | DM4310 | 0x05 | 0x15 | Wrist roll |
|
||||
| J6 | joint_6 | DM4310 | 0x06 | 0x16 | Wrist pitch |
|
||||
| J7 | joint_7 | DM4310 | 0x07 | 0x17 | Wrist rotation |
|
||||
| J8 | gripper | DM4310 | 0x08 | 0x18 | Gripper |
|
||||
|
||||
For dual arm setups, the left arm uses IDs 0x09-0x10 for joints 1-8 with the same motor types.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Install system dependencies
|
||||
sudo apt install can-utils iproute2
|
||||
|
||||
# Install LeRobot with OpenArms support
|
||||
pip install -e ".[openarms]"
|
||||
```
|
||||
|
||||
## Setup Guide
|
||||
|
||||
### Step 1: Motor ID Configuration
|
||||
|
||||
**IMPORTANT**: Before using the robot, motors must be configured with the correct CAN IDs.
|
||||
|
||||
Refer to the [OpenArm Motor ID Configuration Guide](https://docs.openarm.dev/software/setup/motor-id) for detailed instructions using the Damiao Debugging Tools on Windows.
|
||||
|
||||
Key points:
|
||||
- Each motor needs a unique **Sender CAN ID** (0x01-0x08)
|
||||
- Each motor needs a unique **Receiver/Master ID** (0x11-0x18)
|
||||
- Use the Damiao Debugging Tools to set these IDs
|
||||
|
||||
### Step 2: Setup CAN Interface
|
||||
|
||||
Configure your CAN interface as described in the [OpenArm CAN Setup Guide](https://docs.openarm.dev/software/setup/can-setup):
|
||||
|
||||
#### Linux (SocketCAN)
|
||||
|
||||
```bash
|
||||
# Find your CAN interface
|
||||
ip link show
|
||||
|
||||
# Configure can0, 1, 2, 3
|
||||
sudo ip link set can0 down
|
||||
sudo ip link set can0 type can bitrate 1000000
|
||||
sudo ip link set can0 up
|
||||
|
||||
sudo ip link set can1 down
|
||||
sudo ip link set can1 type can bitrate 1000000
|
||||
sudo ip link set can1 up
|
||||
|
||||
sudo ip link set can2 down
|
||||
sudo ip link set can2 type can bitrate 1000000
|
||||
sudo ip link set can2 up
|
||||
|
||||
sudo ip link set can3 down
|
||||
sudo ip link set can3 type can bitrate 1000000
|
||||
sudo ip link set can3 up
|
||||
|
||||
# Verify configuration
|
||||
ip link show can0
|
||||
```
|
||||
|
||||
or run:
|
||||
|
||||
`examples/openarms/setup_can.sh`
|
||||
|
||||
### Testing canbus and motor connection
|
||||
|
||||
Please run this script to check if all motors can be found and to find your can-fd speed: `python examples/openarms/debug_can_communication.py`
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Setup
|
||||
|
||||
|
||||
```python
|
||||
from lerobot.robots.openarms import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
|
||||
# Configure for dual arm setup
|
||||
config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
can_interface="socketcan", # Or "auto" for auto-detection
|
||||
id="openarms_dual",
|
||||
is_dual_arm=True,
|
||||
)
|
||||
|
||||
robot = OpenArmsFollower(config)
|
||||
robot.connect()
|
||||
```
|
||||
|
||||
### Calibration
|
||||
|
||||
On first use, you'll need to calibrate the robot:
|
||||
|
||||
```python
|
||||
robot.calibrate()
|
||||
```
|
||||
|
||||
The calibration process will:
|
||||
1. Disable torque on all motors
|
||||
2. Ask you to position arms in **hanging position with grippers closed**
|
||||
3. Set this as the zero position
|
||||
4. Ask you to move each joint through its full range
|
||||
5. Record min/max positions for each joint
|
||||
6. Save calibration to file
|
||||
|
||||
### Reading Observations
|
||||
|
||||
The robot provides comprehensive state information:
|
||||
|
||||
```python
|
||||
observation = robot.get_observation()
|
||||
|
||||
# Observation includes for each motor:
|
||||
# - {motor_name}.pos: Position in degrees
|
||||
# - {motor_name}.vel: Velocity in degrees/second
|
||||
# - {motor_name}.torque: Motor torque
|
||||
# - {camera_name}: Camera images (if configured)
|
||||
|
||||
print(f"Right arm joint 1 position: {observation['right_joint_1.pos']:.1f}°")
|
||||
print(f"Right arm joint 1 velocity: {observation['right_joint_1.vel']:.1f}°/s")
|
||||
print(f"Right arm joint 1 torque: {observation['right_joint_1.torque']:.3f} N·m")
|
||||
```
|
||||
|
||||
### Sending Actions
|
||||
|
||||
```python
|
||||
# Send target positions (in degrees)
|
||||
action = {
|
||||
"right_joint_1.pos": 45.0,
|
||||
"right_joint_2.pos": -30.0,
|
||||
# ... all joints
|
||||
"right_gripper.pos": 45.0, # Half-closed
|
||||
}
|
||||
|
||||
actual_action = robot.send_action(action)
|
||||
```
|
||||
|
||||
### Gripper Control
|
||||
|
||||
```python
|
||||
# Open gripper
|
||||
robot.open_gripper(arm="right")
|
||||
|
||||
# Close gripper
|
||||
robot.close_gripper(arm="right")
|
||||
```
|
||||
|
||||
## Safety Features
|
||||
|
||||
### 1. Maximum Relative Target
|
||||
|
||||
Limits how far a joint can move in a single command to prevent sudden movements:
|
||||
|
||||
```python
|
||||
config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
# Limit all joints to 10 degrees per command
|
||||
max_relative_target=10.0,
|
||||
|
||||
# Or set per-motor limits
|
||||
max_relative_target={
|
||||
"right_joint_1": 15.0, # Slower moving joint
|
||||
"right_joint_2": 10.0,
|
||||
"right_gripper": 5.0, # Very slow gripper
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**How it works**: If current position is 50° and you command 80°, with `max_relative_target=10.0`, the robot will only move to 60° in that step.
|
||||
|
||||
### 2. Torque Limits
|
||||
|
||||
Control maximum torque output, especially important for grippers and teleoperation:
|
||||
|
||||
```python
|
||||
config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
# Gripper torque limit (fraction of motor's max torque)
|
||||
gripper_torque_limit=0.5, # 50% of max torque
|
||||
)
|
||||
```
|
||||
|
||||
Lower torque limits prevent damage when gripping delicate objects.
|
||||
|
||||
### 3. MIT Control Gains
|
||||
|
||||
Control responsiveness and stability via PID-like gains:
|
||||
|
||||
```python
|
||||
config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
position_kp=10.0, # Position gain (higher = more responsive)
|
||||
position_kd=0.5, # Velocity damping (higher = more damped)
|
||||
)
|
||||
```
|
||||
|
||||
**Guidelines**:
|
||||
- **For following (robot)**: Higher gains for responsiveness
|
||||
- `position_kp=10.0`, `position_kd=0.5`
|
||||
- **For teleoperation (leader)**: Lower gains or disable torque for manual movement
|
||||
- `manual_control=True` (torque disabled)
|
||||
|
||||
### 4. Velocity Limits
|
||||
|
||||
Velocity limits are enforced by the Damiao motors based on motor type. For DM4310:
|
||||
- Max velocity: 30 rad/s ≈ 1718°/s
|
||||
|
||||
The motors will automatically limit velocity to safe values.
|
||||
|
||||
## Teleoperation
|
||||
|
||||
### Leader Arm Setup
|
||||
|
||||
The leader arm is moved manually (torque disabled) to generate commands:
|
||||
|
||||
```python
|
||||
from lerobot.teleoperators.openarms import OpenArmsLeader
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
|
||||
config = OpenArmsLeaderConfig(
|
||||
port="can1", # Separate CAN interface for leader
|
||||
id="openarms_leader",
|
||||
manual_control=True, # Torque disabled for manual movement
|
||||
is_dual_arm=True,
|
||||
)
|
||||
|
||||
leader = OpenArmsLeader(config)
|
||||
leader.connect()
|
||||
|
||||
# Read current position as action
|
||||
action = leader.get_action()
|
||||
# action contains positions for all joints in degrees
|
||||
```
|
||||
|
||||
### Safety Considerations for Teleoperation
|
||||
|
||||
1. **Use separate CAN interfaces** for leader and follower to avoid conflicts
|
||||
2. **Enable max_relative_target** on follower to smooth abrupt movements
|
||||
3. **Lower torque limits** on follower to prevent damage from tracking errors
|
||||
4. **Test with one arm** before enabling dual arm teleoperation
|
||||
5. **Have emergency stop** ready (power switch or CAN disable)
|
||||
|
||||
```python
|
||||
# Recommended follower config for teleoperation
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
max_relative_target=5.0, # Small steps for smooth following
|
||||
gripper_torque_limit=0.3, # Low torque for safety
|
||||
position_kp=5.0, # Lower gains for gentler following
|
||||
position_kd=0.3,
|
||||
)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Motor Shaking/Unstable
|
||||
|
||||
- **Lower control gains**: Reduce `position_kp` and `position_kd`
|
||||
- **Check calibration**: Re-run calibration procedure
|
||||
- **Verify power**: Insufficient current can cause instability
|
||||
- **Check mechanical**: Loose connections, binding, or damaged components
|
||||
|
||||
### CAN Bus Errors
|
||||
|
||||
```bash
|
||||
# Check for errors
|
||||
ip -s link show can0
|
||||
|
||||
# Reset CAN interface
|
||||
sudo ip link set can0 down
|
||||
sudo ip link set can0 up
|
||||
```
|
||||
|
||||
### Control Mode
|
||||
|
||||
OpenArms uses **MIT control mode** which allows simultaneous control of:
|
||||
- Position (degrees)
|
||||
- Velocity (degrees/second)
|
||||
- Torque (N·m)
|
||||
- Position gain (Kp)
|
||||
- Velocity damping (Kd)
|
||||
|
||||
### Communication
|
||||
|
||||
- **Protocol**: CAN 2.0 at 1 Mbps (or CAN-FD at 5 Mbps)
|
||||
- **Frame format**: Standard 11-bit IDs
|
||||
- **Update rate**: Typically 50-100 Hz depending on motor count
|
||||
- **Latency**: ~10-20ms per motor command
|
||||
|
||||
## References
|
||||
|
||||
- [OpenArm Official Documentation](https://docs.openarm.dev/)
|
||||
- [OpenArm Setup Guide](https://docs.openarm.dev/software/setup/)
|
||||
- [Motor ID Configuration](https://docs.openarm.dev/software/setup/motor-id)
|
||||
- [CAN Interface Setup](https://docs.openarm.dev/software/setup/can-setup)
|
||||
- [Motor Communication Test](https://docs.openarm.dev/software/setup/configure-test)
|
||||
- [Damiao Motor Documentation](https://wiki.seeedstudio.com/damiao_series/)
|
||||
- [Enactic GitHub](https://github.com/enactic/openarm_can)
|
||||
@@ -79,7 +79,7 @@ After running the example:
|
||||
- Android: after starting the script, open the printed local URL on your phone, tap Start, then press and hold Move.
|
||||
- iOS: open HEBI Mobile I/O first; B1 enables motion. A3 controls the gripper.
|
||||
|
||||
Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop.mdx) guide.
|
||||
Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop) guide.
|
||||
|
||||
- Run this example to record a dataset, which saves absolute end effector observations and actions:
|
||||
|
||||
|
||||
79
docs/source/pi0.mdx
Normal file
79
docs/source/pi0.mdx
Normal file
@@ -0,0 +1,79 @@
|
||||
# π₀ (Pi0)
|
||||
|
||||
π₀ is a **Vision-Language-Action model for general robot control**, from Physical Intelligence. The LeRobot implementation is adapted from their open source [OpenPI](https://github.com/Physical-Intelligence/openpi) repository.
|
||||
|
||||
## Model Overview
|
||||
|
||||
π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi0). Unlike traditional robot programs that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks.
|
||||
|
||||
### The Vision for Physical Intelligence
|
||||
|
||||
As described by Physical Intelligence, while AI has achieved remarkable success in digital domains, from chess-playing to drug discovery, human intelligence still dramatically outpaces AI in the physical world. To paraphrase Moravec's paradox, winning a game of chess represents an "easy" problem for AI, but folding a shirt or cleaning up a table requires solving some of the most difficult engineering problems ever conceived. π₀ represents a first step toward developing artificial physical intelligence that enables users to simply ask robots to perform any task they want, just like they can with large language models.
|
||||
|
||||
### Architecture and Approach
|
||||
|
||||
π₀ combines several key innovations:
|
||||
|
||||
- **Flow Matching**: Uses a novel method to augment pre-trained VLMs with continuous action outputs via flow matching (a variant of diffusion models)
|
||||
- **Cross-Embodiment Training**: Trained on data from 8 distinct robot platforms including UR5e, Bimanual UR5e, Franka, Bimanual Trossen, Bimanual ARX, Mobile Trossen, and Mobile Fibocom
|
||||
- **Internet-Scale Pre-training**: Inherits semantic knowledge from a pre-trained 3B parameter Vision-Language Model
|
||||
- **High-Frequency Control**: Outputs motor commands at up to 50 Hz for real-time dexterous manipulation
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. Install Pi0 dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
## Training Data and Capabilities
|
||||
|
||||
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
|
||||
|
||||
1. **Internet-Scale Pre-training**: Vision-language data from the web for semantic understanding
|
||||
2. **Open X-Embodiment Dataset**: Open-source robot manipulation datasets
|
||||
3. **Physical Intelligence Dataset**: Large and diverse dataset of dexterous tasks across 8 distinct robots
|
||||
|
||||
## Usage
|
||||
|
||||
To use π₀ in LeRobot, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=pi0
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi0 \
|
||||
--output_dir=./outputs/pi0_training \
|
||||
--job_name=pi0_training \
|
||||
--policy.pretrained_path=lerobot/pi0_base \
|
||||
--policy.repo_id=your_repo_id \
|
||||
--policy.compile_model=true \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--batch_size=32
|
||||
```
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
- **`--policy.compile_model=true`**: Enables model compilation for faster training
|
||||
- **`--policy.gradient_checkpointing=true`**: Reduces memory usage significantly during training
|
||||
- **`--policy.dtype=bfloat16`**: Use mixed precision training for efficiency
|
||||
- **`--batch_size=32`**: Batch size for training, adapt this based on your GPU memory
|
||||
- **`--policy.pretrained_path=lerobot/pi0_base`**: The base π₀ model you want to finetune, options are:
|
||||
- [lerobot/pi0_base](https://huggingface.co/lerobot/pi0_base)
|
||||
- [lerobot/pi0_libero](https://huggingface.co/lerobot/pi0_libero) (specifically trained on the Libero dataset)
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
107
docs/source/pi05.mdx
Normal file
107
docs/source/pi05.mdx
Normal file
@@ -0,0 +1,107 @@
|
||||
# π₀.₅ (Pi05) Policy
|
||||
|
||||
π₀.₅ is a **Vision-Language-Action model with open-world generalization**, from Physical Intelligence. The LeRobot implementation is adapted from their open source [OpenPI](https://github.com/Physical-Intelligence/openpi) repository.
|
||||
|
||||
## Model Overview
|
||||
|
||||
π₀.₅ represents a significant evolution from π₀, developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi05) to address a big challenge in robotics: **open-world generalization**. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
|
||||
|
||||
### The Generalization Challenge
|
||||
|
||||
As Physical Intelligence explains, the fundamental challenge isn't performing tasks of agility or dexterity, but generalization, the ability to correctly perform tasks in new settings with new objects. Consider a robot cleaning different homes: each home has different objects in different places. Generalization must occur at multiple levels:
|
||||
|
||||
- **Physical Level**: Understanding how to pick up a spoon (by the handle) or plate (by the edge), even with unseen objects in cluttered environments
|
||||
- **Semantic Level**: Understanding task semantics, where to put clothes and shoes (laundry hamper, not on the bed), and what tools are appropriate for cleaning spills
|
||||
- **Environmental Level**: Adapting to "messy" real-world environments like homes, grocery stores, offices, and hospitals
|
||||
|
||||
### Co-Training on Heterogeneous Data
|
||||
|
||||
The breakthrough innovation in π₀.₅ is **co-training on heterogeneous data sources**. The model learns from:
|
||||
|
||||
1. **Multimodal Web Data**: Image captioning, visual question answering, object detection
|
||||
2. **Verbal Instructions**: Humans coaching robots through complex tasks step-by-step
|
||||
3. **Subtask Commands**: High-level semantic behavior labels (e.g., "pick up the pillow" for an unmade bed)
|
||||
4. **Cross-Embodiment Robot Data**: Data from various robot platforms with different capabilities
|
||||
5. **Multi-Environment Data**: Static robots deployed across many different homes
|
||||
6. **Mobile Manipulation Data**: ~400 hours of mobile robot demonstrations
|
||||
|
||||
This diverse training mixture creates a "curriculum" that enables generalization across physical, visual, and semantic levels simultaneously.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. Install Pi0.5 dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=pi05
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Training Command Example
|
||||
|
||||
Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py\
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi05 \
|
||||
--output_dir=./outputs/pi05_training \
|
||||
--job_name=pi05_training \
|
||||
--policy.repo_id=your_repo_id \
|
||||
--policy.pretrained_path=lerobot/pi05_base \
|
||||
--policy.compile_model=true \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--wandb.enable=true \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--batch_size=32
|
||||
```
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
- **`--policy.compile_model=true`**: Enables model compilation for faster training
|
||||
- **`--policy.gradient_checkpointing=true`**: Reduces memory usage significantly during training
|
||||
- **`--policy.dtype=bfloat16`**: Use mixed precision training for efficiency
|
||||
- **`--batch_size=32`**: Batch size for training, adapt this based on your GPU memory
|
||||
- **`--policy.pretrained_path=lerobot/pi05_base`**: The base π₀.₅ model you want to finetune, options are:
|
||||
- [lerobot/pi05_base](https://huggingface.co/lerobot/pi05_base)
|
||||
- [lerobot/pi05_libero](https://huggingface.co/lerobot/pi05_libero) (specifically trained on the Libero dataset)
|
||||
|
||||
If your dataset is not converted with `quantiles`, you can convert it with the following command:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
--repo-id=your_dataset \
|
||||
```
|
||||
|
||||
Or train pi05 with this normalization mapping: `--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'`
|
||||
|
||||
## Performance Results
|
||||
|
||||
### Libero Benchmark Results
|
||||
|
||||
π₀.₅ has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the libero base model for an additional 6k steps on the Libero dataset and compared the results to the OpenPI reference results.
|
||||
|
||||
| Benchmark | LeRobot Implementation | OpenPI Reference |
|
||||
| ------------------ | ---------------------- | ---------------- |
|
||||
| **Libero Spatial** | 97.0% | 98.8% |
|
||||
| **Libero Object** | 99.0% | 98.2% |
|
||||
| **Libero Goal** | 98.0% | 98.0% |
|
||||
| **Libero 10** | 96.0% | 92.4% |
|
||||
| **Average** | 97.5% | 96.85% |
|
||||
|
||||
These results demonstrate π₀.₅'s strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
27
docs/source/policy_groot_README.md
Normal file
27
docs/source/policy_groot_README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
## Research Paper
|
||||
|
||||
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
|
||||
## Repository
|
||||
|
||||
Code: https://github.com/NVIDIA/Isaac-GR00T
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{gr00tn1_2025,
|
||||
archivePrefix = {arxiv},
|
||||
eprint = {2503.14734},
|
||||
title = {{GR00T} {N1}: An Open Foundation Model for Generalist Humanoid Robots},
|
||||
author = {NVIDIA and Johan Bjorck andFernando Castañeda, Nikita Cherniadev and Xingye Da and Runyu Ding and Linxi "Jim" Fan and Yu Fang and Dieter Fox and Fengyuan Hu and Spencer Huang and Joel Jang and Zhenyu Jiang and Jan Kautz and Kaushil Kundalia and Lawrence Lao and Zhiqi Li and Zongyu Lin and Kevin Lin and Guilin Liu and Edith Llontop and Loic Magne and Ajay Mandlekar and Avnish Narayan and Soroush Nasiriany and Scott Reed and You Liang Tan and Guanzhi Wang and Zu Wang and Jing Wang and Qi Wang and Jiannan Xiang and Yuqi Xie and Yinzhen Xu and Zhenjia Xu and Seonghyeon Ye and Zhiding Yu and Ao Zhang and Hao Zhang and Yizhou Zhao and Ruijie Zheng and Yuke Zhu},
|
||||
month = {March},
|
||||
year = {2025},
|
||||
booktitle = {ArXiv Preprint},
|
||||
}
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
Blog: https://developer.nvidia.com/isaac/gr00t
|
||||
|
||||
Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B
|
||||
@@ -1,4 +1,4 @@
|
||||
# Finetune SmolVLA
|
||||
# SmolVLA
|
||||
|
||||
SmolVLA is Hugging Face’s lightweight foundation model for robotics. Designed for easy fine-tuning on LeRobot datasets, it helps accelerate your development!
|
||||
|
||||
|
||||
102
docs/source/using_dataset_tools.mdx
Normal file
102
docs/source/using_dataset_tools.mdx
Normal file
@@ -0,0 +1,102 @@
|
||||
# Using Dataset Tools
|
||||
|
||||
This guide covers the dataset tools utilities available in LeRobot for modifying and editing existing datasets.
|
||||
|
||||
## Overview
|
||||
|
||||
LeRobot provides several utilities for manipulating datasets:
|
||||
|
||||
1. **Delete Episodes** - Remove specific episodes from a dataset
|
||||
2. **Split Dataset** - Divide a dataset into multiple smaller datasets
|
||||
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
|
||||
## Command-Line Tool: lerobot-edit-dataset
|
||||
|
||||
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features.
|
||||
|
||||
Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
|
||||
|
||||
### Usage Examples
|
||||
|
||||
#### Delete Episodes
|
||||
|
||||
Remove specific episodes from a dataset. This is useful for filtering out undesired data.
|
||||
|
||||
```bash
|
||||
# Delete episodes 0, 2, and 5 (modifies original dataset)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
# Delete episodes and save to a new dataset (preserves original dataset)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_after_deletion \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
```
|
||||
|
||||
#### Split Dataset
|
||||
|
||||
Divide a dataset into multiple subsets.
|
||||
|
||||
```bash
|
||||
# Split by fractions (e.g. 80% train, 20% test, 20% val)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.8, "test": 0.2, "val": 0.2}'
|
||||
|
||||
# Split by specific episode indices
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"task1": [0, 1, 2, 3], "task2": [4, 5]}'
|
||||
```
|
||||
|
||||
There are no constraints on the split names, they can be determined by the user. Resulting datasets are saved under the repo id with the split name appended, e.g. `lerobot/pusht_train`, `lerobot/pusht_task1`, `lerobot/pusht_task2`.
|
||||
|
||||
#### Merge Datasets
|
||||
|
||||
Combine multiple datasets into a single dataset.
|
||||
|
||||
```bash
|
||||
# Merge train and validation splits back into one dataset
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_merged \
|
||||
--operation.type merge \
|
||||
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
||||
```
|
||||
|
||||
#### Remove Features
|
||||
|
||||
Remove features from a dataset.
|
||||
|
||||
```bash
|
||||
# Remove a camera feature
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
```
|
||||
|
||||
### Push to Hub
|
||||
|
||||
Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
|
||||
```bash
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_after_deletion \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.
|
||||
@@ -132,17 +132,15 @@ print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
||||
|
||||
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
|
||||
# PyTorch datasets.
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
break
|
||||
if __name__ == "__main__":
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
)
|
||||
for batch in dataloader:
|
||||
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
break
|
||||
|
||||
124
examples/dataset/use_dataset_tools.py
Normal file
124
examples/dataset/use_dataset_tools.py
Normal file
@@ -0,0 +1,124 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example script demonstrating dataset tools utilities.
|
||||
|
||||
This script shows how to:
|
||||
1. Delete episodes from a dataset
|
||||
2. Split a dataset into train/val sets
|
||||
3. Add/remove features
|
||||
4. Merge datasets
|
||||
|
||||
Usage:
|
||||
python examples/dataset/use_dataset_tools.py
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
add_features,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
modify_features,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def main():
|
||||
dataset = LeRobotDataset("lerobot/pusht")
|
||||
|
||||
print(f"Original dataset: {dataset.meta.total_episodes} episodes, {dataset.meta.total_frames} frames")
|
||||
print(f"Features: {list(dataset.meta.features.keys())}")
|
||||
|
||||
print("\n1. Deleting episodes 0 and 2...")
|
||||
filtered_dataset = delete_episodes(dataset, episode_indices=[0, 2], repo_id="lerobot/pusht_filtered")
|
||||
print(f"Filtered dataset: {filtered_dataset.meta.total_episodes} episodes")
|
||||
|
||||
print("\n2. Splitting dataset into train/val...")
|
||||
splits = split_dataset(
|
||||
dataset,
|
||||
splits={"train": 0.8, "val": 0.2},
|
||||
)
|
||||
print(f"Train split: {splits['train'].meta.total_episodes} episodes")
|
||||
print(f"Val split: {splits['val'].meta.total_episodes} episodes")
|
||||
|
||||
print("\n3. Adding features...")
|
||||
|
||||
reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32)
|
||||
|
||||
def compute_success(row_dict, episode_index, frame_index):
|
||||
episode_length = 10
|
||||
return float(frame_index >= episode_length - 10)
|
||||
|
||||
dataset_with_features = add_features(
|
||||
dataset,
|
||||
features={
|
||||
"reward": (
|
||||
reward_values,
|
||||
{"dtype": "float32", "shape": (1,), "names": None},
|
||||
),
|
||||
"success": (
|
||||
compute_success,
|
||||
{"dtype": "float32", "shape": (1,), "names": None},
|
||||
),
|
||||
},
|
||||
repo_id="lerobot/pusht_with_features",
|
||||
)
|
||||
|
||||
print(f"New features: {list(dataset_with_features.meta.features.keys())}")
|
||||
|
||||
print("\n4. Removing the success feature...")
|
||||
dataset_cleaned = remove_feature(
|
||||
dataset_with_features, feature_names="success", repo_id="lerobot/pusht_cleaned"
|
||||
)
|
||||
print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}")
|
||||
|
||||
print("\n5. Using modify_features to add and remove features simultaneously...")
|
||||
dataset_modified = modify_features(
|
||||
dataset_with_features,
|
||||
add_features={
|
||||
"discount": (
|
||||
np.ones(dataset.meta.total_frames, dtype=np.float32) * 0.99,
|
||||
{"dtype": "float32", "shape": (1,), "names": None},
|
||||
),
|
||||
},
|
||||
remove_features="reward",
|
||||
repo_id="lerobot/pusht_modified",
|
||||
)
|
||||
print(f"Modified features: {list(dataset_modified.meta.features.keys())}")
|
||||
|
||||
print("\n6. Merging train and val splits back together...")
|
||||
merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged")
|
||||
print(f"Merged dataset: {merged.meta.total_episodes} episodes")
|
||||
|
||||
print("\n7. Complex workflow example...")
|
||||
|
||||
if len(dataset.meta.camera_keys) > 1:
|
||||
camera_to_remove = dataset.meta.camera_keys[0]
|
||||
print(f"Removing camera: {camera_to_remove}")
|
||||
dataset_no_cam = remove_feature(
|
||||
dataset, feature_names=camera_to_remove, repo_id="pusht_no_first_camera"
|
||||
)
|
||||
print(f"Remaining cameras: {dataset_no_cam.meta.camera_keys}")
|
||||
|
||||
print("\nDone! Check ~/.cache/huggingface/lerobot/ for the created datasets.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -133,4 +133,6 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
@@ -130,4 +130,6 @@ robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
416
examples/openarms/debug_can_communication.py
Normal file
416
examples/openarms/debug_can_communication.py
Normal file
@@ -0,0 +1,416 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive debug script for OpenArms CAN FD communication.
|
||||
Tests all 4 CAN interfaces with CAN FD support.
|
||||
"""
|
||||
|
||||
import can
|
||||
import time
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
def check_can_interface(port):
|
||||
"""Check if CAN interface is UP and configured."""
|
||||
try:
|
||||
result = subprocess.run(['ip', 'link', 'show', port],
|
||||
capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
return False, "Interface not found", None
|
||||
|
||||
output = result.stdout
|
||||
if 'UP' not in output:
|
||||
return False, "Interface is DOWN", None
|
||||
|
||||
# Check if CAN FD is enabled
|
||||
is_fd = 'fd on' in output.lower() or 'canfd' in output.lower()
|
||||
|
||||
return True, "Interface is UP", is_fd
|
||||
except FileNotFoundError:
|
||||
return None, "Cannot check (ip command not found)", None
|
||||
|
||||
|
||||
def test_motor_on_interface(bus, motor_id, timeout=2.0, use_fd=False):
|
||||
"""
|
||||
Test a single motor and return all responses.
|
||||
|
||||
Returns:
|
||||
list of (arbitration_id, data) tuples for all responses received
|
||||
"""
|
||||
# Send enable command
|
||||
enable_msg = can.Message(
|
||||
arbitration_id=motor_id,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
|
||||
is_extended_id=False,
|
||||
is_fd=use_fd
|
||||
)
|
||||
|
||||
try:
|
||||
bus.send(enable_msg)
|
||||
except Exception as e:
|
||||
return None, f"Send error: {e}"
|
||||
|
||||
# Listen for responses
|
||||
responses = []
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
msg = bus.recv(timeout=0.1)
|
||||
if msg:
|
||||
responses.append((msg.arbitration_id, msg.data, msg.is_fd if hasattr(msg, 'is_fd') else False))
|
||||
|
||||
# Send disable command
|
||||
disable_msg = can.Message(
|
||||
arbitration_id=motor_id,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD],
|
||||
is_extended_id=False,
|
||||
is_fd=use_fd
|
||||
)
|
||||
try:
|
||||
bus.send(disable_msg)
|
||||
except:
|
||||
pass
|
||||
|
||||
return responses, None
|
||||
|
||||
|
||||
def test_interface(port, interface_type="socketcan", use_can_fd=True):
|
||||
"""Test all 8 motors on a single CAN interface."""
|
||||
|
||||
results = {
|
||||
'interface': port,
|
||||
'status': None,
|
||||
'is_fd': use_can_fd,
|
||||
'motors': {}
|
||||
}
|
||||
|
||||
# Check interface status
|
||||
status_ok, status_msg, interface_has_fd = check_can_interface(port)
|
||||
|
||||
if interface_has_fd is not None:
|
||||
results['interface_fd_enabled'] = interface_has_fd
|
||||
if use_can_fd and not interface_has_fd:
|
||||
status_msg += " (CAN FD NOT enabled on interface!)"
|
||||
elif interface_has_fd:
|
||||
status_msg += " (CAN FD enabled)"
|
||||
|
||||
results['status'] = status_msg
|
||||
|
||||
if status_ok is False:
|
||||
return results
|
||||
|
||||
# Try to connect
|
||||
try:
|
||||
if use_can_fd:
|
||||
print(f" Connecting to {port} with CAN FD (1 Mbps / 5 Mbps)...")
|
||||
bus = can.interface.Bus(
|
||||
channel=port,
|
||||
interface=interface_type,
|
||||
bitrate=1000000,
|
||||
data_bitrate=5000000,
|
||||
fd=True
|
||||
)
|
||||
else:
|
||||
print(f" Connecting to {port} with CAN 2.0 (1 Mbps)...")
|
||||
bus = can.interface.Bus(
|
||||
channel=port,
|
||||
interface=interface_type,
|
||||
bitrate=1000000
|
||||
)
|
||||
except Exception as e:
|
||||
results['status'] = f"Connection failed: {e}"
|
||||
return results
|
||||
|
||||
try:
|
||||
# Clear any pending messages
|
||||
while bus.recv(timeout=0.01):
|
||||
pass
|
||||
|
||||
# Test each motor (0x01 to 0x08)
|
||||
for motor_id in range(0x01, 0x09):
|
||||
responses, error = test_motor_on_interface(bus, motor_id, timeout=1.0, use_fd=use_can_fd)
|
||||
|
||||
if error:
|
||||
results['motors'][motor_id] = {'error': error}
|
||||
elif responses:
|
||||
results['motors'][motor_id] = {
|
||||
'found': True,
|
||||
'responses': responses
|
||||
}
|
||||
else:
|
||||
results['motors'][motor_id] = {
|
||||
'found': False,
|
||||
'responses': []
|
||||
}
|
||||
|
||||
time.sleep(0.05) # Small delay between motors
|
||||
|
||||
finally:
|
||||
bus.shutdown()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def print_results(all_results):
|
||||
"""Print formatted results for all interfaces."""
|
||||
|
||||
print("SUMMARY - Motors Found on Each Interface")
|
||||
|
||||
motor_names = {
|
||||
0x01: "joint_1 (Shoulder pan)",
|
||||
0x02: "joint_2 (Shoulder lift)",
|
||||
0x03: "joint_3 (Shoulder rotation)",
|
||||
0x04: "joint_4 (Elbow flex)",
|
||||
0x05: "joint_5 (Wrist roll)",
|
||||
0x06: "joint_6 (Wrist pitch)",
|
||||
0x07: "joint_7 (Wrist rotation)",
|
||||
0x08: "gripper",
|
||||
}
|
||||
|
||||
total_found = 0
|
||||
|
||||
for result in all_results:
|
||||
interface = result['interface']
|
||||
status = result['status']
|
||||
|
||||
print(f"{interface}: {status}")
|
||||
if result.get('is_fd'):
|
||||
print(f" Mode: CAN FD")
|
||||
else:
|
||||
print(f" Mode: CAN 2.0")
|
||||
|
||||
if 'Connection failed' in status or 'DOWN' in status:
|
||||
print(f" ⚠ Cannot test {interface}")
|
||||
continue
|
||||
|
||||
motors_found = 0
|
||||
|
||||
for motor_id in range(0x01, 0x09):
|
||||
motor_data = result['motors'].get(motor_id, {})
|
||||
motor_name = motor_names.get(motor_id, "Unknown")
|
||||
|
||||
if motor_data.get('error'):
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ {motor_data['error']}")
|
||||
elif motor_data.get('found'):
|
||||
motors_found += 1
|
||||
total_found += 1
|
||||
responses = motor_data['responses']
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✓ FOUND")
|
||||
|
||||
for resp_id, data, is_fd in responses:
|
||||
data_hex = data.hex()
|
||||
fd_flag = " [FD]" if is_fd else " [2.0]"
|
||||
print(f" → Response from 0x{resp_id:02X}{fd_flag}: {data_hex}")
|
||||
else:
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ No response")
|
||||
|
||||
print(f"\n Summary: {motors_found}/8 motors found on {interface}")
|
||||
|
||||
# Overall summary
|
||||
print("OVERALL SUMMARY")
|
||||
print(f"Total motors found across all interfaces: {total_found}")
|
||||
|
||||
# Analyze configuration
|
||||
print("DIAGNOSIS")
|
||||
|
||||
for result in all_results:
|
||||
interface = result['interface']
|
||||
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
|
||||
|
||||
if motors_found == 0:
|
||||
print(f"\n⚠ {interface}: NO MOTORS FOUND")
|
||||
print(" Possible issues:")
|
||||
print(" 1. CAN FD mode mismatch (interface vs motor configuration)")
|
||||
print(" 2. Missing 120Ω termination resistors at BOTH cable ends")
|
||||
print(" 3. Motor timeout parameter set incorrectly (should NOT be 0)")
|
||||
print(" 4. CANH/CANL wiring issue")
|
||||
print(" 5. Cable too long (>40m for CAN FD at 5Mbps)")
|
||||
|
||||
# Check FD mismatch
|
||||
if result.get('is_fd') and not result.get('interface_fd_enabled'):
|
||||
print(" ⚠️ CRITICAL: Trying CAN FD but interface NOT configured for FD!")
|
||||
print(f" Fix: sudo ip link set {interface} type can bitrate 1000000 dbitrate 5000000 fd on")
|
||||
|
||||
elif motors_found < 8:
|
||||
print(f"\n⚠ {interface}: Only {motors_found}/8 motors responding")
|
||||
print(" Check power and connections for missing motors")
|
||||
else:
|
||||
print(f"\n✓ {interface}: All 8 motors responding correctly!")
|
||||
|
||||
# Check for unexpected response IDs
|
||||
print("RESPONSE ID ANALYSIS")
|
||||
|
||||
for result in all_results:
|
||||
interface = result['interface']
|
||||
unexpected = []
|
||||
|
||||
for motor_id, motor_data in result['motors'].items():
|
||||
if motor_data.get('found'):
|
||||
expected_id = motor_id + 0x10
|
||||
actual_ids = [resp[0] for resp in motor_data['responses']]
|
||||
|
||||
if expected_id not in actual_ids:
|
||||
unexpected.append((motor_id, actual_ids))
|
||||
|
||||
if unexpected:
|
||||
print(f"\n⚠ {interface}: Unexpected response IDs detected")
|
||||
for motor_id, actual_ids in unexpected:
|
||||
expected_id = motor_id + 0x10
|
||||
print(f" Motor 0x{motor_id:02X}: Expected 0x{expected_id:02X}, "
|
||||
f"got {[f'0x{id:02X}' for id in actual_ids]}")
|
||||
print(" → Motor Master IDs need reconfiguration")
|
||||
else:
|
||||
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
|
||||
if motors_found > 0:
|
||||
print(f"\n✓ {interface}: All responding motors use correct IDs")
|
||||
|
||||
|
||||
def test_communication_speed(interface, motor_id, num_iterations=100):
|
||||
"""
|
||||
Test communication speed with a motor.
|
||||
|
||||
Returns:
|
||||
tuple: (hz, avg_latency_ms) or (None, None) if test failed
|
||||
"""
|
||||
try:
|
||||
# Connect to interface
|
||||
bus = can.interface.Bus(
|
||||
channel=interface,
|
||||
interface="socketcan",
|
||||
bitrate=1000000,
|
||||
data_bitrate=5000000,
|
||||
fd=True
|
||||
)
|
||||
|
||||
# Send refresh commands and measure round-trip time
|
||||
latencies = []
|
||||
successful = 0
|
||||
|
||||
for _ in range(num_iterations):
|
||||
start = time.perf_counter()
|
||||
|
||||
# Send enable command (lightweight operation)
|
||||
enable_msg = can.Message(
|
||||
arbitration_id=motor_id,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
|
||||
is_extended_id=False,
|
||||
is_fd=True
|
||||
)
|
||||
bus.send(enable_msg)
|
||||
|
||||
# Wait for response
|
||||
msg = bus.recv(timeout=0.1)
|
||||
|
||||
if msg:
|
||||
latency = (time.perf_counter() - start) * 1000 # Convert to ms
|
||||
latencies.append(latency)
|
||||
successful += 1
|
||||
|
||||
bus.shutdown()
|
||||
|
||||
if successful > 0:
|
||||
avg_latency = sum(latencies) / len(latencies)
|
||||
hz = 1000.0 / avg_latency if avg_latency > 0 else 0
|
||||
return hz, avg_latency
|
||||
|
||||
return None, None
|
||||
|
||||
except Exception as e:
|
||||
print(f" Speed test error: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to test all CAN interfaces with CAN FD."""
|
||||
|
||||
print("\nThis will test all 4 CAN interfaces (can0-can3) with CAN FD")
|
||||
print("Testing motors 0x01-0x08 on each interface")
|
||||
print()
|
||||
print("Make sure:")
|
||||
print(" ✓ Motors are powered (24V)")
|
||||
print(" ✓ CAN interfaces configured with FD mode:")
|
||||
print(" ./examples/openarms/setup_can.sh")
|
||||
print(" ✓ Motor 'timeout' parameter NOT set to 0 (use Damiao tools)")
|
||||
print(" ✓ CAN wiring includes 120Ω termination at BOTH ends")
|
||||
print()
|
||||
|
||||
input("Press ENTER to start testing...")
|
||||
|
||||
# Test all 4 interfaces with CAN FD
|
||||
all_results = []
|
||||
|
||||
for i in range(4):
|
||||
interface = f"can{i}"
|
||||
print(f"Testing {interface}...")
|
||||
|
||||
result = test_interface(interface, use_can_fd=True)
|
||||
all_results.append(result)
|
||||
|
||||
# Quick status
|
||||
if 'Connection failed' in result['status'] or 'DOWN' in result['status']:
|
||||
print(f" ⚠ {interface}: {result['status']}")
|
||||
else:
|
||||
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
|
||||
print(f" {interface}: {motors_found}/8 motors found")
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
# Print detailed results
|
||||
print_results(all_results)
|
||||
|
||||
print("Testing Complete!")
|
||||
|
||||
all_found = sum(sum(1 for m in r['motors'].values() if m.get('found')) for r in all_results)
|
||||
|
||||
if all_found == 0:
|
||||
print("\n⚠️ CRITICAL: No motors found on any interface!")
|
||||
print("\nTop issues to check:")
|
||||
print(" 1. Motor 'timeout' parameter (use Damiao tools to set > 0)")
|
||||
print(" 2. CAN FD not enabled (run ./examples/openarms/setup_can.sh)")
|
||||
print(" 3. Missing termination resistors")
|
||||
print("\nTry:")
|
||||
print(" a) Check motor parameters with Damiao Debugging Tools")
|
||||
print(" b) Verify CAN FD is enabled: ip -d link show can0 | grep fd")
|
||||
print(" c) Run setup script: ./examples/openarms/setup_can.sh")
|
||||
else:
|
||||
# Run speed test on interfaces with motors
|
||||
print("COMMUNICATION SPEED TEST")
|
||||
print("\nTesting maximum communication frequency...")
|
||||
|
||||
for result in all_results:
|
||||
interface = result['interface']
|
||||
|
||||
# Find first responding motor
|
||||
responding_motor = None
|
||||
for motor_id, motor_data in result['motors'].items():
|
||||
if motor_data.get('found'):
|
||||
responding_motor = motor_id
|
||||
break
|
||||
|
||||
if responding_motor:
|
||||
print(f"\n{interface}: Testing with motor 0x{responding_motor:02X}...")
|
||||
hz, latency = test_communication_speed(interface, responding_motor, num_iterations=100)
|
||||
|
||||
if hz:
|
||||
print(f" ✓ Max frequency: {hz:.1f} Hz")
|
||||
print(f" ✓ Avg latency: {latency:.2f} ms")
|
||||
print(f" ✓ Commands per second: ~{int(hz)}")
|
||||
else:
|
||||
print(f" ✗ Speed test failed")
|
||||
else:
|
||||
print(f"\n{interface}: No motors found, skipping speed test")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nTesting interrupted by user.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\nUnexpected error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
360
examples/openarms/evaluate.py
Normal file
360
examples/openarms/evaluate.py
Normal file
@@ -0,0 +1,360 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OpenArms Policy Evaluation
|
||||
|
||||
Evaluates a trained policy on the OpenArms robot by running inference and recording
|
||||
the evaluation episodes to a dataset. Supports optional leader arm for manual resets.
|
||||
|
||||
Example usage:
|
||||
python examples/openarms/evaluate.py
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
|
||||
HF_MODEL_ID = "lerobot-data-collection/three-folds-pi0" # TODO: Replace with your trained model
|
||||
HF_EVAL_DATASET_ID = "lerobot-data-collection/three-folds-pi0_eval7" # TODO: Replace with your eval dataset name
|
||||
TASK_DESCRIPTION = "three-folds-dataset" # TODO: Replace with your task, this should match!!
|
||||
|
||||
NUM_EPISODES = 1
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 300
|
||||
RESET_TIME_SEC = 60
|
||||
|
||||
# Robot CAN interfaces
|
||||
FOLLOWER_LEFT_PORT = "can0"
|
||||
FOLLOWER_RIGHT_PORT = "can1"
|
||||
|
||||
# If enabled, you can manually reset the environment between evaluation episodes
|
||||
USE_LEADER_FOR_RESETS = True # Set to False if you don't want to use leader
|
||||
LEADER_LEFT_PORT = "can2"
|
||||
LEADER_RIGHT_PORT = "can3"
|
||||
|
||||
# Camera configuration
|
||||
CAMERA_CONFIG = {
|
||||
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=FPS),
|
||||
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=FPS),
|
||||
"base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=FPS),
|
||||
}
|
||||
|
||||
def main():
|
||||
"""Main evaluation function."""
|
||||
print("OpenArms Policy Evaluation")
|
||||
print(f"\nModel: {HF_MODEL_ID}")
|
||||
print(f"Evaluation Dataset: {HF_EVAL_DATASET_ID}")
|
||||
print(f"Task: {TASK_DESCRIPTION}")
|
||||
print(f"Episodes: {NUM_EPISODES}")
|
||||
print(f"Episode Duration: {EPISODE_TIME_SEC}s")
|
||||
print(f"Reset Duration: {RESET_TIME_SEC}s")
|
||||
print(f"Use Leader for Resets: {USE_LEADER_FOR_RESETS}")
|
||||
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left=FOLLOWER_LEFT_PORT,
|
||||
port_right=FOLLOWER_RIGHT_PORT,
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0,
|
||||
cameras=CAMERA_CONFIG,
|
||||
)
|
||||
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
follower.connect(calibrate=False)
|
||||
|
||||
if not follower.is_connected:
|
||||
raise RuntimeError("Follower robot failed to connect!")
|
||||
|
||||
|
||||
leader = None
|
||||
if USE_LEADER_FOR_RESETS:
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
port_left=LEADER_LEFT_PORT,
|
||||
port_right=LEADER_RIGHT_PORT,
|
||||
can_interface="socketcan",
|
||||
id="openarms_leader",
|
||||
manual_control=False, # Enable torque control for gravity compensation
|
||||
)
|
||||
|
||||
leader = OpenArmsLeader(leader_config)
|
||||
leader.connect(calibrate=False)
|
||||
|
||||
if not leader.is_connected:
|
||||
raise RuntimeError("Leader robot failed to connect!")
|
||||
|
||||
# Enable gravity compensation
|
||||
if leader.pin_robot is not None:
|
||||
leader.bus_right.enable_torque()
|
||||
leader.bus_left.enable_torque()
|
||||
time.sleep(0.1)
|
||||
print(f"Leader connected with gravity compensation ({LEADER_LEFT_PORT}, {LEADER_RIGHT_PORT})")
|
||||
else:
|
||||
print(f"Leader connected but gravity compensation unavailable (no URDF)")
|
||||
|
||||
# Build default processors for action and observation
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Build dataset features from robot features and processors
|
||||
# For actions, only include positions (no velocity or torque)
|
||||
action_features_hw = {}
|
||||
for key, value in follower.action_features.items():
|
||||
if key.endswith(".pos"):
|
||||
action_features_hw[key] = value
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(action=action_features_hw),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_observation_processor,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Check if dataset already exists
|
||||
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / HF_EVAL_DATASET_ID
|
||||
if dataset_path.exists():
|
||||
print(f"Evaluation dataset already exists at: {dataset_path}")
|
||||
print("This will append new episodes to the existing dataset.")
|
||||
choice = input(" Continue? (y/n): ").strip().lower()
|
||||
if choice != 'y':
|
||||
print(" Aborting evaluation.")
|
||||
follower.disconnect()
|
||||
if leader:
|
||||
leader.disconnect()
|
||||
return
|
||||
|
||||
# Create dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_EVAL_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_processes=0,
|
||||
image_writer_threads=12,
|
||||
)
|
||||
|
||||
# Load policy config from pretrained model and create policy using factory
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
policy = make_policy(policy_config, ds_meta=dataset.meta)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy.config,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": str(policy.config.device)}
|
||||
},
|
||||
)
|
||||
|
||||
print(f"\nRunning evaluation...")
|
||||
# Initialize keyboard listener and visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="openarms_evaluation")
|
||||
episode_idx = 0
|
||||
|
||||
try:
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Evaluating episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print(f"\nRunning inference for episode {episode_idx + 1}...")
|
||||
|
||||
# Run inference with policy
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
# Handle re-recording
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Save episode
|
||||
if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0:
|
||||
print(f"Saving episode {episode_idx + 1} ({dataset.episode_buffer['size']} frames)...")
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Reset environment between episodes (if not last episode)
|
||||
if not events["stop_recording"] and episode_idx < NUM_EPISODES:
|
||||
if USE_LEADER_FOR_RESETS and leader:
|
||||
log_say("Reset the environment using leader arms")
|
||||
print(f"\nManual reset period ({RESET_TIME_SEC}s)...")
|
||||
|
||||
# Use leader for manual reset with gravity compensation
|
||||
import numpy as np
|
||||
|
||||
dt = 1 / FPS
|
||||
reset_start_time = time.perf_counter()
|
||||
|
||||
while time.perf_counter() - reset_start_time < RESET_TIME_SEC:
|
||||
if events["exit_early"] or events["stop_recording"]:
|
||||
break
|
||||
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get leader state
|
||||
leader_action = leader.get_action()
|
||||
|
||||
# Extract positions and velocities
|
||||
leader_positions_deg = {}
|
||||
leader_velocities_deg_per_sec = {}
|
||||
|
||||
for motor in leader.bus_right.motors:
|
||||
pos_key = f"right_{motor}.pos"
|
||||
vel_key = f"right_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
|
||||
|
||||
for motor in leader.bus_left.motors:
|
||||
pos_key = f"left_{motor}.pos"
|
||||
vel_key = f"left_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
|
||||
|
||||
# Calculate gravity and friction torques
|
||||
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
|
||||
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
|
||||
|
||||
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
|
||||
leader_friction_torques_nm = leader._friction_from_velocity(
|
||||
leader_velocities_rad_per_sec,
|
||||
friction_scale=1.0
|
||||
)
|
||||
|
||||
# Combine torques
|
||||
leader_total_torques_nm = {}
|
||||
for motor_name in leader_gravity_torques_nm:
|
||||
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
|
||||
friction = leader_friction_torques_nm.get(motor_name, 0.0)
|
||||
leader_total_torques_nm[motor_name] = gravity + friction
|
||||
|
||||
# Apply compensation
|
||||
for motor in leader.bus_right.motors:
|
||||
full_name = f"right_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_right._mit_control(
|
||||
motor=motor, kp=0.0, kd=kd,
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
for motor in leader.bus_left.motors:
|
||||
full_name = f"left_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_left._mit_control(
|
||||
motor=motor, kp=0.0, kd=kd,
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Send leader positions to follower
|
||||
follower_action = {}
|
||||
for joint in leader_positions_deg.keys():
|
||||
pos_key = f"{joint}.pos"
|
||||
if pos_key in leader_action:
|
||||
follower_action[pos_key] = leader_action[pos_key]
|
||||
|
||||
if follower_action:
|
||||
follower.send_action(follower_action)
|
||||
|
||||
# Maintain loop rate
|
||||
loop_duration = time.perf_counter() - loop_start
|
||||
sleep_time = dt - loop_duration
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
print("Reset complete")
|
||||
else:
|
||||
log_say("Waiting for manual reset")
|
||||
print(f"Manually reset the environment and press ENTER to continue")
|
||||
input("Press ENTER when ready...")
|
||||
|
||||
print(f"Evaluation complete! {episode_idx} episodes recorded")
|
||||
log_say("Evaluation complete", blocking=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nEvaluation interrupted by user")
|
||||
|
||||
finally:
|
||||
if leader:
|
||||
leader.bus_right.disable_torque()
|
||||
leader.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
leader.disconnect()
|
||||
|
||||
follower.disconnect()
|
||||
|
||||
if listener is not None:
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
print("\nUploading to Hugging Face Hub...")
|
||||
dataset.push_to_hub(private=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
216
examples/openarms/friction_compensation.py
Normal file
216
examples/openarms/friction_compensation.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
|
||||
|
||||
# Friction model parameters from OpenArms config/follower.yaml
|
||||
# τ_fric(ω) = Fo + Fv·ω + Fc·tanh(k·ω)
|
||||
# For 8 motors: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
|
||||
FRICTION_PARAMS = {
|
||||
"Fc": [0.306, 0.306, 0.40, 0.166, 0.050, 0.093, 0.172, 0.0512], # Coulomb friction [Nm]
|
||||
"k": [28.417, 28.417, 29.065, 130.038, 151.771, 242.287, 7.888, 4.000], # tanh steepness
|
||||
"Fv": [0.063, 0.0630, 0.604, 0.813, 0.029, 0.072, 0.084, 0.084], # Viscous friction [Nm·s/rad]
|
||||
"Fo": [0.088, 0.088, 0.008, -0.058, 0.005, 0.009, -0.059, -0.050], # Offset torque [Nm]
|
||||
}
|
||||
|
||||
# Constants from OpenArms C++ implementation
|
||||
AMP_TMP = 1.0
|
||||
COEF_TMP = 0.1
|
||||
|
||||
FRICTION_SCALE = 1.0 # OpenArms C++ uses 0.3 factor in unilateral mode
|
||||
DAMPING_KD = [0.5, 0.5, 0.5, 0.5, 0.1, 0.1, 0.1, 0.1] # Damping gains for stability
|
||||
|
||||
def compute_friction_torque(velocity_rad_per_sec: float, motor_index: int) -> float:
|
||||
"""
|
||||
Compute friction torque for a single motor using the tanh friction model.
|
||||
|
||||
Args:
|
||||
velocity_rad_per_sec: Angular velocity in rad/s
|
||||
motor_index: Index of the motor (0-7)
|
||||
|
||||
Returns:
|
||||
Friction torque in N·m (scaled for stability)
|
||||
"""
|
||||
|
||||
Fc = FRICTION_PARAMS["Fc"][motor_index]
|
||||
k = FRICTION_PARAMS["k"][motor_index]
|
||||
Fv = FRICTION_PARAMS["Fv"][motor_index]
|
||||
Fo = FRICTION_PARAMS["Fo"][motor_index]
|
||||
|
||||
# Friction model: τ_fric = amp * Fc * tanh(coef * k * ω) + Fv * ω + Fo
|
||||
friction_torque = (
|
||||
AMP_TMP * Fc * np.tanh(COEF_TMP * k * velocity_rad_per_sec) +
|
||||
Fv * velocity_rad_per_sec +
|
||||
Fo
|
||||
)
|
||||
|
||||
# Scale down friction compensation for stability at lower control rates
|
||||
# (OpenArms C++ uses 0.3 factor in unilateral mode)!!
|
||||
friction_torque *= FRICTION_SCALE
|
||||
|
||||
return friction_torque
|
||||
|
||||
|
||||
def main() -> None:
|
||||
config = OpenArmsFollowerConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=5.0,
|
||||
)
|
||||
|
||||
print("Initializing robot...")
|
||||
follower = OpenArmsFollower(config)
|
||||
follower.connect(calibrate=True)
|
||||
|
||||
print(f"Applying friction compensation")
|
||||
print(" 1. Support the arm before starting")
|
||||
print(" 2. The arm will be held in place by friction compensation")
|
||||
print(" 3. You should be able to move it with gentle force")
|
||||
print("\nPress ENTER when ready to start...")
|
||||
input()
|
||||
|
||||
print(f"✓ Motors enabled")
|
||||
print("\nStarting friction compensation loop...")
|
||||
print("Press Ctrl+C to stop\n")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = time.perf_counter()
|
||||
|
||||
# Motor name to index mapping
|
||||
motor_name_to_index = {
|
||||
"joint_1": 0,
|
||||
"joint_2": 1,
|
||||
"joint_3": 2,
|
||||
"joint_4": 3,
|
||||
"joint_5": 4,
|
||||
"joint_6": 5,
|
||||
"joint_7": 6,
|
||||
"gripper": 7,
|
||||
}
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get current joint positions and velocities from robot
|
||||
obs = follower.get_observation()
|
||||
|
||||
# Extract velocities in degrees per second
|
||||
velocities_deg_per_sec = {}
|
||||
positions_deg = {}
|
||||
|
||||
for motor in follower.bus_right.motors:
|
||||
vel_key = f"right_{motor}.vel"
|
||||
pos_key = f"right_{motor}.pos"
|
||||
if vel_key in obs:
|
||||
velocities_deg_per_sec[f"right_{motor}"] = obs[vel_key]
|
||||
if pos_key in obs:
|
||||
positions_deg[f"right_{motor}"] = obs[pos_key]
|
||||
|
||||
for motor in follower.bus_left.motors:
|
||||
vel_key = f"left_{motor}.vel"
|
||||
pos_key = f"left_{motor}.pos"
|
||||
if vel_key in obs:
|
||||
velocities_deg_per_sec[f"left_{motor}"] = obs[vel_key]
|
||||
if pos_key in obs:
|
||||
positions_deg[f"left_{motor}"] = obs[pos_key]
|
||||
|
||||
# Convert velocities to rad/s and compute friction torques
|
||||
friction_torques_nm = {}
|
||||
for motor_full_name, velocity_deg_per_sec in velocities_deg_per_sec.items():
|
||||
# Extract motor name without arm prefix
|
||||
if motor_full_name.startswith("right_"):
|
||||
motor_name = motor_full_name.removeprefix("right_")
|
||||
elif motor_full_name.startswith("left_"):
|
||||
motor_name = motor_full_name.removeprefix("left_")
|
||||
else:
|
||||
continue
|
||||
|
||||
# Get motor index for friction parameters
|
||||
motor_index = motor_name_to_index.get(motor_name, 0)
|
||||
|
||||
# Convert velocity to rad/s
|
||||
velocity_rad_per_sec = np.deg2rad(velocity_deg_per_sec)
|
||||
|
||||
# Compute friction torque
|
||||
friction_torque = compute_friction_torque(velocity_rad_per_sec, motor_index)
|
||||
friction_torques_nm[motor_full_name] = friction_torque
|
||||
|
||||
# Apply friction compensation to right arm (all joints INCLUDING gripper)
|
||||
for motor in follower.bus_right.motors:
|
||||
full_name = f"right_{motor}"
|
||||
position = positions_deg.get(full_name, 0.0)
|
||||
torque = friction_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get motor index for damping gain
|
||||
motor_index = motor_name_to_index.get(motor, 0)
|
||||
kd = DAMPING_KD[motor_index]
|
||||
|
||||
# Send MIT control command with friction compensation + damping
|
||||
follower.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0, # No position control
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque
|
||||
)
|
||||
|
||||
# Apply friction compensation to left arm (all joints INCLUDING gripper)
|
||||
for motor in follower.bus_left.motors:
|
||||
full_name = f"left_{motor}"
|
||||
position = positions_deg.get(full_name, 0.0)
|
||||
torque = friction_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get motor index for damping gain
|
||||
motor_index = motor_name_to_index.get(motor, 0)
|
||||
kd = DAMPING_KD[motor_index]
|
||||
|
||||
# Send MIT control command with friction compensation + damping
|
||||
follower.bus_left._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0, # No position control
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque
|
||||
)
|
||||
|
||||
# Measure loop time
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
# Print status every 2 seconds
|
||||
if loop_end - last_print_time >= 2.0:
|
||||
if loop_times:
|
||||
avg_time = sum(loop_times) / len(loop_times)
|
||||
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||
|
||||
print(f"{current_hz:.1f} Hz")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = loop_end
|
||||
|
||||
time.sleep(0.001)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping friction compensation...")
|
||||
|
||||
finally:
|
||||
print("\nDisabling all motors and disconnecting...")
|
||||
follower.bus_right.disable_torque()
|
||||
follower.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
follower.disconnect()
|
||||
print("✓ Safe shutdown complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
142
examples/openarms/gravity_compensation.py
Executable file
142
examples/openarms/gravity_compensation.py
Executable file
@@ -0,0 +1,142 @@
|
||||
import time
|
||||
import numpy as np
|
||||
import pinocchio as pin
|
||||
from os.path import join, dirname, exists, expanduser
|
||||
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
|
||||
|
||||
def main() -> None:
|
||||
config = OpenArmsFollowerConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=5.0,
|
||||
)
|
||||
|
||||
|
||||
print("Initializing robot...")
|
||||
follower = OpenArmsFollower(config)
|
||||
follower.connect(calibrate=True)
|
||||
|
||||
# Load URDF for Pinocchio dynamics
|
||||
urdf_path = "/home/croissant/Documents/openarm_description/openarm_bimanual_pybullet.urdf"
|
||||
|
||||
pin_robot = pin.RobotWrapper.BuildFromURDF(urdf_path, dirname(urdf_path))
|
||||
pin_robot.data = pin_robot.model.createData()
|
||||
print(f"✓ Loaded Pinocchio model with {pin_robot.nq} DoFs")
|
||||
|
||||
follower.pin_robot = pin_robot
|
||||
|
||||
print(f"Applying gravity compensation")
|
||||
print(" 1. Support the arm before starting")
|
||||
print(" 2. The arm will be held in place by gravity compensation")
|
||||
print(" 3. You should be able to move it with gentle force")
|
||||
print("\nPress ENTER when ready to start...")
|
||||
input()
|
||||
|
||||
print(f"✓ Motors enabled")
|
||||
print("\nStarting gravity compensation loop...")
|
||||
print("Press Ctrl+C to stop\n")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get current joint positions from robot
|
||||
obs = follower.get_observation()
|
||||
|
||||
# Extract positions in degrees
|
||||
positions_deg = {}
|
||||
for motor in follower.bus_right.motors:
|
||||
key = f"right_{motor}.pos"
|
||||
if key in obs:
|
||||
positions_deg[f"right_{motor}"] = obs[key]
|
||||
|
||||
for motor in follower.bus_left.motors:
|
||||
key = f"left_{motor}.pos"
|
||||
if key in obs:
|
||||
positions_deg[f"left_{motor}"] = obs[key]
|
||||
|
||||
# Convert to radians and calculate gravity torques
|
||||
# Use the built-in method from OpenArmsFollower
|
||||
positions_rad = {k: np.deg2rad(v) for k, v in positions_deg.items()}
|
||||
torques_nm = follower._gravity_from_q(positions_rad)
|
||||
|
||||
# Apply gravity compensation to right arm (all joints except gripper)
|
||||
for motor in follower.bus_right.motors:
|
||||
if motor == "gripper":
|
||||
continue # Skip gripper
|
||||
|
||||
full_name = f"right_{motor}"
|
||||
position = positions_deg.get(full_name, 0.0)
|
||||
torque = torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Send MIT control command with gravity compensation torque
|
||||
follower.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0, # No position control
|
||||
kd=0.0, # No velocity damping
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque
|
||||
)
|
||||
|
||||
# Apply gravity compensation to left arm (all joints except gripper)
|
||||
for motor in follower.bus_left.motors:
|
||||
if motor == "gripper":
|
||||
continue # Skip gripper
|
||||
|
||||
full_name = f"left_{motor}"
|
||||
position = positions_deg.get(full_name, 0.0)
|
||||
torque = torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Send MIT control command with gravity compensation torque
|
||||
follower.bus_left._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0, # No position control
|
||||
kd=0.0, # No velocity damping
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque
|
||||
)
|
||||
|
||||
# Measure loop time
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
# Print status every 2 seconds
|
||||
if loop_end - last_print_time >= 2.0:
|
||||
if loop_times:
|
||||
avg_time = sum(loop_times) / len(loop_times)
|
||||
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||
|
||||
print(f"{current_hz:.1f} Hz ({avg_time*1000:.1f} ms)")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = loop_end
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping gravity compensation...")
|
||||
|
||||
finally:
|
||||
print("\nDisabling all motors and disconnecting...")
|
||||
follower.bus_right.disable_torque()
|
||||
follower.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
follower.disconnect()
|
||||
print("✓ Safe shutdown complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
395
examples/openarms/record_with_compensation.py
Normal file
395
examples/openarms/record_with_compensation.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
OpenArms Dataset Recording with Gravity + Friction Compensation
|
||||
|
||||
Records a dataset using OpenArms follower robot with leader teleoperator.
|
||||
Leader arms have gravity and friction compensation for weightless, easy movement.
|
||||
Includes 3 cameras: left wrist, right wrist, and base camera.
|
||||
|
||||
Uses the same compensation approach as teleop_with_compensation.py
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
# Recording parameters
|
||||
NUM_EPISODES = 1
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 600
|
||||
RESET_TIME_SEC = 120
|
||||
TASK_DESCRIPTION = "OpenArms task description"
|
||||
|
||||
# Friction compensation scale factor (1.0 = full, 0.3 = 30% for stability)
|
||||
FRICTION_SCALE = 1.0
|
||||
|
||||
def record_loop_with_compensation(
|
||||
robot,
|
||||
leader,
|
||||
events,
|
||||
fps,
|
||||
dataset,
|
||||
dataset_features,
|
||||
control_time_s,
|
||||
single_task,
|
||||
display_data=True,
|
||||
):
|
||||
"""
|
||||
Custom record loop that applies gravity + friction compensation to leader.
|
||||
Based on record_loop but with integrated compensation.
|
||||
"""
|
||||
dt = 1 / fps
|
||||
episode_start_time = time.perf_counter()
|
||||
|
||||
# All joints (both arms)
|
||||
all_joints = []
|
||||
for motor in leader.bus_right.motors:
|
||||
all_joints.append(f"right_{motor}")
|
||||
for motor in leader.bus_left.motors:
|
||||
all_joints.append(f"left_{motor}")
|
||||
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
elapsed = loop_start - episode_start_time
|
||||
|
||||
# Check if we should exit
|
||||
if elapsed >= control_time_s or events["exit_early"] or events["stop_recording"]:
|
||||
break
|
||||
|
||||
# Get leader state
|
||||
leader_action = leader.get_action()
|
||||
|
||||
# Extract positions and velocities in degrees
|
||||
leader_positions_deg = {}
|
||||
leader_velocities_deg_per_sec = {}
|
||||
|
||||
for motor in leader.bus_right.motors:
|
||||
pos_key = f"right_{motor}.pos"
|
||||
vel_key = f"right_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
|
||||
|
||||
for motor in leader.bus_left.motors:
|
||||
pos_key = f"left_{motor}.pos"
|
||||
vel_key = f"left_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
|
||||
|
||||
# Calculate gravity torques for leader using built-in method
|
||||
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
|
||||
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
|
||||
|
||||
# Calculate friction torques for leader using built-in method
|
||||
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
|
||||
leader_friction_torques_nm = leader._friction_from_velocity(
|
||||
leader_velocities_rad_per_sec,
|
||||
friction_scale=FRICTION_SCALE
|
||||
)
|
||||
|
||||
# Combine gravity + friction torques
|
||||
leader_total_torques_nm = {}
|
||||
for motor_name in leader_gravity_torques_nm:
|
||||
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
|
||||
friction = leader_friction_torques_nm.get(motor_name, 0.0)
|
||||
leader_total_torques_nm[motor_name] = gravity + friction
|
||||
|
||||
# Apply gravity + friction compensation to leader RIGHT arm (all joints including gripper)
|
||||
for motor in leader.bus_right.motors:
|
||||
full_name = f"right_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get damping gain for stability
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0,
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Apply gravity + friction compensation to leader LEFT arm (all joints including gripper)
|
||||
for motor in leader.bus_left.motors:
|
||||
full_name = f"left_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get damping gain for stability
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_left._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0,
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Send leader positions to follower (both arms)
|
||||
follower_action = {}
|
||||
for joint in all_joints:
|
||||
pos_key = f"{joint}.pos"
|
||||
if pos_key in leader_action:
|
||||
follower_action[pos_key] = leader_action[pos_key]
|
||||
|
||||
# Send action to robot
|
||||
if follower_action:
|
||||
robot.send_action(follower_action)
|
||||
|
||||
# Get observation from robot (includes camera images)
|
||||
observation = robot.get_observation()
|
||||
|
||||
# Add to dataset if we have a dataset
|
||||
if dataset is not None:
|
||||
# Build properly formatted observation frame
|
||||
obs_frame = build_dataset_frame(dataset_features, observation, prefix="observation")
|
||||
|
||||
# Build properly formatted action frame (keep .pos suffix - it matches the feature names)
|
||||
action_frame = build_dataset_frame(dataset_features, follower_action, prefix="action")
|
||||
|
||||
# Combine into single frame
|
||||
frame = {**obs_frame, **action_frame}
|
||||
|
||||
# Add metadata (task is required, timestamp will be auto-calculated by add_frame)
|
||||
frame["task"] = single_task
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# Display data if requested
|
||||
if display_data:
|
||||
log_rerun_data(observation=observation, action=follower_action)
|
||||
|
||||
# Maintain loop rate
|
||||
loop_duration = time.perf_counter() - loop_start
|
||||
sleep_time = dt - loop_duration
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main recording loop with gravity compensation."""
|
||||
|
||||
print("=" * 70)
|
||||
print("OpenArms Dataset Recording with Compensation")
|
||||
print("=" * 70)
|
||||
|
||||
# Create camera configurations (3 cameras: left wrist, right wrist, base)
|
||||
# Using actual device paths found by lerobot-find-cameras opencv
|
||||
camera_config = {
|
||||
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video0", width=640, height=480, fps=FPS),
|
||||
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=FPS),
|
||||
"base": OpenCVCameraConfig(index_or_path="/dev/video7", width=640, height=480, fps=FPS),
|
||||
}
|
||||
|
||||
# Configure follower robot with cameras
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left="can2",
|
||||
port_right="can3",
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0,
|
||||
cameras=camera_config,
|
||||
)
|
||||
|
||||
# Configure leader teleoperator (no cameras needed)
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_leader",
|
||||
manual_control=False, # Enable torque control for gravity compensation
|
||||
)
|
||||
|
||||
# Initialize robot and teleoperator
|
||||
print("\nInitializing devices...")
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
leader = OpenArmsLeader(leader_config)
|
||||
|
||||
# Connect devices
|
||||
print("Connecting and calibrating...")
|
||||
follower.connect(calibrate=True)
|
||||
leader.connect(calibrate=True)
|
||||
|
||||
# Verify URDF is loaded for gravity compensation
|
||||
if leader.pin_robot is None:
|
||||
raise RuntimeError("URDF model not loaded on leader. Gravity compensation not available.")
|
||||
|
||||
# Configure the dataset features
|
||||
# For actions, we only want to record positions (not velocity or torque)
|
||||
action_features_hw = {}
|
||||
for key, value in follower.action_features.items():
|
||||
if key.endswith(".pos"):
|
||||
action_features_hw[key] = value
|
||||
|
||||
action_features = hw_to_dataset_features(action_features_hw, "action")
|
||||
obs_features = hw_to_dataset_features(follower.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
print("\nCreating dataset...")
|
||||
repo_id = "<hf_username>/<dataset_repo_id>" # TODO: Replace with your Hugging Face repo
|
||||
|
||||
# Check if dataset already exists and prompt user
|
||||
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id
|
||||
while dataset_path.exists():
|
||||
print(f"\nDataset already exists at: {dataset_path}")
|
||||
print("\nOptions:")
|
||||
print(" 1. Overwrite existing dataset")
|
||||
print(" 2. Use a different name")
|
||||
print(" 3. Abort")
|
||||
|
||||
choice = input("\nEnter your choice (1/2/3): ").strip()
|
||||
|
||||
if choice == '1':
|
||||
print(f"Removing existing dataset...")
|
||||
shutil.rmtree(dataset_path)
|
||||
print("✓ Existing dataset removed")
|
||||
break
|
||||
elif choice == '2':
|
||||
print("\nCurrent repo_id:", repo_id)
|
||||
new_repo_id = input("Enter new repo_id (format: <username>/<dataset_name>): ").strip()
|
||||
if new_repo_id and '/' in new_repo_id:
|
||||
repo_id = new_repo_id
|
||||
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id
|
||||
print(f"✓ Using new repo_id: {repo_id}")
|
||||
# Loop will continue if this new path also exists
|
||||
else:
|
||||
print("Invalid repo_id format. Please use format: <username>/<dataset_name>")
|
||||
elif choice == '3':
|
||||
print("Aborting. Please remove the existing dataset manually or restart with a different repo_id.")
|
||||
follower.disconnect()
|
||||
leader.disconnect()
|
||||
return
|
||||
else:
|
||||
print("Invalid choice. Please enter 1, 2, or 3.")
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize keyboard listener and visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="openarms_recording")
|
||||
|
||||
# Enable motors on both leader arms for gravity compensation
|
||||
leader.bus_right.enable_torque()
|
||||
leader.bus_left.enable_torque()
|
||||
time.sleep(0.1)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(f"Recording {NUM_EPISODES} episodes")
|
||||
print(f"Task: {TASK_DESCRIPTION}")
|
||||
print("=" * 70)
|
||||
print("\nLeader BOTH arms: Gravity + Friction comp | Follower BOTH arms: Teleop")
|
||||
print("\nKeyboard controls:")
|
||||
print(" - Press 'q' to stop recording")
|
||||
print(" - Press 'r' to re-record current episode")
|
||||
print("=" * 70)
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
try:
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Record episode with compensation active
|
||||
record_loop_with_compensation(
|
||||
robot=follower,
|
||||
leader=leader,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
dataset_features=dataset_features,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop_with_compensation(
|
||||
robot=follower,
|
||||
leader=leader,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=None, # Don't save reset period
|
||||
dataset_features=dataset_features,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
# Handle re-recording
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Only save episode if frames were recorded
|
||||
if dataset.episode_buffer is not None and dataset.episode_buffer["size"] > 0:
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
else:
|
||||
log_say("No frames recorded, skipping episode save")
|
||||
# Clear the empty buffer
|
||||
dataset.episode_buffer = None
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping recording...")
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
try:
|
||||
leader.bus_right.disable_torque()
|
||||
leader.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
print("✓ Shutdown complete")
|
||||
except Exception as e:
|
||||
print(f"Shutdown error: {e}")
|
||||
|
||||
# Upload dataset
|
||||
print("\nUploading dataset to Hugging Face Hub...")
|
||||
try:
|
||||
dataset.push_to_hub()
|
||||
print("✓ Dataset uploaded successfully")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to upload dataset: {e}")
|
||||
print("You can manually upload later using: dataset.push_to_hub()")
|
||||
|
||||
print("✓ Recording complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
166
examples/openarms/replay.py
Normal file
166
examples/openarms/replay.py
Normal file
@@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OpenArms Dataset Replay Example
|
||||
|
||||
Replays position actions from a recorded dataset on an OpenArms follower robot.
|
||||
Only position commands (ending with .pos) are replayed, not velocity or torque.
|
||||
|
||||
Example usage:
|
||||
python examples/openarms/replay.py
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
# Configuration
|
||||
EPISODE_IDX = 0
|
||||
DATASET_REPO_ID = "lerobot-data-collection/replay-this-2025-11-02-17-58" # TODO: Replace with your dataset
|
||||
DATASET_ROOT = None # Use default cache location, or specify custom path
|
||||
|
||||
# Robot configuration - adjust these to match your setup
|
||||
ROBOT_CONFIG = OpenArmsFollowerConfig(
|
||||
port_left="can2", # CAN interface for left arm
|
||||
port_right="can3", # CAN interface for right arm
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0, # Safety limit: max degrees to move per step
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main replay function."""
|
||||
print("=" * 70)
|
||||
print("OpenArms Dataset Replay")
|
||||
print("=" * 70)
|
||||
print(f"\nDataset: {DATASET_REPO_ID}")
|
||||
print(f"Episode: {EPISODE_IDX}")
|
||||
print(f"Robot: {ROBOT_CONFIG.id}")
|
||||
print(f" Left arm: {ROBOT_CONFIG.port_left}")
|
||||
print(f" Right arm: {ROBOT_CONFIG.port_right}")
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
# Initialize the robot
|
||||
print("\n[1/3] Initializing robot...")
|
||||
robot = OpenArmsFollower(ROBOT_CONFIG)
|
||||
|
||||
# Load the dataset
|
||||
print(f"\n[2/3] Loading dataset '{DATASET_REPO_ID}'...")
|
||||
dataset = LeRobotDataset(
|
||||
DATASET_REPO_ID,
|
||||
root=DATASET_ROOT,
|
||||
episodes=[EPISODE_IDX]
|
||||
)
|
||||
|
||||
# Filter dataset to only include frames from the specified episode
|
||||
# (required for dataset V3.0 where episodes are chunked)
|
||||
episode_frames = dataset.hf_dataset.filter(
|
||||
lambda x: x["episode_index"] == EPISODE_IDX
|
||||
)
|
||||
|
||||
if len(episode_frames) == 0:
|
||||
raise ValueError(
|
||||
f"No frames found for episode {EPISODE_IDX} in dataset {DATASET_REPO_ID}"
|
||||
)
|
||||
|
||||
print(f" Found {len(episode_frames)} frames in episode {EPISODE_IDX}")
|
||||
|
||||
# Extract action features from dataset
|
||||
action_features = dataset.features.get(ACTION, {})
|
||||
action_names = action_features.get("names", [])
|
||||
|
||||
# Filter to only position actions (ending with .pos)
|
||||
position_action_names = [name for name in action_names if name.endswith(".pos")]
|
||||
|
||||
if not position_action_names:
|
||||
raise ValueError(
|
||||
f"No position actions found in dataset. Action names: {action_names}"
|
||||
)
|
||||
|
||||
print(f" Found {len(position_action_names)} position actions to replay")
|
||||
print(f" Actions: {', '.join(position_action_names[:5])}{'...' if len(position_action_names) > 5 else ''}")
|
||||
|
||||
# Select only action columns from dataset
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Connect to the robot
|
||||
print(f"\n[3/3] Connecting to robot...")
|
||||
robot.connect(calibrate=False) # Skip calibration for replay
|
||||
|
||||
if not robot.is_connected:
|
||||
raise RuntimeError("Robot failed to connect!")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Ready to replay!")
|
||||
print("=" * 70)
|
||||
print("\nThe robot will replay the recorded positions.")
|
||||
print("Press Ctrl+C to stop at any time.\n")
|
||||
|
||||
input("Press ENTER to start replaying...")
|
||||
|
||||
# Replay loop
|
||||
log_say(f"Replaying episode {EPISODE_IDX}", blocking=True)
|
||||
|
||||
try:
|
||||
for idx in range(len(episode_frames)):
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Extract action array from dataset
|
||||
action_array = actions[idx][ACTION]
|
||||
|
||||
# Build action dictionary, but only include position actions
|
||||
action = {}
|
||||
for i, name in enumerate(action_names):
|
||||
# Only include position actions (ending with .pos)
|
||||
if name.endswith(".pos"):
|
||||
action[name] = float(action_array[i])
|
||||
|
||||
# Send action to robot
|
||||
robot.send_action(action)
|
||||
|
||||
# Maintain replay rate (use dataset fps)
|
||||
loop_duration = time.perf_counter() - loop_start
|
||||
dt_s = 1.0 / dataset.fps - loop_duration
|
||||
busy_wait(dt_s)
|
||||
|
||||
# Progress indicator every 100 frames
|
||||
if (idx + 1) % 100 == 0:
|
||||
progress = (idx + 1) / len(episode_frames) * 100
|
||||
print(f"Progress: {idx + 1}/{len(episode_frames)} frames ({progress:.1f}%)")
|
||||
|
||||
print(f"\n✓ Successfully replayed {len(episode_frames)} frames")
|
||||
log_say("Replay complete", blocking=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nReplay interrupted by user")
|
||||
finally:
|
||||
# Disconnect robot
|
||||
print("\nDisconnecting robot...")
|
||||
robot.disconnect()
|
||||
print("✓ Replay complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
73
examples/openarms/setup_can.sh
Executable file
73
examples/openarms/setup_can.sh
Executable file
@@ -0,0 +1,73 @@
|
||||
#!/bin/bash
|
||||
# Setup all OpenArms CAN interfaces with CAN FD
|
||||
|
||||
set -e
|
||||
|
||||
echo "=========================================="
|
||||
echo "OpenArms CAN FD Interface Setup"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "Mode: CAN FD"
|
||||
echo " - Nominal bitrate: 1 Mbps"
|
||||
echo " - Data bitrate: 5 Mbps"
|
||||
echo ""
|
||||
echo "Configuring interfaces can0, can1, can2, can3..."
|
||||
echo ""
|
||||
|
||||
# Configure each CAN interface with CAN FD
|
||||
for i in 0 1 2 3; do
|
||||
interface="can$i"
|
||||
|
||||
# Check if interface exists
|
||||
if ! ip link show "$interface" &> /dev/null; then
|
||||
echo "⚠ $interface: Not found, skipping"
|
||||
continue
|
||||
fi
|
||||
|
||||
# Bring down interface
|
||||
sudo ip link set "$interface" down 2>/dev/null
|
||||
|
||||
# Configure CAN FD mode
|
||||
sudo ip link set "$interface" type can \
|
||||
bitrate 1000000 \
|
||||
dbitrate 5000000 \
|
||||
fd on
|
||||
|
||||
# Bring up interface
|
||||
sudo ip link set "$interface" up
|
||||
|
||||
# Verify configuration
|
||||
if ip link show "$interface" | grep -q "UP"; then
|
||||
echo "✓ $interface: Configured and UP"
|
||||
else
|
||||
echo "✗ $interface: Failed to bring UP"
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Verification"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Show detailed status for each interface
|
||||
for i in 0 1 2 3; do
|
||||
interface="can$i"
|
||||
if ip link show "$interface" &> /dev/null; then
|
||||
echo "$interface:"
|
||||
# Show key parameters
|
||||
ip -d link show "$interface" | grep -E "can|state|bitrate|dbitrate" | head -3
|
||||
echo ""
|
||||
fi
|
||||
done
|
||||
|
||||
echo "=========================================="
|
||||
echo "Setup Complete!"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "All interfaces configured for CAN FD mode"
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo " 1. Test motors: python debug_can_communication.py"
|
||||
echo " 2. Run teleoperation: python examples/openarms/teleop.py"
|
||||
echo ""
|
||||
148
examples/openarms/teleop.py
Normal file
148
examples/openarms/teleop.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
OpenArms Teleoperation Example - Full Dual Arms
|
||||
|
||||
This script demonstrates teleoperation of OpenArms follower robot using an OpenArms leader arm.
|
||||
It first calibrates both devices, then enters a teleoperation loop for both arms.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
|
||||
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left="can2", # CAN interface for follower left arm
|
||||
port_right="can3", # CAN interface for follower right arm
|
||||
can_interface="socketcan", # Linux SocketCAN
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=5.0, # Safety limit
|
||||
)
|
||||
|
||||
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
port_left="can0", # CAN interface for leader left arm
|
||||
port_right="can1", # CAN interface for leader right arm
|
||||
can_interface="socketcan", # Linux SocketCAN
|
||||
id="openarms_leader",
|
||||
manual_control=True, # Enable manual control (torque disabled)
|
||||
)
|
||||
|
||||
print("=" * 60)
|
||||
print("OpenArms Teleoperation - Full Dual Arms")
|
||||
print("=" * 60)
|
||||
|
||||
# Initialize devices
|
||||
print("\n[1/4] Initializing devices...")
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
leader = OpenArmsLeader(leader_config)
|
||||
|
||||
# Connect and calibrate follower
|
||||
print("\n[2/4] Connecting and calibrating follower robot...")
|
||||
print("Note: If you have existing calibration, just press ENTER to use it.")
|
||||
follower.connect(calibrate=True)
|
||||
|
||||
# Connect and calibrate leader
|
||||
print("\n[3/4] Connecting and calibrating leader arm...")
|
||||
print("Note: The leader arm will have torque disabled for manual control.")
|
||||
leader.connect(calibrate=True)
|
||||
|
||||
# Wait for user to be ready
|
||||
print("\n[4/4] Ready for teleoperation!")
|
||||
print("\nBoth arms will be controlled (16 motors total):")
|
||||
print(" RIGHT ARM: joints 1-7 + gripper")
|
||||
print(" LEFT ARM: joints 1-7 + gripper")
|
||||
|
||||
print("\nPress ENTER to start teleoperation...")
|
||||
input()
|
||||
|
||||
print("\nTeleoperation started! Move both leader arms.")
|
||||
print("Press Ctrl+C to stop.\n")
|
||||
|
||||
# All joints for both arms (16 motors total)
|
||||
all_joints = [
|
||||
# Right arm
|
||||
"right_joint_1",
|
||||
"right_joint_2",
|
||||
"right_joint_3",
|
||||
"right_joint_4",
|
||||
"right_joint_5",
|
||||
"right_joint_6",
|
||||
"right_joint_7",
|
||||
"right_gripper",
|
||||
# Left arm
|
||||
"left_joint_1",
|
||||
"left_joint_2",
|
||||
"left_joint_3",
|
||||
"left_joint_4",
|
||||
"left_joint_5",
|
||||
"left_joint_6",
|
||||
"left_joint_7",
|
||||
"left_gripper",
|
||||
]
|
||||
|
||||
# Performance monitoring
|
||||
loop_times = []
|
||||
start_time = time.perf_counter()
|
||||
last_print_time = start_time
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get action from leader
|
||||
leader_action = leader.get_action()
|
||||
|
||||
# Filter to only position data for all joints (both arms)
|
||||
joint_action = {}
|
||||
for joint in all_joints:
|
||||
pos_key = f"{joint}.pos"
|
||||
if pos_key in leader_action:
|
||||
joint_action[pos_key] = leader_action[pos_key]
|
||||
|
||||
# Send action to follower (both arms)
|
||||
if joint_action:
|
||||
follower.send_action(joint_action)
|
||||
|
||||
# Measure loop time
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
# Print stats every 2 seconds
|
||||
if loop_end - last_print_time >= 2.0:
|
||||
if loop_times:
|
||||
avg_time = sum(loop_times) / len(loop_times)
|
||||
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||
min_time = min(loop_times)
|
||||
max_time = max(loop_times)
|
||||
max_hz = 1.0 / min_time if min_time > 0 else 0
|
||||
min_hz = 1.0 / max_time if max_time > 0 else 0
|
||||
|
||||
print(f"[Hz Stats] Avg: {current_hz:.1f} Hz | "
|
||||
f"Range: {min_hz:.1f}-{max_hz:.1f} Hz | "
|
||||
f"Avg loop time: {avg_time*1000:.1f} ms")
|
||||
|
||||
# Reset for next measurement window
|
||||
loop_times = []
|
||||
last_print_time = loop_end
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping teleoperation...")
|
||||
finally:
|
||||
# Disconnect devices
|
||||
print("Disconnecting devices...")
|
||||
try:
|
||||
follower.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting follower: {e}")
|
||||
|
||||
try:
|
||||
leader.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting leader: {e}")
|
||||
|
||||
print("Done!")
|
||||
197
examples/openarms/teleop_openarms_mini.py
Normal file
197
examples/openarms/teleop_openarms_mini.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
OpenArms Mini Teleoperation Example
|
||||
|
||||
This script demonstrates teleoperation of an OpenArms follower robot using
|
||||
an OpenArms Mini leader (Feetech-based) with dual arms (16 motors total).
|
||||
|
||||
The OpenArms Mini has:
|
||||
- Right arm: 8 motors (joint_1 to joint_7 + gripper)
|
||||
- Left arm: 8 motors (joint_1 to joint_7 + gripper)
|
||||
|
||||
Note on gripper normalization:
|
||||
- OpenArms Mini gripper: 0-100 scale (0=closed, 100=open)
|
||||
- OpenArms follower gripper: degrees (0=closed, -65=open)
|
||||
- This script automatically converts between the two ranges
|
||||
"""
|
||||
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.teleoperators.openarms_mini.openarms_mini import OpenArmsMini
|
||||
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
|
||||
# Target control frequency
|
||||
TARGET_FPS = 30
|
||||
|
||||
# Configure the OpenArms follower (Damiao motors on CAN bus)
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left="can0", # CAN interface for follower left arm
|
||||
port_right="can1", # CAN interface for follower right arm
|
||||
can_interface="socketcan", # Linux SocketCAN
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0, # Safety limit (degrees per step)
|
||||
)
|
||||
|
||||
# Configure the OpenArms Mini leader (Feetech motors on serial)
|
||||
leader_config = OpenArmsMiniConfig(
|
||||
port_right="/dev/ttyACM0", # Serial port for right arm
|
||||
port_left="/dev/ttyACM1", # Serial port for left arm
|
||||
id="openarms_mini",
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
print("OpenArms Mini → OpenArms Follower Teleoperation")
|
||||
|
||||
# Initialize devices
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
leader = OpenArmsMini(leader_config)
|
||||
|
||||
# Connect and calibrate follower
|
||||
print("Note: If you have existing calibration, just press ENTER to use it.")
|
||||
follower.connect(calibrate=True)
|
||||
|
||||
# Connect and calibrate leader
|
||||
print("Note: The leader arms will have torque disabled for manual control.")
|
||||
leader.connect(calibrate=True)
|
||||
|
||||
print("\nPress ENTER to start teleoperation...")
|
||||
input()
|
||||
|
||||
print("Press Ctrl+C to stop.\n")
|
||||
|
||||
# All joints for both arms (16 motors total)
|
||||
all_joints = [
|
||||
# Right arm
|
||||
"right_joint_1",
|
||||
"right_joint_2",
|
||||
"right_joint_3",
|
||||
"right_joint_4",
|
||||
"right_joint_5",
|
||||
"right_joint_6",
|
||||
"right_joint_7",
|
||||
"right_gripper",
|
||||
# Left arm
|
||||
"left_joint_1",
|
||||
"left_joint_2",
|
||||
"left_joint_3",
|
||||
"left_joint_4",
|
||||
"left_joint_5",
|
||||
"left_joint_6",
|
||||
"left_joint_7",
|
||||
"left_gripper",
|
||||
]
|
||||
|
||||
# Performance monitoring
|
||||
loop_times = []
|
||||
avg_loop_time = 0.0
|
||||
min_loop_time = float('inf')
|
||||
max_loop_time = 0.0
|
||||
stats_update_interval = 1.0 # Update stats every 1 second
|
||||
last_stats_update = time.perf_counter()
|
||||
|
||||
|
||||
SWAPPED_JOINTS = {
|
||||
"right_joint_6": "right_joint_7",
|
||||
"right_joint_7": "right_joint_6",
|
||||
"left_joint_6": "left_joint_7",
|
||||
"left_joint_7": "left_joint_6",
|
||||
}
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get actions and observations
|
||||
leader_action = leader.get_action()
|
||||
follower_obs = follower.get_observation()
|
||||
|
||||
joint_action = {}
|
||||
for joint in all_joints:
|
||||
leader_key = f"{joint}.pos"
|
||||
|
||||
# Determine which follower joint this leader joint controls
|
||||
follower_joint = SWAPPED_JOINTS.get(joint, joint)
|
||||
follower_key = f"{follower_joint}.pos"
|
||||
|
||||
# Get leader position (default 0 if missing)
|
||||
pos = leader_action.get(leader_key, 0.0)
|
||||
|
||||
# Convert gripper values: Mini uses 0-100, OpenArms uses 0 to -65 degrees
|
||||
if "gripper" in joint:
|
||||
# Map 0-100 (Mini) to 0 to -65 (OpenArms)
|
||||
# 0 (closed) -> 0°, 100 (open) -> -65°
|
||||
pos = (pos / 100.0) * -65.0
|
||||
|
||||
# Store in action dict for follower
|
||||
joint_action[follower_key] = pos
|
||||
|
||||
follower.send_action(joint_action)
|
||||
|
||||
# Loop timing
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
# Update stats periodically
|
||||
current_time = time.perf_counter()
|
||||
if current_time - last_stats_update >= stats_update_interval:
|
||||
if loop_times:
|
||||
avg_loop_time = sum(loop_times) / len(loop_times)
|
||||
min_loop_time = min(loop_times)
|
||||
max_loop_time = max(loop_times)
|
||||
loop_times = []
|
||||
last_stats_update = current_time
|
||||
|
||||
# Display everything
|
||||
sys.stdout.write("\033[H\033[J") # Clear screen
|
||||
|
||||
# Show timing stats at the top
|
||||
if avg_loop_time > 0:
|
||||
avg_hz = 1.0 / avg_loop_time
|
||||
min_hz = 1.0 / max_loop_time if max_loop_time > 0 else 0
|
||||
max_hz = 1.0 / min_loop_time if min_loop_time > 0 and min_loop_time < float('inf') else 0
|
||||
print(f"[Performance] Target: {TARGET_FPS} Hz | Avg: {avg_hz:.1f} Hz | Range: {min_hz:.1f}-{max_hz:.1f} Hz | Loop: {avg_loop_time*1000:.1f} ms\n")
|
||||
else:
|
||||
print(f"[Performance] Target: {TARGET_FPS} Hz | Measuring...\n")
|
||||
|
||||
# Show joint positions
|
||||
print(f"{'Joint':<20} {'Leader':>15} {'Follower':>15}")
|
||||
print(f"{'':20} {'(0-100/deg)':>15} {'(deg)':>15}")
|
||||
print("-" * 52)
|
||||
|
||||
for joint in all_joints:
|
||||
leader_key = f"{joint}.pos"
|
||||
follower_joint = SWAPPED_JOINTS.get(joint, joint)
|
||||
follower_key = f"{follower_joint}.pos"
|
||||
|
||||
leader_pos = leader_action.get(leader_key, 0.0)
|
||||
follower_pos = follower_obs.get(follower_key, 0.0)
|
||||
|
||||
print(f"{joint:<20} {leader_pos:>15.2f} {follower_pos:>15.2f}")
|
||||
|
||||
# Smart sleep to maintain target FPS
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
busy_wait(max(0, 1.0 / TARGET_FPS - dt_s))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping teleoperation...")
|
||||
finally:
|
||||
# Disconnect devices
|
||||
print("Disconnecting devices...")
|
||||
try:
|
||||
follower.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting follower: {e}")
|
||||
|
||||
try:
|
||||
leader.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting leader: {e}")
|
||||
|
||||
print("Done!")
|
||||
|
||||
202
examples/openarms/teleop_with_compensation.py
Executable file
202
examples/openarms/teleop_with_compensation.py
Executable file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
OpenArms Teleoperation with Gravity + Friction Compensation
|
||||
|
||||
Leader arms (both LEFT and RIGHT): Gravity + Friction compensation (weightless, easy to move)
|
||||
Follower arms (both LEFT and RIGHT): Mirror leader movements
|
||||
|
||||
Uses the URDF file from the lerobot repository.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
|
||||
# Friction compensation scale factor (1.0 = full, 0.3 = 30% for stability)
|
||||
FRICTION_SCALE = 1.0
|
||||
|
||||
def main():
|
||||
"""Main teleoperation loop with gravity compensation"""
|
||||
|
||||
print("=" * 70)
|
||||
print("OpenArms Teleoperation with Gravity Compensation")
|
||||
print("=" * 70)
|
||||
|
||||
# Configuration
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left="can2",
|
||||
port_right="can3",
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0,
|
||||
)
|
||||
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_leader",
|
||||
manual_control=False, # Enable torque control for gravity compensation
|
||||
)
|
||||
|
||||
# Initialize and connect
|
||||
print("\nInitializing devices...")
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
leader = OpenArmsLeader(leader_config)
|
||||
|
||||
follower.connect()
|
||||
leader.connect()
|
||||
|
||||
# URDF is automatically loaded in the leader constructor
|
||||
if leader.pin_robot is None:
|
||||
raise RuntimeError("URDF model not loaded on leader. Gravity compensation not available.")
|
||||
|
||||
print("\nLeader BOTH arms: Gravity + Friction comp | Follower BOTH arms: Teleop")
|
||||
print("Press ENTER to start...")
|
||||
input()
|
||||
|
||||
# Enable motors on both leader arms for gravity compensation
|
||||
leader.bus_right.enable_torque()
|
||||
leader.bus_left.enable_torque()
|
||||
time.sleep(0.1)
|
||||
|
||||
print("Press Ctrl+C to stop\n")
|
||||
|
||||
# Main control loop
|
||||
loop_times = []
|
||||
last_print_time = time.perf_counter()
|
||||
|
||||
# All joints (both arms)
|
||||
all_joints = []
|
||||
for motor in leader.bus_right.motors:
|
||||
all_joints.append(f"right_{motor}")
|
||||
for motor in leader.bus_left.motors:
|
||||
all_joints.append(f"left_{motor}")
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get leader state
|
||||
leader_action = leader.get_action()
|
||||
|
||||
# Extract positions and velocities in degrees
|
||||
leader_positions_deg = {}
|
||||
leader_velocities_deg_per_sec = {}
|
||||
|
||||
for motor in leader.bus_right.motors:
|
||||
pos_key = f"right_{motor}.pos"
|
||||
vel_key = f"right_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
|
||||
|
||||
for motor in leader.bus_left.motors:
|
||||
pos_key = f"left_{motor}.pos"
|
||||
vel_key = f"left_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
|
||||
|
||||
# Calculate gravity torques for leader using built-in method
|
||||
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
|
||||
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
|
||||
|
||||
# Calculate friction torques for leader using built-in method
|
||||
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
|
||||
leader_friction_torques_nm = leader._friction_from_velocity(
|
||||
leader_velocities_rad_per_sec,
|
||||
friction_scale=FRICTION_SCALE
|
||||
)
|
||||
|
||||
# Combine gravity + friction torques
|
||||
leader_total_torques_nm = {}
|
||||
for motor_name in leader_gravity_torques_nm:
|
||||
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
|
||||
friction = leader_friction_torques_nm.get(motor_name, 0.0)
|
||||
leader_total_torques_nm[motor_name] = gravity + friction
|
||||
|
||||
# Apply gravity + friction compensation to leader RIGHT arm (all joints including gripper)
|
||||
for motor in leader.bus_right.motors:
|
||||
full_name = f"right_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get damping gain for stability
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0,
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Apply gravity + friction compensation to leader LEFT arm (all joints including gripper)
|
||||
for motor in leader.bus_left.motors:
|
||||
full_name = f"left_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get damping gain for stability
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_left._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0,
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Send leader positions to follower (both arms)
|
||||
follower_action = {}
|
||||
for joint in all_joints:
|
||||
pos_key = f"{joint}.pos"
|
||||
if pos_key in leader_action:
|
||||
follower_action[pos_key] = leader_action[pos_key]
|
||||
|
||||
if follower_action:
|
||||
follower.send_action(follower_action)
|
||||
|
||||
# Performance monitoring
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
if loop_end - last_print_time >= 2.0:
|
||||
if loop_times:
|
||||
avg_time = sum(loop_times) / len(loop_times)
|
||||
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||
|
||||
print(f"{current_hz:.1f} Hz ({avg_time*1000:.1f} ms)")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = loop_end
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping...")
|
||||
finally:
|
||||
try:
|
||||
leader.bus_right.disable_torque()
|
||||
leader.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
print("✓ Shutdown complete")
|
||||
except Exception as e:
|
||||
print(f"Shutdown error: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
745
examples/openarms_web_interface/App.css
Normal file
745
examples/openarms_web_interface/App.css
Normal file
@@ -0,0 +1,745 @@
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
|
||||
main {
|
||||
min-height: 100vh;
|
||||
padding: 2rem;
|
||||
}
|
||||
|
||||
header {
|
||||
text-align: center;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 2rem;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
h2 {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin: 0 0 1rem 0;
|
||||
}
|
||||
|
||||
h3 {
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
color: #666;
|
||||
margin: 0 0 0.5rem 0;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 1920px;
|
||||
margin: 0 auto;
|
||||
display: grid;
|
||||
grid-template-columns: minmax(500px, 600px) 1fr;
|
||||
gap: 2rem;
|
||||
align-items: start;
|
||||
}
|
||||
|
||||
/* Left column container */
|
||||
.left-column {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
/* Right column container */
|
||||
.right-column {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
/* Responsive: Stack on smaller screens */
|
||||
@media (max-width: 1200px) {
|
||||
.container {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
.panel {
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
.config-panel {
|
||||
border: 2px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.config-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.config-header:hover {
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
.toggle-icon {
|
||||
font-size: 1rem;
|
||||
color: #6b7280;
|
||||
transition: transform 0.2s;
|
||||
}
|
||||
|
||||
.config-content {
|
||||
margin-top: 1rem;
|
||||
padding-top: 1rem;
|
||||
border-top: 1px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.robot-setup {
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.robot-status {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 1rem;
|
||||
border-radius: 6px;
|
||||
font-weight: 500;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.robot-status.ready {
|
||||
background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%);
|
||||
color: #065f46;
|
||||
border: 1px solid #10b981;
|
||||
}
|
||||
|
||||
.robot-status.not-ready {
|
||||
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
|
||||
color: #92400e;
|
||||
border: 1px solid #f59e0b;
|
||||
}
|
||||
|
||||
.btn-setup {
|
||||
background: #10b981;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.btn-setup:hover:not(:disabled) {
|
||||
background: #059669;
|
||||
}
|
||||
|
||||
.btn-setup:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.btn-zero {
|
||||
background: #8b5cf6;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.btn-zero:hover:not(:disabled) {
|
||||
background: #7c3aed;
|
||||
}
|
||||
|
||||
.btn-zero:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.zero-position-section {
|
||||
margin-top: 1rem;
|
||||
padding-top: 1rem;
|
||||
border-top: 1px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.btn-zero-large {
|
||||
width: 100%;
|
||||
background: #8b5cf6;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.875rem 1.5rem;
|
||||
border-radius: 8px;
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
box-shadow: 0 2px 4px rgba(139, 92, 246, 0.2);
|
||||
}
|
||||
|
||||
.btn-zero-large:hover:not(:disabled) {
|
||||
background: #7c3aed;
|
||||
box-shadow: 0 4px 8px rgba(139, 92, 246, 0.3);
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.btn-zero-large:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
box-shadow: none;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
.delete-episode-section {
|
||||
margin-top: 1rem;
|
||||
padding-top: 1rem;
|
||||
border-top: 1px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.btn-delete {
|
||||
width: 100%;
|
||||
background: #ef4444;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.875rem 1.5rem;
|
||||
border-radius: 8px;
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
box-shadow: 0 2px 4px rgba(239, 68, 68, 0.2);
|
||||
}
|
||||
|
||||
.btn-delete:hover:not(:disabled) {
|
||||
background: #dc2626;
|
||||
box-shadow: 0 4px 8px rgba(239, 68, 68, 0.3);
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.btn-delete:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
box-shadow: none;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
.delete-info {
|
||||
margin-top: 0.5rem;
|
||||
font-size: 0.875rem;
|
||||
color: #666;
|
||||
text-align: center;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.btn-disconnect {
|
||||
background: #ef4444;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.btn-disconnect:hover {
|
||||
background: #dc2626;
|
||||
}
|
||||
|
||||
.btn-refresh {
|
||||
background: #3b82f6;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.4rem 0.8rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.btn-refresh:hover:not(:disabled) {
|
||||
background: #2563eb;
|
||||
}
|
||||
|
||||
.btn-refresh:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.control-panel {
|
||||
border: 2px solid #10b981;
|
||||
}
|
||||
|
||||
.status-banner {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 1rem;
|
||||
padding: 1rem 1.5rem;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 1.5rem;
|
||||
font-weight: 500;
|
||||
font-size: 0.95rem;
|
||||
}
|
||||
|
||||
.status-banner.initializing {
|
||||
background: linear-gradient(135deg, #dbeafe 0%, #bfdbfe 100%);
|
||||
color: #1e40af;
|
||||
border-left: 4px solid #3b82f6;
|
||||
}
|
||||
|
||||
.status-banner.encoding {
|
||||
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
|
||||
color: #92400e;
|
||||
border-left: 4px solid #f59e0b;
|
||||
}
|
||||
|
||||
.status-banner.uploading {
|
||||
background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%);
|
||||
color: #3730a3;
|
||||
border-left: 4px solid #6366f1;
|
||||
}
|
||||
|
||||
.status-banner.success {
|
||||
background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%);
|
||||
color: #065f46;
|
||||
border-left: 4px solid #10b981;
|
||||
}
|
||||
|
||||
.status-banner.warning {
|
||||
background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%);
|
||||
color: #991b1b;
|
||||
border-left: 4px solid #ef4444;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
border: 3px solid rgba(0, 0, 0, 0.1);
|
||||
border-top-color: currentColor;
|
||||
border-radius: 50%;
|
||||
animation: spin 0.8s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
to { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
.control-horizontal {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.control-left {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.control-right {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.input-group {
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
input[type="text"] {
|
||||
flex: 1;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
input[type="text"]:disabled {
|
||||
background: #f5f5f5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
input[type="text"]:focus {
|
||||
outline: none;
|
||||
border-color: #10b981;
|
||||
}
|
||||
|
||||
button {
|
||||
padding: 0.75rem 1.5rem;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
font-size: 1rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.btn-set-task {
|
||||
background: #3b82f6;
|
||||
color: white;
|
||||
min-width: 120px;
|
||||
}
|
||||
|
||||
.btn-set-task:hover:not(:disabled) {
|
||||
background: #2563eb;
|
||||
}
|
||||
|
||||
.btn-set-task:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.btn-start {
|
||||
background: #10b981;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-start:hover:not(:disabled) {
|
||||
background: #059669;
|
||||
}
|
||||
|
||||
.btn-start:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.btn-stop {
|
||||
background: #ef4444;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-stop:hover {
|
||||
background: #dc2626;
|
||||
}
|
||||
|
||||
.btn-reset {
|
||||
padding: 0.5rem 1rem;
|
||||
background: #6b7280;
|
||||
color: white;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.btn-reset:hover {
|
||||
background: #4b5563;
|
||||
}
|
||||
|
||||
.status {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
padding: 1rem;
|
||||
border-radius: 4px;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.status.recording {
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
}
|
||||
|
||||
.status.recording.recording-active {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
background: #dc2626;
|
||||
color: white;
|
||||
padding: 1.5rem;
|
||||
border: 4px solid #991b1b;
|
||||
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.4);
|
||||
font-weight: 700;
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
.status.recording.recording-active .indicator {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
background: #fef2f2;
|
||||
animation: pulse-strong 1s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse-strong {
|
||||
0%, 100% {
|
||||
opacity: 1;
|
||||
transform: scale(1);
|
||||
}
|
||||
50% {
|
||||
opacity: 0.7;
|
||||
transform: scale(1.1);
|
||||
}
|
||||
}
|
||||
|
||||
.status.recording.recording-active .time-display {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
font-size: 1.5rem;
|
||||
font-weight: 700;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.fps-display {
|
||||
font-size: 1rem;
|
||||
font-weight: 500;
|
||||
opacity: 0.95;
|
||||
}
|
||||
|
||||
.fps-warning {
|
||||
color: #fef2f2;
|
||||
animation: pulse-warning 1s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse-warning {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.5; }
|
||||
}
|
||||
|
||||
.status.recording.recording-active .btn-stop {
|
||||
align-self: stretch;
|
||||
}
|
||||
|
||||
.ramp-up-countdown {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.countdown-box {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 2rem 3rem;
|
||||
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
|
||||
border: 4px solid #f59e0b;
|
||||
border-radius: 16px;
|
||||
box-shadow: 0 6px 20px rgba(245, 158, 11, 0.4);
|
||||
min-width: 280px;
|
||||
animation: pulse-warm 1.5s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse-warm {
|
||||
0%, 100% {
|
||||
box-shadow: 0 6px 20px rgba(245, 158, 11, 0.4);
|
||||
}
|
||||
50% {
|
||||
box-shadow: 0 6px 25px rgba(245, 158, 11, 0.6);
|
||||
}
|
||||
}
|
||||
|
||||
.countdown-label {
|
||||
font-size: 1rem;
|
||||
color: #92400e;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 1.5px;
|
||||
font-weight: 800;
|
||||
margin-bottom: 1rem;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.countdown-value {
|
||||
font-size: 4.5rem;
|
||||
font-weight: 900;
|
||||
color: #d97706;
|
||||
font-family: 'Courier New', monospace;
|
||||
line-height: 1;
|
||||
text-shadow: 2px 2px 6px rgba(0, 0, 0, 0.15);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.countdown-subtitle {
|
||||
font-size: 0.875rem;
|
||||
color: #78350f;
|
||||
font-weight: 600;
|
||||
font-style: italic;
|
||||
text-align: center;
|
||||
margin-top: 0.5rem;
|
||||
}
|
||||
|
||||
.status.idle {
|
||||
background: #f3f4f6;
|
||||
color: #374151;
|
||||
}
|
||||
|
||||
.indicator {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border-radius: 50%;
|
||||
background: #ef4444;
|
||||
animation: pulse 1.5s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.5; }
|
||||
}
|
||||
|
||||
.counter {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
padding: 1.5rem;
|
||||
background: linear-gradient(135deg, #f9fafb 0%, #f3f4f6 100%);
|
||||
border-radius: 8px;
|
||||
border: 2px solid #e5e7eb;
|
||||
min-width: 200px;
|
||||
}
|
||||
|
||||
.counter-label {
|
||||
font-size: 0.75rem;
|
||||
color: #6b7280;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.counter-value {
|
||||
font-size: 3rem;
|
||||
font-weight: 700;
|
||||
color: #10b981;
|
||||
line-height: 1;
|
||||
}
|
||||
|
||||
.time-display {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 600;
|
||||
font-family: 'Courier New', monospace;
|
||||
}
|
||||
|
||||
.error-box {
|
||||
padding: 1rem;
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
border-radius: 4px;
|
||||
border-left: 4px solid #ef4444;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.config-section {
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.config-section:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.config-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
label {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
font-size: 0.875rem;
|
||||
color: #374151;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
select {
|
||||
padding: 0.5rem;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
background: white;
|
||||
}
|
||||
|
||||
select:disabled {
|
||||
background: #f5f5f5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
/* Camera Layout */
|
||||
.camera-layout {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.camera-base {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.camera-wrist-container {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(2, 1fr);
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.camera-wrist {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.camera {
|
||||
border: 1px solid #e5e7eb;
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.camera h3 {
|
||||
padding: 0.75rem;
|
||||
background: #f9fafb;
|
||||
border-bottom: 1px solid #e5e7eb;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.camera img {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
display: block;
|
||||
background: #000;
|
||||
min-height: 300px;
|
||||
object-fit: cover;
|
||||
}
|
||||
|
||||
.camera-placeholder {
|
||||
text-align: center;
|
||||
padding: 4rem 2rem;
|
||||
background: #f9fafb;
|
||||
border-radius: 4px;
|
||||
border: 2px dashed #d1d5db;
|
||||
}
|
||||
|
||||
.camera-placeholder p {
|
||||
margin: 0.5rem 0;
|
||||
font-size: 1rem;
|
||||
color: #6b7280;
|
||||
}
|
||||
|
||||
.camera-placeholder p:first-child {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 500;
|
||||
color: #374151;
|
||||
}
|
||||
|
||||
.hint {
|
||||
margin-top: 0.5rem;
|
||||
font-size: 0.75rem;
|
||||
color: #6b7280;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
857
examples/openarms_web_interface/App.jsx
Normal file
857
examples/openarms_web_interface/App.jsx
Normal file
@@ -0,0 +1,857 @@
|
||||
import { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import './App.css';
|
||||
|
||||
const API_BASE = 'http://localhost:8000/api';
|
||||
|
||||
function App() {
|
||||
// State
|
||||
const [task, setTask] = useState('');
|
||||
const [isRecording, setIsRecording] = useState(false);
|
||||
const [isInitializing, setIsInitializing] = useState(false);
|
||||
const [isEncoding, setIsEncoding] = useState(false);
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [robotsReady, setRobotsReady] = useState(false);
|
||||
const [elapsedTime, setElapsedTime] = useState(0);
|
||||
const [currentFps, setCurrentFps] = useState(0);
|
||||
const [loopFps, setLoopFps] = useState(0);
|
||||
const [episodeCount, setEpisodeCount] = useState(0);
|
||||
const [error, setError] = useState(null);
|
||||
const [statusMessage, setStatusMessage] = useState('Ready');
|
||||
const [uploadStatus, setUploadStatus] = useState(null);
|
||||
const [rampUpRemaining, setRampUpRemaining] = useState(0);
|
||||
const [movingToZero, setMovingToZero] = useState(false);
|
||||
const [configExpanded, setConfigExpanded] = useState(false);
|
||||
const [latestRepoId, setLatestRepoId] = useState(null);
|
||||
|
||||
// Configuration
|
||||
const [config, setConfig] = useState({
|
||||
leader_type: 'openarms', // 'openarms' or 'openarms_mini'
|
||||
leader_left: 'can0',
|
||||
leader_right: 'can1',
|
||||
follower_left: 'can2',
|
||||
follower_right: 'can3',
|
||||
left_wrist: '/dev/video0',
|
||||
right_wrist: '/dev/video1',
|
||||
base: '/dev/video4'
|
||||
});
|
||||
|
||||
// Available options
|
||||
const [availableCameras, setAvailableCameras] = useState([]);
|
||||
const [availableUsbPorts, setAvailableUsbPorts] = useState([]);
|
||||
const canInterfaces = ['can0', 'can1', 'can2', 'can3'];
|
||||
|
||||
const statusIntervalRef = useRef(null);
|
||||
const hasInitializedRef = useRef(false);
|
||||
|
||||
const loadConfig = () => {
|
||||
try {
|
||||
const saved = localStorage.getItem('openarms_config');
|
||||
if (saved) {
|
||||
const loadedConfig = JSON.parse(saved);
|
||||
setConfig(prev => ({ ...prev, ...loadedConfig }));
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Load config error:', e);
|
||||
}
|
||||
};
|
||||
|
||||
const saveConfig = (newConfig) => {
|
||||
try {
|
||||
localStorage.setItem('openarms_config', JSON.stringify(newConfig || config));
|
||||
} catch (e) {
|
||||
console.error('Save config error:', e);
|
||||
}
|
||||
};
|
||||
|
||||
// Fetch status periodically
|
||||
const fetchStatus = async () => {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/status`);
|
||||
const data = await response.json();
|
||||
|
||||
setIsRecording(data.is_recording);
|
||||
setIsInitializing(data.is_initializing);
|
||||
setIsEncoding(data.is_encoding);
|
||||
setIsUploading(data.is_uploading);
|
||||
setRobotsReady(data.robots_ready);
|
||||
setElapsedTime(data.elapsed_time);
|
||||
setCurrentFps(data.current_fps || 0);
|
||||
setLoopFps(data.loop_fps || 0);
|
||||
setEpisodeCount(data.episode_count);
|
||||
setError(data.error);
|
||||
setStatusMessage(data.status_message || 'Ready');
|
||||
setUploadStatus(data.upload_status);
|
||||
setRampUpRemaining(data.ramp_up_remaining || 0);
|
||||
setMovingToZero(data.moving_to_zero || false);
|
||||
|
||||
// Track the latest repo_id from the backend
|
||||
if (data.latest_repo_id) {
|
||||
setLatestRepoId(data.latest_repo_id);
|
||||
}
|
||||
|
||||
if (data.config) {
|
||||
// Only merge server config if we don't have a saved config (first load)
|
||||
if (!localStorage.getItem('openarms_config')) {
|
||||
setConfig(prev => {
|
||||
const merged = { ...data.config, ...prev };
|
||||
localStorage.setItem('openarms_config', JSON.stringify(merged));
|
||||
return merged;
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to fetch status:', e);
|
||||
}
|
||||
};
|
||||
|
||||
const setupRobots = async () => {
|
||||
// Show warning to verify camera positions
|
||||
const confirmed = window.confirm(
|
||||
'⚠️ IMPORTANT: Before connecting robots, please verify:\n\n' +
|
||||
'📹 Check that cameras are correctly positioned:\n' +
|
||||
' • LEFT wrist camera is actually on the LEFT arm\n' +
|
||||
' • RIGHT wrist camera is actually on the RIGHT arm\n' +
|
||||
' • BASE camera is actually the BASE/overhead camera\n\n' +
|
||||
'Incorrect camera positioning will result in invalid training data!\n\n' +
|
||||
'Click OK to continue with robot setup, or Cancel to review configuration.'
|
||||
);
|
||||
|
||||
if (!confirmed) {
|
||||
return; // User cancelled, don't proceed
|
||||
}
|
||||
|
||||
setError(null);
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/robots/setup`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(config)
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to setup robots');
|
||||
}
|
||||
|
||||
await response.json();
|
||||
saveConfig(config);
|
||||
} catch (e) {
|
||||
setError(`Robot setup failed: ${e.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Disconnect robots
|
||||
const disconnectRobots = async () => {
|
||||
try {
|
||||
await fetch(`${API_BASE}/robots/disconnect`, { method: 'POST' });
|
||||
setRobotsReady(false);
|
||||
} catch (e) {
|
||||
console.error('Failed to disconnect robots:', e);
|
||||
}
|
||||
};
|
||||
|
||||
// Discover cameras
|
||||
const discoverCameras = async () => {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/cameras/discover`);
|
||||
const data = await response.json();
|
||||
const cameras = data.cameras || [];
|
||||
setAvailableCameras(cameras);
|
||||
|
||||
// Get list of valid camera IDs
|
||||
const validCameraIds = cameras.map(cam => String(cam.id));
|
||||
|
||||
// Auto-fix config if current values are invalid or not set
|
||||
const updated = { ...config };
|
||||
let changed = false;
|
||||
|
||||
// Auto-fix invalid camera config
|
||||
if (!config.left_wrist || !validCameraIds.includes(config.left_wrist)) {
|
||||
if (cameras.length >= 1) {
|
||||
updated.left_wrist = String(cameras[0].id);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!config.right_wrist || !validCameraIds.includes(config.right_wrist)) {
|
||||
if (cameras.length >= 2) {
|
||||
updated.right_wrist = String(cameras[1].id);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!config.base || !validCameraIds.includes(config.base)) {
|
||||
if (cameras.length >= 3) {
|
||||
updated.base = String(cameras[2].id);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
setConfig(updated);
|
||||
saveConfig(updated);
|
||||
}
|
||||
|
||||
if (cameras.length === 0) {
|
||||
setError('No cameras detected! Please connect cameras and refresh.');
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to discover cameras:', e);
|
||||
setError(`Camera discovery failed: ${e.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Discover USB ports
|
||||
const discoverUsbPorts = async () => {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/usb/discover`);
|
||||
const data = await response.json();
|
||||
const ports = data.ports || [];
|
||||
setAvailableUsbPorts(ports);
|
||||
|
||||
// Auto-fix config if OpenArms Mini is selected and ports are invalid
|
||||
if (config.leader_type === 'openarms_mini') {
|
||||
const updated = { ...config };
|
||||
let changed = false;
|
||||
|
||||
if (ports.length >= 1 && !ports.includes(config.leader_left)) {
|
||||
updated.leader_left = ports[0];
|
||||
changed = true;
|
||||
}
|
||||
|
||||
if (ports.length >= 2 && !ports.includes(config.leader_right)) {
|
||||
updated.leader_right = ports[1];
|
||||
changed = true;
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
setConfig(updated);
|
||||
saveConfig(updated);
|
||||
}
|
||||
}
|
||||
|
||||
if (ports.length === 0) {
|
||||
console.warn('No USB ports detected for OpenArms Mini');
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to discover USB ports:', e);
|
||||
}
|
||||
};
|
||||
|
||||
// Set task only (for pedal use)
|
||||
const setTaskOnly = async () => {
|
||||
if (!task.trim()) {
|
||||
setError('Please enter a task description');
|
||||
return;
|
||||
}
|
||||
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/recording/set-task`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ task, ...config })
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to set task');
|
||||
}
|
||||
|
||||
const result = await response.json();
|
||||
setStatusMessage(result.message || `Task set: ${task}`);
|
||||
saveConfig(config);
|
||||
|
||||
// Clear success message after 3 seconds
|
||||
setTimeout(() => {
|
||||
if (!isRecording && !isInitializing) {
|
||||
setStatusMessage('Ready');
|
||||
}
|
||||
}, 3000);
|
||||
} catch (e) {
|
||||
setError(e.message);
|
||||
}
|
||||
};
|
||||
|
||||
// Start recording
|
||||
const startRecording = async () => {
|
||||
if (!task.trim()) {
|
||||
setError('Please enter a task description');
|
||||
return;
|
||||
}
|
||||
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/recording/start`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ task, ...config })
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to start recording');
|
||||
}
|
||||
|
||||
await response.json();
|
||||
saveConfig(config);
|
||||
} catch (e) {
|
||||
setError(e.message);
|
||||
}
|
||||
};
|
||||
|
||||
// Stop recording
|
||||
const stopRecording = async () => {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/recording/stop`, {
|
||||
method: 'POST'
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to stop recording');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
setError(null);
|
||||
// Update latest repo_id after recording
|
||||
if (data.dataset_name) {
|
||||
setLatestRepoId(`lerobot-data-collection/${data.dataset_name}`);
|
||||
}
|
||||
} catch (e) {
|
||||
setError(e.message);
|
||||
}
|
||||
};
|
||||
|
||||
const deleteLatestEpisode = async () => {
|
||||
if (!latestRepoId) {
|
||||
setError('No episode to delete');
|
||||
return;
|
||||
}
|
||||
|
||||
const confirmed = window.confirm(
|
||||
`WARNING: This will permanently delete the repository:\n\n${latestRepoId}\n\nThis action cannot be undone. Continue?`
|
||||
);
|
||||
|
||||
if (!confirmed) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/recording/delete-latest`, { method: 'POST' });
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to delete episode');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
setLatestRepoId(null);
|
||||
setEpisodeCount(Math.max(0, episodeCount - 1));
|
||||
setStatusMessage(`Deleted: ${data.deleted_repo}`);
|
||||
|
||||
setTimeout(() => {
|
||||
if (!isRecording && !isInitializing) {
|
||||
setStatusMessage('Ready');
|
||||
}
|
||||
}, 3000);
|
||||
} catch (e) {
|
||||
setError(`Delete failed: ${e.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Reset counter
|
||||
const resetCounter = async () => {
|
||||
try {
|
||||
await fetch(`${API_BASE}/counter/reset`, { method: 'POST' });
|
||||
setEpisodeCount(0);
|
||||
} catch (e) {
|
||||
console.error('Failed to reset counter:', e);
|
||||
}
|
||||
};
|
||||
|
||||
// Move robot to zero position
|
||||
const moveToZero = async () => {
|
||||
setError(null);
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/robots/move-to-zero`, { method: 'POST' });
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to move to zero position');
|
||||
}
|
||||
await response.json();
|
||||
} catch (e) {
|
||||
setError(`Move to zero failed: ${e.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Format time as MM:SS
|
||||
const formatTime = (seconds) => {
|
||||
const mins = Math.floor(seconds / 60);
|
||||
const secs = Math.floor(seconds % 60);
|
||||
return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}`;
|
||||
};
|
||||
|
||||
// Update config and save
|
||||
const updateConfig = (key, value) => {
|
||||
const updated = { ...config, [key]: value };
|
||||
setConfig(updated);
|
||||
saveConfig(updated);
|
||||
};
|
||||
|
||||
// Initialize on mount only
|
||||
useEffect(() => {
|
||||
// Prevent double-initialization in development
|
||||
if (hasInitializedRef.current) {
|
||||
return;
|
||||
}
|
||||
hasInitializedRef.current = true;
|
||||
|
||||
loadConfig();
|
||||
discoverCameras();
|
||||
discoverUsbPorts();
|
||||
fetchStatus();
|
||||
statusIntervalRef.current = setInterval(fetchStatus, 1000);
|
||||
|
||||
return () => {
|
||||
if (statusIntervalRef.current) {
|
||||
clearInterval(statusIntervalRef.current);
|
||||
}
|
||||
};
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []); // Run only once on mount
|
||||
|
||||
// Discover USB ports when leader type changes to Mini
|
||||
useEffect(() => {
|
||||
if (config.leader_type === 'openarms_mini') {
|
||||
discoverUsbPorts();
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [config.leader_type]);
|
||||
|
||||
return (
|
||||
<main>
|
||||
<header>
|
||||
<h1>OpenArms Recording</h1>
|
||||
</header>
|
||||
|
||||
<div className="container">
|
||||
{/* Left Column: Configuration and Recording Control */}
|
||||
<div className="left-column">
|
||||
{/* Configuration Panel */}
|
||||
<section className="panel config-panel">
|
||||
<div
|
||||
className="config-header"
|
||||
onClick={() => setConfigExpanded(!configExpanded)}
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
onKeyDown={(e) => e.key === 'Enter' && setConfigExpanded(!configExpanded)}
|
||||
>
|
||||
<h2>⚙️ Configuration</h2>
|
||||
<span className="toggle-icon">{configExpanded ? '▼' : '▶'}</span>
|
||||
</div>
|
||||
|
||||
{configExpanded && (
|
||||
<div className="config-content">
|
||||
{/* Robot Setup */}
|
||||
<div className="config-section">
|
||||
<h3>🤖 Robot Setup</h3>
|
||||
<div className="robot-setup">
|
||||
{robotsReady ? (
|
||||
<div className="robot-status ready">
|
||||
<span>✅ Robots Ready - Recording will start instantly</span>
|
||||
<button onClick={disconnectRobots} className="btn-disconnect">
|
||||
Disconnect Robots
|
||||
</button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="robot-status not-ready">
|
||||
<span>⚠️ Robots not initialized - Recording will take ~10 seconds</span>
|
||||
<button
|
||||
onClick={setupRobots}
|
||||
disabled={isRecording || isInitializing}
|
||||
className="btn-setup"
|
||||
>
|
||||
🚀 Setup Robots
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Leader Type Selection */}
|
||||
<div className="config-section">
|
||||
<h3>🎮 Leader Type</h3>
|
||||
<div className="config-grid">
|
||||
<label style={{gridColumn: '1 / -1'}}>
|
||||
Leader Arm Type
|
||||
<select
|
||||
value={config.leader_type}
|
||||
onChange={(e) => updateConfig('leader_type', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
<option value="openarms">OpenArms (CAN Bus - Damiao Motors)</option>
|
||||
<option value="openarms_mini">OpenArms Mini (USB - Feetech Motors)</option>
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Leader Interfaces (CAN or USB based on type) */}
|
||||
<div className="config-section">
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
|
||||
<h3>
|
||||
{config.leader_type === 'openarms_mini'
|
||||
? `Leader Ports (USB/Serial) ${availableUsbPorts.length > 0 ? `(${availableUsbPorts.length} detected)` : ''}`
|
||||
: 'Leader Interfaces (CAN)'}
|
||||
</h3>
|
||||
{config.leader_type === 'openarms_mini' && (
|
||||
<button
|
||||
onClick={discoverUsbPorts}
|
||||
className="btn-refresh"
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
🔄 Refresh
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="config-grid">
|
||||
<label>
|
||||
Leader Left
|
||||
<select
|
||||
value={config.leader_left}
|
||||
onChange={(e) => updateConfig('leader_left', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{config.leader_type === 'openarms_mini' ? (
|
||||
availableUsbPorts.length > 0 ? (
|
||||
availableUsbPorts.map((port) => (
|
||||
<option key={port} value={port}>{port}</option>
|
||||
))
|
||||
) : (
|
||||
<option value="">No USB ports detected</option>
|
||||
)
|
||||
) : (
|
||||
canInterfaces.map((iface) => (
|
||||
<option key={iface} value={iface}>{iface}</option>
|
||||
))
|
||||
)}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Leader Right
|
||||
<select
|
||||
value={config.leader_right}
|
||||
onChange={(e) => updateConfig('leader_right', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{config.leader_type === 'openarms_mini' ? (
|
||||
availableUsbPorts.length > 0 ? (
|
||||
availableUsbPorts.map((port) => (
|
||||
<option key={port} value={port}>{port}</option>
|
||||
))
|
||||
) : (
|
||||
<option value="">No USB ports detected</option>
|
||||
)
|
||||
) : (
|
||||
canInterfaces.map((iface) => (
|
||||
<option key={iface} value={iface}>{iface}</option>
|
||||
))
|
||||
)}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Follower CAN Interfaces */}
|
||||
<div className="config-section">
|
||||
<h3>Follower Interfaces (CAN)</h3>
|
||||
|
||||
<div className="config-grid">
|
||||
<label>
|
||||
Follower Left
|
||||
<select
|
||||
value={config.follower_left}
|
||||
onChange={(e) => updateConfig('follower_left', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{canInterfaces.map((iface) => (
|
||||
<option key={iface} value={iface}>{iface}</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Follower Right
|
||||
<select
|
||||
value={config.follower_right}
|
||||
onChange={(e) => updateConfig('follower_right', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{canInterfaces.map((iface) => (
|
||||
<option key={iface} value={iface}>{iface}</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Camera Configuration */}
|
||||
<div className="config-section">
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
|
||||
<h3>Cameras {availableCameras.length > 0 && `(${availableCameras.length} detected)`}</h3>
|
||||
<button
|
||||
onClick={discoverCameras}
|
||||
className="btn-refresh"
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
🔄 Refresh
|
||||
</button>
|
||||
</div>
|
||||
<div className="config-grid">
|
||||
<label>
|
||||
Left Wrist
|
||||
<select
|
||||
value={config.left_wrist}
|
||||
onChange={(e) => updateConfig('left_wrist', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{availableCameras.map((cam) => (
|
||||
<option key={cam.id} value={String(cam.id)}>
|
||||
{cam.name || `Camera @ ${cam.id}`}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Right Wrist
|
||||
<select
|
||||
value={config.right_wrist}
|
||||
onChange={(e) => updateConfig('right_wrist', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{availableCameras.map((cam) => (
|
||||
<option key={cam.id} value={String(cam.id)}>
|
||||
{cam.name || `Camera @ ${cam.id}`}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Base Camera
|
||||
<select
|
||||
value={config.base}
|
||||
onChange={(e) => updateConfig('base', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{availableCameras.map((cam) => (
|
||||
<option key={cam.id} value={String(cam.id)}>
|
||||
{cam.name || `Camera @ ${cam.id}`}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
|
||||
{/* Control Panel */}
|
||||
<section className="panel control-panel">
|
||||
<h2>🎬 Recording Control</h2>
|
||||
|
||||
{/* Status Banner - Always show important statuses */}
|
||||
{isInitializing && (
|
||||
<div className="status-banner initializing">
|
||||
<div className="spinner"></div>
|
||||
<span>{statusMessage}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isEncoding && (
|
||||
<div className="status-banner encoding">
|
||||
<div className="spinner"></div>
|
||||
<span>📹 {statusMessage}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isUploading && (
|
||||
<div className="status-banner uploading">
|
||||
<div className="spinner"></div>
|
||||
<span>☁️ {statusMessage}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{uploadStatus && !isRecording && !isEncoding && !isUploading && (
|
||||
<div className={`status-banner ${uploadStatus.startsWith('✓') ? 'success' : 'warning'}`}>
|
||||
<span>{uploadStatus}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="control-horizontal">
|
||||
{/* Task Input and Status */}
|
||||
<div className="control-left">
|
||||
<div className="input-group">
|
||||
<input
|
||||
type="text"
|
||||
value={task}
|
||||
onChange={(e) => setTask(e.target.value)}
|
||||
placeholder="Task description (e.g., 'pick and place')"
|
||||
disabled={isRecording || isInitializing || isEncoding || isUploading}
|
||||
onKeyPress={(e) => {
|
||||
if (e.key === 'Enter' && robotsReady) {
|
||||
setTaskOnly();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<button
|
||||
onClick={setTaskOnly}
|
||||
disabled={isRecording || isInitializing || isEncoding || isUploading || !robotsReady}
|
||||
className="btn-set-task"
|
||||
title={!robotsReady ? 'Please setup robots first' : 'Store task for pedal use (Enter key)'}
|
||||
>
|
||||
💾 Set Task
|
||||
</button>
|
||||
<button
|
||||
onClick={startRecording}
|
||||
disabled={isRecording || isInitializing || isEncoding || isUploading || !robotsReady}
|
||||
className="btn-start"
|
||||
title={!robotsReady ? 'Please setup robots first' : ''}
|
||||
>
|
||||
{isInitializing
|
||||
? '⏳ Initializing...'
|
||||
: isRecording
|
||||
? '⏺ Recording...'
|
||||
: robotsReady
|
||||
? '⏺ Start Recording'
|
||||
: '⏺ Setup Robots First'}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Ramp-up Countdown */}
|
||||
{isRecording && rampUpRemaining > 0 && (
|
||||
<div className="ramp-up-countdown">
|
||||
<div className="countdown-box">
|
||||
<div className="countdown-label">⚡ WARMING UP - PID RAMP-UP</div>
|
||||
<div className="countdown-value">{rampUpRemaining.toFixed(1)}s</div>
|
||||
<div className="countdown-subtitle">Recording will start automatically...</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Recording Status - Only show after ramp-up */}
|
||||
{isRecording && rampUpRemaining <= 0 && (
|
||||
<div className="status recording recording-active">
|
||||
<div className="indicator"></div>
|
||||
<div className="time-display">
|
||||
<span>{formatTime(elapsedTime)}</span>
|
||||
<span className="fps-display">
|
||||
Loop: {loopFps.toFixed(1)} Hz
|
||||
{loopFps > 0 && loopFps < 29 && <span className="fps-warning"> ⚠️</span>}
|
||||
</span>
|
||||
<span className="fps-display">Recording: {currentFps.toFixed(1)} FPS</span>
|
||||
</div>
|
||||
<button onClick={stopRecording} className="btn-stop">
|
||||
⏹ Stop
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Episode Counter */}
|
||||
<div className="control-right">
|
||||
<div className="counter">
|
||||
<div className="counter-label">Episodes Recorded</div>
|
||||
<div className="counter-value">{episodeCount}</div>
|
||||
<button onClick={resetCounter} className="btn-reset">
|
||||
Reset
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Delete Latest Episode Button */}
|
||||
{!isRecording && !isInitializing && latestRepoId && (
|
||||
<div className="delete-episode-section">
|
||||
<button
|
||||
onClick={deleteLatestEpisode}
|
||||
className="btn-delete"
|
||||
title="Delete the latest recorded episode from HuggingFace Hub"
|
||||
>
|
||||
Delete Latest Episode
|
||||
</button>
|
||||
<div className="delete-info">Will delete: {latestRepoId}</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Move to Zero Button */}
|
||||
{robotsReady && !isRecording && !isInitializing && (
|
||||
<div className="zero-position-section">
|
||||
<button
|
||||
onClick={moveToZero}
|
||||
disabled={movingToZero}
|
||||
className="btn-zero-large"
|
||||
title="Move both leader and follower robots to zero position (2s)"
|
||||
>
|
||||
{movingToZero ? '⏳ Moving to Zero Position...' : '🎯 Move to Zero Position (Leader + Follower)'}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Error Display */}
|
||||
{error && (
|
||||
<div className="error-box">
|
||||
⚠️ {error}
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
</div>
|
||||
|
||||
{/* Right Column: Camera Feeds */}
|
||||
<div className="right-column">
|
||||
<section className="panel cameras">
|
||||
<h2>📹 Camera Views</h2>
|
||||
{robotsReady || isRecording || isInitializing ? (
|
||||
<div className="camera-layout">
|
||||
{/* Base camera - full width */}
|
||||
<div className="camera camera-base">
|
||||
<h3>Base Camera</h3>
|
||||
<img src={`${API_BASE}/camera/stream/base`} alt="Base Camera" />
|
||||
</div>
|
||||
|
||||
{/* Wrist cameras - side by side */}
|
||||
<div className="camera-wrist-container">
|
||||
<div className="camera camera-wrist">
|
||||
<h3>Left Wrist</h3>
|
||||
<img src={`${API_BASE}/camera/stream/left_wrist`} alt="Left Wrist Camera" />
|
||||
</div>
|
||||
|
||||
<div className="camera camera-wrist">
|
||||
<h3>Right Wrist</h3>
|
||||
<img src={`${API_BASE}/camera/stream/right_wrist`} alt="Right Wrist Camera" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="camera-placeholder">
|
||||
<p>📷 Camera feeds will appear when robots are set up</p>
|
||||
<p className="hint">Click "Setup Robots" above to preview camera feeds</p>
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</main>
|
||||
);
|
||||
}
|
||||
|
||||
export default App;
|
||||
|
||||
41
examples/openarms_web_interface/README.md
Normal file
41
examples/openarms_web_interface/README.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# OpenArms Web Recording Interface
|
||||
|
||||
A web interface for recording OpenArms datasets.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
cd examples/openarms_web_interface
|
||||
npm install
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
**Start everything with one command:**
|
||||
|
||||
```bash
|
||||
./launch.sh
|
||||
```
|
||||
|
||||
This will:
|
||||
- Start the FastAPI backend on port 8000
|
||||
- Start the React frontend on port 5173
|
||||
- Show live logs from both services
|
||||
|
||||
Then open your browser to: **http://localhost:5173**
|
||||
|
||||
**Stop with:** `Ctrl+C`
|
||||
|
||||
---
|
||||
|
||||
## Workflow
|
||||
|
||||
1. **Configure CAN interfaces** and **camera paths** in the dropdowns
|
||||
2. Click **"Setup Robots"** to initialize (once at start)
|
||||
3. Enter a **task description**
|
||||
4. Click **"Start Recording"** to begin an episode
|
||||
5. Click **"Stop Recording"** when done
|
||||
6. Dataset is automatically encoded and uploaded to HuggingFace Hub as **private**
|
||||
7. Repeat steps 3-6 for more episodes (no need to re-setup robots!)
|
||||
|
||||
---
|
||||
12
examples/openarms_web_interface/index.html
Normal file
12
examples/openarms_web_interface/index.html
Normal file
@@ -0,0 +1,12 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>OpenArms Recording Interface</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/main.jsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
142
examples/openarms_web_interface/launch.sh
Executable file
142
examples/openarms_web_interface/launch.sh
Executable file
@@ -0,0 +1,142 @@
|
||||
#!/bin/bash
|
||||
|
||||
# OpenArms Web Interface Launcher
|
||||
# Starts Rerun viewer, FastAPI backend, and React frontend
|
||||
|
||||
set -e
|
||||
|
||||
# Colors for output
|
||||
GREEN='\033[0;32m'
|
||||
BLUE='\033[0;34m'
|
||||
YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Get script directory
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
echo -e "${BLUE}╔════════════════════════════════════════╗${NC}"
|
||||
echo -e "${BLUE}║ OpenArms Web Recording Interface ║${NC}"
|
||||
echo -e "${BLUE}╚════════════════════════════════════════╝${NC}"
|
||||
echo ""
|
||||
|
||||
# Function to cleanup on exit
|
||||
cleanup() {
|
||||
echo ""
|
||||
echo -e "${YELLOW}Shutting down services...${NC}"
|
||||
|
||||
# Kill all child processes
|
||||
pkill -P $$ 2>/dev/null || true
|
||||
|
||||
# Kill specific services by port
|
||||
lsof -ti:8000 | xargs kill -9 2>/dev/null || true # Backend
|
||||
lsof -ti:5173 | xargs kill -9 2>/dev/null || true # Frontend
|
||||
lsof -ti:9876 | xargs kill -9 2>/dev/null || true # Rerun (if spawned)
|
||||
|
||||
echo -e "${GREEN}✓ Services stopped${NC}"
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Register cleanup on script exit
|
||||
trap cleanup EXIT INT TERM
|
||||
|
||||
# Check if required commands exist
|
||||
command -v rerun >/dev/null 2>&1 || {
|
||||
echo -e "${RED}✗ Error: 'rerun' not found. Please install: pip install rerun-sdk${NC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
command -v python >/dev/null 2>&1 || {
|
||||
echo -e "${RED}✗ Error: 'python' not found${NC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
command -v npm >/dev/null 2>&1 || {
|
||||
echo -e "${RED}✗ Error: 'npm' not found${NC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Check if node_modules exists
|
||||
if [ ! -d "node_modules" ]; then
|
||||
echo -e "${YELLOW}⚠ node_modules not found. Running npm install...${NC}"
|
||||
npm install
|
||||
echo -e "${GREEN}✓ Dependencies installed${NC}"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}Starting services...${NC}"
|
||||
echo ""
|
||||
|
||||
# 1. Start FastAPI backend (Rerun will start when recording begins)
|
||||
echo -e "${BLUE}[1/2]${NC} Starting FastAPI backend on port 8000..."
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# Use Python from current environment (if lerobot env is active, it will use that)
|
||||
# Otherwise, check if we need to use conda run
|
||||
if [[ "$CONDA_DEFAULT_ENV" == "lerobot" ]]; then
|
||||
# Already in lerobot environment
|
||||
echo -e "${GREEN}✓ Using active lerobot environment${NC}"
|
||||
PYTHON_CMD="python"
|
||||
elif command -v conda >/dev/null 2>&1 && conda env list | grep -q "^lerobot "; then
|
||||
# lerobot env exists but not active - use conda run
|
||||
echo -e "${YELLOW}Using conda run with lerobot environment...${NC}"
|
||||
PYTHON_CMD="conda run -n lerobot --no-capture-output python"
|
||||
else
|
||||
# Fall back to system python
|
||||
echo -e "${YELLOW}⚠ Warning: lerobot environment not found, using system python${NC}"
|
||||
PYTHON_CMD="python"
|
||||
fi
|
||||
|
||||
$PYTHON_CMD web_record_server.py > /tmp/openarms_backend.log 2>&1 &
|
||||
BACKEND_PID=$!
|
||||
sleep 3
|
||||
|
||||
if ps -p $BACKEND_PID > /dev/null; then
|
||||
echo -e "${GREEN}✓ Backend started${NC} (PID: $BACKEND_PID)"
|
||||
echo -e " URL: ${BLUE}http://localhost:8000${NC}"
|
||||
else
|
||||
echo -e "${RED}✗ Failed to start backend${NC}"
|
||||
echo -e "${YELLOW}Check logs: tail -f /tmp/openarms_backend.log${NC}"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# 2. Start React frontend
|
||||
echo -e "${BLUE}[2/2]${NC} Starting React frontend on port 5173..."
|
||||
cd "$SCRIPT_DIR"
|
||||
npm run dev > /tmp/openarms_frontend.log 2>&1 &
|
||||
FRONTEND_PID=$!
|
||||
sleep 3
|
||||
|
||||
if ps -p $FRONTEND_PID > /dev/null; then
|
||||
echo -e "${GREEN}✓ Frontend started${NC} (PID: $FRONTEND_PID)"
|
||||
echo -e " URL: ${BLUE}http://localhost:5173${NC}"
|
||||
else
|
||||
echo -e "${RED}✗ Failed to start frontend${NC}"
|
||||
echo -e "${YELLOW}Check logs: tail -f /tmp/openarms_frontend.log${NC}"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Display status
|
||||
echo -e "${GREEN}╔════════════════════════════════════════╗${NC}"
|
||||
echo -e "${GREEN}║ All services running! 🚀 ║${NC}"
|
||||
echo -e "${GREEN}╚════════════════════════════════════════╝${NC}"
|
||||
echo ""
|
||||
echo -e "🔧 ${BLUE}Backend:${NC} http://localhost:8000"
|
||||
echo -e "🌐 ${BLUE}Frontend:${NC} http://localhost:5173"
|
||||
echo -e "📊 ${BLUE}Rerun:${NC} Will spawn automatically when recording starts"
|
||||
echo ""
|
||||
echo -e "${YELLOW}Open your browser to:${NC} ${BLUE}http://localhost:5173${NC}"
|
||||
echo ""
|
||||
echo -e "${YELLOW}Logs:${NC}"
|
||||
echo -e " • Backend: tail -f /tmp/openarms_backend.log"
|
||||
echo -e " • Frontend: tail -f /tmp/openarms_frontend.log"
|
||||
echo ""
|
||||
echo -e "${RED}Press Ctrl+C to stop all services${NC}"
|
||||
echo ""
|
||||
|
||||
# Keep script running and wait for any service to exit
|
||||
wait
|
||||
|
||||
7
examples/openarms_web_interface/main.jsx
Normal file
7
examples/openarms_web_interface/main.jsx
Normal file
@@ -0,0 +1,7 @@
|
||||
import { createRoot } from 'react-dom/client'
|
||||
import App from './App.jsx'
|
||||
|
||||
createRoot(document.getElementById('root')).render(
|
||||
<App />
|
||||
)
|
||||
|
||||
1955
examples/openarms_web_interface/package-lock.json
generated
Normal file
1955
examples/openarms_web_interface/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
21
examples/openarms_web_interface/package.json
Normal file
21
examples/openarms_web_interface/package.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"name": "openarms-web-interface",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/react": "^18.3.12",
|
||||
"@types/react-dom": "^18.3.1",
|
||||
"@vitejs/plugin-react": "^4.3.4",
|
||||
"vite": "^6.0.1"
|
||||
}
|
||||
}
|
||||
17
examples/openarms_web_interface/vite.config.js
Normal file
17
examples/openarms_web_interface/vite.config.js
Normal file
@@ -0,0 +1,17 @@
|
||||
import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react'
|
||||
|
||||
// https://vite.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
server: {
|
||||
port: 5173,
|
||||
strictPort: false,
|
||||
host: true,
|
||||
open: false
|
||||
},
|
||||
build: {
|
||||
outDir: 'dist',
|
||||
sourcemap: true
|
||||
}
|
||||
})
|
||||
1533
examples/openarms_web_interface/web_record_server.py
Normal file
1533
examples/openarms_web_interface/web_record_server.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -194,4 +194,6 @@ for episode_idx in range(NUM_EPISODES):
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
@@ -200,4 +200,6 @@ log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
@@ -362,6 +362,8 @@ def port_droid(
|
||||
lerobot_dataset.save_episode()
|
||||
logging.info("Save_episode")
|
||||
|
||||
lerobot_dataset.finalize()
|
||||
|
||||
if push_to_hub:
|
||||
lerobot_dataset.push_to_hub(
|
||||
# Add openx tag, since it belongs to the openx collection of datasets
|
||||
|
||||
@@ -195,4 +195,6 @@ for episode_idx in range(NUM_EPISODES):
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
@@ -199,4 +199,6 @@ log_say("Stop recording")
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
98
examples/tutorial/act/act_training_example.py
Normal file
98
examples/tutorial/act/act_training_example.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""This script demonstrates how to train ACT Policy on a real-world dataset."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
|
||||
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
|
||||
if delta_indices is None:
|
||||
return [0]
|
||||
|
||||
return [i / fps for i in delta_indices]
|
||||
|
||||
|
||||
output_directory = Path("outputs/robot_learning_tutorial/act")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
policy = ACTPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
|
||||
}
|
||||
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
57
examples/tutorial/act/act_using_example.py
Normal file
57
examples/tutorial/act/act_using_example.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "fracapuano/robot_learning_tutorial_act"
|
||||
model = ACTPolicy.from_pretrained(model_id)
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
11
examples/tutorial/async-inf/policy_server.py
Normal file
11
examples/tutorial/async-inf/policy_server.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
host = ... # something like "127.0.0.1" if you're exposing to localhost
|
||||
port = ... # something like 8080
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
serve(config)
|
||||
55
examples/tutorial/async-inf/robot_client.py
Normal file
55
examples/tutorial/async-inf/robot_client.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import threading
|
||||
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.helpers import visualize_action_queue_size
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
|
||||
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
|
||||
# check the config.json on the Hub for the policy you are using to see the expected camera specs
|
||||
camera_cfg = {
|
||||
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
|
||||
|
||||
server_address = ... # something like "127.0.0.1:8080" if using localhost
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="fracapuano/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
# 5. Provide a textual description of the task
|
||||
task = ...
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
99
examples/tutorial/diffusion/diffusion_training_example.py
Normal file
99
examples/tutorial/diffusion/diffusion_training_example.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""This script demonstrates how to train Diffusion Policy on a real-world dataset."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
|
||||
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
|
||||
if delta_indices is None:
|
||||
return [0]
|
||||
|
||||
return [i / fps for i in delta_indices]
|
||||
|
||||
|
||||
output_directory = Path("outputs/robot_learning_tutorial/diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
||||
policy = DiffusionPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
|
||||
}
|
||||
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
60
examples/tutorial/diffusion/diffusion_using_example.py
Normal file
60
examples/tutorial/diffusion/diffusion_using_example.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "fracapuano/robot_learning_tutorial_diffusion"
|
||||
|
||||
model = DiffusionPolicy.from_pretrained(model_id)
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config, model_id, dataset_stats=dataset_metadata.stats
|
||||
)
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
67
examples/tutorial/pi0/using_pi0_example.py
Normal file
67
examples/tutorial/pi0/using_pi0_example.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/pi0_base"
|
||||
|
||||
model = PI0Policy.from_pretrained(model_id)
|
||||
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"base_0_rgb": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"left_wrist_0_rgb": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
"right_wrist_0_rgb": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
345
examples/tutorial/rl/hilserl_example.py
Normal file
345
examples/tutorial/rl/hilserl_example.py
Normal file
@@ -0,0 +1,345 @@
|
||||
import multiprocessing as mp
|
||||
import signal
|
||||
from pathlib import Path
|
||||
from queue import Empty, Full
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
from lerobot.teleoperators.so100_leader import SO100LeaderConfig
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
LOG_EVERY = 10
|
||||
SEND_EVERY = 10
|
||||
|
||||
|
||||
def run_learner(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_learner: SACPolicy,
|
||||
online_buffer: ReplayBuffer,
|
||||
offline_buffer: ReplayBuffer,
|
||||
lr: float = 3e-4,
|
||||
batch_size: int = 32,
|
||||
device: torch.device = "mps",
|
||||
):
|
||||
"""The learner process - trains SAC policy on transitions streamed from the actor, updating parameters
|
||||
for the actor to adopt."""
|
||||
policy_learner.train()
|
||||
policy_learner.to(device)
|
||||
|
||||
# Create Adam optimizer from scratch - simple and clean
|
||||
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
|
||||
|
||||
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
|
||||
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
|
||||
|
||||
training_step = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
# retrieve incoming transitions from the actor process
|
||||
try:
|
||||
transitions = transitions_queue.get(timeout=0.1)
|
||||
for transition in transitions:
|
||||
# HIL-SERL: Add ALL transitions to online buffer
|
||||
online_buffer.add(**transition)
|
||||
|
||||
# HIL-SERL: Add ONLY human intervention transitions to offline buffer
|
||||
is_intervention = transition.get("complementary_info", {}).get("is_intervention", False)
|
||||
if is_intervention:
|
||||
offline_buffer.add(**transition)
|
||||
print(
|
||||
f"[LEARNER] Human intervention detected! Added to offline buffer (now {len(offline_buffer)} transitions)"
|
||||
)
|
||||
|
||||
except Empty:
|
||||
pass # No transitions available, continue
|
||||
|
||||
# Train if we have enough data
|
||||
if len(online_buffer) >= policy_learner.config.online_step_before_learning:
|
||||
# Sample from online buffer (autonomous + human data)
|
||||
online_batch = online_buffer.sample(batch_size // 2)
|
||||
|
||||
# Sample from offline buffer (human demonstrations only, either precollected or at runtime)
|
||||
offline_batch = offline_buffer.sample(batch_size // 2)
|
||||
|
||||
# Combine batches - this is the key HIL-SERL mechanism!
|
||||
batch = {}
|
||||
for key in online_batch:
|
||||
if key in offline_batch:
|
||||
batch[key] = torch.cat([online_batch[key], offline_batch[key]], dim=0)
|
||||
else:
|
||||
batch[key] = online_batch[key]
|
||||
|
||||
loss, _ = policy_learner.forward(batch)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
training_step += 1
|
||||
|
||||
if training_step % LOG_EVERY == 0:
|
||||
print(
|
||||
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
|
||||
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
|
||||
)
|
||||
|
||||
# Send updated parameters to actor every 10 training steps
|
||||
if training_step % SEND_EVERY == 0:
|
||||
try:
|
||||
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
|
||||
parameters_queue.put_nowait(state_dict)
|
||||
print("[LEARNER] Sent updated parameters to actor")
|
||||
except Full:
|
||||
# Missing write due to queue not being consumed (should happen rarely)
|
||||
pass
|
||||
|
||||
print("[LEARNER] Learner process finished")
|
||||
|
||||
|
||||
def run_actor(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_actor: SACPolicy,
|
||||
reward_classifier: Classifier,
|
||||
env_cfg: HILSerlRobotEnvConfig,
|
||||
device: torch.device = "mps",
|
||||
output_directory: Path | None = None,
|
||||
):
|
||||
"""The actor process - interacts with environment and collects data.
|
||||
The policy is frozen and only the parameters are updated, popping the most recent ones from a queue."""
|
||||
policy_actor.eval()
|
||||
policy_actor.to(device)
|
||||
|
||||
reward_classifier.eval()
|
||||
reward_classifier.to(device)
|
||||
|
||||
# Create robot environment inside the actor process
|
||||
env, teleop_device = make_robot_env(env_cfg)
|
||||
|
||||
try:
|
||||
for episode in range(MAX_EPISODES):
|
||||
if shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs, _info = env.reset()
|
||||
episode_reward = 0.0
|
||||
step = 0
|
||||
episode_transitions = []
|
||||
|
||||
print(f"[ACTOR] Starting episode {episode + 1}")
|
||||
|
||||
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
|
||||
try:
|
||||
new_params = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_params)
|
||||
print("[ACTOR] Updated policy parameters from learner")
|
||||
except Empty: # No new updated parameters available from learner, waiting
|
||||
pass
|
||||
|
||||
# Get action from policy
|
||||
policy_obs = make_policy_obs(obs, device=device)
|
||||
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
|
||||
action = action_tensor.squeeze(0).cpu().numpy()
|
||||
|
||||
# Step environment
|
||||
next_obs, _env_reward, terminated, truncated, _info = env.step(action)
|
||||
done = terminated or truncated
|
||||
|
||||
# Predict reward
|
||||
policy_next_obs = make_policy_obs(next_obs, device=device)
|
||||
reward = reward_classifier.predict_reward(policy_next_obs)
|
||||
|
||||
if reward >= 1.0 and not done: # success detected! halt episode
|
||||
terminated = True
|
||||
done = True
|
||||
|
||||
# In HIL-SERL, human interventions come from the teleop device
|
||||
is_intervention = False
|
||||
if hasattr(teleop_device, "get_teleop_events"):
|
||||
# Real intervention detection from teleop device
|
||||
teleop_events = teleop_device.get_teleop_events()
|
||||
is_intervention = teleop_events.get(TeleopEvents.IS_INTERVENTION, False)
|
||||
|
||||
# Store transition with intervention metadata
|
||||
transition = {
|
||||
"state": policy_obs,
|
||||
"action": action,
|
||||
"reward": float(reward) if hasattr(reward, "item") else reward,
|
||||
"next_state": policy_next_obs,
|
||||
"done": done,
|
||||
"truncated": truncated,
|
||||
"complementary_info": {
|
||||
"is_intervention": is_intervention,
|
||||
},
|
||||
}
|
||||
|
||||
episode_transitions.append(transition)
|
||||
|
||||
episode_reward += reward
|
||||
step += 1
|
||||
|
||||
obs = next_obs
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Send episode transitions to learner
|
||||
transitions_queue.put_nowait(episode_transitions)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("[ACTOR] Interrupted by user")
|
||||
finally:
|
||||
# Clean up
|
||||
if hasattr(env, "robot") and env.robot.is_connected:
|
||||
env.robot.disconnect()
|
||||
if teleop_device and hasattr(teleop_device, "disconnect"):
|
||||
teleop_device.disconnect()
|
||||
if output_directory is not None:
|
||||
policy_actor.save_pretrained(output_directory)
|
||||
print(f"[ACTOR] Latest actor policy saved at: {output_directory}")
|
||||
|
||||
print("[ACTOR] Actor process finished")
|
||||
|
||||
|
||||
def make_policy_obs(obs, device: torch.device = "cpu"):
|
||||
return {
|
||||
"observation.state": torch.from_numpy(obs["agent_pos"]).float().unsqueeze(0).to(device),
|
||||
**{
|
||||
f"observation.image.{k}": torch.from_numpy(obs["pixels"][k]).float().unsqueeze(0).to(device)
|
||||
for k in obs["pixels"]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
"""Main function - coordinates actor and learner processes."""
|
||||
|
||||
device = "mps" # or "cuda" or "cpu"
|
||||
output_directory = Path("outputs/robot_learning_tutorial/hil_serl")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ...
|
||||
leader_port = ...
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ...
|
||||
leader_id = ...
|
||||
|
||||
# A pretrained model (to be used in-distribution!)
|
||||
reward_classifier_id = "fracapuano/reward_classifier_hil_serl_example"
|
||||
reward_classifier = Classifier.from_pretrained(reward_classifier_id)
|
||||
|
||||
reward_classifier.to(device)
|
||||
reward_classifier.eval()
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
# Robot and environment configuration
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id)
|
||||
teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id)
|
||||
processor_cfg = HILSerlProcessorConfig(control_mode="leader")
|
||||
|
||||
env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg)
|
||||
|
||||
# Create robot environment
|
||||
env, teleop_device = make_robot_env(env_cfg)
|
||||
|
||||
obs_features = hw_to_dataset_features(env.robot.observation_features, "observation")
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = SACConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
|
||||
# Online buffer: initialized from scratch
|
||||
online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys()))
|
||||
# Offline buffer: Created from dataset (pre-populated it with demonstrations)
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys())
|
||||
)
|
||||
|
||||
# Create communication channels between learner and actor processes
|
||||
transitions_queue = mp.Queue(maxsize=10)
|
||||
parameters_queue = mp.Queue(maxsize=2)
|
||||
shutdown_event = mp.Event()
|
||||
|
||||
|
||||
# Signal handler for graceful shutdown
|
||||
def signal_handler(sig):
|
||||
print(f"\nSignal {sig} received, shutting down...")
|
||||
shutdown_event.set()
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Create processes
|
||||
learner_process = mp.Process(
|
||||
target=run_learner,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_learner,
|
||||
online_replay_buffer,
|
||||
offline_replay_buffer,
|
||||
),
|
||||
kwargs={"device": device}, # can run on accelerated hardware for training
|
||||
)
|
||||
|
||||
actor_process = mp.Process(
|
||||
target=run_actor,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_actor,
|
||||
reward_classifier,
|
||||
env_cfg,
|
||||
output_directory,
|
||||
),
|
||||
kwargs={"device": "cpu"}, # actor is frozen, can run on CPU or accelerate for inference
|
||||
)
|
||||
|
||||
learner_process.start()
|
||||
actor_process.start()
|
||||
|
||||
try:
|
||||
# Wait for actor to finish (it controls the episode loop)
|
||||
actor_process.join()
|
||||
shutdown_event.set()
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Main process interrupted")
|
||||
shutdown_event.set()
|
||||
actor_process.join(timeout=5)
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
finally:
|
||||
if learner_process.is_alive():
|
||||
learner_process.terminate()
|
||||
if actor_process.is_alive():
|
||||
actor_process.terminate()
|
||||
62
examples/tutorial/rl/reward_classifier_example.py
Normal file
62
examples/tutorial/rl/reward_classifier_example.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
# Device to use for training
|
||||
device = "mps" # or "cuda", or "cpu"
|
||||
|
||||
# Load the dataset used for training
|
||||
repo_id = "lerobot/example_hil_serl_dataset"
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
# Configure the policy to extract features from the image frames
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
|
||||
config = RewardClassifierConfig(
|
||||
num_cameras=len(camera_keys),
|
||||
device=device,
|
||||
# backbone model to extract features from the image frames
|
||||
model_name="microsoft/resnet-18",
|
||||
)
|
||||
|
||||
# Make policy, preprocessor, and optimizer
|
||||
policy = make_policy(config, ds_meta=dataset.meta)
|
||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
||||
|
||||
|
||||
classifier_id = "fracapuano/reward_classifier_hil_serl_example"
|
||||
|
||||
# Instantiate a dataloader
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
|
||||
|
||||
# Training loop
|
||||
num_epochs = 5
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
total_accuracy = 0
|
||||
for batch in dataloader:
|
||||
# Preprocess the batch and move it to the correct device.
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Forward pass
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_accuracy += output_dict["accuracy"]
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
avg_accuracy = total_accuracy / len(dataloader)
|
||||
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
|
||||
|
||||
print("Training finished!")
|
||||
|
||||
# You can now save the trained policy.
|
||||
policy.push_to_hub(classifier_id)
|
||||
66
examples/tutorial/smolvla/using_smolvla_example.py
Normal file
66
examples/tutorial/smolvla/using_smolvla_example.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/smolvla_base"
|
||||
|
||||
model = SmolVLAPolicy.from_pretrained(model_id)
|
||||
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
10
loop_datasets.py
Normal file
10
loop_datasets.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from huggingface_hub import HfApi, list_datasets
|
||||
|
||||
api = HfApi()
|
||||
datasets = list_datasets(author="lerobot-data-collection")
|
||||
print('"[', end="")
|
||||
i=0
|
||||
for dataset in datasets:
|
||||
if "three-folds-dataset" in dataset.id:
|
||||
print("'" + dataset.id + "',", end="")
|
||||
print(']"',)
|
||||
183
pyproject.toml
183
pyproject.toml
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.3.4"
|
||||
version = "0.4.1"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
readme = "README.md"
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -59,28 +59,30 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
||||
dependencies = [
|
||||
|
||||
# Hugging Face dependencies
|
||||
"datasets>=4.0.0",
|
||||
"diffusers>=0.27.2",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2",
|
||||
"datasets>=4.0.0,<4.2.0",
|
||||
"diffusers>=0.27.2,<0.36.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
|
||||
# Core dependencies
|
||||
"cmake>=3.29.0.1",
|
||||
"einops>=0.8.0",
|
||||
"opencv-python-headless>=4.9.0",
|
||||
"av>=14.2.0",
|
||||
"jsonlines>=4.0.0",
|
||||
"packaging>=24.2",
|
||||
"pynput>=1.7.7",
|
||||
"pyserial>=3.5",
|
||||
"wandb>=0.20.0",
|
||||
"setuptools>=71.0.0,<81.0.0",
|
||||
"cmake>=3.29.0.1,<4.2.0",
|
||||
"einops>=0.8.0,<0.9.0",
|
||||
"opencv-python-headless>=4.9.0,<4.13.0",
|
||||
"av>=15.0.0,<16.0.0",
|
||||
"jsonlines>=4.0.0,<5.0.0",
|
||||
"packaging>=24.2,<26.0",
|
||||
"pynput>=1.7.7,<1.9.0",
|
||||
"pyserial>=3.5,<4.0",
|
||||
"wandb>=0.20.0,<0.22.0", # TODO: Bumb dependency (compatible with protobuf)
|
||||
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
|
||||
"draccus==0.10.0", # TODO: Remove ==
|
||||
"gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency
|
||||
"rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
"gymnasium>=1.1.1,<2.0.0",
|
||||
"rerun-sdk>=0.24.0,<0.27.0",
|
||||
|
||||
# Support dependencies
|
||||
"deepdiff>=7.0.1,<9.0.0",
|
||||
@@ -92,63 +94,72 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1"]
|
||||
placo-dep = ["placo>=0.9.6"]
|
||||
transformers-dep = ["transformers>=4.52.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31"]
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
damiao = ["python-can>=4.2.0,<5.0.0"]
|
||||
|
||||
# Robots
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0"]
|
||||
openarms = ["lerobot[damiao]"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1"]
|
||||
reachy2 = ["reachy2_sdk>=1.0.14"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
|
||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
|
||||
]
|
||||
phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"]
|
||||
# stretch = [
|
||||
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
|
||||
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
|
||||
# "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"
|
||||
# ] # TODO: Currently not supported
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
||||
|
||||
# Policies
|
||||
pi0 = ["lerobot[transformers-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
"peft>=0.13.0,<1.0.0",
|
||||
"dm-tree>=0.1.8,<1.0.0",
|
||||
"timm>=1.0.0,<1.1.0",
|
||||
"safetensors>=0.4.3,<1.0.0",
|
||||
"Pillow>=10.0.0,<13.0.0",
|
||||
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"]
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||
test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
aloha = ["gym-aloha>=0.1.1"]
|
||||
pusht = ["gym-pusht>=0.1.5", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
xarm = ["gym-xarm>=0.1.1"]
|
||||
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
|
||||
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"]
|
||||
|
||||
metaworld = ["metaworld==3.0.0"]
|
||||
|
||||
# All
|
||||
all = [
|
||||
"lerobot[dynamixel]",
|
||||
"lerobot[openarms]",
|
||||
"lerobot[gamepad]",
|
||||
"lerobot[hopejr]",
|
||||
"lerobot[lekiwi]",
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[pi0]",
|
||||
"lerobot[pi]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
@@ -156,9 +167,9 @@ all = [
|
||||
"lerobot[video_benchmark]",
|
||||
"lerobot[aloha]",
|
||||
"lerobot[pusht]",
|
||||
"lerobot[xarm]",
|
||||
"lerobot[phone]",
|
||||
"lerobot[libero]",
|
||||
"lerobot[metaworld]",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -175,6 +186,7 @@ lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
|
||||
lerobot-info="lerobot.scripts.lerobot_info:main"
|
||||
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.packages.find]
|
||||
@@ -232,9 +244,6 @@ exclude_dirs = [
|
||||
"tests",
|
||||
"benchmarks",
|
||||
"src/lerobot/datasets/push_dataset_to_hub",
|
||||
"src/lerobot/datasets/v2/convert_dataset_v1_to_v2",
|
||||
"src/lerobot/policies/pi0/conversion_scripts",
|
||||
"src/lerobot/scripts/push_dataset_to_hub.py",
|
||||
]
|
||||
skips = ["B101", "B311", "B404", "B603", "B615"]
|
||||
|
||||
@@ -249,6 +258,8 @@ default.extend-ignore-identifiers-re = [
|
||||
"pn",
|
||||
"ser",
|
||||
"ein",
|
||||
"thw",
|
||||
"inpt",
|
||||
]
|
||||
|
||||
# TODO: Uncomment when ready to use
|
||||
@@ -270,80 +281,88 @@ default.extend-ignore-identifiers-re = [
|
||||
# TODO: Enable mypy gradually module by module across multiple PRs
|
||||
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
|
||||
|
||||
# [tool.mypy]
|
||||
# python_version = "3.10"
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
follow_imports = "skip"
|
||||
# warn_return_any = true
|
||||
# warn_unused_configs = true
|
||||
# ignore_missing_imports = false
|
||||
# strict = true
|
||||
# disallow_untyped_defs = true
|
||||
# disallow_incomplete_defs = true
|
||||
# check_untyped_defs = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.*"
|
||||
ignore_errors = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.envs.*"
|
||||
ignore_errors = false
|
||||
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.utils.*"
|
||||
# # include = "src/lerobot/utils/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.configs.*"
|
||||
ignore_errors = false
|
||||
|
||||
# extra strictness for configs
|
||||
disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
check_untyped_defs = true
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.configs.*"
|
||||
# # include = "src/lerobot/configs/**/*.py"
|
||||
# module = "lerobot.optim.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.model.*"
|
||||
ignore_errors = false
|
||||
|
||||
# # Data processing modules
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.processor.*"
|
||||
# # include = "src/lerobot/processor/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.datasets.*"
|
||||
# # include = "src/lerobot/datasets/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# # Core machine learning modules
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.optim.*"
|
||||
# # include = "src/lerobot/optim/**/*.py"
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.model.*"
|
||||
# # include = "src/lerobot/model/**/*.py"
|
||||
|
||||
# # Hardware interfaces
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.cameras.*"
|
||||
# # include = "src/lerobot/cameras/**/*.py"
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.cameras.*"
|
||||
ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.motors.*"
|
||||
# # include = "src/lerobot/motors/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.robots.*"
|
||||
# # include = "src/lerobot/robots/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.teleoperators.*"
|
||||
# # include = "src/lerobot/teleoperators/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# # Complex modules (enable these last)
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.policies.*"
|
||||
# # include = "src/lerobot/policies/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.rl.*"
|
||||
# # include = "src/lerobot/rl/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.envs.*"
|
||||
# # include = "src/lerobot/envs/**/*.py"
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.async_inference.*"
|
||||
# # include = "src/lerobot/async_inference/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.transport.*"
|
||||
# # include = "src/lerobot/transport/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.scripts.*"
|
||||
# # include = "src/lerobot/scripts/**/*.py"
|
||||
# ignore_errors = false
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with Python 3.10
|
||||
# by the following command:
|
||||
#
|
||||
@@ -12,47 +13,62 @@ absl-py==2.3.1
|
||||
# dm-tree
|
||||
# labmaze
|
||||
# mujoco
|
||||
accelerate==1.9.0
|
||||
# via lerobot
|
||||
# tensorboard
|
||||
accelerate==1.11.0
|
||||
# via
|
||||
# lerobot
|
||||
# peft
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.12.15
|
||||
aiohttp==3.13.1
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via
|
||||
# hydra-core
|
||||
# omegaconf
|
||||
anyio==4.11.0
|
||||
# via
|
||||
# starlette
|
||||
# watchfiles
|
||||
asttokens==3.0.0
|
||||
# via stack-data
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.3.0
|
||||
attrs==25.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
# dm-tree
|
||||
# jsonlines
|
||||
# jsonschema
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
av==15.0.0
|
||||
av==15.1.0
|
||||
# via lerobot
|
||||
blinker==1.9.0
|
||||
# via flask
|
||||
certifi==2025.7.14
|
||||
bddl==1.0.1
|
||||
# via libero
|
||||
certifi==2025.10.5
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==1.17.1
|
||||
cffi==2.0.0
|
||||
# via pymunk
|
||||
cfgv==3.4.0
|
||||
# via pre-commit
|
||||
charset-normalizer==3.4.2
|
||||
charset-normalizer==3.4.4
|
||||
# via requests
|
||||
click==8.2.1
|
||||
click==8.3.0
|
||||
# via
|
||||
# flask
|
||||
# uvicorn
|
||||
# wandb
|
||||
cloudpickle==3.1.1
|
||||
# via gymnasium
|
||||
cmake==4.0.3
|
||||
# via
|
||||
# gymnasium
|
||||
# libero
|
||||
cmake==4.1.0
|
||||
# via lerobot
|
||||
cmeel==0.57.3
|
||||
# via
|
||||
@@ -94,27 +110,27 @@ coal-library==3.0.1
|
||||
# via pin
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
coverage[toml]==7.10.1
|
||||
coverage[toml]==7.11.0
|
||||
# via pytest-cov
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
datasets==3.6.0
|
||||
datasets==4.1.1
|
||||
# via lerobot
|
||||
debugpy==1.8.15
|
||||
debugpy==1.8.17
|
||||
# via lerobot
|
||||
decorator==5.2.1
|
||||
# via ipython
|
||||
deepdiff==8.5.0
|
||||
deepdiff==8.6.1
|
||||
# via lerobot
|
||||
diffusers==0.34.0
|
||||
diffusers==0.35.2
|
||||
# via lerobot
|
||||
dill==0.3.8
|
||||
dill==0.4.0
|
||||
# via
|
||||
# datasets
|
||||
# multiprocess
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
dm-control==1.0.14
|
||||
dm-control==1.0.34
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
@@ -122,29 +138,45 @@ dm-tree==0.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# lerobot
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
draccus==0.10.0
|
||||
# via lerobot
|
||||
dynamixel-sdk==3.7.31
|
||||
dynamixel-sdk==3.8.4
|
||||
# via lerobot
|
||||
easydict==1.13
|
||||
# via libero
|
||||
egl-probe @ git+https://github.com/huggingface/egl_probe.git
|
||||
# via
|
||||
# libero
|
||||
# robomimic
|
||||
eigenpy==3.10.3
|
||||
# via coal-library
|
||||
einops==0.8.1
|
||||
# via lerobot
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
etils[epath,epy]==1.13.0
|
||||
# via mujoco
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# ipython
|
||||
# pytest
|
||||
executing==2.2.0
|
||||
executing==2.2.1
|
||||
# via stack-data
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fastapi==0.119.1
|
||||
# via teleop
|
||||
fastjsonschema==2.21.2
|
||||
# via nbformat
|
||||
feetech-servo-sdk==1.0.0
|
||||
# via lerobot
|
||||
filelock==3.18.0
|
||||
filelock==3.20.0
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
@@ -152,24 +184,25 @@ filelock==3.18.0
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
flask==3.1.1
|
||||
# via lerobot
|
||||
fonttools==4.59.0
|
||||
fonttools==4.60.1
|
||||
# via matplotlib
|
||||
frozenlist==1.7.0
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2025.3.0
|
||||
fsspec[http]==2025.9.0
|
||||
# via
|
||||
# datasets
|
||||
# etils
|
||||
# huggingface-hub
|
||||
# torch
|
||||
future==1.0.0
|
||||
# via libero
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.45
|
||||
# via wandb
|
||||
glfw==2.9.0
|
||||
glfw==2.10.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
@@ -177,61 +210,79 @@ grpcio==1.73.1
|
||||
# via
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
grpcio-tools==1.73.1
|
||||
# via
|
||||
# lerobot
|
||||
# reachy2-sdk-api
|
||||
gym-aloha==0.1.3
|
||||
# via lerobot
|
||||
gym-aloha==0.1.1
|
||||
gym-hil==0.1.13
|
||||
# via lerobot
|
||||
gym-hil==0.1.10
|
||||
gym-pusht==0.1.6
|
||||
# via lerobot
|
||||
gym-pusht==0.1.5
|
||||
# via lerobot
|
||||
gym-xarm==0.1.1
|
||||
# via lerobot
|
||||
gymnasium==0.29.1
|
||||
gymnasium==1.2.1
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# pettingzoo
|
||||
gymnasium-robotics==1.2.4
|
||||
# via gym-xarm
|
||||
# libero
|
||||
# metaworld
|
||||
h11==0.16.0
|
||||
# via uvicorn
|
||||
h5py==3.15.1
|
||||
# via robomimic
|
||||
hebi-py==2.11.0
|
||||
# via lerobot
|
||||
hf-transfer==0.1.9
|
||||
# via huggingface-hub
|
||||
hf-xet==1.1.5
|
||||
hf-xet==1.1.10
|
||||
# via huggingface-hub
|
||||
hidapi==0.14.0.post4
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
huggingface-hub[cli,hf-transfer]==0.34.3
|
||||
httptools==0.7.1
|
||||
# via uvicorn
|
||||
huggingface-hub[cli,hf-transfer]==0.35.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
identify==2.6.12
|
||||
hydra-core==1.3.2
|
||||
# via libero
|
||||
identify==2.6.15
|
||||
# via pre-commit
|
||||
idna==3.10
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# requests
|
||||
# yarl
|
||||
imageio[ffmpeg]==2.37.0
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robomimic
|
||||
# scikit-image
|
||||
imageio-ffmpeg==0.6.0
|
||||
# via imageio
|
||||
# via
|
||||
# imageio
|
||||
# robomimic
|
||||
importlib-metadata==8.7.0
|
||||
# via diffusers
|
||||
iniconfig==2.1.0
|
||||
importlib-resources==6.5.2
|
||||
# via etils
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
inquirerpy==0.3.4
|
||||
# via huggingface-hub
|
||||
@@ -239,50 +290,71 @@ ipython==8.37.0
|
||||
# via meshcat
|
||||
ischedule==1.2.7
|
||||
# via placo
|
||||
itsdangerous==2.2.0
|
||||
# via flask
|
||||
jedi==0.19.2
|
||||
# via ipython
|
||||
jinja2==3.1.6
|
||||
# via
|
||||
# flask
|
||||
# gymnasium-robotics
|
||||
# torch
|
||||
# via torch
|
||||
jsonlines==4.0.0
|
||||
# via lerobot
|
||||
kiwisolver==1.4.8
|
||||
jsonschema==4.25.1
|
||||
# via nbformat
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
jupyter-core==5.9.1
|
||||
# via nbformat
|
||||
jupytext==1.18.1
|
||||
# via bddl
|
||||
kiwisolver==1.4.9
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lazy-loader==0.4
|
||||
# via scikit-image
|
||||
lxml==6.0.0
|
||||
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
|
||||
# via lerobot
|
||||
llvmlite==0.45.1
|
||||
# via numba
|
||||
lxml==6.0.2
|
||||
# via dm-control
|
||||
markupsafe==3.0.2
|
||||
markdown==3.9
|
||||
# via tensorboard
|
||||
markdown-it-py==4.0.0
|
||||
# via
|
||||
# jupytext
|
||||
# mdit-py-plugins
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# flask
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.10.5
|
||||
# via lerobot
|
||||
matplotlib-inline==0.1.7
|
||||
matplotlib==3.10.7
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
matplotlib-inline==0.2.1
|
||||
# via ipython
|
||||
mdit-py-plugins==0.5.0
|
||||
# via jupytext
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mergedeep==1.3.4
|
||||
# via draccus
|
||||
meshcat==0.3.2
|
||||
# via placo
|
||||
metaworld==3.0.0
|
||||
# via lerobot
|
||||
mock-serial==0.0.1
|
||||
# via lerobot
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mujoco==2.3.7
|
||||
mujoco==3.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
multidict==6.6.3
|
||||
# libero
|
||||
# metaworld
|
||||
# robosuite
|
||||
multidict==6.7.0
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@@ -290,17 +362,25 @@ multiprocess==0.70.16
|
||||
# via datasets
|
||||
mypy-extensions==1.1.0
|
||||
# via typing-inspect
|
||||
nbformat==5.10.4
|
||||
# via jupytext
|
||||
networkx==3.4.2
|
||||
# via
|
||||
# bddl
|
||||
# scikit-image
|
||||
# torch
|
||||
ninja==1.13.0
|
||||
# via lerobot
|
||||
nodeenv==1.9.1
|
||||
# via pre-commit
|
||||
num2words==0.5.14
|
||||
# via lerobot
|
||||
numba==0.62.1
|
||||
# via robosuite
|
||||
numpy==2.2.6
|
||||
# via
|
||||
# accelerate
|
||||
# bddl
|
||||
# cmeel-boost
|
||||
# contourpy
|
||||
# datasets
|
||||
@@ -309,25 +389,43 @@ numpy==2.2.6
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# gymnasium
|
||||
# gymnasium-robotics
|
||||
# h5py
|
||||
# hebi-py
|
||||
# imageio
|
||||
# labmaze
|
||||
# libero
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# metaworld
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# opencv-python-headless
|
||||
# pandas
|
||||
# pettingzoo
|
||||
# peft
|
||||
# pyquaternion
|
||||
# reachy2-sdk
|
||||
# rerun-sdk
|
||||
# robomimic
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# scipy
|
||||
# shapely
|
||||
# teleop
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# tifffile
|
||||
# torchvision
|
||||
# transformers
|
||||
# transforms3d
|
||||
omegaconf==2.3.0
|
||||
# via hydra-core
|
||||
opencv-python==4.12.0.88
|
||||
# via gym-pusht
|
||||
# via
|
||||
# gym-pusht
|
||||
# libero
|
||||
# reachy2-sdk
|
||||
# robosuite
|
||||
opencv-python-headless==4.12.0.88
|
||||
# via lerobot
|
||||
orderly-set==5.5.0
|
||||
@@ -337,53 +435,63 @@ packaging==25.0
|
||||
# accelerate
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# hydra-core
|
||||
# jupytext
|
||||
# lazy-loader
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# peft
|
||||
# pytest
|
||||
# reachy2-sdk
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# transformers
|
||||
# wandb
|
||||
pandas==2.3.1
|
||||
pandas==2.3.3
|
||||
# via
|
||||
# datasets
|
||||
# lerobot
|
||||
parso==0.8.4
|
||||
parso==0.8.5
|
||||
# via jedi
|
||||
pettingzoo==1.24.3
|
||||
# via gymnasium-robotics
|
||||
peft==0.17.1
|
||||
# via lerobot
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pfzy==0.3.4
|
||||
# via inquirerpy
|
||||
pillow==11.3.0
|
||||
pillow==12.0.0
|
||||
# via
|
||||
# diffusers
|
||||
# imageio
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# rerun-sdk
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# torchvision
|
||||
pin==3.4.0
|
||||
# via placo
|
||||
placo==0.9.14
|
||||
# via lerobot
|
||||
platformdirs==4.3.8
|
||||
platformdirs==4.5.0
|
||||
# via
|
||||
# jupyter-core
|
||||
# virtualenv
|
||||
# wandb
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
pre-commit==4.2.0
|
||||
pre-commit==4.3.0
|
||||
# via lerobot
|
||||
prompt-toolkit==3.0.51
|
||||
prompt-toolkit==3.0.52
|
||||
# via
|
||||
# inquirerpy
|
||||
# ipython
|
||||
propcache==0.3.2
|
||||
propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@@ -392,11 +500,17 @@ protobuf==6.31.0
|
||||
# dm-control
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# wandb
|
||||
psutil==7.0.0
|
||||
psutil==7.1.1
|
||||
# via
|
||||
# accelerate
|
||||
# imageio
|
||||
# peft
|
||||
# robomimic
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
@@ -405,11 +519,13 @@ pyarrow==21.0.0
|
||||
# via
|
||||
# datasets
|
||||
# rerun-sdk
|
||||
pycparser==2.22
|
||||
pycparser==2.23
|
||||
# via cffi
|
||||
pydantic==2.11.7
|
||||
# via wandb
|
||||
pydantic-core==2.33.2
|
||||
pydantic==2.12.3
|
||||
# via
|
||||
# fastapi
|
||||
# wandb
|
||||
pydantic-core==2.41.4
|
||||
# via pydantic
|
||||
pygame==2.6.1
|
||||
# via
|
||||
@@ -424,40 +540,42 @@ pymunk==6.11.1
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pyngrok==7.2.12
|
||||
pyngrok==7.4.1
|
||||
# via meshcat
|
||||
pynput==1.8.1
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
pyobjc-core==11.1
|
||||
pyobjc-core==12.0
|
||||
# via
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-cocoa
|
||||
# pyobjc-framework-coretext
|
||||
# pyobjc-framework-quartz
|
||||
pyobjc-framework-applicationservices==11.1
|
||||
pyobjc-framework-applicationservices==12.0
|
||||
# via pynput
|
||||
pyobjc-framework-cocoa==11.1
|
||||
pyobjc-framework-cocoa==12.0
|
||||
# via
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-coretext
|
||||
# pyobjc-framework-quartz
|
||||
pyobjc-framework-coretext==11.1
|
||||
pyobjc-framework-coretext==12.0
|
||||
# via pyobjc-framework-applicationservices
|
||||
pyobjc-framework-quartz==11.1
|
||||
pyobjc-framework-quartz==12.0
|
||||
# via
|
||||
# pynput
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-coretext
|
||||
pyopengl==3.1.9
|
||||
pyopengl==3.1.10
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.3
|
||||
pyparsing==3.2.5
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via reachy2-sdk
|
||||
pyrealsense2-macosx==2.54.2
|
||||
# via lerobot
|
||||
pyserial==3.5
|
||||
@@ -465,12 +583,14 @@ pyserial==3.5
|
||||
# dynamixel-sdk
|
||||
# feetech-servo-sdk
|
||||
# lerobot
|
||||
pytest==8.4.1
|
||||
pytest==8.4.2
|
||||
# via
|
||||
# bddl
|
||||
# lerobot
|
||||
# pytest-cov
|
||||
# pytest-timeout
|
||||
pytest-cov==6.2.1
|
||||
# teleop
|
||||
pytest-cov==7.0.0
|
||||
# via lerobot
|
||||
pytest-timeout==2.4.0
|
||||
# via lerobot
|
||||
@@ -478,46 +598,73 @@ python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# matplotlib
|
||||
# pandas
|
||||
python-dotenv==1.1.1
|
||||
# via uvicorn
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
pyyaml==6.0.2
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# draccus
|
||||
# hebi-py
|
||||
# huggingface-hub
|
||||
# jupytext
|
||||
# omegaconf
|
||||
# peft
|
||||
# pre-commit
|
||||
# pyngrok
|
||||
# pyyaml-include
|
||||
# timm
|
||||
# transformers
|
||||
# uvicorn
|
||||
# wandb
|
||||
pyyaml-include==1.4.1
|
||||
# via draccus
|
||||
pyzmq==27.0.0
|
||||
pyzmq==27.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# meshcat
|
||||
regex==2025.7.34
|
||||
reachy2-sdk==1.0.14
|
||||
# via lerobot
|
||||
reachy2-sdk-api==1.0.21
|
||||
# via reachy2-sdk
|
||||
referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2025.10.23
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
requests==2.32.4
|
||||
requests==2.32.5
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# teleop
|
||||
# transformers
|
||||
# wandb
|
||||
rerun-sdk==0.22.1
|
||||
rerun-sdk==0.26.1
|
||||
# via lerobot
|
||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||
# via placo
|
||||
safetensors==0.5.3
|
||||
robomimic==0.2.0
|
||||
# via libero
|
||||
robosuite==1.4.0
|
||||
# via libero
|
||||
rpds-py==0.28.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.6.2
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# transformers
|
||||
scikit-image==0.25.2
|
||||
# via
|
||||
@@ -526,10 +673,12 @@ scikit-image==0.25.2
|
||||
scipy==1.15.3
|
||||
# via
|
||||
# dm-control
|
||||
# metaworld
|
||||
# robosuite
|
||||
# scikit-image
|
||||
sentry-sdk==2.34.1
|
||||
sentry-sdk==2.42.1
|
||||
# via wandb
|
||||
shapely==2.1.1
|
||||
shapely==2.1.2
|
||||
# via gym-pusht
|
||||
six==1.17.0
|
||||
# via
|
||||
@@ -537,64 +686,106 @@ six==1.17.0
|
||||
# python-dateutil
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
starlette==0.48.0
|
||||
# via fastapi
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
termcolor==3.1.0
|
||||
teleop==0.1.2
|
||||
# via lerobot
|
||||
tensorboard==2.20.0
|
||||
# via robomimic
|
||||
tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
tensorboardx==2.6.4
|
||||
# via robomimic
|
||||
termcolor==3.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
thop==0.1.1.post2209072238
|
||||
# via libero
|
||||
tifffile==2025.5.10
|
||||
# via scikit-image
|
||||
tokenizers==0.21.4
|
||||
timm==1.0.20
|
||||
# via lerobot
|
||||
tokenizers==0.22.1
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via draccus
|
||||
tomli==2.2.1
|
||||
tomli==2.3.0
|
||||
# via
|
||||
# cmeel
|
||||
# coverage
|
||||
# jupytext
|
||||
# pytest
|
||||
torch==2.7.1
|
||||
# via
|
||||
# accelerate
|
||||
# lerobot
|
||||
# peft
|
||||
# robomimic
|
||||
# thop
|
||||
# timm
|
||||
# torchvision
|
||||
torchcodec==0.5
|
||||
# via lerobot
|
||||
torchvision==0.22.1
|
||||
# via lerobot
|
||||
tornado==6.5.1
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
# timm
|
||||
tornado==6.5.2
|
||||
# via meshcat
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# datasets
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# peft
|
||||
# robomimic
|
||||
# transformers
|
||||
traitlets==5.14.3
|
||||
# via
|
||||
# ipython
|
||||
# jupyter-core
|
||||
# matplotlib-inline
|
||||
transformers==4.51.3
|
||||
# via lerobot
|
||||
typing-extensions==4.14.1
|
||||
# nbformat
|
||||
transformers==4.57.1
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
# peft
|
||||
transforms3d==0.4.2
|
||||
# via teleop
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# etils
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
# gymnasium
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# multidict
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
# starlette
|
||||
# torch
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# uvicorn
|
||||
# virtualenv
|
||||
# wandb
|
||||
typing-inspect==0.9.0
|
||||
# via draccus
|
||||
typing-inspection==0.4.1
|
||||
typing-inspection==0.4.2
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via pandas
|
||||
@@ -604,22 +795,36 @@ urllib3==2.5.0
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
virtualenv==20.32.0
|
||||
uvicorn[standard]==0.38.0
|
||||
# via teleop
|
||||
uvloop==0.22.1
|
||||
# via uvicorn
|
||||
virtualenv==20.35.3
|
||||
# via pre-commit
|
||||
wandb==0.21.0
|
||||
# via lerobot
|
||||
wcwidth==0.2.13
|
||||
wandb==0.21.4
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
watchfiles==1.1.1
|
||||
# via uvicorn
|
||||
wcwidth==0.2.14
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.9.0
|
||||
# via teleop
|
||||
websockets==15.0.1
|
||||
# via uvicorn
|
||||
werkzeug==3.1.3
|
||||
# via flask
|
||||
wrapt==1.17.2
|
||||
# via tensorboard
|
||||
wrapt==2.0.0
|
||||
# via dm-tree
|
||||
xxhash==3.5.0
|
||||
xxhash==3.6.0
|
||||
# via datasets
|
||||
yarl==1.20.1
|
||||
yarl==1.22.0
|
||||
# via aiohttp
|
||||
zipp==3.23.0
|
||||
# via importlib-metadata
|
||||
# via
|
||||
# etils
|
||||
# importlib-metadata
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# setuptools
|
||||
|
||||
@@ -13,47 +13,62 @@ absl-py==2.3.1
|
||||
# dm-tree
|
||||
# labmaze
|
||||
# mujoco
|
||||
accelerate==1.9.0
|
||||
# via lerobot
|
||||
# tensorboard
|
||||
accelerate==1.11.0
|
||||
# via
|
||||
# lerobot
|
||||
# peft
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.12.15
|
||||
aiohttp==3.13.1
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via
|
||||
# hydra-core
|
||||
# omegaconf
|
||||
anyio==4.11.0
|
||||
# via
|
||||
# starlette
|
||||
# watchfiles
|
||||
asttokens==3.0.0
|
||||
# via stack-data
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.3.0
|
||||
attrs==25.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
# dm-tree
|
||||
# jsonlines
|
||||
# jsonschema
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
av==15.0.0
|
||||
av==15.1.0
|
||||
# via lerobot
|
||||
blinker==1.9.0
|
||||
# via flask
|
||||
certifi==2025.7.14
|
||||
bddl==1.0.1
|
||||
# via libero
|
||||
certifi==2025.10.5
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==1.17.1
|
||||
cffi==2.0.0
|
||||
# via pymunk
|
||||
cfgv==3.4.0
|
||||
# via pre-commit
|
||||
charset-normalizer==3.4.2
|
||||
charset-normalizer==3.4.4
|
||||
# via requests
|
||||
click==8.2.1
|
||||
click==8.3.0
|
||||
# via
|
||||
# flask
|
||||
# uvicorn
|
||||
# wandb
|
||||
cloudpickle==3.1.1
|
||||
# via gymnasium
|
||||
cmake==4.0.3
|
||||
# via
|
||||
# gymnasium
|
||||
# libero
|
||||
cmake==4.1.0
|
||||
# via lerobot
|
||||
cmeel==0.57.3
|
||||
# via
|
||||
@@ -95,27 +110,29 @@ coal-library==3.0.1
|
||||
# via pin
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
coverage[toml]==7.10.1
|
||||
coverage[toml]==7.11.0
|
||||
# via pytest-cov
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
datasets==3.6.0
|
||||
datasets==4.1.1
|
||||
# via lerobot
|
||||
debugpy==1.8.15
|
||||
debugpy==1.8.17
|
||||
# via lerobot
|
||||
decorator==5.2.1
|
||||
# via ipython
|
||||
deepdiff==8.5.0
|
||||
decord==0.6.0
|
||||
# via lerobot
|
||||
diffusers==0.34.0
|
||||
deepdiff==8.6.1
|
||||
# via lerobot
|
||||
dill==0.3.8
|
||||
diffusers==0.35.2
|
||||
# via lerobot
|
||||
dill==0.4.0
|
||||
# via
|
||||
# datasets
|
||||
# multiprocess
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
dm-control==1.0.14
|
||||
dm-control==1.0.34
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
@@ -123,31 +140,48 @@ dm-tree==0.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# lerobot
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
draccus==0.10.0
|
||||
# via lerobot
|
||||
dynamixel-sdk==3.7.31
|
||||
dynamixel-sdk==3.8.4
|
||||
# via lerobot
|
||||
easydict==1.13
|
||||
# via libero
|
||||
egl-probe @ git+https://github.com/huggingface/egl_probe.git
|
||||
# via
|
||||
# libero
|
||||
# robomimic
|
||||
eigenpy==3.10.3
|
||||
# via coal-library
|
||||
einops==0.8.1
|
||||
# via lerobot
|
||||
# via
|
||||
# flash-attn
|
||||
# lerobot
|
||||
# libero
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
etils[epath,epy]==1.13.0
|
||||
# via mujoco
|
||||
evdev==1.9.2
|
||||
# via pynput
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# ipython
|
||||
# pytest
|
||||
executing==2.2.0
|
||||
executing==2.2.1
|
||||
# via stack-data
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fastapi==0.119.1
|
||||
# via teleop
|
||||
fastjsonschema==2.21.2
|
||||
# via nbformat
|
||||
feetech-servo-sdk==1.0.0
|
||||
# via lerobot
|
||||
filelock==3.18.0
|
||||
filelock==3.20.0
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
@@ -155,24 +189,27 @@ filelock==3.18.0
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
flask==3.1.1
|
||||
flash-attn==2.8.3
|
||||
# via lerobot
|
||||
fonttools==4.59.0
|
||||
fonttools==4.60.1
|
||||
# via matplotlib
|
||||
frozenlist==1.7.0
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2025.3.0
|
||||
fsspec[http]==2025.9.0
|
||||
# via
|
||||
# datasets
|
||||
# etils
|
||||
# huggingface-hub
|
||||
# torch
|
||||
future==1.0.0
|
||||
# via libero
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.45
|
||||
# via wandb
|
||||
glfw==2.9.0
|
||||
glfw==2.10.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
@@ -180,61 +217,79 @@ grpcio==1.73.1
|
||||
# via
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
grpcio-tools==1.73.1
|
||||
# via
|
||||
# lerobot
|
||||
# reachy2-sdk-api
|
||||
gym-aloha==0.1.3
|
||||
# via lerobot
|
||||
gym-aloha==0.1.1
|
||||
gym-hil==0.1.13
|
||||
# via lerobot
|
||||
gym-hil==0.1.10
|
||||
gym-pusht==0.1.6
|
||||
# via lerobot
|
||||
gym-pusht==0.1.5
|
||||
# via lerobot
|
||||
gym-xarm==0.1.1
|
||||
# via lerobot
|
||||
gymnasium==0.29.1
|
||||
gymnasium==1.2.1
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# pettingzoo
|
||||
gymnasium-robotics==1.2.4
|
||||
# via gym-xarm
|
||||
# libero
|
||||
# metaworld
|
||||
h11==0.16.0
|
||||
# via uvicorn
|
||||
h5py==3.15.1
|
||||
# via robomimic
|
||||
hebi-py==2.11.0
|
||||
# via lerobot
|
||||
hf-transfer==0.1.9
|
||||
# via huggingface-hub
|
||||
hf-xet==1.1.5
|
||||
hf-xet==1.1.10
|
||||
# via huggingface-hub
|
||||
hidapi==0.14.0.post4
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
huggingface-hub[cli,hf-transfer]==0.34.3
|
||||
httptools==0.7.1
|
||||
# via uvicorn
|
||||
huggingface-hub[cli,hf-transfer]==0.35.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
identify==2.6.12
|
||||
hydra-core==1.3.2
|
||||
# via libero
|
||||
identify==2.6.15
|
||||
# via pre-commit
|
||||
idna==3.10
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# requests
|
||||
# yarl
|
||||
imageio[ffmpeg]==2.37.0
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robomimic
|
||||
# scikit-image
|
||||
imageio-ffmpeg==0.6.0
|
||||
# via imageio
|
||||
# via
|
||||
# imageio
|
||||
# robomimic
|
||||
importlib-metadata==8.7.0
|
||||
# via diffusers
|
||||
iniconfig==2.1.0
|
||||
importlib-resources==6.5.2
|
||||
# via etils
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
inquirerpy==0.3.4
|
||||
# via huggingface-hub
|
||||
@@ -242,50 +297,71 @@ ipython==8.37.0
|
||||
# via meshcat
|
||||
ischedule==1.2.7
|
||||
# via placo
|
||||
itsdangerous==2.2.0
|
||||
# via flask
|
||||
jedi==0.19.2
|
||||
# via ipython
|
||||
jinja2==3.1.6
|
||||
# via
|
||||
# flask
|
||||
# gymnasium-robotics
|
||||
# torch
|
||||
# via torch
|
||||
jsonlines==4.0.0
|
||||
# via lerobot
|
||||
kiwisolver==1.4.8
|
||||
jsonschema==4.25.1
|
||||
# via nbformat
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
jupyter-core==5.9.1
|
||||
# via nbformat
|
||||
jupytext==1.18.1
|
||||
# via bddl
|
||||
kiwisolver==1.4.9
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lazy-loader==0.4
|
||||
# via scikit-image
|
||||
lxml==6.0.0
|
||||
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
|
||||
# via lerobot
|
||||
llvmlite==0.45.1
|
||||
# via numba
|
||||
lxml==6.0.2
|
||||
# via dm-control
|
||||
markupsafe==3.0.2
|
||||
markdown==3.9
|
||||
# via tensorboard
|
||||
markdown-it-py==4.0.0
|
||||
# via
|
||||
# jupytext
|
||||
# mdit-py-plugins
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# flask
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.10.5
|
||||
# via lerobot
|
||||
matplotlib-inline==0.1.7
|
||||
matplotlib==3.10.7
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
matplotlib-inline==0.2.1
|
||||
# via ipython
|
||||
mdit-py-plugins==0.5.0
|
||||
# via jupytext
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mergedeep==1.3.4
|
||||
# via draccus
|
||||
meshcat==0.3.2
|
||||
# via placo
|
||||
metaworld==3.0.0
|
||||
# via lerobot
|
||||
mock-serial==0.0.1
|
||||
# via lerobot
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mujoco==2.3.7
|
||||
mujoco==3.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
multidict==6.6.3
|
||||
# libero
|
||||
# metaworld
|
||||
# robosuite
|
||||
multidict==6.7.0
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@@ -293,42 +369,63 @@ multiprocess==0.70.16
|
||||
# via datasets
|
||||
mypy-extensions==1.1.0
|
||||
# via typing-inspect
|
||||
nbformat==5.10.4
|
||||
# via jupytext
|
||||
networkx==3.4.2
|
||||
# via
|
||||
# bddl
|
||||
# scikit-image
|
||||
# torch
|
||||
ninja==1.13.0
|
||||
# via lerobot
|
||||
nodeenv==1.9.1
|
||||
# via pre-commit
|
||||
num2words==0.5.14
|
||||
# via lerobot
|
||||
numba==0.62.1
|
||||
# via robosuite
|
||||
numpy==2.2.6
|
||||
# via
|
||||
# accelerate
|
||||
# bddl
|
||||
# cmeel-boost
|
||||
# contourpy
|
||||
# datasets
|
||||
# decord
|
||||
# diffusers
|
||||
# dm-control
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# gymnasium
|
||||
# gymnasium-robotics
|
||||
# h5py
|
||||
# hebi-py
|
||||
# imageio
|
||||
# labmaze
|
||||
# libero
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# metaworld
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# opencv-python-headless
|
||||
# pandas
|
||||
# pettingzoo
|
||||
# peft
|
||||
# pyquaternion
|
||||
# reachy2-sdk
|
||||
# rerun-sdk
|
||||
# robomimic
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# scipy
|
||||
# shapely
|
||||
# teleop
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# tifffile
|
||||
# torchvision
|
||||
# transformers
|
||||
# transforms3d
|
||||
nvidia-cublas-cu12==12.6.4.1
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
@@ -366,8 +463,14 @@ nvidia-nvjitlink-cu12==12.6.85
|
||||
# torch
|
||||
nvidia-nvtx-cu12==12.6.77
|
||||
# via torch
|
||||
omegaconf==2.3.0
|
||||
# via hydra-core
|
||||
opencv-python==4.12.0.88
|
||||
# via gym-pusht
|
||||
# via
|
||||
# gym-pusht
|
||||
# libero
|
||||
# reachy2-sdk
|
||||
# robosuite
|
||||
opencv-python-headless==4.12.0.88
|
||||
# via lerobot
|
||||
orderly-set==5.5.0
|
||||
@@ -377,53 +480,63 @@ packaging==25.0
|
||||
# accelerate
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# hydra-core
|
||||
# jupytext
|
||||
# lazy-loader
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# peft
|
||||
# pytest
|
||||
# reachy2-sdk
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# transformers
|
||||
# wandb
|
||||
pandas==2.3.1
|
||||
pandas==2.3.3
|
||||
# via
|
||||
# datasets
|
||||
# lerobot
|
||||
parso==0.8.4
|
||||
parso==0.8.5
|
||||
# via jedi
|
||||
pettingzoo==1.24.3
|
||||
# via gymnasium-robotics
|
||||
peft==0.17.1
|
||||
# via lerobot
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pfzy==0.3.4
|
||||
# via inquirerpy
|
||||
pillow==11.3.0
|
||||
pillow==12.0.0
|
||||
# via
|
||||
# diffusers
|
||||
# imageio
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# rerun-sdk
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# torchvision
|
||||
pin==3.4.0
|
||||
# via placo
|
||||
placo==0.9.14
|
||||
# via lerobot
|
||||
platformdirs==4.3.8
|
||||
platformdirs==4.5.0
|
||||
# via
|
||||
# jupyter-core
|
||||
# virtualenv
|
||||
# wandb
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
pre-commit==4.2.0
|
||||
pre-commit==4.3.0
|
||||
# via lerobot
|
||||
prompt-toolkit==3.0.51
|
||||
prompt-toolkit==3.0.52
|
||||
# via
|
||||
# inquirerpy
|
||||
# ipython
|
||||
propcache==0.3.2
|
||||
propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@@ -432,11 +545,17 @@ protobuf==6.31.0
|
||||
# dm-control
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# wandb
|
||||
psutil==7.0.0
|
||||
psutil==7.1.1
|
||||
# via
|
||||
# accelerate
|
||||
# imageio
|
||||
# peft
|
||||
# robomimic
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
@@ -445,11 +564,13 @@ pyarrow==21.0.0
|
||||
# via
|
||||
# datasets
|
||||
# rerun-sdk
|
||||
pycparser==2.22
|
||||
pycparser==2.23
|
||||
# via cffi
|
||||
pydantic==2.11.7
|
||||
# via wandb
|
||||
pydantic-core==2.33.2
|
||||
pydantic==2.12.3
|
||||
# via
|
||||
# fastapi
|
||||
# wandb
|
||||
pydantic-core==2.41.4
|
||||
# via pydantic
|
||||
pygame==2.6.1
|
||||
# via
|
||||
@@ -464,20 +585,22 @@ pymunk==6.11.1
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pyngrok==7.2.12
|
||||
pyngrok==7.4.1
|
||||
# via meshcat
|
||||
pynput==1.8.1
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
pyopengl==3.1.9
|
||||
pyopengl==3.1.10
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.3
|
||||
pyparsing==3.2.5
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via reachy2-sdk
|
||||
pyrealsense2==2.56.5.9235
|
||||
# via lerobot
|
||||
pyserial==3.5
|
||||
@@ -485,12 +608,14 @@ pyserial==3.5
|
||||
# dynamixel-sdk
|
||||
# feetech-servo-sdk
|
||||
# lerobot
|
||||
pytest==8.4.1
|
||||
pytest==8.4.2
|
||||
# via
|
||||
# bddl
|
||||
# lerobot
|
||||
# pytest-cov
|
||||
# pytest-timeout
|
||||
pytest-cov==6.2.1
|
||||
# teleop
|
||||
pytest-cov==7.0.0
|
||||
# via lerobot
|
||||
pytest-timeout==2.4.0
|
||||
# via lerobot
|
||||
@@ -498,48 +623,75 @@ python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# matplotlib
|
||||
# pandas
|
||||
python-dotenv==1.1.1
|
||||
# via uvicorn
|
||||
python-xlib==0.33
|
||||
# via pynput
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
pyyaml==6.0.2
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# draccus
|
||||
# hebi-py
|
||||
# huggingface-hub
|
||||
# jupytext
|
||||
# omegaconf
|
||||
# peft
|
||||
# pre-commit
|
||||
# pyngrok
|
||||
# pyyaml-include
|
||||
# timm
|
||||
# transformers
|
||||
# uvicorn
|
||||
# wandb
|
||||
pyyaml-include==1.4.1
|
||||
# via draccus
|
||||
pyzmq==27.0.0
|
||||
pyzmq==27.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# meshcat
|
||||
regex==2025.7.34
|
||||
reachy2-sdk==1.0.14
|
||||
# via lerobot
|
||||
reachy2-sdk-api==1.0.21
|
||||
# via reachy2-sdk
|
||||
referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2025.10.23
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
requests==2.32.4
|
||||
requests==2.32.5
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# teleop
|
||||
# transformers
|
||||
# wandb
|
||||
rerun-sdk==0.22.1
|
||||
rerun-sdk==0.26.1
|
||||
# via lerobot
|
||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||
# via placo
|
||||
safetensors==0.5.3
|
||||
robomimic==0.2.0
|
||||
# via libero
|
||||
robosuite==1.4.0
|
||||
# via libero
|
||||
rpds-py==0.28.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.6.2
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# transformers
|
||||
scikit-image==0.25.2
|
||||
# via
|
||||
@@ -548,10 +700,12 @@ scikit-image==0.25.2
|
||||
scipy==1.15.3
|
||||
# via
|
||||
# dm-control
|
||||
# metaworld
|
||||
# robosuite
|
||||
# scikit-image
|
||||
sentry-sdk==2.34.1
|
||||
sentry-sdk==2.42.1
|
||||
# via wandb
|
||||
shapely==2.1.1
|
||||
shapely==2.1.2
|
||||
# via gym-pusht
|
||||
six==1.17.0
|
||||
# via
|
||||
@@ -560,66 +714,109 @@ six==1.17.0
|
||||
# python-xlib
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
starlette==0.48.0
|
||||
# via fastapi
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
termcolor==3.1.0
|
||||
teleop==0.1.2
|
||||
# via lerobot
|
||||
tensorboard==2.20.0
|
||||
# via robomimic
|
||||
tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
tensorboardx==2.6.4
|
||||
# via robomimic
|
||||
termcolor==3.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
thop==0.1.1.post2209072238
|
||||
# via libero
|
||||
tifffile==2025.5.10
|
||||
# via scikit-image
|
||||
tokenizers==0.21.4
|
||||
timm==1.0.20
|
||||
# via lerobot
|
||||
tokenizers==0.22.1
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via draccus
|
||||
tomli==2.2.1
|
||||
tomli==2.3.0
|
||||
# via
|
||||
# cmeel
|
||||
# coverage
|
||||
# jupytext
|
||||
# pytest
|
||||
torch==2.7.1
|
||||
# via
|
||||
# accelerate
|
||||
# flash-attn
|
||||
# lerobot
|
||||
# peft
|
||||
# robomimic
|
||||
# thop
|
||||
# timm
|
||||
# torchvision
|
||||
torchcodec==0.5
|
||||
# via lerobot
|
||||
torchvision==0.22.1
|
||||
# via lerobot
|
||||
tornado==6.5.1
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
# timm
|
||||
tornado==6.5.2
|
||||
# via meshcat
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# datasets
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# peft
|
||||
# robomimic
|
||||
# transformers
|
||||
traitlets==5.14.3
|
||||
# via
|
||||
# ipython
|
||||
# jupyter-core
|
||||
# matplotlib-inline
|
||||
transformers==4.51.3
|
||||
# via lerobot
|
||||
# nbformat
|
||||
transformers==4.57.1
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
# peft
|
||||
transforms3d==0.4.2
|
||||
# via teleop
|
||||
triton==3.3.1
|
||||
# via torch
|
||||
typing-extensions==4.14.1
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# etils
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
# gymnasium
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# multidict
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
# starlette
|
||||
# torch
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# uvicorn
|
||||
# virtualenv
|
||||
# wandb
|
||||
typing-inspect==0.9.0
|
||||
# via draccus
|
||||
typing-inspection==0.4.1
|
||||
typing-inspection==0.4.2
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via pandas
|
||||
@@ -629,22 +826,36 @@ urllib3==2.5.0
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
virtualenv==20.32.0
|
||||
uvicorn[standard]==0.38.0
|
||||
# via teleop
|
||||
uvloop==0.22.1
|
||||
# via uvicorn
|
||||
virtualenv==20.35.3
|
||||
# via pre-commit
|
||||
wandb==0.21.0
|
||||
# via lerobot
|
||||
wcwidth==0.2.13
|
||||
wandb==0.21.4
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
watchfiles==1.1.1
|
||||
# via uvicorn
|
||||
wcwidth==0.2.14
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.9.0
|
||||
# via teleop
|
||||
websockets==15.0.1
|
||||
# via uvicorn
|
||||
werkzeug==3.1.3
|
||||
# via flask
|
||||
wrapt==1.17.2
|
||||
# via tensorboard
|
||||
wrapt==2.0.0
|
||||
# via dm-tree
|
||||
xxhash==3.5.0
|
||||
xxhash==3.6.0
|
||||
# via datasets
|
||||
yarl==1.20.1
|
||||
yarl==1.22.0
|
||||
# via aiohttp
|
||||
zipp==3.23.0
|
||||
# via importlib-metadata
|
||||
# via
|
||||
# etils
|
||||
# importlib-metadata
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# setuptools
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# requirements.in
|
||||
|
||||
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 15.5 24F74 arm64).
|
||||
# Darwin MacBook-Pro.local 24.5.0 Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:43 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8132 arm64
|
||||
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.0.1 25A362 arm64).
|
||||
# Darwin MacBook-Pro.local 25.0.0 Darwin Kernel Version 25.0.0: Wed Sep 17 21:42:08 PDT 2025; root:xnu-12377.1.9~141/RELEASE_ARM64_T8132 arm64
|
||||
|
||||
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.2 LTS x86_64).
|
||||
# Linux mlerobot-linux 6.14.0-27-generic #27~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 22 17:38:49 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
||||
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.3 LTS x86_64).
|
||||
# Linux mlerobot-linux 6.14.0-33-generic #33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
||||
|
||||
-e .[all]
|
||||
|
||||
278
scripts/find_high_mse_episodes.py
Normal file
278
scripts/find_high_mse_episodes.py
Normal file
@@ -0,0 +1,278 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Script to find episodes with highest MSE between observation.state and action pairs.
|
||||
|
||||
This script:
|
||||
1. Downloads a LeRobot dataset (if needed, skipping videos)
|
||||
2. Computes MSE between observation.state and action for each frame
|
||||
3. Aggregates MSE per episode
|
||||
4. Returns the top 1% episodes with highest total MSE
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
|
||||
def compute_episode_mse(
|
||||
dataset: LeRobotDataset,
|
||||
state_key: str = "observation.state",
|
||||
action_key: str = "action",
|
||||
) -> dict[int, float]:
|
||||
"""
|
||||
Compute total MSE between state and action for each episode.
|
||||
|
||||
Args:
|
||||
dataset: LeRobotDataset to analyze
|
||||
state_key: Key for the observation state in the dataset
|
||||
action_key: Key for the action in the dataset
|
||||
|
||||
Returns:
|
||||
Dictionary mapping episode_index to total MSE for that episode
|
||||
"""
|
||||
episode_mse = {}
|
||||
|
||||
# Get all unique episode indices
|
||||
hf_dataset = dataset.hf_dataset
|
||||
|
||||
# Group frames by episode for efficient processing
|
||||
logging.info("Computing MSE for each episode...")
|
||||
|
||||
# Process all frames and accumulate MSE per episode
|
||||
for idx in tqdm(range(len(hf_dataset)), desc="Processing frames"):
|
||||
item = hf_dataset[idx]
|
||||
|
||||
ep_idx = item["episode_index"]
|
||||
if isinstance(ep_idx, torch.Tensor):
|
||||
ep_idx = ep_idx.item()
|
||||
|
||||
state = item[state_key]
|
||||
action = item[action_key]
|
||||
|
||||
if isinstance(state, torch.Tensor):
|
||||
state = state.numpy()
|
||||
if isinstance(action, torch.Tensor):
|
||||
action = action.numpy()
|
||||
|
||||
# Compute MSE for this frame (sum of squared differences across all dimensions)
|
||||
mse = np.mean((state - action) ** 2)
|
||||
|
||||
if ep_idx not in episode_mse:
|
||||
episode_mse[ep_idx] = 0.0
|
||||
episode_mse[ep_idx] += mse
|
||||
|
||||
return episode_mse
|
||||
|
||||
|
||||
def get_top_mse_episodes(
|
||||
episode_mse: dict[int, float],
|
||||
top_percent: float = 1.0,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get the top X% of episodes with highest total MSE.
|
||||
|
||||
Args:
|
||||
episode_mse: Dictionary mapping episode_index to total MSE
|
||||
top_percent: Percentage of episodes to return (default: 1%)
|
||||
|
||||
Returns:
|
||||
List of episode indices sorted by MSE (highest first)
|
||||
"""
|
||||
# Sort episodes by MSE in descending order
|
||||
sorted_episodes = sorted(episode_mse.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Calculate number of episodes to return
|
||||
num_episodes = len(sorted_episodes)
|
||||
num_top = max(1, int(np.ceil(num_episodes * top_percent / 100)))
|
||||
|
||||
# Extract top episode indices
|
||||
top_episodes = [ep_idx for ep_idx, _ in sorted_episodes[:num_top]]
|
||||
|
||||
return top_episodes
|
||||
|
||||
|
||||
def find_high_mse_episodes(
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
state_key: str = "observation.state",
|
||||
action_key: str = "action",
|
||||
top_percent: float = 1.0,
|
||||
force_download: bool = False,
|
||||
) -> tuple[list[int], dict[int, float]]:
|
||||
"""
|
||||
Find episodes with highest MSE between observation.state and action.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace dataset repository ID
|
||||
root: Local directory for dataset storage (default: ~/.cache/huggingface/lerobot)
|
||||
state_key: Key for the observation state in the dataset
|
||||
action_key: Key for the action in the dataset
|
||||
top_percent: Percentage of episodes to return (default: 1%)
|
||||
force_download: Force re-download of the dataset
|
||||
|
||||
Returns:
|
||||
Tuple of (list of top episode indices, dict of all episode MSEs)
|
||||
"""
|
||||
logging.info(f"Loading dataset: {repo_id}")
|
||||
|
||||
# Load the dataset (skip video download since we only need state/action data)
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
download_videos=False,
|
||||
force_cache_sync=force_download,
|
||||
)
|
||||
|
||||
# Verify the dataset has the required features
|
||||
if state_key not in dataset.features:
|
||||
raise ValueError(f"Dataset does not contain '{state_key}' feature. "
|
||||
f"Available features: {list(dataset.features.keys())}")
|
||||
if action_key not in dataset.features:
|
||||
raise ValueError(f"Dataset does not contain '{action_key}' feature. "
|
||||
f"Available features: {list(dataset.features.keys())}")
|
||||
|
||||
# Check that state and action have the same shape
|
||||
state_shape = tuple(dataset.features[state_key]["shape"])
|
||||
action_shape = tuple(dataset.features[action_key]["shape"])
|
||||
if state_shape != action_shape:
|
||||
raise ValueError(f"State shape {state_shape} does not match action shape {action_shape}")
|
||||
|
||||
logging.info(f"Dataset loaded successfully:")
|
||||
logging.info(f" - Total episodes: {dataset.meta.total_episodes}")
|
||||
logging.info(f" - Total frames: {dataset.meta.total_frames}")
|
||||
logging.info(f" - State shape: {state_shape}")
|
||||
logging.info(f" - Action shape: {action_shape}")
|
||||
logging.info(f" - Feature names: {dataset.features[state_key].get('names', 'N/A')}")
|
||||
|
||||
# Compute MSE for each episode
|
||||
episode_mse = compute_episode_mse(dataset, state_key, action_key)
|
||||
|
||||
# Get top episodes
|
||||
top_episodes = get_top_mse_episodes(episode_mse, top_percent)
|
||||
|
||||
return top_episodes, episode_mse
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Find episodes with highest MSE between observation.state and action"
|
||||
)
|
||||
parser.add_argument(
|
||||
"repo_id",
|
||||
type=str,
|
||||
help="HuggingFace dataset repository ID (e.g., 'lerobot/aloha_sim_insertion_human')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Local directory for dataset storage (default: ~/.cache/huggingface/lerobot)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--state-key",
|
||||
type=str,
|
||||
default="observation.state",
|
||||
help="Key for observation state feature (default: 'observation.state')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action-key",
|
||||
type=str,
|
||||
default="action",
|
||||
help="Key for action feature (default: 'action')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-percent",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Percentage of episodes to return (default: 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-download",
|
||||
action="store_true",
|
||||
help="Force re-download of the dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-all-mse",
|
||||
action="store_true",
|
||||
help="Show MSE values for all episodes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output file to save results (optional)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Find high MSE episodes
|
||||
top_episodes, all_mse = find_high_mse_episodes(
|
||||
repo_id=args.repo_id,
|
||||
root=args.root,
|
||||
state_key=args.state_key,
|
||||
action_key=args.action_key,
|
||||
top_percent=args.top_percent,
|
||||
force_download=args.force_download,
|
||||
)
|
||||
|
||||
# Print results
|
||||
print("\n" + "=" * 60)
|
||||
print(f"TOP {args.top_percent}% EPISODES WITH HIGHEST MSE")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"\nTotal episodes analyzed: {len(all_mse)}")
|
||||
print(f"Number of top episodes (top {args.top_percent}%): {len(top_episodes)}")
|
||||
|
||||
print(f"\nTop {len(top_episodes)} episode(s) with highest MSE:")
|
||||
print("-" * 40)
|
||||
for i, ep_idx in enumerate(top_episodes, 1):
|
||||
print(f" {i:3d}. Episode {ep_idx:5d} - Total MSE: {all_mse[ep_idx]:.6f}")
|
||||
|
||||
# Statistics
|
||||
all_mse_values = list(all_mse.values())
|
||||
print(f"\nMSE Statistics:")
|
||||
print(f" - Mean MSE: {np.mean(all_mse_values):.6f}")
|
||||
print(f" - Std MSE: {np.std(all_mse_values):.6f}")
|
||||
print(f" - Min MSE: {np.min(all_mse_values):.6f}")
|
||||
print(f" - Max MSE: {np.max(all_mse_values):.6f}")
|
||||
print(f" - Median MSE: {np.median(all_mse_values):.6f}")
|
||||
|
||||
if args.show_all_mse:
|
||||
print(f"\nAll episodes sorted by MSE (descending):")
|
||||
print("-" * 40)
|
||||
sorted_episodes = sorted(all_mse.items(), key=lambda x: x[1], reverse=True)
|
||||
for ep_idx, mse in sorted_episodes:
|
||||
print(f" Episode {ep_idx:5d} - Total MSE: {mse:.6f}")
|
||||
|
||||
# Save results if output file specified
|
||||
if args.output:
|
||||
output_path = Path(args.output)
|
||||
with open(output_path, "w") as f:
|
||||
f.write(f"# High MSE Episodes Analysis\n")
|
||||
f.write(f"# Dataset: {args.repo_id}\n")
|
||||
f.write(f"# State key: {args.state_key}\n")
|
||||
f.write(f"# Action key: {args.action_key}\n")
|
||||
f.write(f"# Top percent: {args.top_percent}%\n\n")
|
||||
|
||||
f.write(f"Top {args.top_percent}% episodes:\n")
|
||||
for ep_idx in top_episodes:
|
||||
f.write(f"{ep_idx},{all_mse[ep_idx]:.6f}\n")
|
||||
|
||||
logging.info(f"Results saved to: {output_path}")
|
||||
|
||||
# Return the list for programmatic use
|
||||
return top_episodes
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -57,7 +57,6 @@ available_tasks_per_env = {
|
||||
"AlohaTransferCube-v0",
|
||||
],
|
||||
"pusht": ["PushT-v0"],
|
||||
"xarm": ["XarmLift-v0"],
|
||||
}
|
||||
available_envs = list(available_tasks_per_env.keys())
|
||||
|
||||
@@ -75,16 +74,6 @@ available_datasets_per_env = {
|
||||
# TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly
|
||||
# coupled with tests.
|
||||
"pusht": ["lerobot/pusht", "lerobot/pusht_image"],
|
||||
"xarm": [
|
||||
"lerobot/xarm_lift_medium",
|
||||
"lerobot/xarm_lift_medium_replay",
|
||||
"lerobot/xarm_push_medium",
|
||||
"lerobot/xarm_push_medium_replay",
|
||||
"lerobot/xarm_lift_medium_image",
|
||||
"lerobot/xarm_lift_medium_replay_image",
|
||||
"lerobot/xarm_push_medium_image",
|
||||
"lerobot/xarm_push_medium_replay_image",
|
||||
],
|
||||
}
|
||||
|
||||
available_real_world_datasets = [
|
||||
@@ -195,7 +184,6 @@ available_motors = [
|
||||
available_policies_per_env = {
|
||||
"aloha": ["act"],
|
||||
"pusht": ["diffusion", "vqbet"],
|
||||
"xarm": ["tdmpc"],
|
||||
"koch_real": ["act_koch_real"],
|
||||
"aloha_real": ["act_aloha_real"],
|
||||
}
|
||||
|
||||
@@ -142,11 +142,6 @@ class RobotClientConfig:
|
||||
default=False, metadata={"help": "Visualize the action queue size"}
|
||||
)
|
||||
|
||||
# Verification configuration
|
||||
verify_robot_cameras: bool = field(
|
||||
default=True, metadata={"help": "Verify that the robot cameras match the policy cameras"}
|
||||
)
|
||||
|
||||
@property
|
||||
def environment_dt(self) -> float:
|
||||
"""Environment time step, in seconds"""
|
||||
|
||||
@@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
||||
|
||||
# All action chunking policies
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"]
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"]
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower"]
|
||||
|
||||
@@ -16,7 +16,7 @@ import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@@ -25,7 +25,14 @@ from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
|
||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
|
||||
from lerobot.policies import ( # noqa: F401
|
||||
ACTConfig,
|
||||
DiffusionConfig,
|
||||
PI0Config,
|
||||
PI05Config,
|
||||
SmolVLAConfig,
|
||||
VQBeTConfig,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.utils import init_logging
|
||||
@@ -55,15 +62,6 @@ def visualize_action_queue_size(action_queue_size: list[int]) -> None:
|
||||
plt.show()
|
||||
|
||||
|
||||
def validate_robot_cameras_for_policy(
|
||||
lerobot_observation_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature]
|
||||
) -> None:
|
||||
image_keys = list(filter(is_image_key, lerobot_observation_features))
|
||||
assert set(image_keys) == set(policy_image_features.keys()), (
|
||||
f"Policy image features must match robot cameras! Received {list(policy_image_features.keys())} != {image_keys}"
|
||||
)
|
||||
|
||||
|
||||
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
|
||||
return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False)
|
||||
|
||||
@@ -85,11 +83,11 @@ def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int,
|
||||
return resized.squeeze(0)
|
||||
|
||||
|
||||
# TODO(Steven): Consider implementing a pipeline step for this
|
||||
def raw_observation_to_observation(
|
||||
raw_observation: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
policy_image_features: dict[str, PolicyFeature],
|
||||
device: str,
|
||||
) -> Observation:
|
||||
observation = {}
|
||||
|
||||
@@ -98,9 +96,7 @@ def raw_observation_to_observation(
|
||||
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations
|
||||
if "image" in k:
|
||||
# Policy expects images in shape (B, C, H, W)
|
||||
observation[k] = prepare_image(v).unsqueeze(0).to(device)
|
||||
else:
|
||||
observation[k] = v.to(device)
|
||||
observation[k] = prepare_image(v).unsqueeze(0)
|
||||
else:
|
||||
observation[k] = v
|
||||
|
||||
@@ -272,6 +268,7 @@ class RemotePolicyConfig:
|
||||
lerobot_features: dict[str, PolicyFeature]
|
||||
actions_per_chunk: int
|
||||
device: str = "cpu"
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"""
|
||||
Example:
|
||||
```shell
|
||||
python src/lerobot/async_inference/policy_server.py \
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080 \
|
||||
--fps=30 \
|
||||
@@ -32,12 +32,17 @@ from concurrent import futures
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
@@ -82,6 +87,8 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
self.lerobot_features = None
|
||||
self.actions_per_chunk = None
|
||||
self.policy = None
|
||||
self.preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None
|
||||
self.postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
@@ -146,6 +153,19 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
start = time.perf_counter()
|
||||
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
|
||||
self.policy.to(self.device)
|
||||
|
||||
# Load preprocessor and postprocessor, overriding device to match requested device
|
||||
device_override = {"device": self.device}
|
||||
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
||||
self.policy.config,
|
||||
pretrained_path=policy_specs.pretrained_name_or_path,
|
||||
preprocessor_overrides={
|
||||
"device_processor": device_override,
|
||||
"rename_observations_processor": {"rename_map": policy_specs.rename_map},
|
||||
},
|
||||
postprocessor_overrides={"device_processor": device_override},
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
|
||||
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
|
||||
@@ -173,7 +193,7 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
# Calculate FPS metrics
|
||||
fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp)
|
||||
|
||||
self.logger.info(
|
||||
self.logger.debug(
|
||||
f"Received observation #{obs_timestep} | "
|
||||
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client
|
||||
f"Target: {fps_metrics['target_fps']:.2f} | "
|
||||
@@ -189,7 +209,7 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
if not self._enqueue_observation(
|
||||
timed_observation # wrapping a RawObservation
|
||||
):
|
||||
self.logger.info(f"Observation #{obs_timestep} has been filtered out")
|
||||
self.logger.debug(f"Observation #{obs_timestep} has been filtered out")
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
@@ -301,23 +321,6 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
for i, action in enumerate(action_chunk)
|
||||
]
|
||||
|
||||
def _prepare_observation(self, observation_t: TimedObservation) -> Observation:
|
||||
"""
|
||||
Prepare observation, ready for policy inference.
|
||||
E.g.: To keep observation sampling rate high (and network packet tiny) we send int8 [0,255] images from the
|
||||
client and then convert them to float32 [0,1] images here, before running inference.
|
||||
"""
|
||||
# RawObservation from robot.get_observation() - wrong keys, wrong dtype, wrong image shape
|
||||
observation: Observation = raw_observation_to_observation(
|
||||
observation_t.get_observation(),
|
||||
self.lerobot_features,
|
||||
self.policy_image_features,
|
||||
self.device,
|
||||
)
|
||||
# processed Observation - right keys, right dtype, right image shape
|
||||
|
||||
return observation
|
||||
|
||||
def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Get an action chunk from the policy. The chunk contains only"""
|
||||
chunk = self.policy.predict_action_chunk(observation)
|
||||
@@ -327,44 +330,76 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
return chunk[:, : self.actions_per_chunk, :]
|
||||
|
||||
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
|
||||
"""Predict an action chunk based on an observation"""
|
||||
inference_starts = time.perf_counter()
|
||||
"""Predict an action chunk based on an observation.
|
||||
|
||||
Pipeline:
|
||||
1. Convert raw observation to LeRobot format
|
||||
2. Apply preprocessor (tokenization, normalization, batching, device placement)
|
||||
3. Run policy inference to get action chunk
|
||||
4. Apply postprocessor (unnormalization, device movement)
|
||||
5. Convert to TimedAction list
|
||||
"""
|
||||
"""1. Prepare observation"""
|
||||
start_time = time.perf_counter()
|
||||
observation = self._prepare_observation(observation_t)
|
||||
preprocessing_time = time.perf_counter() - start_time
|
||||
start_prepare = time.perf_counter()
|
||||
observation: Observation = raw_observation_to_observation(
|
||||
observation_t.get_observation(),
|
||||
self.lerobot_features,
|
||||
self.policy_image_features,
|
||||
)
|
||||
prepare_time = time.perf_counter() - start_prepare
|
||||
|
||||
"""2. Apply preprocessor"""
|
||||
start_preprocess = time.perf_counter()
|
||||
observation = self.preprocessor(observation)
|
||||
self.last_processed_obs: TimedObservation = observation_t
|
||||
preprocessing_time = time.perf_counter() - start_preprocess
|
||||
|
||||
"""2. Get action chunk"""
|
||||
start_time = time.perf_counter()
|
||||
"""3. Get action chunk"""
|
||||
start_inference = time.perf_counter()
|
||||
action_tensor = self._get_action_chunk(observation)
|
||||
inference_time = time.perf_counter() - start_time
|
||||
inference_time = time.perf_counter() - start_inference
|
||||
self.logger.info(
|
||||
f"Preprocessing and inference took {inference_time:.4f}s, action shape: {action_tensor.shape}"
|
||||
)
|
||||
|
||||
"""3. Post-inference processing"""
|
||||
start_time = time.perf_counter()
|
||||
# Move to CPU before serializing
|
||||
action_tensor = action_tensor.cpu().squeeze(0)
|
||||
"""4. Apply postprocessor"""
|
||||
# Apply postprocessor (handles unnormalization and device movement)
|
||||
# Postprocessor expects (B, action_dim) per action, but we have (B, chunk_size, action_dim)
|
||||
# So we process each action in the chunk individually
|
||||
start_postprocess = time.perf_counter()
|
||||
_, chunk_size, _ = action_tensor.shape
|
||||
|
||||
# Process each action in the chunk
|
||||
processed_actions = []
|
||||
for i in range(chunk_size):
|
||||
# Extract action at timestep i: (B, action_dim)
|
||||
single_action = action_tensor[:, i, :]
|
||||
processed_action = self.postprocessor(single_action)
|
||||
processed_actions.append(processed_action)
|
||||
|
||||
# Stack back to (B, chunk_size, action_dim), then remove batch dim
|
||||
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
|
||||
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
|
||||
|
||||
"""5. Convert to TimedAction list"""
|
||||
action_chunk = self._time_action_chunk(
|
||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
||||
)
|
||||
postprocessing_time = time.perf_counter() - start_time
|
||||
inference_stops = time.perf_counter()
|
||||
postprocess_stops = time.perf_counter()
|
||||
postprocessing_time = postprocess_stops - start_postprocess
|
||||
|
||||
self.logger.info(
|
||||
f"Observation {observation_t.get_timestep()} |"
|
||||
f"Inference time: {1000 * (inference_stops - inference_starts):.2f}ms"
|
||||
f"Observation {observation_t.get_timestep()} | "
|
||||
f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
|
||||
)
|
||||
|
||||
# full-process latency breakdown for debugging purposes
|
||||
self.logger.debug(
|
||||
f"Observation {observation_t.get_timestep()} | "
|
||||
f"Preprocessing time: {1000 * (preprocessing_time - inference_starts):.2f}ms | "
|
||||
f"Inference time: {1000 * (inference_time - preprocessing_time):.2f}ms | "
|
||||
f"Postprocessing time: {1000 * (postprocessing_time - inference_time):.2f}ms | "
|
||||
f"Total time: {1000 * (postprocessing_time - inference_starts):.2f}ms"
|
||||
f"Prepare time: {1000 * prepare_time:.2f}ms | "
|
||||
f"Preprocessing time: {1000 * preprocessing_time:.2f}ms | "
|
||||
f"Inference time: {1000 * inference_time:.2f}ms | "
|
||||
f"Postprocessing time: {1000 * postprocessing_time:.2f}ms | "
|
||||
f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
|
||||
)
|
||||
|
||||
return action_chunk
|
||||
|
||||
@@ -48,10 +48,10 @@ import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so100_follower,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
@@ -75,7 +75,6 @@ from .helpers import (
|
||||
TimedObservation,
|
||||
get_logger,
|
||||
map_robot_keys_to_lerobot_features,
|
||||
validate_robot_cameras_for_policy,
|
||||
visualize_action_queue_size,
|
||||
)
|
||||
|
||||
@@ -97,14 +96,6 @@ class RobotClient:
|
||||
|
||||
lerobot_features = map_robot_keys_to_lerobot_features(self.robot)
|
||||
|
||||
if config.verify_robot_cameras:
|
||||
# Load policy config for validation
|
||||
policy_config = PreTrainedConfig.from_pretrained(config.pretrained_name_or_path)
|
||||
policy_image_features = policy_config.image_features
|
||||
|
||||
# The cameras specified for inference must match the one supported by the policy chosen
|
||||
validate_robot_cameras_for_policy(lerobot_features, policy_image_features)
|
||||
|
||||
# Use environment variable if server_address is not provided in config
|
||||
self.server_address = config.server_address
|
||||
|
||||
@@ -214,7 +205,7 @@ class RobotClient:
|
||||
)
|
||||
_ = self.stub.SendObservations(observation_iterator)
|
||||
obs_timestep = obs.get_timestep()
|
||||
self.logger.info(f"Sent observation #{obs_timestep} | ")
|
||||
self.logger.debug(f"Sent observation #{obs_timestep} | ")
|
||||
|
||||
return True
|
||||
|
||||
@@ -467,7 +458,7 @@ class RobotClient:
|
||||
if self._ready_to_send_observation():
|
||||
_captured_observation = self.control_loop_observation(task, verbose)
|
||||
|
||||
self.logger.info(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
|
||||
self.logger.debug(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
|
||||
# Dynamically adjust sleep time to maintain the desired control frequency
|
||||
time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start)))
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
from .configs import CameraConfig, ColorMode
|
||||
|
||||
@@ -89,7 +89,7 @@ class Camera(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""Capture and return a single frame from the camera.
|
||||
|
||||
Args:
|
||||
@@ -102,7 +102,7 @@ class Camera(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def async_read(self, timeout_ms: float = ...) -> np.ndarray:
|
||||
def async_read(self, timeout_ms: float = ...) -> NDArray[Any]:
|
||||
"""Asynchronously capture and return a single frame from the camera.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -18,7 +18,7 @@ import abc
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import draccus
|
||||
import draccus # type: ignore # TODO: add type stubs for draccus
|
||||
|
||||
|
||||
class ColorMode(str, Enum):
|
||||
@@ -34,11 +34,11 @@ class Cv2Rotation(int, Enum):
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
return str(self.get_choice_name(self.__class__))
|
||||
|
||||
@@ -14,3 +14,5 @@
|
||||
|
||||
from .camera_opencv import OpenCVCamera
|
||||
from .configuration_opencv import OpenCVCameraConfig
|
||||
|
||||
__all__ = ["OpenCVCamera", "OpenCVCameraConfig"]
|
||||
|
||||
@@ -25,11 +25,12 @@ from pathlib import Path
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
# Fix MSMF hardware transform compatibility for Windows before importing cv2
|
||||
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2
|
||||
import numpy as np
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
@@ -121,7 +122,7 @@ class OpenCVCamera(Camera):
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
@@ -140,7 +141,7 @@ class OpenCVCamera(Camera):
|
||||
"""Checks if the camera is currently connected and opened."""
|
||||
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
|
||||
|
||||
def connect(self, warmup: bool = True):
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the OpenCV camera specified in the configuration.
|
||||
|
||||
@@ -180,12 +181,14 @@ class OpenCVCamera(Camera):
|
||||
|
||||
def _configure_capture_settings(self) -> None:
|
||||
"""
|
||||
Applies the specified FPS, width, and height settings to the connected camera.
|
||||
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
|
||||
|
||||
This method attempts to set the camera properties via OpenCV. It checks if
|
||||
the camera successfully applied the settings and raises an error if not.
|
||||
FOURCC is set first (if specified) as it can affect the available FPS and resolution options.
|
||||
|
||||
Args:
|
||||
fourcc: The desired FOURCC code (e.g., "MJPG", "YUYV"). If None, auto-detect.
|
||||
fps: The desired frames per second. If None, the setting is skipped.
|
||||
width: The desired capture width. If None, the setting is skipped.
|
||||
height: The desired capture height. If None, the setting is skipped.
|
||||
@@ -199,10 +202,11 @@ class OpenCVCamera(Camera):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
|
||||
|
||||
if self.fps is None:
|
||||
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
|
||||
else:
|
||||
self._validate_fps()
|
||||
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
|
||||
if self.config.fourcc is not None:
|
||||
self._validate_fourcc()
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
|
||||
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
||||
@@ -216,18 +220,56 @@ class OpenCVCamera(Camera):
|
||||
else:
|
||||
self._validate_width_and_height()
|
||||
|
||||
if self.fps is None:
|
||||
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
|
||||
else:
|
||||
self._validate_fps()
|
||||
|
||||
def _validate_fps(self) -> None:
|
||||
"""Validates and sets the camera's frames per second (FPS)."""
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
if self.fps is None:
|
||||
raise ValueError(f"{self} FPS is not set")
|
||||
|
||||
success = self.videocapture.set(cv2.CAP_PROP_FPS, float(self.fps))
|
||||
actual_fps = self.videocapture.get(cv2.CAP_PROP_FPS)
|
||||
# Use math.isclose for robust float comparison
|
||||
if not success or not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
raise RuntimeError(f"{self} failed to set fps={self.fps} ({actual_fps=}).")
|
||||
|
||||
def _validate_fourcc(self) -> None:
|
||||
"""Validates and sets the camera's FOURCC code."""
|
||||
|
||||
fourcc_code = cv2.VideoWriter_fourcc(*self.config.fourcc)
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
success = self.videocapture.set(cv2.CAP_PROP_FOURCC, fourcc_code)
|
||||
actual_fourcc_code = self.videocapture.get(cv2.CAP_PROP_FOURCC)
|
||||
|
||||
# Convert actual FOURCC code back to string for comparison
|
||||
actual_fourcc_code_int = int(actual_fourcc_code)
|
||||
actual_fourcc = "".join([chr((actual_fourcc_code_int >> 8 * i) & 0xFF) for i in range(4)])
|
||||
|
||||
if not success or actual_fourcc != self.config.fourcc:
|
||||
logger.warning(
|
||||
f"{self} failed to set fourcc={self.config.fourcc} (actual={actual_fourcc}, success={success}). "
|
||||
f"Continuing with default format."
|
||||
)
|
||||
|
||||
def _validate_width_and_height(self) -> None:
|
||||
"""Validates and sets the camera's frame capture width and height."""
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
if self.capture_width is None or self.capture_height is None:
|
||||
raise ValueError(f"{self} capture_width or capture_height is not set")
|
||||
|
||||
width_success = self.videocapture.set(cv2.CAP_PROP_FRAME_WIDTH, float(self.capture_width))
|
||||
height_success = self.videocapture.set(cv2.CAP_PROP_FRAME_HEIGHT, float(self.capture_height))
|
||||
|
||||
@@ -258,11 +300,12 @@ class OpenCVCamera(Camera):
|
||||
"""
|
||||
found_cameras_info = []
|
||||
|
||||
targets_to_scan: list[str | int]
|
||||
if platform.system() == "Linux":
|
||||
possible_paths = sorted(Path("/dev").glob("video*"), key=lambda p: p.name)
|
||||
targets_to_scan = [str(p) for p in possible_paths]
|
||||
else:
|
||||
targets_to_scan = list(range(MAX_OPENCV_INDEX))
|
||||
targets_to_scan = [int(i) for i in range(MAX_OPENCV_INDEX)]
|
||||
|
||||
for target in targets_to_scan:
|
||||
camera = cv2.VideoCapture(target)
|
||||
@@ -271,6 +314,12 @@ class OpenCVCamera(Camera):
|
||||
default_height = int(camera.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
default_fps = camera.get(cv2.CAP_PROP_FPS)
|
||||
default_format = camera.get(cv2.CAP_PROP_FORMAT)
|
||||
|
||||
# Get FOURCC code and convert to string
|
||||
default_fourcc_code = camera.get(cv2.CAP_PROP_FOURCC)
|
||||
default_fourcc_code_int = int(default_fourcc_code)
|
||||
default_fourcc = "".join([chr((default_fourcc_code_int >> 8 * i) & 0xFF) for i in range(4)])
|
||||
|
||||
camera_info = {
|
||||
"name": f"OpenCV Camera @ {target}",
|
||||
"type": "OpenCV",
|
||||
@@ -278,6 +327,7 @@ class OpenCVCamera(Camera):
|
||||
"backend_api": camera.getBackendName(),
|
||||
"default_stream_profile": {
|
||||
"format": default_format,
|
||||
"fourcc": default_fourcc,
|
||||
"width": default_width,
|
||||
"height": default_height,
|
||||
"fps": default_fps,
|
||||
@@ -289,7 +339,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return found_cameras_info
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
@@ -317,6 +367,9 @@ class OpenCVCamera(Camera):
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
ret, frame = self.videocapture.read()
|
||||
|
||||
if not ret or frame is None:
|
||||
@@ -329,7 +382,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return processed_frame
|
||||
|
||||
def _postprocess_image(self, image: np.ndarray, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Applies color conversion, dimension validation, and rotation to a raw frame.
|
||||
|
||||
@@ -372,7 +425,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return processed_image
|
||||
|
||||
def _read_loop(self):
|
||||
def _read_loop(self) -> None:
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
@@ -383,6 +436,9 @@ class OpenCVCamera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
@@ -419,7 +475,7 @@ class OpenCVCamera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
@@ -462,7 +518,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera and cleans up resources.
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ from pathlib import Path
|
||||
|
||||
from ..configs import CameraConfig, ColorMode, Cv2Rotation
|
||||
|
||||
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("opencv")
|
||||
@dataclass
|
||||
@@ -33,8 +35,9 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
OpenCVCameraConfig(0, 30, 1280, 720) # 1280x720 @ 30FPS
|
||||
OpenCVCameraConfig(/dev/video4, 60, 640, 480) # 640x480 @ 60FPS
|
||||
|
||||
# Advanced configurations
|
||||
OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation
|
||||
# Advanced configurations with FOURCC format
|
||||
OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90, fourcc="MJPG") # With 90° rotation and MJPG format
|
||||
OpenCVCameraConfig(0, 30, 1280, 720, fourcc="YUYV") # With YUYV format
|
||||
```
|
||||
|
||||
Attributes:
|
||||
@@ -46,17 +49,21 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
|
||||
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
|
||||
warmup_s: Time reading frames before returning from connect (in seconds)
|
||||
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
|
||||
|
||||
Note:
|
||||
- Only 3-channel color output (RGB/BGR) is currently supported.
|
||||
- FOURCC codes must be 4-character strings (e.g., "MJPG", "YUYV"). Some common FOUCC codes: https://learn.microsoft.com/en-us/windows/win32/medfound/video-fourccs#fourcc-constants
|
||||
- Setting FOURCC can help achieve higher frame rates on some cameras.
|
||||
"""
|
||||
|
||||
index_or_path: int | Path
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
|
||||
warmup_s: int = 1
|
||||
fourcc: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
@@ -71,3 +78,8 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
raise ValueError(
|
||||
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
|
||||
)
|
||||
|
||||
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
|
||||
raise ValueError(
|
||||
f"`fourcc` must be a 4-character string (e.g., 'MJPG', 'YUYV'), but '{self.fourcc}' is provided."
|
||||
)
|
||||
|
||||
@@ -16,6 +16,8 @@ from dataclasses import dataclass
|
||||
|
||||
from ..configs import CameraConfig, ColorMode
|
||||
|
||||
__all__ = ["CameraConfig", "ColorMode", "Reachy2CameraConfig"]
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("reachy2_camera")
|
||||
@dataclass
|
||||
@@ -62,7 +64,7 @@ class Reachy2CameraConfig(CameraConfig):
|
||||
port: int = 50065
|
||||
# use_depth: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.name not in ["teleop", "depth"]:
|
||||
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
|
||||
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (
|
||||
|
||||
@@ -23,13 +23,17 @@ import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
# Fix MSMF hardware transform compatibility for Windows before importing cv2
|
||||
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2
|
||||
import numpy as np
|
||||
from reachy2_sdk.media.camera import CameraView
|
||||
from reachy2_sdk.media.camera_manager import CameraManager
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import numpy as np # type: ignore # TODO: add type stubs for numpy
|
||||
from reachy2_sdk.media.camera import CameraView # type: ignore # TODO: add type stubs for reachy2_sdk
|
||||
from reachy2_sdk.media.camera_manager import ( # type: ignore # TODO: add type stubs for reachy2_sdk
|
||||
CameraManager,
|
||||
)
|
||||
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
@@ -73,7 +77,7 @@ class Reachy2Camera(Camera):
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -83,13 +87,17 @@ class Reachy2Camera(Camera):
|
||||
def is_connected(self) -> bool:
|
||||
"""Checks if the camera is currently connected and opened."""
|
||||
if self.config.name == "teleop":
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
|
||||
return bool(
|
||||
self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
|
||||
)
|
||||
elif self.config.name == "depth":
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
|
||||
return bool(
|
||||
self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
|
||||
|
||||
def connect(self, warmup: bool = True):
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the Reachy2 CameraManager as specified in the configuration.
|
||||
"""
|
||||
@@ -131,7 +139,7 @@ class Reachy2Camera(Camera):
|
||||
camera_manager.disconnect()
|
||||
return initialized_cameras
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
@@ -152,7 +160,7 @@ class Reachy2Camera(Camera):
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
frame = None
|
||||
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
|
||||
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
@@ -179,7 +187,7 @@ class Reachy2Camera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self):
|
||||
def _read_loop(self) -> None:
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
@@ -190,6 +198,9 @@ class Reachy2Camera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
@@ -226,7 +237,7 @@ class Reachy2Camera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
@@ -269,7 +280,7 @@ class Reachy2Camera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Stops the background read thread (if running).
|
||||
|
||||
|
||||
@@ -21,11 +21,12 @@ import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import numpy as np # type: ignore # TODO: add type stubs for numpy
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
try:
|
||||
import pyrealsense2 as rs
|
||||
import pyrealsense2 as rs # type: ignore # TODO: add type stubs for pyrealsense2
|
||||
except Exception as e:
|
||||
logging.info(f"Could not import realsense: {e}")
|
||||
|
||||
@@ -132,7 +133,7 @@ class RealSenseCamera(Camera):
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
@@ -150,7 +151,7 @@ class RealSenseCamera(Camera):
|
||||
"""Checks if the camera pipeline is started and streams are active."""
|
||||
return self.rs_pipeline is not None and self.rs_profile is not None
|
||||
|
||||
def connect(self, warmup: bool = True):
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the RealSense camera specified in the configuration.
|
||||
|
||||
@@ -264,7 +265,7 @@ class RealSenseCamera(Camera):
|
||||
serial_number = str(found_devices[0]["serial_number"])
|
||||
return serial_number
|
||||
|
||||
def _configure_rs_pipeline_config(self, rs_config):
|
||||
def _configure_rs_pipeline_config(self, rs_config: Any) -> None:
|
||||
"""Creates and configures the RealSense pipeline configuration object."""
|
||||
rs.config.enable_device(rs_config, self.serial_number)
|
||||
|
||||
@@ -293,6 +294,9 @@ class RealSenseCamera(Camera):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
|
||||
|
||||
if self.rs_profile is None:
|
||||
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
|
||||
|
||||
stream = self.rs_profile.get_stream(rs.stream.color).as_video_stream_profile()
|
||||
|
||||
if self.fps is None:
|
||||
@@ -308,7 +312,7 @@ class RealSenseCamera(Camera):
|
||||
self.width, self.height = actual_width, actual_height
|
||||
self.capture_width, self.capture_height = actual_width, actual_height
|
||||
|
||||
def read_depth(self, timeout_ms: int = 200) -> np.ndarray:
|
||||
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame (depth) synchronously from the camera.
|
||||
|
||||
@@ -336,6 +340,9 @@ class RealSenseCamera(Camera):
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.rs_pipeline is None:
|
||||
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
|
||||
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
|
||||
|
||||
if not ret or frame is None:
|
||||
@@ -351,7 +358,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return depth_map_processed
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> np.ndarray:
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame (color) synchronously from the camera.
|
||||
|
||||
@@ -376,6 +383,9 @@ class RealSenseCamera(Camera):
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.rs_pipeline is None:
|
||||
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
|
||||
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
|
||||
|
||||
if not ret or frame is None:
|
||||
@@ -392,8 +402,8 @@ class RealSenseCamera(Camera):
|
||||
return color_image_processed
|
||||
|
||||
def _postprocess_image(
|
||||
self, image: np.ndarray, color_mode: ColorMode | None = None, depth_frame: bool = False
|
||||
) -> np.ndarray:
|
||||
self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False
|
||||
) -> NDArray[Any]:
|
||||
"""
|
||||
Applies color conversion, dimension validation, and rotation to a raw color frame.
|
||||
|
||||
@@ -438,7 +448,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return processed_image
|
||||
|
||||
def _read_loop(self):
|
||||
def _read_loop(self) -> None:
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
@@ -449,6 +459,9 @@ class RealSenseCamera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read(timeout_ms=500)
|
||||
@@ -474,7 +487,7 @@ class RealSenseCamera(Camera):
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def _stop_read_thread(self):
|
||||
def _stop_read_thread(self) -> None:
|
||||
"""Signals the background read thread to stop and waits for it to join."""
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
@@ -486,7 +499,7 @@ class RealSenseCamera(Camera):
|
||||
self.stop_event = None
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame data (color) asynchronously.
|
||||
|
||||
@@ -529,7 +542,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera, stops the pipeline, and cleans up resources.
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ class RealSenseCameraConfig(CameraConfig):
|
||||
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
|
||||
warmup_s: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
|
||||
@@ -15,15 +15,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
import platform
|
||||
from typing import cast
|
||||
|
||||
from lerobot.utils.import_utils import make_device_from_device_class
|
||||
|
||||
from .camera import Camera
|
||||
from .configs import CameraConfig, Cv2Rotation
|
||||
|
||||
|
||||
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[str, Camera]:
|
||||
cameras = {}
|
||||
cameras: dict[str, Camera] = {}
|
||||
|
||||
for key, cfg in camera_configs.items():
|
||||
# TODO(Steven): Consider just using the make_device_from_device_class for all types
|
||||
if cfg.type == "opencv":
|
||||
from .opencv import OpenCVCamera
|
||||
|
||||
@@ -40,20 +44,23 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
|
||||
cameras[key] = Reachy2Camera(cfg)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
|
||||
try:
|
||||
cameras[key] = cast(Camera, make_device_from_device_class(cfg))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error creating camera {key} with config {cfg}: {e}") from e
|
||||
|
||||
return cameras
|
||||
|
||||
|
||||
def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
|
||||
import cv2
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
|
||||
if rotation == Cv2Rotation.ROTATE_90:
|
||||
return cv2.ROTATE_90_CLOCKWISE
|
||||
return int(cv2.ROTATE_90_CLOCKWISE)
|
||||
elif rotation == Cv2Rotation.ROTATE_180:
|
||||
return cv2.ROTATE_180
|
||||
return int(cv2.ROTATE_180)
|
||||
elif rotation == Cv2Rotation.ROTATE_270:
|
||||
return cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -62,8 +69,8 @@ def get_cv2_backend() -> int:
|
||||
import cv2
|
||||
|
||||
if platform.system() == "Windows":
|
||||
return cv2.CAP_MSMF # Use MSMF for Windows instead of AVFOUNDATION
|
||||
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
|
||||
# elif platform.system() == "Darwin": # macOS
|
||||
# return cv2.CAP_AVFOUNDATION
|
||||
else: # Linux and others
|
||||
return cv2.CAP_ANY
|
||||
return int(cv2.CAP_ANY)
|
||||
|
||||
@@ -57,7 +57,7 @@ class EvalConfig:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
use_async_envs: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.batch_size > self.n_episodes:
|
||||
raise ValueError(
|
||||
"The eval batch size is greater than the number of eval episodes "
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import datetime as dt
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot import envs, policies # noqa: F401
|
||||
@@ -22,6 +22,8 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.default import EvalConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalPipelineConfig:
|
||||
@@ -34,25 +36,31 @@ class EvalPipelineConfig:
|
||||
output_dir: Path | None = None
|
||||
job_name: str | None = None
|
||||
seed: int | None = 1000
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
|
||||
else:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
||||
)
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{self.policy.type}"
|
||||
self.job_name = f"{self.policy.type if self.policy is not None else 'scratch'}"
|
||||
else:
|
||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
||||
self.job_name = (
|
||||
f"{self.env.type}_{self.policy.type if self.policy is not None else 'scratch'}"
|
||||
)
|
||||
|
||||
logger.warning(f"No job name provided, using '{self.job_name}' as job name.")
|
||||
|
||||
if not self.output_dir:
|
||||
now = dt.datetime.now()
|
||||
|
||||
@@ -16,14 +16,19 @@ import inspect
|
||||
import pkgutil
|
||||
import sys
|
||||
from argparse import ArgumentError
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from pkgutil import ModuleInfo
|
||||
from types import ModuleType
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.utils.utils import has_method
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
PATH_KEY = "path"
|
||||
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||
|
||||
@@ -60,7 +65,7 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
|
||||
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict[str, str]:
|
||||
"""Parse plugin-related arguments from command-line arguments.
|
||||
|
||||
This function extracts arguments from command-line arguments that match a specified suffix pattern.
|
||||
@@ -127,7 +132,7 @@ def load_plugin(plugin_path: str) -> None:
|
||||
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
||||
) from e
|
||||
|
||||
def iter_namespace(ns_pkg):
|
||||
def iter_namespace(ns_pkg: ModuleType) -> Iterable[ModuleInfo]:
|
||||
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
|
||||
|
||||
try:
|
||||
@@ -148,6 +153,8 @@ def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | No
|
||||
|
||||
|
||||
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
|
||||
if args is None:
|
||||
return []
|
||||
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
|
||||
|
||||
|
||||
@@ -171,7 +178,8 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
if isinstance(fields_to_filter, str):
|
||||
fields_to_filter = [fields_to_filter]
|
||||
|
||||
filtered_args = args
|
||||
filtered_args = [] if args is None else list(args)
|
||||
|
||||
for field in fields_to_filter:
|
||||
if get_path_arg(field, args):
|
||||
if get_type_arg(field, args):
|
||||
@@ -184,7 +192,7 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
return filtered_args
|
||||
|
||||
|
||||
def wrap(config_path: Path | None = None):
|
||||
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
"""
|
||||
HACK: Similar to draccus.wrap but does three additional things:
|
||||
- Will remove '.path' arguments from CLI in order to process them later on.
|
||||
@@ -195,9 +203,9 @@ def wrap(config_path: Path | None = None):
|
||||
from the CLI '.type' arguments
|
||||
"""
|
||||
|
||||
def wrapper_outer(fn):
|
||||
def wrapper_outer(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def wrapper_inner(*args, **kwargs):
|
||||
def wrapper_inner(*args: Any, **kwargs: Any) -> Any:
|
||||
argspec = inspect.getfullargspec(fn)
|
||||
argtype = argspec.annotations[argspec.args[0]]
|
||||
if len(args) > 0 and type(args[0]) is argtype:
|
||||
@@ -225,6 +233,6 @@ def wrap(config_path: Path | None = None):
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
return wrapper_inner
|
||||
return cast(F, wrapper_inner)
|
||||
|
||||
return wrapper_outer
|
||||
return cast(Callable[[F], F], wrapper_outer)
|
||||
|
||||
@@ -14,12 +14,12 @@
|
||||
import abc
|
||||
import builtins
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -34,10 +34,11 @@ from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedConfig")
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: ignore[misc,name-defined] #TODO: draccus issue
|
||||
"""
|
||||
Base configuration class for policy models.
|
||||
|
||||
@@ -57,12 +58,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
device: str | None = None # cuda | cpu | mp
|
||||
device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps"
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool = False
|
||||
|
||||
push_to_hub: bool = True
|
||||
push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override
|
||||
repo_id: str | None = None
|
||||
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
@@ -71,38 +72,43 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
tags: list[str] | None = None
|
||||
# Add tags to your policy on the hub.
|
||||
license: str | None = None
|
||||
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
||||
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
|
||||
pretrained_path: Path | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.pretrained_path = None
|
||||
def __post_init__(self) -> None:
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
self.device = auto_device.type
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
if self.use_amp and not is_amp_available(self.device):
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||
)
|
||||
self.use_amp = False
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
choice_name = self.get_choice_name(self.__class__)
|
||||
if not isinstance(choice_name, str):
|
||||
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||
return choice_name
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_delta_indices(self) -> list | None:
|
||||
def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def reward_delta_indices(self) -> list | None:
|
||||
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -152,13 +158,13 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = None,
|
||||
proxies: dict | None = None,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[Any, Any] | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**policy_kwargs,
|
||||
**policy_kwargs: Any,
|
||||
) -> T:
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
@@ -166,7 +172,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
if CONFIG_NAME in os.listdir(model_id):
|
||||
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||
else:
|
||||
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
@@ -192,6 +198,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
with draccus.config_type("json"):
|
||||
orig_config = draccus.parse(cls, config_file, args=[])
|
||||
|
||||
if config_file is None:
|
||||
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
|
||||
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import datetime as dt
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -63,18 +64,18 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
self.checkpoint_path = None
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
# Only load the policy config
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
elif self.resume:
|
||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||
config_path = parser.parse_arg("config_path")
|
||||
@@ -82,14 +83,22 @@ class TrainPipelineConfig(HubMixin):
|
||||
raise ValueError(
|
||||
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
||||
)
|
||||
|
||||
if not Path(config_path).resolve().exists():
|
||||
raise NotADirectoryError(
|
||||
f"{config_path=} is expected to be a local path. "
|
||||
"Resuming from the hub is not supported for now."
|
||||
)
|
||||
policy_path = Path(config_path).parent
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.checkpoint_path = policy_path.parent
|
||||
|
||||
policy_dir = Path(config_path).parent
|
||||
if self.policy is not None:
|
||||
self.policy.pretrained_path = policy_dir
|
||||
self.checkpoint_path = policy_dir.parent
|
||||
|
||||
if self.policy is None:
|
||||
raise ValueError(
|
||||
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
|
||||
)
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
@@ -126,8 +135,8 @@ class TrainPipelineConfig(HubMixin):
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return draccus.encode(self)
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
@@ -139,13 +148,13 @@ class TrainPipelineConfig(HubMixin):
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = None,
|
||||
proxies: dict | None = None,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[Any, Any] | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> "TrainPipelineConfig":
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
@@ -181,4 +190,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
@@ -35,9 +35,11 @@ class NormalizationMode(str, Enum):
|
||||
MIN_MAX = "MIN_MAX"
|
||||
MEAN_STD = "MEAN_STD"
|
||||
IDENTITY = "IDENTITY"
|
||||
QUANTILES = "QUANTILES"
|
||||
QUANTILE10 = "QUANTILE10"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyFeature:
|
||||
type: FeatureType
|
||||
shape: tuple
|
||||
shape: tuple[int, ...]
|
||||
|
||||
@@ -31,15 +31,15 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
get_file_size_in_mb,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_video_size_in_mb,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
|
||||
|
||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
@@ -130,10 +130,34 @@ def update_meta_data(
|
||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||
for key, video_idx in videos_idx.items():
|
||||
df[f"videos/{key}/chunk_index"] = df[f"videos/{key}/chunk_index"] + video_idx["chunk"]
|
||||
df[f"videos/{key}/file_index"] = df[f"videos/{key}/file_index"] + video_idx["file"]
|
||||
df[f"videos/{key}/from_timestamp"] = df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||
df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
|
||||
# Store original video file indices before updating
|
||||
orig_chunk_col = f"videos/{key}/chunk_index"
|
||||
orig_file_col = f"videos/{key}/file_index"
|
||||
df["_orig_chunk"] = df[orig_chunk_col].copy()
|
||||
df["_orig_file"] = df[orig_file_col].copy()
|
||||
|
||||
# Update chunk and file indices to point to destination
|
||||
df[orig_chunk_col] = video_idx["chunk"]
|
||||
df[orig_file_col] = video_idx["file"]
|
||||
|
||||
# Apply per-source-file timestamp offsets
|
||||
src_to_offset = video_idx.get("src_to_offset", {})
|
||||
if src_to_offset:
|
||||
# Apply offset based on original source file
|
||||
for idx in df.index:
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||
else:
|
||||
# Fallback to simple offset (for backward compatibility)
|
||||
df[f"videos/{key}/from_timestamp"] = (
|
||||
df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||
)
|
||||
df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
|
||||
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
@@ -193,6 +217,10 @@ def aggregate_datasets(
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
root=aggr_root,
|
||||
use_videos=len(video_keys) > 0,
|
||||
chunks_size=chunk_size,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
video_files_size_in_mb=video_files_size_in_mb,
|
||||
)
|
||||
|
||||
logging.info("Find all tasks")
|
||||
@@ -236,6 +264,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
Returns:
|
||||
dict: Updated videos_idx with current chunk and file indices.
|
||||
"""
|
||||
for key in videos_idx:
|
||||
videos_idx[key]["episode_duration"] = 0
|
||||
# Track offset for each source (chunk, file) pair
|
||||
videos_idx[key]["src_to_offset"] = {}
|
||||
|
||||
for key, video_idx in videos_idx.items():
|
||||
unique_chunk_file_pairs = {
|
||||
(chunk, file)
|
||||
@@ -249,6 +282,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
|
||||
chunk_idx = video_idx["chunk"]
|
||||
file_idx = video_idx["file"]
|
||||
current_offset = video_idx["latest_duration"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
@@ -263,21 +297,25 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
# If a new file is created, we don't want to increment the latest_duration
|
||||
update_latest_duration = False
|
||||
src_duration = get_video_duration_in_s(src_path)
|
||||
|
||||
if not dst_path.exists():
|
||||
# First write to this destination file
|
||||
# Store offset before incrementing
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
continue # not accumulating further, already copied the file in place
|
||||
videos_idx[key]["episode_duration"] += src_duration
|
||||
current_offset += src_duration
|
||||
continue
|
||||
|
||||
# Check file sizes before appending
|
||||
src_size = get_video_size_in_mb(src_path)
|
||||
dst_size = get_video_size_in_mb(dst_path)
|
||||
src_size = get_file_size_in_mb(src_path)
|
||||
dst_size = get_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= video_files_size_in_mb:
|
||||
# Rotate to a new chunk/file
|
||||
# Rotate to a new file, this source becomes start of new destination
|
||||
# So its offset should be 0
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
@@ -286,25 +324,22 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
# Reset offset for next file
|
||||
current_offset = src_duration
|
||||
else:
|
||||
# Get the timestamps shift for this video
|
||||
timestamps_shift_s = dst_meta.info["total_frames"] / dst_meta.info["fps"]
|
||||
|
||||
# Append to existing video file
|
||||
# Append to existing video file - use current accumulated offset
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
concatenate_video_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
)
|
||||
# Update the latest_duration when appending (shifts timestamps!)
|
||||
update_latest_duration = not update_latest_duration
|
||||
current_offset += src_duration
|
||||
|
||||
videos_idx[key]["episode_duration"] += src_duration
|
||||
|
||||
# Update the videos_idx with the final chunk and file indices for this key
|
||||
videos_idx[key]["chunk"] = chunk_idx
|
||||
videos_idx[key]["file"] = file_idx
|
||||
|
||||
if update_latest_duration:
|
||||
videos_idx[key]["latest_duration"] += timestamps_shift_s
|
||||
|
||||
return videos_idx
|
||||
|
||||
|
||||
@@ -389,9 +424,6 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
videos_idx,
|
||||
)
|
||||
|
||||
for k in videos_idx:
|
||||
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||
|
||||
meta_idx = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
@@ -403,6 +435,10 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
aggr_root=dst_meta.root,
|
||||
)
|
||||
|
||||
# Increment latest_duration by the total duration added from this source dataset
|
||||
for k in videos_idx:
|
||||
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||
|
||||
return meta_idx
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,9 @@ Please, update your dataset to the new format using this command:
|
||||
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id={repo_id}
|
||||
```
|
||||
|
||||
If you already have a converted version uploaded to the hub, then this error might be because of
|
||||
an older version in your local cache. Consider deleting the cached version and retrying.
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
@@ -17,6 +17,179 @@ import numpy as np
|
||||
|
||||
from lerobot.datasets.utils import load_image_as_numpy
|
||||
|
||||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||||
|
||||
|
||||
class RunningQuantileStats:
|
||||
"""
|
||||
Maintains running statistics for batches of vectors, including mean,
|
||||
standard deviation, min, max, and approximate quantiles.
|
||||
|
||||
Statistics are computed per feature dimension and updated incrementally
|
||||
as new batches are observed. Quantiles are estimated using histograms,
|
||||
which adapt dynamically if the observed data range expands.
|
||||
"""
|
||||
|
||||
def __init__(self, quantile_list: list[float] | None = None, num_quantile_bins: int = 5000):
|
||||
self._count = 0
|
||||
self._mean = None
|
||||
self._mean_of_squares = None
|
||||
self._min = None
|
||||
self._max = None
|
||||
self._histograms = None
|
||||
self._bin_edges = None
|
||||
self._num_quantile_bins = num_quantile_bins
|
||||
|
||||
self._quantile_list = quantile_list
|
||||
if self._quantile_list is None:
|
||||
self._quantile_list = DEFAULT_QUANTILES
|
||||
self._quantile_keys = [f"q{int(q * 100):02d}" for q in self._quantile_list]
|
||||
|
||||
def update(self, batch: np.ndarray) -> None:
|
||||
"""Update the running statistics with a batch of vectors.
|
||||
|
||||
Args:
|
||||
batch: An array where all dimensions except the last are batch dimensions.
|
||||
"""
|
||||
batch = batch.reshape(-1, batch.shape[-1])
|
||||
num_elements, vector_length = batch.shape
|
||||
|
||||
if self._count == 0:
|
||||
self._mean = np.mean(batch, axis=0)
|
||||
self._mean_of_squares = np.mean(batch**2, axis=0)
|
||||
self._min = np.min(batch, axis=0)
|
||||
self._max = np.max(batch, axis=0)
|
||||
self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)]
|
||||
self._bin_edges = [
|
||||
np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1)
|
||||
for i in range(vector_length)
|
||||
]
|
||||
else:
|
||||
if vector_length != self._mean.size:
|
||||
raise ValueError("The length of new vectors does not match the initialized vector length.")
|
||||
|
||||
new_max = np.max(batch, axis=0)
|
||||
new_min = np.min(batch, axis=0)
|
||||
max_changed = np.any(new_max > self._max)
|
||||
min_changed = np.any(new_min < self._min)
|
||||
self._max = np.maximum(self._max, new_max)
|
||||
self._min = np.minimum(self._min, new_min)
|
||||
|
||||
if max_changed or min_changed:
|
||||
self._adjust_histograms()
|
||||
|
||||
self._count += num_elements
|
||||
|
||||
batch_mean = np.mean(batch, axis=0)
|
||||
batch_mean_of_squares = np.mean(batch**2, axis=0)
|
||||
|
||||
# Update running mean and mean of squares
|
||||
self._mean += (batch_mean - self._mean) * (num_elements / self._count)
|
||||
self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (
|
||||
num_elements / self._count
|
||||
)
|
||||
|
||||
self._update_histograms(batch)
|
||||
|
||||
def get_statistics(self) -> dict[str, np.ndarray]:
|
||||
"""Compute and return the statistics of the vectors processed so far.
|
||||
|
||||
Args:
|
||||
quantiles: List of quantiles to compute (e.g., [0.01, 0.10, 0.50, 0.90, 0.99]). If None, no quantiles computed.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the computed statistics.
|
||||
"""
|
||||
if self._count < 2:
|
||||
raise ValueError("Cannot compute statistics for less than 2 vectors.")
|
||||
|
||||
variance = self._mean_of_squares - self._mean**2
|
||||
|
||||
stddev = np.sqrt(np.maximum(0, variance))
|
||||
|
||||
stats = {
|
||||
"min": self._min.copy(),
|
||||
"max": self._max.copy(),
|
||||
"mean": self._mean.copy(),
|
||||
"std": stddev,
|
||||
"count": np.array([self._count]),
|
||||
}
|
||||
|
||||
quantile_results = self._compute_quantiles()
|
||||
for i, q in enumerate(self._quantile_keys):
|
||||
stats[q] = quantile_results[i]
|
||||
|
||||
return stats
|
||||
|
||||
def _adjust_histograms(self):
|
||||
"""Adjust histograms when min or max changes."""
|
||||
for i in range(len(self._histograms)):
|
||||
old_edges = self._bin_edges[i]
|
||||
old_hist = self._histograms[i]
|
||||
|
||||
# Create new edges with small padding to ensure range coverage
|
||||
padding = (self._max[i] - self._min[i]) * 1e-10
|
||||
new_edges = np.linspace(
|
||||
self._min[i] - padding, self._max[i] + padding, self._num_quantile_bins + 1
|
||||
)
|
||||
|
||||
# Redistribute existing histogram counts to new bins
|
||||
# We need to map each old bin center to the new bins
|
||||
old_centers = (old_edges[:-1] + old_edges[1:]) / 2
|
||||
new_hist = np.zeros(self._num_quantile_bins)
|
||||
|
||||
for old_center, count in zip(old_centers, old_hist, strict=False):
|
||||
if count > 0:
|
||||
# Find which new bin this old center belongs to
|
||||
bin_idx = np.searchsorted(new_edges, old_center) - 1
|
||||
bin_idx = max(0, min(bin_idx, self._num_quantile_bins - 1))
|
||||
new_hist[bin_idx] += count
|
||||
|
||||
self._histograms[i] = new_hist
|
||||
self._bin_edges[i] = new_edges
|
||||
|
||||
def _update_histograms(self, batch: np.ndarray) -> None:
|
||||
"""Update histograms with new vectors."""
|
||||
for i in range(batch.shape[1]):
|
||||
hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i])
|
||||
self._histograms[i] += hist
|
||||
|
||||
def _compute_quantiles(self) -> list[np.ndarray]:
|
||||
"""Compute quantiles based on histograms."""
|
||||
results = []
|
||||
for q in self._quantile_list:
|
||||
target_count = q * self._count
|
||||
q_values = []
|
||||
|
||||
for hist, edges in zip(self._histograms, self._bin_edges, strict=True):
|
||||
q_value = self._compute_single_quantile(hist, edges, target_count)
|
||||
q_values.append(q_value)
|
||||
|
||||
results.append(np.array(q_values))
|
||||
return results
|
||||
|
||||
def _compute_single_quantile(self, hist: np.ndarray, edges: np.ndarray, target_count: float) -> float:
|
||||
"""Compute a single quantile value from histogram and bin edges."""
|
||||
cumsum = np.cumsum(hist)
|
||||
idx = np.searchsorted(cumsum, target_count)
|
||||
|
||||
if idx == 0:
|
||||
return edges[0]
|
||||
if idx >= len(cumsum):
|
||||
return edges[-1]
|
||||
|
||||
# If not edge case, interpolate within the bin
|
||||
count_before = cumsum[idx - 1]
|
||||
count_in_bin = cumsum[idx] - count_before
|
||||
|
||||
# If no samples in this bin, use the bin edge
|
||||
if count_in_bin == 0:
|
||||
return edges[idx]
|
||||
|
||||
# Linear interpolation within the bin
|
||||
fraction = (target_count - count_before) / count_in_bin
|
||||
return edges[idx] + fraction * (edges[idx + 1] - edges[idx])
|
||||
|
||||
|
||||
def estimate_num_samples(
|
||||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||
@@ -72,33 +245,282 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
return images
|
||||
|
||||
|
||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||
return {
|
||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||
"max": np.max(array, axis=axis, keepdims=keepdims),
|
||||
"mean": np.mean(array, axis=axis, keepdims=keepdims),
|
||||
"std": np.std(array, axis=axis, keepdims=keepdims),
|
||||
"count": np.array([len(array)]),
|
||||
def _reshape_stats_by_axis(
|
||||
stats: dict[str, np.ndarray],
|
||||
axis: int | tuple[int, ...] | None,
|
||||
keepdims: bool,
|
||||
original_shape: tuple[int, ...],
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Reshape all statistics to match NumPy's output conventions.
|
||||
|
||||
Applies consistent reshaping to all statistics (except 'count') based on the
|
||||
axis and keepdims parameters. This ensures statistics have the correct shape
|
||||
for broadcasting with the original data.
|
||||
|
||||
Args:
|
||||
stats: Dictionary of computed statistics
|
||||
axis: Axis or axes along which statistics were computed
|
||||
keepdims: Whether to keep reduced dimensions as size-1 dimensions
|
||||
original_shape: Shape of the original array
|
||||
|
||||
Returns:
|
||||
Dictionary with reshaped statistics
|
||||
|
||||
Note:
|
||||
The 'count' statistic is never reshaped as it represents metadata
|
||||
rather than per-feature statistics.
|
||||
"""
|
||||
if axis == (1,) and not keepdims:
|
||||
return stats
|
||||
|
||||
result = {}
|
||||
for key, value in stats.items():
|
||||
if key == "count":
|
||||
result[key] = value
|
||||
else:
|
||||
result[key] = _reshape_single_stat(value, axis, keepdims, original_shape)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _reshape_for_image_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
|
||||
"""Reshape statistics for image data (axis=(0,2,3))."""
|
||||
if keepdims and value.ndim == 1:
|
||||
return value.reshape(1, -1, 1, 1)
|
||||
return value
|
||||
|
||||
|
||||
def _reshape_for_vector_stats(
|
||||
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
|
||||
) -> np.ndarray:
|
||||
"""Reshape statistics for vector data (axis=0 or axis=(0,))."""
|
||||
if not keepdims:
|
||||
return value
|
||||
|
||||
if len(original_shape) == 1 and value.ndim > 0:
|
||||
return value.reshape(1)
|
||||
elif len(original_shape) >= 2 and value.ndim == 1:
|
||||
return value.reshape(1, -1)
|
||||
return value
|
||||
|
||||
|
||||
def _reshape_for_feature_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
|
||||
"""Reshape statistics for feature-wise computation (axis=(1,))."""
|
||||
if not keepdims:
|
||||
return value
|
||||
|
||||
if value.ndim == 0:
|
||||
return value.reshape(1, 1)
|
||||
elif value.ndim == 1:
|
||||
return value.reshape(-1, 1)
|
||||
return value
|
||||
|
||||
|
||||
def _reshape_for_global_stats(
|
||||
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
|
||||
) -> np.ndarray | float:
|
||||
"""Reshape statistics for global reduction (axis=None)."""
|
||||
if keepdims:
|
||||
target_shape = tuple(1 for _ in original_shape)
|
||||
return value.reshape(target_shape)
|
||||
# Keep at least 1-D arrays to satisfy validator
|
||||
return np.atleast_1d(value)
|
||||
|
||||
|
||||
def _reshape_single_stat(
|
||||
value: np.ndarray, axis: int | tuple[int, ...] | None, keepdims: bool, original_shape: tuple[int, ...]
|
||||
) -> np.ndarray | float:
|
||||
"""Apply appropriate reshaping to a single statistic array.
|
||||
|
||||
This function transforms statistic arrays to match expected output shapes
|
||||
based on the axis configuration and keepdims parameter.
|
||||
|
||||
Args:
|
||||
value: The statistic array to reshape
|
||||
axis: Axis or axes that were reduced during computation
|
||||
keepdims: Whether to maintain reduced dimensions as size-1 dimensions
|
||||
original_shape: Shape of the original data before reduction
|
||||
|
||||
Returns:
|
||||
Reshaped array following NumPy broadcasting conventions
|
||||
|
||||
"""
|
||||
if axis == (0, 2, 3):
|
||||
return _reshape_for_image_stats(value, keepdims)
|
||||
|
||||
if axis in [0, (0,)]:
|
||||
return _reshape_for_vector_stats(value, keepdims, original_shape)
|
||||
|
||||
if axis == (1,):
|
||||
return _reshape_for_feature_stats(value, keepdims)
|
||||
|
||||
if axis is None:
|
||||
return _reshape_for_global_stats(value, keepdims, original_shape)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _prepare_array_for_stats(array: np.ndarray, axis: int | tuple[int, ...] | None) -> tuple[np.ndarray, int]:
|
||||
"""Prepare array for statistics computation by reshaping according to axis.
|
||||
|
||||
Args:
|
||||
array: Input data array
|
||||
axis: Axis or axes along which to compute statistics
|
||||
|
||||
Returns:
|
||||
Tuple of (reshaped_array, sample_count)
|
||||
"""
|
||||
if axis == (0, 2, 3): # Image data
|
||||
batch_size, channels, height, width = array.shape
|
||||
reshaped = array.transpose(0, 2, 3, 1).reshape(-1, channels)
|
||||
return reshaped, batch_size
|
||||
|
||||
if axis == 0 or axis == (0,): # Vector data
|
||||
reshaped = array
|
||||
if array.ndim == 1:
|
||||
reshaped = array.reshape(-1, 1)
|
||||
return reshaped, array.shape[0]
|
||||
|
||||
if axis == (1,): # Feature-wise statistics
|
||||
return array.T, array.shape[1]
|
||||
|
||||
if axis is None: # Global statistics
|
||||
reshaped = array.reshape(-1, 1)
|
||||
# For backward compatibility, count represents the first dimension size
|
||||
return reshaped, array.shape[0] if array.ndim > 0 else 1
|
||||
|
||||
raise ValueError(f"Unsupported axis configuration: {axis}")
|
||||
|
||||
|
||||
def _compute_basic_stats(
|
||||
array: np.ndarray, sample_count: int, quantile_list: list[float] | None = None
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Compute basic statistics for arrays with insufficient samples for quantiles.
|
||||
|
||||
Args:
|
||||
array: Reshaped array ready for statistics computation
|
||||
sample_count: Number of samples represented in the data
|
||||
|
||||
Returns:
|
||||
Dictionary with basic statistics and quantiles set to mean values
|
||||
"""
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
quantile_list_keys = [f"q{int(q * 100):02d}" for q in quantile_list]
|
||||
|
||||
stats = {
|
||||
"min": np.min(array, axis=0),
|
||||
"max": np.max(array, axis=0),
|
||||
"mean": np.mean(array, axis=0),
|
||||
"std": np.std(array, axis=0),
|
||||
"count": np.array([sample_count]),
|
||||
}
|
||||
|
||||
for q in quantile_list_keys:
|
||||
stats[q] = stats["mean"].copy()
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def get_feature_stats(
|
||||
array: np.ndarray,
|
||||
axis: int | tuple[int, ...] | None,
|
||||
keepdims: bool,
|
||||
quantile_list: list[float] | None = None,
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Compute comprehensive statistics for array features along specified axes.
|
||||
|
||||
This function calculates min, max, mean, std, and quantiles (1%, 10%, 50%, 90%, 99%)
|
||||
for the input array along the specified axes. It handles different data layouts:
|
||||
- Image data: axis=(0,2,3) computes per-channel statistics
|
||||
- Vector data: axis=0 computes per-feature statistics
|
||||
- Feature-wise: axis=1 computes statistics across features
|
||||
- Global: axis=None computes statistics over entire array
|
||||
|
||||
Args:
|
||||
array: Input data array with shape appropriate for the specified axis
|
||||
axis: Axis or axes along which to compute statistics
|
||||
- (0, 2, 3): For image data (batch, channels, height, width)
|
||||
- 0 or (0,): For vector/tabular data (samples, features)
|
||||
- (1,): For computing across features
|
||||
- None: For global statistics over entire array
|
||||
keepdims: If True, reduced axes are kept as dimensions with size 1
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- 'min': Minimum values
|
||||
- 'max': Maximum values
|
||||
- 'mean': Mean values
|
||||
- 'std': Standard deviation
|
||||
- 'count': Number of samples (always shape (1,))
|
||||
- 'q01', 'q10', 'q50', 'q90', 'q99': Quantile values
|
||||
|
||||
"""
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
|
||||
original_shape = array.shape
|
||||
reshaped, sample_count = _prepare_array_for_stats(array, axis)
|
||||
|
||||
if reshaped.shape[0] < 2:
|
||||
stats = _compute_basic_stats(reshaped, sample_count, quantile_list)
|
||||
else:
|
||||
running_stats = RunningQuantileStats()
|
||||
running_stats.update(reshaped)
|
||||
stats = running_stats.get_statistics()
|
||||
stats["count"] = np.array([sample_count])
|
||||
|
||||
stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape)
|
||||
return stats
|
||||
|
||||
|
||||
def compute_episode_stats(
|
||||
episode_data: dict[str, list[str] | np.ndarray],
|
||||
features: dict,
|
||||
quantile_list: list[float] | None = None,
|
||||
) -> dict:
|
||||
"""Compute comprehensive statistics for all features in an episode.
|
||||
|
||||
Processes different data types appropriately:
|
||||
- Images/videos: Samples from paths, computes per-channel stats, normalizes to [0,1]
|
||||
- Numerical arrays: Computes per-feature statistics
|
||||
- Strings: Skipped (no statistics computed)
|
||||
|
||||
Args:
|
||||
episode_data: Dictionary mapping feature names to data
|
||||
- For images/videos: list of file paths
|
||||
- For numerical data: numpy arrays
|
||||
features: Dictionary describing each feature's dtype and shape
|
||||
|
||||
Returns:
|
||||
Dictionary mapping feature names to their statistics dictionaries.
|
||||
Each statistics dictionary contains min, max, mean, std, count, and quantiles.
|
||||
|
||||
Note:
|
||||
Image statistics are normalized to [0,1] range and have shape (3,1,1) for
|
||||
per-channel values when dtype is 'image' or 'video'.
|
||||
"""
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
|
||||
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
continue # HACK: we should receive np.arrays of strings
|
||||
elif features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data) # data is a list of image paths
|
||||
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||
continue
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data)
|
||||
axes_to_reduce = (0, 2, 3)
|
||||
keepdims = True
|
||||
else:
|
||||
ep_ft_array = data # data is already a np.ndarray
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
keepdims = data.ndim == 1 # keep as np.array
|
||||
ep_ft_array = data
|
||||
axes_to_reduce = 0
|
||||
keepdims = data.ndim == 1
|
||||
|
||||
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||
ep_stats[key] = get_feature_stats(
|
||||
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list
|
||||
)
|
||||
|
||||
# finally, we normalize and remove batch dim for images
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
@@ -107,20 +529,37 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
|
||||
return ep_stats
|
||||
|
||||
|
||||
def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
|
||||
"""Validate a single statistic value."""
|
||||
if not isinstance(value, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Stats must be composed of numpy array, but key '{key}' of feature '{feature_key}' "
|
||||
f"is of type '{type(value)}' instead."
|
||||
)
|
||||
|
||||
if value.ndim == 0:
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
|
||||
if key == "count" and value.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
|
||||
|
||||
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
|
||||
|
||||
|
||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
for i in range(len(stats_list)):
|
||||
for fkey in stats_list[i]:
|
||||
for k, v in stats_list[i][fkey].items():
|
||||
if not isinstance(v, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
|
||||
)
|
||||
if v.ndim == 0:
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
if k == "count" and v.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
|
||||
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
||||
"""Validate that all statistics have correct types and shapes.
|
||||
|
||||
Args:
|
||||
stats_list: List of statistics dictionaries to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If any statistic has incorrect type or shape
|
||||
"""
|
||||
for stats in stats_list:
|
||||
for feature_key, feature_stats in stats.items():
|
||||
for stat_key, stat_value in feature_stats.items():
|
||||
_validate_stat_value(stat_value, stat_key, feature_key)
|
||||
|
||||
|
||||
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
@@ -143,7 +582,7 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
|
||||
weighted_variances = (variances + delta_means**2) * counts
|
||||
total_variance = weighted_variances.sum(axis=0) / total_count
|
||||
|
||||
return {
|
||||
aggregated = {
|
||||
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
|
||||
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
|
||||
"mean": total_mean,
|
||||
@@ -151,6 +590,17 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
|
||||
"count": total_count,
|
||||
}
|
||||
|
||||
if stats_ft_list:
|
||||
quantile_keys = [k for k in stats_ft_list[0] if k.startswith("q") and k[1:].isdigit()]
|
||||
|
||||
for q_key in quantile_keys:
|
||||
if all(q_key in s for s in stats_ft_list):
|
||||
quantile_values = np.stack([s[q_key] for s in stats_ft_list])
|
||||
weighted_quantiles = quantile_values * counts
|
||||
aggregated[q_key] = weighted_quantiles.sum(axis=0) / total_count
|
||||
|
||||
return aggregated
|
||||
|
||||
|
||||
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user