mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
Compare commits
86 Commits
feat/add_m
...
feat/bilat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
020fc12ead | ||
|
|
f816092993 | ||
|
|
1753235a61 | ||
|
|
739aaa8edd | ||
|
|
15678bd51a | ||
|
|
d72b4fe056 | ||
|
|
f9fd0fb841 | ||
|
|
6cf4555081 | ||
|
|
5ec2615b21 | ||
|
|
65c11eb5e6 | ||
|
|
7621acf776 | ||
|
|
3b33f9e34c | ||
|
|
7157794f58 | ||
|
|
88bc763033 | ||
|
|
64172756a7 | ||
|
|
3cd10d3560 | ||
|
|
dc69ae3fc0 | ||
|
|
bb0175e05e | ||
|
|
cff530a17a | ||
|
|
746336f9c8 | ||
|
|
e48d8babe0 | ||
|
|
da71b233be | ||
|
|
485aa2332c | ||
|
|
0bd16432bc | ||
|
|
5ab6505ea8 | ||
|
|
5170862d23 | ||
|
|
101fb02697 | ||
|
|
0664addec1 | ||
|
|
a7391e82c7 | ||
|
|
3521dd93c1 | ||
|
|
6288439d48 | ||
|
|
1cf768e17a | ||
|
|
d11ec6b5ef | ||
|
|
c75455a6de | ||
|
|
f25ac02e6c | ||
|
|
23cb668cac | ||
|
|
2ea3043b1b | ||
|
|
0f61e2415f | ||
|
|
76a425c600 | ||
|
|
df71f3ce24 | ||
|
|
326aca0a48 | ||
|
|
be46bdea8f | ||
|
|
306429a85b | ||
|
|
12f2f35760 | ||
|
|
a024d33750 | ||
|
|
63cd2111ad | ||
|
|
abe9e79825 | ||
|
|
503fc4e9f4 | ||
|
|
92b479f9ac | ||
|
|
b954337ac7 | ||
|
|
5f6f476f32 | ||
|
|
502fdc0630 | ||
|
|
9db6213895 | ||
|
|
aa1d906802 | ||
|
|
eff8a6fd12 | ||
|
|
c54cd529a2 | ||
|
|
a5ca206c49 | ||
|
|
88100943ef | ||
|
|
a95b15ccc0 | ||
|
|
a97d078d95 | ||
|
|
98662e5f24 | ||
|
|
4d8f242af9 | ||
|
|
1ff8986c77 | ||
|
|
f0aeded142 | ||
|
|
da5d2f3e91 | ||
|
|
d6ea3bbce0 | ||
|
|
7aedbbf81a | ||
|
|
1ee8d824f5 | ||
|
|
f7c4f99545 | ||
|
|
92b6254473 | ||
|
|
79137f58d1 | ||
|
|
da9c2e66f4 | ||
|
|
45730cc71e | ||
|
|
5d4af4b0b1 | ||
|
|
0050d7c61c | ||
|
|
cf2897f545 | ||
|
|
2c18210d02 | ||
|
|
44bf283701 | ||
|
|
a51682b266 | ||
|
|
ed49c9935a | ||
|
|
52455d03a7 | ||
|
|
4afb253825 | ||
|
|
96c664e09f | ||
|
|
8bd0aec618 | ||
|
|
e82e7a02e9 | ||
|
|
845b359d39 |
25
.github/workflows/fast_tests.yml
vendored
25
.github/workflows/fast_tests.yml
vendored
@@ -57,11 +57,7 @@ jobs:
|
||||
# It runs everytime we commit to a PR or push to main
|
||||
fast-pytest-tests:
|
||||
name: Fast Pytest Tests
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
@@ -71,21 +67,12 @@ jobs:
|
||||
lfs: true
|
||||
|
||||
# TODO(Steven): Evaluate the need of these dependencies
|
||||
- name: Install dependencies
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
|
||||
brew update && brew install git geos portaudio ffmpeg@7
|
||||
# Add ffmpeg@7 paths for subsequent steps
|
||||
echo "PATH=/opt/homebrew/opt/ffmpeg@7/bin:$PATH" >> $GITHUB_ENV
|
||||
echo "LDFLAGS=-L/opt/homebrew/opt/ffmpeg@7/lib" >> $GITHUB_ENV
|
||||
echo "CPPFLAGS=-I/opt/homebrew/opt/ffmpeg@7/include" >> $GITHUB_ENV
|
||||
echo "PKG_CONFIG_PATH=/opt/homebrew/opt/ffmpeg@7/lib/pkgconfig" >> $GITHUB_ENV
|
||||
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||
fi
|
||||
sudo apt-get update && sudo apt-get install -y build-essential git \
|
||||
curl libglib2.0-0 libegl1-mesa-dev ffmpeg \
|
||||
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev
|
||||
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
|
||||
21
.github/workflows/full_tests.yml
vendored
21
.github/workflows/full_tests.yml
vendored
@@ -51,11 +51,7 @@ jobs:
|
||||
# It runs everytime a PR is approved or a push to main
|
||||
full-tests:
|
||||
name: Full Tests
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved') ||
|
||||
github.event_name == 'push' ||
|
||||
@@ -68,16 +64,11 @@ jobs:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
|
||||
brew update && brew install git geos portaudio ffmpeg@7
|
||||
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||
fi
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
@@ -87,7 +78,7 @@ jobs:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --all-extras
|
||||
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
|
||||
34
.github/workflows/nightly.yml
vendored
34
.github/workflows/nightly.yml
vendored
@@ -119,6 +119,7 @@ jobs:
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
container:
|
||||
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
@@ -158,3 +159,36 @@ jobs:
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
run: make test-end-to-end
|
||||
|
||||
# This job runs multi-GPU training tests with 4 GPUs
|
||||
nightly-multi-gpu-tests:
|
||||
name: Nightly Multi-GPU Tests
|
||||
needs: [build-docker-gpu-nightly]
|
||||
runs-on:
|
||||
group: aws-g4dn-12xlarge # Instance with 4 GPUs
|
||||
env:
|
||||
HF_HOME: /home/user_lerobot/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
CUDA_VISIBLE_DEVICES: "0,1,2,3"
|
||||
container:
|
||||
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Verify GPU availability
|
||||
run: |
|
||||
nvidia-smi
|
||||
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
||||
|
||||
- name: Run multi-GPU training tests
|
||||
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
|
||||
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
|
||||
timeout-minutes: 10
|
||||
|
||||
33
.github/workflows/release.yml
vendored
33
.github/workflows/release.yml
vendored
@@ -82,6 +82,14 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Remove Tags with Git dependencies
|
||||
# TODO(Steven): Temporary patch to remove libero and pi from PyPi 0.4.0 release due to its reliance on git dependencies.
|
||||
run: |
|
||||
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
|
||||
grep -E '@ git\+https|lerobot\[pi\]|lerobot\[libero\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
|
||||
sed -E -i '/@ git\+https|lerobot\[pi\]|lerobot\[libero\]/d' pyproject.toml
|
||||
echo "::info:: Git dependencies removed. Proceeding with build."
|
||||
|
||||
- name: Install build dependencies
|
||||
run: python -m pip install build
|
||||
|
||||
@@ -103,7 +111,7 @@ jobs:
|
||||
- name: Publish to TestPyPI for pre-releases
|
||||
# True for tags like 'v0.2.0-rc1'
|
||||
if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-')
|
||||
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
verbose: true
|
||||
@@ -111,7 +119,7 @@ jobs:
|
||||
|
||||
- name: Publish to PyPI
|
||||
if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
|
||||
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
with:
|
||||
verbose: true
|
||||
print-hash: true
|
||||
@@ -120,11 +128,7 @@ jobs:
|
||||
test-release:
|
||||
name: Test Release
|
||||
needs: [build-and-publish]
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
env:
|
||||
@@ -134,20 +138,15 @@ jobs:
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
- name: Install dependencies
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
|
||||
brew update && brew install git geos portaudio ffmpeg@7
|
||||
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||
fi
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
enable-cache: true
|
||||
enable-cache: true # zizmor: ignore[cache-poisoning]
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
- name: Create uv virtual environment
|
||||
|
||||
12
.github/workflows/stale.yml
vendored
12
.github/workflows/stale.yml
vendored
@@ -27,15 +27,17 @@ env:
|
||||
This issue was closed because it has been stalled for 14 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
CLOSE_PR_MESSAGE: >
|
||||
This PR was closed because it has been stalled for 14 days with no activity.
|
||||
This PR was closed because it has been stalled for 21 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
WARN_ISSUE_MESSAGE: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
Any change, comment or update to this issue will reset this count.
|
||||
Thank you for your contributions.
|
||||
WARN_PR_MESSAGE: >
|
||||
This PR has been automatically marked as stale because it has not had
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
recent activity (1 year). It will be closed if no further activity occurs.
|
||||
Any change, comment or update to this PR will reset this count.
|
||||
Thank you for your contributions.
|
||||
|
||||
jobs:
|
||||
@@ -56,10 +58,10 @@ jobs:
|
||||
stale-pr-label: stale
|
||||
exempt-issue-labels: never-stale
|
||||
exempt-pr-labels: never-stale
|
||||
days-before-issue-stale: 180 # TODO(Steven): Will modify this to 90 after initial cleanup
|
||||
days-before-issue-stale: 180
|
||||
days-before-issue-close: 14
|
||||
days-before-pr-stale: 180
|
||||
days-before-pr-close: 14
|
||||
days-before-pr-stale: 365
|
||||
days-before-pr-close: 21
|
||||
delete-branch: true
|
||||
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
|
||||
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}
|
||||
|
||||
19
.github/workflows/unbound_deps_tests.yml
vendored
19
.github/workflows/unbound_deps_tests.yml
vendored
@@ -42,11 +42,7 @@ jobs:
|
||||
# This job runs the E2E tests + pytest with all unbound extras
|
||||
full-tests:
|
||||
name: Full Unbound Tests
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
@@ -55,16 +51,11 @@ jobs:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
elif [[ "${{ matrix.os }}" == 'macos-latest' ]]; then
|
||||
brew update && brew install git geos portaudio ffmpeg@7
|
||||
echo "DYLD_LIBRARY_PATH=/opt/homebrew/opt/ffmpeg@7/lib:/opt/homebrew/lib:/usr/local/lib:$DYLD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||
fi
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
|
||||
@@ -26,7 +26,7 @@ repos:
|
||||
|
||||
##### General Code Quality & Formatting #####
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
rev: v6.0.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=1024']
|
||||
@@ -39,20 +39,20 @@ repos:
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.4
|
||||
rev: v0.14.1
|
||||
hooks:
|
||||
- id: ruff-format
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
|
||||
- repo: https://github.com/adhtruong/mirrors-typos
|
||||
rev: v1.34.0
|
||||
rev: v1.38.1
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--force-exclude]
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.20.0
|
||||
rev: v3.21.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py310-plus]
|
||||
@@ -68,12 +68,12 @@ repos:
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.27.2
|
||||
rev: v8.28.0
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||
rev: v1.11.0
|
||||
rev: v1.15.2
|
||||
hooks:
|
||||
- id: zizmor
|
||||
|
||||
@@ -87,7 +87,7 @@ repos:
|
||||
# TODO(Steven): Uncomment when ready to use
|
||||
##### Static Analysis & Typing #####
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.16.0
|
||||
rev: v1.18.2
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: [--config-file=pyproject.toml]
|
||||
|
||||
@@ -137,7 +137,7 @@ Follow these steps to start contributing:
|
||||
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
|
||||
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
|
||||
|
||||
Set up a development environment with conda or miniconda:
|
||||
Set up a development environment with conda:
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
|
||||
|
||||
19
README.md
19
README.md
@@ -104,14 +104,14 @@ LeRobot works with Python 3.10+ and PyTorch 2.2+.
|
||||
|
||||
### Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniforge`](https://conda-forge.org/download/):
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
When using `conda`, install `ffmpeg` in your environment:
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
@@ -185,6 +185,11 @@ _Replace `[...]` with your desired features._
|
||||
For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install libero or pi tags, you will have to do: `pip install "lerobot[pi,libero]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
### Weights & Biases
|
||||
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
||||
@@ -207,13 +212,13 @@ lerobot-dataset-viz \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
or from a dataset in a local folder with the `root` option and the `--mode local` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--root ./my_local_data_dir \
|
||||
--local-files-only 1 \
|
||||
--mode local \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
@@ -310,7 +315,7 @@ To upload these to the hub, run the following:
|
||||
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
|
||||
```
|
||||
|
||||
See [eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/eval.py) for an example of how other people may use your policy.
|
||||
See [lerobot_eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py) for an example of how other people may use your policy.
|
||||
|
||||
### Acknowledgment
|
||||
|
||||
@@ -337,7 +342,3 @@ If you want, you can cite this work with:
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#huggingface/lerobot&Timeline)
|
||||
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@@ -17,6 +17,8 @@
|
||||
title: Train RL in Simulation
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
- local: multi_gpu_training
|
||||
title: Multi GPU training
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
@@ -35,6 +37,8 @@
|
||||
title: π₀ (Pi0)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
title: "Policies"
|
||||
- sections:
|
||||
- local: il_sim
|
||||
|
||||
122
docs/source/groot.mdx
Normal file
122
docs/source/groot.mdx
Normal file
@@ -0,0 +1,122 @@
|
||||
# GR00T N1.5 Policy
|
||||
|
||||
GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
|
||||
|
||||
This document outlines the specifics of its integration and usage within the LeRobot framework.
|
||||
|
||||
## Model Overview
|
||||
|
||||
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
|
||||
|
||||
Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
|
||||
|
||||
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
|
||||
|
||||
Its strong performance comes from being trained on an expansive and diverse humanoid dataset, which includes:
|
||||
|
||||
- Real captured data from robots.
|
||||
- Synthetic data generated using NVIDIA Isaac GR00T Blueprint.
|
||||
- Internet-scale video data.
|
||||
|
||||
This approach allows the model to be highly adaptable through post-training for specific embodiments, tasks, and environments.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
As of today, GR00T N1.5 requires flash attention for it's internal working.
|
||||
|
||||
We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
|
||||
|
||||
1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
|
||||
2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
|
||||
|
||||
```bash
|
||||
# Check https://pytorch.org/get-started/locally/ for your system
|
||||
pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
|
||||
pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
|
||||
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
|
||||
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
|
||||
```
|
||||
|
||||
3. Install LeRobot by running:
|
||||
|
||||
```bash
|
||||
pip install lerobot[groot] # consider also installing libero,dev and test tags
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To use GR00T in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=groot
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Training Command Example
|
||||
|
||||
Here's a complete training command for finetuning the base GR00T model on your own dataset:
|
||||
|
||||
```bash
|
||||
# Using a multi-GPU setup
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=$NUM_GPUS \
|
||||
$(which lerobot-train) \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--save_checkpoint=true \
|
||||
--batch_size=$BATCH_SIZE \
|
||||
--steps=$NUM_STEPS \
|
||||
--save_freq=$SAVE_FREQ \
|
||||
--log_freq=$LOG_FREQ \
|
||||
--policy.push_to_hub=true \
|
||||
--policy.type=groot \
|
||||
--policy.repo_id=$REPO_ID \
|
||||
--policy.tune_diffusion_model=false \
|
||||
--dataset.repo_id=$DATASET_ID \
|
||||
--wandb.enable=true \
|
||||
--wandb.disable_artifact=true \
|
||||
--job_name=$JOB_NAME
|
||||
```
|
||||
|
||||
## Performance Results
|
||||
|
||||
### Libero Benchmark Results
|
||||
|
||||
GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
|
||||
|
||||
| Benchmark | LeRobot Implementation | GR00T Reference |
|
||||
| ------------------ | ---------------------- | --------------- |
|
||||
| **Libero Spatial** | 82.0% | 92.0% |
|
||||
| **Libero Object** | 99.0% | 92.0% |
|
||||
| **Libero Long** | 82.0% | 76.0% |
|
||||
| **Average** | 87.0% | 87.0% |
|
||||
|
||||
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
|
||||
|
||||
### Evaluate in your hardware setup
|
||||
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=bi_so100_follower \
|
||||
--robot.left_arm_port=/dev/ttyACM1 \
|
||||
--robot.right_arm_port=/dev/ttyACM0 \
|
||||
--robot.id=bimanual_follower \
|
||||
--robot.cameras='{ right: {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30},
|
||||
left: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30},
|
||||
top: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30},
|
||||
}' \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm"
|
||||
--policy.path=<user>/groot-bimanual # your trained model
|
||||
--dataset.episode_time_s=30
|
||||
--dataset.reset_time_s=10
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T).
|
||||
@@ -165,7 +165,7 @@ huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
HF_USER=$(hf auth whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
# Installation
|
||||
|
||||
## Install [`miniforge`](https://conda-forge.org/download/)
|
||||
|
||||
```bash
|
||||
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
|
||||
bash Miniforge3-$(uname)-$(uname -m).sh
|
||||
```
|
||||
|
||||
## Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install)
|
||||
Create a virtual environment with Python 3.10, using conda:
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
@@ -14,7 +21,7 @@ Then activate your conda environment, you have to do this each time you open a s
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
When using `conda`, install `ffmpeg` in your environment:
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
@@ -74,6 +81,9 @@ _Replace `[...]` with your desired features._
|
||||
For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install libero or pi, you will have to do: `pip install "lerobot[pi,libero]@git+https://github.com/huggingface/lerobot.git"`
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||
|
||||
@@ -208,34 +208,36 @@ LeRobot supports saving and loading calibration data automatically. This is usef
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
> @property
|
||||
> def is_calibrated(self) -> bool:
|
||||
> return True
|
||||
>
|
||||
> def calibrate(self) -> None:
|
||||
> pass
|
||||
> ```
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
def calibrate(self) -> None:
|
||||
pass
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### `is_calibrated`
|
||||
|
||||
This should reflect whether your robot has the required calibration loaded.
|
||||
|
||||
```
|
||||
<!-- prettier-ignore-end -->python
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.bus.is_calibrated
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### `calibrate()`
|
||||
|
||||
The goal of the calibration is twofold:
|
||||
- Know the physical range of motion of each motors in order to only send commands within this range.
|
||||
- Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
|
||||
|
||||
- Know the physical range of motion of each motors in order to only send commands within this range.
|
||||
- Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
|
||||
|
||||
It should implement the logic for calibration (if relevant) and update the `self.calibration` dictionary. If you are using Feetech or Dynamixel motors, our bus interfaces already include methods to help with this.
|
||||
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
def calibrate(self) -> None:
|
||||
|
||||
125
docs/source/multi_gpu_training.mdx
Normal file
125
docs/source/multi_gpu_training.mdx
Normal file
@@ -0,0 +1,125 @@
|
||||
# Multi-GPU Training
|
||||
|
||||
This guide shows you how to train policies on multiple GPUs using [Hugging Face Accelerate](https://huggingface.co/docs/accelerate).
|
||||
|
||||
## Installation
|
||||
|
||||
First, ensure you have accelerate installed:
|
||||
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
## Training with Multiple GPUs
|
||||
|
||||
You can launch training in two ways:
|
||||
|
||||
### Option 1: Without config (specify parameters directly)
|
||||
|
||||
You can specify all parameters directly in the command without running `accelerate config`:
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=2 \
|
||||
$(which lerobot-train) \
|
||||
--dataset.repo_id=${HF_USER}/my_dataset \
|
||||
--policy.type=act \
|
||||
--policy.repo_id=${HF_USER}/my_trained_policy \
|
||||
--output_dir=outputs/train/act_multi_gpu \
|
||||
--job_name=act_multi_gpu \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
**Key accelerate parameters:**
|
||||
|
||||
- `--multi_gpu`: Enable multi-GPU training
|
||||
- `--num_processes=2`: Number of GPUs to use
|
||||
- `--mixed_precision=fp16`: Use fp16 mixed precision (or `bf16` if supported)
|
||||
|
||||
### Option 2: Using accelerate config
|
||||
|
||||
If you prefer to save your configuration, you can optionally configure accelerate for your hardware setup by running:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
This interactive setup will ask you questions about your training environment (number of GPUs, mixed precision settings, etc.) and saves the configuration for future use. For a simple multi-GPU setup on a single machine, you can use these recommended settings:
|
||||
|
||||
- Compute environment: This machine
|
||||
- Number of machines: 1
|
||||
- Number of processes: (number of GPUs you want to use)
|
||||
- GPU ids to use: (leave empty to use all)
|
||||
- Mixed precision: fp16 or bf16 (recommended for faster training)
|
||||
|
||||
Then launch training with:
|
||||
|
||||
```bash
|
||||
accelerate launch $(which lerobot-train) \
|
||||
--dataset.repo_id=${HF_USER}/my_dataset \
|
||||
--policy.type=act \
|
||||
--policy.repo_id=${HF_USER}/my_trained_policy \
|
||||
--output_dir=outputs/train/act_multi_gpu \
|
||||
--job_name=act_multi_gpu \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
When you launch training with accelerate:
|
||||
|
||||
1. **Automatic detection**: LeRobot automatically detects if it's running under accelerate
|
||||
2. **Data distribution**: Your batch is automatically split across GPUs
|
||||
3. **Gradient synchronization**: Gradients are synchronized across GPUs during backpropagation
|
||||
4. **Single process logging**: Only the main process logs to wandb and saves checkpoints
|
||||
|
||||
## Learning Rate and Training Steps Scaling
|
||||
|
||||
**Important:** LeRobot does **NOT** automatically scale learning rates or training steps based on the number of GPUs. This gives you full control over your training hyperparameters.
|
||||
|
||||
### Why No Automatic Scaling?
|
||||
|
||||
Many distributed training frameworks automatically scale the learning rate by the number of GPUs (e.g., `lr = base_lr × num_gpus`).
|
||||
However, LeRobot keeps the learning rate exactly as you specify it.
|
||||
|
||||
### When and How to Scale
|
||||
|
||||
If you want to scale your hyperparameters when using multiple GPUs, you should do it manually:
|
||||
|
||||
**Learning Rate Scaling:**
|
||||
|
||||
```bash
|
||||
# Example: 2 GPUs with linear LR scaling
|
||||
# Base LR: 1e-4, with 2 GPUs -> 2e-4
|
||||
accelerate launch --num_processes=2 $(which lerobot-train) \
|
||||
--optimizer.lr=2e-4 \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--policy=act
|
||||
```
|
||||
|
||||
**Training Steps Scaling:**
|
||||
|
||||
Since the effective batch size `bs` increases with multiple GPUs (batch_size × num_gpus), you may want to reduce the number of training steps proportionally:
|
||||
|
||||
```bash
|
||||
# Example: 2 GPUs with effective batch size 2x larger
|
||||
# Original: batch_size=8, steps=100000
|
||||
# With 2 GPUs: batch_size=8 (16 in total), steps=50000
|
||||
accelerate launch --num_processes=2 $(which lerobot-train) \
|
||||
--batch_size=8 \
|
||||
--steps=50000 \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--policy=act
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
|
||||
- Training logs, checkpoints, and hub uploads are only done by the main process to avoid conflicts. Non-main processes have console logging disabled to prevent duplicate output.
|
||||
- The effective batch size is `batch_size × num_gpus`. If you use 4 GPUs with `--batch_size=8`, your effective batch size is 32.
|
||||
- Learning rate scheduling is handled correctly across multiple processes—LeRobot sets `step_scheduler_with_optimizer=False` to prevent accelerate from adjusting scheduler steps based on the number of processes.
|
||||
- When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility.
|
||||
- WandB integration automatically initializes only on the main process, preventing multiple runs from being created.
|
||||
|
||||
For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate). If you want to learn more about how to train on a large number of GPUs, checkout this awesome guide: [Ultrascale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook).
|
||||
328
docs/source/openarms.mdx
Normal file
328
docs/source/openarms.mdx
Normal file
@@ -0,0 +1,328 @@
|
||||
# OpenArms Robot
|
||||
|
||||
OpenArms is a 7 DOF robotic arm with a gripper, designed by [Enactic, Inc.](https://www.enactic.com/) It uses Damiao motors controlled via CAN bus communication and MIT control mode for smooth, precise motion.
|
||||
|
||||
## Hardware Overview
|
||||
|
||||
- **7 DOF per arm** (14 DOF total for dual arm setup)
|
||||
- **1 gripper per arm** (2 grippers total)
|
||||
- **Damiao motors** with 4 different types:
|
||||
- **DM8009** (DM-J8009P-2EC) for shoulders (J1, J2) - high torque
|
||||
- **DM4340** for shoulder rotation and elbow (J3, J4)
|
||||
- **DM4310** (DM-J4310-2EC V1.1) for wrist (J5, J6, J7) and gripper (J8)
|
||||
- **24V power supply** required
|
||||
- **CAN interface device**:
|
||||
- **Linux**: Any SocketCAN-compatible adapter
|
||||
- **macOS**: CANable, PEAK PCAN-USB, or Kvaser USBcan
|
||||
- Proper CAN wiring (CANH, CANL, 120Ω termination)
|
||||
|
||||
|
||||
## Motor Configuration
|
||||
|
||||
Each arm has the following motor configuration based on the [OpenArm setup guide](https://docs.openarm.dev/software/setup/):
|
||||
|
||||
| Joint | Motor | Motor Type | Sender CAN ID | Receiver ID | Description |
|
||||
|-------|-------|------------|---------------|-------------|-------------|
|
||||
| J1 | joint_1 | DM8009 | 0x01 | 0x11 | Shoulder pan |
|
||||
| J2 | joint_2 | DM8009 | 0x02 | 0x12 | Shoulder lift |
|
||||
| J3 | joint_3 | DM4340 | 0x03 | 0x13 | Shoulder rotation |
|
||||
| J4 | joint_4 | DM4340 | 0x04 | 0x14 | Elbow flex |
|
||||
| J5 | joint_5 | DM4310 | 0x05 | 0x15 | Wrist roll |
|
||||
| J6 | joint_6 | DM4310 | 0x06 | 0x16 | Wrist pitch |
|
||||
| J7 | joint_7 | DM4310 | 0x07 | 0x17 | Wrist rotation |
|
||||
| J8 | gripper | DM4310 | 0x08 | 0x18 | Gripper |
|
||||
|
||||
For dual arm setups, the left arm uses IDs 0x09-0x10 for joints 1-8 with the same motor types.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Install system dependencies
|
||||
sudo apt install can-utils iproute2
|
||||
|
||||
# Install LeRobot with OpenArms support
|
||||
pip install -e ".[openarms]"
|
||||
```
|
||||
|
||||
## Setup Guide
|
||||
|
||||
### Step 1: Motor ID Configuration
|
||||
|
||||
**IMPORTANT**: Before using the robot, motors must be configured with the correct CAN IDs.
|
||||
|
||||
Refer to the [OpenArm Motor ID Configuration Guide](https://docs.openarm.dev/software/setup/motor-id) for detailed instructions using the Damiao Debugging Tools on Windows.
|
||||
|
||||
Key points:
|
||||
- Each motor needs a unique **Sender CAN ID** (0x01-0x08)
|
||||
- Each motor needs a unique **Receiver/Master ID** (0x11-0x18)
|
||||
- Use the Damiao Debugging Tools to set these IDs
|
||||
|
||||
### Step 2: Setup CAN Interface
|
||||
|
||||
Configure your CAN interface as described in the [OpenArm CAN Setup Guide](https://docs.openarm.dev/software/setup/can-setup):
|
||||
|
||||
#### Linux (SocketCAN)
|
||||
|
||||
```bash
|
||||
# Find your CAN interface
|
||||
ip link show
|
||||
|
||||
# Configure can0, 1, 2, 3
|
||||
sudo ip link set can0 down
|
||||
sudo ip link set can0 type can bitrate 1000000
|
||||
sudo ip link set can0 up
|
||||
|
||||
sudo ip link set can1 down
|
||||
sudo ip link set can1 type can bitrate 1000000
|
||||
sudo ip link set can1 up
|
||||
|
||||
sudo ip link set can2 down
|
||||
sudo ip link set can2 type can bitrate 1000000
|
||||
sudo ip link set can2 up
|
||||
|
||||
sudo ip link set can3 down
|
||||
sudo ip link set can3 type can bitrate 1000000
|
||||
sudo ip link set can3 up
|
||||
|
||||
# Verify configuration
|
||||
ip link show can0
|
||||
```
|
||||
|
||||
or run:
|
||||
|
||||
`examples/openarms/setup_can.sh`
|
||||
|
||||
### Testing canbus and motor connection
|
||||
|
||||
Please run this script to check if all motors can be found and to find your can-fd speed: `python examples/openarms/debug_can_communication.py`
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Setup
|
||||
|
||||
|
||||
```python
|
||||
from lerobot.robots.openarms import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
|
||||
# Configure for dual arm setup
|
||||
config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
can_interface="socketcan", # Or "auto" for auto-detection
|
||||
id="openarms_dual",
|
||||
is_dual_arm=True,
|
||||
)
|
||||
|
||||
robot = OpenArmsFollower(config)
|
||||
robot.connect()
|
||||
```
|
||||
|
||||
### Calibration
|
||||
|
||||
On first use, you'll need to calibrate the robot:
|
||||
|
||||
```python
|
||||
robot.calibrate()
|
||||
```
|
||||
|
||||
The calibration process will:
|
||||
1. Disable torque on all motors
|
||||
2. Ask you to position arms in **hanging position with grippers closed**
|
||||
3. Set this as the zero position
|
||||
4. Ask you to move each joint through its full range
|
||||
5. Record min/max positions for each joint
|
||||
6. Save calibration to file
|
||||
|
||||
### Reading Observations
|
||||
|
||||
The robot provides comprehensive state information:
|
||||
|
||||
```python
|
||||
observation = robot.get_observation()
|
||||
|
||||
# Observation includes for each motor:
|
||||
# - {motor_name}.pos: Position in degrees
|
||||
# - {motor_name}.vel: Velocity in degrees/second
|
||||
# - {motor_name}.torque: Motor torque
|
||||
# - {camera_name}: Camera images (if configured)
|
||||
|
||||
print(f"Right arm joint 1 position: {observation['right_joint_1.pos']:.1f}°")
|
||||
print(f"Right arm joint 1 velocity: {observation['right_joint_1.vel']:.1f}°/s")
|
||||
print(f"Right arm joint 1 torque: {observation['right_joint_1.torque']:.3f} N·m")
|
||||
```
|
||||
|
||||
### Sending Actions
|
||||
|
||||
```python
|
||||
# Send target positions (in degrees)
|
||||
action = {
|
||||
"right_joint_1.pos": 45.0,
|
||||
"right_joint_2.pos": -30.0,
|
||||
# ... all joints
|
||||
"right_gripper.pos": 45.0, # Half-closed
|
||||
}
|
||||
|
||||
actual_action = robot.send_action(action)
|
||||
```
|
||||
|
||||
### Gripper Control
|
||||
|
||||
```python
|
||||
# Open gripper
|
||||
robot.open_gripper(arm="right")
|
||||
|
||||
# Close gripper
|
||||
robot.close_gripper(arm="right")
|
||||
```
|
||||
|
||||
## Safety Features
|
||||
|
||||
### 1. Maximum Relative Target
|
||||
|
||||
Limits how far a joint can move in a single command to prevent sudden movements:
|
||||
|
||||
```python
|
||||
config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
# Limit all joints to 10 degrees per command
|
||||
max_relative_target=10.0,
|
||||
|
||||
# Or set per-motor limits
|
||||
max_relative_target={
|
||||
"right_joint_1": 15.0, # Slower moving joint
|
||||
"right_joint_2": 10.0,
|
||||
"right_gripper": 5.0, # Very slow gripper
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**How it works**: If current position is 50° and you command 80°, with `max_relative_target=10.0`, the robot will only move to 60° in that step.
|
||||
|
||||
### 2. Torque Limits
|
||||
|
||||
Control maximum torque output, especially important for grippers and teleoperation:
|
||||
|
||||
```python
|
||||
config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
# Gripper torque limit (fraction of motor's max torque)
|
||||
gripper_torque_limit=0.5, # 50% of max torque
|
||||
)
|
||||
```
|
||||
|
||||
Lower torque limits prevent damage when gripping delicate objects.
|
||||
|
||||
### 3. MIT Control Gains
|
||||
|
||||
Control responsiveness and stability via PID-like gains:
|
||||
|
||||
```python
|
||||
config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
position_kp=10.0, # Position gain (higher = more responsive)
|
||||
position_kd=0.5, # Velocity damping (higher = more damped)
|
||||
)
|
||||
```
|
||||
|
||||
**Guidelines**:
|
||||
- **For following (robot)**: Higher gains for responsiveness
|
||||
- `position_kp=10.0`, `position_kd=0.5`
|
||||
- **For teleoperation (leader)**: Lower gains or disable torque for manual movement
|
||||
- `manual_control=True` (torque disabled)
|
||||
|
||||
### 4. Velocity Limits
|
||||
|
||||
Velocity limits are enforced by the Damiao motors based on motor type. For DM4310:
|
||||
- Max velocity: 30 rad/s ≈ 1718°/s
|
||||
|
||||
The motors will automatically limit velocity to safe values.
|
||||
|
||||
## Teleoperation
|
||||
|
||||
### Leader Arm Setup
|
||||
|
||||
The leader arm is moved manually (torque disabled) to generate commands:
|
||||
|
||||
```python
|
||||
from lerobot.teleoperators.openarms import OpenArmsLeader
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
|
||||
config = OpenArmsLeaderConfig(
|
||||
port="can1", # Separate CAN interface for leader
|
||||
id="openarms_leader",
|
||||
manual_control=True, # Torque disabled for manual movement
|
||||
is_dual_arm=True,
|
||||
)
|
||||
|
||||
leader = OpenArmsLeader(config)
|
||||
leader.connect()
|
||||
|
||||
# Read current position as action
|
||||
action = leader.get_action()
|
||||
# action contains positions for all joints in degrees
|
||||
```
|
||||
|
||||
### Safety Considerations for Teleoperation
|
||||
|
||||
1. **Use separate CAN interfaces** for leader and follower to avoid conflicts
|
||||
2. **Enable max_relative_target** on follower to smooth abrupt movements
|
||||
3. **Lower torque limits** on follower to prevent damage from tracking errors
|
||||
4. **Test with one arm** before enabling dual arm teleoperation
|
||||
5. **Have emergency stop** ready (power switch or CAN disable)
|
||||
|
||||
```python
|
||||
# Recommended follower config for teleoperation
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
max_relative_target=5.0, # Small steps for smooth following
|
||||
gripper_torque_limit=0.3, # Low torque for safety
|
||||
position_kp=5.0, # Lower gains for gentler following
|
||||
position_kd=0.3,
|
||||
)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Motor Shaking/Unstable
|
||||
|
||||
- **Lower control gains**: Reduce `position_kp` and `position_kd`
|
||||
- **Check calibration**: Re-run calibration procedure
|
||||
- **Verify power**: Insufficient current can cause instability
|
||||
- **Check mechanical**: Loose connections, binding, or damaged components
|
||||
|
||||
### CAN Bus Errors
|
||||
|
||||
```bash
|
||||
# Check for errors
|
||||
ip -s link show can0
|
||||
|
||||
# Reset CAN interface
|
||||
sudo ip link set can0 down
|
||||
sudo ip link set can0 up
|
||||
```
|
||||
|
||||
### Control Mode
|
||||
|
||||
OpenArms uses **MIT control mode** which allows simultaneous control of:
|
||||
- Position (degrees)
|
||||
- Velocity (degrees/second)
|
||||
- Torque (N·m)
|
||||
- Position gain (Kp)
|
||||
- Velocity damping (Kd)
|
||||
|
||||
### Communication
|
||||
|
||||
- **Protocol**: CAN 2.0 at 1 Mbps (or CAN-FD at 5 Mbps)
|
||||
- **Frame format**: Standard 11-bit IDs
|
||||
- **Update rate**: Typically 50-100 Hz depending on motor count
|
||||
- **Latency**: ~10-20ms per motor command
|
||||
|
||||
## References
|
||||
|
||||
- [OpenArm Official Documentation](https://docs.openarm.dev/)
|
||||
- [OpenArm Setup Guide](https://docs.openarm.dev/software/setup/)
|
||||
- [Motor ID Configuration](https://docs.openarm.dev/software/setup/motor-id)
|
||||
- [CAN Interface Setup](https://docs.openarm.dev/software/setup/can-setup)
|
||||
- [Motor Communication Test](https://docs.openarm.dev/software/setup/configure-test)
|
||||
- [Damiao Motor Documentation](https://wiki.seeedstudio.com/damiao_series/)
|
||||
- [Enactic GitHub](https://github.com/enactic/openarm_can)
|
||||
27
docs/source/policy_groot_README.md
Normal file
27
docs/source/policy_groot_README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
## Research Paper
|
||||
|
||||
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
|
||||
## Repository
|
||||
|
||||
Code: https://github.com/NVIDIA/Isaac-GR00T
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{gr00tn1_2025,
|
||||
archivePrefix = {arxiv},
|
||||
eprint = {2503.14734},
|
||||
title = {{GR00T} {N1}: An Open Foundation Model for Generalist Humanoid Robots},
|
||||
author = {NVIDIA and Johan Bjorck andFernando Castañeda, Nikita Cherniadev and Xingye Da and Runyu Ding and Linxi "Jim" Fan and Yu Fang and Dieter Fox and Fengyuan Hu and Spencer Huang and Joel Jang and Zhenyu Jiang and Jan Kautz and Kaushil Kundalia and Lawrence Lao and Zhiqi Li and Zongyu Lin and Kevin Lin and Guilin Liu and Edith Llontop and Loic Magne and Ajay Mandlekar and Avnish Narayan and Soroush Nasiriany and Scott Reed and You Liang Tan and Guanzhi Wang and Zu Wang and Jing Wang and Qi Wang and Jiannan Xiang and Yuqi Xie and Yinzhen Xu and Zhenjia Xu and Seonghyeon Ye and Zhiding Yu and Ao Zhang and Hao Zhang and Yizhou Zhao and Ruijie Zheng and Yuke Zhu},
|
||||
month = {March},
|
||||
year = {2025},
|
||||
booktitle = {ArXiv Preprint},
|
||||
}
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
Blog: https://developer.nvidia.com/isaac/gr00t
|
||||
|
||||
Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B
|
||||
@@ -132,17 +132,15 @@ print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
||||
|
||||
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
|
||||
# PyTorch datasets.
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
break
|
||||
if __name__ == "__main__":
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
)
|
||||
for batch in dataloader:
|
||||
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
break
|
||||
|
||||
140
examples/openarms/bilateral.py
Normal file
140
examples/openarms/bilateral.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import time
|
||||
import numpy as np
|
||||
import pinocchio as pin
|
||||
from os.path import dirname
|
||||
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
|
||||
|
||||
same_direction = {"joint_4", "gripper"}
|
||||
|
||||
idx = {
|
||||
"joint_1": 0,
|
||||
"joint_2": 1,
|
||||
"joint_3": 2,
|
||||
"joint_4": 3,
|
||||
"joint_5": 4,
|
||||
"joint_6": 5,
|
||||
"joint_7": 6,
|
||||
"gripper": 7,
|
||||
}
|
||||
|
||||
# joints to freeze
|
||||
frozen = {"joint_6", "joint_7", "gripper"}
|
||||
initial_pose = {}
|
||||
|
||||
|
||||
def pos_deg(rob, obs):
|
||||
out = {}
|
||||
for side in ("left", "right"):
|
||||
for m in getattr(rob, f"bus_{side}").motors:
|
||||
k = f"{side}_{m}.pos"
|
||||
if k in obs:
|
||||
out[f"{side}_{m}"] = obs[k]
|
||||
return out
|
||||
|
||||
|
||||
def vel_rad(rob, obs):
|
||||
out = {}
|
||||
for side in ("left", "right"):
|
||||
for m in getattr(rob, f"bus_{side}").motors:
|
||||
k = f"{side}_{m}.vel"
|
||||
out[f"{side}_{m}"] = np.deg2rad(obs.get(k, 0.0))
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
cfg = OpenArmsLeaderConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_bilateral",
|
||||
manual_control=False,
|
||||
)
|
||||
|
||||
rob = OpenArmsLeader(cfg)
|
||||
rob.connect(calibrate=True)
|
||||
|
||||
urdf = "/home/yope/Documents/lerobot_g1_integration/openarm_description/openarm_bimanual_pybullet.urdf"
|
||||
rob.pin_robot = pin.RobotWrapper.BuildFromURDF(urdf, dirname(urdf))
|
||||
rob.pin_robot.data = rob.pin_robot.model.createData()
|
||||
|
||||
dt = 0.005
|
||||
grav = 1.0
|
||||
fric = 0.3
|
||||
|
||||
# capture initial pose to freeze selected joints later
|
||||
obs0 = rob.get_action()
|
||||
for side in ("left", "right"):
|
||||
for m in getattr(rob, f"bus_{side}").motors:
|
||||
key = f"{side}_{m}.pos"
|
||||
if key in obs0 and m in frozen:
|
||||
initial_pose[f"{side}_{m}"] = obs0[key]
|
||||
|
||||
try:
|
||||
while True:
|
||||
obs = rob.get_action()
|
||||
|
||||
pdeg = pos_deg(rob, obs)
|
||||
prad = {k: np.deg2rad(v) for k, v in pdeg.items()}
|
||||
vrad = vel_rad(rob, obs)
|
||||
|
||||
tau_g = rob._gravity_from_q(prad)
|
||||
tau_f = rob._friction_from_velocity(vrad, friction_scale=fric)
|
||||
|
||||
# bilateral midpoint calculation
|
||||
cmd = {}
|
||||
for m in rob.bus_right.motors:
|
||||
kl = f"left_{m}.pos"
|
||||
kr = f"right_{m}.pos"
|
||||
if kl not in obs or kr not in obs:
|
||||
continue
|
||||
|
||||
ql = obs[kl]
|
||||
qr = obs[kr]
|
||||
|
||||
if m in same_direction:
|
||||
qmid = 0.5 * (ql + qr)
|
||||
else:
|
||||
qmid = 0.5 * (ql - qr)
|
||||
|
||||
# assign midpoint for both
|
||||
cmd[f"left_{m}"] = qmid
|
||||
cmd[f"right_{m}"] = qmid if m in same_direction else -qmid
|
||||
|
||||
# override midpoint with frozen values
|
||||
for key, val in initial_pose.items():
|
||||
cmd[key] = val
|
||||
|
||||
# single mit control call
|
||||
for side in ("left", "right"):
|
||||
bus = getattr(rob, f"bus_{side}")
|
||||
for m in bus.motors:
|
||||
base_key = f"{side}_{m}"
|
||||
kp = float(cfg.position_kp[idx[m]])
|
||||
kd = float(cfg.position_kd[idx[m]])
|
||||
torque = tau_g.get(base_key, 0.0) * grav + tau_f.get(base_key, 0.0)
|
||||
pos_cmd = cmd.get(base_key, pdeg.get(base_key, 0.0))
|
||||
|
||||
bus._mit_control(
|
||||
motor=m,
|
||||
kp=kp,
|
||||
kd=kd,
|
||||
position_degrees=pos_cmd,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
time.sleep(dt)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
rob.bus_left.disable_torque()
|
||||
rob.bus_right.disable_torque()
|
||||
rob.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
416
examples/openarms/debug_can_communication.py
Normal file
416
examples/openarms/debug_can_communication.py
Normal file
@@ -0,0 +1,416 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive debug script for OpenArms CAN FD communication.
|
||||
Tests all 4 CAN interfaces with CAN FD support.
|
||||
"""
|
||||
|
||||
import can
|
||||
import time
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
def check_can_interface(port):
|
||||
"""Check if CAN interface is UP and configured."""
|
||||
try:
|
||||
result = subprocess.run(['ip', 'link', 'show', port],
|
||||
capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
return False, "Interface not found", None
|
||||
|
||||
output = result.stdout
|
||||
if 'UP' not in output:
|
||||
return False, "Interface is DOWN", None
|
||||
|
||||
# Check if CAN FD is enabled
|
||||
is_fd = 'fd on' in output.lower() or 'canfd' in output.lower()
|
||||
|
||||
return True, "Interface is UP", is_fd
|
||||
except FileNotFoundError:
|
||||
return None, "Cannot check (ip command not found)", None
|
||||
|
||||
|
||||
def test_motor_on_interface(bus, motor_id, timeout=2.0, use_fd=False):
|
||||
"""
|
||||
Test a single motor and return all responses.
|
||||
|
||||
Returns:
|
||||
list of (arbitration_id, data) tuples for all responses received
|
||||
"""
|
||||
# Send enable command
|
||||
enable_msg = can.Message(
|
||||
arbitration_id=motor_id,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
|
||||
is_extended_id=False,
|
||||
is_fd=use_fd
|
||||
)
|
||||
|
||||
try:
|
||||
bus.send(enable_msg)
|
||||
except Exception as e:
|
||||
return None, f"Send error: {e}"
|
||||
|
||||
# Listen for responses
|
||||
responses = []
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
msg = bus.recv(timeout=0.1)
|
||||
if msg:
|
||||
responses.append((msg.arbitration_id, msg.data, msg.is_fd if hasattr(msg, 'is_fd') else False))
|
||||
|
||||
# Send disable command
|
||||
disable_msg = can.Message(
|
||||
arbitration_id=motor_id,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD],
|
||||
is_extended_id=False,
|
||||
is_fd=use_fd
|
||||
)
|
||||
try:
|
||||
bus.send(disable_msg)
|
||||
except:
|
||||
pass
|
||||
|
||||
return responses, None
|
||||
|
||||
|
||||
def test_interface(port, interface_type="socketcan", use_can_fd=True):
|
||||
"""Test all 8 motors on a single CAN interface."""
|
||||
|
||||
results = {
|
||||
'interface': port,
|
||||
'status': None,
|
||||
'is_fd': use_can_fd,
|
||||
'motors': {}
|
||||
}
|
||||
|
||||
# Check interface status
|
||||
status_ok, status_msg, interface_has_fd = check_can_interface(port)
|
||||
|
||||
if interface_has_fd is not None:
|
||||
results['interface_fd_enabled'] = interface_has_fd
|
||||
if use_can_fd and not interface_has_fd:
|
||||
status_msg += " (CAN FD NOT enabled on interface!)"
|
||||
elif interface_has_fd:
|
||||
status_msg += " (CAN FD enabled)"
|
||||
|
||||
results['status'] = status_msg
|
||||
|
||||
if status_ok is False:
|
||||
return results
|
||||
|
||||
# Try to connect
|
||||
try:
|
||||
if use_can_fd:
|
||||
print(f" Connecting to {port} with CAN FD (1 Mbps / 5 Mbps)...")
|
||||
bus = can.interface.Bus(
|
||||
channel=port,
|
||||
interface=interface_type,
|
||||
bitrate=1000000,
|
||||
data_bitrate=5000000,
|
||||
fd=True
|
||||
)
|
||||
else:
|
||||
print(f" Connecting to {port} with CAN 2.0 (1 Mbps)...")
|
||||
bus = can.interface.Bus(
|
||||
channel=port,
|
||||
interface=interface_type,
|
||||
bitrate=1000000
|
||||
)
|
||||
except Exception as e:
|
||||
results['status'] = f"Connection failed: {e}"
|
||||
return results
|
||||
|
||||
try:
|
||||
# Clear any pending messages
|
||||
while bus.recv(timeout=0.01):
|
||||
pass
|
||||
|
||||
# Test each motor (0x01 to 0x08)
|
||||
for motor_id in range(0x01, 0x09):
|
||||
responses, error = test_motor_on_interface(bus, motor_id, timeout=1.0, use_fd=use_can_fd)
|
||||
|
||||
if error:
|
||||
results['motors'][motor_id] = {'error': error}
|
||||
elif responses:
|
||||
results['motors'][motor_id] = {
|
||||
'found': True,
|
||||
'responses': responses
|
||||
}
|
||||
else:
|
||||
results['motors'][motor_id] = {
|
||||
'found': False,
|
||||
'responses': []
|
||||
}
|
||||
|
||||
time.sleep(0.05) # Small delay between motors
|
||||
|
||||
finally:
|
||||
bus.shutdown()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def print_results(all_results):
|
||||
"""Print formatted results for all interfaces."""
|
||||
|
||||
print("SUMMARY - Motors Found on Each Interface")
|
||||
|
||||
motor_names = {
|
||||
0x01: "joint_1 (Shoulder pan)",
|
||||
0x02: "joint_2 (Shoulder lift)",
|
||||
0x03: "joint_3 (Shoulder rotation)",
|
||||
0x04: "joint_4 (Elbow flex)",
|
||||
0x05: "joint_5 (Wrist roll)",
|
||||
0x06: "joint_6 (Wrist pitch)",
|
||||
0x07: "joint_7 (Wrist rotation)",
|
||||
0x08: "gripper",
|
||||
}
|
||||
|
||||
total_found = 0
|
||||
|
||||
for result in all_results:
|
||||
interface = result['interface']
|
||||
status = result['status']
|
||||
|
||||
print(f"{interface}: {status}")
|
||||
if result.get('is_fd'):
|
||||
print(f" Mode: CAN FD")
|
||||
else:
|
||||
print(f" Mode: CAN 2.0")
|
||||
|
||||
if 'Connection failed' in status or 'DOWN' in status:
|
||||
print(f" ⚠ Cannot test {interface}")
|
||||
continue
|
||||
|
||||
motors_found = 0
|
||||
|
||||
for motor_id in range(0x01, 0x09):
|
||||
motor_data = result['motors'].get(motor_id, {})
|
||||
motor_name = motor_names.get(motor_id, "Unknown")
|
||||
|
||||
if motor_data.get('error'):
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ {motor_data['error']}")
|
||||
elif motor_data.get('found'):
|
||||
motors_found += 1
|
||||
total_found += 1
|
||||
responses = motor_data['responses']
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✓ FOUND")
|
||||
|
||||
for resp_id, data, is_fd in responses:
|
||||
data_hex = data.hex()
|
||||
fd_flag = " [FD]" if is_fd else " [2.0]"
|
||||
print(f" → Response from 0x{resp_id:02X}{fd_flag}: {data_hex}")
|
||||
else:
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ No response")
|
||||
|
||||
print(f"\n Summary: {motors_found}/8 motors found on {interface}")
|
||||
|
||||
# Overall summary
|
||||
print("OVERALL SUMMARY")
|
||||
print(f"Total motors found across all interfaces: {total_found}")
|
||||
|
||||
# Analyze configuration
|
||||
print("DIAGNOSIS")
|
||||
|
||||
for result in all_results:
|
||||
interface = result['interface']
|
||||
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
|
||||
|
||||
if motors_found == 0:
|
||||
print(f"\n⚠ {interface}: NO MOTORS FOUND")
|
||||
print(" Possible issues:")
|
||||
print(" 1. CAN FD mode mismatch (interface vs motor configuration)")
|
||||
print(" 2. Missing 120Ω termination resistors at BOTH cable ends")
|
||||
print(" 3. Motor timeout parameter set incorrectly (should NOT be 0)")
|
||||
print(" 4. CANH/CANL wiring issue")
|
||||
print(" 5. Cable too long (>40m for CAN FD at 5Mbps)")
|
||||
|
||||
# Check FD mismatch
|
||||
if result.get('is_fd') and not result.get('interface_fd_enabled'):
|
||||
print(" ⚠️ CRITICAL: Trying CAN FD but interface NOT configured for FD!")
|
||||
print(f" Fix: sudo ip link set {interface} type can bitrate 1000000 dbitrate 5000000 fd on")
|
||||
|
||||
elif motors_found < 8:
|
||||
print(f"\n⚠ {interface}: Only {motors_found}/8 motors responding")
|
||||
print(" Check power and connections for missing motors")
|
||||
else:
|
||||
print(f"\n✓ {interface}: All 8 motors responding correctly!")
|
||||
|
||||
# Check for unexpected response IDs
|
||||
print("RESPONSE ID ANALYSIS")
|
||||
|
||||
for result in all_results:
|
||||
interface = result['interface']
|
||||
unexpected = []
|
||||
|
||||
for motor_id, motor_data in result['motors'].items():
|
||||
if motor_data.get('found'):
|
||||
expected_id = motor_id + 0x10
|
||||
actual_ids = [resp[0] for resp in motor_data['responses']]
|
||||
|
||||
if expected_id not in actual_ids:
|
||||
unexpected.append((motor_id, actual_ids))
|
||||
|
||||
if unexpected:
|
||||
print(f"\n⚠ {interface}: Unexpected response IDs detected")
|
||||
for motor_id, actual_ids in unexpected:
|
||||
expected_id = motor_id + 0x10
|
||||
print(f" Motor 0x{motor_id:02X}: Expected 0x{expected_id:02X}, "
|
||||
f"got {[f'0x{id:02X}' for id in actual_ids]}")
|
||||
print(" → Motor Master IDs need reconfiguration")
|
||||
else:
|
||||
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
|
||||
if motors_found > 0:
|
||||
print(f"\n✓ {interface}: All responding motors use correct IDs")
|
||||
|
||||
|
||||
def test_communication_speed(interface, motor_id, num_iterations=100):
|
||||
"""
|
||||
Test communication speed with a motor.
|
||||
|
||||
Returns:
|
||||
tuple: (hz, avg_latency_ms) or (None, None) if test failed
|
||||
"""
|
||||
try:
|
||||
# Connect to interface
|
||||
bus = can.interface.Bus(
|
||||
channel=interface,
|
||||
interface="socketcan",
|
||||
bitrate=1000000,
|
||||
data_bitrate=5000000,
|
||||
fd=True
|
||||
)
|
||||
|
||||
# Send refresh commands and measure round-trip time
|
||||
latencies = []
|
||||
successful = 0
|
||||
|
||||
for _ in range(num_iterations):
|
||||
start = time.perf_counter()
|
||||
|
||||
# Send enable command (lightweight operation)
|
||||
enable_msg = can.Message(
|
||||
arbitration_id=motor_id,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
|
||||
is_extended_id=False,
|
||||
is_fd=True
|
||||
)
|
||||
bus.send(enable_msg)
|
||||
|
||||
# Wait for response
|
||||
msg = bus.recv(timeout=0.1)
|
||||
|
||||
if msg:
|
||||
latency = (time.perf_counter() - start) * 1000 # Convert to ms
|
||||
latencies.append(latency)
|
||||
successful += 1
|
||||
|
||||
bus.shutdown()
|
||||
|
||||
if successful > 0:
|
||||
avg_latency = sum(latencies) / len(latencies)
|
||||
hz = 1000.0 / avg_latency if avg_latency > 0 else 0
|
||||
return hz, avg_latency
|
||||
|
||||
return None, None
|
||||
|
||||
except Exception as e:
|
||||
print(f" Speed test error: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to test all CAN interfaces with CAN FD."""
|
||||
|
||||
print("\nThis will test all 4 CAN interfaces (can0-can3) with CAN FD")
|
||||
print("Testing motors 0x01-0x08 on each interface")
|
||||
print()
|
||||
print("Make sure:")
|
||||
print(" ✓ Motors are powered (24V)")
|
||||
print(" ✓ CAN interfaces configured with FD mode:")
|
||||
print(" ./examples/openarms/setup_can.sh")
|
||||
print(" ✓ Motor 'timeout' parameter NOT set to 0 (use Damiao tools)")
|
||||
print(" ✓ CAN wiring includes 120Ω termination at BOTH ends")
|
||||
print()
|
||||
|
||||
input("Press ENTER to start testing...")
|
||||
|
||||
# Test all 4 interfaces with CAN FD
|
||||
all_results = []
|
||||
|
||||
for i in range(4):
|
||||
interface = f"can{i}"
|
||||
print(f"Testing {interface}...")
|
||||
|
||||
result = test_interface(interface, use_can_fd=True)
|
||||
all_results.append(result)
|
||||
|
||||
# Quick status
|
||||
if 'Connection failed' in result['status'] or 'DOWN' in result['status']:
|
||||
print(f" ⚠ {interface}: {result['status']}")
|
||||
else:
|
||||
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
|
||||
print(f" {interface}: {motors_found}/8 motors found")
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
# Print detailed results
|
||||
print_results(all_results)
|
||||
|
||||
print("Testing Complete!")
|
||||
|
||||
all_found = sum(sum(1 for m in r['motors'].values() if m.get('found')) for r in all_results)
|
||||
|
||||
if all_found == 0:
|
||||
print("\n⚠️ CRITICAL: No motors found on any interface!")
|
||||
print("\nTop issues to check:")
|
||||
print(" 1. Motor 'timeout' parameter (use Damiao tools to set > 0)")
|
||||
print(" 2. CAN FD not enabled (run ./examples/openarms/setup_can.sh)")
|
||||
print(" 3. Missing termination resistors")
|
||||
print("\nTry:")
|
||||
print(" a) Check motor parameters with Damiao Debugging Tools")
|
||||
print(" b) Verify CAN FD is enabled: ip -d link show can0 | grep fd")
|
||||
print(" c) Run setup script: ./examples/openarms/setup_can.sh")
|
||||
else:
|
||||
# Run speed test on interfaces with motors
|
||||
print("COMMUNICATION SPEED TEST")
|
||||
print("\nTesting maximum communication frequency...")
|
||||
|
||||
for result in all_results:
|
||||
interface = result['interface']
|
||||
|
||||
# Find first responding motor
|
||||
responding_motor = None
|
||||
for motor_id, motor_data in result['motors'].items():
|
||||
if motor_data.get('found'):
|
||||
responding_motor = motor_id
|
||||
break
|
||||
|
||||
if responding_motor:
|
||||
print(f"\n{interface}: Testing with motor 0x{responding_motor:02X}...")
|
||||
hz, latency = test_communication_speed(interface, responding_motor, num_iterations=100)
|
||||
|
||||
if hz:
|
||||
print(f" ✓ Max frequency: {hz:.1f} Hz")
|
||||
print(f" ✓ Avg latency: {latency:.2f} ms")
|
||||
print(f" ✓ Commands per second: ~{int(hz)}")
|
||||
else:
|
||||
print(f" ✗ Speed test failed")
|
||||
else:
|
||||
print(f"\n{interface}: No motors found, skipping speed test")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nTesting interrupted by user.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\nUnexpected error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
360
examples/openarms/evaluate.py
Normal file
360
examples/openarms/evaluate.py
Normal file
@@ -0,0 +1,360 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OpenArms Policy Evaluation
|
||||
|
||||
Evaluates a trained policy on the OpenArms robot by running inference and recording
|
||||
the evaluation episodes to a dataset. Supports optional leader arm for manual resets.
|
||||
|
||||
Example usage:
|
||||
python examples/openarms/evaluate.py
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
|
||||
HF_MODEL_ID = "lerobot-data-collection/three-folds-pi0" # TODO: Replace with your trained model
|
||||
HF_EVAL_DATASET_ID = "lerobot-data-collection/three-folds-pi0_eval7" # TODO: Replace with your eval dataset name
|
||||
TASK_DESCRIPTION = "three-folds-dataset" # TODO: Replace with your task, this should match!!
|
||||
|
||||
NUM_EPISODES = 1
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 300
|
||||
RESET_TIME_SEC = 60
|
||||
|
||||
# Robot CAN interfaces
|
||||
FOLLOWER_LEFT_PORT = "can0"
|
||||
FOLLOWER_RIGHT_PORT = "can1"
|
||||
|
||||
# If enabled, you can manually reset the environment between evaluation episodes
|
||||
USE_LEADER_FOR_RESETS = True # Set to False if you don't want to use leader
|
||||
LEADER_LEFT_PORT = "can2"
|
||||
LEADER_RIGHT_PORT = "can3"
|
||||
|
||||
# Camera configuration
|
||||
CAMERA_CONFIG = {
|
||||
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=FPS),
|
||||
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=FPS),
|
||||
"base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=FPS),
|
||||
}
|
||||
|
||||
def main():
|
||||
"""Main evaluation function."""
|
||||
print("OpenArms Policy Evaluation")
|
||||
print(f"\nModel: {HF_MODEL_ID}")
|
||||
print(f"Evaluation Dataset: {HF_EVAL_DATASET_ID}")
|
||||
print(f"Task: {TASK_DESCRIPTION}")
|
||||
print(f"Episodes: {NUM_EPISODES}")
|
||||
print(f"Episode Duration: {EPISODE_TIME_SEC}s")
|
||||
print(f"Reset Duration: {RESET_TIME_SEC}s")
|
||||
print(f"Use Leader for Resets: {USE_LEADER_FOR_RESETS}")
|
||||
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left=FOLLOWER_LEFT_PORT,
|
||||
port_right=FOLLOWER_RIGHT_PORT,
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0,
|
||||
cameras=CAMERA_CONFIG,
|
||||
)
|
||||
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
follower.connect(calibrate=False)
|
||||
|
||||
if not follower.is_connected:
|
||||
raise RuntimeError("Follower robot failed to connect!")
|
||||
|
||||
|
||||
leader = None
|
||||
if USE_LEADER_FOR_RESETS:
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
port_left=LEADER_LEFT_PORT,
|
||||
port_right=LEADER_RIGHT_PORT,
|
||||
can_interface="socketcan",
|
||||
id="openarms_leader",
|
||||
manual_control=False, # Enable torque control for gravity compensation
|
||||
)
|
||||
|
||||
leader = OpenArmsLeader(leader_config)
|
||||
leader.connect(calibrate=False)
|
||||
|
||||
if not leader.is_connected:
|
||||
raise RuntimeError("Leader robot failed to connect!")
|
||||
|
||||
# Enable gravity compensation
|
||||
if leader.pin_robot is not None:
|
||||
leader.bus_right.enable_torque()
|
||||
leader.bus_left.enable_torque()
|
||||
time.sleep(0.1)
|
||||
print(f"Leader connected with gravity compensation ({LEADER_LEFT_PORT}, {LEADER_RIGHT_PORT})")
|
||||
else:
|
||||
print(f"Leader connected but gravity compensation unavailable (no URDF)")
|
||||
|
||||
# Build default processors for action and observation
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Build dataset features from robot features and processors
|
||||
# For actions, only include positions (no velocity or torque)
|
||||
action_features_hw = {}
|
||||
for key, value in follower.action_features.items():
|
||||
if key.endswith(".pos"):
|
||||
action_features_hw[key] = value
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(action=action_features_hw),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_observation_processor,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Check if dataset already exists
|
||||
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / HF_EVAL_DATASET_ID
|
||||
if dataset_path.exists():
|
||||
print(f"Evaluation dataset already exists at: {dataset_path}")
|
||||
print("This will append new episodes to the existing dataset.")
|
||||
choice = input(" Continue? (y/n): ").strip().lower()
|
||||
if choice != 'y':
|
||||
print(" Aborting evaluation.")
|
||||
follower.disconnect()
|
||||
if leader:
|
||||
leader.disconnect()
|
||||
return
|
||||
|
||||
# Create dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_EVAL_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_processes=0,
|
||||
image_writer_threads=12,
|
||||
)
|
||||
|
||||
# Load policy config from pretrained model and create policy using factory
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
policy = make_policy(policy_config, ds_meta=dataset.meta)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy.config,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": str(policy.config.device)}
|
||||
},
|
||||
)
|
||||
|
||||
print(f"\nRunning evaluation...")
|
||||
# Initialize keyboard listener and visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="openarms_evaluation")
|
||||
episode_idx = 0
|
||||
|
||||
try:
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Evaluating episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print(f"\nRunning inference for episode {episode_idx + 1}...")
|
||||
|
||||
# Run inference with policy
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
# Handle re-recording
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Save episode
|
||||
if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0:
|
||||
print(f"Saving episode {episode_idx + 1} ({dataset.episode_buffer['size']} frames)...")
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Reset environment between episodes (if not last episode)
|
||||
if not events["stop_recording"] and episode_idx < NUM_EPISODES:
|
||||
if USE_LEADER_FOR_RESETS and leader:
|
||||
log_say("Reset the environment using leader arms")
|
||||
print(f"\nManual reset period ({RESET_TIME_SEC}s)...")
|
||||
|
||||
# Use leader for manual reset with gravity compensation
|
||||
import numpy as np
|
||||
|
||||
dt = 1 / FPS
|
||||
reset_start_time = time.perf_counter()
|
||||
|
||||
while time.perf_counter() - reset_start_time < RESET_TIME_SEC:
|
||||
if events["exit_early"] or events["stop_recording"]:
|
||||
break
|
||||
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get leader state
|
||||
leader_action = leader.get_action()
|
||||
|
||||
# Extract positions and velocities
|
||||
leader_positions_deg = {}
|
||||
leader_velocities_deg_per_sec = {}
|
||||
|
||||
for motor in leader.bus_right.motors:
|
||||
pos_key = f"right_{motor}.pos"
|
||||
vel_key = f"right_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
|
||||
|
||||
for motor in leader.bus_left.motors:
|
||||
pos_key = f"left_{motor}.pos"
|
||||
vel_key = f"left_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
|
||||
|
||||
# Calculate gravity and friction torques
|
||||
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
|
||||
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
|
||||
|
||||
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
|
||||
leader_friction_torques_nm = leader._friction_from_velocity(
|
||||
leader_velocities_rad_per_sec,
|
||||
friction_scale=1.0
|
||||
)
|
||||
|
||||
# Combine torques
|
||||
leader_total_torques_nm = {}
|
||||
for motor_name in leader_gravity_torques_nm:
|
||||
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
|
||||
friction = leader_friction_torques_nm.get(motor_name, 0.0)
|
||||
leader_total_torques_nm[motor_name] = gravity + friction
|
||||
|
||||
# Apply compensation
|
||||
for motor in leader.bus_right.motors:
|
||||
full_name = f"right_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_right._mit_control(
|
||||
motor=motor, kp=0.0, kd=kd,
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
for motor in leader.bus_left.motors:
|
||||
full_name = f"left_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_left._mit_control(
|
||||
motor=motor, kp=0.0, kd=kd,
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Send leader positions to follower
|
||||
follower_action = {}
|
||||
for joint in leader_positions_deg.keys():
|
||||
pos_key = f"{joint}.pos"
|
||||
if pos_key in leader_action:
|
||||
follower_action[pos_key] = leader_action[pos_key]
|
||||
|
||||
if follower_action:
|
||||
follower.send_action(follower_action)
|
||||
|
||||
# Maintain loop rate
|
||||
loop_duration = time.perf_counter() - loop_start
|
||||
sleep_time = dt - loop_duration
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
print("Reset complete")
|
||||
else:
|
||||
log_say("Waiting for manual reset")
|
||||
print(f"Manually reset the environment and press ENTER to continue")
|
||||
input("Press ENTER when ready...")
|
||||
|
||||
print(f"Evaluation complete! {episode_idx} episodes recorded")
|
||||
log_say("Evaluation complete", blocking=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nEvaluation interrupted by user")
|
||||
|
||||
finally:
|
||||
if leader:
|
||||
leader.bus_right.disable_torque()
|
||||
leader.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
leader.disconnect()
|
||||
|
||||
follower.disconnect()
|
||||
|
||||
if listener is not None:
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
print("\nUploading to Hugging Face Hub...")
|
||||
dataset.push_to_hub(private=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
216
examples/openarms/friction_compensation.py
Normal file
216
examples/openarms/friction_compensation.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
|
||||
|
||||
# Friction model parameters from OpenArms config/follower.yaml
|
||||
# τ_fric(ω) = Fo + Fv·ω + Fc·tanh(k·ω)
|
||||
# For 8 motors: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
|
||||
FRICTION_PARAMS = {
|
||||
"Fc": [0.306, 0.306, 0.40, 0.166, 0.050, 0.093, 0.172, 0.0512], # Coulomb friction [Nm]
|
||||
"k": [28.417, 28.417, 29.065, 130.038, 151.771, 242.287, 7.888, 4.000], # tanh steepness
|
||||
"Fv": [0.063, 0.0630, 0.604, 0.813, 0.029, 0.072, 0.084, 0.084], # Viscous friction [Nm·s/rad]
|
||||
"Fo": [0.088, 0.088, 0.008, -0.058, 0.005, 0.009, -0.059, -0.050], # Offset torque [Nm]
|
||||
}
|
||||
|
||||
# Constants from OpenArms C++ implementation
|
||||
AMP_TMP = 1.0
|
||||
COEF_TMP = 0.1
|
||||
|
||||
FRICTION_SCALE = 1.0 # OpenArms C++ uses 0.3 factor in unilateral mode
|
||||
DAMPING_KD = [0.5, 0.5, 0.5, 0.5, 0.1, 0.1, 0.1, 0.1] # Damping gains for stability
|
||||
|
||||
def compute_friction_torque(velocity_rad_per_sec: float, motor_index: int) -> float:
|
||||
"""
|
||||
Compute friction torque for a single motor using the tanh friction model.
|
||||
|
||||
Args:
|
||||
velocity_rad_per_sec: Angular velocity in rad/s
|
||||
motor_index: Index of the motor (0-7)
|
||||
|
||||
Returns:
|
||||
Friction torque in N·m (scaled for stability)
|
||||
"""
|
||||
|
||||
Fc = FRICTION_PARAMS["Fc"][motor_index]
|
||||
k = FRICTION_PARAMS["k"][motor_index]
|
||||
Fv = FRICTION_PARAMS["Fv"][motor_index]
|
||||
Fo = FRICTION_PARAMS["Fo"][motor_index]
|
||||
|
||||
# Friction model: τ_fric = amp * Fc * tanh(coef * k * ω) + Fv * ω + Fo
|
||||
friction_torque = (
|
||||
AMP_TMP * Fc * np.tanh(COEF_TMP * k * velocity_rad_per_sec) +
|
||||
Fv * velocity_rad_per_sec +
|
||||
Fo
|
||||
)
|
||||
|
||||
# Scale down friction compensation for stability at lower control rates
|
||||
# (OpenArms C++ uses 0.3 factor in unilateral mode)!!
|
||||
friction_torque *= FRICTION_SCALE
|
||||
|
||||
return friction_torque
|
||||
|
||||
|
||||
def main() -> None:
|
||||
config = OpenArmsFollowerConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=5.0,
|
||||
)
|
||||
|
||||
print("Initializing robot...")
|
||||
follower = OpenArmsFollower(config)
|
||||
follower.connect(calibrate=True)
|
||||
|
||||
print(f"Applying friction compensation")
|
||||
print(" 1. Support the arm before starting")
|
||||
print(" 2. The arm will be held in place by friction compensation")
|
||||
print(" 3. You should be able to move it with gentle force")
|
||||
print("\nPress ENTER when ready to start...")
|
||||
input()
|
||||
|
||||
print(f"✓ Motors enabled")
|
||||
print("\nStarting friction compensation loop...")
|
||||
print("Press Ctrl+C to stop\n")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = time.perf_counter()
|
||||
|
||||
# Motor name to index mapping
|
||||
motor_name_to_index = {
|
||||
"joint_1": 0,
|
||||
"joint_2": 1,
|
||||
"joint_3": 2,
|
||||
"joint_4": 3,
|
||||
"joint_5": 4,
|
||||
"joint_6": 5,
|
||||
"joint_7": 6,
|
||||
"gripper": 7,
|
||||
}
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get current joint positions and velocities from robot
|
||||
obs = follower.get_observation()
|
||||
|
||||
# Extract velocities in degrees per second
|
||||
velocities_deg_per_sec = {}
|
||||
positions_deg = {}
|
||||
|
||||
for motor in follower.bus_right.motors:
|
||||
vel_key = f"right_{motor}.vel"
|
||||
pos_key = f"right_{motor}.pos"
|
||||
if vel_key in obs:
|
||||
velocities_deg_per_sec[f"right_{motor}"] = obs[vel_key]
|
||||
if pos_key in obs:
|
||||
positions_deg[f"right_{motor}"] = obs[pos_key]
|
||||
|
||||
for motor in follower.bus_left.motors:
|
||||
vel_key = f"left_{motor}.vel"
|
||||
pos_key = f"left_{motor}.pos"
|
||||
if vel_key in obs:
|
||||
velocities_deg_per_sec[f"left_{motor}"] = obs[vel_key]
|
||||
if pos_key in obs:
|
||||
positions_deg[f"left_{motor}"] = obs[pos_key]
|
||||
|
||||
# Convert velocities to rad/s and compute friction torques
|
||||
friction_torques_nm = {}
|
||||
for motor_full_name, velocity_deg_per_sec in velocities_deg_per_sec.items():
|
||||
# Extract motor name without arm prefix
|
||||
if motor_full_name.startswith("right_"):
|
||||
motor_name = motor_full_name.removeprefix("right_")
|
||||
elif motor_full_name.startswith("left_"):
|
||||
motor_name = motor_full_name.removeprefix("left_")
|
||||
else:
|
||||
continue
|
||||
|
||||
# Get motor index for friction parameters
|
||||
motor_index = motor_name_to_index.get(motor_name, 0)
|
||||
|
||||
# Convert velocity to rad/s
|
||||
velocity_rad_per_sec = np.deg2rad(velocity_deg_per_sec)
|
||||
|
||||
# Compute friction torque
|
||||
friction_torque = compute_friction_torque(velocity_rad_per_sec, motor_index)
|
||||
friction_torques_nm[motor_full_name] = friction_torque
|
||||
|
||||
# Apply friction compensation to right arm (all joints INCLUDING gripper)
|
||||
for motor in follower.bus_right.motors:
|
||||
full_name = f"right_{motor}"
|
||||
position = positions_deg.get(full_name, 0.0)
|
||||
torque = friction_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get motor index for damping gain
|
||||
motor_index = motor_name_to_index.get(motor, 0)
|
||||
kd = DAMPING_KD[motor_index]
|
||||
|
||||
# Send MIT control command with friction compensation + damping
|
||||
follower.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0, # No position control
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque
|
||||
)
|
||||
|
||||
# Apply friction compensation to left arm (all joints INCLUDING gripper)
|
||||
for motor in follower.bus_left.motors:
|
||||
full_name = f"left_{motor}"
|
||||
position = positions_deg.get(full_name, 0.0)
|
||||
torque = friction_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get motor index for damping gain
|
||||
motor_index = motor_name_to_index.get(motor, 0)
|
||||
kd = DAMPING_KD[motor_index]
|
||||
|
||||
# Send MIT control command with friction compensation + damping
|
||||
follower.bus_left._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0, # No position control
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque
|
||||
)
|
||||
|
||||
# Measure loop time
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
# Print status every 2 seconds
|
||||
if loop_end - last_print_time >= 2.0:
|
||||
if loop_times:
|
||||
avg_time = sum(loop_times) / len(loop_times)
|
||||
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||
|
||||
print(f"{current_hz:.1f} Hz")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = loop_end
|
||||
|
||||
time.sleep(0.001)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping friction compensation...")
|
||||
|
||||
finally:
|
||||
print("\nDisabling all motors and disconnecting...")
|
||||
follower.bus_right.disable_torque()
|
||||
follower.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
follower.disconnect()
|
||||
print("✓ Safe shutdown complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
139
examples/openarms/gravity_compensation.py
Executable file
139
examples/openarms/gravity_compensation.py
Executable file
@@ -0,0 +1,139 @@
|
||||
import time
|
||||
import numpy as np
|
||||
import pinocchio as pin
|
||||
from os.path import join, dirname, exists, expanduser
|
||||
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
|
||||
|
||||
def main() -> None:
|
||||
config = OpenArmsLeaderConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_leader",
|
||||
manual_control=False, # Enable torque control for gravity compensation
|
||||
)
|
||||
|
||||
|
||||
print("Initializing robot...")
|
||||
follower = OpenArmsLeader(config)
|
||||
follower.connect(calibrate=True)
|
||||
|
||||
# Load URDF for Pinocchio dynamics
|
||||
urdf_path = "/home/yope/Documents/lerobot_g1_integration/openarm_description/openarm_bimanual_pybullet.urdf"
|
||||
pin_robot = pin.RobotWrapper.BuildFromURDF(urdf_path, dirname(urdf_path))
|
||||
pin_robot.data = pin_robot.model.createData()
|
||||
print(f"✓ Loaded Pinocchio model with {pin_robot.nq} DoFs")
|
||||
|
||||
follower.pin_robot = pin_robot
|
||||
|
||||
print(f"Applying gravity compensation")
|
||||
print(" 1. Support the arm before starting")
|
||||
print(" 2. The arm will be held in place by gravity compensation")
|
||||
print(" 3. You should be able to move it with gentle force")
|
||||
print("\nPress ENTER when ready to start...")
|
||||
input()
|
||||
|
||||
print(f"✓ Motors enabled")
|
||||
print("\nStarting gravity compensation loop...")
|
||||
print("Press Ctrl+C to stop\n")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get current joint positions from robot
|
||||
obs = follower.get_action()
|
||||
|
||||
# Extract positions in degrees
|
||||
positions_deg = {}
|
||||
for motor in follower.bus_right.motors:
|
||||
key = f"right_{motor}.pos"
|
||||
if key in obs:
|
||||
positions_deg[f"right_{motor}"] = obs[key]
|
||||
|
||||
for motor in follower.bus_left.motors:
|
||||
key = f"left_{motor}.pos"
|
||||
if key in obs:
|
||||
positions_deg[f"left_{motor}"] = obs[key]
|
||||
|
||||
# Convert to radians and calculate gravity torques
|
||||
# Use the built-in method from OpenArmsFollower
|
||||
positions_rad = {k: np.deg2rad(v) for k, v in positions_deg.items()}
|
||||
torques_nm = follower._gravity_from_q(positions_rad)
|
||||
|
||||
# Apply gravity compensation to right arm (all joints except gripper)
|
||||
for motor in follower.bus_right.motors:
|
||||
|
||||
|
||||
full_name = f"right_{motor}"
|
||||
position = positions_deg.get(full_name, 0.0)
|
||||
torque = torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Send MIT control command with gravity compensation torque
|
||||
follower.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0, # No position control
|
||||
kd=0.0, # No velocity damping
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque
|
||||
)
|
||||
|
||||
|
||||
# Apply gravity compensation to left arm (all joints except gripper)
|
||||
for motor in follower.bus_left.motors:
|
||||
|
||||
|
||||
full_name = f"left_{motor}"
|
||||
position = positions_deg.get(full_name, 0.0)
|
||||
torque = torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Send MIT control command with gravity compensation torque
|
||||
follower.bus_left._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0, # No position control
|
||||
kd=0.0, # No velocity damping
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque
|
||||
)
|
||||
|
||||
# Measure loop time
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
# Print status every 2 seconds
|
||||
if loop_end - last_print_time >= 2.0:
|
||||
if loop_times:
|
||||
avg_time = sum(loop_times) / len(loop_times)
|
||||
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||
|
||||
print(f"{current_hz:.1f} Hz ({avg_time*1000:.1f} ms)")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = loop_end
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping gravity compensation...")
|
||||
|
||||
finally:
|
||||
print("\nDisabling all motors and disconnecting...")
|
||||
follower.bus_right.disable_torque()
|
||||
follower.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
follower.disconnect()
|
||||
print("✓ Safe shutdown complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
618
examples/openarms/openarm_bimanual_pybullet.urdf
Normal file
618
examples/openarms/openarm_bimanual_pybullet.urdf
Normal file
@@ -0,0 +1,618 @@
|
||||
<?xml version='1.0' encoding='utf-8'?>
|
||||
<robot name="openarm">
|
||||
<link name="world" />
|
||||
<joint name="openarm_body_world_joint" type="fixed">
|
||||
<parent link="world" />
|
||||
<child link="openarm_body_link0" />
|
||||
<origin rpy="0 0 0" xyz="0 0 0" />
|
||||
</joint>
|
||||
<link name="openarm_body_link0">
|
||||
<visual name="openarm_body_link0_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/body/v10/visual/body_link0.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_body_link0_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/body/v10/collision/body_link0_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
|
||||
<mass value="13.89" />
|
||||
<inertia ixx="1.653" ixy="0.0" ixz="0.0" iyy="1.653" iyz="0.0" izz="0.051" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_openarm_body_link0_joint" type="fixed">
|
||||
<parent link="openarm_body_link0" />
|
||||
<child link="openarm_left_link0" />
|
||||
<origin rpy="-1.5708 0 0" xyz="0.0 0.031 0.698" />
|
||||
</joint>
|
||||
<link name="openarm_left_link0">
|
||||
<visual name="openarm_left_link0_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link0.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_link0_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link0_symp.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0009483362816297526 -0.0001580207020448382 0.03076860287587199" />
|
||||
<mass value="1.1432284943239561" />
|
||||
<inertia ixx="0.001128" ixy="-4e-06" ixz="-3.3e-05" iyy="0.000962" iyz="-7e-06" izz="0.00147" />
|
||||
</inertial>
|
||||
</link>
|
||||
<link name="openarm_left_link1">
|
||||
<visual name="openarm_left_link1_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 0.0 -0.0625" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link1.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_link1_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 0.0 -0.0625" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link1_symp.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0011467657911800769 -3.319987657026362e-05 0.05395284380736254" />
|
||||
<mass value="1.1416684646202298" />
|
||||
<inertia ixx="0.001567" ixy="-1e-06" ixz="-2.9e-05" iyy="0.001273" iyz="1e-06" izz="0.001016" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_joint1" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="0.0 0.0 0.0625" />
|
||||
<parent link="openarm_left_link0" />
|
||||
<child link="openarm_left_link1" />
|
||||
<axis xyz="0 0 1" />
|
||||
<limit effort="40" lower="-3.490659" upper="1.3962629999999998" velocity="16.754666" />
|
||||
</joint>
|
||||
<link name="openarm_left_link2">
|
||||
<visual name="openarm_left_link2_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0301 0.0 -0.1225" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link2.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_link2_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0301 0.0 -0.1225" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link2_symp.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.00839629182351943 2.0145102027597523e-08 0.03256649300522363" />
|
||||
<mass value="0.2775092746011571" />
|
||||
<inertia ixx="0.000359" ixy="1e-06" ixz="-0.000109" iyy="0.000376" iyz="1e-06" izz="0.000232" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_joint2" type="revolute">
|
||||
<origin rpy="-1.57079632679 0 0" xyz="-0.0301 0.0 0.06" />
|
||||
<parent link="openarm_left_link1" />
|
||||
<child link="openarm_left_link2" />
|
||||
<axis xyz="-1 0 0" />
|
||||
<limit effort="40" lower="-3.3161253267948965" upper="0.17453267320510335" velocity="16.754666" />
|
||||
</joint>
|
||||
<link name="openarm_left_link3">
|
||||
<visual name="openarm_left_link3_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.18875" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link3.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_link3_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.18875" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link3_symp.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.002104752099628911 -0.0005549085042607548 0.09047470545721961" />
|
||||
<mass value="1.073863338202347" />
|
||||
<inertia ixx="0.004372" ixy="1e-06" ixz="1.1e-05" iyy="0.004319" iyz="-3.6e-05" izz="0.000661" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_joint3" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="0.0301 0.0 0.06625" />
|
||||
<parent link="openarm_left_link2" />
|
||||
<child link="openarm_left_link3" />
|
||||
<axis xyz="0 0 1" />
|
||||
<limit effort="27" lower="-1.570796" upper="1.570796" velocity="5.445426" />
|
||||
</joint>
|
||||
<link name="openarm_left_link4">
|
||||
<visual name="openarm_left_link4_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0315 -0.3425" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link4.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_link4_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0315 -0.3425" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link4_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0029006831074562967 -0.03030575826634669 0.06339637422196209" />
|
||||
<mass value="0.6348534566833373" />
|
||||
<inertia ixx="0.000623" ixy="-1e-06" ixz="-1.9e-05" iyy="0.000511" iyz="3.8e-05" izz="0.000334" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_joint4" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="-0.0 0.0315 0.15375" />
|
||||
<parent link="openarm_left_link3" />
|
||||
<child link="openarm_left_link4" />
|
||||
<axis xyz="0 1 0" />
|
||||
<limit effort="27" lower="0.0" upper="2.443461" velocity="5.445426" />
|
||||
</joint>
|
||||
<link name="openarm_left_link5">
|
||||
<visual name="openarm_left_link5_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.438" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link5.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_link5_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.438" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link5_symp.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.003049665024221911 -0.0008866902457326625 0.043079803024980934" />
|
||||
<mass value="0.6156588026168502" />
|
||||
<inertia ixx="0.000423" ixy="-8e-06" ixz="6e-06" iyy="0.000445" iyz="-6e-06" izz="0.000324" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_joint5" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="0.0 -0.0315 0.0955" />
|
||||
<parent link="openarm_left_link4" />
|
||||
<child link="openarm_left_link5" />
|
||||
<axis xyz="0 0 1" />
|
||||
<limit effort="7" lower="-1.570796" upper="1.570796" velocity="20.943946" />
|
||||
</joint>
|
||||
<link name="openarm_left_link6">
|
||||
<visual name="openarm_left_link6_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0375 -0.0 -0.5585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link6.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_link6_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0375 -0.0 -0.5585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link6_symp.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.037136587005447405 -0.00033230528343419053 -9.498374522309838e-05" />
|
||||
<mass value="0.475202773187987" />
|
||||
<inertia ixx="0.000143" ixy="1e-06" ixz="1e-06" iyy="0.000157" iyz="1e-06" izz="0.000159" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_joint6" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="0.0375 0.0 0.1205" />
|
||||
<parent link="openarm_left_link5" />
|
||||
<child link="openarm_left_link6" />
|
||||
<axis xyz="1 0 0" />
|
||||
<limit effort="7" lower="-0.785398" upper="0.785398" velocity="20.943946" />
|
||||
</joint>
|
||||
<link name="openarm_left_link7">
|
||||
<visual name="openarm_left_link7_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0 -0.5585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link7.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_link7_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0 -0.5585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link7_symp.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="6.875510271106056e-05 -0.01266175250761268 0.06951945409987448" />
|
||||
<mass value="0.4659771327380578" />
|
||||
<inertia ixx="0.000639" ixy="1e-06" ixz="1e-06" iyy="0.000497" iyz="8.9e-05" izz="0.000342" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_joint7" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="-0.0375 0.0 0.0" />
|
||||
<parent link="openarm_left_link6" />
|
||||
<child link="openarm_left_link7" />
|
||||
<axis xyz="0 -1 0" />
|
||||
<limit effort="7" lower="-1.570796" upper="1.570796" velocity="20.943946" />
|
||||
</joint>
|
||||
<joint name="openarm_right_openarm_body_link0_joint" type="fixed">
|
||||
<parent link="openarm_body_link0" />
|
||||
<child link="openarm_right_link0" />
|
||||
<origin rpy="1.5708 0 0" xyz="0.0 -0.031 0.698" />
|
||||
</joint>
|
||||
<link name="openarm_right_link0">
|
||||
<visual name="openarm_right_link0_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link0.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_link0_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link0_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0009483362816297526 0.0001580207020448382 0.03076860287587199" />
|
||||
<mass value="1.1432284943239561" />
|
||||
<inertia ixx="0.001128" ixy="-4e-06" ixz="-3.3e-05" iyy="0.000962" iyz="-7e-06" izz="0.00147" />
|
||||
</inertial>
|
||||
</link>
|
||||
<link name="openarm_right_link1">
|
||||
<visual name="openarm_right_link1_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 0.0 -0.0625" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link1.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_link1_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 0.0 -0.0625" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link1_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0011467657911800769 3.319987657026362e-05 0.05395284380736254" />
|
||||
<mass value="1.1416684646202298" />
|
||||
<inertia ixx="0.001567" ixy="-1e-06" ixz="-2.9e-05" iyy="0.001273" iyz="1e-06" izz="0.001016" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_joint1" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="0.0 0.0 0.0625" />
|
||||
<parent link="openarm_right_link0" />
|
||||
<child link="openarm_right_link1" />
|
||||
<axis xyz="0 0 1" />
|
||||
<limit effort="40" lower="-1.396263" upper="3.490659" velocity="16.754666" />
|
||||
</joint>
|
||||
<link name="openarm_right_link2">
|
||||
<visual name="openarm_right_link2_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0301 0.0 -0.1225" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link2.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_link2_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0301 0.0 -0.1225" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link2_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.00839629182351943 -2.0145102027597523e-08 0.03256649300522363" />
|
||||
<mass value="0.2775092746011571" />
|
||||
<inertia ixx="0.000359" ixy="1e-06" ixz="-0.000109" iyy="0.000376" iyz="1e-06" izz="0.000232" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_joint2" type="revolute">
|
||||
<origin rpy="1.57079632679 0 0" xyz="-0.0301 0.0 0.06" />
|
||||
<parent link="openarm_right_link1" />
|
||||
<child link="openarm_right_link2" />
|
||||
<axis xyz="-1 0 0" />
|
||||
<limit effort="40" lower="-0.17453267320510335" upper="3.3161253267948965" velocity="16.754666" />
|
||||
</joint>
|
||||
<link name="openarm_right_link3">
|
||||
<visual name="openarm_right_link3_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.18875" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link3.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_link3_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.18875" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link3_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.002104752099628911 0.0005549085042607548 0.09047470545721961" />
|
||||
<mass value="1.073863338202347" />
|
||||
<inertia ixx="0.004372" ixy="1e-06" ixz="1.1e-05" iyy="0.004319" iyz="-3.6e-05" izz="0.000661" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_joint3" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="0.0301 0.0 0.06625" />
|
||||
<parent link="openarm_right_link2" />
|
||||
<child link="openarm_right_link3" />
|
||||
<axis xyz="0 0 1" />
|
||||
<limit effort="27" lower="-1.570796" upper="1.570796" velocity="5.445426" />
|
||||
</joint>
|
||||
<link name="openarm_right_link4">
|
||||
<visual name="openarm_right_link4_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0315 -0.3425" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link4.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_link4_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0315 -0.3425" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link4_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0029006831074562967 -0.03030575826634669 0.06339637422196209" />
|
||||
<mass value="0.6348534566833373" />
|
||||
<inertia ixx="0.000623" ixy="-1e-06" ixz="-1.9e-05" iyy="0.000511" iyz="3.8e-05" izz="0.000334" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_joint4" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="-0.0 0.0315 0.15375" />
|
||||
<parent link="openarm_right_link3" />
|
||||
<child link="openarm_right_link4" />
|
||||
<axis xyz="0 1 0" />
|
||||
<limit effort="27" lower="0.0" upper="2.443461" velocity="5.445426" />
|
||||
</joint>
|
||||
<link name="openarm_right_link5">
|
||||
<visual name="openarm_right_link5_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.438" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link5.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_link5_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.438" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link5_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.003049665024221911 0.0008866902457326625 0.043079803024980934" />
|
||||
<mass value="0.6156588026168502" />
|
||||
<inertia ixx="0.000423" ixy="-8e-06" ixz="6e-06" iyy="0.000445" iyz="-6e-06" izz="0.000324" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_joint5" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="0.0 -0.0315 0.0955" />
|
||||
<parent link="openarm_right_link4" />
|
||||
<child link="openarm_right_link5" />
|
||||
<axis xyz="0 0 1" />
|
||||
<limit effort="7" lower="-1.570796" upper="1.570796" velocity="20.943946" />
|
||||
</joint>
|
||||
<link name="openarm_right_link6">
|
||||
<visual name="openarm_right_link6_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0375 -0.0 -0.5585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link6.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_link6_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0375 -0.0 -0.5585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link6_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.037136587005447405 0.00033230528343419053 -9.498374522309838e-05" />
|
||||
<mass value="0.475202773187987" />
|
||||
<inertia ixx="0.000143" ixy="1e-06" ixz="1e-06" iyy="0.000157" iyz="1e-06" izz="0.000159" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_joint6" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="0.0375 0.0 0.1205" />
|
||||
<parent link="openarm_right_link5" />
|
||||
<child link="openarm_right_link6" />
|
||||
<axis xyz="1 0 0" />
|
||||
<limit effort="7" lower="-0.785398" upper="0.785398" velocity="20.943946" />
|
||||
</joint>
|
||||
<link name="openarm_right_link7">
|
||||
<visual name="openarm_right_link7_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0 -0.5585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link7.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_link7_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0 -0.5585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link7_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="6.875510271106056e-05 0.01266175250761268 0.06951945409987448" />
|
||||
<mass value="0.4659771327380578" />
|
||||
<inertia ixx="0.000639" ixy="1e-06" ixz="1e-06" iyy="0.000497" iyz="8.9e-05" izz="0.000342" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_joint7" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="-0.0375 0.0 0.0" />
|
||||
<parent link="openarm_right_link6" />
|
||||
<child link="openarm_right_link7" />
|
||||
<axis xyz="0 1 0" />
|
||||
<limit effort="7" lower="-1.570796" upper="1.570796" velocity="20.943946" />
|
||||
</joint>
|
||||
<link name="openarm_left_hand">
|
||||
<visual name="openarm_left_hand_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 -0.6585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/visual/hand.dae" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_hand_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 -0.6585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/collision/hand.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0 0 0" xyz="0.0 0.002 0.03" />
|
||||
<mass value="0.35" />
|
||||
<inertia ixx="0.0002473" ixy="1e-06" ixz="1e-06" iyy="1.763e-05" iyz="1e-06" izz="0.0002521" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="left_openarm_hand_joint" type="fixed">
|
||||
<parent link="openarm_left_link7" />
|
||||
<child link="openarm_left_hand" />
|
||||
<origin rpy="0 0 0" xyz="0 -0.0 0.1001" />
|
||||
</joint>
|
||||
<link name="openarm_left_hand_tcp">
|
||||
<inertial>
|
||||
<origin xyz="0 0 0" rpy="0 0 0" />
|
||||
<mass value="0.001" />
|
||||
<inertia ixx="0.000001" ixy="0.0" ixz="0.0" iyy="0.000001" iyz="0.0" izz="0.000001" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_hand_tcp_joint" type="fixed">
|
||||
<origin rpy="0 0 0" xyz="0 -0.0 0.08" />
|
||||
<parent link="openarm_left_hand" />
|
||||
<child link="openarm_left_hand_tcp" />
|
||||
</joint>
|
||||
<link name="openarm_left_left_finger">
|
||||
<visual name="openarm_left_left_finger_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.05 -0.673001" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/visual/finger.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_left_finger_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.05 -0.673001" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/collision/finger.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0 0 0" xyz="0.0064528 0.01702 0.0219685" />
|
||||
<mass value="0.03602545343277134" />
|
||||
<inertia ixx="2.3749999999999997e-06" ixy="1e-06" ixz="1e-06" iyy="2.3749999999999997e-06" iyz="1e-06" izz="7.5e-07" />
|
||||
</inertial>
|
||||
</link>
|
||||
<link name="openarm_left_right_finger">
|
||||
<visual name="openarm_left_right_finger_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.05 -0.673001" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/visual/finger.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_right_finger_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.05 -0.673001" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/collision/finger.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0 0 0" xyz="0.0064528 -0.01702 0.0219685" />
|
||||
<mass value="0.03602545343277134" />
|
||||
<inertia ixx="2.3749999999999997e-06" ixy="1e-06" ixz="1e-06" iyy="2.3749999999999997e-06" iyz="1e-06" izz="7.5e-07" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_finger_joint1" type="prismatic">
|
||||
<parent link="openarm_left_hand" />
|
||||
<child link="openarm_left_right_finger" />
|
||||
<origin rpy="0 0 0" xyz="0 -0.006 0.015" />
|
||||
<axis xyz="0 -1 0" />
|
||||
<limit effort="333" lower="0.0" upper="0.044" velocity="10.0" />
|
||||
</joint>
|
||||
<joint name="openarm_left_finger_joint2" type="prismatic">
|
||||
<parent link="openarm_left_hand" />
|
||||
<child link="openarm_left_left_finger" />
|
||||
<origin rpy="0 0 0" xyz="0 0.006 0.015" />
|
||||
<axis xyz="0 1 0" />
|
||||
<limit effort="333" lower="0.0" upper="0.044" velocity="10.0" />
|
||||
<mimic joint="openarm_left_finger_joint1" />
|
||||
</joint>
|
||||
<link name="openarm_right_hand">
|
||||
<visual name="openarm_right_hand_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 -0.6585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/visual/hand.dae" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_hand_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 -0.6585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/collision/hand.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0 0 0" xyz="0.0 0.002 0.03" />
|
||||
<mass value="0.35" />
|
||||
<inertia ixx="0.0002473" ixy="1e-06" ixz="1e-06" iyy="1.763e-05" iyz="1e-06" izz="0.0002521" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="right_openarm_hand_joint" type="fixed">
|
||||
<parent link="openarm_right_link7" />
|
||||
<child link="openarm_right_hand" />
|
||||
<origin rpy="0 0 0" xyz="0 -0.0 0.1001" />
|
||||
</joint>
|
||||
<link name="openarm_right_hand_tcp">
|
||||
<inertial>
|
||||
<origin xyz="0 0 0" rpy="0 0 0" />
|
||||
<mass value="0.001" />
|
||||
<inertia ixx="0.000001" ixy="0.0" ixz="0.0" iyy="0.000001" iyz="0.0" izz="0.000001" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_hand_tcp_joint" type="fixed">
|
||||
<origin rpy="0 0 0" xyz="0 -0.0 0.08" />
|
||||
<parent link="openarm_right_hand" />
|
||||
<child link="openarm_right_hand_tcp" />
|
||||
</joint>
|
||||
<link name="openarm_right_left_finger">
|
||||
<visual name="openarm_right_left_finger_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.05 -0.673001" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/visual/finger.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_left_finger_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.05 -0.673001" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/collision/finger.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0 0 0" xyz="0.0064528 0.01702 0.0219685" />
|
||||
<mass value="0.03602545343277134" />
|
||||
<inertia ixx="2.3749999999999997e-06" ixy="1e-06" ixz="1e-06" iyy="2.3749999999999997e-06" iyz="1e-06" izz="7.5e-07" />
|
||||
</inertial>
|
||||
</link>
|
||||
<link name="openarm_right_right_finger">
|
||||
<visual name="openarm_right_right_finger_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.05 -0.673001" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/visual/finger.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_right_finger_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.05 -0.673001" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/collision/finger.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0 0 0" xyz="0.0064528 -0.01702 0.0219685" />
|
||||
<mass value="0.03602545343277134" />
|
||||
<inertia ixx="2.3749999999999997e-06" ixy="1e-06" ixz="1e-06" iyy="2.3749999999999997e-06" iyz="1e-06" izz="7.5e-07" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_finger_joint1" type="prismatic">
|
||||
<parent link="openarm_right_hand" />
|
||||
<child link="openarm_right_right_finger" />
|
||||
<origin rpy="0 0 0" xyz="0 -0.006 0.015" />
|
||||
<axis xyz="0 -1 0" />
|
||||
<limit effort="333" lower="0.0" upper="0.044" velocity="10.0" />
|
||||
</joint>
|
||||
<joint name="openarm_right_finger_joint2" type="prismatic">
|
||||
<parent link="openarm_right_hand" />
|
||||
<child link="openarm_right_left_finger" />
|
||||
<origin rpy="0 0 0" xyz="0 0.006 0.015" />
|
||||
<axis xyz="0 1 0" />
|
||||
<limit effort="333" lower="0.0" upper="0.044" velocity="10.0" />
|
||||
<mimic joint="openarm_right_finger_joint1" />
|
||||
</joint>
|
||||
</robot>
|
||||
395
examples/openarms/record_with_compensation.py
Normal file
395
examples/openarms/record_with_compensation.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
OpenArms Dataset Recording with Gravity + Friction Compensation
|
||||
|
||||
Records a dataset using OpenArms follower robot with leader teleoperator.
|
||||
Leader arms have gravity and friction compensation for weightless, easy movement.
|
||||
Includes 3 cameras: left wrist, right wrist, and base camera.
|
||||
|
||||
Uses the same compensation approach as teleop_with_compensation.py
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
# Recording parameters
|
||||
NUM_EPISODES = 1
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 600
|
||||
RESET_TIME_SEC = 120
|
||||
TASK_DESCRIPTION = "OpenArms task description"
|
||||
|
||||
# Friction compensation scale factor (1.0 = full, 0.3 = 30% for stability)
|
||||
FRICTION_SCALE = 1.0
|
||||
|
||||
def record_loop_with_compensation(
|
||||
robot,
|
||||
leader,
|
||||
events,
|
||||
fps,
|
||||
dataset,
|
||||
dataset_features,
|
||||
control_time_s,
|
||||
single_task,
|
||||
display_data=True,
|
||||
):
|
||||
"""
|
||||
Custom record loop that applies gravity + friction compensation to leader.
|
||||
Based on record_loop but with integrated compensation.
|
||||
"""
|
||||
dt = 1 / fps
|
||||
episode_start_time = time.perf_counter()
|
||||
|
||||
# All joints (both arms)
|
||||
all_joints = []
|
||||
for motor in leader.bus_right.motors:
|
||||
all_joints.append(f"right_{motor}")
|
||||
for motor in leader.bus_left.motors:
|
||||
all_joints.append(f"left_{motor}")
|
||||
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
elapsed = loop_start - episode_start_time
|
||||
|
||||
# Check if we should exit
|
||||
if elapsed >= control_time_s or events["exit_early"] or events["stop_recording"]:
|
||||
break
|
||||
|
||||
# Get leader state
|
||||
leader_action = leader.get_action()
|
||||
|
||||
# Extract positions and velocities in degrees
|
||||
leader_positions_deg = {}
|
||||
leader_velocities_deg_per_sec = {}
|
||||
|
||||
for motor in leader.bus_right.motors:
|
||||
pos_key = f"right_{motor}.pos"
|
||||
vel_key = f"right_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
|
||||
|
||||
for motor in leader.bus_left.motors:
|
||||
pos_key = f"left_{motor}.pos"
|
||||
vel_key = f"left_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
|
||||
|
||||
# Calculate gravity torques for leader using built-in method
|
||||
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
|
||||
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
|
||||
|
||||
# Calculate friction torques for leader using built-in method
|
||||
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
|
||||
leader_friction_torques_nm = leader._friction_from_velocity(
|
||||
leader_velocities_rad_per_sec,
|
||||
friction_scale=FRICTION_SCALE
|
||||
)
|
||||
|
||||
# Combine gravity + friction torques
|
||||
leader_total_torques_nm = {}
|
||||
for motor_name in leader_gravity_torques_nm:
|
||||
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
|
||||
friction = leader_friction_torques_nm.get(motor_name, 0.0)
|
||||
leader_total_torques_nm[motor_name] = gravity + friction
|
||||
|
||||
# Apply gravity + friction compensation to leader RIGHT arm (all joints including gripper)
|
||||
for motor in leader.bus_right.motors:
|
||||
full_name = f"right_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get damping gain for stability
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0,
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Apply gravity + friction compensation to leader LEFT arm (all joints including gripper)
|
||||
for motor in leader.bus_left.motors:
|
||||
full_name = f"left_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get damping gain for stability
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_left._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0,
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Send leader positions to follower (both arms)
|
||||
follower_action = {}
|
||||
for joint in all_joints:
|
||||
pos_key = f"{joint}.pos"
|
||||
if pos_key in leader_action:
|
||||
follower_action[pos_key] = leader_action[pos_key]
|
||||
|
||||
# Send action to robot
|
||||
if follower_action:
|
||||
robot.send_action(follower_action)
|
||||
|
||||
# Get observation from robot (includes camera images)
|
||||
observation = robot.get_observation()
|
||||
|
||||
# Add to dataset if we have a dataset
|
||||
if dataset is not None:
|
||||
# Build properly formatted observation frame
|
||||
obs_frame = build_dataset_frame(dataset_features, observation, prefix="observation")
|
||||
|
||||
# Build properly formatted action frame (keep .pos suffix - it matches the feature names)
|
||||
action_frame = build_dataset_frame(dataset_features, follower_action, prefix="action")
|
||||
|
||||
# Combine into single frame
|
||||
frame = {**obs_frame, **action_frame}
|
||||
|
||||
# Add metadata (task is required, timestamp will be auto-calculated by add_frame)
|
||||
frame["task"] = single_task
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# Display data if requested
|
||||
if display_data:
|
||||
log_rerun_data(observation=observation, action=follower_action)
|
||||
|
||||
# Maintain loop rate
|
||||
loop_duration = time.perf_counter() - loop_start
|
||||
sleep_time = dt - loop_duration
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main recording loop with gravity compensation."""
|
||||
|
||||
print("=" * 70)
|
||||
print("OpenArms Dataset Recording with Compensation")
|
||||
print("=" * 70)
|
||||
|
||||
# Create camera configurations (3 cameras: left wrist, right wrist, base)
|
||||
# Using actual device paths found by lerobot-find-cameras opencv
|
||||
camera_config = {
|
||||
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video0", width=640, height=480, fps=FPS),
|
||||
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=FPS),
|
||||
"base": OpenCVCameraConfig(index_or_path="/dev/video7", width=640, height=480, fps=FPS),
|
||||
}
|
||||
|
||||
# Configure follower robot with cameras
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left="can2",
|
||||
port_right="can3",
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0,
|
||||
cameras=camera_config,
|
||||
)
|
||||
|
||||
# Configure leader teleoperator (no cameras needed)
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_leader",
|
||||
manual_control=False, # Enable torque control for gravity compensation
|
||||
)
|
||||
|
||||
# Initialize robot and teleoperator
|
||||
print("\nInitializing devices...")
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
leader = OpenArmsLeader(leader_config)
|
||||
|
||||
# Connect devices
|
||||
print("Connecting and calibrating...")
|
||||
follower.connect(calibrate=True)
|
||||
leader.connect(calibrate=True)
|
||||
|
||||
# Verify URDF is loaded for gravity compensation
|
||||
if leader.pin_robot is None:
|
||||
raise RuntimeError("URDF model not loaded on leader. Gravity compensation not available.")
|
||||
|
||||
# Configure the dataset features
|
||||
# For actions, we only want to record positions (not velocity or torque)
|
||||
action_features_hw = {}
|
||||
for key, value in follower.action_features.items():
|
||||
if key.endswith(".pos"):
|
||||
action_features_hw[key] = value
|
||||
|
||||
action_features = hw_to_dataset_features(action_features_hw, "action")
|
||||
obs_features = hw_to_dataset_features(follower.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
print("\nCreating dataset...")
|
||||
repo_id = "<hf_username>/<dataset_repo_id>" # TODO: Replace with your Hugging Face repo
|
||||
|
||||
# Check if dataset already exists and prompt user
|
||||
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id
|
||||
while dataset_path.exists():
|
||||
print(f"\nDataset already exists at: {dataset_path}")
|
||||
print("\nOptions:")
|
||||
print(" 1. Overwrite existing dataset")
|
||||
print(" 2. Use a different name")
|
||||
print(" 3. Abort")
|
||||
|
||||
choice = input("\nEnter your choice (1/2/3): ").strip()
|
||||
|
||||
if choice == '1':
|
||||
print(f"Removing existing dataset...")
|
||||
shutil.rmtree(dataset_path)
|
||||
print("✓ Existing dataset removed")
|
||||
break
|
||||
elif choice == '2':
|
||||
print("\nCurrent repo_id:", repo_id)
|
||||
new_repo_id = input("Enter new repo_id (format: <username>/<dataset_name>): ").strip()
|
||||
if new_repo_id and '/' in new_repo_id:
|
||||
repo_id = new_repo_id
|
||||
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id
|
||||
print(f"✓ Using new repo_id: {repo_id}")
|
||||
# Loop will continue if this new path also exists
|
||||
else:
|
||||
print("Invalid repo_id format. Please use format: <username>/<dataset_name>")
|
||||
elif choice == '3':
|
||||
print("Aborting. Please remove the existing dataset manually or restart with a different repo_id.")
|
||||
follower.disconnect()
|
||||
leader.disconnect()
|
||||
return
|
||||
else:
|
||||
print("Invalid choice. Please enter 1, 2, or 3.")
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize keyboard listener and visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="openarms_recording")
|
||||
|
||||
# Enable motors on both leader arms for gravity compensation
|
||||
leader.bus_right.enable_torque()
|
||||
leader.bus_left.enable_torque()
|
||||
time.sleep(0.1)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(f"Recording {NUM_EPISODES} episodes")
|
||||
print(f"Task: {TASK_DESCRIPTION}")
|
||||
print("=" * 70)
|
||||
print("\nLeader BOTH arms: Gravity + Friction comp | Follower BOTH arms: Teleop")
|
||||
print("\nKeyboard controls:")
|
||||
print(" - Press 'q' to stop recording")
|
||||
print(" - Press 'r' to re-record current episode")
|
||||
print("=" * 70)
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
try:
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Record episode with compensation active
|
||||
record_loop_with_compensation(
|
||||
robot=follower,
|
||||
leader=leader,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
dataset_features=dataset_features,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop_with_compensation(
|
||||
robot=follower,
|
||||
leader=leader,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=None, # Don't save reset period
|
||||
dataset_features=dataset_features,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
# Handle re-recording
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Only save episode if frames were recorded
|
||||
if dataset.episode_buffer is not None and dataset.episode_buffer["size"] > 0:
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
else:
|
||||
log_say("No frames recorded, skipping episode save")
|
||||
# Clear the empty buffer
|
||||
dataset.episode_buffer = None
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping recording...")
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
try:
|
||||
leader.bus_right.disable_torque()
|
||||
leader.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
print("✓ Shutdown complete")
|
||||
except Exception as e:
|
||||
print(f"Shutdown error: {e}")
|
||||
|
||||
# Upload dataset
|
||||
print("\nUploading dataset to Hugging Face Hub...")
|
||||
try:
|
||||
dataset.push_to_hub()
|
||||
print("✓ Dataset uploaded successfully")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to upload dataset: {e}")
|
||||
print("You can manually upload later using: dataset.push_to_hub()")
|
||||
|
||||
print("✓ Recording complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
166
examples/openarms/replay.py
Normal file
166
examples/openarms/replay.py
Normal file
@@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OpenArms Dataset Replay Example
|
||||
|
||||
Replays position actions from a recorded dataset on an OpenArms follower robot.
|
||||
Only position commands (ending with .pos) are replayed, not velocity or torque.
|
||||
|
||||
Example usage:
|
||||
python examples/openarms/replay.py
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
# Configuration
|
||||
EPISODE_IDX = 0
|
||||
DATASET_REPO_ID = "lerobot-data-collection/replay-this-2025-11-02-17-58" # TODO: Replace with your dataset
|
||||
DATASET_ROOT = None # Use default cache location, or specify custom path
|
||||
|
||||
# Robot configuration - adjust these to match your setup
|
||||
ROBOT_CONFIG = OpenArmsFollowerConfig(
|
||||
port_left="can2", # CAN interface for left arm
|
||||
port_right="can3", # CAN interface for right arm
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0, # Safety limit: max degrees to move per step
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main replay function."""
|
||||
print("=" * 70)
|
||||
print("OpenArms Dataset Replay")
|
||||
print("=" * 70)
|
||||
print(f"\nDataset: {DATASET_REPO_ID}")
|
||||
print(f"Episode: {EPISODE_IDX}")
|
||||
print(f"Robot: {ROBOT_CONFIG.id}")
|
||||
print(f" Left arm: {ROBOT_CONFIG.port_left}")
|
||||
print(f" Right arm: {ROBOT_CONFIG.port_right}")
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
# Initialize the robot
|
||||
print("\n[1/3] Initializing robot...")
|
||||
robot = OpenArmsFollower(ROBOT_CONFIG)
|
||||
|
||||
# Load the dataset
|
||||
print(f"\n[2/3] Loading dataset '{DATASET_REPO_ID}'...")
|
||||
dataset = LeRobotDataset(
|
||||
DATASET_REPO_ID,
|
||||
root=DATASET_ROOT,
|
||||
episodes=[EPISODE_IDX]
|
||||
)
|
||||
|
||||
# Filter dataset to only include frames from the specified episode
|
||||
# (required for dataset V3.0 where episodes are chunked)
|
||||
episode_frames = dataset.hf_dataset.filter(
|
||||
lambda x: x["episode_index"] == EPISODE_IDX
|
||||
)
|
||||
|
||||
if len(episode_frames) == 0:
|
||||
raise ValueError(
|
||||
f"No frames found for episode {EPISODE_IDX} in dataset {DATASET_REPO_ID}"
|
||||
)
|
||||
|
||||
print(f" Found {len(episode_frames)} frames in episode {EPISODE_IDX}")
|
||||
|
||||
# Extract action features from dataset
|
||||
action_features = dataset.features.get(ACTION, {})
|
||||
action_names = action_features.get("names", [])
|
||||
|
||||
# Filter to only position actions (ending with .pos)
|
||||
position_action_names = [name for name in action_names if name.endswith(".pos")]
|
||||
|
||||
if not position_action_names:
|
||||
raise ValueError(
|
||||
f"No position actions found in dataset. Action names: {action_names}"
|
||||
)
|
||||
|
||||
print(f" Found {len(position_action_names)} position actions to replay")
|
||||
print(f" Actions: {', '.join(position_action_names[:5])}{'...' if len(position_action_names) > 5 else ''}")
|
||||
|
||||
# Select only action columns from dataset
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Connect to the robot
|
||||
print(f"\n[3/3] Connecting to robot...")
|
||||
robot.connect(calibrate=False) # Skip calibration for replay
|
||||
|
||||
if not robot.is_connected:
|
||||
raise RuntimeError("Robot failed to connect!")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Ready to replay!")
|
||||
print("=" * 70)
|
||||
print("\nThe robot will replay the recorded positions.")
|
||||
print("Press Ctrl+C to stop at any time.\n")
|
||||
|
||||
input("Press ENTER to start replaying...")
|
||||
|
||||
# Replay loop
|
||||
log_say(f"Replaying episode {EPISODE_IDX}", blocking=True)
|
||||
|
||||
try:
|
||||
for idx in range(len(episode_frames)):
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Extract action array from dataset
|
||||
action_array = actions[idx][ACTION]
|
||||
|
||||
# Build action dictionary, but only include position actions
|
||||
action = {}
|
||||
for i, name in enumerate(action_names):
|
||||
# Only include position actions (ending with .pos)
|
||||
if name.endswith(".pos"):
|
||||
action[name] = float(action_array[i])
|
||||
|
||||
# Send action to robot
|
||||
robot.send_action(action)
|
||||
|
||||
# Maintain replay rate (use dataset fps)
|
||||
loop_duration = time.perf_counter() - loop_start
|
||||
dt_s = 1.0 / dataset.fps - loop_duration
|
||||
busy_wait(dt_s)
|
||||
|
||||
# Progress indicator every 100 frames
|
||||
if (idx + 1) % 100 == 0:
|
||||
progress = (idx + 1) / len(episode_frames) * 100
|
||||
print(f"Progress: {idx + 1}/{len(episode_frames)} frames ({progress:.1f}%)")
|
||||
|
||||
print(f"\n✓ Successfully replayed {len(episode_frames)} frames")
|
||||
log_say("Replay complete", blocking=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nReplay interrupted by user")
|
||||
finally:
|
||||
# Disconnect robot
|
||||
print("\nDisconnecting robot...")
|
||||
robot.disconnect()
|
||||
print("✓ Replay complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
73
examples/openarms/setup_can.sh
Executable file
73
examples/openarms/setup_can.sh
Executable file
@@ -0,0 +1,73 @@
|
||||
#!/bin/bash
|
||||
# Setup all OpenArms CAN interfaces with CAN FD
|
||||
|
||||
set -e
|
||||
|
||||
echo "=========================================="
|
||||
echo "OpenArms CAN FD Interface Setup"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "Mode: CAN FD"
|
||||
echo " - Nominal bitrate: 1 Mbps"
|
||||
echo " - Data bitrate: 5 Mbps"
|
||||
echo ""
|
||||
echo "Configuring interfaces can0, can1, can2, can3..."
|
||||
echo ""
|
||||
|
||||
# Configure each CAN interface with CAN FD
|
||||
for i in 0 1 2 3; do
|
||||
interface="can$i"
|
||||
|
||||
# Check if interface exists
|
||||
if ! ip link show "$interface" &> /dev/null; then
|
||||
echo "⚠ $interface: Not found, skipping"
|
||||
continue
|
||||
fi
|
||||
|
||||
# Bring down interface
|
||||
sudo ip link set "$interface" down 2>/dev/null
|
||||
|
||||
# Configure CAN FD mode
|
||||
sudo ip link set "$interface" type can \
|
||||
bitrate 1000000 \
|
||||
dbitrate 5000000 \
|
||||
fd on
|
||||
|
||||
# Bring up interface
|
||||
sudo ip link set "$interface" up
|
||||
|
||||
# Verify configuration
|
||||
if ip link show "$interface" | grep -q "UP"; then
|
||||
echo "✓ $interface: Configured and UP"
|
||||
else
|
||||
echo "✗ $interface: Failed to bring UP"
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Verification"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Show detailed status for each interface
|
||||
for i in 0 1 2 3; do
|
||||
interface="can$i"
|
||||
if ip link show "$interface" &> /dev/null; then
|
||||
echo "$interface:"
|
||||
# Show key parameters
|
||||
ip -d link show "$interface" | grep -E "can|state|bitrate|dbitrate" | head -3
|
||||
echo ""
|
||||
fi
|
||||
done
|
||||
|
||||
echo "=========================================="
|
||||
echo "Setup Complete!"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "All interfaces configured for CAN FD mode"
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo " 1. Test motors: python debug_can_communication.py"
|
||||
echo " 2. Run teleoperation: python examples/openarms/teleop.py"
|
||||
echo ""
|
||||
148
examples/openarms/teleop.py
Normal file
148
examples/openarms/teleop.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
OpenArms Teleoperation Example - Full Dual Arms
|
||||
|
||||
This script demonstrates teleoperation of OpenArms follower robot using an OpenArms leader arm.
|
||||
It first calibrates both devices, then enters a teleoperation loop for both arms.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
|
||||
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left="can2", # CAN interface for follower left arm
|
||||
port_right="can3", # CAN interface for follower right arm
|
||||
can_interface="socketcan", # Linux SocketCAN
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=5.0, # Safety limit
|
||||
)
|
||||
|
||||
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
port_left="can0", # CAN interface for leader left arm
|
||||
port_right="can1", # CAN interface for leader right arm
|
||||
can_interface="socketcan", # Linux SocketCAN
|
||||
id="openarms_leader",
|
||||
manual_control=True, # Enable manual control (torque disabled)
|
||||
)
|
||||
|
||||
print("=" * 60)
|
||||
print("OpenArms Teleoperation - Full Dual Arms")
|
||||
print("=" * 60)
|
||||
|
||||
# Initialize devices
|
||||
print("\n[1/4] Initializing devices...")
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
leader = OpenArmsLeader(leader_config)
|
||||
|
||||
# Connect and calibrate follower
|
||||
print("\n[2/4] Connecting and calibrating follower robot...")
|
||||
print("Note: If you have existing calibration, just press ENTER to use it.")
|
||||
follower.connect(calibrate=True)
|
||||
|
||||
# Connect and calibrate leader
|
||||
print("\n[3/4] Connecting and calibrating leader arm...")
|
||||
print("Note: The leader arm will have torque disabled for manual control.")
|
||||
leader.connect(calibrate=True)
|
||||
|
||||
# Wait for user to be ready
|
||||
print("\n[4/4] Ready for teleoperation!")
|
||||
print("\nBoth arms will be controlled (16 motors total):")
|
||||
print(" RIGHT ARM: joints 1-7 + gripper")
|
||||
print(" LEFT ARM: joints 1-7 + gripper")
|
||||
|
||||
print("\nPress ENTER to start teleoperation...")
|
||||
input()
|
||||
|
||||
print("\nTeleoperation started! Move both leader arms.")
|
||||
print("Press Ctrl+C to stop.\n")
|
||||
|
||||
# All joints for both arms (16 motors total)
|
||||
all_joints = [
|
||||
# Right arm
|
||||
"right_joint_1",
|
||||
"right_joint_2",
|
||||
"right_joint_3",
|
||||
"right_joint_4",
|
||||
"right_joint_5",
|
||||
"right_joint_6",
|
||||
"right_joint_7",
|
||||
"right_gripper",
|
||||
# Left arm
|
||||
"left_joint_1",
|
||||
"left_joint_2",
|
||||
"left_joint_3",
|
||||
"left_joint_4",
|
||||
"left_joint_5",
|
||||
"left_joint_6",
|
||||
"left_joint_7",
|
||||
"left_gripper",
|
||||
]
|
||||
|
||||
# Performance monitoring
|
||||
loop_times = []
|
||||
start_time = time.perf_counter()
|
||||
last_print_time = start_time
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get action from leader
|
||||
leader_action = leader.get_action()
|
||||
|
||||
# Filter to only position data for all joints (both arms)
|
||||
joint_action = {}
|
||||
for joint in all_joints:
|
||||
pos_key = f"{joint}.pos"
|
||||
if pos_key in leader_action:
|
||||
joint_action[pos_key] = leader_action[pos_key]
|
||||
|
||||
# Send action to follower (both arms)
|
||||
if joint_action:
|
||||
follower.send_action(joint_action)
|
||||
|
||||
# Measure loop time
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
# Print stats every 2 seconds
|
||||
if loop_end - last_print_time >= 2.0:
|
||||
if loop_times:
|
||||
avg_time = sum(loop_times) / len(loop_times)
|
||||
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||
min_time = min(loop_times)
|
||||
max_time = max(loop_times)
|
||||
max_hz = 1.0 / min_time if min_time > 0 else 0
|
||||
min_hz = 1.0 / max_time if max_time > 0 else 0
|
||||
|
||||
print(f"[Hz Stats] Avg: {current_hz:.1f} Hz | "
|
||||
f"Range: {min_hz:.1f}-{max_hz:.1f} Hz | "
|
||||
f"Avg loop time: {avg_time*1000:.1f} ms")
|
||||
|
||||
# Reset for next measurement window
|
||||
loop_times = []
|
||||
last_print_time = loop_end
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping teleoperation...")
|
||||
finally:
|
||||
# Disconnect devices
|
||||
print("Disconnecting devices...")
|
||||
try:
|
||||
follower.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting follower: {e}")
|
||||
|
||||
try:
|
||||
leader.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting leader: {e}")
|
||||
|
||||
print("Done!")
|
||||
197
examples/openarms/teleop_openarms_mini.py
Normal file
197
examples/openarms/teleop_openarms_mini.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
OpenArms Mini Teleoperation Example
|
||||
|
||||
This script demonstrates teleoperation of an OpenArms follower robot using
|
||||
an OpenArms Mini leader (Feetech-based) with dual arms (16 motors total).
|
||||
|
||||
The OpenArms Mini has:
|
||||
- Right arm: 8 motors (joint_1 to joint_7 + gripper)
|
||||
- Left arm: 8 motors (joint_1 to joint_7 + gripper)
|
||||
|
||||
Note on gripper normalization:
|
||||
- OpenArms Mini gripper: 0-100 scale (0=closed, 100=open)
|
||||
- OpenArms follower gripper: degrees (0=closed, -65=open)
|
||||
- This script automatically converts between the two ranges
|
||||
"""
|
||||
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.teleoperators.openarms_mini.openarms_mini import OpenArmsMini
|
||||
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
|
||||
# Target control frequency
|
||||
TARGET_FPS = 30
|
||||
|
||||
# Configure the OpenArms follower (Damiao motors on CAN bus)
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left="can0", # CAN interface for follower left arm
|
||||
port_right="can1", # CAN interface for follower right arm
|
||||
can_interface="socketcan", # Linux SocketCAN
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0, # Safety limit (degrees per step)
|
||||
)
|
||||
|
||||
# Configure the OpenArms Mini leader (Feetech motors on serial)
|
||||
leader_config = OpenArmsMiniConfig(
|
||||
port_right="/dev/ttyACM0", # Serial port for right arm
|
||||
port_left="/dev/ttyACM1", # Serial port for left arm
|
||||
id="openarms_mini",
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
print("OpenArms Mini → OpenArms Follower Teleoperation")
|
||||
|
||||
# Initialize devices
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
leader = OpenArmsMini(leader_config)
|
||||
|
||||
# Connect and calibrate follower
|
||||
print("Note: If you have existing calibration, just press ENTER to use it.")
|
||||
follower.connect(calibrate=True)
|
||||
|
||||
# Connect and calibrate leader
|
||||
print("Note: The leader arms will have torque disabled for manual control.")
|
||||
leader.connect(calibrate=True)
|
||||
|
||||
print("\nPress ENTER to start teleoperation...")
|
||||
input()
|
||||
|
||||
print("Press Ctrl+C to stop.\n")
|
||||
|
||||
# All joints for both arms (16 motors total)
|
||||
all_joints = [
|
||||
# Right arm
|
||||
"right_joint_1",
|
||||
"right_joint_2",
|
||||
"right_joint_3",
|
||||
"right_joint_4",
|
||||
"right_joint_5",
|
||||
"right_joint_6",
|
||||
"right_joint_7",
|
||||
"right_gripper",
|
||||
# Left arm
|
||||
"left_joint_1",
|
||||
"left_joint_2",
|
||||
"left_joint_3",
|
||||
"left_joint_4",
|
||||
"left_joint_5",
|
||||
"left_joint_6",
|
||||
"left_joint_7",
|
||||
"left_gripper",
|
||||
]
|
||||
|
||||
# Performance monitoring
|
||||
loop_times = []
|
||||
avg_loop_time = 0.0
|
||||
min_loop_time = float('inf')
|
||||
max_loop_time = 0.0
|
||||
stats_update_interval = 1.0 # Update stats every 1 second
|
||||
last_stats_update = time.perf_counter()
|
||||
|
||||
|
||||
SWAPPED_JOINTS = {
|
||||
"right_joint_6": "right_joint_7",
|
||||
"right_joint_7": "right_joint_6",
|
||||
"left_joint_6": "left_joint_7",
|
||||
"left_joint_7": "left_joint_6",
|
||||
}
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get actions and observations
|
||||
leader_action = leader.get_action()
|
||||
follower_obs = follower.get_observation()
|
||||
|
||||
joint_action = {}
|
||||
for joint in all_joints:
|
||||
leader_key = f"{joint}.pos"
|
||||
|
||||
# Determine which follower joint this leader joint controls
|
||||
follower_joint = SWAPPED_JOINTS.get(joint, joint)
|
||||
follower_key = f"{follower_joint}.pos"
|
||||
|
||||
# Get leader position (default 0 if missing)
|
||||
pos = leader_action.get(leader_key, 0.0)
|
||||
|
||||
# Convert gripper values: Mini uses 0-100, OpenArms uses 0 to -65 degrees
|
||||
if "gripper" in joint:
|
||||
# Map 0-100 (Mini) to 0 to -65 (OpenArms)
|
||||
# 0 (closed) -> 0°, 100 (open) -> -65°
|
||||
pos = (pos / 100.0) * -65.0
|
||||
|
||||
# Store in action dict for follower
|
||||
joint_action[follower_key] = pos
|
||||
|
||||
follower.send_action(joint_action)
|
||||
|
||||
# Loop timing
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
# Update stats periodically
|
||||
current_time = time.perf_counter()
|
||||
if current_time - last_stats_update >= stats_update_interval:
|
||||
if loop_times:
|
||||
avg_loop_time = sum(loop_times) / len(loop_times)
|
||||
min_loop_time = min(loop_times)
|
||||
max_loop_time = max(loop_times)
|
||||
loop_times = []
|
||||
last_stats_update = current_time
|
||||
|
||||
# Display everything
|
||||
sys.stdout.write("\033[H\033[J") # Clear screen
|
||||
|
||||
# Show timing stats at the top
|
||||
if avg_loop_time > 0:
|
||||
avg_hz = 1.0 / avg_loop_time
|
||||
min_hz = 1.0 / max_loop_time if max_loop_time > 0 else 0
|
||||
max_hz = 1.0 / min_loop_time if min_loop_time > 0 and min_loop_time < float('inf') else 0
|
||||
print(f"[Performance] Target: {TARGET_FPS} Hz | Avg: {avg_hz:.1f} Hz | Range: {min_hz:.1f}-{max_hz:.1f} Hz | Loop: {avg_loop_time*1000:.1f} ms\n")
|
||||
else:
|
||||
print(f"[Performance] Target: {TARGET_FPS} Hz | Measuring...\n")
|
||||
|
||||
# Show joint positions
|
||||
print(f"{'Joint':<20} {'Leader':>15} {'Follower':>15}")
|
||||
print(f"{'':20} {'(0-100/deg)':>15} {'(deg)':>15}")
|
||||
print("-" * 52)
|
||||
|
||||
for joint in all_joints:
|
||||
leader_key = f"{joint}.pos"
|
||||
follower_joint = SWAPPED_JOINTS.get(joint, joint)
|
||||
follower_key = f"{follower_joint}.pos"
|
||||
|
||||
leader_pos = leader_action.get(leader_key, 0.0)
|
||||
follower_pos = follower_obs.get(follower_key, 0.0)
|
||||
|
||||
print(f"{joint:<20} {leader_pos:>15.2f} {follower_pos:>15.2f}")
|
||||
|
||||
# Smart sleep to maintain target FPS
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
busy_wait(max(0, 1.0 / TARGET_FPS - dt_s))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping teleoperation...")
|
||||
finally:
|
||||
# Disconnect devices
|
||||
print("Disconnecting devices...")
|
||||
try:
|
||||
follower.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting follower: {e}")
|
||||
|
||||
try:
|
||||
leader.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting leader: {e}")
|
||||
|
||||
print("Done!")
|
||||
|
||||
202
examples/openarms/teleop_with_compensation.py
Executable file
202
examples/openarms/teleop_with_compensation.py
Executable file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
OpenArms Teleoperation with Gravity + Friction Compensation
|
||||
|
||||
Leader arms (both LEFT and RIGHT): Gravity + Friction compensation (weightless, easy to move)
|
||||
Follower arms (both LEFT and RIGHT): Mirror leader movements
|
||||
|
||||
Uses the URDF file from the lerobot repository.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
|
||||
# Friction compensation scale factor (1.0 = full, 0.3 = 30% for stability)
|
||||
FRICTION_SCALE = 1.0
|
||||
|
||||
def main():
|
||||
"""Main teleoperation loop with gravity compensation"""
|
||||
|
||||
print("=" * 70)
|
||||
print("OpenArms Teleoperation with Gravity Compensation")
|
||||
print("=" * 70)
|
||||
|
||||
# Configuration
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left="can2",
|
||||
port_right="can3",
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0,
|
||||
)
|
||||
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_leader",
|
||||
manual_control=False, # Enable torque control for gravity compensation
|
||||
)
|
||||
|
||||
# Initialize and connect
|
||||
print("\nInitializing devices...")
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
leader = OpenArmsLeader(leader_config)
|
||||
|
||||
follower.connect()
|
||||
leader.connect()
|
||||
|
||||
# URDF is automatically loaded in the leader constructor
|
||||
if leader.pin_robot is None:
|
||||
raise RuntimeError("URDF model not loaded on leader. Gravity compensation not available.")
|
||||
|
||||
print("\nLeader BOTH arms: Gravity + Friction comp | Follower BOTH arms: Teleop")
|
||||
print("Press ENTER to start...")
|
||||
input()
|
||||
|
||||
# Enable motors on both leader arms for gravity compensation
|
||||
leader.bus_right.enable_torque()
|
||||
leader.bus_left.enable_torque()
|
||||
time.sleep(0.1)
|
||||
|
||||
print("Press Ctrl+C to stop\n")
|
||||
|
||||
# Main control loop
|
||||
loop_times = []
|
||||
last_print_time = time.perf_counter()
|
||||
|
||||
# All joints (both arms)
|
||||
all_joints = []
|
||||
for motor in leader.bus_right.motors:
|
||||
all_joints.append(f"right_{motor}")
|
||||
for motor in leader.bus_left.motors:
|
||||
all_joints.append(f"left_{motor}")
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get leader state
|
||||
leader_action = leader.get_action()
|
||||
|
||||
# Extract positions and velocities in degrees
|
||||
leader_positions_deg = {}
|
||||
leader_velocities_deg_per_sec = {}
|
||||
|
||||
for motor in leader.bus_right.motors:
|
||||
pos_key = f"right_{motor}.pos"
|
||||
vel_key = f"right_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
|
||||
|
||||
for motor in leader.bus_left.motors:
|
||||
pos_key = f"left_{motor}.pos"
|
||||
vel_key = f"left_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
|
||||
|
||||
# Calculate gravity torques for leader using built-in method
|
||||
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
|
||||
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
|
||||
|
||||
# Calculate friction torques for leader using built-in method
|
||||
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
|
||||
leader_friction_torques_nm = leader._friction_from_velocity(
|
||||
leader_velocities_rad_per_sec,
|
||||
friction_scale=FRICTION_SCALE
|
||||
)
|
||||
|
||||
# Combine gravity + friction torques
|
||||
leader_total_torques_nm = {}
|
||||
for motor_name in leader_gravity_torques_nm:
|
||||
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
|
||||
friction = leader_friction_torques_nm.get(motor_name, 0.0)
|
||||
leader_total_torques_nm[motor_name] = gravity + friction
|
||||
|
||||
# Apply gravity + friction compensation to leader RIGHT arm (all joints including gripper)
|
||||
for motor in leader.bus_right.motors:
|
||||
full_name = f"right_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get damping gain for stability
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0,
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Apply gravity + friction compensation to leader LEFT arm (all joints including gripper)
|
||||
for motor in leader.bus_left.motors:
|
||||
full_name = f"left_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get damping gain for stability
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_left._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0,
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Send leader positions to follower (both arms)
|
||||
follower_action = {}
|
||||
for joint in all_joints:
|
||||
pos_key = f"{joint}.pos"
|
||||
if pos_key in leader_action:
|
||||
follower_action[pos_key] = leader_action[pos_key]
|
||||
|
||||
if follower_action:
|
||||
follower.send_action(follower_action)
|
||||
|
||||
# Performance monitoring
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
if loop_end - last_print_time >= 2.0:
|
||||
if loop_times:
|
||||
avg_time = sum(loop_times) / len(loop_times)
|
||||
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||
|
||||
print(f"{current_hz:.1f} Hz ({avg_time*1000:.1f} ms)")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = loop_end
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping...")
|
||||
finally:
|
||||
try:
|
||||
leader.bus_right.disable_torque()
|
||||
leader.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
print("✓ Shutdown complete")
|
||||
except Exception as e:
|
||||
print(f"Shutdown error: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
76
examples/openarms/test_r_arm.py
Normal file
76
examples/openarms/test_r_arm.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import time
|
||||
import math
|
||||
import numpy as np
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
|
||||
|
||||
def main():
|
||||
cfg = OpenArmsFollowerConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_test",
|
||||
manual_control=True, # direct position control
|
||||
)
|
||||
|
||||
print('connecting...')
|
||||
rob = OpenArmsFollower(cfg)
|
||||
rob.connect(calibrate=True)
|
||||
|
||||
# disable left torque fully — keep it still
|
||||
rob.bus_left.disable_torque()
|
||||
|
||||
# desired angular sweep = 1/4 of current joint range
|
||||
sweep_deg = 20.0 # tweak if you want bigger movement
|
||||
|
||||
# frequency of movement
|
||||
hz = 100.0
|
||||
dt = 1.0 / hz
|
||||
move_time = 1.0 # seconds per joint
|
||||
|
||||
print('starting right–arm joint test…')
|
||||
print('support the arm and keep clear')
|
||||
|
||||
time.sleep(1.0)
|
||||
|
||||
# iterate motors except gripper
|
||||
for motor in rob.bus_right.motors:
|
||||
if motor == 'gripper':
|
||||
continue
|
||||
|
||||
print(f'testing {motor} on right arm...')
|
||||
start = time.time()
|
||||
|
||||
# read current position as center
|
||||
obs = rob.get_action()
|
||||
key = f'right_{motor}.pos'
|
||||
center = obs.get(key, 0.0)
|
||||
|
||||
t = 0.0
|
||||
while time.time() - start < move_time:
|
||||
offset = sweep_deg * math.sin(2 * math.pi * t)
|
||||
pos_cmd = center + offset
|
||||
|
||||
rob.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=3.0, # some stiffness so it tracks well
|
||||
kd=0.2,
|
||||
position_degrees=pos_cmd,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=0.0
|
||||
)
|
||||
|
||||
t += dt
|
||||
time.sleep(dt)
|
||||
|
||||
print(f'done {motor}')
|
||||
|
||||
print('\nall right–arm joints tested')
|
||||
print('disabling torque…')
|
||||
rob.bus_right.disable_torque()
|
||||
rob.disconnect()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
745
examples/openarms_web_interface/App.css
Normal file
745
examples/openarms_web_interface/App.css
Normal file
@@ -0,0 +1,745 @@
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
|
||||
main {
|
||||
min-height: 100vh;
|
||||
padding: 2rem;
|
||||
}
|
||||
|
||||
header {
|
||||
text-align: center;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 2rem;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
h2 {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin: 0 0 1rem 0;
|
||||
}
|
||||
|
||||
h3 {
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
color: #666;
|
||||
margin: 0 0 0.5rem 0;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 1920px;
|
||||
margin: 0 auto;
|
||||
display: grid;
|
||||
grid-template-columns: minmax(500px, 600px) 1fr;
|
||||
gap: 2rem;
|
||||
align-items: start;
|
||||
}
|
||||
|
||||
/* Left column container */
|
||||
.left-column {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
/* Right column container */
|
||||
.right-column {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
/* Responsive: Stack on smaller screens */
|
||||
@media (max-width: 1200px) {
|
||||
.container {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
.panel {
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
.config-panel {
|
||||
border: 2px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.config-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.config-header:hover {
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
.toggle-icon {
|
||||
font-size: 1rem;
|
||||
color: #6b7280;
|
||||
transition: transform 0.2s;
|
||||
}
|
||||
|
||||
.config-content {
|
||||
margin-top: 1rem;
|
||||
padding-top: 1rem;
|
||||
border-top: 1px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.robot-setup {
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.robot-status {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 1rem;
|
||||
border-radius: 6px;
|
||||
font-weight: 500;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.robot-status.ready {
|
||||
background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%);
|
||||
color: #065f46;
|
||||
border: 1px solid #10b981;
|
||||
}
|
||||
|
||||
.robot-status.not-ready {
|
||||
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
|
||||
color: #92400e;
|
||||
border: 1px solid #f59e0b;
|
||||
}
|
||||
|
||||
.btn-setup {
|
||||
background: #10b981;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.btn-setup:hover:not(:disabled) {
|
||||
background: #059669;
|
||||
}
|
||||
|
||||
.btn-setup:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.btn-zero {
|
||||
background: #8b5cf6;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.btn-zero:hover:not(:disabled) {
|
||||
background: #7c3aed;
|
||||
}
|
||||
|
||||
.btn-zero:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.zero-position-section {
|
||||
margin-top: 1rem;
|
||||
padding-top: 1rem;
|
||||
border-top: 1px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.btn-zero-large {
|
||||
width: 100%;
|
||||
background: #8b5cf6;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.875rem 1.5rem;
|
||||
border-radius: 8px;
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
box-shadow: 0 2px 4px rgba(139, 92, 246, 0.2);
|
||||
}
|
||||
|
||||
.btn-zero-large:hover:not(:disabled) {
|
||||
background: #7c3aed;
|
||||
box-shadow: 0 4px 8px rgba(139, 92, 246, 0.3);
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.btn-zero-large:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
box-shadow: none;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
.delete-episode-section {
|
||||
margin-top: 1rem;
|
||||
padding-top: 1rem;
|
||||
border-top: 1px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.btn-delete {
|
||||
width: 100%;
|
||||
background: #ef4444;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.875rem 1.5rem;
|
||||
border-radius: 8px;
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
box-shadow: 0 2px 4px rgba(239, 68, 68, 0.2);
|
||||
}
|
||||
|
||||
.btn-delete:hover:not(:disabled) {
|
||||
background: #dc2626;
|
||||
box-shadow: 0 4px 8px rgba(239, 68, 68, 0.3);
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.btn-delete:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
box-shadow: none;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
.delete-info {
|
||||
margin-top: 0.5rem;
|
||||
font-size: 0.875rem;
|
||||
color: #666;
|
||||
text-align: center;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.btn-disconnect {
|
||||
background: #ef4444;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.btn-disconnect:hover {
|
||||
background: #dc2626;
|
||||
}
|
||||
|
||||
.btn-refresh {
|
||||
background: #3b82f6;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.4rem 0.8rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.btn-refresh:hover:not(:disabled) {
|
||||
background: #2563eb;
|
||||
}
|
||||
|
||||
.btn-refresh:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.control-panel {
|
||||
border: 2px solid #10b981;
|
||||
}
|
||||
|
||||
.status-banner {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 1rem;
|
||||
padding: 1rem 1.5rem;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 1.5rem;
|
||||
font-weight: 500;
|
||||
font-size: 0.95rem;
|
||||
}
|
||||
|
||||
.status-banner.initializing {
|
||||
background: linear-gradient(135deg, #dbeafe 0%, #bfdbfe 100%);
|
||||
color: #1e40af;
|
||||
border-left: 4px solid #3b82f6;
|
||||
}
|
||||
|
||||
.status-banner.encoding {
|
||||
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
|
||||
color: #92400e;
|
||||
border-left: 4px solid #f59e0b;
|
||||
}
|
||||
|
||||
.status-banner.uploading {
|
||||
background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%);
|
||||
color: #3730a3;
|
||||
border-left: 4px solid #6366f1;
|
||||
}
|
||||
|
||||
.status-banner.success {
|
||||
background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%);
|
||||
color: #065f46;
|
||||
border-left: 4px solid #10b981;
|
||||
}
|
||||
|
||||
.status-banner.warning {
|
||||
background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%);
|
||||
color: #991b1b;
|
||||
border-left: 4px solid #ef4444;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
border: 3px solid rgba(0, 0, 0, 0.1);
|
||||
border-top-color: currentColor;
|
||||
border-radius: 50%;
|
||||
animation: spin 0.8s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
to { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
.control-horizontal {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.control-left {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.control-right {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.input-group {
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
input[type="text"] {
|
||||
flex: 1;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
input[type="text"]:disabled {
|
||||
background: #f5f5f5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
input[type="text"]:focus {
|
||||
outline: none;
|
||||
border-color: #10b981;
|
||||
}
|
||||
|
||||
button {
|
||||
padding: 0.75rem 1.5rem;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
font-size: 1rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.btn-set-task {
|
||||
background: #3b82f6;
|
||||
color: white;
|
||||
min-width: 120px;
|
||||
}
|
||||
|
||||
.btn-set-task:hover:not(:disabled) {
|
||||
background: #2563eb;
|
||||
}
|
||||
|
||||
.btn-set-task:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.btn-start {
|
||||
background: #10b981;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-start:hover:not(:disabled) {
|
||||
background: #059669;
|
||||
}
|
||||
|
||||
.btn-start:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.btn-stop {
|
||||
background: #ef4444;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-stop:hover {
|
||||
background: #dc2626;
|
||||
}
|
||||
|
||||
.btn-reset {
|
||||
padding: 0.5rem 1rem;
|
||||
background: #6b7280;
|
||||
color: white;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.btn-reset:hover {
|
||||
background: #4b5563;
|
||||
}
|
||||
|
||||
.status {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
padding: 1rem;
|
||||
border-radius: 4px;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.status.recording {
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
}
|
||||
|
||||
.status.recording.recording-active {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
background: #dc2626;
|
||||
color: white;
|
||||
padding: 1.5rem;
|
||||
border: 4px solid #991b1b;
|
||||
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.4);
|
||||
font-weight: 700;
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
.status.recording.recording-active .indicator {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
background: #fef2f2;
|
||||
animation: pulse-strong 1s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse-strong {
|
||||
0%, 100% {
|
||||
opacity: 1;
|
||||
transform: scale(1);
|
||||
}
|
||||
50% {
|
||||
opacity: 0.7;
|
||||
transform: scale(1.1);
|
||||
}
|
||||
}
|
||||
|
||||
.status.recording.recording-active .time-display {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
font-size: 1.5rem;
|
||||
font-weight: 700;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.fps-display {
|
||||
font-size: 1rem;
|
||||
font-weight: 500;
|
||||
opacity: 0.95;
|
||||
}
|
||||
|
||||
.fps-warning {
|
||||
color: #fef2f2;
|
||||
animation: pulse-warning 1s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse-warning {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.5; }
|
||||
}
|
||||
|
||||
.status.recording.recording-active .btn-stop {
|
||||
align-self: stretch;
|
||||
}
|
||||
|
||||
.ramp-up-countdown {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.countdown-box {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 2rem 3rem;
|
||||
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
|
||||
border: 4px solid #f59e0b;
|
||||
border-radius: 16px;
|
||||
box-shadow: 0 6px 20px rgba(245, 158, 11, 0.4);
|
||||
min-width: 280px;
|
||||
animation: pulse-warm 1.5s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse-warm {
|
||||
0%, 100% {
|
||||
box-shadow: 0 6px 20px rgba(245, 158, 11, 0.4);
|
||||
}
|
||||
50% {
|
||||
box-shadow: 0 6px 25px rgba(245, 158, 11, 0.6);
|
||||
}
|
||||
}
|
||||
|
||||
.countdown-label {
|
||||
font-size: 1rem;
|
||||
color: #92400e;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 1.5px;
|
||||
font-weight: 800;
|
||||
margin-bottom: 1rem;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.countdown-value {
|
||||
font-size: 4.5rem;
|
||||
font-weight: 900;
|
||||
color: #d97706;
|
||||
font-family: 'Courier New', monospace;
|
||||
line-height: 1;
|
||||
text-shadow: 2px 2px 6px rgba(0, 0, 0, 0.15);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.countdown-subtitle {
|
||||
font-size: 0.875rem;
|
||||
color: #78350f;
|
||||
font-weight: 600;
|
||||
font-style: italic;
|
||||
text-align: center;
|
||||
margin-top: 0.5rem;
|
||||
}
|
||||
|
||||
.status.idle {
|
||||
background: #f3f4f6;
|
||||
color: #374151;
|
||||
}
|
||||
|
||||
.indicator {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border-radius: 50%;
|
||||
background: #ef4444;
|
||||
animation: pulse 1.5s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.5; }
|
||||
}
|
||||
|
||||
.counter {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
padding: 1.5rem;
|
||||
background: linear-gradient(135deg, #f9fafb 0%, #f3f4f6 100%);
|
||||
border-radius: 8px;
|
||||
border: 2px solid #e5e7eb;
|
||||
min-width: 200px;
|
||||
}
|
||||
|
||||
.counter-label {
|
||||
font-size: 0.75rem;
|
||||
color: #6b7280;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.counter-value {
|
||||
font-size: 3rem;
|
||||
font-weight: 700;
|
||||
color: #10b981;
|
||||
line-height: 1;
|
||||
}
|
||||
|
||||
.time-display {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 600;
|
||||
font-family: 'Courier New', monospace;
|
||||
}
|
||||
|
||||
.error-box {
|
||||
padding: 1rem;
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
border-radius: 4px;
|
||||
border-left: 4px solid #ef4444;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.config-section {
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.config-section:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.config-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
label {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
font-size: 0.875rem;
|
||||
color: #374151;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
select {
|
||||
padding: 0.5rem;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
background: white;
|
||||
}
|
||||
|
||||
select:disabled {
|
||||
background: #f5f5f5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
/* Camera Layout */
|
||||
.camera-layout {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.camera-base {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.camera-wrist-container {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(2, 1fr);
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.camera-wrist {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.camera {
|
||||
border: 1px solid #e5e7eb;
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.camera h3 {
|
||||
padding: 0.75rem;
|
||||
background: #f9fafb;
|
||||
border-bottom: 1px solid #e5e7eb;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.camera img {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
display: block;
|
||||
background: #000;
|
||||
min-height: 300px;
|
||||
object-fit: cover;
|
||||
}
|
||||
|
||||
.camera-placeholder {
|
||||
text-align: center;
|
||||
padding: 4rem 2rem;
|
||||
background: #f9fafb;
|
||||
border-radius: 4px;
|
||||
border: 2px dashed #d1d5db;
|
||||
}
|
||||
|
||||
.camera-placeholder p {
|
||||
margin: 0.5rem 0;
|
||||
font-size: 1rem;
|
||||
color: #6b7280;
|
||||
}
|
||||
|
||||
.camera-placeholder p:first-child {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 500;
|
||||
color: #374151;
|
||||
}
|
||||
|
||||
.hint {
|
||||
margin-top: 0.5rem;
|
||||
font-size: 0.75rem;
|
||||
color: #6b7280;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
857
examples/openarms_web_interface/App.jsx
Normal file
857
examples/openarms_web_interface/App.jsx
Normal file
@@ -0,0 +1,857 @@
|
||||
import { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import './App.css';
|
||||
|
||||
const API_BASE = 'http://localhost:8000/api';
|
||||
|
||||
function App() {
|
||||
// State
|
||||
const [task, setTask] = useState('');
|
||||
const [isRecording, setIsRecording] = useState(false);
|
||||
const [isInitializing, setIsInitializing] = useState(false);
|
||||
const [isEncoding, setIsEncoding] = useState(false);
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [robotsReady, setRobotsReady] = useState(false);
|
||||
const [elapsedTime, setElapsedTime] = useState(0);
|
||||
const [currentFps, setCurrentFps] = useState(0);
|
||||
const [loopFps, setLoopFps] = useState(0);
|
||||
const [episodeCount, setEpisodeCount] = useState(0);
|
||||
const [error, setError] = useState(null);
|
||||
const [statusMessage, setStatusMessage] = useState('Ready');
|
||||
const [uploadStatus, setUploadStatus] = useState(null);
|
||||
const [rampUpRemaining, setRampUpRemaining] = useState(0);
|
||||
const [movingToZero, setMovingToZero] = useState(false);
|
||||
const [configExpanded, setConfigExpanded] = useState(false);
|
||||
const [latestRepoId, setLatestRepoId] = useState(null);
|
||||
|
||||
// Configuration
|
||||
const [config, setConfig] = useState({
|
||||
leader_type: 'openarms', // 'openarms' or 'openarms_mini'
|
||||
leader_left: 'can0',
|
||||
leader_right: 'can1',
|
||||
follower_left: 'can2',
|
||||
follower_right: 'can3',
|
||||
left_wrist: '/dev/video0',
|
||||
right_wrist: '/dev/video1',
|
||||
base: '/dev/video4'
|
||||
});
|
||||
|
||||
// Available options
|
||||
const [availableCameras, setAvailableCameras] = useState([]);
|
||||
const [availableUsbPorts, setAvailableUsbPorts] = useState([]);
|
||||
const canInterfaces = ['can0', 'can1', 'can2', 'can3'];
|
||||
|
||||
const statusIntervalRef = useRef(null);
|
||||
const hasInitializedRef = useRef(false);
|
||||
|
||||
const loadConfig = () => {
|
||||
try {
|
||||
const saved = localStorage.getItem('openarms_config');
|
||||
if (saved) {
|
||||
const loadedConfig = JSON.parse(saved);
|
||||
setConfig(prev => ({ ...prev, ...loadedConfig }));
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Load config error:', e);
|
||||
}
|
||||
};
|
||||
|
||||
const saveConfig = (newConfig) => {
|
||||
try {
|
||||
localStorage.setItem('openarms_config', JSON.stringify(newConfig || config));
|
||||
} catch (e) {
|
||||
console.error('Save config error:', e);
|
||||
}
|
||||
};
|
||||
|
||||
// Fetch status periodically
|
||||
const fetchStatus = async () => {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/status`);
|
||||
const data = await response.json();
|
||||
|
||||
setIsRecording(data.is_recording);
|
||||
setIsInitializing(data.is_initializing);
|
||||
setIsEncoding(data.is_encoding);
|
||||
setIsUploading(data.is_uploading);
|
||||
setRobotsReady(data.robots_ready);
|
||||
setElapsedTime(data.elapsed_time);
|
||||
setCurrentFps(data.current_fps || 0);
|
||||
setLoopFps(data.loop_fps || 0);
|
||||
setEpisodeCount(data.episode_count);
|
||||
setError(data.error);
|
||||
setStatusMessage(data.status_message || 'Ready');
|
||||
setUploadStatus(data.upload_status);
|
||||
setRampUpRemaining(data.ramp_up_remaining || 0);
|
||||
setMovingToZero(data.moving_to_zero || false);
|
||||
|
||||
// Track the latest repo_id from the backend
|
||||
if (data.latest_repo_id) {
|
||||
setLatestRepoId(data.latest_repo_id);
|
||||
}
|
||||
|
||||
if (data.config) {
|
||||
// Only merge server config if we don't have a saved config (first load)
|
||||
if (!localStorage.getItem('openarms_config')) {
|
||||
setConfig(prev => {
|
||||
const merged = { ...data.config, ...prev };
|
||||
localStorage.setItem('openarms_config', JSON.stringify(merged));
|
||||
return merged;
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to fetch status:', e);
|
||||
}
|
||||
};
|
||||
|
||||
const setupRobots = async () => {
|
||||
// Show warning to verify camera positions
|
||||
const confirmed = window.confirm(
|
||||
'⚠️ IMPORTANT: Before connecting robots, please verify:\n\n' +
|
||||
'📹 Check that cameras are correctly positioned:\n' +
|
||||
' • LEFT wrist camera is actually on the LEFT arm\n' +
|
||||
' • RIGHT wrist camera is actually on the RIGHT arm\n' +
|
||||
' • BASE camera is actually the BASE/overhead camera\n\n' +
|
||||
'Incorrect camera positioning will result in invalid training data!\n\n' +
|
||||
'Click OK to continue with robot setup, or Cancel to review configuration.'
|
||||
);
|
||||
|
||||
if (!confirmed) {
|
||||
return; // User cancelled, don't proceed
|
||||
}
|
||||
|
||||
setError(null);
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/robots/setup`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(config)
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to setup robots');
|
||||
}
|
||||
|
||||
await response.json();
|
||||
saveConfig(config);
|
||||
} catch (e) {
|
||||
setError(`Robot setup failed: ${e.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Disconnect robots
|
||||
const disconnectRobots = async () => {
|
||||
try {
|
||||
await fetch(`${API_BASE}/robots/disconnect`, { method: 'POST' });
|
||||
setRobotsReady(false);
|
||||
} catch (e) {
|
||||
console.error('Failed to disconnect robots:', e);
|
||||
}
|
||||
};
|
||||
|
||||
// Discover cameras
|
||||
const discoverCameras = async () => {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/cameras/discover`);
|
||||
const data = await response.json();
|
||||
const cameras = data.cameras || [];
|
||||
setAvailableCameras(cameras);
|
||||
|
||||
// Get list of valid camera IDs
|
||||
const validCameraIds = cameras.map(cam => String(cam.id));
|
||||
|
||||
// Auto-fix config if current values are invalid or not set
|
||||
const updated = { ...config };
|
||||
let changed = false;
|
||||
|
||||
// Auto-fix invalid camera config
|
||||
if (!config.left_wrist || !validCameraIds.includes(config.left_wrist)) {
|
||||
if (cameras.length >= 1) {
|
||||
updated.left_wrist = String(cameras[0].id);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!config.right_wrist || !validCameraIds.includes(config.right_wrist)) {
|
||||
if (cameras.length >= 2) {
|
||||
updated.right_wrist = String(cameras[1].id);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!config.base || !validCameraIds.includes(config.base)) {
|
||||
if (cameras.length >= 3) {
|
||||
updated.base = String(cameras[2].id);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
setConfig(updated);
|
||||
saveConfig(updated);
|
||||
}
|
||||
|
||||
if (cameras.length === 0) {
|
||||
setError('No cameras detected! Please connect cameras and refresh.');
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to discover cameras:', e);
|
||||
setError(`Camera discovery failed: ${e.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Discover USB ports
|
||||
const discoverUsbPorts = async () => {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/usb/discover`);
|
||||
const data = await response.json();
|
||||
const ports = data.ports || [];
|
||||
setAvailableUsbPorts(ports);
|
||||
|
||||
// Auto-fix config if OpenArms Mini is selected and ports are invalid
|
||||
if (config.leader_type === 'openarms_mini') {
|
||||
const updated = { ...config };
|
||||
let changed = false;
|
||||
|
||||
if (ports.length >= 1 && !ports.includes(config.leader_left)) {
|
||||
updated.leader_left = ports[0];
|
||||
changed = true;
|
||||
}
|
||||
|
||||
if (ports.length >= 2 && !ports.includes(config.leader_right)) {
|
||||
updated.leader_right = ports[1];
|
||||
changed = true;
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
setConfig(updated);
|
||||
saveConfig(updated);
|
||||
}
|
||||
}
|
||||
|
||||
if (ports.length === 0) {
|
||||
console.warn('No USB ports detected for OpenArms Mini');
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to discover USB ports:', e);
|
||||
}
|
||||
};
|
||||
|
||||
// Set task only (for pedal use)
|
||||
const setTaskOnly = async () => {
|
||||
if (!task.trim()) {
|
||||
setError('Please enter a task description');
|
||||
return;
|
||||
}
|
||||
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/recording/set-task`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ task, ...config })
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to set task');
|
||||
}
|
||||
|
||||
const result = await response.json();
|
||||
setStatusMessage(result.message || `Task set: ${task}`);
|
||||
saveConfig(config);
|
||||
|
||||
// Clear success message after 3 seconds
|
||||
setTimeout(() => {
|
||||
if (!isRecording && !isInitializing) {
|
||||
setStatusMessage('Ready');
|
||||
}
|
||||
}, 3000);
|
||||
} catch (e) {
|
||||
setError(e.message);
|
||||
}
|
||||
};
|
||||
|
||||
// Start recording
|
||||
const startRecording = async () => {
|
||||
if (!task.trim()) {
|
||||
setError('Please enter a task description');
|
||||
return;
|
||||
}
|
||||
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/recording/start`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ task, ...config })
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to start recording');
|
||||
}
|
||||
|
||||
await response.json();
|
||||
saveConfig(config);
|
||||
} catch (e) {
|
||||
setError(e.message);
|
||||
}
|
||||
};
|
||||
|
||||
// Stop recording
|
||||
const stopRecording = async () => {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/recording/stop`, {
|
||||
method: 'POST'
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to stop recording');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
setError(null);
|
||||
// Update latest repo_id after recording
|
||||
if (data.dataset_name) {
|
||||
setLatestRepoId(`lerobot-data-collection/${data.dataset_name}`);
|
||||
}
|
||||
} catch (e) {
|
||||
setError(e.message);
|
||||
}
|
||||
};
|
||||
|
||||
const deleteLatestEpisode = async () => {
|
||||
if (!latestRepoId) {
|
||||
setError('No episode to delete');
|
||||
return;
|
||||
}
|
||||
|
||||
const confirmed = window.confirm(
|
||||
`WARNING: This will permanently delete the repository:\n\n${latestRepoId}\n\nThis action cannot be undone. Continue?`
|
||||
);
|
||||
|
||||
if (!confirmed) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/recording/delete-latest`, { method: 'POST' });
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to delete episode');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
setLatestRepoId(null);
|
||||
setEpisodeCount(Math.max(0, episodeCount - 1));
|
||||
setStatusMessage(`Deleted: ${data.deleted_repo}`);
|
||||
|
||||
setTimeout(() => {
|
||||
if (!isRecording && !isInitializing) {
|
||||
setStatusMessage('Ready');
|
||||
}
|
||||
}, 3000);
|
||||
} catch (e) {
|
||||
setError(`Delete failed: ${e.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Reset counter
|
||||
const resetCounter = async () => {
|
||||
try {
|
||||
await fetch(`${API_BASE}/counter/reset`, { method: 'POST' });
|
||||
setEpisodeCount(0);
|
||||
} catch (e) {
|
||||
console.error('Failed to reset counter:', e);
|
||||
}
|
||||
};
|
||||
|
||||
// Move robot to zero position
|
||||
const moveToZero = async () => {
|
||||
setError(null);
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/robots/move-to-zero`, { method: 'POST' });
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to move to zero position');
|
||||
}
|
||||
await response.json();
|
||||
} catch (e) {
|
||||
setError(`Move to zero failed: ${e.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Format time as MM:SS
|
||||
const formatTime = (seconds) => {
|
||||
const mins = Math.floor(seconds / 60);
|
||||
const secs = Math.floor(seconds % 60);
|
||||
return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}`;
|
||||
};
|
||||
|
||||
// Update config and save
|
||||
const updateConfig = (key, value) => {
|
||||
const updated = { ...config, [key]: value };
|
||||
setConfig(updated);
|
||||
saveConfig(updated);
|
||||
};
|
||||
|
||||
// Initialize on mount only
|
||||
useEffect(() => {
|
||||
// Prevent double-initialization in development
|
||||
if (hasInitializedRef.current) {
|
||||
return;
|
||||
}
|
||||
hasInitializedRef.current = true;
|
||||
|
||||
loadConfig();
|
||||
discoverCameras();
|
||||
discoverUsbPorts();
|
||||
fetchStatus();
|
||||
statusIntervalRef.current = setInterval(fetchStatus, 1000);
|
||||
|
||||
return () => {
|
||||
if (statusIntervalRef.current) {
|
||||
clearInterval(statusIntervalRef.current);
|
||||
}
|
||||
};
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []); // Run only once on mount
|
||||
|
||||
// Discover USB ports when leader type changes to Mini
|
||||
useEffect(() => {
|
||||
if (config.leader_type === 'openarms_mini') {
|
||||
discoverUsbPorts();
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [config.leader_type]);
|
||||
|
||||
return (
|
||||
<main>
|
||||
<header>
|
||||
<h1>OpenArms Recording</h1>
|
||||
</header>
|
||||
|
||||
<div className="container">
|
||||
{/* Left Column: Configuration and Recording Control */}
|
||||
<div className="left-column">
|
||||
{/* Configuration Panel */}
|
||||
<section className="panel config-panel">
|
||||
<div
|
||||
className="config-header"
|
||||
onClick={() => setConfigExpanded(!configExpanded)}
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
onKeyDown={(e) => e.key === 'Enter' && setConfigExpanded(!configExpanded)}
|
||||
>
|
||||
<h2>⚙️ Configuration</h2>
|
||||
<span className="toggle-icon">{configExpanded ? '▼' : '▶'}</span>
|
||||
</div>
|
||||
|
||||
{configExpanded && (
|
||||
<div className="config-content">
|
||||
{/* Robot Setup */}
|
||||
<div className="config-section">
|
||||
<h3>🤖 Robot Setup</h3>
|
||||
<div className="robot-setup">
|
||||
{robotsReady ? (
|
||||
<div className="robot-status ready">
|
||||
<span>✅ Robots Ready - Recording will start instantly</span>
|
||||
<button onClick={disconnectRobots} className="btn-disconnect">
|
||||
Disconnect Robots
|
||||
</button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="robot-status not-ready">
|
||||
<span>⚠️ Robots not initialized - Recording will take ~10 seconds</span>
|
||||
<button
|
||||
onClick={setupRobots}
|
||||
disabled={isRecording || isInitializing}
|
||||
className="btn-setup"
|
||||
>
|
||||
🚀 Setup Robots
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Leader Type Selection */}
|
||||
<div className="config-section">
|
||||
<h3>🎮 Leader Type</h3>
|
||||
<div className="config-grid">
|
||||
<label style={{gridColumn: '1 / -1'}}>
|
||||
Leader Arm Type
|
||||
<select
|
||||
value={config.leader_type}
|
||||
onChange={(e) => updateConfig('leader_type', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
<option value="openarms">OpenArms (CAN Bus - Damiao Motors)</option>
|
||||
<option value="openarms_mini">OpenArms Mini (USB - Feetech Motors)</option>
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Leader Interfaces (CAN or USB based on type) */}
|
||||
<div className="config-section">
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
|
||||
<h3>
|
||||
{config.leader_type === 'openarms_mini'
|
||||
? `Leader Ports (USB/Serial) ${availableUsbPorts.length > 0 ? `(${availableUsbPorts.length} detected)` : ''}`
|
||||
: 'Leader Interfaces (CAN)'}
|
||||
</h3>
|
||||
{config.leader_type === 'openarms_mini' && (
|
||||
<button
|
||||
onClick={discoverUsbPorts}
|
||||
className="btn-refresh"
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
🔄 Refresh
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="config-grid">
|
||||
<label>
|
||||
Leader Left
|
||||
<select
|
||||
value={config.leader_left}
|
||||
onChange={(e) => updateConfig('leader_left', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{config.leader_type === 'openarms_mini' ? (
|
||||
availableUsbPorts.length > 0 ? (
|
||||
availableUsbPorts.map((port) => (
|
||||
<option key={port} value={port}>{port}</option>
|
||||
))
|
||||
) : (
|
||||
<option value="">No USB ports detected</option>
|
||||
)
|
||||
) : (
|
||||
canInterfaces.map((iface) => (
|
||||
<option key={iface} value={iface}>{iface}</option>
|
||||
))
|
||||
)}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Leader Right
|
||||
<select
|
||||
value={config.leader_right}
|
||||
onChange={(e) => updateConfig('leader_right', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{config.leader_type === 'openarms_mini' ? (
|
||||
availableUsbPorts.length > 0 ? (
|
||||
availableUsbPorts.map((port) => (
|
||||
<option key={port} value={port}>{port}</option>
|
||||
))
|
||||
) : (
|
||||
<option value="">No USB ports detected</option>
|
||||
)
|
||||
) : (
|
||||
canInterfaces.map((iface) => (
|
||||
<option key={iface} value={iface}>{iface}</option>
|
||||
))
|
||||
)}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Follower CAN Interfaces */}
|
||||
<div className="config-section">
|
||||
<h3>Follower Interfaces (CAN)</h3>
|
||||
|
||||
<div className="config-grid">
|
||||
<label>
|
||||
Follower Left
|
||||
<select
|
||||
value={config.follower_left}
|
||||
onChange={(e) => updateConfig('follower_left', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{canInterfaces.map((iface) => (
|
||||
<option key={iface} value={iface}>{iface}</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Follower Right
|
||||
<select
|
||||
value={config.follower_right}
|
||||
onChange={(e) => updateConfig('follower_right', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{canInterfaces.map((iface) => (
|
||||
<option key={iface} value={iface}>{iface}</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Camera Configuration */}
|
||||
<div className="config-section">
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
|
||||
<h3>Cameras {availableCameras.length > 0 && `(${availableCameras.length} detected)`}</h3>
|
||||
<button
|
||||
onClick={discoverCameras}
|
||||
className="btn-refresh"
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
🔄 Refresh
|
||||
</button>
|
||||
</div>
|
||||
<div className="config-grid">
|
||||
<label>
|
||||
Left Wrist
|
||||
<select
|
||||
value={config.left_wrist}
|
||||
onChange={(e) => updateConfig('left_wrist', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{availableCameras.map((cam) => (
|
||||
<option key={cam.id} value={String(cam.id)}>
|
||||
{cam.name || `Camera @ ${cam.id}`}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Right Wrist
|
||||
<select
|
||||
value={config.right_wrist}
|
||||
onChange={(e) => updateConfig('right_wrist', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{availableCameras.map((cam) => (
|
||||
<option key={cam.id} value={String(cam.id)}>
|
||||
{cam.name || `Camera @ ${cam.id}`}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Base Camera
|
||||
<select
|
||||
value={config.base}
|
||||
onChange={(e) => updateConfig('base', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{availableCameras.map((cam) => (
|
||||
<option key={cam.id} value={String(cam.id)}>
|
||||
{cam.name || `Camera @ ${cam.id}`}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
|
||||
{/* Control Panel */}
|
||||
<section className="panel control-panel">
|
||||
<h2>🎬 Recording Control</h2>
|
||||
|
||||
{/* Status Banner - Always show important statuses */}
|
||||
{isInitializing && (
|
||||
<div className="status-banner initializing">
|
||||
<div className="spinner"></div>
|
||||
<span>{statusMessage}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isEncoding && (
|
||||
<div className="status-banner encoding">
|
||||
<div className="spinner"></div>
|
||||
<span>📹 {statusMessage}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isUploading && (
|
||||
<div className="status-banner uploading">
|
||||
<div className="spinner"></div>
|
||||
<span>☁️ {statusMessage}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{uploadStatus && !isRecording && !isEncoding && !isUploading && (
|
||||
<div className={`status-banner ${uploadStatus.startsWith('✓') ? 'success' : 'warning'}`}>
|
||||
<span>{uploadStatus}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="control-horizontal">
|
||||
{/* Task Input and Status */}
|
||||
<div className="control-left">
|
||||
<div className="input-group">
|
||||
<input
|
||||
type="text"
|
||||
value={task}
|
||||
onChange={(e) => setTask(e.target.value)}
|
||||
placeholder="Task description (e.g., 'pick and place')"
|
||||
disabled={isRecording || isInitializing || isEncoding || isUploading}
|
||||
onKeyPress={(e) => {
|
||||
if (e.key === 'Enter' && robotsReady) {
|
||||
setTaskOnly();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<button
|
||||
onClick={setTaskOnly}
|
||||
disabled={isRecording || isInitializing || isEncoding || isUploading || !robotsReady}
|
||||
className="btn-set-task"
|
||||
title={!robotsReady ? 'Please setup robots first' : 'Store task for pedal use (Enter key)'}
|
||||
>
|
||||
💾 Set Task
|
||||
</button>
|
||||
<button
|
||||
onClick={startRecording}
|
||||
disabled={isRecording || isInitializing || isEncoding || isUploading || !robotsReady}
|
||||
className="btn-start"
|
||||
title={!robotsReady ? 'Please setup robots first' : ''}
|
||||
>
|
||||
{isInitializing
|
||||
? '⏳ Initializing...'
|
||||
: isRecording
|
||||
? '⏺ Recording...'
|
||||
: robotsReady
|
||||
? '⏺ Start Recording'
|
||||
: '⏺ Setup Robots First'}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Ramp-up Countdown */}
|
||||
{isRecording && rampUpRemaining > 0 && (
|
||||
<div className="ramp-up-countdown">
|
||||
<div className="countdown-box">
|
||||
<div className="countdown-label">⚡ WARMING UP - PID RAMP-UP</div>
|
||||
<div className="countdown-value">{rampUpRemaining.toFixed(1)}s</div>
|
||||
<div className="countdown-subtitle">Recording will start automatically...</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Recording Status - Only show after ramp-up */}
|
||||
{isRecording && rampUpRemaining <= 0 && (
|
||||
<div className="status recording recording-active">
|
||||
<div className="indicator"></div>
|
||||
<div className="time-display">
|
||||
<span>{formatTime(elapsedTime)}</span>
|
||||
<span className="fps-display">
|
||||
Loop: {loopFps.toFixed(1)} Hz
|
||||
{loopFps > 0 && loopFps < 29 && <span className="fps-warning"> ⚠️</span>}
|
||||
</span>
|
||||
<span className="fps-display">Recording: {currentFps.toFixed(1)} FPS</span>
|
||||
</div>
|
||||
<button onClick={stopRecording} className="btn-stop">
|
||||
⏹ Stop
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Episode Counter */}
|
||||
<div className="control-right">
|
||||
<div className="counter">
|
||||
<div className="counter-label">Episodes Recorded</div>
|
||||
<div className="counter-value">{episodeCount}</div>
|
||||
<button onClick={resetCounter} className="btn-reset">
|
||||
Reset
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Delete Latest Episode Button */}
|
||||
{!isRecording && !isInitializing && latestRepoId && (
|
||||
<div className="delete-episode-section">
|
||||
<button
|
||||
onClick={deleteLatestEpisode}
|
||||
className="btn-delete"
|
||||
title="Delete the latest recorded episode from HuggingFace Hub"
|
||||
>
|
||||
Delete Latest Episode
|
||||
</button>
|
||||
<div className="delete-info">Will delete: {latestRepoId}</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Move to Zero Button */}
|
||||
{robotsReady && !isRecording && !isInitializing && (
|
||||
<div className="zero-position-section">
|
||||
<button
|
||||
onClick={moveToZero}
|
||||
disabled={movingToZero}
|
||||
className="btn-zero-large"
|
||||
title="Move both leader and follower robots to zero position (2s)"
|
||||
>
|
||||
{movingToZero ? '⏳ Moving to Zero Position...' : '🎯 Move to Zero Position (Leader + Follower)'}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Error Display */}
|
||||
{error && (
|
||||
<div className="error-box">
|
||||
⚠️ {error}
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
</div>
|
||||
|
||||
{/* Right Column: Camera Feeds */}
|
||||
<div className="right-column">
|
||||
<section className="panel cameras">
|
||||
<h2>📹 Camera Views</h2>
|
||||
{robotsReady || isRecording || isInitializing ? (
|
||||
<div className="camera-layout">
|
||||
{/* Base camera - full width */}
|
||||
<div className="camera camera-base">
|
||||
<h3>Base Camera</h3>
|
||||
<img src={`${API_BASE}/camera/stream/base`} alt="Base Camera" />
|
||||
</div>
|
||||
|
||||
{/* Wrist cameras - side by side */}
|
||||
<div className="camera-wrist-container">
|
||||
<div className="camera camera-wrist">
|
||||
<h3>Left Wrist</h3>
|
||||
<img src={`${API_BASE}/camera/stream/left_wrist`} alt="Left Wrist Camera" />
|
||||
</div>
|
||||
|
||||
<div className="camera camera-wrist">
|
||||
<h3>Right Wrist</h3>
|
||||
<img src={`${API_BASE}/camera/stream/right_wrist`} alt="Right Wrist Camera" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="camera-placeholder">
|
||||
<p>📷 Camera feeds will appear when robots are set up</p>
|
||||
<p className="hint">Click "Setup Robots" above to preview camera feeds</p>
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</main>
|
||||
);
|
||||
}
|
||||
|
||||
export default App;
|
||||
|
||||
41
examples/openarms_web_interface/README.md
Normal file
41
examples/openarms_web_interface/README.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# OpenArms Web Recording Interface
|
||||
|
||||
A web interface for recording OpenArms datasets.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
cd examples/openarms_web_interface
|
||||
npm install
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
**Start everything with one command:**
|
||||
|
||||
```bash
|
||||
./launch.sh
|
||||
```
|
||||
|
||||
This will:
|
||||
- Start the FastAPI backend on port 8000
|
||||
- Start the React frontend on port 5173
|
||||
- Show live logs from both services
|
||||
|
||||
Then open your browser to: **http://localhost:5173**
|
||||
|
||||
**Stop with:** `Ctrl+C`
|
||||
|
||||
---
|
||||
|
||||
## Workflow
|
||||
|
||||
1. **Configure CAN interfaces** and **camera paths** in the dropdowns
|
||||
2. Click **"Setup Robots"** to initialize (once at start)
|
||||
3. Enter a **task description**
|
||||
4. Click **"Start Recording"** to begin an episode
|
||||
5. Click **"Stop Recording"** when done
|
||||
6. Dataset is automatically encoded and uploaded to HuggingFace Hub as **private**
|
||||
7. Repeat steps 3-6 for more episodes (no need to re-setup robots!)
|
||||
|
||||
---
|
||||
12
examples/openarms_web_interface/index.html
Normal file
12
examples/openarms_web_interface/index.html
Normal file
@@ -0,0 +1,12 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>OpenArms Recording Interface</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/main.jsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
142
examples/openarms_web_interface/launch.sh
Executable file
142
examples/openarms_web_interface/launch.sh
Executable file
@@ -0,0 +1,142 @@
|
||||
#!/bin/bash
|
||||
|
||||
# OpenArms Web Interface Launcher
|
||||
# Starts Rerun viewer, FastAPI backend, and React frontend
|
||||
|
||||
set -e
|
||||
|
||||
# Colors for output
|
||||
GREEN='\033[0;32m'
|
||||
BLUE='\033[0;34m'
|
||||
YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Get script directory
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
echo -e "${BLUE}╔════════════════════════════════════════╗${NC}"
|
||||
echo -e "${BLUE}║ OpenArms Web Recording Interface ║${NC}"
|
||||
echo -e "${BLUE}╚════════════════════════════════════════╝${NC}"
|
||||
echo ""
|
||||
|
||||
# Function to cleanup on exit
|
||||
cleanup() {
|
||||
echo ""
|
||||
echo -e "${YELLOW}Shutting down services...${NC}"
|
||||
|
||||
# Kill all child processes
|
||||
pkill -P $$ 2>/dev/null || true
|
||||
|
||||
# Kill specific services by port
|
||||
lsof -ti:8000 | xargs kill -9 2>/dev/null || true # Backend
|
||||
lsof -ti:5173 | xargs kill -9 2>/dev/null || true # Frontend
|
||||
lsof -ti:9876 | xargs kill -9 2>/dev/null || true # Rerun (if spawned)
|
||||
|
||||
echo -e "${GREEN}✓ Services stopped${NC}"
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Register cleanup on script exit
|
||||
trap cleanup EXIT INT TERM
|
||||
|
||||
# Check if required commands exist
|
||||
command -v rerun >/dev/null 2>&1 || {
|
||||
echo -e "${RED}✗ Error: 'rerun' not found. Please install: pip install rerun-sdk${NC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
command -v python >/dev/null 2>&1 || {
|
||||
echo -e "${RED}✗ Error: 'python' not found${NC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
command -v npm >/dev/null 2>&1 || {
|
||||
echo -e "${RED}✗ Error: 'npm' not found${NC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Check if node_modules exists
|
||||
if [ ! -d "node_modules" ]; then
|
||||
echo -e "${YELLOW}⚠ node_modules not found. Running npm install...${NC}"
|
||||
npm install
|
||||
echo -e "${GREEN}✓ Dependencies installed${NC}"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}Starting services...${NC}"
|
||||
echo ""
|
||||
|
||||
# 1. Start FastAPI backend (Rerun will start when recording begins)
|
||||
echo -e "${BLUE}[1/2]${NC} Starting FastAPI backend on port 8000..."
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# Use Python from current environment (if lerobot env is active, it will use that)
|
||||
# Otherwise, check if we need to use conda run
|
||||
if [[ "$CONDA_DEFAULT_ENV" == "lerobot" ]]; then
|
||||
# Already in lerobot environment
|
||||
echo -e "${GREEN}✓ Using active lerobot environment${NC}"
|
||||
PYTHON_CMD="python"
|
||||
elif command -v conda >/dev/null 2>&1 && conda env list | grep -q "^lerobot "; then
|
||||
# lerobot env exists but not active - use conda run
|
||||
echo -e "${YELLOW}Using conda run with lerobot environment...${NC}"
|
||||
PYTHON_CMD="conda run -n lerobot --no-capture-output python"
|
||||
else
|
||||
# Fall back to system python
|
||||
echo -e "${YELLOW}⚠ Warning: lerobot environment not found, using system python${NC}"
|
||||
PYTHON_CMD="python"
|
||||
fi
|
||||
|
||||
$PYTHON_CMD web_record_server.py > /tmp/openarms_backend.log 2>&1 &
|
||||
BACKEND_PID=$!
|
||||
sleep 3
|
||||
|
||||
if ps -p $BACKEND_PID > /dev/null; then
|
||||
echo -e "${GREEN}✓ Backend started${NC} (PID: $BACKEND_PID)"
|
||||
echo -e " URL: ${BLUE}http://localhost:8000${NC}"
|
||||
else
|
||||
echo -e "${RED}✗ Failed to start backend${NC}"
|
||||
echo -e "${YELLOW}Check logs: tail -f /tmp/openarms_backend.log${NC}"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# 2. Start React frontend
|
||||
echo -e "${BLUE}[2/2]${NC} Starting React frontend on port 5173..."
|
||||
cd "$SCRIPT_DIR"
|
||||
npm run dev > /tmp/openarms_frontend.log 2>&1 &
|
||||
FRONTEND_PID=$!
|
||||
sleep 3
|
||||
|
||||
if ps -p $FRONTEND_PID > /dev/null; then
|
||||
echo -e "${GREEN}✓ Frontend started${NC} (PID: $FRONTEND_PID)"
|
||||
echo -e " URL: ${BLUE}http://localhost:5173${NC}"
|
||||
else
|
||||
echo -e "${RED}✗ Failed to start frontend${NC}"
|
||||
echo -e "${YELLOW}Check logs: tail -f /tmp/openarms_frontend.log${NC}"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Display status
|
||||
echo -e "${GREEN}╔════════════════════════════════════════╗${NC}"
|
||||
echo -e "${GREEN}║ All services running! 🚀 ║${NC}"
|
||||
echo -e "${GREEN}╚════════════════════════════════════════╝${NC}"
|
||||
echo ""
|
||||
echo -e "🔧 ${BLUE}Backend:${NC} http://localhost:8000"
|
||||
echo -e "🌐 ${BLUE}Frontend:${NC} http://localhost:5173"
|
||||
echo -e "📊 ${BLUE}Rerun:${NC} Will spawn automatically when recording starts"
|
||||
echo ""
|
||||
echo -e "${YELLOW}Open your browser to:${NC} ${BLUE}http://localhost:5173${NC}"
|
||||
echo ""
|
||||
echo -e "${YELLOW}Logs:${NC}"
|
||||
echo -e " • Backend: tail -f /tmp/openarms_backend.log"
|
||||
echo -e " • Frontend: tail -f /tmp/openarms_frontend.log"
|
||||
echo ""
|
||||
echo -e "${RED}Press Ctrl+C to stop all services${NC}"
|
||||
echo ""
|
||||
|
||||
# Keep script running and wait for any service to exit
|
||||
wait
|
||||
|
||||
7
examples/openarms_web_interface/main.jsx
Normal file
7
examples/openarms_web_interface/main.jsx
Normal file
@@ -0,0 +1,7 @@
|
||||
import { createRoot } from 'react-dom/client'
|
||||
import App from './App.jsx'
|
||||
|
||||
createRoot(document.getElementById('root')).render(
|
||||
<App />
|
||||
)
|
||||
|
||||
1955
examples/openarms_web_interface/package-lock.json
generated
Normal file
1955
examples/openarms_web_interface/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
21
examples/openarms_web_interface/package.json
Normal file
21
examples/openarms_web_interface/package.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"name": "openarms-web-interface",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/react": "^18.3.12",
|
||||
"@types/react-dom": "^18.3.1",
|
||||
"@vitejs/plugin-react": "^4.3.4",
|
||||
"vite": "^6.0.1"
|
||||
}
|
||||
}
|
||||
17
examples/openarms_web_interface/vite.config.js
Normal file
17
examples/openarms_web_interface/vite.config.js
Normal file
@@ -0,0 +1,17 @@
|
||||
import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react'
|
||||
|
||||
// https://vite.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
server: {
|
||||
port: 5173,
|
||||
strictPort: false,
|
||||
host: true,
|
||||
open: false
|
||||
},
|
||||
build: {
|
||||
outDir: 'dist',
|
||||
sourcemap: true
|
||||
}
|
||||
})
|
||||
1533
examples/openarms_web_interface/web_record_server.py
Normal file
1533
examples/openarms_web_interface/web_record_server.py
Normal file
File diff suppressed because it is too large
Load Diff
98
examples/tutorial/act/act_training_example.py
Normal file
98
examples/tutorial/act/act_training_example.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""This script demonstrates how to train ACT Policy on a real-world dataset."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
|
||||
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
|
||||
if delta_indices is None:
|
||||
return [0]
|
||||
|
||||
return [i / fps for i in delta_indices]
|
||||
|
||||
|
||||
output_directory = Path("outputs/robot_learning_tutorial/act")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
policy = ACTPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
|
||||
}
|
||||
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
57
examples/tutorial/act/act_using_example.py
Normal file
57
examples/tutorial/act/act_using_example.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "fracapuano/robot_learning_tutorial_act"
|
||||
model = ACTPolicy.from_pretrained(model_id)
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
11
examples/tutorial/async-inf/policy_server.py
Normal file
11
examples/tutorial/async-inf/policy_server.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
host = ... # something like "127.0.0.1" if you're exposing to localhost
|
||||
port = ... # something like 8080
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
serve(config)
|
||||
55
examples/tutorial/async-inf/robot_client.py
Normal file
55
examples/tutorial/async-inf/robot_client.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import threading
|
||||
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.helpers import visualize_action_queue_size
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
|
||||
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
|
||||
# check the config.json on the Hub for the policy you are using to see the expected camera specs
|
||||
camera_cfg = {
|
||||
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
|
||||
|
||||
server_address = ... # something like "127.0.0.1:8080" if using localhost
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="fracapuano/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
# 5. Provide a textual description of the task
|
||||
task = ...
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
99
examples/tutorial/diffusion/diffusion_training_example.py
Normal file
99
examples/tutorial/diffusion/diffusion_training_example.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""This script demonstrates how to train Diffusion Policy on a real-world dataset."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
|
||||
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
|
||||
if delta_indices is None:
|
||||
return [0]
|
||||
|
||||
return [i / fps for i in delta_indices]
|
||||
|
||||
|
||||
output_directory = Path("outputs/robot_learning_tutorial/diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
||||
policy = DiffusionPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
|
||||
}
|
||||
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
60
examples/tutorial/diffusion/diffusion_using_example.py
Normal file
60
examples/tutorial/diffusion/diffusion_using_example.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "fracapuano/robot_learning_tutorial_diffusion"
|
||||
|
||||
model = DiffusionPolicy.from_pretrained(model_id)
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config, model_id, dataset_stats=dataset_metadata.stats
|
||||
)
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
67
examples/tutorial/pi0/using_pi0_example.py
Normal file
67
examples/tutorial/pi0/using_pi0_example.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/pi0_base"
|
||||
|
||||
model = PI0Policy.from_pretrained(model_id)
|
||||
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"base_0_rgb": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"left_wrist_0_rgb": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
"right_wrist_0_rgb": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
345
examples/tutorial/rl/hilserl_example.py
Normal file
345
examples/tutorial/rl/hilserl_example.py
Normal file
@@ -0,0 +1,345 @@
|
||||
import multiprocessing as mp
|
||||
import signal
|
||||
from pathlib import Path
|
||||
from queue import Empty, Full
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
from lerobot.teleoperators.so100_leader import SO100LeaderConfig
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
LOG_EVERY = 10
|
||||
SEND_EVERY = 10
|
||||
|
||||
|
||||
def run_learner(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_learner: SACPolicy,
|
||||
online_buffer: ReplayBuffer,
|
||||
offline_buffer: ReplayBuffer,
|
||||
lr: float = 3e-4,
|
||||
batch_size: int = 32,
|
||||
device: torch.device = "mps",
|
||||
):
|
||||
"""The learner process - trains SAC policy on transitions streamed from the actor, updating parameters
|
||||
for the actor to adopt."""
|
||||
policy_learner.train()
|
||||
policy_learner.to(device)
|
||||
|
||||
# Create Adam optimizer from scratch - simple and clean
|
||||
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
|
||||
|
||||
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
|
||||
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
|
||||
|
||||
training_step = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
# retrieve incoming transitions from the actor process
|
||||
try:
|
||||
transitions = transitions_queue.get(timeout=0.1)
|
||||
for transition in transitions:
|
||||
# HIL-SERL: Add ALL transitions to online buffer
|
||||
online_buffer.add(**transition)
|
||||
|
||||
# HIL-SERL: Add ONLY human intervention transitions to offline buffer
|
||||
is_intervention = transition.get("complementary_info", {}).get("is_intervention", False)
|
||||
if is_intervention:
|
||||
offline_buffer.add(**transition)
|
||||
print(
|
||||
f"[LEARNER] Human intervention detected! Added to offline buffer (now {len(offline_buffer)} transitions)"
|
||||
)
|
||||
|
||||
except Empty:
|
||||
pass # No transitions available, continue
|
||||
|
||||
# Train if we have enough data
|
||||
if len(online_buffer) >= policy_learner.config.online_step_before_learning:
|
||||
# Sample from online buffer (autonomous + human data)
|
||||
online_batch = online_buffer.sample(batch_size // 2)
|
||||
|
||||
# Sample from offline buffer (human demonstrations only, either precollected or at runtime)
|
||||
offline_batch = offline_buffer.sample(batch_size // 2)
|
||||
|
||||
# Combine batches - this is the key HIL-SERL mechanism!
|
||||
batch = {}
|
||||
for key in online_batch:
|
||||
if key in offline_batch:
|
||||
batch[key] = torch.cat([online_batch[key], offline_batch[key]], dim=0)
|
||||
else:
|
||||
batch[key] = online_batch[key]
|
||||
|
||||
loss, _ = policy_learner.forward(batch)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
training_step += 1
|
||||
|
||||
if training_step % LOG_EVERY == 0:
|
||||
print(
|
||||
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
|
||||
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
|
||||
)
|
||||
|
||||
# Send updated parameters to actor every 10 training steps
|
||||
if training_step % SEND_EVERY == 0:
|
||||
try:
|
||||
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
|
||||
parameters_queue.put_nowait(state_dict)
|
||||
print("[LEARNER] Sent updated parameters to actor")
|
||||
except Full:
|
||||
# Missing write due to queue not being consumed (should happen rarely)
|
||||
pass
|
||||
|
||||
print("[LEARNER] Learner process finished")
|
||||
|
||||
|
||||
def run_actor(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_actor: SACPolicy,
|
||||
reward_classifier: Classifier,
|
||||
env_cfg: HILSerlRobotEnvConfig,
|
||||
device: torch.device = "mps",
|
||||
output_directory: Path | None = None,
|
||||
):
|
||||
"""The actor process - interacts with environment and collects data.
|
||||
The policy is frozen and only the parameters are updated, popping the most recent ones from a queue."""
|
||||
policy_actor.eval()
|
||||
policy_actor.to(device)
|
||||
|
||||
reward_classifier.eval()
|
||||
reward_classifier.to(device)
|
||||
|
||||
# Create robot environment inside the actor process
|
||||
env, teleop_device = make_robot_env(env_cfg)
|
||||
|
||||
try:
|
||||
for episode in range(MAX_EPISODES):
|
||||
if shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs, _info = env.reset()
|
||||
episode_reward = 0.0
|
||||
step = 0
|
||||
episode_transitions = []
|
||||
|
||||
print(f"[ACTOR] Starting episode {episode + 1}")
|
||||
|
||||
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
|
||||
try:
|
||||
new_params = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_params)
|
||||
print("[ACTOR] Updated policy parameters from learner")
|
||||
except Empty: # No new updated parameters available from learner, waiting
|
||||
pass
|
||||
|
||||
# Get action from policy
|
||||
policy_obs = make_policy_obs(obs, device=device)
|
||||
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
|
||||
action = action_tensor.squeeze(0).cpu().numpy()
|
||||
|
||||
# Step environment
|
||||
next_obs, _env_reward, terminated, truncated, _info = env.step(action)
|
||||
done = terminated or truncated
|
||||
|
||||
# Predict reward
|
||||
policy_next_obs = make_policy_obs(next_obs, device=device)
|
||||
reward = reward_classifier.predict_reward(policy_next_obs)
|
||||
|
||||
if reward >= 1.0 and not done: # success detected! halt episode
|
||||
terminated = True
|
||||
done = True
|
||||
|
||||
# In HIL-SERL, human interventions come from the teleop device
|
||||
is_intervention = False
|
||||
if hasattr(teleop_device, "get_teleop_events"):
|
||||
# Real intervention detection from teleop device
|
||||
teleop_events = teleop_device.get_teleop_events()
|
||||
is_intervention = teleop_events.get(TeleopEvents.IS_INTERVENTION, False)
|
||||
|
||||
# Store transition with intervention metadata
|
||||
transition = {
|
||||
"state": policy_obs,
|
||||
"action": action,
|
||||
"reward": float(reward) if hasattr(reward, "item") else reward,
|
||||
"next_state": policy_next_obs,
|
||||
"done": done,
|
||||
"truncated": truncated,
|
||||
"complementary_info": {
|
||||
"is_intervention": is_intervention,
|
||||
},
|
||||
}
|
||||
|
||||
episode_transitions.append(transition)
|
||||
|
||||
episode_reward += reward
|
||||
step += 1
|
||||
|
||||
obs = next_obs
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Send episode transitions to learner
|
||||
transitions_queue.put_nowait(episode_transitions)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("[ACTOR] Interrupted by user")
|
||||
finally:
|
||||
# Clean up
|
||||
if hasattr(env, "robot") and env.robot.is_connected:
|
||||
env.robot.disconnect()
|
||||
if teleop_device and hasattr(teleop_device, "disconnect"):
|
||||
teleop_device.disconnect()
|
||||
if output_directory is not None:
|
||||
policy_actor.save_pretrained(output_directory)
|
||||
print(f"[ACTOR] Latest actor policy saved at: {output_directory}")
|
||||
|
||||
print("[ACTOR] Actor process finished")
|
||||
|
||||
|
||||
def make_policy_obs(obs, device: torch.device = "cpu"):
|
||||
return {
|
||||
"observation.state": torch.from_numpy(obs["agent_pos"]).float().unsqueeze(0).to(device),
|
||||
**{
|
||||
f"observation.image.{k}": torch.from_numpy(obs["pixels"][k]).float().unsqueeze(0).to(device)
|
||||
for k in obs["pixels"]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
"""Main function - coordinates actor and learner processes."""
|
||||
|
||||
device = "mps" # or "cuda" or "cpu"
|
||||
output_directory = Path("outputs/robot_learning_tutorial/hil_serl")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ...
|
||||
leader_port = ...
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ...
|
||||
leader_id = ...
|
||||
|
||||
# A pretrained model (to be used in-distribution!)
|
||||
reward_classifier_id = "fracapuano/reward_classifier_hil_serl_example"
|
||||
reward_classifier = Classifier.from_pretrained(reward_classifier_id)
|
||||
|
||||
reward_classifier.to(device)
|
||||
reward_classifier.eval()
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
# Robot and environment configuration
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id)
|
||||
teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id)
|
||||
processor_cfg = HILSerlProcessorConfig(control_mode="leader")
|
||||
|
||||
env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg)
|
||||
|
||||
# Create robot environment
|
||||
env, teleop_device = make_robot_env(env_cfg)
|
||||
|
||||
obs_features = hw_to_dataset_features(env.robot.observation_features, "observation")
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = SACConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
|
||||
# Online buffer: initialized from scratch
|
||||
online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys()))
|
||||
# Offline buffer: Created from dataset (pre-populated it with demonstrations)
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys())
|
||||
)
|
||||
|
||||
# Create communication channels between learner and actor processes
|
||||
transitions_queue = mp.Queue(maxsize=10)
|
||||
parameters_queue = mp.Queue(maxsize=2)
|
||||
shutdown_event = mp.Event()
|
||||
|
||||
|
||||
# Signal handler for graceful shutdown
|
||||
def signal_handler(sig):
|
||||
print(f"\nSignal {sig} received, shutting down...")
|
||||
shutdown_event.set()
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Create processes
|
||||
learner_process = mp.Process(
|
||||
target=run_learner,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_learner,
|
||||
online_replay_buffer,
|
||||
offline_replay_buffer,
|
||||
),
|
||||
kwargs={"device": device}, # can run on accelerated hardware for training
|
||||
)
|
||||
|
||||
actor_process = mp.Process(
|
||||
target=run_actor,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_actor,
|
||||
reward_classifier,
|
||||
env_cfg,
|
||||
output_directory,
|
||||
),
|
||||
kwargs={"device": "cpu"}, # actor is frozen, can run on CPU or accelerate for inference
|
||||
)
|
||||
|
||||
learner_process.start()
|
||||
actor_process.start()
|
||||
|
||||
try:
|
||||
# Wait for actor to finish (it controls the episode loop)
|
||||
actor_process.join()
|
||||
shutdown_event.set()
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Main process interrupted")
|
||||
shutdown_event.set()
|
||||
actor_process.join(timeout=5)
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
finally:
|
||||
if learner_process.is_alive():
|
||||
learner_process.terminate()
|
||||
if actor_process.is_alive():
|
||||
actor_process.terminate()
|
||||
62
examples/tutorial/rl/reward_classifier_example.py
Normal file
62
examples/tutorial/rl/reward_classifier_example.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
# Device to use for training
|
||||
device = "mps" # or "cuda", or "cpu"
|
||||
|
||||
# Load the dataset used for training
|
||||
repo_id = "lerobot/example_hil_serl_dataset"
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
# Configure the policy to extract features from the image frames
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
|
||||
config = RewardClassifierConfig(
|
||||
num_cameras=len(camera_keys),
|
||||
device=device,
|
||||
# backbone model to extract features from the image frames
|
||||
model_name="microsoft/resnet-18",
|
||||
)
|
||||
|
||||
# Make policy, preprocessor, and optimizer
|
||||
policy = make_policy(config, ds_meta=dataset.meta)
|
||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
||||
|
||||
|
||||
classifier_id = "fracapuano/reward_classifier_hil_serl_example"
|
||||
|
||||
# Instantiate a dataloader
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
|
||||
|
||||
# Training loop
|
||||
num_epochs = 5
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
total_accuracy = 0
|
||||
for batch in dataloader:
|
||||
# Preprocess the batch and move it to the correct device.
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Forward pass
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_accuracy += output_dict["accuracy"]
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
avg_accuracy = total_accuracy / len(dataloader)
|
||||
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
|
||||
|
||||
print("Training finished!")
|
||||
|
||||
# You can now save the trained policy.
|
||||
policy.push_to_hub(classifier_id)
|
||||
66
examples/tutorial/smolvla/using_smolvla_example.py
Normal file
66
examples/tutorial/smolvla/using_smolvla_example.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/smolvla_base"
|
||||
|
||||
model = SmolVLAPolicy.from_pretrained(model_id)
|
||||
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
10
loop_datasets.py
Normal file
10
loop_datasets.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from huggingface_hub import HfApi, list_datasets
|
||||
|
||||
api = HfApi()
|
||||
datasets = list_datasets(author="lerobot-data-collection")
|
||||
print('"[', end="")
|
||||
i=0
|
||||
for dataset in datasets:
|
||||
if "three-folds-dataset" in dataset.id:
|
||||
print("'" + dataset.id + "',", end="")
|
||||
print(']"',)
|
||||
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.3.4"
|
||||
version = "0.4.1"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
readme = "README.md"
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -62,6 +62,7 @@ dependencies = [
|
||||
"datasets>=4.0.0,<4.2.0",
|
||||
"diffusers>=0.27.2,<0.36.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
|
||||
# Core dependencies
|
||||
"setuptools>=71.0.0,<81.0.0",
|
||||
@@ -73,15 +74,15 @@ dependencies = [
|
||||
"packaging>=24.2,<26.0",
|
||||
"pynput>=1.7.7,<1.9.0",
|
||||
"pyserial>=3.5,<4.0",
|
||||
"wandb>=0.20.0,<0.23.0",
|
||||
"wandb>=0.20.0,<0.22.0", # TODO: Bumb dependency (compatible with protobuf)
|
||||
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
|
||||
"draccus==0.10.0", # TODO: Remove ==
|
||||
"gymnasium>=1.0.0",
|
||||
"rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
"gymnasium>=1.1.1,<2.0.0",
|
||||
"rerun-sdk>=0.24.0,<0.27.0",
|
||||
|
||||
# Support dependencies
|
||||
"deepdiff>=7.0.1,<9.0.0",
|
||||
@@ -96,13 +97,15 @@ dependencies = [
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
damiao = ["python-can>=4.2.0,<5.0.0"]
|
||||
|
||||
# Robots
|
||||
openarms = ["lerobot[damiao]"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
@@ -112,17 +115,23 @@ intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
|
||||
]
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0"]
|
||||
# stretch = [
|
||||
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
|
||||
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
|
||||
# "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"
|
||||
# ] # TODO: Currently not supported
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
||||
|
||||
# Policies
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
"peft>=0.13.0,<1.0.0",
|
||||
"dm-tree>=0.1.8,<1.0.0",
|
||||
"timm>=1.0.0,<1.1.0",
|
||||
"safetensors>=0.4.3,<1.0.0",
|
||||
"Pillow>=10.0.0,<13.0.0",
|
||||
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
@@ -136,11 +145,12 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
|
||||
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"]
|
||||
metaworld = ["metaworld>=3.0.0"]
|
||||
metaworld = ["metaworld==3.0.0"]
|
||||
|
||||
# All
|
||||
all = [
|
||||
"lerobot[dynamixel]",
|
||||
"lerobot[openarms]",
|
||||
"lerobot[gamepad]",
|
||||
"lerobot[hopejr]",
|
||||
"lerobot[lekiwi]",
|
||||
@@ -149,6 +159,7 @@ all = [
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[pi]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
@@ -233,9 +244,6 @@ exclude_dirs = [
|
||||
"tests",
|
||||
"benchmarks",
|
||||
"src/lerobot/datasets/push_dataset_to_hub",
|
||||
"src/lerobot/datasets/v2/convert_dataset_v1_to_v2",
|
||||
"src/lerobot/policies/pi0/conversion_scripts",
|
||||
"src/lerobot/scripts/push_dataset_to_hub.py",
|
||||
]
|
||||
skips = ["B101", "B311", "B404", "B603", "B615"]
|
||||
|
||||
@@ -250,6 +258,8 @@ default.extend-ignore-identifiers-re = [
|
||||
"pn",
|
||||
"ser",
|
||||
"ein",
|
||||
"thw",
|
||||
"inpt",
|
||||
]
|
||||
|
||||
# TODO: Uncomment when ready to use
|
||||
@@ -288,7 +298,6 @@ ignore_errors = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.envs.*"
|
||||
# Enable type checking only for the envs module
|
||||
ignore_errors = false
|
||||
|
||||
|
||||
@@ -296,17 +305,22 @@ ignore_errors = false
|
||||
# module = "lerobot.utils.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.configs.*"
|
||||
# ignore_errors = false
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.configs.*"
|
||||
ignore_errors = false
|
||||
|
||||
# extra strictness for configs
|
||||
disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
check_untyped_defs = true
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.optim.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.model.*"
|
||||
# ignore_errors = false
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.model.*"
|
||||
ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.processor.*"
|
||||
@@ -316,9 +330,9 @@ ignore_errors = false
|
||||
# module = "lerobot.datasets.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.cameras.*"
|
||||
# ignore_errors = false
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.cameras.*"
|
||||
ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.motors.*"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with Python 3.10
|
||||
# by the following command:
|
||||
#
|
||||
@@ -12,47 +13,62 @@ absl-py==2.3.1
|
||||
# dm-tree
|
||||
# labmaze
|
||||
# mujoco
|
||||
accelerate==1.9.0
|
||||
# via lerobot
|
||||
# tensorboard
|
||||
accelerate==1.11.0
|
||||
# via
|
||||
# lerobot
|
||||
# peft
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.12.15
|
||||
aiohttp==3.13.1
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via
|
||||
# hydra-core
|
||||
# omegaconf
|
||||
anyio==4.11.0
|
||||
# via
|
||||
# starlette
|
||||
# watchfiles
|
||||
asttokens==3.0.0
|
||||
# via stack-data
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.3.0
|
||||
attrs==25.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
# dm-tree
|
||||
# jsonlines
|
||||
# jsonschema
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
av==15.0.0
|
||||
av==15.1.0
|
||||
# via lerobot
|
||||
blinker==1.9.0
|
||||
# via flask
|
||||
certifi==2025.7.14
|
||||
bddl==1.0.1
|
||||
# via libero
|
||||
certifi==2025.10.5
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==1.17.1
|
||||
cffi==2.0.0
|
||||
# via pymunk
|
||||
cfgv==3.4.0
|
||||
# via pre-commit
|
||||
charset-normalizer==3.4.2
|
||||
charset-normalizer==3.4.4
|
||||
# via requests
|
||||
click==8.2.1
|
||||
click==8.3.0
|
||||
# via
|
||||
# flask
|
||||
# uvicorn
|
||||
# wandb
|
||||
cloudpickle==3.1.1
|
||||
# via gymnasium
|
||||
cmake==4.0.3
|
||||
# via
|
||||
# gymnasium
|
||||
# libero
|
||||
cmake==4.1.0
|
||||
# via lerobot
|
||||
cmeel==0.57.3
|
||||
# via
|
||||
@@ -94,27 +110,27 @@ coal-library==3.0.1
|
||||
# via pin
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
coverage[toml]==7.10.1
|
||||
coverage[toml]==7.11.0
|
||||
# via pytest-cov
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
datasets==3.6.0
|
||||
datasets==4.1.1
|
||||
# via lerobot
|
||||
debugpy==1.8.15
|
||||
debugpy==1.8.17
|
||||
# via lerobot
|
||||
decorator==5.2.1
|
||||
# via ipython
|
||||
deepdiff==8.5.0
|
||||
deepdiff==8.6.1
|
||||
# via lerobot
|
||||
diffusers==0.34.0
|
||||
diffusers==0.35.2
|
||||
# via lerobot
|
||||
dill==0.3.8
|
||||
dill==0.4.0
|
||||
# via
|
||||
# datasets
|
||||
# multiprocess
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
dm-control==1.0.14
|
||||
dm-control==1.0.34
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
@@ -122,29 +138,45 @@ dm-tree==0.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# lerobot
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
draccus==0.10.0
|
||||
# via lerobot
|
||||
dynamixel-sdk==3.7.31
|
||||
dynamixel-sdk==3.8.4
|
||||
# via lerobot
|
||||
easydict==1.13
|
||||
# via libero
|
||||
egl-probe @ git+https://github.com/huggingface/egl_probe.git
|
||||
# via
|
||||
# libero
|
||||
# robomimic
|
||||
eigenpy==3.10.3
|
||||
# via coal-library
|
||||
einops==0.8.1
|
||||
# via lerobot
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
etils[epath,epy]==1.13.0
|
||||
# via mujoco
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# ipython
|
||||
# pytest
|
||||
executing==2.2.0
|
||||
executing==2.2.1
|
||||
# via stack-data
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fastapi==0.119.1
|
||||
# via teleop
|
||||
fastjsonschema==2.21.2
|
||||
# via nbformat
|
||||
feetech-servo-sdk==1.0.0
|
||||
# via lerobot
|
||||
filelock==3.18.0
|
||||
filelock==3.20.0
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
@@ -152,24 +184,25 @@ filelock==3.18.0
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
flask==3.1.1
|
||||
# via lerobot
|
||||
fonttools==4.59.0
|
||||
fonttools==4.60.1
|
||||
# via matplotlib
|
||||
frozenlist==1.7.0
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2025.3.0
|
||||
fsspec[http]==2025.9.0
|
||||
# via
|
||||
# datasets
|
||||
# etils
|
||||
# huggingface-hub
|
||||
# torch
|
||||
future==1.0.0
|
||||
# via libero
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.45
|
||||
# via wandb
|
||||
glfw==2.9.0
|
||||
glfw==2.10.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
@@ -177,61 +210,79 @@ grpcio==1.73.1
|
||||
# via
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
grpcio-tools==1.73.1
|
||||
# via
|
||||
# lerobot
|
||||
# reachy2-sdk-api
|
||||
gym-aloha==0.1.3
|
||||
# via lerobot
|
||||
gym-aloha==0.1.1
|
||||
gym-hil==0.1.13
|
||||
# via lerobot
|
||||
gym-hil==0.1.10
|
||||
gym-pusht==0.1.6
|
||||
# via lerobot
|
||||
gym-pusht==0.1.5
|
||||
# via lerobot
|
||||
gym-xarm==0.1.1
|
||||
# via lerobot
|
||||
gymnasium==0.29.1
|
||||
gymnasium==1.2.1
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# pettingzoo
|
||||
gymnasium-robotics==1.2.4
|
||||
# via gym-xarm
|
||||
# libero
|
||||
# metaworld
|
||||
h11==0.16.0
|
||||
# via uvicorn
|
||||
h5py==3.15.1
|
||||
# via robomimic
|
||||
hebi-py==2.11.0
|
||||
# via lerobot
|
||||
hf-transfer==0.1.9
|
||||
# via huggingface-hub
|
||||
hf-xet==1.1.5
|
||||
hf-xet==1.1.10
|
||||
# via huggingface-hub
|
||||
hidapi==0.14.0.post4
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
huggingface-hub[cli,hf-transfer]==0.34.3
|
||||
httptools==0.7.1
|
||||
# via uvicorn
|
||||
huggingface-hub[cli,hf-transfer]==0.35.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
identify==2.6.12
|
||||
hydra-core==1.3.2
|
||||
# via libero
|
||||
identify==2.6.15
|
||||
# via pre-commit
|
||||
idna==3.10
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# requests
|
||||
# yarl
|
||||
imageio[ffmpeg]==2.37.0
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robomimic
|
||||
# scikit-image
|
||||
imageio-ffmpeg==0.6.0
|
||||
# via imageio
|
||||
# via
|
||||
# imageio
|
||||
# robomimic
|
||||
importlib-metadata==8.7.0
|
||||
# via diffusers
|
||||
iniconfig==2.1.0
|
||||
importlib-resources==6.5.2
|
||||
# via etils
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
inquirerpy==0.3.4
|
||||
# via huggingface-hub
|
||||
@@ -239,50 +290,71 @@ ipython==8.37.0
|
||||
# via meshcat
|
||||
ischedule==1.2.7
|
||||
# via placo
|
||||
itsdangerous==2.2.0
|
||||
# via flask
|
||||
jedi==0.19.2
|
||||
# via ipython
|
||||
jinja2==3.1.6
|
||||
# via
|
||||
# flask
|
||||
# gymnasium-robotics
|
||||
# torch
|
||||
# via torch
|
||||
jsonlines==4.0.0
|
||||
# via lerobot
|
||||
kiwisolver==1.4.8
|
||||
jsonschema==4.25.1
|
||||
# via nbformat
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
jupyter-core==5.9.1
|
||||
# via nbformat
|
||||
jupytext==1.18.1
|
||||
# via bddl
|
||||
kiwisolver==1.4.9
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lazy-loader==0.4
|
||||
# via scikit-image
|
||||
lxml==6.0.0
|
||||
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
|
||||
# via lerobot
|
||||
llvmlite==0.45.1
|
||||
# via numba
|
||||
lxml==6.0.2
|
||||
# via dm-control
|
||||
markupsafe==3.0.2
|
||||
markdown==3.9
|
||||
# via tensorboard
|
||||
markdown-it-py==4.0.0
|
||||
# via
|
||||
# jupytext
|
||||
# mdit-py-plugins
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# flask
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.10.5
|
||||
# via lerobot
|
||||
matplotlib-inline==0.1.7
|
||||
matplotlib==3.10.7
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
matplotlib-inline==0.2.1
|
||||
# via ipython
|
||||
mdit-py-plugins==0.5.0
|
||||
# via jupytext
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mergedeep==1.3.4
|
||||
# via draccus
|
||||
meshcat==0.3.2
|
||||
# via placo
|
||||
metaworld==3.0.0
|
||||
# via lerobot
|
||||
mock-serial==0.0.1
|
||||
# via lerobot
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mujoco==2.3.7
|
||||
mujoco==3.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
multidict==6.6.3
|
||||
# libero
|
||||
# metaworld
|
||||
# robosuite
|
||||
multidict==6.7.0
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@@ -290,17 +362,25 @@ multiprocess==0.70.16
|
||||
# via datasets
|
||||
mypy-extensions==1.1.0
|
||||
# via typing-inspect
|
||||
nbformat==5.10.4
|
||||
# via jupytext
|
||||
networkx==3.4.2
|
||||
# via
|
||||
# bddl
|
||||
# scikit-image
|
||||
# torch
|
||||
ninja==1.13.0
|
||||
# via lerobot
|
||||
nodeenv==1.9.1
|
||||
# via pre-commit
|
||||
num2words==0.5.14
|
||||
# via lerobot
|
||||
numba==0.62.1
|
||||
# via robosuite
|
||||
numpy==2.2.6
|
||||
# via
|
||||
# accelerate
|
||||
# bddl
|
||||
# cmeel-boost
|
||||
# contourpy
|
||||
# datasets
|
||||
@@ -309,25 +389,43 @@ numpy==2.2.6
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# gymnasium
|
||||
# gymnasium-robotics
|
||||
# h5py
|
||||
# hebi-py
|
||||
# imageio
|
||||
# labmaze
|
||||
# libero
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# metaworld
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# opencv-python-headless
|
||||
# pandas
|
||||
# pettingzoo
|
||||
# peft
|
||||
# pyquaternion
|
||||
# reachy2-sdk
|
||||
# rerun-sdk
|
||||
# robomimic
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# scipy
|
||||
# shapely
|
||||
# teleop
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# tifffile
|
||||
# torchvision
|
||||
# transformers
|
||||
# transforms3d
|
||||
omegaconf==2.3.0
|
||||
# via hydra-core
|
||||
opencv-python==4.12.0.88
|
||||
# via gym-pusht
|
||||
# via
|
||||
# gym-pusht
|
||||
# libero
|
||||
# reachy2-sdk
|
||||
# robosuite
|
||||
opencv-python-headless==4.12.0.88
|
||||
# via lerobot
|
||||
orderly-set==5.5.0
|
||||
@@ -337,53 +435,63 @@ packaging==25.0
|
||||
# accelerate
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# hydra-core
|
||||
# jupytext
|
||||
# lazy-loader
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# peft
|
||||
# pytest
|
||||
# reachy2-sdk
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# transformers
|
||||
# wandb
|
||||
pandas==2.3.1
|
||||
pandas==2.3.3
|
||||
# via
|
||||
# datasets
|
||||
# lerobot
|
||||
parso==0.8.4
|
||||
parso==0.8.5
|
||||
# via jedi
|
||||
pettingzoo==1.24.3
|
||||
# via gymnasium-robotics
|
||||
peft==0.17.1
|
||||
# via lerobot
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pfzy==0.3.4
|
||||
# via inquirerpy
|
||||
pillow==11.3.0
|
||||
pillow==12.0.0
|
||||
# via
|
||||
# diffusers
|
||||
# imageio
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# rerun-sdk
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# torchvision
|
||||
pin==3.4.0
|
||||
# via placo
|
||||
placo==0.9.14
|
||||
# via lerobot
|
||||
platformdirs==4.3.8
|
||||
platformdirs==4.5.0
|
||||
# via
|
||||
# jupyter-core
|
||||
# virtualenv
|
||||
# wandb
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
pre-commit==4.2.0
|
||||
pre-commit==4.3.0
|
||||
# via lerobot
|
||||
prompt-toolkit==3.0.51
|
||||
prompt-toolkit==3.0.52
|
||||
# via
|
||||
# inquirerpy
|
||||
# ipython
|
||||
propcache==0.3.2
|
||||
propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@@ -392,11 +500,17 @@ protobuf==6.31.0
|
||||
# dm-control
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# wandb
|
||||
psutil==7.0.0
|
||||
psutil==7.1.1
|
||||
# via
|
||||
# accelerate
|
||||
# imageio
|
||||
# peft
|
||||
# robomimic
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
@@ -405,11 +519,13 @@ pyarrow==21.0.0
|
||||
# via
|
||||
# datasets
|
||||
# rerun-sdk
|
||||
pycparser==2.22
|
||||
pycparser==2.23
|
||||
# via cffi
|
||||
pydantic==2.11.7
|
||||
# via wandb
|
||||
pydantic-core==2.33.2
|
||||
pydantic==2.12.3
|
||||
# via
|
||||
# fastapi
|
||||
# wandb
|
||||
pydantic-core==2.41.4
|
||||
# via pydantic
|
||||
pygame==2.6.1
|
||||
# via
|
||||
@@ -424,40 +540,42 @@ pymunk==6.11.1
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pyngrok==7.2.12
|
||||
pyngrok==7.4.1
|
||||
# via meshcat
|
||||
pynput==1.8.1
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
pyobjc-core==11.1
|
||||
pyobjc-core==12.0
|
||||
# via
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-cocoa
|
||||
# pyobjc-framework-coretext
|
||||
# pyobjc-framework-quartz
|
||||
pyobjc-framework-applicationservices==11.1
|
||||
pyobjc-framework-applicationservices==12.0
|
||||
# via pynput
|
||||
pyobjc-framework-cocoa==11.1
|
||||
pyobjc-framework-cocoa==12.0
|
||||
# via
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-coretext
|
||||
# pyobjc-framework-quartz
|
||||
pyobjc-framework-coretext==11.1
|
||||
pyobjc-framework-coretext==12.0
|
||||
# via pyobjc-framework-applicationservices
|
||||
pyobjc-framework-quartz==11.1
|
||||
pyobjc-framework-quartz==12.0
|
||||
# via
|
||||
# pynput
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-coretext
|
||||
pyopengl==3.1.9
|
||||
pyopengl==3.1.10
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.3
|
||||
pyparsing==3.2.5
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via reachy2-sdk
|
||||
pyrealsense2-macosx==2.54.2
|
||||
# via lerobot
|
||||
pyserial==3.5
|
||||
@@ -465,12 +583,14 @@ pyserial==3.5
|
||||
# dynamixel-sdk
|
||||
# feetech-servo-sdk
|
||||
# lerobot
|
||||
pytest==8.4.1
|
||||
pytest==8.4.2
|
||||
# via
|
||||
# bddl
|
||||
# lerobot
|
||||
# pytest-cov
|
||||
# pytest-timeout
|
||||
pytest-cov==6.2.1
|
||||
# teleop
|
||||
pytest-cov==7.0.0
|
||||
# via lerobot
|
||||
pytest-timeout==2.4.0
|
||||
# via lerobot
|
||||
@@ -478,46 +598,73 @@ python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# matplotlib
|
||||
# pandas
|
||||
python-dotenv==1.1.1
|
||||
# via uvicorn
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
pyyaml==6.0.2
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# draccus
|
||||
# hebi-py
|
||||
# huggingface-hub
|
||||
# jupytext
|
||||
# omegaconf
|
||||
# peft
|
||||
# pre-commit
|
||||
# pyngrok
|
||||
# pyyaml-include
|
||||
# timm
|
||||
# transformers
|
||||
# uvicorn
|
||||
# wandb
|
||||
pyyaml-include==1.4.1
|
||||
# via draccus
|
||||
pyzmq==27.0.0
|
||||
pyzmq==27.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# meshcat
|
||||
regex==2025.7.34
|
||||
reachy2-sdk==1.0.14
|
||||
# via lerobot
|
||||
reachy2-sdk-api==1.0.21
|
||||
# via reachy2-sdk
|
||||
referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2025.10.23
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
requests==2.32.4
|
||||
requests==2.32.5
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# teleop
|
||||
# transformers
|
||||
# wandb
|
||||
rerun-sdk==0.22.1
|
||||
rerun-sdk==0.26.1
|
||||
# via lerobot
|
||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||
# via placo
|
||||
safetensors==0.5.3
|
||||
robomimic==0.2.0
|
||||
# via libero
|
||||
robosuite==1.4.0
|
||||
# via libero
|
||||
rpds-py==0.28.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.6.2
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# transformers
|
||||
scikit-image==0.25.2
|
||||
# via
|
||||
@@ -526,10 +673,12 @@ scikit-image==0.25.2
|
||||
scipy==1.15.3
|
||||
# via
|
||||
# dm-control
|
||||
# metaworld
|
||||
# robosuite
|
||||
# scikit-image
|
||||
sentry-sdk==2.34.1
|
||||
sentry-sdk==2.42.1
|
||||
# via wandb
|
||||
shapely==2.1.1
|
||||
shapely==2.1.2
|
||||
# via gym-pusht
|
||||
six==1.17.0
|
||||
# via
|
||||
@@ -537,64 +686,106 @@ six==1.17.0
|
||||
# python-dateutil
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
starlette==0.48.0
|
||||
# via fastapi
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
termcolor==3.1.0
|
||||
teleop==0.1.2
|
||||
# via lerobot
|
||||
tensorboard==2.20.0
|
||||
# via robomimic
|
||||
tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
tensorboardx==2.6.4
|
||||
# via robomimic
|
||||
termcolor==3.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
thop==0.1.1.post2209072238
|
||||
# via libero
|
||||
tifffile==2025.5.10
|
||||
# via scikit-image
|
||||
tokenizers==0.21.4
|
||||
timm==1.0.20
|
||||
# via lerobot
|
||||
tokenizers==0.22.1
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via draccus
|
||||
tomli==2.2.1
|
||||
tomli==2.3.0
|
||||
# via
|
||||
# cmeel
|
||||
# coverage
|
||||
# jupytext
|
||||
# pytest
|
||||
torch==2.7.1
|
||||
# via
|
||||
# accelerate
|
||||
# lerobot
|
||||
# peft
|
||||
# robomimic
|
||||
# thop
|
||||
# timm
|
||||
# torchvision
|
||||
torchcodec==0.5
|
||||
# via lerobot
|
||||
torchvision==0.22.1
|
||||
# via lerobot
|
||||
tornado==6.5.1
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
# timm
|
||||
tornado==6.5.2
|
||||
# via meshcat
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# datasets
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# peft
|
||||
# robomimic
|
||||
# transformers
|
||||
traitlets==5.14.3
|
||||
# via
|
||||
# ipython
|
||||
# jupyter-core
|
||||
# matplotlib-inline
|
||||
transformers==4.51.3
|
||||
# via lerobot
|
||||
typing-extensions==4.14.1
|
||||
# nbformat
|
||||
transformers==4.57.1
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
# peft
|
||||
transforms3d==0.4.2
|
||||
# via teleop
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# etils
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
# gymnasium
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# multidict
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
# starlette
|
||||
# torch
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# uvicorn
|
||||
# virtualenv
|
||||
# wandb
|
||||
typing-inspect==0.9.0
|
||||
# via draccus
|
||||
typing-inspection==0.4.1
|
||||
typing-inspection==0.4.2
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via pandas
|
||||
@@ -604,22 +795,36 @@ urllib3==2.5.0
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
virtualenv==20.32.0
|
||||
uvicorn[standard]==0.38.0
|
||||
# via teleop
|
||||
uvloop==0.22.1
|
||||
# via uvicorn
|
||||
virtualenv==20.35.3
|
||||
# via pre-commit
|
||||
wandb==0.21.0
|
||||
# via lerobot
|
||||
wcwidth==0.2.13
|
||||
wandb==0.21.4
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
watchfiles==1.1.1
|
||||
# via uvicorn
|
||||
wcwidth==0.2.14
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.9.0
|
||||
# via teleop
|
||||
websockets==15.0.1
|
||||
# via uvicorn
|
||||
werkzeug==3.1.3
|
||||
# via flask
|
||||
wrapt==1.17.2
|
||||
# via tensorboard
|
||||
wrapt==2.0.0
|
||||
# via dm-tree
|
||||
xxhash==3.5.0
|
||||
xxhash==3.6.0
|
||||
# via datasets
|
||||
yarl==1.20.1
|
||||
yarl==1.22.0
|
||||
# via aiohttp
|
||||
zipp==3.23.0
|
||||
# via importlib-metadata
|
||||
# via
|
||||
# etils
|
||||
# importlib-metadata
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# setuptools
|
||||
|
||||
@@ -13,47 +13,62 @@ absl-py==2.3.1
|
||||
# dm-tree
|
||||
# labmaze
|
||||
# mujoco
|
||||
accelerate==1.9.0
|
||||
# via lerobot
|
||||
# tensorboard
|
||||
accelerate==1.11.0
|
||||
# via
|
||||
# lerobot
|
||||
# peft
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.12.15
|
||||
aiohttp==3.13.1
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via
|
||||
# hydra-core
|
||||
# omegaconf
|
||||
anyio==4.11.0
|
||||
# via
|
||||
# starlette
|
||||
# watchfiles
|
||||
asttokens==3.0.0
|
||||
# via stack-data
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.3.0
|
||||
attrs==25.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
# dm-tree
|
||||
# jsonlines
|
||||
# jsonschema
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
av==15.0.0
|
||||
av==15.1.0
|
||||
# via lerobot
|
||||
blinker==1.9.0
|
||||
# via flask
|
||||
certifi==2025.7.14
|
||||
bddl==1.0.1
|
||||
# via libero
|
||||
certifi==2025.10.5
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==1.17.1
|
||||
cffi==2.0.0
|
||||
# via pymunk
|
||||
cfgv==3.4.0
|
||||
# via pre-commit
|
||||
charset-normalizer==3.4.2
|
||||
charset-normalizer==3.4.4
|
||||
# via requests
|
||||
click==8.2.1
|
||||
click==8.3.0
|
||||
# via
|
||||
# flask
|
||||
# uvicorn
|
||||
# wandb
|
||||
cloudpickle==3.1.1
|
||||
# via gymnasium
|
||||
cmake==4.0.3
|
||||
# via
|
||||
# gymnasium
|
||||
# libero
|
||||
cmake==4.1.0
|
||||
# via lerobot
|
||||
cmeel==0.57.3
|
||||
# via
|
||||
@@ -95,27 +110,29 @@ coal-library==3.0.1
|
||||
# via pin
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
coverage[toml]==7.10.1
|
||||
coverage[toml]==7.11.0
|
||||
# via pytest-cov
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
datasets==3.6.0
|
||||
datasets==4.1.1
|
||||
# via lerobot
|
||||
debugpy==1.8.15
|
||||
debugpy==1.8.17
|
||||
# via lerobot
|
||||
decorator==5.2.1
|
||||
# via ipython
|
||||
deepdiff==8.5.0
|
||||
decord==0.6.0
|
||||
# via lerobot
|
||||
diffusers==0.34.0
|
||||
deepdiff==8.6.1
|
||||
# via lerobot
|
||||
dill==0.3.8
|
||||
diffusers==0.35.2
|
||||
# via lerobot
|
||||
dill==0.4.0
|
||||
# via
|
||||
# datasets
|
||||
# multiprocess
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
dm-control==1.0.14
|
||||
dm-control==1.0.34
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
@@ -123,31 +140,48 @@ dm-tree==0.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# lerobot
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
draccus==0.10.0
|
||||
# via lerobot
|
||||
dynamixel-sdk==3.7.31
|
||||
dynamixel-sdk==3.8.4
|
||||
# via lerobot
|
||||
easydict==1.13
|
||||
# via libero
|
||||
egl-probe @ git+https://github.com/huggingface/egl_probe.git
|
||||
# via
|
||||
# libero
|
||||
# robomimic
|
||||
eigenpy==3.10.3
|
||||
# via coal-library
|
||||
einops==0.8.1
|
||||
# via lerobot
|
||||
# via
|
||||
# flash-attn
|
||||
# lerobot
|
||||
# libero
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
etils[epath,epy]==1.13.0
|
||||
# via mujoco
|
||||
evdev==1.9.2
|
||||
# via pynput
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# ipython
|
||||
# pytest
|
||||
executing==2.2.0
|
||||
executing==2.2.1
|
||||
# via stack-data
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fastapi==0.119.1
|
||||
# via teleop
|
||||
fastjsonschema==2.21.2
|
||||
# via nbformat
|
||||
feetech-servo-sdk==1.0.0
|
||||
# via lerobot
|
||||
filelock==3.18.0
|
||||
filelock==3.20.0
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
@@ -155,24 +189,27 @@ filelock==3.18.0
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
flask==3.1.1
|
||||
flash-attn==2.8.3
|
||||
# via lerobot
|
||||
fonttools==4.59.0
|
||||
fonttools==4.60.1
|
||||
# via matplotlib
|
||||
frozenlist==1.7.0
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2025.3.0
|
||||
fsspec[http]==2025.9.0
|
||||
# via
|
||||
# datasets
|
||||
# etils
|
||||
# huggingface-hub
|
||||
# torch
|
||||
future==1.0.0
|
||||
# via libero
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.45
|
||||
# via wandb
|
||||
glfw==2.9.0
|
||||
glfw==2.10.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
@@ -180,61 +217,79 @@ grpcio==1.73.1
|
||||
# via
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
grpcio-tools==1.73.1
|
||||
# via
|
||||
# lerobot
|
||||
# reachy2-sdk-api
|
||||
gym-aloha==0.1.3
|
||||
# via lerobot
|
||||
gym-aloha==0.1.1
|
||||
gym-hil==0.1.13
|
||||
# via lerobot
|
||||
gym-hil==0.1.10
|
||||
gym-pusht==0.1.6
|
||||
# via lerobot
|
||||
gym-pusht==0.1.5
|
||||
# via lerobot
|
||||
gym-xarm==0.1.1
|
||||
# via lerobot
|
||||
gymnasium==0.29.1
|
||||
gymnasium==1.2.1
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# pettingzoo
|
||||
gymnasium-robotics==1.2.4
|
||||
# via gym-xarm
|
||||
# libero
|
||||
# metaworld
|
||||
h11==0.16.0
|
||||
# via uvicorn
|
||||
h5py==3.15.1
|
||||
# via robomimic
|
||||
hebi-py==2.11.0
|
||||
# via lerobot
|
||||
hf-transfer==0.1.9
|
||||
# via huggingface-hub
|
||||
hf-xet==1.1.5
|
||||
hf-xet==1.1.10
|
||||
# via huggingface-hub
|
||||
hidapi==0.14.0.post4
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
huggingface-hub[cli,hf-transfer]==0.34.3
|
||||
httptools==0.7.1
|
||||
# via uvicorn
|
||||
huggingface-hub[cli,hf-transfer]==0.35.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
identify==2.6.12
|
||||
hydra-core==1.3.2
|
||||
# via libero
|
||||
identify==2.6.15
|
||||
# via pre-commit
|
||||
idna==3.10
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# requests
|
||||
# yarl
|
||||
imageio[ffmpeg]==2.37.0
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# metaworld
|
||||
# robomimic
|
||||
# scikit-image
|
||||
imageio-ffmpeg==0.6.0
|
||||
# via imageio
|
||||
# via
|
||||
# imageio
|
||||
# robomimic
|
||||
importlib-metadata==8.7.0
|
||||
# via diffusers
|
||||
iniconfig==2.1.0
|
||||
importlib-resources==6.5.2
|
||||
# via etils
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
inquirerpy==0.3.4
|
||||
# via huggingface-hub
|
||||
@@ -242,50 +297,71 @@ ipython==8.37.0
|
||||
# via meshcat
|
||||
ischedule==1.2.7
|
||||
# via placo
|
||||
itsdangerous==2.2.0
|
||||
# via flask
|
||||
jedi==0.19.2
|
||||
# via ipython
|
||||
jinja2==3.1.6
|
||||
# via
|
||||
# flask
|
||||
# gymnasium-robotics
|
||||
# torch
|
||||
# via torch
|
||||
jsonlines==4.0.0
|
||||
# via lerobot
|
||||
kiwisolver==1.4.8
|
||||
jsonschema==4.25.1
|
||||
# via nbformat
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
jupyter-core==5.9.1
|
||||
# via nbformat
|
||||
jupytext==1.18.1
|
||||
# via bddl
|
||||
kiwisolver==1.4.9
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lazy-loader==0.4
|
||||
# via scikit-image
|
||||
lxml==6.0.0
|
||||
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
|
||||
# via lerobot
|
||||
llvmlite==0.45.1
|
||||
# via numba
|
||||
lxml==6.0.2
|
||||
# via dm-control
|
||||
markupsafe==3.0.2
|
||||
markdown==3.9
|
||||
# via tensorboard
|
||||
markdown-it-py==4.0.0
|
||||
# via
|
||||
# jupytext
|
||||
# mdit-py-plugins
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# flask
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.10.5
|
||||
# via lerobot
|
||||
matplotlib-inline==0.1.7
|
||||
matplotlib==3.10.7
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
matplotlib-inline==0.2.1
|
||||
# via ipython
|
||||
mdit-py-plugins==0.5.0
|
||||
# via jupytext
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mergedeep==1.3.4
|
||||
# via draccus
|
||||
meshcat==0.3.2
|
||||
# via placo
|
||||
metaworld==3.0.0
|
||||
# via lerobot
|
||||
mock-serial==0.0.1
|
||||
# via lerobot
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mujoco==2.3.7
|
||||
mujoco==3.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
multidict==6.6.3
|
||||
# libero
|
||||
# metaworld
|
||||
# robosuite
|
||||
multidict==6.7.0
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@@ -293,42 +369,63 @@ multiprocess==0.70.16
|
||||
# via datasets
|
||||
mypy-extensions==1.1.0
|
||||
# via typing-inspect
|
||||
nbformat==5.10.4
|
||||
# via jupytext
|
||||
networkx==3.4.2
|
||||
# via
|
||||
# bddl
|
||||
# scikit-image
|
||||
# torch
|
||||
ninja==1.13.0
|
||||
# via lerobot
|
||||
nodeenv==1.9.1
|
||||
# via pre-commit
|
||||
num2words==0.5.14
|
||||
# via lerobot
|
||||
numba==0.62.1
|
||||
# via robosuite
|
||||
numpy==2.2.6
|
||||
# via
|
||||
# accelerate
|
||||
# bddl
|
||||
# cmeel-boost
|
||||
# contourpy
|
||||
# datasets
|
||||
# decord
|
||||
# diffusers
|
||||
# dm-control
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# gymnasium
|
||||
# gymnasium-robotics
|
||||
# h5py
|
||||
# hebi-py
|
||||
# imageio
|
||||
# labmaze
|
||||
# libero
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# metaworld
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# opencv-python-headless
|
||||
# pandas
|
||||
# pettingzoo
|
||||
# peft
|
||||
# pyquaternion
|
||||
# reachy2-sdk
|
||||
# rerun-sdk
|
||||
# robomimic
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# scipy
|
||||
# shapely
|
||||
# teleop
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# tifffile
|
||||
# torchvision
|
||||
# transformers
|
||||
# transforms3d
|
||||
nvidia-cublas-cu12==12.6.4.1
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
@@ -366,8 +463,14 @@ nvidia-nvjitlink-cu12==12.6.85
|
||||
# torch
|
||||
nvidia-nvtx-cu12==12.6.77
|
||||
# via torch
|
||||
omegaconf==2.3.0
|
||||
# via hydra-core
|
||||
opencv-python==4.12.0.88
|
||||
# via gym-pusht
|
||||
# via
|
||||
# gym-pusht
|
||||
# libero
|
||||
# reachy2-sdk
|
||||
# robosuite
|
||||
opencv-python-headless==4.12.0.88
|
||||
# via lerobot
|
||||
orderly-set==5.5.0
|
||||
@@ -377,53 +480,63 @@ packaging==25.0
|
||||
# accelerate
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# hydra-core
|
||||
# jupytext
|
||||
# lazy-loader
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# peft
|
||||
# pytest
|
||||
# reachy2-sdk
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# transformers
|
||||
# wandb
|
||||
pandas==2.3.1
|
||||
pandas==2.3.3
|
||||
# via
|
||||
# datasets
|
||||
# lerobot
|
||||
parso==0.8.4
|
||||
parso==0.8.5
|
||||
# via jedi
|
||||
pettingzoo==1.24.3
|
||||
# via gymnasium-robotics
|
||||
peft==0.17.1
|
||||
# via lerobot
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pfzy==0.3.4
|
||||
# via inquirerpy
|
||||
pillow==11.3.0
|
||||
pillow==12.0.0
|
||||
# via
|
||||
# diffusers
|
||||
# imageio
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# rerun-sdk
|
||||
# robosuite
|
||||
# scikit-image
|
||||
# tensorboard
|
||||
# torchvision
|
||||
pin==3.4.0
|
||||
# via placo
|
||||
placo==0.9.14
|
||||
# via lerobot
|
||||
platformdirs==4.3.8
|
||||
platformdirs==4.5.0
|
||||
# via
|
||||
# jupyter-core
|
||||
# virtualenv
|
||||
# wandb
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
pre-commit==4.2.0
|
||||
pre-commit==4.3.0
|
||||
# via lerobot
|
||||
prompt-toolkit==3.0.51
|
||||
prompt-toolkit==3.0.52
|
||||
# via
|
||||
# inquirerpy
|
||||
# ipython
|
||||
propcache==0.3.2
|
||||
propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@@ -432,11 +545,17 @@ protobuf==6.31.0
|
||||
# dm-control
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# reachy2-sdk
|
||||
# reachy2-sdk-api
|
||||
# tensorboard
|
||||
# tensorboardx
|
||||
# wandb
|
||||
psutil==7.0.0
|
||||
psutil==7.1.1
|
||||
# via
|
||||
# accelerate
|
||||
# imageio
|
||||
# peft
|
||||
# robomimic
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
@@ -445,11 +564,13 @@ pyarrow==21.0.0
|
||||
# via
|
||||
# datasets
|
||||
# rerun-sdk
|
||||
pycparser==2.22
|
||||
pycparser==2.23
|
||||
# via cffi
|
||||
pydantic==2.11.7
|
||||
# via wandb
|
||||
pydantic-core==2.33.2
|
||||
pydantic==2.12.3
|
||||
# via
|
||||
# fastapi
|
||||
# wandb
|
||||
pydantic-core==2.41.4
|
||||
# via pydantic
|
||||
pygame==2.6.1
|
||||
# via
|
||||
@@ -464,20 +585,22 @@ pymunk==6.11.1
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pyngrok==7.2.12
|
||||
pyngrok==7.4.1
|
||||
# via meshcat
|
||||
pynput==1.8.1
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
pyopengl==3.1.9
|
||||
pyopengl==3.1.10
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.3
|
||||
pyparsing==3.2.5
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via reachy2-sdk
|
||||
pyrealsense2==2.56.5.9235
|
||||
# via lerobot
|
||||
pyserial==3.5
|
||||
@@ -485,12 +608,14 @@ pyserial==3.5
|
||||
# dynamixel-sdk
|
||||
# feetech-servo-sdk
|
||||
# lerobot
|
||||
pytest==8.4.1
|
||||
pytest==8.4.2
|
||||
# via
|
||||
# bddl
|
||||
# lerobot
|
||||
# pytest-cov
|
||||
# pytest-timeout
|
||||
pytest-cov==6.2.1
|
||||
# teleop
|
||||
pytest-cov==7.0.0
|
||||
# via lerobot
|
||||
pytest-timeout==2.4.0
|
||||
# via lerobot
|
||||
@@ -498,48 +623,75 @@ python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# matplotlib
|
||||
# pandas
|
||||
python-dotenv==1.1.1
|
||||
# via uvicorn
|
||||
python-xlib==0.33
|
||||
# via pynput
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
pyyaml==6.0.2
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# draccus
|
||||
# hebi-py
|
||||
# huggingface-hub
|
||||
# jupytext
|
||||
# omegaconf
|
||||
# peft
|
||||
# pre-commit
|
||||
# pyngrok
|
||||
# pyyaml-include
|
||||
# timm
|
||||
# transformers
|
||||
# uvicorn
|
||||
# wandb
|
||||
pyyaml-include==1.4.1
|
||||
# via draccus
|
||||
pyzmq==27.0.0
|
||||
pyzmq==27.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# meshcat
|
||||
regex==2025.7.34
|
||||
reachy2-sdk==1.0.14
|
||||
# via lerobot
|
||||
reachy2-sdk-api==1.0.21
|
||||
# via reachy2-sdk
|
||||
referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2025.10.23
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
requests==2.32.4
|
||||
requests==2.32.5
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# teleop
|
||||
# transformers
|
||||
# wandb
|
||||
rerun-sdk==0.22.1
|
||||
rerun-sdk==0.26.1
|
||||
# via lerobot
|
||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||
# via placo
|
||||
safetensors==0.5.3
|
||||
robomimic==0.2.0
|
||||
# via libero
|
||||
robosuite==1.4.0
|
||||
# via libero
|
||||
rpds-py==0.28.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.6.2
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# lerobot
|
||||
# peft
|
||||
# timm
|
||||
# transformers
|
||||
scikit-image==0.25.2
|
||||
# via
|
||||
@@ -548,10 +700,12 @@ scikit-image==0.25.2
|
||||
scipy==1.15.3
|
||||
# via
|
||||
# dm-control
|
||||
# metaworld
|
||||
# robosuite
|
||||
# scikit-image
|
||||
sentry-sdk==2.34.1
|
||||
sentry-sdk==2.42.1
|
||||
# via wandb
|
||||
shapely==2.1.1
|
||||
shapely==2.1.2
|
||||
# via gym-pusht
|
||||
six==1.17.0
|
||||
# via
|
||||
@@ -560,66 +714,109 @@ six==1.17.0
|
||||
# python-xlib
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
starlette==0.48.0
|
||||
# via fastapi
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
termcolor==3.1.0
|
||||
teleop==0.1.2
|
||||
# via lerobot
|
||||
tensorboard==2.20.0
|
||||
# via robomimic
|
||||
tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
tensorboardx==2.6.4
|
||||
# via robomimic
|
||||
termcolor==3.1.0
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
thop==0.1.1.post2209072238
|
||||
# via libero
|
||||
tifffile==2025.5.10
|
||||
# via scikit-image
|
||||
tokenizers==0.21.4
|
||||
timm==1.0.20
|
||||
# via lerobot
|
||||
tokenizers==0.22.1
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via draccus
|
||||
tomli==2.2.1
|
||||
tomli==2.3.0
|
||||
# via
|
||||
# cmeel
|
||||
# coverage
|
||||
# jupytext
|
||||
# pytest
|
||||
torch==2.7.1
|
||||
# via
|
||||
# accelerate
|
||||
# flash-attn
|
||||
# lerobot
|
||||
# peft
|
||||
# robomimic
|
||||
# thop
|
||||
# timm
|
||||
# torchvision
|
||||
torchcodec==0.5
|
||||
# via lerobot
|
||||
torchvision==0.22.1
|
||||
# via lerobot
|
||||
tornado==6.5.1
|
||||
# via
|
||||
# lerobot
|
||||
# robomimic
|
||||
# timm
|
||||
tornado==6.5.2
|
||||
# via meshcat
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# datasets
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# peft
|
||||
# robomimic
|
||||
# transformers
|
||||
traitlets==5.14.3
|
||||
# via
|
||||
# ipython
|
||||
# jupyter-core
|
||||
# matplotlib-inline
|
||||
transformers==4.51.3
|
||||
# via lerobot
|
||||
# nbformat
|
||||
transformers==4.57.1
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
# peft
|
||||
transforms3d==0.4.2
|
||||
# via teleop
|
||||
triton==3.3.1
|
||||
# via torch
|
||||
typing-extensions==4.14.1
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# etils
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
# gymnasium
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# multidict
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# rerun-sdk
|
||||
# starlette
|
||||
# torch
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# uvicorn
|
||||
# virtualenv
|
||||
# wandb
|
||||
typing-inspect==0.9.0
|
||||
# via draccus
|
||||
typing-inspection==0.4.1
|
||||
typing-inspection==0.4.2
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via pandas
|
||||
@@ -629,22 +826,36 @@ urllib3==2.5.0
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
virtualenv==20.32.0
|
||||
uvicorn[standard]==0.38.0
|
||||
# via teleop
|
||||
uvloop==0.22.1
|
||||
# via uvicorn
|
||||
virtualenv==20.35.3
|
||||
# via pre-commit
|
||||
wandb==0.21.0
|
||||
# via lerobot
|
||||
wcwidth==0.2.13
|
||||
wandb==0.21.4
|
||||
# via
|
||||
# lerobot
|
||||
# libero
|
||||
watchfiles==1.1.1
|
||||
# via uvicorn
|
||||
wcwidth==0.2.14
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.9.0
|
||||
# via teleop
|
||||
websockets==15.0.1
|
||||
# via uvicorn
|
||||
werkzeug==3.1.3
|
||||
# via flask
|
||||
wrapt==1.17.2
|
||||
# via tensorboard
|
||||
wrapt==2.0.0
|
||||
# via dm-tree
|
||||
xxhash==3.5.0
|
||||
xxhash==3.6.0
|
||||
# via datasets
|
||||
yarl==1.20.1
|
||||
yarl==1.22.0
|
||||
# via aiohttp
|
||||
zipp==3.23.0
|
||||
# via importlib-metadata
|
||||
# via
|
||||
# etils
|
||||
# importlib-metadata
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# setuptools
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# requirements.in
|
||||
|
||||
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 15.5 24F74 arm64).
|
||||
# Darwin MacBook-Pro.local 24.5.0 Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:43 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8132 arm64
|
||||
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.0.1 25A362 arm64).
|
||||
# Darwin MacBook-Pro.local 25.0.0 Darwin Kernel Version 25.0.0: Wed Sep 17 21:42:08 PDT 2025; root:xnu-12377.1.9~141/RELEASE_ARM64_T8132 arm64
|
||||
|
||||
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.2 LTS x86_64).
|
||||
# Linux mlerobot-linux 6.14.0-27-generic #27~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 22 17:38:49 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
||||
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.3 LTS x86_64).
|
||||
# Linux mlerobot-linux 6.14.0-33-generic #33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
||||
|
||||
-e .[all]
|
||||
|
||||
@@ -16,7 +16,7 @@ import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@@ -268,6 +268,7 @@ class RemotePolicyConfig:
|
||||
lerobot_features: dict[str, PolicyFeature]
|
||||
actions_per_chunk: int
|
||||
device: str = "cpu"
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
|
||||
|
||||
@@ -159,7 +159,10 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
||||
self.policy.config,
|
||||
pretrained_path=policy_specs.pretrained_name_or_path,
|
||||
preprocessor_overrides={"device_processor": device_override},
|
||||
preprocessor_overrides={
|
||||
"device_processor": device_override,
|
||||
"rename_observations_processor": {"rename_map": policy_specs.rename_map},
|
||||
},
|
||||
postprocessor_overrides={"device_processor": device_override},
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
from .configs import CameraConfig, ColorMode
|
||||
|
||||
@@ -89,7 +89,7 @@ class Camera(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""Capture and return a single frame from the camera.
|
||||
|
||||
Args:
|
||||
@@ -102,7 +102,7 @@ class Camera(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def async_read(self, timeout_ms: float = ...) -> np.ndarray:
|
||||
def async_read(self, timeout_ms: float = ...) -> NDArray[Any]:
|
||||
"""Asynchronously capture and return a single frame from the camera.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -18,7 +18,7 @@ import abc
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import draccus
|
||||
import draccus # type: ignore # TODO: add type stubs for draccus
|
||||
|
||||
|
||||
class ColorMode(str, Enum):
|
||||
@@ -34,11 +34,11 @@ class Cv2Rotation(int, Enum):
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
return str(self.get_choice_name(self.__class__))
|
||||
|
||||
@@ -14,3 +14,5 @@
|
||||
|
||||
from .camera_opencv import OpenCVCamera
|
||||
from .configuration_opencv import OpenCVCameraConfig
|
||||
|
||||
__all__ = ["OpenCVCamera", "OpenCVCameraConfig"]
|
||||
|
||||
@@ -25,11 +25,12 @@ from pathlib import Path
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
# Fix MSMF hardware transform compatibility for Windows before importing cv2
|
||||
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2
|
||||
import numpy as np
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
@@ -121,7 +122,7 @@ class OpenCVCamera(Camera):
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
@@ -140,7 +141,7 @@ class OpenCVCamera(Camera):
|
||||
"""Checks if the camera is currently connected and opened."""
|
||||
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
|
||||
|
||||
def connect(self, warmup: bool = True):
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the OpenCV camera specified in the configuration.
|
||||
|
||||
@@ -180,12 +181,14 @@ class OpenCVCamera(Camera):
|
||||
|
||||
def _configure_capture_settings(self) -> None:
|
||||
"""
|
||||
Applies the specified FPS, width, and height settings to the connected camera.
|
||||
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
|
||||
|
||||
This method attempts to set the camera properties via OpenCV. It checks if
|
||||
the camera successfully applied the settings and raises an error if not.
|
||||
FOURCC is set first (if specified) as it can affect the available FPS and resolution options.
|
||||
|
||||
Args:
|
||||
fourcc: The desired FOURCC code (e.g., "MJPG", "YUYV"). If None, auto-detect.
|
||||
fps: The desired frames per second. If None, the setting is skipped.
|
||||
width: The desired capture width. If None, the setting is skipped.
|
||||
height: The desired capture height. If None, the setting is skipped.
|
||||
@@ -199,10 +202,11 @@ class OpenCVCamera(Camera):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
|
||||
|
||||
if self.fps is None:
|
||||
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
|
||||
else:
|
||||
self._validate_fps()
|
||||
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
|
||||
if self.config.fourcc is not None:
|
||||
self._validate_fourcc()
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
|
||||
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
||||
@@ -216,18 +220,56 @@ class OpenCVCamera(Camera):
|
||||
else:
|
||||
self._validate_width_and_height()
|
||||
|
||||
if self.fps is None:
|
||||
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
|
||||
else:
|
||||
self._validate_fps()
|
||||
|
||||
def _validate_fps(self) -> None:
|
||||
"""Validates and sets the camera's frames per second (FPS)."""
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
if self.fps is None:
|
||||
raise ValueError(f"{self} FPS is not set")
|
||||
|
||||
success = self.videocapture.set(cv2.CAP_PROP_FPS, float(self.fps))
|
||||
actual_fps = self.videocapture.get(cv2.CAP_PROP_FPS)
|
||||
# Use math.isclose for robust float comparison
|
||||
if not success or not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
raise RuntimeError(f"{self} failed to set fps={self.fps} ({actual_fps=}).")
|
||||
|
||||
def _validate_fourcc(self) -> None:
|
||||
"""Validates and sets the camera's FOURCC code."""
|
||||
|
||||
fourcc_code = cv2.VideoWriter_fourcc(*self.config.fourcc)
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
success = self.videocapture.set(cv2.CAP_PROP_FOURCC, fourcc_code)
|
||||
actual_fourcc_code = self.videocapture.get(cv2.CAP_PROP_FOURCC)
|
||||
|
||||
# Convert actual FOURCC code back to string for comparison
|
||||
actual_fourcc_code_int = int(actual_fourcc_code)
|
||||
actual_fourcc = "".join([chr((actual_fourcc_code_int >> 8 * i) & 0xFF) for i in range(4)])
|
||||
|
||||
if not success or actual_fourcc != self.config.fourcc:
|
||||
logger.warning(
|
||||
f"{self} failed to set fourcc={self.config.fourcc} (actual={actual_fourcc}, success={success}). "
|
||||
f"Continuing with default format."
|
||||
)
|
||||
|
||||
def _validate_width_and_height(self) -> None:
|
||||
"""Validates and sets the camera's frame capture width and height."""
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
if self.capture_width is None or self.capture_height is None:
|
||||
raise ValueError(f"{self} capture_width or capture_height is not set")
|
||||
|
||||
width_success = self.videocapture.set(cv2.CAP_PROP_FRAME_WIDTH, float(self.capture_width))
|
||||
height_success = self.videocapture.set(cv2.CAP_PROP_FRAME_HEIGHT, float(self.capture_height))
|
||||
|
||||
@@ -258,11 +300,12 @@ class OpenCVCamera(Camera):
|
||||
"""
|
||||
found_cameras_info = []
|
||||
|
||||
targets_to_scan: list[str | int]
|
||||
if platform.system() == "Linux":
|
||||
possible_paths = sorted(Path("/dev").glob("video*"), key=lambda p: p.name)
|
||||
targets_to_scan = [str(p) for p in possible_paths]
|
||||
else:
|
||||
targets_to_scan = list(range(MAX_OPENCV_INDEX))
|
||||
targets_to_scan = [int(i) for i in range(MAX_OPENCV_INDEX)]
|
||||
|
||||
for target in targets_to_scan:
|
||||
camera = cv2.VideoCapture(target)
|
||||
@@ -271,6 +314,12 @@ class OpenCVCamera(Camera):
|
||||
default_height = int(camera.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
default_fps = camera.get(cv2.CAP_PROP_FPS)
|
||||
default_format = camera.get(cv2.CAP_PROP_FORMAT)
|
||||
|
||||
# Get FOURCC code and convert to string
|
||||
default_fourcc_code = camera.get(cv2.CAP_PROP_FOURCC)
|
||||
default_fourcc_code_int = int(default_fourcc_code)
|
||||
default_fourcc = "".join([chr((default_fourcc_code_int >> 8 * i) & 0xFF) for i in range(4)])
|
||||
|
||||
camera_info = {
|
||||
"name": f"OpenCV Camera @ {target}",
|
||||
"type": "OpenCV",
|
||||
@@ -278,6 +327,7 @@ class OpenCVCamera(Camera):
|
||||
"backend_api": camera.getBackendName(),
|
||||
"default_stream_profile": {
|
||||
"format": default_format,
|
||||
"fourcc": default_fourcc,
|
||||
"width": default_width,
|
||||
"height": default_height,
|
||||
"fps": default_fps,
|
||||
@@ -289,7 +339,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return found_cameras_info
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
@@ -317,6 +367,9 @@ class OpenCVCamera(Camera):
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
ret, frame = self.videocapture.read()
|
||||
|
||||
if not ret or frame is None:
|
||||
@@ -329,7 +382,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return processed_frame
|
||||
|
||||
def _postprocess_image(self, image: np.ndarray, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Applies color conversion, dimension validation, and rotation to a raw frame.
|
||||
|
||||
@@ -372,7 +425,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return processed_image
|
||||
|
||||
def _read_loop(self):
|
||||
def _read_loop(self) -> None:
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
@@ -383,6 +436,9 @@ class OpenCVCamera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
@@ -419,7 +475,7 @@ class OpenCVCamera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
@@ -462,7 +518,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera and cleans up resources.
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ from pathlib import Path
|
||||
|
||||
from ..configs import CameraConfig, ColorMode, Cv2Rotation
|
||||
|
||||
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("opencv")
|
||||
@dataclass
|
||||
@@ -33,8 +35,9 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
OpenCVCameraConfig(0, 30, 1280, 720) # 1280x720 @ 30FPS
|
||||
OpenCVCameraConfig(/dev/video4, 60, 640, 480) # 640x480 @ 60FPS
|
||||
|
||||
# Advanced configurations
|
||||
OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation
|
||||
# Advanced configurations with FOURCC format
|
||||
OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90, fourcc="MJPG") # With 90° rotation and MJPG format
|
||||
OpenCVCameraConfig(0, 30, 1280, 720, fourcc="YUYV") # With YUYV format
|
||||
```
|
||||
|
||||
Attributes:
|
||||
@@ -46,17 +49,21 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
|
||||
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
|
||||
warmup_s: Time reading frames before returning from connect (in seconds)
|
||||
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
|
||||
|
||||
Note:
|
||||
- Only 3-channel color output (RGB/BGR) is currently supported.
|
||||
- FOURCC codes must be 4-character strings (e.g., "MJPG", "YUYV"). Some common FOUCC codes: https://learn.microsoft.com/en-us/windows/win32/medfound/video-fourccs#fourcc-constants
|
||||
- Setting FOURCC can help achieve higher frame rates on some cameras.
|
||||
"""
|
||||
|
||||
index_or_path: int | Path
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
|
||||
warmup_s: int = 1
|
||||
fourcc: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
@@ -71,3 +78,8 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
raise ValueError(
|
||||
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
|
||||
)
|
||||
|
||||
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
|
||||
raise ValueError(
|
||||
f"`fourcc` must be a 4-character string (e.g., 'MJPG', 'YUYV'), but '{self.fourcc}' is provided."
|
||||
)
|
||||
|
||||
@@ -16,6 +16,8 @@ from dataclasses import dataclass
|
||||
|
||||
from ..configs import CameraConfig, ColorMode
|
||||
|
||||
__all__ = ["CameraConfig", "ColorMode", "Reachy2CameraConfig"]
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("reachy2_camera")
|
||||
@dataclass
|
||||
@@ -62,7 +64,7 @@ class Reachy2CameraConfig(CameraConfig):
|
||||
port: int = 50065
|
||||
# use_depth: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.name not in ["teleop", "depth"]:
|
||||
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
|
||||
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (
|
||||
|
||||
@@ -23,13 +23,17 @@ import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
# Fix MSMF hardware transform compatibility for Windows before importing cv2
|
||||
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2
|
||||
import numpy as np
|
||||
from reachy2_sdk.media.camera import CameraView
|
||||
from reachy2_sdk.media.camera_manager import CameraManager
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import numpy as np # type: ignore # TODO: add type stubs for numpy
|
||||
from reachy2_sdk.media.camera import CameraView # type: ignore # TODO: add type stubs for reachy2_sdk
|
||||
from reachy2_sdk.media.camera_manager import ( # type: ignore # TODO: add type stubs for reachy2_sdk
|
||||
CameraManager,
|
||||
)
|
||||
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
@@ -73,7 +77,7 @@ class Reachy2Camera(Camera):
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -83,13 +87,17 @@ class Reachy2Camera(Camera):
|
||||
def is_connected(self) -> bool:
|
||||
"""Checks if the camera is currently connected and opened."""
|
||||
if self.config.name == "teleop":
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
|
||||
return bool(
|
||||
self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
|
||||
)
|
||||
elif self.config.name == "depth":
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
|
||||
return bool(
|
||||
self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
|
||||
|
||||
def connect(self, warmup: bool = True):
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the Reachy2 CameraManager as specified in the configuration.
|
||||
"""
|
||||
@@ -131,7 +139,7 @@ class Reachy2Camera(Camera):
|
||||
camera_manager.disconnect()
|
||||
return initialized_cameras
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
@@ -152,7 +160,7 @@ class Reachy2Camera(Camera):
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
frame = None
|
||||
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
|
||||
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
@@ -179,7 +187,7 @@ class Reachy2Camera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self):
|
||||
def _read_loop(self) -> None:
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
@@ -190,6 +198,9 @@ class Reachy2Camera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
@@ -226,7 +237,7 @@ class Reachy2Camera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
@@ -269,7 +280,7 @@ class Reachy2Camera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Stops the background read thread (if running).
|
||||
|
||||
|
||||
@@ -21,11 +21,12 @@ import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import numpy as np # type: ignore # TODO: add type stubs for numpy
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
try:
|
||||
import pyrealsense2 as rs
|
||||
import pyrealsense2 as rs # type: ignore # TODO: add type stubs for pyrealsense2
|
||||
except Exception as e:
|
||||
logging.info(f"Could not import realsense: {e}")
|
||||
|
||||
@@ -132,7 +133,7 @@ class RealSenseCamera(Camera):
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
@@ -150,7 +151,7 @@ class RealSenseCamera(Camera):
|
||||
"""Checks if the camera pipeline is started and streams are active."""
|
||||
return self.rs_pipeline is not None and self.rs_profile is not None
|
||||
|
||||
def connect(self, warmup: bool = True):
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the RealSense camera specified in the configuration.
|
||||
|
||||
@@ -264,7 +265,7 @@ class RealSenseCamera(Camera):
|
||||
serial_number = str(found_devices[0]["serial_number"])
|
||||
return serial_number
|
||||
|
||||
def _configure_rs_pipeline_config(self, rs_config):
|
||||
def _configure_rs_pipeline_config(self, rs_config: Any) -> None:
|
||||
"""Creates and configures the RealSense pipeline configuration object."""
|
||||
rs.config.enable_device(rs_config, self.serial_number)
|
||||
|
||||
@@ -293,6 +294,9 @@ class RealSenseCamera(Camera):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
|
||||
|
||||
if self.rs_profile is None:
|
||||
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
|
||||
|
||||
stream = self.rs_profile.get_stream(rs.stream.color).as_video_stream_profile()
|
||||
|
||||
if self.fps is None:
|
||||
@@ -308,7 +312,7 @@ class RealSenseCamera(Camera):
|
||||
self.width, self.height = actual_width, actual_height
|
||||
self.capture_width, self.capture_height = actual_width, actual_height
|
||||
|
||||
def read_depth(self, timeout_ms: int = 200) -> np.ndarray:
|
||||
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame (depth) synchronously from the camera.
|
||||
|
||||
@@ -336,6 +340,9 @@ class RealSenseCamera(Camera):
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.rs_pipeline is None:
|
||||
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
|
||||
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
|
||||
|
||||
if not ret or frame is None:
|
||||
@@ -351,7 +358,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return depth_map_processed
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> np.ndarray:
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame (color) synchronously from the camera.
|
||||
|
||||
@@ -376,6 +383,9 @@ class RealSenseCamera(Camera):
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.rs_pipeline is None:
|
||||
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
|
||||
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
|
||||
|
||||
if not ret or frame is None:
|
||||
@@ -392,8 +402,8 @@ class RealSenseCamera(Camera):
|
||||
return color_image_processed
|
||||
|
||||
def _postprocess_image(
|
||||
self, image: np.ndarray, color_mode: ColorMode | None = None, depth_frame: bool = False
|
||||
) -> np.ndarray:
|
||||
self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False
|
||||
) -> NDArray[Any]:
|
||||
"""
|
||||
Applies color conversion, dimension validation, and rotation to a raw color frame.
|
||||
|
||||
@@ -438,7 +448,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return processed_image
|
||||
|
||||
def _read_loop(self):
|
||||
def _read_loop(self) -> None:
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
@@ -449,6 +459,9 @@ class RealSenseCamera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read(timeout_ms=500)
|
||||
@@ -474,7 +487,7 @@ class RealSenseCamera(Camera):
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def _stop_read_thread(self):
|
||||
def _stop_read_thread(self) -> None:
|
||||
"""Signals the background read thread to stop and waits for it to join."""
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
@@ -486,7 +499,7 @@ class RealSenseCamera(Camera):
|
||||
self.stop_event = None
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame data (color) asynchronously.
|
||||
|
||||
@@ -529,7 +542,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera, stops the pipeline, and cleans up resources.
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ class RealSenseCameraConfig(CameraConfig):
|
||||
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
|
||||
warmup_s: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
|
||||
@@ -53,14 +53,14 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
|
||||
|
||||
|
||||
def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
|
||||
import cv2
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
|
||||
if rotation == Cv2Rotation.ROTATE_90:
|
||||
return cv2.ROTATE_90_CLOCKWISE
|
||||
return int(cv2.ROTATE_90_CLOCKWISE)
|
||||
elif rotation == Cv2Rotation.ROTATE_180:
|
||||
return cv2.ROTATE_180
|
||||
return int(cv2.ROTATE_180)
|
||||
elif rotation == Cv2Rotation.ROTATE_270:
|
||||
return cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -69,8 +69,8 @@ def get_cv2_backend() -> int:
|
||||
import cv2
|
||||
|
||||
if platform.system() == "Windows":
|
||||
return cv2.CAP_MSMF # Use MSMF for Windows instead of AVFOUNDATION
|
||||
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
|
||||
# elif platform.system() == "Darwin": # macOS
|
||||
# return cv2.CAP_AVFOUNDATION
|
||||
else: # Linux and others
|
||||
return cv2.CAP_ANY
|
||||
return int(cv2.CAP_ANY)
|
||||
|
||||
@@ -57,7 +57,7 @@ class EvalConfig:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
use_async_envs: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.batch_size > self.n_episodes:
|
||||
raise ValueError(
|
||||
"The eval batch size is greater than the number of eval episodes "
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import datetime as dt
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot import envs, policies # noqa: F401
|
||||
@@ -22,6 +22,8 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.default import EvalConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalPipelineConfig:
|
||||
@@ -34,25 +36,31 @@ class EvalPipelineConfig:
|
||||
output_dir: Path | None = None
|
||||
job_name: str | None = None
|
||||
seed: int | None = 1000
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
|
||||
else:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
||||
)
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{self.policy.type}"
|
||||
self.job_name = f"{self.policy.type if self.policy is not None else 'scratch'}"
|
||||
else:
|
||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
||||
self.job_name = (
|
||||
f"{self.env.type}_{self.policy.type if self.policy is not None else 'scratch'}"
|
||||
)
|
||||
|
||||
logger.warning(f"No job name provided, using '{self.job_name}' as job name.")
|
||||
|
||||
if not self.output_dir:
|
||||
now = dt.datetime.now()
|
||||
|
||||
@@ -16,14 +16,19 @@ import inspect
|
||||
import pkgutil
|
||||
import sys
|
||||
from argparse import ArgumentError
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from pkgutil import ModuleInfo
|
||||
from types import ModuleType
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.utils.utils import has_method
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
PATH_KEY = "path"
|
||||
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||
|
||||
@@ -60,7 +65,7 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
|
||||
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict[str, str]:
|
||||
"""Parse plugin-related arguments from command-line arguments.
|
||||
|
||||
This function extracts arguments from command-line arguments that match a specified suffix pattern.
|
||||
@@ -127,7 +132,7 @@ def load_plugin(plugin_path: str) -> None:
|
||||
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
||||
) from e
|
||||
|
||||
def iter_namespace(ns_pkg):
|
||||
def iter_namespace(ns_pkg: ModuleType) -> Iterable[ModuleInfo]:
|
||||
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
|
||||
|
||||
try:
|
||||
@@ -148,6 +153,8 @@ def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | No
|
||||
|
||||
|
||||
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
|
||||
if args is None:
|
||||
return []
|
||||
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
|
||||
|
||||
|
||||
@@ -171,7 +178,8 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
if isinstance(fields_to_filter, str):
|
||||
fields_to_filter = [fields_to_filter]
|
||||
|
||||
filtered_args = args
|
||||
filtered_args = [] if args is None else list(args)
|
||||
|
||||
for field in fields_to_filter:
|
||||
if get_path_arg(field, args):
|
||||
if get_type_arg(field, args):
|
||||
@@ -184,7 +192,7 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
return filtered_args
|
||||
|
||||
|
||||
def wrap(config_path: Path | None = None):
|
||||
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
"""
|
||||
HACK: Similar to draccus.wrap but does three additional things:
|
||||
- Will remove '.path' arguments from CLI in order to process them later on.
|
||||
@@ -195,9 +203,9 @@ def wrap(config_path: Path | None = None):
|
||||
from the CLI '.type' arguments
|
||||
"""
|
||||
|
||||
def wrapper_outer(fn):
|
||||
def wrapper_outer(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def wrapper_inner(*args, **kwargs):
|
||||
def wrapper_inner(*args: Any, **kwargs: Any) -> Any:
|
||||
argspec = inspect.getfullargspec(fn)
|
||||
argtype = argspec.annotations[argspec.args[0]]
|
||||
if len(args) > 0 and type(args[0]) is argtype:
|
||||
@@ -225,6 +233,6 @@ def wrap(config_path: Path | None = None):
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
return wrapper_inner
|
||||
return cast(F, wrapper_inner)
|
||||
|
||||
return wrapper_outer
|
||||
return cast(Callable[[F], F], wrapper_outer)
|
||||
|
||||
@@ -14,12 +14,12 @@
|
||||
import abc
|
||||
import builtins
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -34,10 +34,11 @@ from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedConfig")
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: ignore[misc,name-defined] #TODO: draccus issue
|
||||
"""
|
||||
Base configuration class for policy models.
|
||||
|
||||
@@ -57,12 +58,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
device: str | None = None # cuda | cpu | mp
|
||||
device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps"
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool = False
|
||||
|
||||
push_to_hub: bool = True
|
||||
push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override
|
||||
repo_id: str | None = None
|
||||
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
@@ -73,38 +74,41 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
license: str | None = None
|
||||
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
||||
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
|
||||
pretrained_path: str | None = None
|
||||
pretrained_path: Path | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
self.device = auto_device.type
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
if self.use_amp and not is_amp_available(self.device):
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||
)
|
||||
self.use_amp = False
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
choice_name = self.get_choice_name(self.__class__)
|
||||
if not isinstance(choice_name, str):
|
||||
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||
return choice_name
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_delta_indices(self) -> list | None:
|
||||
def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def reward_delta_indices(self) -> list | None:
|
||||
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -154,13 +158,13 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = None,
|
||||
proxies: dict | None = None,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[Any, Any] | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**policy_kwargs,
|
||||
**policy_kwargs: Any,
|
||||
) -> T:
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
@@ -168,7 +172,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
if CONFIG_NAME in os.listdir(model_id):
|
||||
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||
else:
|
||||
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
@@ -194,6 +198,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
with draccus.config_type("json"):
|
||||
orig_config = draccus.parse(cls, config_file, args=[])
|
||||
|
||||
if config_file is None:
|
||||
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
|
||||
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import datetime as dt
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -63,18 +64,18 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
self.checkpoint_path = None
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
# Only load the policy config
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
elif self.resume:
|
||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||
config_path = parser.parse_arg("config_path")
|
||||
@@ -82,14 +83,22 @@ class TrainPipelineConfig(HubMixin):
|
||||
raise ValueError(
|
||||
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
||||
)
|
||||
|
||||
if not Path(config_path).resolve().exists():
|
||||
raise NotADirectoryError(
|
||||
f"{config_path=} is expected to be a local path. "
|
||||
"Resuming from the hub is not supported for now."
|
||||
)
|
||||
policy_path = Path(config_path).parent
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.checkpoint_path = policy_path.parent
|
||||
|
||||
policy_dir = Path(config_path).parent
|
||||
if self.policy is not None:
|
||||
self.policy.pretrained_path = policy_dir
|
||||
self.checkpoint_path = policy_dir.parent
|
||||
|
||||
if self.policy is None:
|
||||
raise ValueError(
|
||||
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
|
||||
)
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
@@ -126,8 +135,8 @@ class TrainPipelineConfig(HubMixin):
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return draccus.encode(self)
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
@@ -139,13 +148,13 @@ class TrainPipelineConfig(HubMixin):
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = None,
|
||||
proxies: dict | None = None,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[Any, Any] | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> "TrainPipelineConfig":
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
@@ -181,4 +190,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
@@ -42,4 +42,4 @@ class NormalizationMode(str, Enum):
|
||||
@dataclass
|
||||
class PolicyFeature:
|
||||
type: FeatureType
|
||||
shape: tuple
|
||||
shape: tuple[int, ...]
|
||||
|
||||
@@ -22,11 +22,13 @@ from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import os
|
||||
import packaging.version
|
||||
import pandas as pd
|
||||
import PIL.Image
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
import torch
|
||||
import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
@@ -686,6 +688,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episode_buffer = None
|
||||
self.writer = None
|
||||
self.latest_episode = None
|
||||
self._current_file_start_frame = None # Track the starting frame index of the current parquet file
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@@ -708,7 +711,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if not self._check_cached_episodes_sufficient():
|
||||
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
@@ -835,14 +839,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return hf_dataset
|
||||
|
||||
def _check_cached_episodes_sufficient(self) -> bool:
|
||||
"""Check if the cached dataset contains all requested episodes."""
|
||||
"""Check if the cached dataset contains all requested episodes and their video files."""
|
||||
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
||||
return False
|
||||
|
||||
# Get available episode indices from cached dataset
|
||||
available_episodes = {
|
||||
ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx
|
||||
for ep_idx in self.hf_dataset["episode_index"]
|
||||
for ep_idx in self.hf_dataset.unique("episode_index")
|
||||
}
|
||||
|
||||
# Determine requested episodes
|
||||
@@ -854,7 +858,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
requested_episodes = set(self.episodes)
|
||||
|
||||
# Check if all requested episodes are available in cached data
|
||||
return requested_episodes.issubset(available_episodes)
|
||||
if not requested_episodes.issubset(available_episodes):
|
||||
return False
|
||||
|
||||
# Check if all required video files exist
|
||||
if len(self.meta.video_keys) > 0:
|
||||
for ep_idx in requested_episodes:
|
||||
for vid_key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
if not video_path.exists():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def create_hf_dataset(self) -> datasets.Dataset:
|
||||
features = get_hf_features_from_features(self.features)
|
||||
@@ -1136,8 +1151,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
if has_video_keys and not use_batched_encoding:
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
video_paths = self._encode_multiple_temporary_episode_videos(self.meta.video_keys, episode_index)
|
||||
for (video_key, video_path) in zip(self.meta.video_keys, video_paths):
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, video_path))
|
||||
|
||||
# `meta.save_episode` need to be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
@@ -1231,6 +1247,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
global_frame_index = 0
|
||||
self._current_file_start_frame = 0
|
||||
# However, if the episodes already exists
|
||||
# It means we are resuming recording, so we need to load the latest episode
|
||||
# Update the indices to avoid overwriting the latest episode
|
||||
@@ -1242,6 +1259,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# When resuming, move to the next file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
self._current_file_start_frame = global_frame_index
|
||||
else:
|
||||
# Retrieve information from the latest parquet file
|
||||
latest_ep = self.latest_episode
|
||||
@@ -1252,7 +1270,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
latest_size_in_mb = get_file_size_in_mb(latest_path)
|
||||
|
||||
frames_in_current_file = global_frame_index - latest_ep["dataset_from_index"]
|
||||
frames_in_current_file = global_frame_index - self._current_file_start_frame
|
||||
av_size_per_frame = (
|
||||
latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0
|
||||
)
|
||||
@@ -1266,6 +1284,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
self._close_writer()
|
||||
self._writer_closed_for_reading = False
|
||||
self._current_file_start_frame = global_frame_index
|
||||
|
||||
ep_dict["data/chunk_index"] = chunk_idx
|
||||
ep_dict["data/file_index"] = file_idx
|
||||
@@ -1299,9 +1318,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return metadata
|
||||
|
||||
def _save_episode_video(self, video_key: str, episode_index: int) -> dict:
|
||||
def _save_episode_video(self, video_key: str, episode_index: int, video_path: str | Path | None = None) -> dict:
|
||||
# Encode episode frames into a temporary video
|
||||
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||
if video_path is None:
|
||||
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||
else:
|
||||
ep_path = video_path
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
|
||||
@@ -1425,6 +1447,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
|
||||
def _encode_multiple_temporary_episode_videos(self, video_keys, episode_index):
|
||||
temp_paths = []
|
||||
img_dirs = []
|
||||
for video_key in video_keys:
|
||||
temp_paths.append(Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4")
|
||||
img_dirs.append(self._get_image_file_dir(episode_index, video_key))
|
||||
fps = [self.fps]*len(video_keys)
|
||||
|
||||
with ProcessPoolExecutor(max_workers=len(video_keys)) as executor:
|
||||
executor.map(encode_video_frames,img_dirs,temp_paths,fps)
|
||||
|
||||
for img_dir in img_dirs:
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
return temp_paths
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
@@ -1472,6 +1510,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj.writer = None
|
||||
obj.latest_episode = None
|
||||
obj._current_file_start_frame = None
|
||||
# Initialize tracking for incremental recording
|
||||
obj._lazy_loading = False
|
||||
obj._recorded_frames = 0
|
||||
|
||||
@@ -206,6 +206,11 @@ class ImageTransformsConfig:
|
||||
type="SharpnessJitter",
|
||||
kwargs={"sharpness": (0.5, 1.5)},
|
||||
),
|
||||
"affine": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="RandomAffine",
|
||||
kwargs={"degrees": (-5.0, 5.0), "translate": (0.05, 0.05)},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -217,6 +222,8 @@ def make_transform_from_config(cfg: ImageTransformConfig):
|
||||
return v2.ColorJitter(**cfg.kwargs)
|
||||
elif cfg.type == "SharpnessJitter":
|
||||
return SharpnessJitter(**cfg.kwargs)
|
||||
elif cfg.type == "RandomAffine":
|
||||
return v2.RandomAffine(**cfg.kwargs)
|
||||
else:
|
||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ OLD
|
||||
videos/chunk-000/CAMERA/episode_000000.mp4
|
||||
|
||||
NEW
|
||||
videos/chunk-000/file_000.mp4
|
||||
videos/CAMERA/chunk-000/file_000.mp4
|
||||
-------------------------
|
||||
OLD
|
||||
episodes.jsonl
|
||||
|
||||
@@ -310,7 +310,7 @@ def encode_video_frames(
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
log_level: int | None = av.logging.ERROR,
|
||||
overwrite: bool = False,
|
||||
overwrite: bool = True,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
# Check encoder availability
|
||||
@@ -342,8 +342,8 @@ def encode_video_frames(
|
||||
# Define video output frame size (assuming all input frames are the same size)
|
||||
if len(input_list) == 0:
|
||||
raise FileNotFoundError(f"No images found in {imgs_dir}.")
|
||||
dummy_image = Image.open(input_list[0])
|
||||
width, height = dummy_image.size
|
||||
with Image.open(input_list[0]) as dummy_image:
|
||||
width, height = dummy_image.size
|
||||
|
||||
# Define video codec options
|
||||
video_options = {}
|
||||
@@ -354,6 +354,9 @@ def encode_video_frames(
|
||||
if crf is not None:
|
||||
video_options["crf"] = str(crf)
|
||||
|
||||
#TEMPORARY FIX
|
||||
video_options["preset"] = "12"
|
||||
|
||||
if fast_decode:
|
||||
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
@@ -373,11 +376,12 @@ def encode_video_frames(
|
||||
|
||||
# Loop through input frames and encode them
|
||||
for input_data in input_list:
|
||||
input_image = Image.open(input_data).convert("RGB")
|
||||
input_frame = av.VideoFrame.from_image(input_image)
|
||||
packet = output_stream.encode(input_frame)
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
with Image.open(input_data) as input_image:
|
||||
input_image = input_image.convert("RGB")
|
||||
input_frame = av.VideoFrame.from_image(input_image)
|
||||
packet = output_stream.encode(input_frame)
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
|
||||
# Flush the encoder
|
||||
packet = output_stream.encode()
|
||||
|
||||
@@ -37,6 +37,16 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@property
|
||||
def package_name(self) -> str:
|
||||
"""Package name to import if environment not found in gym registry"""
|
||||
return f"gym_{self.type}"
|
||||
|
||||
@property
|
||||
def gym_id(self) -> str:
|
||||
"""ID string used in gym.make() to instantiate the environment"""
|
||||
return f"{self.package_name}/{self.task}"
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def gym_kwargs(self) -> dict:
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
|
||||
|
||||
@@ -84,17 +85,24 @@ def make_env(
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
package_name = f"gym_{cfg.type}"
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
|
||||
raise e
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.task}"
|
||||
if cfg.gym_id not in gym_registry:
|
||||
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
|
||||
try:
|
||||
importlib.import_module(cfg.package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
|
||||
f"Please install it or check PYTHONPATH."
|
||||
) from e
|
||||
|
||||
if cfg.gym_id not in gym_registry:
|
||||
raise gym.error.NameNotFound(
|
||||
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
|
||||
)
|
||||
|
||||
def _make_one():
|
||||
return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
|
||||
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
|
||||
|
||||
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
|
||||
|
||||
|
||||
@@ -22,18 +22,18 @@ class RobotKinematics:
|
||||
self,
|
||||
urdf_path: str,
|
||||
target_frame_name: str = "gripper_frame_link",
|
||||
joint_names: list[str] = None,
|
||||
joint_names: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize placo-based kinematics solver.
|
||||
|
||||
Args:
|
||||
urdf_path: Path to the robot URDF file
|
||||
target_frame_name: Name of the end-effector frame in the URDF
|
||||
joint_names: List of joint names to use for the kinematics solver
|
||||
urdf_path (str): Path to the robot URDF file
|
||||
target_frame_name (str): Name of the end-effector frame in the URDF
|
||||
joint_names (list[str] | None): List of joint names to use for the kinematics solver
|
||||
"""
|
||||
try:
|
||||
import placo
|
||||
import placo # type: ignore[import-not-found] # C++ library with Python bindings, no type stubs available. TODO: Create stub file or request upstream typing support.
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"placo is required for RobotKinematics. "
|
||||
@@ -52,7 +52,7 @@ class RobotKinematics:
|
||||
# Initialize frame task for IK
|
||||
self.tip_frame = self.solver.add_frame_task(self.target_frame_name, np.eye(4))
|
||||
|
||||
def forward_kinematics(self, joint_pos_deg):
|
||||
def forward_kinematics(self, joint_pos_deg: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Compute forward kinematics for given joint configuration given the target frame name in the constructor.
|
||||
|
||||
@@ -77,8 +77,12 @@ class RobotKinematics:
|
||||
return self.robot.get_T_world_frame(self.target_frame_name)
|
||||
|
||||
def inverse_kinematics(
|
||||
self, current_joint_pos, desired_ee_pose, position_weight=1.0, orientation_weight=0.01
|
||||
):
|
||||
self,
|
||||
current_joint_pos: np.ndarray,
|
||||
desired_ee_pose: np.ndarray,
|
||||
position_weight: float = 1.0,
|
||||
orientation_weight: float = 0.01,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute inverse kinematics using placo solver.
|
||||
|
||||
|
||||
@@ -14,4 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus
|
||||
from .motors_bus import (
|
||||
Motor,
|
||||
MotorCalibration,
|
||||
MotorNormMode,
|
||||
MotorsBus, # Backward compatibility (alias for SerialMotorsBus)
|
||||
MotorsBusBase,
|
||||
SerialMotorsBus,
|
||||
)
|
||||
|
||||
@@ -14,5 +14,5 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_viperx import ViperXConfig
|
||||
from .viperx import ViperX
|
||||
from .damiao import DamiaoMotorsBus
|
||||
from .tables import *
|
||||
905
src/lerobot/motors/damiao/damiao.py
Normal file
905
src/lerobot/motors/damiao/damiao.py
Normal file
@@ -0,0 +1,905 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO(pepijn): add license of: https://github.com/cmjang/DM_Control_Python?tab=MIT-1-ov-file#readme
|
||||
|
||||
import logging
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from functools import cached_property
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import can
|
||||
import numpy as np
|
||||
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode, MotorsBusBase
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
from .tables import (
|
||||
AVAILABLE_BAUDRATES,
|
||||
CAN_CMD_DISABLE,
|
||||
CAN_CMD_ENABLE,
|
||||
CAN_CMD_REFRESH,
|
||||
CAN_CMD_SET_ZERO,
|
||||
CAN_PARAM_ID,
|
||||
DEFAULT_BAUDRATE,
|
||||
DEFAULT_TIMEOUT_MS,
|
||||
MODEL_RESOLUTION,
|
||||
MOTOR_LIMIT_PARAMS,
|
||||
NORMALIZED_DATA,
|
||||
MotorType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NameOrID = Union[str, int]
|
||||
Value = Union[int, float]
|
||||
|
||||
|
||||
class DamiaoMotorsBus(MotorsBusBase):
|
||||
"""
|
||||
The Damiao implementation for a MotorsBus using CAN bus communication.
|
||||
|
||||
This class uses python-can for CAN bus communication with Damiao motors.
|
||||
For more info, see:
|
||||
- python-can documentation: https://python-can.readthedocs.io/en/stable/
|
||||
- Seedstudio documentation: https://wiki.seeedstudio.com/damiao_series/
|
||||
- DM_Control_Python repo: https://github.com/cmjang/DM_Control_Python
|
||||
"""
|
||||
|
||||
# CAN-specific settings
|
||||
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
|
||||
default_baudrate = DEFAULT_BAUDRATE
|
||||
default_timeout = DEFAULT_TIMEOUT_MS
|
||||
|
||||
# Motor configuration
|
||||
model_resolution_table = deepcopy(MODEL_RESOLUTION)
|
||||
normalized_data = deepcopy(NORMALIZED_DATA)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
can_interface: str = "auto",
|
||||
use_can_fd: bool = True,
|
||||
bitrate: int = 1000000,
|
||||
data_bitrate: int | None = 5000000,
|
||||
):
|
||||
"""
|
||||
Initialize the Damiao motors bus.
|
||||
|
||||
Args:
|
||||
port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS)
|
||||
motors: Dictionary mapping motor names to Motor objects
|
||||
calibration: Optional calibration data
|
||||
can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial)
|
||||
use_can_fd: Whether to use CAN FD mode (default: True for OpenArms)
|
||||
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
|
||||
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
|
||||
"""
|
||||
super().__init__(port, motors, calibration)
|
||||
self.port = port
|
||||
self.can_interface = can_interface
|
||||
self.use_can_fd = use_can_fd
|
||||
self.bitrate = bitrate
|
||||
self.data_bitrate = data_bitrate
|
||||
self.canbus = None
|
||||
self._is_connected = False
|
||||
|
||||
# Map motor names to CAN IDs
|
||||
self._motor_can_ids = {}
|
||||
self._recv_id_to_motor = {}
|
||||
|
||||
# Store motor types and recv IDs
|
||||
self._motor_types = {}
|
||||
for name, motor in self.motors.items():
|
||||
if hasattr(motor, "motor_type"):
|
||||
self._motor_types[name] = motor.motor_type
|
||||
else:
|
||||
# Default to DM4310 if not specified
|
||||
self._motor_types[name] = MotorType.DM4310
|
||||
|
||||
# Map recv_id to motor name for filtering responses
|
||||
if hasattr(motor, "recv_id"):
|
||||
self._recv_id_to_motor[motor.recv_id] = name
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the CAN bus is connected."""
|
||||
return self._is_connected and self.canbus is not None
|
||||
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""
|
||||
Open the CAN bus and initialize communication.
|
||||
|
||||
Args:
|
||||
handshake: If True, ping all motors to verify they're present
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is already connected."
|
||||
)
|
||||
|
||||
try:
|
||||
# Auto-detect interface type based on port name
|
||||
if self.can_interface == "auto":
|
||||
if self.port.startswith("/dev/"):
|
||||
# Serial device (macOS/Windows)
|
||||
self.can_interface = "slcan"
|
||||
logger.info(f"Auto-detected slcan interface for port {self.port}")
|
||||
else:
|
||||
# Network interface (Linux)
|
||||
self.can_interface = "socketcan"
|
||||
logger.info(f"Auto-detected socketcan interface for port {self.port}")
|
||||
|
||||
# Connect to CAN bus
|
||||
if self.can_interface == "socketcan":
|
||||
# Linux SocketCAN with CAN FD support
|
||||
if self.use_can_fd and self.data_bitrate is not None:
|
||||
self.canbus = can.interface.Bus(
|
||||
channel=self.port,
|
||||
interface="socketcan",
|
||||
bitrate=self.bitrate,
|
||||
data_bitrate=self.data_bitrate,
|
||||
fd=True
|
||||
)
|
||||
logger.info(f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})")
|
||||
else:
|
||||
self.canbus = can.interface.Bus(
|
||||
channel=self.port,
|
||||
interface="socketcan",
|
||||
bitrate=self.bitrate
|
||||
)
|
||||
logger.info(f"Connected to {self.port} with CAN 2.0 (bitrate={self.bitrate})")
|
||||
elif self.can_interface == "slcan":
|
||||
# Serial Line CAN (macOS, Windows, or USB adapters)
|
||||
# Note: SLCAN typically doesn't support CAN FD
|
||||
self.canbus = can.interface.Bus(
|
||||
channel=self.port,
|
||||
interface="slcan",
|
||||
bitrate=self.bitrate
|
||||
)
|
||||
logger.info(f"Connected to {self.port} with SLCAN (bitrate={self.bitrate})")
|
||||
else:
|
||||
# Generic interface (vector, pcan, etc.)
|
||||
if self.use_can_fd and self.data_bitrate is not None:
|
||||
self.canbus = can.interface.Bus(
|
||||
channel=self.port,
|
||||
interface=self.can_interface,
|
||||
bitrate=self.bitrate,
|
||||
data_bitrate=self.data_bitrate,
|
||||
fd=True
|
||||
)
|
||||
else:
|
||||
self.canbus = can.interface.Bus(
|
||||
channel=self.port,
|
||||
interface=self.can_interface,
|
||||
bitrate=self.bitrate
|
||||
)
|
||||
|
||||
self._is_connected = True
|
||||
|
||||
if handshake:
|
||||
self._handshake()
|
||||
|
||||
logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.")
|
||||
except Exception as e:
|
||||
self._is_connected = False
|
||||
raise ConnectionError(f"Failed to connect to CAN bus: {e}")
|
||||
|
||||
def _handshake(self) -> None:
|
||||
"""Verify all motors are present by refreshing their status."""
|
||||
for motor_name in self.motors:
|
||||
self._refresh_motor(motor_name)
|
||||
time.sleep(0.01) # Small delay between motors
|
||||
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""
|
||||
Close the CAN bus connection.
|
||||
|
||||
Args:
|
||||
disable_torque: If True, disable torque on all motors before disconnecting
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected."
|
||||
)
|
||||
|
||||
if disable_torque:
|
||||
try:
|
||||
self.disable_torque()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to disable torque during disconnect: {e}")
|
||||
|
||||
if self.canbus:
|
||||
self.canbus.shutdown()
|
||||
self.canbus = None
|
||||
self._is_connected = False
|
||||
logger.debug(f"{self.__class__.__name__} disconnected.")
|
||||
|
||||
def configure_motors(self) -> None:
|
||||
"""Configure all motors with default settings."""
|
||||
# Damiao motors don't require much configuration in MIT mode
|
||||
# Just ensure they're enabled
|
||||
for motor in self.motors:
|
||||
self._enable_motor(motor)
|
||||
time.sleep(0.01)
|
||||
|
||||
def _enable_motor(self, motor: NameOrID) -> None:
|
||||
"""Enable a single motor."""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [0xFF] * 7 + [CAN_CMD_ENABLE]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
self._recv_motor_response(expected_recv_id=recv_id)
|
||||
|
||||
def _disable_motor(self, motor: NameOrID) -> None:
|
||||
"""Disable a single motor."""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [0xFF] * 7 + [CAN_CMD_DISABLE]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
self._recv_motor_response(expected_recv_id=recv_id)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Enable torque on selected motors."""
|
||||
motors = self._get_motors_list(motors)
|
||||
for motor in motors:
|
||||
for _ in range(num_retry + 1):
|
||||
try:
|
||||
self._enable_motor(motor)
|
||||
break
|
||||
except Exception as e:
|
||||
if _ == num_retry:
|
||||
raise e
|
||||
time.sleep(0.01)
|
||||
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Disable torque on selected motors."""
|
||||
motors = self._get_motors_list(motors)
|
||||
for motor in motors:
|
||||
for _ in range(num_retry + 1):
|
||||
try:
|
||||
self._disable_motor(motor)
|
||||
break
|
||||
except Exception as e:
|
||||
if _ == num_retry:
|
||||
raise e
|
||||
time.sleep(0.01)
|
||||
|
||||
@contextmanager
|
||||
def torque_disabled(self, motors: str | list[str] | None = None):
|
||||
"""
|
||||
Context manager that guarantees torque is re-enabled.
|
||||
|
||||
This helper is useful to temporarily disable torque when configuring motors.
|
||||
|
||||
Examples:
|
||||
>>> with bus.torque_disabled():
|
||||
... # Safe operations here with torque disabled
|
||||
... pass
|
||||
"""
|
||||
self.disable_torque(motors)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.enable_torque(motors)
|
||||
|
||||
def set_zero_position(self, motors: str | list[str] | None = None) -> None:
|
||||
"""Set current position as zero for selected motors."""
|
||||
motors = self._get_motors_list(motors)
|
||||
for motor in motors:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [0xFF] * 7 + [CAN_CMD_SET_ZERO]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
self._recv_motor_response(expected_recv_id=recv_id)
|
||||
time.sleep(0.01)
|
||||
|
||||
def _refresh_motor(self, motor: NameOrID) -> Optional[can.Message]:
|
||||
"""Refresh motor status and return the response."""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
return self._recv_motor_response(expected_recv_id=recv_id)
|
||||
|
||||
def _recv_motor_response(self, expected_recv_id: Optional[int] = None, timeout: float = 0.001) -> Optional[can.Message]:
|
||||
"""
|
||||
Receive a response from a motor.
|
||||
|
||||
Args:
|
||||
expected_recv_id: If provided, only return messages from this CAN ID
|
||||
timeout: Timeout in seconds (default: 1ms for high-speed operation)
|
||||
|
||||
Returns:
|
||||
CAN message if received, None otherwise
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
messages_seen = []
|
||||
while time.time() - start_time < timeout:
|
||||
msg = self.canbus.recv(timeout=0.0001) # 100us timeout for fast polling
|
||||
if msg:
|
||||
messages_seen.append(f"0x{msg.arbitration_id:02X}")
|
||||
# If no filter specified, return any message
|
||||
if expected_recv_id is None:
|
||||
return msg
|
||||
# Otherwise, only return if it matches the expected recv_id
|
||||
if msg.arbitration_id == expected_recv_id:
|
||||
return msg
|
||||
else:
|
||||
logger.debug(f"Ignoring message from CAN ID 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}")
|
||||
|
||||
# Only log warnings if we're in debug mode to reduce overhead
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
if messages_seen:
|
||||
logger.debug(f"Received {len(messages_seen)} message(s) from IDs {set(messages_seen)}, but expected 0x{expected_recv_id:02X}")
|
||||
else:
|
||||
logger.debug(f"No CAN messages received (expected from 0x{expected_recv_id:02X})")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to receive CAN message: {e}")
|
||||
return None
|
||||
|
||||
def _recv_all_responses(self, expected_recv_ids: list[int], timeout: float = 0.002) -> dict[int, can.Message]:
|
||||
"""
|
||||
Efficiently receive responses from multiple motors at once.
|
||||
Uses the OpenArms pattern: collect all available messages within timeout.
|
||||
|
||||
Args:
|
||||
expected_recv_ids: List of CAN IDs we expect responses from
|
||||
timeout: Total timeout in seconds (default: 2ms)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping recv_id to CAN message
|
||||
"""
|
||||
responses = {}
|
||||
expected_set = set(expected_recv_ids)
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
|
||||
msg = self.canbus.recv(timeout=0.0002) # 200us poll timeout (increased from 100us for better reliability)
|
||||
if msg and msg.arbitration_id in expected_set:
|
||||
responses[msg.arbitration_id] = msg
|
||||
if len(responses) == len(expected_recv_ids):
|
||||
break # Got all responses, exit early
|
||||
except Exception as e:
|
||||
logger.debug(f"Error receiving responses: {e}")
|
||||
|
||||
return responses
|
||||
|
||||
def _mit_control(
|
||||
self,
|
||||
motor: NameOrID,
|
||||
kp: float,
|
||||
kd: float,
|
||||
position_degrees: float,
|
||||
velocity_deg_per_sec: float,
|
||||
torque: float,
|
||||
) -> None:
|
||||
"""
|
||||
Send MIT control command to a motor.
|
||||
|
||||
Args:
|
||||
motor: Motor name or ID
|
||||
kp: Position gain
|
||||
kd: Velocity gain
|
||||
position_degrees: Target position (degrees)
|
||||
velocity_deg_per_sec: Target velocity (degrees/s)
|
||||
torque: Target torque (N·m)
|
||||
"""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types.get(motor_name, MotorType.DM4310)
|
||||
|
||||
# Convert degrees to radians for motor control
|
||||
position_rad = np.radians(position_degrees)
|
||||
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
|
||||
|
||||
# Get motor limits
|
||||
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||
|
||||
# Encode parameters
|
||||
kp_uint = self._float_to_uint(kp, 0, 500, 12)
|
||||
kd_uint = self._float_to_uint(kd, 0, 5, 12)
|
||||
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
|
||||
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
|
||||
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
|
||||
|
||||
# Pack data
|
||||
data = [0] * 8
|
||||
data[0] = (q_uint >> 8) & 0xFF
|
||||
data[1] = q_uint & 0xFF
|
||||
data[2] = dq_uint >> 4
|
||||
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
|
||||
data[4] = kp_uint & 0xFF
|
||||
data[5] = kd_uint >> 4
|
||||
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
|
||||
data[7] = tau_uint & 0xFF
|
||||
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
self._recv_motor_response(expected_recv_id=recv_id)
|
||||
|
||||
def _mit_control_batch(
|
||||
self,
|
||||
commands: Dict[NameOrID, Tuple[float, float, float, float, float]],
|
||||
) -> None:
|
||||
"""
|
||||
Send MIT control commands to multiple motors in batch (optimized).
|
||||
Sends all commands first, then collects responses. Much faster than sequential.
|
||||
|
||||
Args:
|
||||
commands: Dict mapping motor name/ID to (kp, kd, position_deg, velocity_deg/s, torque)
|
||||
Example: {'joint_1': (10.0, 0.5, 45.0, 0.0, 0.0), ...}
|
||||
"""
|
||||
if not commands:
|
||||
return
|
||||
|
||||
expected_recv_ids = []
|
||||
|
||||
# Step 1: Send all MIT control commands (no waiting)
|
||||
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types.get(motor_name, MotorType.DM4310)
|
||||
|
||||
# Convert degrees to radians
|
||||
position_rad = np.radians(position_degrees)
|
||||
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
|
||||
|
||||
# Get motor limits
|
||||
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||
|
||||
# Encode parameters
|
||||
kp_uint = self._float_to_uint(kp, 0, 500, 12)
|
||||
kd_uint = self._float_to_uint(kd, 0, 5, 12)
|
||||
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
|
||||
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
|
||||
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
|
||||
|
||||
# Pack data
|
||||
data = [0] * 8
|
||||
data[0] = (q_uint >> 8) & 0xFF
|
||||
data[1] = q_uint & 0xFF
|
||||
data[2] = dq_uint >> 4
|
||||
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
|
||||
data[4] = kp_uint & 0xFF
|
||||
data[5] = kd_uint >> 4
|
||||
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
|
||||
data[7] = tau_uint & 0xFF
|
||||
|
||||
# Send command
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
|
||||
# Track expected response
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
expected_recv_ids.append(recv_id)
|
||||
|
||||
# Step 2: Collect all responses at once
|
||||
self._recv_all_responses(expected_recv_ids, timeout=0.002)
|
||||
|
||||
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
|
||||
"""Convert float to unsigned integer for CAN transmission."""
|
||||
x = max(x_min, min(x_max, x)) # Clamp to range
|
||||
span = x_max - x_min
|
||||
data_norm = (x - x_min) / span
|
||||
return int(data_norm * ((1 << bits) - 1))
|
||||
|
||||
def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float:
|
||||
"""Convert unsigned integer from CAN to float."""
|
||||
span = x_max - x_min
|
||||
data_norm = float(x) / ((1 << bits) - 1)
|
||||
return data_norm * span + x_min
|
||||
|
||||
def _decode_motor_state(self, data: bytes, motor_type: MotorType) -> Tuple[float, float, float, int, int]:
|
||||
"""
|
||||
Decode motor state from CAN data.
|
||||
|
||||
Returns:
|
||||
Tuple of (position_degrees, velocity_deg_per_sec, torque, temp_mos, temp_rotor)
|
||||
"""
|
||||
if len(data) < 8:
|
||||
raise ValueError("Invalid motor state data")
|
||||
|
||||
# Extract encoded values
|
||||
q_uint = (data[1] << 8) | data[2]
|
||||
dq_uint = (data[3] << 4) | (data[4] >> 4)
|
||||
tau_uint = ((data[4] & 0x0F) << 8) | data[5]
|
||||
t_mos = data[6]
|
||||
t_rotor = data[7]
|
||||
|
||||
# Get motor limits
|
||||
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||
|
||||
# Decode to physical values (radians)
|
||||
position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16)
|
||||
velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12)
|
||||
torque = self._uint_to_float(tau_uint, -tmax, tmax, 12)
|
||||
|
||||
# Convert to degrees
|
||||
position_degrees = np.degrees(position_rad)
|
||||
velocity_deg_per_sec = np.degrees(velocity_rad_per_sec)
|
||||
|
||||
return position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor
|
||||
|
||||
def read(
|
||||
self,
|
||||
data_name: str,
|
||||
motor: str,
|
||||
*,
|
||||
normalize: bool = True,
|
||||
num_retry: int = 0,
|
||||
) -> Value:
|
||||
"""Read a value from a single motor. Positions are always in degrees."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Refresh motor to get latest state
|
||||
msg = self._refresh_motor(motor)
|
||||
if msg is None:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
raise ConnectionError(
|
||||
f"No response from motor '{motor}' (send ID: 0x{motor_id:02X}, recv ID: 0x{recv_id:02X}). "
|
||||
f"Check that: 1) Motor is powered (24V), 2) CAN wiring is correct, "
|
||||
f"3) Motor IDs are configured correctly using Damiao Debugging Tools"
|
||||
)
|
||||
|
||||
motor_type = self._motor_types.get(motor, MotorType.DM4310)
|
||||
position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
|
||||
|
||||
# Return requested data (already in degrees for position/velocity)
|
||||
if data_name == "Present_Position":
|
||||
value = position_degrees
|
||||
elif data_name == "Present_Velocity":
|
||||
value = velocity_deg_per_sec
|
||||
elif data_name == "Present_Torque":
|
||||
value = torque
|
||||
elif data_name == "Temperature_MOS":
|
||||
value = t_mos
|
||||
elif data_name == "Temperature_Rotor":
|
||||
value = t_rotor
|
||||
else:
|
||||
raise ValueError(f"Unknown data_name: {data_name}")
|
||||
|
||||
# For Damiao, positions are always in degrees, no normalization needed
|
||||
# We keep the normalize parameter for compatibility but don't use it
|
||||
return value
|
||||
|
||||
def write(
|
||||
self,
|
||||
data_name: str,
|
||||
motor: str,
|
||||
value: Value,
|
||||
*,
|
||||
normalize: bool = True,
|
||||
num_retry: int = 0,
|
||||
) -> None:
|
||||
"""Write a value to a single motor. Positions are always in degrees."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Value is expected to be in degrees for positions
|
||||
if data_name == "Goal_Position":
|
||||
# Use MIT control with position in degrees
|
||||
self._mit_control(motor, 10.0, 0.5, value, 0, 0)
|
||||
else:
|
||||
raise ValueError(f"Writing {data_name} not supported in MIT mode")
|
||||
|
||||
def sync_read(
|
||||
self,
|
||||
data_name: str,
|
||||
motors: str | list[str] | None = None,
|
||||
*,
|
||||
normalize: bool = True,
|
||||
num_retry: int = 0,
|
||||
) -> Dict[str, Value]:
|
||||
"""
|
||||
Read the same value from multiple motors simultaneously.
|
||||
Uses batched operations: sends all refresh commands, then collects all responses.
|
||||
This is MUCH faster than sequential reads (OpenArms pattern).
|
||||
"""
|
||||
motors = self._get_motors_list(motors)
|
||||
result = {}
|
||||
|
||||
# Step 1: Send refresh commands to ALL motors first (no waiting)
|
||||
for motor in motors:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
|
||||
# Step 2: Collect all responses at once (batch receive)
|
||||
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in motors]
|
||||
responses = self._recv_all_responses(expected_recv_ids, timeout=0.01) # 10ms total timeout
|
||||
|
||||
# Step 3: Parse responses
|
||||
for motor in motors:
|
||||
try:
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
msg = responses.get(recv_id)
|
||||
|
||||
if msg is None:
|
||||
logger.warning(f"No response from motor '{motor}' (recv ID: 0x{recv_id:02X})")
|
||||
result[motor] = 0.0
|
||||
continue
|
||||
|
||||
motor_type = self._motor_types.get(motor, MotorType.DM4310)
|
||||
position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
|
||||
|
||||
# Return requested data
|
||||
if data_name == "Present_Position":
|
||||
value = position_degrees
|
||||
elif data_name == "Present_Velocity":
|
||||
value = velocity_deg_per_sec
|
||||
elif data_name == "Present_Torque":
|
||||
value = torque
|
||||
elif data_name == "Temperature_MOS":
|
||||
value = t_mos
|
||||
elif data_name == "Temperature_Rotor":
|
||||
value = t_rotor
|
||||
else:
|
||||
raise ValueError(f"Unknown data_name: {data_name}")
|
||||
|
||||
result[motor] = value
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read {data_name} from {motor}: {e}")
|
||||
result[motor] = 0.0
|
||||
|
||||
return result
|
||||
|
||||
def sync_read_all_states(
|
||||
self,
|
||||
motors: str | list[str] | None = None,
|
||||
*,
|
||||
num_retry: int = 0,
|
||||
) -> Dict[str, Dict[str, Value]]:
|
||||
"""
|
||||
Read ALL motor states (position, velocity, torque) from multiple motors in ONE refresh cycle.
|
||||
This is 3x faster than calling sync_read() three times separately.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping motor names to state dicts with keys: 'position', 'velocity', 'torque'
|
||||
Example: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
|
||||
"""
|
||||
motors = self._get_motors_list(motors)
|
||||
result = {}
|
||||
|
||||
# Step 1: Send refresh commands to ALL motors first (with small delays to reduce bus congestion)
|
||||
for motor in motors:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
time.sleep(0.0001) # 100us delay between commands to reduce bus congestion
|
||||
|
||||
# Step 2: Collect all responses at once (batch receive)
|
||||
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in motors]
|
||||
responses = self._recv_all_responses(expected_recv_ids, timeout=0.015) # 15ms timeout (increased for reliability)
|
||||
|
||||
# Step 3: Parse responses and extract ALL state values
|
||||
for motor in motors:
|
||||
try:
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
msg = responses.get(recv_id)
|
||||
|
||||
if msg is None:
|
||||
logger.warning(f"No response from motor '{motor}' (recv ID: 0x{recv_id:02X})")
|
||||
result[motor] = {"position": 0.0, "velocity": 0.0, "torque": 0.0}
|
||||
continue
|
||||
|
||||
motor_type = self._motor_types.get(motor, MotorType.DM4310)
|
||||
position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
|
||||
|
||||
# Return all state values in one dict
|
||||
result[motor] = {
|
||||
"position": position_degrees,
|
||||
"velocity": velocity_deg_per_sec,
|
||||
"torque": torque,
|
||||
"temp_mos": t_mos,
|
||||
"temp_rotor": t_rotor,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read state from {motor}: {e}")
|
||||
result[motor] = {"position": 0.0, "velocity": 0.0, "torque": 0.0}
|
||||
|
||||
return result
|
||||
|
||||
def sync_write(
|
||||
self,
|
||||
data_name: str,
|
||||
values: Dict[str, Value],
|
||||
*,
|
||||
normalize: bool = True,
|
||||
num_retry: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Write different values to multiple motors simultaneously. Positions are always in degrees.
|
||||
Uses batched operations: sends all commands first, then collects responses (OpenArms pattern).
|
||||
"""
|
||||
if data_name == "Goal_Position":
|
||||
# Step 1: Send all MIT control commands first (no waiting)
|
||||
for motor, value_degrees in values.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types.get(motor_name, MotorType.DM4310)
|
||||
|
||||
# Convert degrees to radians
|
||||
position_rad = np.radians(value_degrees)
|
||||
|
||||
# Default gains for position control
|
||||
kp, kd = 10.0, 0.5
|
||||
|
||||
# Get motor limits and encode parameters
|
||||
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||
kp_uint = self._float_to_uint(kp, 0, 500, 12)
|
||||
kd_uint = self._float_to_uint(kd, 0, 5, 12)
|
||||
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
|
||||
dq_uint = self._float_to_uint(0, -vmax, vmax, 12)
|
||||
tau_uint = self._float_to_uint(0, -tmax, tmax, 12)
|
||||
|
||||
# Pack data
|
||||
data = [0] * 8
|
||||
data[0] = (q_uint >> 8) & 0xFF
|
||||
data[1] = q_uint & 0xFF
|
||||
data[2] = dq_uint >> 4
|
||||
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
|
||||
data[4] = kp_uint & 0xFF
|
||||
data[5] = kd_uint >> 4
|
||||
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
|
||||
data[7] = tau_uint & 0xFF
|
||||
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
time.sleep(0.0001) # 100us delay between commands to reduce bus congestion
|
||||
|
||||
# Step 2: Collect all responses at once
|
||||
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in values.keys()]
|
||||
self._recv_all_responses(expected_recv_ids, timeout=0.015) # 15ms timeout (increased for reliability)
|
||||
else:
|
||||
# Fall back to individual writes for other data types
|
||||
for motor, value in values.items():
|
||||
self.write(data_name, motor, value, normalize=normalize, num_retry=num_retry)
|
||||
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
"""Read calibration data from motors."""
|
||||
# Damiao motors don't store calibration internally
|
||||
# Return existing calibration or empty dict
|
||||
return self.calibration if self.calibration else {}
|
||||
|
||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
|
||||
"""Write calibration data to motors."""
|
||||
# Damiao motors don't store calibration internally
|
||||
# Just cache it in memory
|
||||
if cache:
|
||||
self.calibration = calibration_dict
|
||||
|
||||
def record_ranges_of_motion(
|
||||
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
|
||||
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
|
||||
"""
|
||||
Interactively record the min/max values of each motor in degrees.
|
||||
|
||||
Move the joints by hand (with torque disabled) while the method streams live positions.
|
||||
Press Enter to finish.
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors.keys())
|
||||
elif isinstance(motors, (str, int)):
|
||||
motors = [motors]
|
||||
|
||||
# Disable torque for manual movement
|
||||
self.disable_torque(motors)
|
||||
time.sleep(0.1)
|
||||
|
||||
# Get initial positions (already in degrees)
|
||||
start_positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
mins = start_positions.copy()
|
||||
maxes = start_positions.copy()
|
||||
|
||||
print("\nMove joints through their full range of motion. Press ENTER when done.")
|
||||
user_pressed_enter = False
|
||||
|
||||
while not user_pressed_enter:
|
||||
positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
|
||||
for motor in motors:
|
||||
if motor in positions:
|
||||
mins[motor] = min(positions[motor], mins.get(motor, positions[motor]))
|
||||
maxes[motor] = max(positions[motor], maxes.get(motor, positions[motor]))
|
||||
|
||||
if display_values:
|
||||
print("\n" + "=" * 50)
|
||||
print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}")
|
||||
print("-" * 50)
|
||||
for motor in motors:
|
||||
if motor in positions:
|
||||
print(f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}")
|
||||
|
||||
if enter_pressed():
|
||||
user_pressed_enter = True
|
||||
|
||||
if display_values and not user_pressed_enter:
|
||||
# Move cursor up to overwrite the previous output
|
||||
move_cursor_up(len(motors) + 4)
|
||||
|
||||
time.sleep(0.05)
|
||||
|
||||
# Re-enable torque
|
||||
self.enable_torque(motors)
|
||||
|
||||
# Validate ranges
|
||||
for motor in motors:
|
||||
if motor in mins and motor in maxes:
|
||||
if abs(maxes[motor] - mins[motor]) < 5.0: # At least 5 degrees of range
|
||||
raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)")
|
||||
|
||||
return mins, maxes
|
||||
|
||||
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
|
||||
"""Convert motor specification to list of motor names."""
|
||||
if motors is None:
|
||||
return list(self.motors.keys())
|
||||
elif isinstance(motors, str):
|
||||
return [motors]
|
||||
elif isinstance(motors, list):
|
||||
return motors
|
||||
else:
|
||||
raise TypeError(f"Invalid motors type: {type(motors)}")
|
||||
|
||||
def _get_motor_id(self, motor: NameOrID) -> int:
|
||||
"""Get CAN ID for a motor."""
|
||||
if isinstance(motor, str):
|
||||
if motor in self.motors:
|
||||
return self.motors[motor].id
|
||||
else:
|
||||
raise ValueError(f"Unknown motor: {motor}")
|
||||
else:
|
||||
return motor
|
||||
|
||||
def _get_motor_name(self, motor: NameOrID) -> str:
|
||||
"""Get motor name from name or ID."""
|
||||
if isinstance(motor, str):
|
||||
return motor
|
||||
else:
|
||||
for name, m in self.motors.items():
|
||||
if m.id == motor:
|
||||
return name
|
||||
raise ValueError(f"Unknown motor ID: {motor}")
|
||||
|
||||
def _get_motor_recv_id(self, motor: NameOrID) -> Optional[int]:
|
||||
"""Get motor recv_id from name or ID."""
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_obj = self.motors.get(motor_name)
|
||||
if motor_obj and hasattr(motor_obj, "recv_id"):
|
||||
return motor_obj.recv_id
|
||||
return None
|
||||
|
||||
@cached_property
|
||||
def is_calibrated(self) -> bool:
|
||||
"""Check if motors are calibrated."""
|
||||
return bool(self.calibration)
|
||||
209
src/lerobot/motors/damiao/tables.py
Normal file
209
src/lerobot/motors/damiao/tables.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration tables for Damiao motors."""
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
# Motor type definitions
|
||||
class MotorType(IntEnum):
|
||||
DM3507 = 0
|
||||
DM4310 = 1
|
||||
DM4310_48V = 2
|
||||
DM4340 = 3
|
||||
DM4340_48V = 4
|
||||
DM6006 = 5
|
||||
DM8006 = 6
|
||||
DM8009 = 7
|
||||
DM10010L = 8
|
||||
DM10010 = 9
|
||||
DMH3510 = 10
|
||||
DMH6215 = 11
|
||||
DMG6220 = 12
|
||||
|
||||
# Control modes
|
||||
class ControlMode(IntEnum):
|
||||
MIT = 1
|
||||
POS_VEL = 2
|
||||
VEL = 3
|
||||
TORQUE_POS = 4
|
||||
|
||||
# Motor variable IDs (RID)
|
||||
class MotorVariable(IntEnum):
|
||||
UV_VALUE = 0
|
||||
KT_VALUE = 1
|
||||
OT_VALUE = 2
|
||||
OC_VALUE = 3
|
||||
ACC = 4
|
||||
DEC = 5
|
||||
MAX_SPD = 6
|
||||
MST_ID = 7
|
||||
ESC_ID = 8
|
||||
TIMEOUT = 9
|
||||
CTRL_MODE = 10
|
||||
DAMP = 11
|
||||
INERTIA = 12
|
||||
HW_VER = 13
|
||||
SW_VER = 14
|
||||
SN = 15
|
||||
NPP = 16
|
||||
RS = 17
|
||||
LS = 18
|
||||
FLUX = 19
|
||||
GR = 20
|
||||
PMAX = 21
|
||||
VMAX = 22
|
||||
TMAX = 23
|
||||
I_BW = 24
|
||||
KP_ASR = 25
|
||||
KI_ASR = 26
|
||||
KP_APR = 27
|
||||
KI_APR = 28
|
||||
OV_VALUE = 29
|
||||
GREF = 30
|
||||
DETA = 31
|
||||
V_BW = 32
|
||||
IQ_C1 = 33
|
||||
VL_C1 = 34
|
||||
CAN_BR = 35
|
||||
SUB_VER = 36
|
||||
U_OFF = 50
|
||||
V_OFF = 51
|
||||
K1 = 52
|
||||
K2 = 53
|
||||
M_OFF = 54
|
||||
DIR = 55
|
||||
P_M = 80
|
||||
XOUT = 81
|
||||
|
||||
# Motor limit parameters [PMAX, VMAX, TMAX]
|
||||
# PMAX: Maximum position (rad)
|
||||
# VMAX: Maximum velocity (rad/s)
|
||||
# TMAX: Maximum torque (N·m)
|
||||
MOTOR_LIMIT_PARAMS = {
|
||||
MotorType.DM3507: (12.5, 30, 10),
|
||||
MotorType.DM4310: (12.5, 30, 10),
|
||||
MotorType.DM4310_48V: (12.5, 50, 10),
|
||||
MotorType.DM4340: (12.5, 8, 28),
|
||||
MotorType.DM4340_48V: (12.5, 10, 28),
|
||||
MotorType.DM6006: (12.5, 45, 20),
|
||||
MotorType.DM8006: (12.5, 45, 40),
|
||||
MotorType.DM8009: (12.5, 45, 54),
|
||||
MotorType.DM10010L: (12.5, 25, 200),
|
||||
MotorType.DM10010: (12.5, 20, 200),
|
||||
MotorType.DMH3510: (12.5, 280, 1),
|
||||
MotorType.DMH6215: (12.5, 45, 10),
|
||||
MotorType.DMG6220: (12.5, 45, 10),
|
||||
}
|
||||
|
||||
# Motor model names
|
||||
MODEL_NAMES = {
|
||||
MotorType.DM3507: "dm3507",
|
||||
MotorType.DM4310: "dm4310",
|
||||
MotorType.DM4310_48V: "dm4310_48v",
|
||||
MotorType.DM4340: "dm4340",
|
||||
MotorType.DM4340_48V: "dm4340_48v",
|
||||
MotorType.DM6006: "dm6006",
|
||||
MotorType.DM8006: "dm8006",
|
||||
MotorType.DM8009: "dm8009",
|
||||
MotorType.DM10010L: "dm10010l",
|
||||
MotorType.DM10010: "dm10010",
|
||||
MotorType.DMH3510: "dmh3510",
|
||||
MotorType.DMH6215: "dmh6215",
|
||||
MotorType.DMG6220: "dmg6220",
|
||||
}
|
||||
|
||||
# Motor resolution table (encoder counts per revolution)
|
||||
MODEL_RESOLUTION = {
|
||||
"dm3507": 65536,
|
||||
"dm4310": 65536,
|
||||
"dm4310_48v": 65536,
|
||||
"dm4340": 65536,
|
||||
"dm4340_48v": 65536,
|
||||
"dm6006": 65536,
|
||||
"dm8006": 65536,
|
||||
"dm8009": 65536,
|
||||
"dm10010l": 65536,
|
||||
"dm10010": 65536,
|
||||
"dmh3510": 65536,
|
||||
"dmh6215": 65536,
|
||||
"dmg6220": 65536,
|
||||
}
|
||||
|
||||
# CAN baudrates supported by Damiao motors
|
||||
AVAILABLE_BAUDRATES = [
|
||||
125000, # 0: 125 kbps
|
||||
200000, # 1: 200 kbps
|
||||
250000, # 2: 250 kbps
|
||||
500000, # 3: 500 kbps
|
||||
1000000, # 4: 1 mbps (default for OpenArms)
|
||||
2000000, # 5: 2 mbps
|
||||
2500000, # 6: 2.5 mbps
|
||||
3200000, # 7: 3.2 mbps
|
||||
4000000, # 8: 4 mbps
|
||||
5000000, # 9: 5 mbps
|
||||
]
|
||||
DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms
|
||||
|
||||
# Default timeout in milliseconds
|
||||
DEFAULT_TIMEOUT_MS = 1000
|
||||
|
||||
# Data that should be normalized
|
||||
NORMALIZED_DATA = ["Present_Position", "Goal_Position"]
|
||||
|
||||
# OpenArms specific configurations
|
||||
# Based on: https://docs.openarm.dev/software/setup/configure-test
|
||||
# OpenArms has 7 DOF per arm (14 total for dual arm)
|
||||
OPENARMS_ARM_MOTOR_IDS = {
|
||||
"joint_1": {"send": 0x01, "recv": 0x11}, # J1 - Shoulder pan
|
||||
"joint_2": {"send": 0x02, "recv": 0x12}, # J2 - Shoulder lift
|
||||
"joint_3": {"send": 0x03, "recv": 0x13}, # J3 - Elbow flex
|
||||
"joint_4": {"send": 0x04, "recv": 0x14}, # J4 - Wrist flex
|
||||
"joint_5": {"send": 0x05, "recv": 0x15}, # J5 - Wrist roll
|
||||
"joint_6": {"send": 0x06, "recv": 0x16}, # J6 - Wrist pitch
|
||||
"joint_7": {"send": 0x07, "recv": 0x17}, # J7 - Wrist rotation
|
||||
}
|
||||
|
||||
OPENARMS_GRIPPER_MOTOR_IDS = {
|
||||
"gripper": {"send": 0x08, "recv": 0x18}, # J8 - Gripper
|
||||
}
|
||||
|
||||
# Default motor types for OpenArms
|
||||
OPENARMS_DEFAULT_MOTOR_TYPES = {
|
||||
"joint_1": MotorType.DM8009, # Shoulder pan - high torque
|
||||
"joint_2": MotorType.DM8009, # Shoulder lift - high torque
|
||||
"joint_3": MotorType.DM4340, # Shoulder rotation
|
||||
"joint_4": MotorType.DM4340, # Elbow flex
|
||||
"joint_5": MotorType.DM4310, # Wrist roll
|
||||
"joint_6": MotorType.DM4310, # Wrist pitch
|
||||
"joint_7": MotorType.DM4310, # Wrist rotation
|
||||
"gripper": MotorType.DM4310, # Gripper
|
||||
}
|
||||
|
||||
# MIT control parameter ranges
|
||||
MIT_KP_RANGE = (0.0, 500.0)
|
||||
MIT_KD_RANGE = (0.0, 5.0)
|
||||
|
||||
# CAN frame command IDs
|
||||
CAN_CMD_ENABLE = 0xFC
|
||||
CAN_CMD_DISABLE = 0xFD
|
||||
CAN_CMD_SET_ZERO = 0xFE
|
||||
CAN_CMD_REFRESH = 0xCC
|
||||
CAN_CMD_QUERY_PARAM = 0x33
|
||||
CAN_CMD_WRITE_PARAM = 0x55
|
||||
CAN_CMD_SAVE_PARAM = 0xAA
|
||||
|
||||
# CAN ID for parameter operations
|
||||
CAN_PARAM_ID = 0x7FF
|
||||
@@ -24,7 +24,7 @@ from enum import Enum
|
||||
|
||||
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
|
||||
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
||||
from .tables import (
|
||||
AVAILABLE_BAUDRATES,
|
||||
MODEL_BAUDRATE_TABLE,
|
||||
@@ -60,7 +60,7 @@ class OperatingMode(Enum):
|
||||
|
||||
# This mode controls position. This mode is identical to the Multi-turn Position Control from existing
|
||||
# DYNAMIXEL. 512 turns are supported(-256[rev] ~ 256[rev]). This mode is ideal for multi-turn wrists or
|
||||
# conveyer systems or a system that requires an additional reduction gear. Note that Max Position
|
||||
# conveyor systems or a system that requires an additional reduction gear. Note that Max Position
|
||||
# Limit(48), Min Position Limit(52) are not used on Extended Position Control Mode.
|
||||
EXTENDED_POSITION = 4
|
||||
|
||||
@@ -100,7 +100,7 @@ def _split_into_byte_chunks(value: int, length: int) -> list[int]:
|
||||
return data
|
||||
|
||||
|
||||
class DynamixelMotorsBus(MotorsBus):
|
||||
class DynamixelMotorsBus(SerialMotorsBus):
|
||||
"""
|
||||
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
|
||||
the motors. For more info, see the Dynamixel SDK Documentation:
|
||||
|
||||
@@ -19,7 +19,7 @@ from pprint import pformat
|
||||
|
||||
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
|
||||
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
||||
from .tables import (
|
||||
FIRMWARE_MAJOR_VERSION,
|
||||
FIRMWARE_MINOR_VERSION,
|
||||
@@ -96,7 +96,7 @@ def patch_setPacketTimeout(self, packet_length): # noqa: N802
|
||||
self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50
|
||||
|
||||
|
||||
class FeetechMotorsBus(MotorsBus):
|
||||
class FeetechMotorsBus(SerialMotorsBus):
|
||||
"""
|
||||
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the
|
||||
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
|
||||
@@ -165,7 +165,7 @@ class FeetechMotorsBus(MotorsBus):
|
||||
|
||||
def _handshake(self) -> None:
|
||||
self._assert_motors_exist()
|
||||
self._assert_same_firmware()
|
||||
#self._assert_same_firmware()
|
||||
|
||||
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
|
||||
if self.protocol_version == 0:
|
||||
|
||||
@@ -206,8 +206,12 @@ MODEL_BAUDRATE_TABLE = {
|
||||
# Sign-Magnitude encoding bits
|
||||
STS_SMS_SERIES_ENCODINGS_TABLE = {
|
||||
"Homing_Offset": 11,
|
||||
"Goal_Position": 15,
|
||||
"Goal_Velocity": 15,
|
||||
"Goal_Speed": 15,
|
||||
"Present_Position": 15,
|
||||
"Present_Velocity": 15,
|
||||
"Present_Speed": 15,
|
||||
}
|
||||
|
||||
MODEL_ENCODING_TABLE = {
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
# TODO(aliberts): Add block noqa when feature below is available
|
||||
# https://github.com/astral-sh/ruff/issues/3711
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
@@ -41,6 +43,92 @@ Value: TypeAlias = int | float
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MotorsBusBase(abc.ABC):
|
||||
"""
|
||||
Base class for all motor bus implementations.
|
||||
|
||||
This is a minimal interface that all motor buses must implement, regardless of their
|
||||
communication protocol (serial, CAN, etc.).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
self.calibration = calibration if calibration else {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""Establish connection to the motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""Disconnect from the motors."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected to the motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, data_name: str, motor: str, *, normalize: bool = True, num_retry: int = 0) -> Value:
|
||||
"""Read a value from a single motor."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def write(
|
||||
self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0
|
||||
) -> None:
|
||||
"""Write a value to a single motor."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_read(
|
||||
self, data_name: str, motors: str | list[str] | None = None, *, normalize: bool = True
|
||||
) -> dict[str, Value]:
|
||||
"""Read a value from multiple motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_write(
|
||||
self,
|
||||
data_name: str,
|
||||
values: Value | dict[str, Value],
|
||||
motors: str | list[str] | None = None,
|
||||
*,
|
||||
normalize: bool = True,
|
||||
) -> None:
|
||||
"""Write values to multiple motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Enable torque on selected motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Disable torque on selected motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
"""Read calibration parameters from the motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
|
||||
"""Write calibration parameters to the motors."""
|
||||
pass
|
||||
|
||||
|
||||
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
|
||||
ctrl_table = model_ctrl_table.get(model)
|
||||
if ctrl_table is None:
|
||||
@@ -203,15 +291,15 @@ class GroupSyncWrite(Protocol):
|
||||
def txPacket(self): ...
|
||||
|
||||
|
||||
class MotorsBus(abc.ABC):
|
||||
class SerialMotorsBus(MotorsBusBase):
|
||||
"""
|
||||
A MotorsBus allows to efficiently read and write to the attached motors.
|
||||
A SerialMotorsBus allows to efficiently read and write to motors connected via serial communication.
|
||||
It represents several motors daisy-chained together and connected through a serial port.
|
||||
There are currently two implementations of this abstract class:
|
||||
There are currently two implementations of this class:
|
||||
- DynamixelMotorsBus
|
||||
- FeetechMotorsBus
|
||||
|
||||
Note: This class may evolve in the future should we add support for other types of bus.
|
||||
This class is specifically for serial-based motor protocols (Dynamixel, Feetech, etc.).
|
||||
|
||||
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
||||
To find the port, you can run our utility script:
|
||||
@@ -1212,3 +1300,7 @@ class MotorsBus(abc.ABC):
|
||||
for id_, value in ids_values.items():
|
||||
data = self._serialize_data(value, length)
|
||||
self.sync_writer.addParam(id_, data)
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
MotorsBus = SerialMotorsBus
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
@@ -79,7 +80,11 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Used by Physical Intelligence to train Pi0"""
|
||||
"""Used by Physical Intelligence to train Pi0.
|
||||
|
||||
Automatically scales warmup and decay steps if num_training_steps < num_decay_steps.
|
||||
This ensures the learning rate schedule completes properly even with shorter training runs.
|
||||
"""
|
||||
|
||||
num_warmup_steps: int
|
||||
num_decay_steps: int
|
||||
@@ -87,23 +92,39 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
decay_lr: float
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
del num_training_steps
|
||||
# Auto-scale scheduler parameters if training steps are shorter than configured decay steps
|
||||
actual_warmup_steps = self.num_warmup_steps
|
||||
actual_decay_steps = self.num_decay_steps
|
||||
|
||||
if num_training_steps < self.num_decay_steps:
|
||||
# Calculate scaling factor to fit the schedule into the available training steps
|
||||
scale_factor = num_training_steps / self.num_decay_steps
|
||||
actual_warmup_steps = int(self.num_warmup_steps * scale_factor)
|
||||
actual_decay_steps = num_training_steps
|
||||
|
||||
logging.info(
|
||||
f"Auto-scaling LR scheduler: "
|
||||
f"num_training_steps ({num_training_steps}) < num_decay_steps ({self.num_decay_steps}). "
|
||||
f"Scaling warmup: {self.num_warmup_steps} → {actual_warmup_steps}, "
|
||||
f"decay: {self.num_decay_steps} → {actual_decay_steps} "
|
||||
f"(scale factor: {scale_factor:.3f})"
|
||||
)
|
||||
|
||||
def lr_lambda(current_step):
|
||||
def linear_warmup_schedule(current_step):
|
||||
if current_step <= 0:
|
||||
return 1 / (self.num_warmup_steps + 1)
|
||||
frac = 1 - current_step / self.num_warmup_steps
|
||||
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
|
||||
return 1 / (actual_warmup_steps + 1)
|
||||
frac = 1 - current_step / actual_warmup_steps
|
||||
return (1 / (actual_warmup_steps + 1) - 1) * frac + 1
|
||||
|
||||
def cosine_decay_schedule(current_step):
|
||||
step = min(current_step, self.num_decay_steps)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
||||
step = min(current_step, actual_decay_steps)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / actual_decay_steps))
|
||||
alpha = self.decay_lr / self.peak_lr
|
||||
decayed = (1 - alpha) * cosine_decay + alpha
|
||||
return decayed
|
||||
|
||||
if current_step < self.num_warmup_steps:
|
||||
if current_step < actual_warmup_steps:
|
||||
return linear_warmup_schedule(current_step)
|
||||
|
||||
return cosine_decay_schedule(current_step)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
@@ -29,4 +30,5 @@ __all__ = [
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
"GrootConfig",
|
||||
]
|
||||
|
||||
@@ -626,8 +626,8 @@ class ACTDecoderLayer(nn.Module):
|
||||
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
|
||||
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
|
||||
cross-attending with.
|
||||
decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
|
||||
encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
|
||||
encoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
|
||||
decoder_pos_embed: (DS, 1, C) positional embedding for the queries (from the decoder).
|
||||
Returns:
|
||||
(DS, B, C) tensor of decoder output features.
|
||||
"""
|
||||
|
||||
@@ -30,6 +30,7 @@ from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.envs.utils import env_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
@@ -101,6 +102,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "groot":
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
|
||||
return GrootPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -142,6 +147,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
return GrootConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
@@ -199,6 +206,27 @@ def make_pre_post_processors(
|
||||
policy configuration type.
|
||||
"""
|
||||
if pretrained_path:
|
||||
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
|
||||
if isinstance(policy_cfg, GrootConfig):
|
||||
# GROOT handles normalization in groot_pack_inputs_v3 step
|
||||
# Need to override both stats AND normalize_min_max since saved config might be empty
|
||||
preprocessor_overrides = {}
|
||||
postprocessor_overrides = {}
|
||||
preprocessor_overrides["groot_pack_inputs_v3"] = {
|
||||
"stats": kwargs.get("dataset_stats"),
|
||||
"normalize_min_max": True,
|
||||
}
|
||||
|
||||
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
|
||||
env_action_dim = policy_cfg.output_features["action"].shape[0]
|
||||
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
|
||||
"stats": kwargs.get("dataset_stats"),
|
||||
"normalize_min_max": True,
|
||||
"env_action_dim": env_action_dim,
|
||||
}
|
||||
kwargs["preprocessor_overrides"] = preprocessor_overrides
|
||||
kwargs["postprocessor_overrides"] = postprocessor_overrides
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
@@ -293,6 +321,14 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, GrootConfig):
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
processors = make_groot_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||
|
||||
@@ -303,6 +339,7 @@ def make_policy(
|
||||
cfg: PreTrainedConfig,
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
env_cfg: EnvConfig | None = None,
|
||||
rename_map: dict[str, str] | None = None,
|
||||
) -> PreTrainedPolicy:
|
||||
"""
|
||||
Instantiate a policy model.
|
||||
@@ -319,6 +356,8 @@ def make_policy(
|
||||
statistics for normalization layers.
|
||||
env_cfg: Environment configuration used to infer feature shapes and types.
|
||||
One of `ds_meta` or `env_cfg` must be provided.
|
||||
rename_map: Optional mapping of dataset or environment feature keys to match
|
||||
expected policy feature names (e.g., `"left"` → `"camera1"`).
|
||||
|
||||
Returns:
|
||||
An instantiated and device-placed policy model.
|
||||
@@ -360,8 +399,10 @@ def make_policy(
|
||||
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
if not cfg.output_features:
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
if not cfg.input_features:
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
if cfg.pretrained_path:
|
||||
@@ -378,4 +419,21 @@ def make_policy(
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
if not rename_map:
|
||||
expected_features = set(cfg.input_features.keys()) | set(cfg.output_features.keys())
|
||||
provided_features = set(features.keys())
|
||||
if expected_features and provided_features != expected_features:
|
||||
missing = expected_features - provided_features
|
||||
extra = provided_features - expected_features
|
||||
# TODO (jadechoghari): provide a dynamic rename map suggestion to the user.
|
||||
raise ValueError(
|
||||
f"Feature mismatch between dataset/environment and policy config.\n"
|
||||
f"- Missing features: {sorted(missing) if missing else 'None'}\n"
|
||||
f"- Extra features: {sorted(extra) if extra else 'None'}\n\n"
|
||||
f"Please ensure your dataset and policy use consistent feature names.\n"
|
||||
f"If your dataset uses different observation keys (e.g., cameras named differently), "
|
||||
f"use the `--rename_map` argument, for example:\n"
|
||||
f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", '
|
||||
f'"observation.images.top": "observation.images.camera2"}}\''
|
||||
)
|
||||
return policy
|
||||
|
||||
1
src/lerobot/policies/groot/README.md
Symbolic link
1
src/lerobot/policies/groot/README.md
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_groot_README.md
|
||||
21
src/lerobot/policies/groot/__init__.py
Normal file
21
src/lerobot/policies/groot/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Nvidia and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_groot import GrootConfig
|
||||
from .modeling_groot import GrootPolicy
|
||||
from .processor_groot import make_groot_pre_post_processors
|
||||
|
||||
__all__ = ["GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
|
||||
14
src/lerobot/policies/groot/action_head/__init__.py
Normal file
14
src/lerobot/policies/groot/action_head/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
54
src/lerobot/policies/groot/action_head/action_encoder.py
Normal file
54
src/lerobot/policies/groot/action_head/action_encoder.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
"""
|
||||
Produces a sinusoidal encoding of shape (B, T, w)
|
||||
given timesteps of shape (B, T).
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps):
|
||||
# timesteps: shape (B, T)
|
||||
# We'll compute sin/cos frequencies across dim T
|
||||
timesteps = timesteps.float() # ensure float
|
||||
|
||||
b, t = timesteps.shape
|
||||
device = timesteps.device
|
||||
|
||||
half_dim = self.embedding_dim // 2
|
||||
# typical log space frequencies for sinusoidal encoding
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
# Expand timesteps to (B, T, 1) then multiply
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
|
||||
|
||||
sin = torch.sin(freqs)
|
||||
cos = torch.cos(freqs)
|
||||
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
||||
|
||||
return enc
|
||||
370
src/lerobot/policies/groot/action_head/cross_attention_dit.py
Executable file
370
src/lerobot/policies/groot/action_head/cross_attention_dit.py
Executable file
@@ -0,0 +1,370 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import (
|
||||
SinusoidalPositionalEmbedding,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps):
|
||||
dtype = next(self.parameters()).dtype
|
||||
timesteps_proj = self.time_proj(timesteps).to(dtype)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-5,
|
||||
chunk_dim: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.chunk_dim = chunk_dim
|
||||
output_dim = embedding_dim * 2
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
temb: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
temb = self.linear(self.silu(temb))
|
||||
scale, shift = temb.chunk(2, dim=1)
|
||||
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return x
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: int | None = None,
|
||||
activation_fn: str = "geglu",
|
||||
attention_bias: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
attention_type: str = "default",
|
||||
positional_embeddings: str | None = None,
|
||||
num_positional_embeddings: int | None = None,
|
||||
ff_inner_dim: int | None = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.norm_type = norm_type
|
||||
|
||||
if positional_embeddings and (num_positional_embeddings is None):
|
||||
raise ValueError(
|
||||
"If `positional_embeddings` type is defined, `num_positional_embeddings` must also be defined."
|
||||
)
|
||||
|
||||
if positional_embeddings == "sinusoidal":
|
||||
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
if norm_type == "ada_norm":
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
if final_dropout:
|
||||
self.final_dropout = nn.Dropout(dropout)
|
||||
else:
|
||||
self.final_dropout = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
temb: torch.LongTensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# 0. Self-Attention
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm1(hidden_states, temb)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
# encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
if self.final_dropout:
|
||||
attn_output = self.final_dropout(attn_output)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 4. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DiT(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 8,
|
||||
attention_head_dim: int = 64,
|
||||
output_dim: int = 26,
|
||||
num_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
attention_bias: bool = True,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
num_embeds_ada_norm: int | None = 1000,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "ada_norm",
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-5,
|
||||
max_num_positional_embeddings: int = 512,
|
||||
compute_dtype=torch.float32,
|
||||
final_dropout: bool = True,
|
||||
positional_embeddings: str | None = "sinusoidal",
|
||||
interleave_self_attention=False,
|
||||
cross_attention_dim: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Timestep encoder
|
||||
self.timestep_encoder = TimestepEncoder(
|
||||
embedding_dim=self.inner_dim, compute_dtype=self.config.compute_dtype
|
||||
)
|
||||
|
||||
all_blocks = []
|
||||
for idx in range(self.config.num_layers):
|
||||
use_self_attn = idx % 2 == 1 and interleave_self_attention
|
||||
curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None
|
||||
|
||||
all_blocks += [
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
activation_fn=self.config.activation_fn,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
||||
norm_eps=self.config.norm_eps,
|
||||
positional_embeddings=positional_embeddings,
|
||||
num_positional_embeddings=self.config.max_num_positional_embeddings,
|
||||
final_dropout=final_dropout,
|
||||
cross_attention_dim=curr_cross_attention_dim,
|
||||
)
|
||||
]
|
||||
self.transformer_blocks = nn.ModuleList(all_blocks)
|
||||
|
||||
# Output blocks
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
||||
print(
|
||||
"Total number of DiT parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # Shape: (B, T, D)
|
||||
encoder_hidden_states: torch.Tensor, # Shape: (B, S, D)
|
||||
timestep: torch.LongTensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
return_all_hidden_states: bool = False,
|
||||
):
|
||||
# Encode timesteps
|
||||
temb = self.timestep_encoder(timestep)
|
||||
|
||||
# Process through transformer blocks - single pass through the blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
|
||||
all_hidden_states = [hidden_states]
|
||||
|
||||
# Process through transformer blocks
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if idx % 2 == 1 and self.config.interleave_self_attention:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
temb=temb,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=None,
|
||||
temb=temb,
|
||||
)
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
# Output processing
|
||||
conditioning = temb
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
if return_all_hidden_states:
|
||||
return self.proj_out_2(hidden_states), all_hidden_states
|
||||
else:
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
class SelfAttentionTransformer(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 8,
|
||||
attention_head_dim: int = 64,
|
||||
output_dim: int = 26,
|
||||
num_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
attention_bias: bool = True,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
num_embeds_ada_norm: int | None = 1000,
|
||||
upcast_attention: bool = False,
|
||||
max_num_positional_embeddings: int = 512,
|
||||
compute_dtype=torch.float32,
|
||||
final_dropout: bool = True,
|
||||
positional_embeddings: str | None = "sinusoidal",
|
||||
interleave_self_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
activation_fn=self.config.activation_fn,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
positional_embeddings=positional_embeddings,
|
||||
num_positional_embeddings=self.config.max_num_positional_embeddings,
|
||||
final_dropout=final_dropout,
|
||||
)
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
print(
|
||||
"Total number of SelfAttentionTransformer parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # Shape: (B, T, D)
|
||||
return_all_hidden_states: bool = False,
|
||||
):
|
||||
# Process through transformer blocks - single pass through the blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
all_hidden_states = [hidden_states]
|
||||
|
||||
# Process through transformer blocks
|
||||
for _idx, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(hidden_states)
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
if return_all_hidden_states:
|
||||
return hidden_states, all_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
@@ -0,0 +1,406 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
PretrainedConfig = object
|
||||
BatchFeature = None
|
||||
|
||||
from lerobot.policies.groot.action_head.action_encoder import (
|
||||
SinusoidalPositionalEncoding,
|
||||
swish,
|
||||
)
|
||||
|
||||
from .cross_attention_dit import DiT, SelfAttentionTransformer
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
# For each category, we have separate weights and biases.
|
||||
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
|
||||
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
selected_w = self.W[cat_ids]
|
||||
selected_b = self.b[cat_ids]
|
||||
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
|
||||
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
hidden = F.relu(self.layer1(x, cat_ids))
|
||||
return self.layer2(hidden, cat_ids)
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim, hidden_size, num_embodiments):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_embodiments = num_embodiments
|
||||
|
||||
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
|
||||
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
|
||||
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
|
||||
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions, timesteps, cat_ids):
|
||||
"""
|
||||
actions: shape (B, T, action_dim)
|
||||
timesteps: shape (B,) -- a single scalar per batch item
|
||||
cat_ids: shape (B,)
|
||||
returns: shape (B, T, hidden_size)
|
||||
"""
|
||||
b, t, _ = actions.shape
|
||||
|
||||
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
||||
# so that shape => (B, T)
|
||||
# e.g. if timesteps is (B,), replicate across T
|
||||
if timesteps.dim() == 1 and timesteps.shape[0] == b:
|
||||
# shape (B,) => (B,T)
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, t)
|
||||
else:
|
||||
raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.")
|
||||
|
||||
# 2) Standard action MLP step for shape => (B, T, w)
|
||||
a_emb = self.W1(actions, cat_ids)
|
||||
|
||||
# 3) Get the sinusoidal encoding (B, T, w)
|
||||
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
|
||||
|
||||
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
||||
x = torch.cat([a_emb, tau_emb], dim=-1)
|
||||
x = swish(self.W2(x, cat_ids))
|
||||
|
||||
# 5) Finally W3 => (B, T, w)
|
||||
x = self.W3(x, cat_ids)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
||||
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
||||
|
||||
add_pos_embed: bool = field(default=True, metadata={"help": "Whether to add positional embedding"})
|
||||
model_dtype: str = field(default="float32", metadata={"help": "Model data type."})
|
||||
diffusion_model_cfg: dict = field(default=None, metadata={"help": "Diffusion model configuration."})
|
||||
input_embedding_dim: int = field(default=1536, metadata={"help": "Input embedding channel dimension."})
|
||||
backbone_embedding_dim: int = field(
|
||||
default=1536, metadata={"help": "Backbone embedding channel dimension."}
|
||||
)
|
||||
|
||||
hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."})
|
||||
max_seq_len: int = field(default=1024, metadata={"help": "Maximum Sequence Length"})
|
||||
action_dim: int = field(default=None, metadata={"help": "Action dimension."})
|
||||
action_horizon: int = field(default=None, metadata={"help": "Action horizon."})
|
||||
noise_beta_alpha: float = field(default=1.5, metadata={"help": ""})
|
||||
noise_beta_beta: float = field(default=1.0, metadata={"help": ""})
|
||||
noise_s: float = field(default=0.999, metadata={"help": "Flow matching noise Beta distribution s."})
|
||||
num_timestep_buckets: int = field(
|
||||
default=1000, metadata={"help": "Number of timestep discretization buckets."}
|
||||
)
|
||||
num_inference_timesteps: int = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of inference steps for noise diffusion."},
|
||||
)
|
||||
max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."})
|
||||
tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."})
|
||||
tune_diffusion_model: bool = field(
|
||||
default=True, metadata={"help": "Whether to tune the diffusion model."}
|
||||
)
|
||||
load_pretrained_det_decode_layer_path: str = field(
|
||||
default=None, metadata={"help": "Path to pretrained detection model."}
|
||||
)
|
||||
detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."})
|
||||
|
||||
freeze_decode_layer: bool = field(default=False)
|
||||
expand_batch: int = field(default=None)
|
||||
use_vlln: bool = field(default=True)
|
||||
|
||||
vl_self_attention_cfg: dict = field(default=None)
|
||||
num_target_vision_tokens: int = field(default=32, metadata={"help": "Number of target vision tokens."})
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class FlowmatchingActionHead(nn.Module):
|
||||
config_class = FlowmatchingActionHeadConfig
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FlowmatchingActionHeadConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.input_embedding_dim = config.input_embedding_dim
|
||||
|
||||
self.model = DiT(**config.diffusion_model_cfg)
|
||||
self.action_dim = config.action_dim
|
||||
self.action_horizon = config.action_horizon
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
self.state_encoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=config.max_state_dim,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.input_embedding_dim,
|
||||
)
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=config.action_dim,
|
||||
hidden_size=self.input_embedding_dim,
|
||||
num_embodiments=config.max_num_embodiments,
|
||||
)
|
||||
self.action_decoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=self.hidden_size,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.action_dim,
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim)
|
||||
nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
|
||||
|
||||
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
|
||||
self.vl_self_attention = (
|
||||
SelfAttentionTransformer(**config.vl_self_attention_cfg) if config.use_vlln else nn.Identity()
|
||||
)
|
||||
|
||||
if config.add_pos_embed:
|
||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
|
||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||
|
||||
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
|
||||
self.num_timestep_buckets = config.num_timestep_buckets
|
||||
self.config = config
|
||||
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
|
||||
|
||||
def set_trainable_parameters(self, tune_projector: bool, tune_diffusion_model: bool):
|
||||
self.tune_projector = tune_projector
|
||||
self.tune_diffusion_model = tune_diffusion_model
|
||||
for p in self.parameters():
|
||||
p.requires_grad = True
|
||||
if not tune_projector:
|
||||
self.state_encoder.requires_grad_(False)
|
||||
self.action_encoder.requires_grad_(False)
|
||||
self.action_decoder.requires_grad_(False)
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.requires_grad_(False)
|
||||
if not tune_diffusion_model:
|
||||
self.model.requires_grad_(False)
|
||||
print(f"Tune action head projector: {self.tune_projector}")
|
||||
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
|
||||
# Check if any parameters are still trainable. If not, print a warning.
|
||||
if not tune_projector and not tune_diffusion_model:
|
||||
for name, p in self.named_parameters():
|
||||
if p.requires_grad:
|
||||
print(f"Action head trainable parameter: {name}")
|
||||
if not any(p.requires_grad for p in self.parameters()):
|
||||
print("Warning: No action head trainable parameters found.")
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self):
|
||||
"""
|
||||
Huggingface will call model.train() at each training_step. To ensure
|
||||
the expected behaviors for modules like dropout, batchnorm, etc., we
|
||||
need to call model.eval() for the frozen modules.
|
||||
"""
|
||||
if self.training:
|
||||
if not self.tune_projector:
|
||||
self.state_encoder.eval()
|
||||
self.action_encoder.eval()
|
||||
self.action_decoder.eval()
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.eval()
|
||||
if not self.tune_diffusion_model:
|
||||
self.model.eval()
|
||||
|
||||
def sample_time(self, batch_size, device, dtype):
|
||||
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
return (self.config.noise_s - sample) / self.config.noise_s
|
||||
|
||||
def prepare_input(self, batch: dict) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
|
||||
backbone_features = backbone_output["backbone_features"]
|
||||
backbone_features = self.vlln(backbone_features)
|
||||
backbone_features = self.vl_self_attention(backbone_features)
|
||||
backbone_output["backbone_features"] = backbone_features
|
||||
return backbone_output
|
||||
|
||||
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
# Set frozen modules to eval
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
|
||||
if self.config.expand_batch is not None:
|
||||
for k, v in backbone_output.items():
|
||||
ndim = len(v.shape)
|
||||
factors = [self.config.expand_batch]
|
||||
while len(factors) < ndim:
|
||||
factors.append(1)
|
||||
factors = tuple(factors)
|
||||
expanded = v.repeat(*factors)
|
||||
backbone_output[k] = expanded
|
||||
|
||||
for k, v in action_input.items():
|
||||
ndim = len(v.shape)
|
||||
factors = [self.config.expand_batch]
|
||||
while len(factors) < ndim:
|
||||
factors.append(1)
|
||||
factors = tuple(factors)
|
||||
expanded = v.repeat(*factors)
|
||||
action_input[k] = expanded
|
||||
|
||||
# Get vision and language embeddings.
|
||||
vl_embs = backbone_output.backbone_features
|
||||
device = vl_embs.device
|
||||
|
||||
# Get embodiment ID.
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
# Embed state.
|
||||
state_features = self.state_encoder(action_input.state, embodiment_id)
|
||||
|
||||
# Embed noised action trajectory.
|
||||
actions = action_input.action
|
||||
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
|
||||
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
|
||||
t = t[:, None, None] # shape (B,1,1) for broadcast
|
||||
|
||||
noisy_trajectory = (1 - t) * noise + t * actions
|
||||
velocity = actions - noise
|
||||
|
||||
# Convert (continuous) t -> discrete if needed
|
||||
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
|
||||
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
|
||||
|
||||
# Maybe add position embedding.
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
|
||||
action_features = action_features + pos_embs
|
||||
|
||||
# Join vision, language, state and action embedding along sequence dimension.
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
|
||||
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
|
||||
|
||||
vl_attn_mask = backbone_output.backbone_attention_mask
|
||||
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=vl_attn_mask,
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=False, # NOTE (YL): not using flare now
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
|
||||
# Slice out only the action portion of pred and target.
|
||||
action_mask = action_input.action_mask
|
||||
loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
|
||||
loss = loss.sum() / action_mask.sum()
|
||||
output_dict = {
|
||||
"loss": loss,
|
||||
}
|
||||
return BatchFeature(data=output_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_action(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
|
||||
# Get vision and language embeddings.
|
||||
vl_embs = backbone_output.backbone_features
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
# Embed state.
|
||||
state_features = self.state_encoder(action_input.state, embodiment_id)
|
||||
|
||||
# Set initial actions as the sampled noise.
|
||||
batch_size = vl_embs.shape[0]
|
||||
device = vl_embs.device
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.config.action_dim),
|
||||
dtype=vl_embs.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
num_steps = self.num_inference_timesteps
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
# Run denoising steps.
|
||||
for t in range(num_steps):
|
||||
t_cont = t / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
|
||||
# Embed noised action trajectory.
|
||||
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
|
||||
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
|
||||
# Maybe add position embedding.
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
|
||||
action_features = action_features + pos_embs
|
||||
|
||||
# Join vision, language, state and action embedding along sequence dimension.
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
|
||||
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
|
||||
|
||||
# Run model forward.
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
timestep=timesteps_tensor,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
|
||||
pred_velocity = pred[:, -self.action_horizon :]
|
||||
|
||||
# Update actions using euler integration.
|
||||
actions = actions + dt * pred_velocity
|
||||
return BatchFeature(data={"action_pred": actions})
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
201
src/lerobot/policies/groot/configuration_groot.py
Normal file
201
src/lerobot/policies/groot/configuration_groot.py
Normal file
@@ -0,0 +1,201 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("groot")
|
||||
@dataclass
|
||||
class GrootConfig(PreTrainedConfig):
|
||||
"""Configuration for Groot policy wrapper."""
|
||||
|
||||
# Basic policy settings
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
# Dimension settings (must match pretrained GR00T model expectations)
|
||||
# Maximum state dimension. Shorter states will be zero-padded.
|
||||
max_state_dim: int = 64
|
||||
|
||||
# Maximum action dimension. Shorter actions will be zero-padded.
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Normalization (start with identity, adjust as needed)
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Image preprocessing (adjust to match Groot's expected input)
|
||||
image_size: tuple[int, int] = (224, 224)
|
||||
|
||||
# Groot-specific model parameters (from groot_finetune_script.py)
|
||||
|
||||
# Path or HuggingFace model ID for the base Groot model
|
||||
base_model_path: str = "nvidia/GR00T-N1.5-3B"
|
||||
|
||||
# HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer.
|
||||
tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5"
|
||||
|
||||
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
||||
embodiment_tag: str = "new_embodiment"
|
||||
|
||||
# Fine-tuning control arguments
|
||||
|
||||
# Whether to fine-tune the llm backbone
|
||||
tune_llm: bool = False
|
||||
|
||||
# Whether to fine-tune the vision tower
|
||||
tune_visual: bool = False
|
||||
|
||||
# Whether to fine-tune the projector
|
||||
tune_projector: bool = True
|
||||
|
||||
# Whether to fine-tune the diffusion model
|
||||
tune_diffusion_model: bool = True
|
||||
|
||||
# LoRA parameters (from groot_finetune_script.py)
|
||||
# Rank for the LORA model. If 0, no LORA will be used.
|
||||
lora_rank: int = 0
|
||||
|
||||
# Alpha value for the LORA model
|
||||
lora_alpha: int = 16
|
||||
|
||||
# Dropout rate for the LORA model
|
||||
lora_dropout: float = 0.1
|
||||
|
||||
# Whether to use the full model for LORA
|
||||
lora_full_model: bool = False
|
||||
|
||||
# Training parameters (matching groot_finetune_script.py)
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
warmup_ratio: float = 0.05
|
||||
use_bf16: bool = True
|
||||
|
||||
# Dataset parameters
|
||||
# Video backend to use for training ('decord' or 'torchvision_av')
|
||||
video_backend: str = "decord"
|
||||
|
||||
# Whether to balance dataset weights in mixture datasets
|
||||
balance_dataset_weights: bool = True
|
||||
|
||||
# Whether to sample trajectories weighted by their length
|
||||
balance_trajectory_weights: bool = True
|
||||
|
||||
# Optional dataset paths for delegating training to Isaac-GR00T runner
|
||||
dataset_paths: list[str] | None = None
|
||||
output_dir: str = "./tmp/gr00t"
|
||||
save_steps: int = 1000
|
||||
max_steps: int = 10000
|
||||
batch_size: int = 32
|
||||
dataloader_num_workers: int = 8
|
||||
report_to: str = "wandb"
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
# groot_repo_path is now optional since we ported the components
|
||||
# No validation needed
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features for Groot."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"Groot policy requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if "observation.state" not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,),
|
||||
)
|
||||
self.input_features["observation.state"] = state_feature
|
||||
else:
|
||||
state_shape = self.input_features["observation.state"].shape
|
||||
state_dim = state_shape[0] if state_shape else 0
|
||||
if state_dim > self.max_state_dim:
|
||||
raise ValueError(
|
||||
f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. "
|
||||
f"Either reduce state dimension or increase max_state_dim in config."
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,),
|
||||
)
|
||||
self.output_features["action"] = action_feature
|
||||
else:
|
||||
action_shape = self.output_features["action"].shape
|
||||
action_dim = action_shape[0] if action_shape else 0
|
||||
if action_dim > self.max_action_dim:
|
||||
raise ValueError(
|
||||
f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. "
|
||||
f"Either reduce action dimension or increase max_action_dim in config."
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
"""Return optimizer configuration."""
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
"""Return scheduler configuration."""
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
num_warmup_steps=int(10000 * self.warmup_ratio), # 5% warmup by default
|
||||
num_decay_steps=10000, # Adjust based on training steps
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.optimizer_lr * 0.1,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
"""Return indices for delta observations (None for Groot)."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Return indices for delta actions."""
|
||||
return list(range(min(self.chunk_size, 16)))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
"""Return indices for delta rewards (None for Groot)."""
|
||||
return None
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user