mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
Compare commits
66 Commits
v0.4.2
...
openarms_t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
08d2ed8015 | ||
|
|
4bcd14b8de | ||
|
|
c34935090d | ||
|
|
9cfd56587e | ||
|
|
ff8584a025 | ||
|
|
6bc1e5186a | ||
|
|
69dc8165ae | ||
|
|
021bca2ad9 | ||
|
|
4e0ee0d643 | ||
|
|
0a8aa85871 | ||
|
|
76ddd8b948 | ||
|
|
bf08733068 | ||
|
|
e38f56c071 | ||
|
|
19fe69dac0 | ||
|
|
14319ee608 | ||
|
|
9b04fd25b6 | ||
|
|
40e98ba690 | ||
|
|
894d65d58a | ||
|
|
f58d508df2 | ||
|
|
e22b909e7c | ||
|
|
09f1673cbf | ||
|
|
4744f99990 | ||
|
|
01c1735739 | ||
|
|
6808a42455 | ||
|
|
fff719cb4f | ||
|
|
e2c00f6ed8 | ||
|
|
0f90db23c5 | ||
|
|
96b192f2ae | ||
|
|
ecdc34a699 | ||
|
|
fa6a2fb9b7 | ||
|
|
b011643dc9 | ||
|
|
30c10c1c6e | ||
|
|
56e2360072 | ||
|
|
92fdbe9bbf | ||
|
|
b303d1ab38 | ||
|
|
b1d162f333 | ||
|
|
2b304eeb84 | ||
|
|
4e6048a221 | ||
|
|
81ebcac8d7 | ||
|
|
a6c3a0fa09 | ||
|
|
c2fb644613 | ||
|
|
1d07a4aefd | ||
|
|
ce348a3460 | ||
|
|
cb920235c4 | ||
|
|
7f40b3bf82 | ||
|
|
2e9c9fd832 | ||
|
|
f9cb5e659c | ||
|
|
0217e1e3ad | ||
|
|
d79dd6d31f | ||
|
|
56b43cc888 | ||
|
|
77fe5a09ed | ||
|
|
89ae7813a7 | ||
|
|
e003108cf8 | ||
|
|
5766eea377 | ||
|
|
f8a4cf225b | ||
|
|
43b0f17eb9 | ||
|
|
b0b755471b | ||
|
|
35c5a27352 | ||
|
|
afb90e17e7 | ||
|
|
9ec9ee781a | ||
|
|
0b497fc37d | ||
|
|
797cd2725a | ||
|
|
af4766b602 | ||
|
|
37f43df88a | ||
|
|
5f7b5f2817 | ||
|
|
c55fbe1b3e |
@@ -31,7 +31,8 @@ jobs:
|
|||||||
name: Upload Preview and Comment
|
name: Upload Preview and Comment
|
||||||
if: >
|
if: >
|
||||||
github.event.workflow_run.event == 'pull_request' &&
|
github.event.workflow_run.event == 'pull_request' &&
|
||||||
github.event.workflow_run.conclusion == 'success'
|
github.event.workflow_run.conclusion == 'success' &&
|
||||||
|
github.repository == 'huggingface/lerobot'
|
||||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
|
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
|
||||||
with:
|
with:
|
||||||
package_name: lerobot
|
package_name: lerobot
|
||||||
|
|||||||
6
.github/workflows/documentation.yml
vendored
6
.github/workflows/documentation.yml
vendored
@@ -42,7 +42,9 @@ jobs:
|
|||||||
# This job builds and deploys the official documentation.
|
# This job builds and deploys the official documentation.
|
||||||
build_main_docs:
|
build_main_docs:
|
||||||
name: Build Main Docs
|
name: Build Main Docs
|
||||||
if: github.event_name == 'push' || github.event_name == 'workflow_dispatch'
|
if: >
|
||||||
|
(github.event_name == 'push' || github.event_name == 'workflow_dispatch') &&
|
||||||
|
github.repository == 'huggingface/lerobot'
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
|
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
|
||||||
@@ -58,7 +60,7 @@ jobs:
|
|||||||
# The result of this job triggers the 'Upload PR Documentation' workflow.
|
# The result of this job triggers the 'Upload PR Documentation' workflow.
|
||||||
build_pr_docs:
|
build_pr_docs:
|
||||||
name: Build PR Docs
|
name: Build PR Docs
|
||||||
if: github.event_name == 'pull_request'
|
if: github.event_name == 'pull_request' && github.repository == 'huggingface/lerobot'
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
pull-requests: write
|
pull-requests: write
|
||||||
|
|||||||
8
.github/workflows/fast_tests.yml
vendored
8
.github/workflows/fast_tests.yml
vendored
@@ -45,7 +45,6 @@ permissions:
|
|||||||
env:
|
env:
|
||||||
UV_VERSION: "0.8.0"
|
UV_VERSION: "0.8.0"
|
||||||
PYTHON_VERSION: "3.10"
|
PYTHON_VERSION: "3.10"
|
||||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
|
|
||||||
|
|
||||||
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
|
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
|
||||||
concurrency:
|
concurrency:
|
||||||
@@ -60,12 +59,19 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
MUJOCO_GL: egl
|
MUJOCO_GL: egl
|
||||||
|
HF_HOME: /mnt/cache/.cache/huggingface
|
||||||
|
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
lfs: true
|
lfs: true
|
||||||
|
|
||||||
|
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
|
||||||
|
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
|
||||||
|
- name: Setup /mnt storage
|
||||||
|
run: sudo chown -R $USER:$USER /mnt
|
||||||
|
|
||||||
# TODO(Steven): Evaluate the need of these dependencies
|
# TODO(Steven): Evaluate the need of these dependencies
|
||||||
- name: Install apt dependencies
|
- name: Install apt dependencies
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
7
.github/workflows/full_tests.yml
vendored
7
.github/workflows/full_tests.yml
vendored
@@ -58,12 +58,19 @@ jobs:
|
|||||||
github.event_name == 'workflow_dispatch'
|
github.event_name == 'workflow_dispatch'
|
||||||
env:
|
env:
|
||||||
MUJOCO_GL: egl
|
MUJOCO_GL: egl
|
||||||
|
HF_HOME: /mnt/cache/.cache/huggingface
|
||||||
|
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
lfs: true
|
lfs: true
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
|
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
|
||||||
|
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
|
||||||
|
- name: Setup /mnt storage
|
||||||
|
run: sudo chown -R $USER:$USER /mnt
|
||||||
|
|
||||||
- name: Install apt dependencies
|
- name: Install apt dependencies
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||||
|
|||||||
2
.github/workflows/nightly.yml
vendored
2
.github/workflows/nightly.yml
vendored
@@ -43,6 +43,7 @@ jobs:
|
|||||||
name: Build CPU Docker for Nightly
|
name: Build CPU Docker for Nightly
|
||||||
runs-on:
|
runs-on:
|
||||||
group: aws-general-8-plus
|
group: aws-general-8-plus
|
||||||
|
if: github.repository == 'huggingface/lerobot'
|
||||||
outputs:
|
outputs:
|
||||||
image_tag: ${{ env.DOCKER_IMAGE_NAME_CPU }}
|
image_tag: ${{ env.DOCKER_IMAGE_NAME_CPU }}
|
||||||
steps:
|
steps:
|
||||||
@@ -77,6 +78,7 @@ jobs:
|
|||||||
name: Build GPU Docker for Nightly
|
name: Build GPU Docker for Nightly
|
||||||
runs-on:
|
runs-on:
|
||||||
group: aws-general-8-plus
|
group: aws-general-8-plus
|
||||||
|
if: github.repository == 'huggingface/lerobot'
|
||||||
outputs:
|
outputs:
|
||||||
image_tag: ${{ env.DOCKER_IMAGE_NAME_GPU }}
|
image_tag: ${{ env.DOCKER_IMAGE_NAME_GPU }}
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
1
.github/workflows/release.yml
vendored
1
.github/workflows/release.yml
vendored
@@ -29,6 +29,7 @@ jobs:
|
|||||||
build-and-publish:
|
build-and-publish:
|
||||||
name: Build and publish Python distributions
|
name: Build and publish Python distributions
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
if: github.repository == 'huggingface/lerobot'
|
||||||
outputs:
|
outputs:
|
||||||
version: ${{ steps.extract_info.outputs.tag_version }}
|
version: ${{ steps.extract_info.outputs.tag_version }}
|
||||||
permissions:
|
permissions:
|
||||||
|
|||||||
1
.github/workflows/stale.yml
vendored
1
.github/workflows/stale.yml
vendored
@@ -45,6 +45,7 @@ jobs:
|
|||||||
stale:
|
stale:
|
||||||
name: Close Stale Issues and PRs
|
name: Close Stale Issues and PRs
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
if: github.repository == 'huggingface/lerobot'
|
||||||
permissions:
|
permissions:
|
||||||
actions: write
|
actions: write
|
||||||
contents: write # only for delete-branch option
|
contents: write # only for delete-branch option
|
||||||
|
|||||||
8
.github/workflows/unbound_deps_tests.yml
vendored
8
.github/workflows/unbound_deps_tests.yml
vendored
@@ -43,14 +43,22 @@ jobs:
|
|||||||
full-tests:
|
full-tests:
|
||||||
name: Full Unbound Tests
|
name: Full Unbound Tests
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
if: github.repository == 'huggingface/lerobot'
|
||||||
env:
|
env:
|
||||||
MUJOCO_GL: egl
|
MUJOCO_GL: egl
|
||||||
|
HF_HOME: /mnt/cache/.cache/huggingface
|
||||||
|
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
lfs: true
|
lfs: true
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
|
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
|
||||||
|
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
|
||||||
|
- name: Setup /mnt storage
|
||||||
|
run: sudo chown -R $USER:$USER /mnt
|
||||||
|
|
||||||
- name: Install apt dependencies
|
- name: Install apt dependencies
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||||
|
|||||||
@@ -9,6 +9,8 @@
|
|||||||
title: Imitation Learning for Robots
|
title: Imitation Learning for Robots
|
||||||
- local: cameras
|
- local: cameras
|
||||||
title: Cameras
|
title: Cameras
|
||||||
|
- local: bring_your_own_policies
|
||||||
|
title: Bring Your Own Policies
|
||||||
- local: integrate_hardware
|
- local: integrate_hardware
|
||||||
title: Bring Your Own Hardware
|
title: Bring Your Own Hardware
|
||||||
- local: hilserl
|
- local: hilserl
|
||||||
@@ -37,6 +39,8 @@
|
|||||||
title: π₀.₅ (Pi05)
|
title: π₀.₅ (Pi05)
|
||||||
- local: groot
|
- local: groot
|
||||||
title: NVIDIA GR00T N1.5
|
title: NVIDIA GR00T N1.5
|
||||||
|
- local: xvla
|
||||||
|
title: X-VLA
|
||||||
title: "Policies"
|
title: "Policies"
|
||||||
- sections:
|
- sections:
|
||||||
- local: async
|
- local: async
|
||||||
@@ -79,11 +83,19 @@
|
|||||||
title: Hope Jr
|
title: Hope Jr
|
||||||
- local: reachy2
|
- local: reachy2
|
||||||
title: Reachy 2
|
title: Reachy 2
|
||||||
|
- local: unitree_g1
|
||||||
|
title: Unitree G1
|
||||||
|
- local: earthrover_mini_plus
|
||||||
|
title: Earth Rover Mini
|
||||||
title: "Robots"
|
title: "Robots"
|
||||||
- sections:
|
- sections:
|
||||||
- local: phone_teleop
|
- local: phone_teleop
|
||||||
title: Phone
|
title: Phone
|
||||||
title: "Teleoperators"
|
title: "Teleoperators"
|
||||||
|
- sections:
|
||||||
|
- local: torch_accelerators
|
||||||
|
title: PyTorch accelerators
|
||||||
|
title: "Supported Hardware"
|
||||||
- sections:
|
- sections:
|
||||||
- local: notebooks
|
- local: notebooks
|
||||||
title: Notebooks
|
title: Notebooks
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ We found the default values of `actions_per_chunk` and `chunk_size_threshold` to
|
|||||||
2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue.
|
2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue.
|
||||||
3. **Adjust `chunk_size_threshold`**.
|
3. **Adjust `chunk_size_threshold`**.
|
||||||
- Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model).
|
- Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model).
|
||||||
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug-visualize-queue-size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
|
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug_visualize_queue_size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img
|
<img
|
||||||
@@ -289,7 +289,7 @@ We found the default values of `actions_per_chunk` and `chunk_size_threshold` to
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<i>
|
<i>
|
||||||
The action queue size is plotted at runtime when the
|
The action queue size is plotted at runtime when the
|
||||||
`--debug-visualize-queue-size` flag is passed, for various levels of
|
`--debug_visualize_queue_size` flag is passed, for various levels of
|
||||||
`chunk_size_threshold` (`g` in the SmolVLA paper).
|
`chunk_size_threshold` (`g` in the SmolVLA paper).
|
||||||
</i>
|
</i>
|
||||||
</p>
|
</p>
|
||||||
|
|||||||
175
docs/source/bring_your_own_policies.mdx
Normal file
175
docs/source/bring_your_own_policies.mdx
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
# Bring Your Own Policies
|
||||||
|
|
||||||
|
This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms.
|
||||||
|
|
||||||
|
## Step 1: Create a Policy Package
|
||||||
|
|
||||||
|
Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions.
|
||||||
|
|
||||||
|
### Package Structure
|
||||||
|
|
||||||
|
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot_policy_my_custom_policy/
|
||||||
|
├── pyproject.toml
|
||||||
|
└── src/
|
||||||
|
└── lerobot_policy_my_custom_policy/
|
||||||
|
├── __init__.py
|
||||||
|
├── configuration_my_custom_policy.py
|
||||||
|
├── modeling_my_custom_policy.py
|
||||||
|
└── processor_my_custom_policy.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Package Configuration
|
||||||
|
|
||||||
|
Set up your `pyproject.toml`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[project]
|
||||||
|
name = "lerobot_policy_my_custom_policy"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
# your policy-specific dependencies
|
||||||
|
]
|
||||||
|
requires-python = ">= 3.11"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
build-backend = # your-build-backend
|
||||||
|
requires = # your-build-system
|
||||||
|
```
|
||||||
|
|
||||||
|
## Step 2: Define the Policy Configuration
|
||||||
|
|
||||||
|
Create a configuration class that inherits from `PreTrainedConfig` and registers your policy type:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# configuration_my_custom_policy.py
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.types import NormalizationMode
|
||||||
|
|
||||||
|
@PreTrainedConfig.register_subclass("my_custom_policy")
|
||||||
|
@dataclass
|
||||||
|
class MyCustomPolicyConfig(PreTrainedConfig):
|
||||||
|
"""Configuration class for MyCustomPolicy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_obs_steps: Number of observation steps to use as input
|
||||||
|
horizon: Action prediction horizon
|
||||||
|
n_action_steps: Number of action steps to execute
|
||||||
|
hidden_dim: Hidden dimension for the policy network
|
||||||
|
# Add your policy-specific parameters here
|
||||||
|
"""
|
||||||
|
# ...PreTrainedConfig fields...
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
# Add any validation logic here
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
"""Validate input/output feature compatibility."""
|
||||||
|
# Implement validation logic for your policy's requirements
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
## Step 3: Implement the Policy Class
|
||||||
|
|
||||||
|
Create your policy implementation by inheriting from LeRobot's base `PreTrainedPolicy` class:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# modeling_my_custom_policy.py
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||||
|
|
||||||
|
class MyCustomPolicy(PreTrainedPolicy):
|
||||||
|
config_class = MyCustomPolicyConfig
|
||||||
|
name = "my_custom_policy"
|
||||||
|
|
||||||
|
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None):
|
||||||
|
super().__init__(config, dataset_stats)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Step 4: Add Data Processors
|
||||||
|
|
||||||
|
Create processor functions:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# processor_my_custom_policy.py
|
||||||
|
from typing import Dict, Any
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def make_my_custom_policy_pre_post_processors(
|
||||||
|
config,
|
||||||
|
) -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
"""Create preprocessing and postprocessing functions for your policy."""
|
||||||
|
pass # Define your preprocessing and postprocessing logic here
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Step 5: Package Initialization
|
||||||
|
|
||||||
|
Expose your classes in the package's `__init__.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# __init__.py
|
||||||
|
"""Custom policy package for LeRobot."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
import lerobot # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"lerobot is not installed. Please install lerobot to use this policy package."
|
||||||
|
)
|
||||||
|
|
||||||
|
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||||
|
from .modeling_my_custom_policy import MyCustomPolicy
|
||||||
|
from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MyCustomPolicyConfig",
|
||||||
|
"MyCustomPolicy",
|
||||||
|
"make_my_custom_policy_pre_post_processors",
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Step 6: Installation and Usage
|
||||||
|
|
||||||
|
### Install Your Policy Package
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd lerobot_policy_my_custom_policy
|
||||||
|
pip install -e .
|
||||||
|
|
||||||
|
# Or install from PyPI if published
|
||||||
|
pip install lerobot_policy_my_custom_policy
|
||||||
|
```
|
||||||
|
|
||||||
|
### Use Your Policy
|
||||||
|
|
||||||
|
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--policy.type my_custom_policy \
|
||||||
|
--env.type pusht \
|
||||||
|
--steps 200000
|
||||||
|
```
|
||||||
|
|
||||||
|
## Examples and Community Contributions
|
||||||
|
|
||||||
|
Check out these example policy implementations:
|
||||||
|
|
||||||
|
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
|
||||||
|
|
||||||
|
Share your policy implementations with the community! 🤗
|
||||||
206
docs/source/earthrover_mini_plus.mdx
Normal file
206
docs/source/earthrover_mini_plus.mdx
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
# EarthRover Mini Plus
|
||||||
|
|
||||||
|
The EarthRover Mini Plus is a fully open source mobile robot that connects through the cloud using the Frodobots SDK. This lets you control the robot and record datasets for training AI models.
|
||||||
|
|
||||||
|
## What You Need
|
||||||
|
|
||||||
|
### Hardware
|
||||||
|
|
||||||
|
- EarthRover Mini robot
|
||||||
|
- Computer with Python 3.10 or newer
|
||||||
|
- Internet connection
|
||||||
|
|
||||||
|
### Setting Up the Frodobots SDK
|
||||||
|
|
||||||
|
The robot needs the [Frodobots SDK](https://github.com/Frodobots/earth-rovers-sdk) running on your computer. Here's how:
|
||||||
|
|
||||||
|
1. Download and install the SDK:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/Frodobots/earth-rovers-sdk.git
|
||||||
|
cd earth-rovers-sdk
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start the SDK:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
hypercorn main:app --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Open your web browser and go to `http://localhost:8000`, then click "Join"
|
||||||
|
|
||||||
|
The SDK gives you:
|
||||||
|
|
||||||
|
- Live video from front and rear cameras
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> The SDK must be running before you can use the robot.
|
||||||
|
|
||||||
|
## Install LeRobot
|
||||||
|
|
||||||
|
Follow our [Installation Guide](./installation) to install LeRobot.
|
||||||
|
|
||||||
|
In addition to the base installation, install the EarthRover Mini dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
The robot uses the internet to communicate:
|
||||||
|
|
||||||
|
- **Movement commands**: Sent through the SDK
|
||||||
|
- **Camera video**: Received from the SDK
|
||||||
|
- **Robot info**: Battery, location, speed from the SDK
|
||||||
|
|
||||||
|
You don't need to plug anything in - it all works through the SDK.
|
||||||
|
|
||||||
|
## Calibration
|
||||||
|
|
||||||
|
No calibration needed! The robot is ready to use as soon as the SDK is running.
|
||||||
|
|
||||||
|
## Controlling the Robot
|
||||||
|
|
||||||
|
You control the robot using your keyboard - just like playing a video game with WASD keys.
|
||||||
|
|
||||||
|
### Keyboard Controls
|
||||||
|
|
||||||
|
| Key | Action |
|
||||||
|
| --- | -------------------------------- |
|
||||||
|
| W | Move forward |
|
||||||
|
| S | Move backward |
|
||||||
|
| A | Turn left (with forward motion) |
|
||||||
|
| D | Turn right (with forward motion) |
|
||||||
|
| Q | Rotate left in place |
|
||||||
|
| E | Rotate right in place |
|
||||||
|
| X | Stop all movement |
|
||||||
|
| +/= | Increase speed |
|
||||||
|
| - | Decrease speed |
|
||||||
|
| ESC | Disconnect |
|
||||||
|
|
||||||
|
### Speed Settings
|
||||||
|
|
||||||
|
You can adjust how fast the robot moves:
|
||||||
|
|
||||||
|
- **Forward/backward speed**: Default is full speed (1.0)
|
||||||
|
- **Turning speed**: Default is full speed (1.0)
|
||||||
|
- **Speed changes**: Use +/- keys to adjust by 0.1 each time
|
||||||
|
|
||||||
|
### Try It Out
|
||||||
|
|
||||||
|
Test driving the robot before recording data:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.robots.earthrover_mini_plus import EarthRoverMiniPlus, EarthRoverMiniPlusConfig
|
||||||
|
from lerobot.teleoperators.keyboard import KeyboardRoverTeleop, KeyboardRoverTeleopConfig
|
||||||
|
|
||||||
|
# Initialize robot
|
||||||
|
robot_config = EarthRoverMiniPlusConfig()
|
||||||
|
robot = EarthRoverMiniPlus(robot_config)
|
||||||
|
|
||||||
|
# Initialize teleoperator
|
||||||
|
teleop_config = KeyboardRoverTeleopConfig(
|
||||||
|
linear_speed=1.0,
|
||||||
|
angular_speed=1.0,
|
||||||
|
speed_increment=0.1
|
||||||
|
)
|
||||||
|
teleop = KeyboardRoverTeleop(teleop_config)
|
||||||
|
|
||||||
|
# Connect
|
||||||
|
robot.connect()
|
||||||
|
teleop.connect()
|
||||||
|
|
||||||
|
# Teleoperate (use keyboard controls)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
action = teleop.get_action()
|
||||||
|
robot.send_action(action)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
robot.disconnect()
|
||||||
|
teleop.disconnect()
|
||||||
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> If you're using a Mac, you might need to give Terminal permission to access your keyboard for teleoperation. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal.
|
||||||
|
|
||||||
|
## Recording Data
|
||||||
|
|
||||||
|
Once you can drive the robot well, you can start recording data to train AI models. The system records:
|
||||||
|
|
||||||
|
- **What you do**: How you move the robot (forward, backward, turning)
|
||||||
|
- **What the robot sees**:
|
||||||
|
- Videos from both cameras
|
||||||
|
- Robot speed and direction
|
||||||
|
- Battery level and location
|
||||||
|
- GPS position and signal
|
||||||
|
- Other sensor data
|
||||||
|
- **When it happened**: Timestamps for everything
|
||||||
|
|
||||||
|
### Setting Up Hugging Face
|
||||||
|
|
||||||
|
We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||||
|
```
|
||||||
|
|
||||||
|
Store your Hugging Face username:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||||
|
echo $HF_USER
|
||||||
|
```
|
||||||
|
|
||||||
|
### Start Recording
|
||||||
|
|
||||||
|
Use the standard recording command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/scripts/lerobot_record.py \
|
||||||
|
--robot.type=earthrover_mini_plus \
|
||||||
|
--teleop.type=keyboard_rover \
|
||||||
|
--dataset.repo_id=your_username/dataset_name \
|
||||||
|
--dataset.num_episodes=2 \
|
||||||
|
--dataset.fps=10 \
|
||||||
|
--dataset.single_task="Navigate around obstacles" \
|
||||||
|
--display_data=true
|
||||||
|
```
|
||||||
|
|
||||||
|
Replace `your_username/dataset_name` with your Hugging Face username and a name for your dataset.
|
||||||
|
|
||||||
|
### What Gets Saved
|
||||||
|
|
||||||
|
Your dataset includes:
|
||||||
|
|
||||||
|
**Your Actions (2 things)**:
|
||||||
|
|
||||||
|
- How much you moved forward/backward
|
||||||
|
- How much you turned left/right
|
||||||
|
|
||||||
|
**Robot Observations (12 things)**:
|
||||||
|
|
||||||
|
- Front camera video
|
||||||
|
- Rear camera video
|
||||||
|
- Current speed
|
||||||
|
- Battery level
|
||||||
|
- Which way the robot is facing
|
||||||
|
- GPS location (latitude, longitude, signal strength)
|
||||||
|
- Network signal strength
|
||||||
|
- Vibration level
|
||||||
|
- Lamp status (on/off)
|
||||||
|
|
||||||
|
### Where Your Data Goes
|
||||||
|
|
||||||
|
On your computer: `~/.cache/huggingface/lerobot/{repo-id}`
|
||||||
|
|
||||||
|
After recording, your data automatically uploads to your Hugging Face page:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
echo https://huggingface.co/datasets/${HF_USER}/earthrover-navigation
|
||||||
|
```
|
||||||
|
|
||||||
|
Your dataset will be tagged with `LeRobot` for community discovery.
|
||||||
@@ -201,7 +201,8 @@ from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
|
|||||||
from lerobot.utils.control_utils import init_keyboard_listener
|
from lerobot.utils.control_utils import init_keyboard_listener
|
||||||
from lerobot.utils.utils import log_say
|
from lerobot.utils.utils import log_say
|
||||||
from lerobot.utils.visualization_utils import init_rerun
|
from lerobot.utils.visualization_utils import init_rerun
|
||||||
from lerobot.record import record_loop
|
from lerobot.scripts.lerobot_record import record_loop
|
||||||
|
from lerobot.processor import make_default_processors
|
||||||
|
|
||||||
NUM_EPISODES = 5
|
NUM_EPISODES = 5
|
||||||
FPS = 30
|
FPS = 30
|
||||||
@@ -209,12 +210,19 @@ EPISODE_TIME_SEC = 60
|
|||||||
RESET_TIME_SEC = 10
|
RESET_TIME_SEC = 10
|
||||||
TASK_DESCRIPTION = "My task description"
|
TASK_DESCRIPTION = "My task description"
|
||||||
|
|
||||||
# Create the robot and teleoperator configurations
|
# Create robot configuration
|
||||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
|
||||||
robot_config = SO100FollowerConfig(
|
robot_config = SO100FollowerConfig(
|
||||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config
|
id="my_awesome_follower_arm",
|
||||||
|
cameras={
|
||||||
|
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
|
||||||
|
},
|
||||||
|
port="/dev/tty.usbmodem58760434471",
|
||||||
|
)
|
||||||
|
|
||||||
|
teleop_config = SO100LeaderConfig(
|
||||||
|
id="my_awesome_leader_arm",
|
||||||
|
port="/dev/tty.usbmodem585A0077581",
|
||||||
)
|
)
|
||||||
teleop_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
|
|
||||||
|
|
||||||
# Initialize the robot and teleoperator
|
# Initialize the robot and teleoperator
|
||||||
robot = SO100Follower(robot_config)
|
robot = SO100Follower(robot_config)
|
||||||
@@ -243,6 +251,9 @@ init_rerun(session_name="recording")
|
|||||||
robot.connect()
|
robot.connect()
|
||||||
teleop.connect()
|
teleop.connect()
|
||||||
|
|
||||||
|
# Create the required processors
|
||||||
|
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||||
|
|
||||||
episode_idx = 0
|
episode_idx = 0
|
||||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||||
@@ -251,6 +262,9 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
|||||||
robot=robot,
|
robot=robot,
|
||||||
events=events,
|
events=events,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
|
teleop_action_processor=teleop_action_processor,
|
||||||
|
robot_action_processor=robot_action_processor,
|
||||||
|
robot_observation_processor=robot_observation_processor,
|
||||||
teleop=teleop,
|
teleop=teleop,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
control_time_s=EPISODE_TIME_SEC,
|
||||||
@@ -265,6 +279,9 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
|||||||
robot=robot,
|
robot=robot,
|
||||||
events=events,
|
events=events,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
|
teleop_action_processor=teleop_action_processor,
|
||||||
|
robot_action_processor=robot_action_processor,
|
||||||
|
robot_observation_processor=robot_observation_processor,
|
||||||
teleop=teleop,
|
teleop=teleop,
|
||||||
control_time_s=RESET_TIME_SEC,
|
control_time_s=RESET_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
@@ -428,7 +445,7 @@ Your robot should replicate movements similar to those you recorded. For example
|
|||||||
|
|
||||||
## Train a policy
|
## Train a policy
|
||||||
|
|
||||||
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_train.py) script. A few arguments are required. Here is an example command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-train \
|
lerobot-train \
|
||||||
@@ -485,7 +502,7 @@ huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
|
|||||||
|
|
||||||
## Run inference and evaluate your policy
|
## Run inference and evaluate your policy
|
||||||
|
|
||||||
You can use the `record` script from [`lerobot/record.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
|
You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
|
||||||
|
|
||||||
<hfoptions id="eval">
|
<hfoptions id="eval">
|
||||||
<hfoption id="Command">
|
<hfoption id="Command">
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ If you encounter build errors, you may need to install additional dependencies:
|
|||||||
To install these for linux run:
|
To install these for linux run:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config
|
sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev
|
||||||
```
|
```
|
||||||
|
|
||||||
For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||||
|
|||||||
@@ -62,6 +62,11 @@ lerobot-eval \
|
|||||||
|
|
||||||
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
|
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
|
||||||
|
|
||||||
|
### Control Mode
|
||||||
|
|
||||||
|
LIBERO now supports two control modes: relative and absolute. This matters because different VLA checkpoints are trained with different mode of action to output hence control parameterizations.
|
||||||
|
You can switch them with: `env.control_mode = "relative"` and `env.control_mode = "absolute"`
|
||||||
|
|
||||||
### Policy inputs and outputs
|
### Policy inputs and outputs
|
||||||
|
|
||||||
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:
|
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:
|
||||||
|
|||||||
328
docs/source/openarms.mdx
Normal file
328
docs/source/openarms.mdx
Normal file
@@ -0,0 +1,328 @@
|
|||||||
|
# OpenArms Robot
|
||||||
|
|
||||||
|
OpenArms is a 7 DOF robotic arm with a gripper, designed by [Enactic, Inc.](https://www.enactic.com/) It uses Damiao motors controlled via CAN bus communication and MIT control mode for smooth, precise motion.
|
||||||
|
|
||||||
|
## Hardware Overview
|
||||||
|
|
||||||
|
- **7 DOF per arm** (14 DOF total for dual arm setup)
|
||||||
|
- **1 gripper per arm** (2 grippers total)
|
||||||
|
- **Damiao motors** with 4 different types:
|
||||||
|
- **DM8009** (DM-J8009P-2EC) for shoulders (J1, J2) - high torque
|
||||||
|
- **DM4340** for shoulder rotation and elbow (J3, J4)
|
||||||
|
- **DM4310** (DM-J4310-2EC V1.1) for wrist (J5, J6, J7) and gripper (J8)
|
||||||
|
- **24V power supply** required
|
||||||
|
- **CAN interface device**:
|
||||||
|
- **Linux**: Any SocketCAN-compatible adapter
|
||||||
|
- **macOS**: CANable, PEAK PCAN-USB, or Kvaser USBcan
|
||||||
|
- Proper CAN wiring (CANH, CANL, 120Ω termination)
|
||||||
|
|
||||||
|
|
||||||
|
## Motor Configuration
|
||||||
|
|
||||||
|
Each arm has the following motor configuration based on the [OpenArm setup guide](https://docs.openarm.dev/software/setup/):
|
||||||
|
|
||||||
|
| Joint | Motor | Motor Type | Sender CAN ID | Receiver ID | Description |
|
||||||
|
|-------|-------|------------|---------------|-------------|-------------|
|
||||||
|
| J1 | joint_1 | DM8009 | 0x01 | 0x11 | Shoulder pan |
|
||||||
|
| J2 | joint_2 | DM8009 | 0x02 | 0x12 | Shoulder lift |
|
||||||
|
| J3 | joint_3 | DM4340 | 0x03 | 0x13 | Shoulder rotation |
|
||||||
|
| J4 | joint_4 | DM4340 | 0x04 | 0x14 | Elbow flex |
|
||||||
|
| J5 | joint_5 | DM4310 | 0x05 | 0x15 | Wrist roll |
|
||||||
|
| J6 | joint_6 | DM4310 | 0x06 | 0x16 | Wrist pitch |
|
||||||
|
| J7 | joint_7 | DM4310 | 0x07 | 0x17 | Wrist rotation |
|
||||||
|
| J8 | gripper | DM4310 | 0x08 | 0x18 | Gripper |
|
||||||
|
|
||||||
|
For dual arm setups, the left arm uses IDs 0x09-0x10 for joints 1-8 with the same motor types.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install system dependencies
|
||||||
|
sudo apt install can-utils iproute2
|
||||||
|
|
||||||
|
# Install LeRobot with OpenArms support
|
||||||
|
pip install -e ".[openarms]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Setup Guide
|
||||||
|
|
||||||
|
### Step 1: Motor ID Configuration
|
||||||
|
|
||||||
|
**IMPORTANT**: Before using the robot, motors must be configured with the correct CAN IDs.
|
||||||
|
|
||||||
|
Refer to the [OpenArm Motor ID Configuration Guide](https://docs.openarm.dev/software/setup/motor-id) for detailed instructions using the Damiao Debugging Tools on Windows.
|
||||||
|
|
||||||
|
Key points:
|
||||||
|
- Each motor needs a unique **Sender CAN ID** (0x01-0x08)
|
||||||
|
- Each motor needs a unique **Receiver/Master ID** (0x11-0x18)
|
||||||
|
- Use the Damiao Debugging Tools to set these IDs
|
||||||
|
|
||||||
|
### Step 2: Setup CAN Interface
|
||||||
|
|
||||||
|
Configure your CAN interface as described in the [OpenArm CAN Setup Guide](https://docs.openarm.dev/software/setup/can-setup):
|
||||||
|
|
||||||
|
#### Linux (SocketCAN)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Find your CAN interface
|
||||||
|
ip link show
|
||||||
|
|
||||||
|
# Configure can0, 1, 2, 3
|
||||||
|
sudo ip link set can0 down
|
||||||
|
sudo ip link set can0 type can bitrate 1000000
|
||||||
|
sudo ip link set can0 up
|
||||||
|
|
||||||
|
sudo ip link set can1 down
|
||||||
|
sudo ip link set can1 type can bitrate 1000000
|
||||||
|
sudo ip link set can1 up
|
||||||
|
|
||||||
|
sudo ip link set can2 down
|
||||||
|
sudo ip link set can2 type can bitrate 1000000
|
||||||
|
sudo ip link set can2 up
|
||||||
|
|
||||||
|
sudo ip link set can3 down
|
||||||
|
sudo ip link set can3 type can bitrate 1000000
|
||||||
|
sudo ip link set can3 up
|
||||||
|
|
||||||
|
# Verify configuration
|
||||||
|
ip link show can0
|
||||||
|
```
|
||||||
|
|
||||||
|
or run:
|
||||||
|
|
||||||
|
`examples/openarms/setup_can.sh`
|
||||||
|
|
||||||
|
### Testing canbus and motor connection
|
||||||
|
|
||||||
|
Please run this script to check if all motors can be found and to find your can-fd speed: `python examples/openarms/debug_can_communication.py`
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Basic Setup
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.robots.openarms import OpenArmsFollower
|
||||||
|
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||||
|
|
||||||
|
# Configure for dual arm setup
|
||||||
|
config = OpenArmsFollowerConfig(
|
||||||
|
port="can0",
|
||||||
|
can_interface="socketcan", # Or "auto" for auto-detection
|
||||||
|
id="openarms_dual",
|
||||||
|
is_dual_arm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
robot = OpenArmsFollower(config)
|
||||||
|
robot.connect()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Calibration
|
||||||
|
|
||||||
|
On first use, you'll need to calibrate the robot:
|
||||||
|
|
||||||
|
```python
|
||||||
|
robot.calibrate()
|
||||||
|
```
|
||||||
|
|
||||||
|
The calibration process will:
|
||||||
|
1. Disable torque on all motors
|
||||||
|
2. Ask you to position arms in **hanging position with grippers closed**
|
||||||
|
3. Set this as the zero position
|
||||||
|
4. Ask you to move each joint through its full range
|
||||||
|
5. Record min/max positions for each joint
|
||||||
|
6. Save calibration to file
|
||||||
|
|
||||||
|
### Reading Observations
|
||||||
|
|
||||||
|
The robot provides comprehensive state information:
|
||||||
|
|
||||||
|
```python
|
||||||
|
observation = robot.get_observation()
|
||||||
|
|
||||||
|
# Observation includes for each motor:
|
||||||
|
# - {motor_name}.pos: Position in degrees
|
||||||
|
# - {motor_name}.vel: Velocity in degrees/second
|
||||||
|
# - {motor_name}.torque: Motor torque
|
||||||
|
# - {camera_name}: Camera images (if configured)
|
||||||
|
|
||||||
|
print(f"Right arm joint 1 position: {observation['right_joint_1.pos']:.1f}°")
|
||||||
|
print(f"Right arm joint 1 velocity: {observation['right_joint_1.vel']:.1f}°/s")
|
||||||
|
print(f"Right arm joint 1 torque: {observation['right_joint_1.torque']:.3f} N·m")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Sending Actions
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Send target positions (in degrees)
|
||||||
|
action = {
|
||||||
|
"right_joint_1.pos": 45.0,
|
||||||
|
"right_joint_2.pos": -30.0,
|
||||||
|
# ... all joints
|
||||||
|
"right_gripper.pos": 45.0, # Half-closed
|
||||||
|
}
|
||||||
|
|
||||||
|
actual_action = robot.send_action(action)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Gripper Control
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Open gripper
|
||||||
|
robot.open_gripper(arm="right")
|
||||||
|
|
||||||
|
# Close gripper
|
||||||
|
robot.close_gripper(arm="right")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Safety Features
|
||||||
|
|
||||||
|
### 1. Maximum Relative Target
|
||||||
|
|
||||||
|
Limits how far a joint can move in a single command to prevent sudden movements:
|
||||||
|
|
||||||
|
```python
|
||||||
|
config = OpenArmsFollowerConfig(
|
||||||
|
port="can0",
|
||||||
|
# Limit all joints to 10 degrees per command
|
||||||
|
max_relative_target=10.0,
|
||||||
|
|
||||||
|
# Or set per-motor limits
|
||||||
|
max_relative_target={
|
||||||
|
"right_joint_1": 15.0, # Slower moving joint
|
||||||
|
"right_joint_2": 10.0,
|
||||||
|
"right_gripper": 5.0, # Very slow gripper
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**How it works**: If current position is 50° and you command 80°, with `max_relative_target=10.0`, the robot will only move to 60° in that step.
|
||||||
|
|
||||||
|
### 2. Torque Limits
|
||||||
|
|
||||||
|
Control maximum torque output, especially important for grippers and teleoperation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
config = OpenArmsFollowerConfig(
|
||||||
|
port="can0",
|
||||||
|
# Gripper torque limit (fraction of motor's max torque)
|
||||||
|
gripper_torque_limit=0.5, # 50% of max torque
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Lower torque limits prevent damage when gripping delicate objects.
|
||||||
|
|
||||||
|
### 3. MIT Control Gains
|
||||||
|
|
||||||
|
Control responsiveness and stability via PID-like gains:
|
||||||
|
|
||||||
|
```python
|
||||||
|
config = OpenArmsFollowerConfig(
|
||||||
|
port="can0",
|
||||||
|
position_kp=10.0, # Position gain (higher = more responsive)
|
||||||
|
position_kd=0.5, # Velocity damping (higher = more damped)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Guidelines**:
|
||||||
|
- **For following (robot)**: Higher gains for responsiveness
|
||||||
|
- `position_kp=10.0`, `position_kd=0.5`
|
||||||
|
- **For teleoperation (leader)**: Lower gains or disable torque for manual movement
|
||||||
|
- `manual_control=True` (torque disabled)
|
||||||
|
|
||||||
|
### 4. Velocity Limits
|
||||||
|
|
||||||
|
Velocity limits are enforced by the Damiao motors based on motor type. For DM4310:
|
||||||
|
- Max velocity: 30 rad/s ≈ 1718°/s
|
||||||
|
|
||||||
|
The motors will automatically limit velocity to safe values.
|
||||||
|
|
||||||
|
## Teleoperation
|
||||||
|
|
||||||
|
### Leader Arm Setup
|
||||||
|
|
||||||
|
The leader arm is moved manually (torque disabled) to generate commands:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.teleoperators.openarms import OpenArmsLeader
|
||||||
|
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||||
|
|
||||||
|
config = OpenArmsLeaderConfig(
|
||||||
|
port="can1", # Separate CAN interface for leader
|
||||||
|
id="openarms_leader",
|
||||||
|
manual_control=True, # Torque disabled for manual movement
|
||||||
|
is_dual_arm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
leader = OpenArmsLeader(config)
|
||||||
|
leader.connect()
|
||||||
|
|
||||||
|
# Read current position as action
|
||||||
|
action = leader.get_action()
|
||||||
|
# action contains positions for all joints in degrees
|
||||||
|
```
|
||||||
|
|
||||||
|
### Safety Considerations for Teleoperation
|
||||||
|
|
||||||
|
1. **Use separate CAN interfaces** for leader and follower to avoid conflicts
|
||||||
|
2. **Enable max_relative_target** on follower to smooth abrupt movements
|
||||||
|
3. **Lower torque limits** on follower to prevent damage from tracking errors
|
||||||
|
4. **Test with one arm** before enabling dual arm teleoperation
|
||||||
|
5. **Have emergency stop** ready (power switch or CAN disable)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Recommended follower config for teleoperation
|
||||||
|
follower_config = OpenArmsFollowerConfig(
|
||||||
|
port="can0",
|
||||||
|
max_relative_target=5.0, # Small steps for smooth following
|
||||||
|
gripper_torque_limit=0.3, # Low torque for safety
|
||||||
|
position_kp=5.0, # Lower gains for gentler following
|
||||||
|
position_kd=0.3,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Motor Shaking/Unstable
|
||||||
|
|
||||||
|
- **Lower control gains**: Reduce `position_kp` and `position_kd`
|
||||||
|
- **Check calibration**: Re-run calibration procedure
|
||||||
|
- **Verify power**: Insufficient current can cause instability
|
||||||
|
- **Check mechanical**: Loose connections, binding, or damaged components
|
||||||
|
|
||||||
|
### CAN Bus Errors
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check for errors
|
||||||
|
ip -s link show can0
|
||||||
|
|
||||||
|
# Reset CAN interface
|
||||||
|
sudo ip link set can0 down
|
||||||
|
sudo ip link set can0 up
|
||||||
|
```
|
||||||
|
|
||||||
|
### Control Mode
|
||||||
|
|
||||||
|
OpenArms uses **MIT control mode** which allows simultaneous control of:
|
||||||
|
- Position (degrees)
|
||||||
|
- Velocity (degrees/second)
|
||||||
|
- Torque (N·m)
|
||||||
|
- Position gain (Kp)
|
||||||
|
- Velocity damping (Kd)
|
||||||
|
|
||||||
|
### Communication
|
||||||
|
|
||||||
|
- **Protocol**: CAN 2.0 at 1 Mbps (or CAN-FD at 5 Mbps)
|
||||||
|
- **Frame format**: Standard 11-bit IDs
|
||||||
|
- **Update rate**: Typically 50-100 Hz depending on motor count
|
||||||
|
- **Latency**: ~10-20ms per motor command
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [OpenArm Official Documentation](https://docs.openarm.dev/)
|
||||||
|
- [OpenArm Setup Guide](https://docs.openarm.dev/software/setup/)
|
||||||
|
- [Motor ID Configuration](https://docs.openarm.dev/software/setup/motor-id)
|
||||||
|
- [CAN Interface Setup](https://docs.openarm.dev/software/setup/can-setup)
|
||||||
|
- [Motor Communication Test](https://docs.openarm.dev/software/setup/configure-test)
|
||||||
|
- [Damiao Motor Documentation](https://wiki.seeedstudio.com/damiao_series/)
|
||||||
|
- [Enactic GitHub](https://github.com/enactic/openarm_can)
|
||||||
@@ -30,131 +30,6 @@ The follower arm uses 6x STS3215 motors with 1/345 gearing. The leader, however,
|
|||||||
| Wrist Roll | 5 | 1 / 147 |
|
| Wrist Roll | 5 | 1 / 147 |
|
||||||
| Gripper | 6 | 1 / 147 |
|
| Gripper | 6 | 1 / 147 |
|
||||||
|
|
||||||
### Clean Parts
|
|
||||||
|
|
||||||
Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material.
|
|
||||||
|
|
||||||
It is advisable to install one 3-pin cable in the motor after placing them before continuing assembly.
|
|
||||||
|
|
||||||
### Joint 1
|
|
||||||
|
|
||||||
- Place the first motor into the base.
|
|
||||||
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
|
|
||||||
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
|
|
||||||
- Install both motor horns, securing the top horn with a M3x6mm screw.
|
|
||||||
- Attach the shoulder part.
|
|
||||||
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
|
|
||||||
- Add the shoulder motor holder.
|
|
||||||
|
|
||||||
<div class="video-container">
|
|
||||||
<video controls width="600">
|
|
||||||
<source
|
|
||||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint1_v2.mp4"
|
|
||||||
type="video/mp4"
|
|
||||||
/>
|
|
||||||
</video>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
### Joint 2
|
|
||||||
|
|
||||||
- Slide the second motor in from the top.
|
|
||||||
- Fasten the second motor with 4 M2x6mm screws.
|
|
||||||
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
|
|
||||||
- Attach the upper arm with 4 M3x6mm screws on each side.
|
|
||||||
|
|
||||||
<div class="video-container">
|
|
||||||
<video controls width="600">
|
|
||||||
<source
|
|
||||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint2_v2.mp4"
|
|
||||||
type="video/mp4"
|
|
||||||
/>
|
|
||||||
</video>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
### Joint 3
|
|
||||||
|
|
||||||
- Insert motor 3 and fasten using 4 M2x6mm screws
|
|
||||||
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
|
|
||||||
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
|
|
||||||
|
|
||||||
<div class="video-container">
|
|
||||||
<video controls width="600">
|
|
||||||
<source
|
|
||||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint3_v2.mp4"
|
|
||||||
type="video/mp4"
|
|
||||||
/>
|
|
||||||
</video>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
### Joint 4
|
|
||||||
|
|
||||||
- Slide over motor holder 4.
|
|
||||||
- Slide in motor 4.
|
|
||||||
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
|
|
||||||
|
|
||||||
<div class="video-container">
|
|
||||||
<video controls width="600">
|
|
||||||
<source
|
|
||||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint4_v2.mp4"
|
|
||||||
type="video/mp4"
|
|
||||||
/>
|
|
||||||
</video>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
### Joint 5
|
|
||||||
|
|
||||||
- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws.
|
|
||||||
- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw.
|
|
||||||
- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides.
|
|
||||||
|
|
||||||
<div class="video-container">
|
|
||||||
<video controls width="600">
|
|
||||||
<source
|
|
||||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint5_v2.mp4"
|
|
||||||
type="video/mp4"
|
|
||||||
/>
|
|
||||||
</video>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
### Gripper / Handle
|
|
||||||
|
|
||||||
<hfoptions id="assembly">
|
|
||||||
<hfoption id="Follower">
|
|
||||||
|
|
||||||
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
|
|
||||||
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
|
|
||||||
- Attach the motor horns and again use a M3x6mm horn screw.
|
|
||||||
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
|
|
||||||
|
|
||||||
<div class="video-container">
|
|
||||||
<video controls width="600">
|
|
||||||
<source
|
|
||||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Gripper_v2.mp4"
|
|
||||||
type="video/mp4"
|
|
||||||
/>
|
|
||||||
</video>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
</hfoption>
|
|
||||||
<hfoption id="Leader">
|
|
||||||
|
|
||||||
- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws.
|
|
||||||
- Attach the handle to motor 5 using 1 M2x6mm screw.
|
|
||||||
- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw.
|
|
||||||
- Attach the follower trigger with 4 M3x6mm screws.
|
|
||||||
|
|
||||||
<div class="video-container">
|
|
||||||
<video controls width="600">
|
|
||||||
<source
|
|
||||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Leader_v2.mp4"
|
|
||||||
type="video/mp4"
|
|
||||||
/>
|
|
||||||
</video>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
</hfoption>
|
|
||||||
</hfoptions>
|
|
||||||
|
|
||||||
## Configure the motors
|
## Configure the motors
|
||||||
|
|
||||||
### 1. Find the USB ports associated with each arm
|
### 1. Find the USB ports associated with each arm
|
||||||
@@ -340,6 +215,131 @@ leader.setup_motors()
|
|||||||
</hfoption>
|
</hfoption>
|
||||||
</hfoptions>
|
</hfoptions>
|
||||||
|
|
||||||
|
### Clean Parts
|
||||||
|
|
||||||
|
Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material.
|
||||||
|
|
||||||
|
It is advisable to install one 3-pin cable in the motor after placing them before continuing assembly.
|
||||||
|
|
||||||
|
### Joint 1
|
||||||
|
|
||||||
|
- Place the first motor into the base.
|
||||||
|
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
|
||||||
|
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
|
||||||
|
- Install both motor horns, securing the top horn with a M3x6mm screw.
|
||||||
|
- Attach the shoulder part.
|
||||||
|
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
|
||||||
|
- Add the shoulder motor holder.
|
||||||
|
|
||||||
|
<div class="video-container">
|
||||||
|
<video controls width="600">
|
||||||
|
<source
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint1_v2.mp4"
|
||||||
|
type="video/mp4"
|
||||||
|
/>
|
||||||
|
</video>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
### Joint 2
|
||||||
|
|
||||||
|
- Slide the second motor in from the top.
|
||||||
|
- Fasten the second motor with 4 M2x6mm screws.
|
||||||
|
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
|
||||||
|
- Attach the upper arm with 4 M3x6mm screws on each side.
|
||||||
|
|
||||||
|
<div class="video-container">
|
||||||
|
<video controls width="600">
|
||||||
|
<source
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint2_v2.mp4"
|
||||||
|
type="video/mp4"
|
||||||
|
/>
|
||||||
|
</video>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
### Joint 3
|
||||||
|
|
||||||
|
- Insert motor 3 and fasten using 4 M2x6mm screws
|
||||||
|
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
|
||||||
|
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
|
||||||
|
|
||||||
|
<div class="video-container">
|
||||||
|
<video controls width="600">
|
||||||
|
<source
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint3_v2.mp4"
|
||||||
|
type="video/mp4"
|
||||||
|
/>
|
||||||
|
</video>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
### Joint 4
|
||||||
|
|
||||||
|
- Slide over motor holder 4.
|
||||||
|
- Slide in motor 4.
|
||||||
|
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
|
||||||
|
|
||||||
|
<div class="video-container">
|
||||||
|
<video controls width="600">
|
||||||
|
<source
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint4_v2.mp4"
|
||||||
|
type="video/mp4"
|
||||||
|
/>
|
||||||
|
</video>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
### Joint 5
|
||||||
|
|
||||||
|
- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws.
|
||||||
|
- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw.
|
||||||
|
- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides.
|
||||||
|
|
||||||
|
<div class="video-container">
|
||||||
|
<video controls width="600">
|
||||||
|
<source
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint5_v2.mp4"
|
||||||
|
type="video/mp4"
|
||||||
|
/>
|
||||||
|
</video>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
### Gripper / Handle
|
||||||
|
|
||||||
|
<hfoptions id="assembly">
|
||||||
|
<hfoption id="Follower">
|
||||||
|
|
||||||
|
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
|
||||||
|
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
|
||||||
|
- Attach the motor horns and again use a M3x6mm horn screw.
|
||||||
|
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
|
||||||
|
|
||||||
|
<div class="video-container">
|
||||||
|
<video controls width="600">
|
||||||
|
<source
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Gripper_v2.mp4"
|
||||||
|
type="video/mp4"
|
||||||
|
/>
|
||||||
|
</video>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="Leader">
|
||||||
|
|
||||||
|
- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws.
|
||||||
|
- Attach the handle to motor 5 using 1 M2x6mm screw.
|
||||||
|
- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw.
|
||||||
|
- Attach the follower trigger with 4 M3x6mm screws.
|
||||||
|
|
||||||
|
<div class="video-container">
|
||||||
|
<video controls width="600">
|
||||||
|
<source
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Leader_v2.mp4"
|
||||||
|
type="video/mp4"
|
||||||
|
/>
|
||||||
|
</video>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
</hfoptions>
|
||||||
|
|
||||||
## Calibrate
|
## Calibrate
|
||||||
|
|
||||||
Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.
|
Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.
|
||||||
|
|||||||
42
docs/source/torch_accelerators.mdx
Normal file
42
docs/source/torch_accelerators.mdx
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# PyTorch accelerators
|
||||||
|
|
||||||
|
LeRobot supports multiple hardware acceleration options for both training and inference.
|
||||||
|
|
||||||
|
These options include:
|
||||||
|
|
||||||
|
- **CPU**: CPU executes all computations, no dedicated accelerator is used
|
||||||
|
- **CUDA**: acceleration with NVIDIA & AMD GPUs
|
||||||
|
- **MPS**: acceleration with Apple Silicon GPUs
|
||||||
|
- **XPU**: acceleration with Intel integrated and discrete GPUs
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
To use particular accelerator, a suitable version of PyTorch should be installed.
|
||||||
|
|
||||||
|
For CPU, CUDA, and MPS backends follow instructions provided on [PyTorch installation page](https://pytorch.org/get-started/locally).
|
||||||
|
For XPU backend, follow instructions from [PyTorch documentation](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html).
|
||||||
|
|
||||||
|
### Verifying the installation
|
||||||
|
|
||||||
|
After installation, accelerator availability can be verified by running
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
print(torch.<backend_name>.is_available()) # <backend_name> is cuda, mps, or xpu
|
||||||
|
```
|
||||||
|
|
||||||
|
## How to run training or evaluation
|
||||||
|
|
||||||
|
To select the desired accelerator, use the `--policy.device` flag when running `lerobot-train` or `lerobot-eval`. For example, to use MPS on Apple Silicon, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train
|
||||||
|
--policy.device=mps ...
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.device=mps ...
|
||||||
|
```
|
||||||
|
|
||||||
|
However, in most cases, presence of an accelerator is detected automatically and `policy.device` parameter can be omitted from CLI commands.
|
||||||
208
docs/source/unitree_g1.mdx
Normal file
208
docs/source/unitree_g1.mdx
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
# Unitree G1 Robot Setup and Control
|
||||||
|
|
||||||
|
This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion.
|
||||||
|
|
||||||
|
## About the Unitree G1
|
||||||
|
|
||||||
|
We offer support for both 29 and 23 DOF G1. We introduce:
|
||||||
|
|
||||||
|
- **`unitree g1` robot class, handling low level communication with the humanoid**
|
||||||
|
- **ZMQ socket bridge** for remote communication over WiFi, allowing one to deploy policies remotely instead of over ethernet or directly on the Orin
|
||||||
|
- **GR00T locomotion policy** for bipedal walking and balance
|
||||||
|
- **MuJoCo simulation mode** for testing policies without the physical robot
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Part 1: Connect to Robot over Ethernet
|
||||||
|
|
||||||
|
### Step 1: Configure Your Computer's Ethernet Interface
|
||||||
|
|
||||||
|
Set a static IP on the same subnet as the robot:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Replace 'enp131s0' with your ethernet interface name (check with `ip a`)
|
||||||
|
sudo ip addr flush dev enp131s0
|
||||||
|
sudo ip addr add 192.168.123.200/24 dev enp131s0
|
||||||
|
sudo ip link set enp131s0 up
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: The robot's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` where x ≠ 164.
|
||||||
|
|
||||||
|
### Step 2: SSH into the Robot
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh unitree@192.168.123.164
|
||||||
|
# Password: 123
|
||||||
|
```
|
||||||
|
|
||||||
|
You should now be connected to the robot's onboard computer.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Part 2: Enable WiFi on the Robot
|
||||||
|
|
||||||
|
Once connected via Ethernet, follow these steps to enable WiFi:
|
||||||
|
|
||||||
|
### Step 1: Enable WiFi Hardware
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Unblock WiFi radio
|
||||||
|
sudo rfkill unblock wifi
|
||||||
|
sudo rfkill unblock all
|
||||||
|
|
||||||
|
# Bring up WiFi interface
|
||||||
|
sudo ip link set wlan0 up
|
||||||
|
|
||||||
|
# Enable NetworkManager control
|
||||||
|
sudo nmcli radio wifi on
|
||||||
|
sudo nmcli device set wlan0 managed yes
|
||||||
|
sudo systemctl restart NetworkManager
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Enable Internet Forwarding
|
||||||
|
|
||||||
|
**On your laptop:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Enable IP forwarding
|
||||||
|
sudo sysctl -w net.ipv4.ip_forward=1
|
||||||
|
|
||||||
|
# Set up NAT (replace wlp132s0f0 with your WiFi interface)
|
||||||
|
sudo iptables -t nat -A POSTROUTING -o wlp132s0f0 -s 192.168.123.0/24 -j MASQUERADE
|
||||||
|
sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTABLISHED -j ACCEPT
|
||||||
|
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
|
||||||
|
```
|
||||||
|
|
||||||
|
**On the robot:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Add laptop as default gateway
|
||||||
|
sudo ip route del default 2>/dev/null || true
|
||||||
|
sudo ip route add default via 192.168.123.200 dev eth0
|
||||||
|
echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
|
||||||
|
|
||||||
|
# Test connection
|
||||||
|
ping -c 3 8.8.8.8
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Connect to WiFi Network
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# List available networks
|
||||||
|
nmcli device wifi list
|
||||||
|
|
||||||
|
# Connect to your WiFi (example)
|
||||||
|
sudo nmcli connection add type wifi ifname wlan0 con-name "YourNetwork" ssid "YourNetwork"
|
||||||
|
sudo nmcli connection modify "YourNetwork" wifi-sec.key-mgmt wpa-psk
|
||||||
|
sudo nmcli connection modify "YourNetwork" wifi-sec.psk "YourPassword"
|
||||||
|
sudo nmcli connection modify "YourNetwork" connection.autoconnect yes
|
||||||
|
sudo nmcli connection up "YourNetwork"
|
||||||
|
|
||||||
|
# Check WiFi IP address
|
||||||
|
ip a show wlan0
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 4: SSH Over WiFi
|
||||||
|
|
||||||
|
Once connected to WiFi, note the robot's IP address and disconnect the Ethernet cable. You can now SSH over WiFi:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh unitree@<YOUR_ROBOT_IP>
|
||||||
|
# Password: 123
|
||||||
|
```
|
||||||
|
|
||||||
|
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address (e.g., `172.18.129.215`).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Part 3: Robot Server Setup
|
||||||
|
|
||||||
|
### Step 1: Install LeRobot on the Orin
|
||||||
|
|
||||||
|
SSH into the robot and install LeRobot:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh unitree@<YOUR_ROBOT_IP>
|
||||||
|
|
||||||
|
conda create -y -n lerobot python=3.10
|
||||||
|
conda activate lerobot
|
||||||
|
git clone https://github.com/huggingface/lerobot.git
|
||||||
|
cd lerobot
|
||||||
|
pip install -e '.[unitree_g1]'
|
||||||
|
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||||
|
cd unitree_sdk2_python && pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: The Unitree SDK requires CycloneDDS v0.10.2 to be installed. See the [Unitree SDK documentation](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
|
||||||
|
|
||||||
|
### Step 2: Run the Robot Server
|
||||||
|
|
||||||
|
On the robot:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/robots/unitree_g1/run_g1_server.py
|
||||||
|
```
|
||||||
|
|
||||||
|
**Important**: Keep this terminal running. The server must be active for remote control.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Part 4: Running GR00T Locomotion
|
||||||
|
|
||||||
|
With the robot server running, you can now control the robot from your laptop.
|
||||||
|
|
||||||
|
### Step 1: Install LeRobot on your machine
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda create -y -n lerobot python=3.10
|
||||||
|
conda activate lerobot
|
||||||
|
git clone https://github.com/huggingface/lerobot.git
|
||||||
|
cd lerobot
|
||||||
|
pip install -e '.[unitree_g1]'
|
||||||
|
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||||
|
cd unitree_sdk2_python && pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Update Robot IP in Config
|
||||||
|
|
||||||
|
Edit the config file to match your robot's WiFi IP:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# In src/lerobot/robots/unitree_g1/config_unitree_g1.py
|
||||||
|
robot_ip: str = "<YOUR_ROBOT_IP>" # Replace with your robot's WiFi IP.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: When running directly on the G1 (not remotely), set `robot_ip: str = "127.0.0.1"` instead.
|
||||||
|
|
||||||
|
### Step 3: Run the Locomotion Policy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run GR00T locomotion controller
|
||||||
|
python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 4: Control with Remote
|
||||||
|
|
||||||
|
- **Left stick**: Forward/backward and left/right movement
|
||||||
|
- **Right stick**: Rotation
|
||||||
|
- **R1 button**: Raise waist height
|
||||||
|
- **R2 button**: Lower waist height
|
||||||
|
|
||||||
|
Press `Ctrl+C` to stop the policy.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Extra: Running in Simulation Mode (MuJoCo)
|
||||||
|
|
||||||
|
You can now test and develop policies without a physical robot using MuJoCo. to do so set `is_simulation=True` in config.
|
||||||
|
|
||||||
|
## Additional Resources
|
||||||
|
|
||||||
|
- [Unitree SDK Documentation](https://github.com/unitreerobotics/unitree_sdk2_python)
|
||||||
|
- [GR00T Policy Repository](https://huggingface.co/nepyope/GR00T-WholeBodyControl_g1)
|
||||||
|
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||||
|
- [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
_Last updated: December 2025_
|
||||||
@@ -11,13 +11,14 @@ LeRobot provides several utilities for manipulating datasets:
|
|||||||
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
|
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
|
||||||
4. **Add Features** - Add new features to a dataset
|
4. **Add Features** - Add new features to a dataset
|
||||||
5. **Remove Features** - Remove features from a dataset
|
5. **Remove Features** - Remove features from a dataset
|
||||||
|
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
|
||||||
|
|
||||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||||
|
|
||||||
## Command-Line Tool: lerobot-edit-dataset
|
## Command-Line Tool: lerobot-edit-dataset
|
||||||
|
|
||||||
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features.
|
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, remove features, and convert image datasets to video format.
|
||||||
|
|
||||||
Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
|
Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
|
||||||
|
|
||||||
@@ -86,9 +87,71 @@ lerobot-edit-dataset \
|
|||||||
--operation.feature_names "['observation.images.top']"
|
--operation.feature_names "['observation.images.top']"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Convert to Video
|
||||||
|
|
||||||
|
Convert an image-based dataset to video format, creating a new LeRobotDataset where images are stored as videos. This is useful for reducing storage requirements and improving data loading performance. The new dataset will have the exact same structure as the original, but with images encoded as MP4 videos in the proper LeRobot format.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Local-only: Save to a custom output directory (no hub push)
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht_image \
|
||||||
|
--operation.type convert_to_video \
|
||||||
|
--operation.output_dir /path/to/output/pusht_video
|
||||||
|
|
||||||
|
# Save with new repo_id (local storage)
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht_image \
|
||||||
|
--new_repo_id lerobot/pusht_video \
|
||||||
|
--operation.type convert_to_video
|
||||||
|
|
||||||
|
# Convert and push to Hugging Face Hub
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht_image \
|
||||||
|
--new_repo_id lerobot/pusht_video \
|
||||||
|
--operation.type convert_to_video \
|
||||||
|
--push_to_hub true
|
||||||
|
|
||||||
|
# Convert with custom video codec and quality settings
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht_image \
|
||||||
|
--operation.type convert_to_video \
|
||||||
|
--operation.output_dir outputs/pusht_video \
|
||||||
|
--operation.vcodec libsvtav1 \
|
||||||
|
--operation.pix_fmt yuv420p \
|
||||||
|
--operation.g 2 \
|
||||||
|
--operation.crf 30
|
||||||
|
|
||||||
|
# Convert only specific episodes
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht_image \
|
||||||
|
--operation.type convert_to_video \
|
||||||
|
--operation.output_dir outputs/pusht_video \
|
||||||
|
--operation.episode_indices "[0, 1, 2, 5, 10]"
|
||||||
|
|
||||||
|
# Convert with multiple workers for parallel processing
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht_image \
|
||||||
|
--operation.type convert_to_video \
|
||||||
|
--operation.output_dir outputs/pusht_video \
|
||||||
|
--operation.num_workers 8
|
||||||
|
```
|
||||||
|
|
||||||
|
**Parameters:**
|
||||||
|
|
||||||
|
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
|
||||||
|
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
|
||||||
|
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
|
||||||
|
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
|
||||||
|
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
|
||||||
|
- `fast_decode`: Fast decode tuning option (default: 0)
|
||||||
|
- `episode_indices`: List of specific episodes to convert (default: all episodes)
|
||||||
|
- `num_workers`: Number of parallel workers for processing (default: 4)
|
||||||
|
|
||||||
|
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
|
||||||
|
|
||||||
### Push to Hub
|
### Push to Hub
|
||||||
|
|
||||||
Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-edit-dataset \
|
lerobot-edit-dataset \
|
||||||
@@ -96,7 +159,45 @@ lerobot-edit-dataset \
|
|||||||
--new_repo_id lerobot/pusht_after_deletion \
|
--new_repo_id lerobot/pusht_after_deletion \
|
||||||
--operation.type delete_episodes \
|
--operation.type delete_episodes \
|
||||||
--operation.episode_indices "[0, 2, 5]" \
|
--operation.episode_indices "[0, 2, 5]" \
|
||||||
--push_to_hub
|
--push_to_hub true
|
||||||
```
|
```
|
||||||
|
|
||||||
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.
|
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.
|
||||||
|
|
||||||
|
# Dataset Visualization
|
||||||
|
|
||||||
|
## Online Visualization
|
||||||
|
|
||||||
|
When you record a dataset using `lerobot`, it automatically uploads to the Hugging Face Hub unless you specify otherwise. To view the dataset online, use our **LeRobot Dataset Visualizer**, available at:
|
||||||
|
https://huggingface.co/spaces/lerobot/visualize_dataset
|
||||||
|
|
||||||
|
## Local Visualization
|
||||||
|
|
||||||
|
You can also visualize episodes from a dataset locally using our command-line tool.
|
||||||
|
|
||||||
|
**From the Hugging Face Hub:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-dataset-viz \
|
||||||
|
--repo-id lerobot/pusht \
|
||||||
|
--episode-index 0
|
||||||
|
```
|
||||||
|
|
||||||
|
**From a local folder:**
|
||||||
|
Add the `--root` option and set `--mode local`. For example, to search in `./my_local_data_dir/lerobot/pusht`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-dataset-viz \
|
||||||
|
--repo-id lerobot/pusht \
|
||||||
|
--root ./my_local_data_dir \
|
||||||
|
--mode local \
|
||||||
|
--episode-index 0
|
||||||
|
```
|
||||||
|
|
||||||
|
Once executed, the tool opens `rerun.io` and displays the camera streams, robot states, and actions for the selected episode.
|
||||||
|
|
||||||
|
For advanced usage—including visualizing datasets stored on a remote server—run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-dataset-viz --help
|
||||||
|
```
|
||||||
|
|||||||
528
docs/source/xvla.mdx
Normal file
528
docs/source/xvla.mdx
Normal file
@@ -0,0 +1,528 @@
|
|||||||
|
# X-VLA: The First Soft-Prompted Robot Foundation Model for Any Robot, Any Task
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
For years, robotics has aspired to build agents that can follow natural human instructions and operate dexterously across many environments and robot bodies. Recent breakthroughs in LLMs and VLMs suggest a path forward: extend these foundation-model architectures to embodied control by grounding them in actions. This has led to the rise of Vision-Language-Action (VLA) models, with the hope that a single generalist model could combine broad semantic understanding with robust manipulation skills.
|
||||||
|
|
||||||
|
But training such models is difficult. Robot data is fragmented across platforms, sensors, embodiments, and collection protocols. Heterogeneity appears everywhere: different arm configurations, different action spaces, different camera setups, different visual domains, and different task distributions. These inconsistencies create major distribution shifts that make pretraining unstable and adaptation unreliable.
|
||||||
|
|
||||||
|
Inspired by meta-learning and prompt learning, we ask: **"What if a VLA model could learn the structure of each robot and dataset the same way LLMs learn tasks, through prompts?"**
|
||||||
|
|
||||||
|
**X-VLA** is a soft-prompted, flow-matching VLA framework that treats each hardware setup as a "task" and encodes it using a small set of learnable embeddings. These **Soft Prompts** capture embodiment and domain-specific variations, guiding the Transformer from the earliest stages of multimodal fusion. With this mechanism, X-VLA can reconcile diverse robot morphologies, data types, and sensor setups within a single unified architecture.
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture.png"
|
||||||
|
alt="XVLA Architecture"
|
||||||
|
style="max-width: 100%; height: auto; width: 800px;"
|
||||||
|
/>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
Built from pure Transformer encoders, X-VLA scales naturally with model size and dataset diversity. Across 6 simulation benchmarks and 3 real robots, Soft Prompts consistently outperform existing methods in handling hardware and domain differences. X-VLA-0.9B, trained on 290K episodes spanning seven robotic platforms, learns an embodiment-agnostic generalist policy in Phase I, and adapts efficiently to new robots in Phase II simply by learning a new set of prompts, while keeping the backbone frozen.
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture2.png"
|
||||||
|
alt="XVLA Architecture 2"
|
||||||
|
style="width: 60%; height: auto;"
|
||||||
|
/>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
With only 1% of parameters tuned (9M), X-VLA-0.9B achieves near-π₀ performance on LIBERO and Simpler-WidowX, despite using **300× fewer trainable parameters**. It also demonstrates strong real-world dexterity with minimal demonstrations, including folding cloths in under two minutes.
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-fold.png"
|
||||||
|
alt="XVLA fold visualization"
|
||||||
|
style="width: 95%; max-width: 1100px; height: auto;"
|
||||||
|
/>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
X-VLA shows that generalist robot intelligence does not require increasingly complex architectures, only the right way to absorb heterogeneity. Soft Prompts offer a simple, scalable mechanism for unifying diverse robotic data, paving the way toward adaptable, cross-embodiment robot foundation models.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
After installing LeRobot, install the X-VLA dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e .[xvla]
|
||||||
|
```
|
||||||
|
|
||||||
|
After the new release, you'll be able to do:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install lerobot[xvla]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
To use X-VLA in your LeRobot configuration, specify the policy type as:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
policy.type=xvla
|
||||||
|
```
|
||||||
|
|
||||||
|
### Evaluating Pre-trained Checkpoints
|
||||||
|
|
||||||
|
Example evaluation with LIBERO:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path="lerobot/xvla-libero" \
|
||||||
|
--env.type=libero \
|
||||||
|
--env.task=libero_spatial,libero_goal,libero_10 \
|
||||||
|
--env.control_mode=absolute \
|
||||||
|
--eval.batch_size=1 \
|
||||||
|
--eval.n_episodes=1 \
|
||||||
|
--env.episode_length=800 \
|
||||||
|
--seed=142
|
||||||
|
```
|
||||||
|
|
||||||
|
## Available Checkpoints
|
||||||
|
|
||||||
|
### 🎯 Base Model
|
||||||
|
|
||||||
|
**[lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base)**
|
||||||
|
|
||||||
|
A 0.9B parameter instantiation of X-VLA, trained with a carefully designed data processing and learning recipe. The training pipeline consists of two phases:
|
||||||
|
|
||||||
|
- **Phase I: Pretraining** - Pretrained on 290K episodes from Droid, Robomind, and Agibot, spanning seven platforms across five types of robotic arms (single-arm to bi-manual setups). By leveraging soft prompts to absorb embodiment-specific variations, the model learns an embodiment-agnostic generalist policy.
|
||||||
|
|
||||||
|
- **Phase II: Domain Adaptation** - Adapted to deployable policies for target domains. A new set of soft prompts is introduced and optimized to encode the hardware configuration of the novel domain, while the pretrained backbone remains frozen.
|
||||||
|
|
||||||
|
### Simulation Checkpoints
|
||||||
|
|
||||||
|
**[lerobot/xvla-libero](https://huggingface.co/lerobot/xvla-libero)**
|
||||||
|
|
||||||
|
Achieves 93% success rate on LIBERO benchmarks. Fine-tuned from the base model for simulation tasks.
|
||||||
|
|
||||||
|
**[lerobot/xvla-widowx](https://huggingface.co/lerobot/xvla-widowx)**
|
||||||
|
|
||||||
|
Fine-tuned on BridgeData for pick-and-place experiments on compact WidowX platforms. Demonstrates robust manipulation capabilities.
|
||||||
|
|
||||||
|
### 🤖 Real-World Checkpoints
|
||||||
|
|
||||||
|
**[lerobot/xvla-folding](https://huggingface.co/lerobot/xvla-folding)**
|
||||||
|
|
||||||
|
A fine-tuned dexterous manipulation model trained on the high-quality Soft-FOLD cloth folding dataset. Achieves 100% success rate over 2 hours of continuous cloth folding.
|
||||||
|
|
||||||
|
**[lerobot/xvla-agibot-world](https://huggingface.co/lerobot/xvla-agibot-world)**
|
||||||
|
|
||||||
|
Optimized for AgileX robot dexterous manipulation tasks.
|
||||||
|
|
||||||
|
**[lerobot/xvla-google-robot](https://huggingface.co/lerobot/xvla-google-robot)**
|
||||||
|
|
||||||
|
Adapted for Google Robot platforms.
|
||||||
|
|
||||||
|
## Training X-VLA
|
||||||
|
|
||||||
|
### Recommended Training Configuration
|
||||||
|
|
||||||
|
When fine-tuning X-VLA for a new embodiment or task, we recommend not freezing the VLM, and also setting the `policy.dtype=bfloat16` to not hit OOM errors.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--dataset.repo_id=YOUR_DATASET \
|
||||||
|
--output_dir=./outputs/xvla_training \
|
||||||
|
--job_name=xvla_training \
|
||||||
|
--policy.path="lerobot/xvla-base" \
|
||||||
|
--policy.repo_id="HF_USER/xvla-your-robot" \
|
||||||
|
--policy.dtype=bfloat16 \
|
||||||
|
--policy.action_mode=auto \
|
||||||
|
--steps=20000 \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--policy.freeze_vision_encoder=false \
|
||||||
|
--policy.freeze_language_encoder=false \
|
||||||
|
--policy.train_policy_transformer=true \
|
||||||
|
--policy.train_soft_prompts=true \
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training Parameters Explained
|
||||||
|
|
||||||
|
| Parameter | Default | Description |
|
||||||
|
| -------------------------- | ------- | ---------------------------------------------- |
|
||||||
|
| `freeze_vision_encoder` | `false` | Do not freeze the VLM vision encoder weights |
|
||||||
|
| `freeze_language_encoder` | `false` | Do not freeze the VLM language encoder weights |
|
||||||
|
| `train_policy_transformer` | `true` | Allow policy transformer layers to train |
|
||||||
|
| `train_soft_prompts` | `true` | Allow soft prompts to train |
|
||||||
|
|
||||||
|
**💡 Best Practice**: For Phase II adaptation to new embodiments, do not freeze the VLM encoders and also train the policy transformer and soft prompts.
|
||||||
|
|
||||||
|
### Example: Training on Bimanual Robot
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
|
||||||
|
--output_dir=./outputs/xvla_bimanual \
|
||||||
|
--job_name=xvla_so101_training \
|
||||||
|
--policy.path="lerobot/xvla-base" \
|
||||||
|
--policy.dtype=bfloat16 \
|
||||||
|
--policy.repo_id="YOUR_USERNAME/xvla-biso101" \
|
||||||
|
--steps=3000 \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--policy.action_mode=so101_bimanual \
|
||||||
|
--policy.freeze_vision_encoder=false \
|
||||||
|
--policy.freeze_language_encoder=false \
|
||||||
|
--policy.train_policy_transformer=true \
|
||||||
|
--policy.train_soft_prompts=true
|
||||||
|
```
|
||||||
|
|
||||||
|
💡 **Best Performance:** If you have sufficient computational resources and want to achieve best X-VLA finetuning performance, you should follow the official finetuning strategy:
|
||||||
|
|
||||||
|
**🔥 Full-finetune all components with a custom learning-rate scheme**
|
||||||
|
|
||||||
|
To ensure stable optimization, the Vision-Language Model (VLM) must be trained with only 1/10 of the base learning rate, while all other components use the full LR.
|
||||||
|
This LR ratio is crucial for achieving strong and stable finetuning performance. This is already done for you by default.
|
||||||
|
❕Note
|
||||||
|
|
||||||
|
Completely matching the official reported performance may require an additional warm-up LR schedule for soft-prompts, which can bring minor improvements.
|
||||||
|
We encourage implementing this in your customized training pipeline for optimal results.
|
||||||
|
|
||||||
|
## Core Concepts
|
||||||
|
|
||||||
|
### 1. Action Modes
|
||||||
|
|
||||||
|
X-VLA uses an **Action Registry** system to handle different action spaces and embodiments. The `action_mode` parameter defines how actions are processed, what loss functions are used, and how predictions are post-processed.
|
||||||
|
|
||||||
|
#### Available Action Modes
|
||||||
|
|
||||||
|
| Action Mode | Action Dim | Description | Use Case |
|
||||||
|
| ---------------- | ----------------------- | ------------------------------------------- | ------------------------------------ |
|
||||||
|
| `ee6d` | 20 | End-effector with xyz, 6D rotation, gripper | Dual-arm setups with spatial control |
|
||||||
|
| `joint` | 14 | Joint-space with gripper | Direct joint control robots |
|
||||||
|
| `agibot_ee6d` | 20 | AGI-bot variant with MSE loss | AGI-bot platforms |
|
||||||
|
| `so101_bimanual` | 20 (model), 12 (real) | SO101 bimanual robot | Bimanual manipulation tasks |
|
||||||
|
| `auto` | 20 (model), auto (real) | Auto-detects action dim from dataset | **Recommended** for new robots |
|
||||||
|
|
||||||
|
#### Why Action Modes Matter
|
||||||
|
|
||||||
|
When you have a pretrained checkpoint like `lerobot/xvla-base` trained with `action_dim=20`, and you want to train on a dataset with a different action dimension (e.g., 14 for bimanual arms), you can't simply trim the action dimension. The action mode orchestrates:
|
||||||
|
|
||||||
|
1. **Loss Computation**: Different loss functions for different action components (MSE for joints, BCE for grippers, etc.)
|
||||||
|
2. **Preprocessing**: Zeroing out gripper channels, padding dimensions
|
||||||
|
3. **Postprocessing**: Applying sigmoid to gripper logits, trimming padding
|
||||||
|
|
||||||
|
#### Example: BimanualSO101 Action Space
|
||||||
|
|
||||||
|
The `so101_bimanual` action mode handles the mismatch between model output (20D) and real robot control (12D):
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Model outputs 20 dimensions for compatibility
|
||||||
|
dim_action = 20
|
||||||
|
|
||||||
|
# Real robot only needs 12 dimensions
|
||||||
|
# [left_arm (6), right_arm (6)] = [joints (5) + gripper (1)] × 2
|
||||||
|
REAL_DIM = 12
|
||||||
|
|
||||||
|
# Preprocessing: Pad 12D actions to 20D for training
|
||||||
|
# Postprocessing: Trim 20D predictions to 12D for deployment
|
||||||
|
```
|
||||||
|
|
||||||
|
See the [action_hub.py](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py) implementation for details.
|
||||||
|
|
||||||
|
#### Auto Action Mode (Recommended)
|
||||||
|
|
||||||
|
The `auto` action mode is the easiest way to use X-VLA with any robot. It automatically detects your dataset's action dimension and handles padding/trimming:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--policy.path="lerobot/xvla-base" \
|
||||||
|
--policy.action_mode=auto \
|
||||||
|
--policy.max_action_dim=20 \
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
**How it works:**
|
||||||
|
|
||||||
|
- Reads `action_feature.shape[-1]` from your dataset (e.g., 7 for Franka)
|
||||||
|
- Model outputs `max_action_dim` (default 20) for pretrained compatibility
|
||||||
|
- Loss is computed **only on the real dimensions**: `MSE(pred[:,:,:real_dim], target[:,:,:real_dim])`
|
||||||
|
- Postprocess trims output back to `real_dim` for robot control
|
||||||
|
|
||||||
|
This eliminates the need to create custom action modes for most robots.
|
||||||
|
|
||||||
|
### 2. Domain IDs
|
||||||
|
|
||||||
|
Domain IDs are learnable identifiers for different robot configurations and camera setups. They allow X-VLA to distinguish between:
|
||||||
|
|
||||||
|
- Different robots (Robot 1 vs Robot 2)
|
||||||
|
- Different camera configurations (cam1 vs cam2)
|
||||||
|
- Different combinations (Robot1-cam1-cam2 vs Robot1-cam1 vs Robot2-cam1)
|
||||||
|
|
||||||
|
#### Setting Domain IDs
|
||||||
|
|
||||||
|
**During Training**: By default, domain_id is set to 0 for general training.
|
||||||
|
|
||||||
|
**During Evaluation**: Specify the domain_id that matches your checkpoint's training configuration.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Example: LIBERO checkpoint uses domain_id=3
|
||||||
|
domain_id = 3
|
||||||
|
```
|
||||||
|
|
||||||
|
The domain_id is automatically added to observations by the `XVLAAddDomainIdProcessorStep` in the preprocessing pipeline.
|
||||||
|
|
||||||
|
The `lerobot/xvla-base` model has been trained on the following domain IDs. It is recommended to choose one that most resembles your robot/configuration:
|
||||||
|
|
||||||
|
#### Fine-tuning Datasets
|
||||||
|
|
||||||
|
| Dataset Name | Domain ID |
|
||||||
|
| ---------------- | --------- |
|
||||||
|
| Bridge | 0 |
|
||||||
|
| RT1 | 1 |
|
||||||
|
| Calvin | 2 |
|
||||||
|
| libero | 3 |
|
||||||
|
| widowx-air | 4 |
|
||||||
|
| AIR-AGILEX-HQ | 5 |
|
||||||
|
| robotwin2_abs_ee | 6 |
|
||||||
|
| robotwin2_clean | 6 |
|
||||||
|
| robocasa-human | 7 |
|
||||||
|
| VLABench | 8 |
|
||||||
|
| AGIBOT-challenge | 9 |
|
||||||
|
| AIR-AGILEX | 10 |
|
||||||
|
| AIRBOT | 18 |
|
||||||
|
|
||||||
|
### 3. Processor Steps
|
||||||
|
|
||||||
|
X-VLA requires specific preprocessing and postprocessing steps for proper operation.
|
||||||
|
|
||||||
|
#### Required Preprocessing Steps
|
||||||
|
|
||||||
|
1. **XVLAImageToFloatProcessorStep**: Converts images from [0, 255] to [0, 1] range
|
||||||
|
2. **XVLAImageNetNormalizeProcessorStep**: Applies ImageNet normalization (required for VLM backbone)
|
||||||
|
3. **XVLAAddDomainIdProcessorStep**: Adds domain_id to observations
|
||||||
|
|
||||||
|
#### Example Custom Processor
|
||||||
|
|
||||||
|
For LIBERO environments, a custom processor handles the specific observation format:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.policies.xvla.processor_xvla import LiberoProcessorStep
|
||||||
|
|
||||||
|
processor = LiberoProcessorStep()
|
||||||
|
# Handles robot_state dictionary, converts rotation matrices to 6D representation
|
||||||
|
# Applies 180° image rotation for camera convention
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Configuration Parameters
|
||||||
|
|
||||||
|
Key configuration parameters for X-VLA:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Observation and action
|
||||||
|
n_obs_steps: int = 1 # Number of observation timesteps
|
||||||
|
chunk_size: int = 32 # Action sequence length
|
||||||
|
n_action_steps: int = 32 # Number of action steps to execute
|
||||||
|
|
||||||
|
# Model architecture
|
||||||
|
hidden_size: int = 1024 # Transformer hidden dimension
|
||||||
|
depth: int = 24 # Number of transformer layers
|
||||||
|
num_heads: int = 16 # Number of attention heads
|
||||||
|
num_domains: int = 30 # Maximum number of domain IDs
|
||||||
|
len_soft_prompts: int = 32 # Length of soft prompt embeddings
|
||||||
|
|
||||||
|
# Action space
|
||||||
|
action_mode: str = "ee6d" # Action space type (use "auto" for auto-detection)
|
||||||
|
use_proprio: bool = True # Use proprioceptive state
|
||||||
|
max_state_dim: int = 32 # Maximum state dimension
|
||||||
|
max_action_dim: int = 20 # Max action dim for padding (used by "auto" mode)
|
||||||
|
|
||||||
|
# Vision
|
||||||
|
num_image_views: int | None # Number of camera views
|
||||||
|
resize_imgs_with_padding: tuple[int, int] | None # Target image size with padding
|
||||||
|
|
||||||
|
# Training
|
||||||
|
num_denoising_steps: int = 10 # Flow matching denoising steps
|
||||||
|
```
|
||||||
|
|
||||||
|
## Creating Custom Action Modes
|
||||||
|
|
||||||
|
If your robot has a unique action space, you can create a custom action mode:
|
||||||
|
|
||||||
|
### Step 1: Define Your Action Space
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.policies.xvla.action_hub import BaseActionSpace, register_action
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
@register_action("my_custom_robot")
|
||||||
|
class MyCustomActionSpace(BaseActionSpace):
|
||||||
|
"""Custom action space for my robot."""
|
||||||
|
|
||||||
|
dim_action = 15 # Your robot's action dimension
|
||||||
|
gripper_idx = (7, 14) # Gripper channel indices
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
self.bce = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
|
def compute_loss(self, pred, target):
|
||||||
|
"""Define your loss computation."""
|
||||||
|
# Example: MSE for joints, BCE for grippers
|
||||||
|
joints_loss = self.mse(pred[:, :, :7], target[:, :, :7])
|
||||||
|
gripper_loss = self.bce(pred[:, :, self.gripper_idx],
|
||||||
|
target[:, :, self.gripper_idx])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"joints_loss": joints_loss,
|
||||||
|
"gripper_loss": gripper_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
def preprocess(self, proprio, action, mode="train"):
|
||||||
|
"""Preprocess actions before training."""
|
||||||
|
# Example: Zero out grippers in proprioception
|
||||||
|
proprio_m = proprio.clone()
|
||||||
|
action_m = action.clone() if action is not None else None
|
||||||
|
proprio_m[..., self.gripper_idx] = 0.0
|
||||||
|
if action_m is not None:
|
||||||
|
action_m[..., self.gripper_idx] = 0.0
|
||||||
|
return proprio_m, action_m
|
||||||
|
|
||||||
|
def postprocess(self, action):
|
||||||
|
"""Post-process predictions for deployment."""
|
||||||
|
# Example: Apply sigmoid to gripper logits
|
||||||
|
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||||
|
return action
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Use Your Custom Action Mode
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--policy.action_mode=my_custom_robot \
|
||||||
|
--dataset.repo_id=YOUR_DATASET \
|
||||||
|
--policy.path="lerobot/xvla-base" \
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Topics
|
||||||
|
|
||||||
|
### Multi-Camera Support
|
||||||
|
|
||||||
|
X-VLA supports multiple camera views through the `num_image_views` parameter:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Configure for 3 camera views
|
||||||
|
policy.num_image_views=3
|
||||||
|
|
||||||
|
# Add empty cameras if you have fewer physical cameras
|
||||||
|
policy.empty_cameras=1 # Adds 1 zero-padded camera view
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Preprocessing Pipeline
|
||||||
|
|
||||||
|
Create a custom preprocessing pipeline for your environment:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.processor import PolicyProcessorPipeline
|
||||||
|
from lerobot.policies.xvla.processor_xvla import (
|
||||||
|
XVLAImageToFloatProcessorStep,
|
||||||
|
XVLAImageNetNormalizeProcessorStep,
|
||||||
|
XVLAAddDomainIdProcessorStep,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build custom pipeline
|
||||||
|
preprocessor = PolicyProcessorPipeline(
|
||||||
|
steps=[
|
||||||
|
YourCustomProcessorStep(), # Your custom processing
|
||||||
|
XVLAImageToFloatProcessorStep(), # Required: convert to float
|
||||||
|
XVLAImageNetNormalizeProcessorStep(), # Required: ImageNet norm
|
||||||
|
XVLAAddDomainIdProcessorStep(domain_id=5), # Your domain ID
|
||||||
|
]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Handling Different Action Dimensions
|
||||||
|
|
||||||
|
When your dataset has fewer action dimensions than the pretrained model:
|
||||||
|
|
||||||
|
**Option 1 (Recommended)**: Use `auto` action mode
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Automatically detects your dataset's action dimension
|
||||||
|
# Works with any robot without custom code
|
||||||
|
policy.action_mode=auto
|
||||||
|
policy.max_action_dim=20 # Match pretrained model
|
||||||
|
```
|
||||||
|
|
||||||
|
**Option 2**: Use a predefined action mode with built-in padding
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Model expects 20D, dataset has 12D
|
||||||
|
# Action mode handles padding internally
|
||||||
|
action_mode = "so101_bimanual" # Pads 12 → 20
|
||||||
|
```
|
||||||
|
|
||||||
|
**Option 2**: Create a custom action mode that maps dimensions explicitly
|
||||||
|
|
||||||
|
```python
|
||||||
|
@register_action("my_mapped_action")
|
||||||
|
class MappedActionSpace(BaseActionSpace):
|
||||||
|
dim_action = 20
|
||||||
|
REAL_DIM = 12
|
||||||
|
|
||||||
|
def _pad_to_model_dim(self, x):
|
||||||
|
# Custom padding logic
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
**Issue**: "Action dimension mismatch"
|
||||||
|
|
||||||
|
- **Solution**: Check that your `action_mode` matches your robot's action space. Create a custom action mode if needed.
|
||||||
|
|
||||||
|
**Issue**: "Image values outside [0, 1] range"
|
||||||
|
|
||||||
|
- **Solution**: Ensure images are preprocessed with `XVLAImageToFloatProcessorStep` before normalization.
|
||||||
|
|
||||||
|
**Issue**: "Domain ID not found"
|
||||||
|
|
||||||
|
- **Solution**: Make sure `XVLAAddDomainIdProcessorStep` is in your preprocessing pipeline with the correct domain_id.
|
||||||
|
|
||||||
|
**Issue**: "Low success rate on new embodiment"
|
||||||
|
|
||||||
|
- **Solution**:
|
||||||
|
1. Verify your action_mode is correct
|
||||||
|
2. Check that soft prompts are being trained (`train_soft_prompts=True`)
|
||||||
|
3. Ensure proper preprocessing (ImageNet normalization, domain_id)
|
||||||
|
4. Consider increasing training steps
|
||||||
|
|
||||||
|
**Issue**: "Out of memory during training"
|
||||||
|
|
||||||
|
- **Solution**:
|
||||||
|
1. Reduce `chunk_size` (e.g., from 32 to 16)
|
||||||
|
2. Enable gradient checkpointing
|
||||||
|
3. Reduce batch size
|
||||||
|
4. Freeze more components
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
If you use X-VLA in your research, please cite:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{zheng2025x,
|
||||||
|
title = {X-VLA: Soft-Prompted Transformer as Scalable Cross-Embodiment Vision-Language-Action Model},
|
||||||
|
author = {Zheng, Jinliang and Li, Jianxiong and Wang, Zhihao and Liu, Dongxiu and Kang, Xirui
|
||||||
|
and Feng, Yuchun and Zheng, Yinan and Zou, Jiayin and Chen, Yilun and Zeng, Jia and others},
|
||||||
|
journal = {arXiv preprint arXiv:2510.10274},
|
||||||
|
year = {2025}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Additional Resources
|
||||||
|
|
||||||
|
- [X-VLA Paper](https://arxiv.org/pdf/2510.10274)
|
||||||
|
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||||
|
- [Action Registry Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/action_hub.py)
|
||||||
|
- [Processor Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/processor_xvla.py)
|
||||||
|
- [Model Configuration](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/configuration_xvla.py)
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
We welcome contributions! If you've implemented a new action mode or processor for your robot, please consider submitting a PR to help the community.
|
||||||
416
examples/openarms/debug_can_communication.py
Normal file
416
examples/openarms/debug_can_communication.py
Normal file
@@ -0,0 +1,416 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Comprehensive debug script for OpenArms CAN FD communication.
|
||||||
|
Tests all 4 CAN interfaces with CAN FD support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import can
|
||||||
|
import time
|
||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
def check_can_interface(port):
|
||||||
|
"""Check if CAN interface is UP and configured."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(['ip', 'link', 'show', port],
|
||||||
|
capture_output=True, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return False, "Interface not found", None
|
||||||
|
|
||||||
|
output = result.stdout
|
||||||
|
if 'UP' not in output:
|
||||||
|
return False, "Interface is DOWN", None
|
||||||
|
|
||||||
|
# Check if CAN FD is enabled
|
||||||
|
is_fd = 'fd on' in output.lower() or 'canfd' in output.lower()
|
||||||
|
|
||||||
|
return True, "Interface is UP", is_fd
|
||||||
|
except FileNotFoundError:
|
||||||
|
return None, "Cannot check (ip command not found)", None
|
||||||
|
|
||||||
|
|
||||||
|
def test_motor_on_interface(bus, motor_id, timeout=2.0, use_fd=False):
|
||||||
|
"""
|
||||||
|
Test a single motor and return all responses.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of (arbitration_id, data) tuples for all responses received
|
||||||
|
"""
|
||||||
|
# Send enable command
|
||||||
|
enable_msg = can.Message(
|
||||||
|
arbitration_id=motor_id,
|
||||||
|
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
|
||||||
|
is_extended_id=False,
|
||||||
|
is_fd=use_fd
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
bus.send(enable_msg)
|
||||||
|
except Exception as e:
|
||||||
|
return None, f"Send error: {e}"
|
||||||
|
|
||||||
|
# Listen for responses
|
||||||
|
responses = []
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
msg = bus.recv(timeout=0.1)
|
||||||
|
if msg:
|
||||||
|
responses.append((msg.arbitration_id, msg.data, msg.is_fd if hasattr(msg, 'is_fd') else False))
|
||||||
|
|
||||||
|
# Send disable command
|
||||||
|
disable_msg = can.Message(
|
||||||
|
arbitration_id=motor_id,
|
||||||
|
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD],
|
||||||
|
is_extended_id=False,
|
||||||
|
is_fd=use_fd
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
bus.send(disable_msg)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return responses, None
|
||||||
|
|
||||||
|
|
||||||
|
def test_interface(port, interface_type="socketcan", use_can_fd=True):
|
||||||
|
"""Test all 8 motors on a single CAN interface."""
|
||||||
|
|
||||||
|
results = {
|
||||||
|
'interface': port,
|
||||||
|
'status': None,
|
||||||
|
'is_fd': use_can_fd,
|
||||||
|
'motors': {}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check interface status
|
||||||
|
status_ok, status_msg, interface_has_fd = check_can_interface(port)
|
||||||
|
|
||||||
|
if interface_has_fd is not None:
|
||||||
|
results['interface_fd_enabled'] = interface_has_fd
|
||||||
|
if use_can_fd and not interface_has_fd:
|
||||||
|
status_msg += " (CAN FD NOT enabled on interface!)"
|
||||||
|
elif interface_has_fd:
|
||||||
|
status_msg += " (CAN FD enabled)"
|
||||||
|
|
||||||
|
results['status'] = status_msg
|
||||||
|
|
||||||
|
if status_ok is False:
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Try to connect
|
||||||
|
try:
|
||||||
|
if use_can_fd:
|
||||||
|
print(f" Connecting to {port} with CAN FD (1 Mbps / 5 Mbps)...")
|
||||||
|
bus = can.interface.Bus(
|
||||||
|
channel=port,
|
||||||
|
interface=interface_type,
|
||||||
|
bitrate=1000000,
|
||||||
|
data_bitrate=5000000,
|
||||||
|
fd=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f" Connecting to {port} with CAN 2.0 (1 Mbps)...")
|
||||||
|
bus = can.interface.Bus(
|
||||||
|
channel=port,
|
||||||
|
interface=interface_type,
|
||||||
|
bitrate=1000000
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
results['status'] = f"Connection failed: {e}"
|
||||||
|
return results
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Clear any pending messages
|
||||||
|
while bus.recv(timeout=0.01):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Test each motor (0x01 to 0x08)
|
||||||
|
for motor_id in range(0x01, 0x09):
|
||||||
|
responses, error = test_motor_on_interface(bus, motor_id, timeout=1.0, use_fd=use_can_fd)
|
||||||
|
|
||||||
|
if error:
|
||||||
|
results['motors'][motor_id] = {'error': error}
|
||||||
|
elif responses:
|
||||||
|
results['motors'][motor_id] = {
|
||||||
|
'found': True,
|
||||||
|
'responses': responses
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
results['motors'][motor_id] = {
|
||||||
|
'found': False,
|
||||||
|
'responses': []
|
||||||
|
}
|
||||||
|
|
||||||
|
time.sleep(0.05) # Small delay between motors
|
||||||
|
|
||||||
|
finally:
|
||||||
|
bus.shutdown()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def print_results(all_results):
|
||||||
|
"""Print formatted results for all interfaces."""
|
||||||
|
|
||||||
|
print("SUMMARY - Motors Found on Each Interface")
|
||||||
|
|
||||||
|
motor_names = {
|
||||||
|
0x01: "joint_1 (Shoulder pan)",
|
||||||
|
0x02: "joint_2 (Shoulder lift)",
|
||||||
|
0x03: "joint_3 (Shoulder rotation)",
|
||||||
|
0x04: "joint_4 (Elbow flex)",
|
||||||
|
0x05: "joint_5 (Wrist roll)",
|
||||||
|
0x06: "joint_6 (Wrist pitch)",
|
||||||
|
0x07: "joint_7 (Wrist rotation)",
|
||||||
|
0x08: "gripper",
|
||||||
|
}
|
||||||
|
|
||||||
|
total_found = 0
|
||||||
|
|
||||||
|
for result in all_results:
|
||||||
|
interface = result['interface']
|
||||||
|
status = result['status']
|
||||||
|
|
||||||
|
print(f"{interface}: {status}")
|
||||||
|
if result.get('is_fd'):
|
||||||
|
print(f" Mode: CAN FD")
|
||||||
|
else:
|
||||||
|
print(f" Mode: CAN 2.0")
|
||||||
|
|
||||||
|
if 'Connection failed' in status or 'DOWN' in status:
|
||||||
|
print(f" ⚠ Cannot test {interface}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
motors_found = 0
|
||||||
|
|
||||||
|
for motor_id in range(0x01, 0x09):
|
||||||
|
motor_data = result['motors'].get(motor_id, {})
|
||||||
|
motor_name = motor_names.get(motor_id, "Unknown")
|
||||||
|
|
||||||
|
if motor_data.get('error'):
|
||||||
|
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ {motor_data['error']}")
|
||||||
|
elif motor_data.get('found'):
|
||||||
|
motors_found += 1
|
||||||
|
total_found += 1
|
||||||
|
responses = motor_data['responses']
|
||||||
|
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✓ FOUND")
|
||||||
|
|
||||||
|
for resp_id, data, is_fd in responses:
|
||||||
|
data_hex = data.hex()
|
||||||
|
fd_flag = " [FD]" if is_fd else " [2.0]"
|
||||||
|
print(f" → Response from 0x{resp_id:02X}{fd_flag}: {data_hex}")
|
||||||
|
else:
|
||||||
|
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ No response")
|
||||||
|
|
||||||
|
print(f"\n Summary: {motors_found}/8 motors found on {interface}")
|
||||||
|
|
||||||
|
# Overall summary
|
||||||
|
print("OVERALL SUMMARY")
|
||||||
|
print(f"Total motors found across all interfaces: {total_found}")
|
||||||
|
|
||||||
|
# Analyze configuration
|
||||||
|
print("DIAGNOSIS")
|
||||||
|
|
||||||
|
for result in all_results:
|
||||||
|
interface = result['interface']
|
||||||
|
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
|
||||||
|
|
||||||
|
if motors_found == 0:
|
||||||
|
print(f"\n⚠ {interface}: NO MOTORS FOUND")
|
||||||
|
print(" Possible issues:")
|
||||||
|
print(" 1. CAN FD mode mismatch (interface vs motor configuration)")
|
||||||
|
print(" 2. Missing 120Ω termination resistors at BOTH cable ends")
|
||||||
|
print(" 3. Motor timeout parameter set incorrectly (should NOT be 0)")
|
||||||
|
print(" 4. CANH/CANL wiring issue")
|
||||||
|
print(" 5. Cable too long (>40m for CAN FD at 5Mbps)")
|
||||||
|
|
||||||
|
# Check FD mismatch
|
||||||
|
if result.get('is_fd') and not result.get('interface_fd_enabled'):
|
||||||
|
print(" ⚠️ CRITICAL: Trying CAN FD but interface NOT configured for FD!")
|
||||||
|
print(f" Fix: sudo ip link set {interface} type can bitrate 1000000 dbitrate 5000000 fd on")
|
||||||
|
|
||||||
|
elif motors_found < 8:
|
||||||
|
print(f"\n⚠ {interface}: Only {motors_found}/8 motors responding")
|
||||||
|
print(" Check power and connections for missing motors")
|
||||||
|
else:
|
||||||
|
print(f"\n✓ {interface}: All 8 motors responding correctly!")
|
||||||
|
|
||||||
|
# Check for unexpected response IDs
|
||||||
|
print("RESPONSE ID ANALYSIS")
|
||||||
|
|
||||||
|
for result in all_results:
|
||||||
|
interface = result['interface']
|
||||||
|
unexpected = []
|
||||||
|
|
||||||
|
for motor_id, motor_data in result['motors'].items():
|
||||||
|
if motor_data.get('found'):
|
||||||
|
expected_id = motor_id + 0x10
|
||||||
|
actual_ids = [resp[0] for resp in motor_data['responses']]
|
||||||
|
|
||||||
|
if expected_id not in actual_ids:
|
||||||
|
unexpected.append((motor_id, actual_ids))
|
||||||
|
|
||||||
|
if unexpected:
|
||||||
|
print(f"\n⚠ {interface}: Unexpected response IDs detected")
|
||||||
|
for motor_id, actual_ids in unexpected:
|
||||||
|
expected_id = motor_id + 0x10
|
||||||
|
print(f" Motor 0x{motor_id:02X}: Expected 0x{expected_id:02X}, "
|
||||||
|
f"got {[f'0x{id:02X}' for id in actual_ids]}")
|
||||||
|
print(" → Motor Master IDs need reconfiguration")
|
||||||
|
else:
|
||||||
|
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
|
||||||
|
if motors_found > 0:
|
||||||
|
print(f"\n✓ {interface}: All responding motors use correct IDs")
|
||||||
|
|
||||||
|
|
||||||
|
def test_communication_speed(interface, motor_id, num_iterations=100):
|
||||||
|
"""
|
||||||
|
Test communication speed with a motor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (hz, avg_latency_ms) or (None, None) if test failed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Connect to interface
|
||||||
|
bus = can.interface.Bus(
|
||||||
|
channel=interface,
|
||||||
|
interface="socketcan",
|
||||||
|
bitrate=1000000,
|
||||||
|
data_bitrate=5000000,
|
||||||
|
fd=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send refresh commands and measure round-trip time
|
||||||
|
latencies = []
|
||||||
|
successful = 0
|
||||||
|
|
||||||
|
for _ in range(num_iterations):
|
||||||
|
start = time.perf_counter()
|
||||||
|
|
||||||
|
# Send enable command (lightweight operation)
|
||||||
|
enable_msg = can.Message(
|
||||||
|
arbitration_id=motor_id,
|
||||||
|
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
|
||||||
|
is_extended_id=False,
|
||||||
|
is_fd=True
|
||||||
|
)
|
||||||
|
bus.send(enable_msg)
|
||||||
|
|
||||||
|
# Wait for response
|
||||||
|
msg = bus.recv(timeout=0.1)
|
||||||
|
|
||||||
|
if msg:
|
||||||
|
latency = (time.perf_counter() - start) * 1000 # Convert to ms
|
||||||
|
latencies.append(latency)
|
||||||
|
successful += 1
|
||||||
|
|
||||||
|
bus.shutdown()
|
||||||
|
|
||||||
|
if successful > 0:
|
||||||
|
avg_latency = sum(latencies) / len(latencies)
|
||||||
|
hz = 1000.0 / avg_latency if avg_latency > 0 else 0
|
||||||
|
return hz, avg_latency
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Speed test error: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to test all CAN interfaces with CAN FD."""
|
||||||
|
|
||||||
|
print("\nThis will test all 4 CAN interfaces (can0-can3) with CAN FD")
|
||||||
|
print("Testing motors 0x01-0x08 on each interface")
|
||||||
|
print()
|
||||||
|
print("Make sure:")
|
||||||
|
print(" ✓ Motors are powered (24V)")
|
||||||
|
print(" ✓ CAN interfaces configured with FD mode:")
|
||||||
|
print(" ./examples/openarms/setup_can.sh")
|
||||||
|
print(" ✓ Motor 'timeout' parameter NOT set to 0 (use Damiao tools)")
|
||||||
|
print(" ✓ CAN wiring includes 120Ω termination at BOTH ends")
|
||||||
|
print()
|
||||||
|
|
||||||
|
input("Press ENTER to start testing...")
|
||||||
|
|
||||||
|
# Test all 4 interfaces with CAN FD
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
for i in range(4):
|
||||||
|
interface = f"can{i}"
|
||||||
|
print(f"Testing {interface}...")
|
||||||
|
|
||||||
|
result = test_interface(interface, use_can_fd=True)
|
||||||
|
all_results.append(result)
|
||||||
|
|
||||||
|
# Quick status
|
||||||
|
if 'Connection failed' in result['status'] or 'DOWN' in result['status']:
|
||||||
|
print(f" ⚠ {interface}: {result['status']}")
|
||||||
|
else:
|
||||||
|
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
|
||||||
|
print(f" {interface}: {motors_found}/8 motors found")
|
||||||
|
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
|
# Print detailed results
|
||||||
|
print_results(all_results)
|
||||||
|
|
||||||
|
print("Testing Complete!")
|
||||||
|
|
||||||
|
all_found = sum(sum(1 for m in r['motors'].values() if m.get('found')) for r in all_results)
|
||||||
|
|
||||||
|
if all_found == 0:
|
||||||
|
print("\n⚠️ CRITICAL: No motors found on any interface!")
|
||||||
|
print("\nTop issues to check:")
|
||||||
|
print(" 1. Motor 'timeout' parameter (use Damiao tools to set > 0)")
|
||||||
|
print(" 2. CAN FD not enabled (run ./examples/openarms/setup_can.sh)")
|
||||||
|
print(" 3. Missing termination resistors")
|
||||||
|
print("\nTry:")
|
||||||
|
print(" a) Check motor parameters with Damiao Debugging Tools")
|
||||||
|
print(" b) Verify CAN FD is enabled: ip -d link show can0 | grep fd")
|
||||||
|
print(" c) Run setup script: ./examples/openarms/setup_can.sh")
|
||||||
|
else:
|
||||||
|
# Run speed test on interfaces with motors
|
||||||
|
print("COMMUNICATION SPEED TEST")
|
||||||
|
print("\nTesting maximum communication frequency...")
|
||||||
|
|
||||||
|
for result in all_results:
|
||||||
|
interface = result['interface']
|
||||||
|
|
||||||
|
# Find first responding motor
|
||||||
|
responding_motor = None
|
||||||
|
for motor_id, motor_data in result['motors'].items():
|
||||||
|
if motor_data.get('found'):
|
||||||
|
responding_motor = motor_id
|
||||||
|
break
|
||||||
|
|
||||||
|
if responding_motor:
|
||||||
|
print(f"\n{interface}: Testing with motor 0x{responding_motor:02X}...")
|
||||||
|
hz, latency = test_communication_speed(interface, responding_motor, num_iterations=100)
|
||||||
|
|
||||||
|
if hz:
|
||||||
|
print(f" ✓ Max frequency: {hz:.1f} Hz")
|
||||||
|
print(f" ✓ Avg latency: {latency:.2f} ms")
|
||||||
|
print(f" ✓ Commands per second: ~{int(hz)}")
|
||||||
|
else:
|
||||||
|
print(f" ✗ Speed test failed")
|
||||||
|
else:
|
||||||
|
print(f"\n{interface}: No motors found, skipping speed test")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
main()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nTesting interrupted by user.")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nUnexpected error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
360
examples/openarms/evaluate.py
Normal file
360
examples/openarms/evaluate.py
Normal file
@@ -0,0 +1,360 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
OpenArms Policy Evaluation
|
||||||
|
|
||||||
|
Evaluates a trained policy on the OpenArms robot by running inference and recording
|
||||||
|
the evaluation episodes to a dataset. Supports optional leader arm for manual resets.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
python examples/openarms/evaluate.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||||
|
from lerobot.datasets.utils import combine_feature_dicts
|
||||||
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||||
|
from lerobot.processor import make_default_processors
|
||||||
|
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||||
|
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||||
|
from lerobot.scripts.lerobot_record import record_loop
|
||||||
|
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||||
|
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||||
|
from lerobot.utils.control_utils import init_keyboard_listener
|
||||||
|
from lerobot.utils.utils import log_say
|
||||||
|
from lerobot.utils.visualization_utils import init_rerun
|
||||||
|
|
||||||
|
|
||||||
|
HF_MODEL_ID = "lerobot-data-collection/three-folds-pi0" # TODO: Replace with your trained model
|
||||||
|
HF_EVAL_DATASET_ID = "lerobot-data-collection/three-folds-pi0_eval7" # TODO: Replace with your eval dataset name
|
||||||
|
TASK_DESCRIPTION = "three-folds-dataset" # TODO: Replace with your task, this should match!!
|
||||||
|
|
||||||
|
NUM_EPISODES = 1
|
||||||
|
FPS = 30
|
||||||
|
EPISODE_TIME_SEC = 300
|
||||||
|
RESET_TIME_SEC = 60
|
||||||
|
|
||||||
|
# Robot CAN interfaces
|
||||||
|
FOLLOWER_LEFT_PORT = "can0"
|
||||||
|
FOLLOWER_RIGHT_PORT = "can1"
|
||||||
|
|
||||||
|
# If enabled, you can manually reset the environment between evaluation episodes
|
||||||
|
USE_LEADER_FOR_RESETS = True # Set to False if you don't want to use leader
|
||||||
|
LEADER_LEFT_PORT = "can2"
|
||||||
|
LEADER_RIGHT_PORT = "can3"
|
||||||
|
|
||||||
|
# Camera configuration
|
||||||
|
CAMERA_CONFIG = {
|
||||||
|
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=FPS),
|
||||||
|
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=FPS),
|
||||||
|
"base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=FPS),
|
||||||
|
}
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main evaluation function."""
|
||||||
|
print("OpenArms Policy Evaluation")
|
||||||
|
print(f"\nModel: {HF_MODEL_ID}")
|
||||||
|
print(f"Evaluation Dataset: {HF_EVAL_DATASET_ID}")
|
||||||
|
print(f"Task: {TASK_DESCRIPTION}")
|
||||||
|
print(f"Episodes: {NUM_EPISODES}")
|
||||||
|
print(f"Episode Duration: {EPISODE_TIME_SEC}s")
|
||||||
|
print(f"Reset Duration: {RESET_TIME_SEC}s")
|
||||||
|
print(f"Use Leader for Resets: {USE_LEADER_FOR_RESETS}")
|
||||||
|
|
||||||
|
follower_config = OpenArmsFollowerConfig(
|
||||||
|
port_left=FOLLOWER_LEFT_PORT,
|
||||||
|
port_right=FOLLOWER_RIGHT_PORT,
|
||||||
|
can_interface="socketcan",
|
||||||
|
id="openarms_follower",
|
||||||
|
disable_torque_on_disconnect=True,
|
||||||
|
max_relative_target=10.0,
|
||||||
|
cameras=CAMERA_CONFIG,
|
||||||
|
)
|
||||||
|
|
||||||
|
follower = OpenArmsFollower(follower_config)
|
||||||
|
follower.connect(calibrate=False)
|
||||||
|
|
||||||
|
if not follower.is_connected:
|
||||||
|
raise RuntimeError("Follower robot failed to connect!")
|
||||||
|
|
||||||
|
|
||||||
|
leader = None
|
||||||
|
if USE_LEADER_FOR_RESETS:
|
||||||
|
leader_config = OpenArmsLeaderConfig(
|
||||||
|
port_left=LEADER_LEFT_PORT,
|
||||||
|
port_right=LEADER_RIGHT_PORT,
|
||||||
|
can_interface="socketcan",
|
||||||
|
id="openarms_leader",
|
||||||
|
manual_control=False, # Enable torque control for gravity compensation
|
||||||
|
)
|
||||||
|
|
||||||
|
leader = OpenArmsLeader(leader_config)
|
||||||
|
leader.connect(calibrate=False)
|
||||||
|
|
||||||
|
if not leader.is_connected:
|
||||||
|
raise RuntimeError("Leader robot failed to connect!")
|
||||||
|
|
||||||
|
# Enable gravity compensation
|
||||||
|
if leader.pin_robot is not None:
|
||||||
|
leader.bus_right.enable_torque()
|
||||||
|
leader.bus_left.enable_torque()
|
||||||
|
time.sleep(0.1)
|
||||||
|
print(f"Leader connected with gravity compensation ({LEADER_LEFT_PORT}, {LEADER_RIGHT_PORT})")
|
||||||
|
else:
|
||||||
|
print(f"Leader connected but gravity compensation unavailable (no URDF)")
|
||||||
|
|
||||||
|
# Build default processors for action and observation
|
||||||
|
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||||
|
|
||||||
|
# Build dataset features from robot features and processors
|
||||||
|
# For actions, only include positions (no velocity or torque)
|
||||||
|
action_features_hw = {}
|
||||||
|
for key, value in follower.action_features.items():
|
||||||
|
if key.endswith(".pos"):
|
||||||
|
action_features_hw[key] = value
|
||||||
|
|
||||||
|
dataset_features = combine_feature_dicts(
|
||||||
|
aggregate_pipeline_dataset_features(
|
||||||
|
pipeline=teleop_action_processor,
|
||||||
|
initial_features=create_initial_features(action=action_features_hw),
|
||||||
|
use_videos=True,
|
||||||
|
),
|
||||||
|
aggregate_pipeline_dataset_features(
|
||||||
|
pipeline=robot_observation_processor,
|
||||||
|
initial_features=create_initial_features(observation=follower.observation_features),
|
||||||
|
use_videos=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if dataset already exists
|
||||||
|
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / HF_EVAL_DATASET_ID
|
||||||
|
if dataset_path.exists():
|
||||||
|
print(f"Evaluation dataset already exists at: {dataset_path}")
|
||||||
|
print("This will append new episodes to the existing dataset.")
|
||||||
|
choice = input(" Continue? (y/n): ").strip().lower()
|
||||||
|
if choice != 'y':
|
||||||
|
print(" Aborting evaluation.")
|
||||||
|
follower.disconnect()
|
||||||
|
if leader:
|
||||||
|
leader.disconnect()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create dataset
|
||||||
|
dataset = LeRobotDataset.create(
|
||||||
|
repo_id=HF_EVAL_DATASET_ID,
|
||||||
|
fps=FPS,
|
||||||
|
features=dataset_features,
|
||||||
|
robot_type=follower.name,
|
||||||
|
use_videos=True,
|
||||||
|
image_writer_processes=0,
|
||||||
|
image_writer_threads=12,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load policy config from pretrained model and create policy using factory
|
||||||
|
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||||
|
policy_config.pretrained_path = HF_MODEL_ID
|
||||||
|
policy = make_policy(policy_config, ds_meta=dataset.meta)
|
||||||
|
|
||||||
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
|
policy_cfg=policy.config,
|
||||||
|
pretrained_path=HF_MODEL_ID,
|
||||||
|
dataset_stats=dataset.meta.stats,
|
||||||
|
preprocessor_overrides={
|
||||||
|
"device_processor": {"device": str(policy.config.device)}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nRunning evaluation...")
|
||||||
|
# Initialize keyboard listener and visualization
|
||||||
|
listener, events = init_keyboard_listener()
|
||||||
|
init_rerun(session_name="openarms_evaluation")
|
||||||
|
episode_idx = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||||
|
log_say(f"Evaluating episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||||
|
print(f"\nRunning inference for episode {episode_idx + 1}...")
|
||||||
|
|
||||||
|
# Run inference with policy
|
||||||
|
record_loop(
|
||||||
|
robot=follower,
|
||||||
|
events=events,
|
||||||
|
fps=FPS,
|
||||||
|
teleop_action_processor=teleop_action_processor,
|
||||||
|
robot_action_processor=robot_action_processor,
|
||||||
|
robot_observation_processor=robot_observation_processor,
|
||||||
|
policy=policy,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
dataset=dataset,
|
||||||
|
control_time_s=EPISODE_TIME_SEC,
|
||||||
|
single_task=TASK_DESCRIPTION,
|
||||||
|
display_data=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle re-recording
|
||||||
|
if events["rerecord_episode"]:
|
||||||
|
log_say("Re-recording episode")
|
||||||
|
events["rerecord_episode"] = False
|
||||||
|
events["exit_early"] = False
|
||||||
|
dataset.clear_episode_buffer()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Save episode
|
||||||
|
if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0:
|
||||||
|
print(f"Saving episode {episode_idx + 1} ({dataset.episode_buffer['size']} frames)...")
|
||||||
|
dataset.save_episode()
|
||||||
|
episode_idx += 1
|
||||||
|
|
||||||
|
# Reset environment between episodes (if not last episode)
|
||||||
|
if not events["stop_recording"] and episode_idx < NUM_EPISODES:
|
||||||
|
if USE_LEADER_FOR_RESETS and leader:
|
||||||
|
log_say("Reset the environment using leader arms")
|
||||||
|
print(f"\nManual reset period ({RESET_TIME_SEC}s)...")
|
||||||
|
|
||||||
|
# Use leader for manual reset with gravity compensation
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
dt = 1 / FPS
|
||||||
|
reset_start_time = time.perf_counter()
|
||||||
|
|
||||||
|
while time.perf_counter() - reset_start_time < RESET_TIME_SEC:
|
||||||
|
if events["exit_early"] or events["stop_recording"]:
|
||||||
|
break
|
||||||
|
|
||||||
|
loop_start = time.perf_counter()
|
||||||
|
|
||||||
|
# Get leader state
|
||||||
|
leader_action = leader.get_action()
|
||||||
|
|
||||||
|
# Extract positions and velocities
|
||||||
|
leader_positions_deg = {}
|
||||||
|
leader_velocities_deg_per_sec = {}
|
||||||
|
|
||||||
|
for motor in leader.bus_right.motors:
|
||||||
|
pos_key = f"right_{motor}.pos"
|
||||||
|
vel_key = f"right_{motor}.vel"
|
||||||
|
if pos_key in leader_action:
|
||||||
|
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
|
||||||
|
if vel_key in leader_action:
|
||||||
|
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
|
||||||
|
|
||||||
|
for motor in leader.bus_left.motors:
|
||||||
|
pos_key = f"left_{motor}.pos"
|
||||||
|
vel_key = f"left_{motor}.vel"
|
||||||
|
if pos_key in leader_action:
|
||||||
|
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
|
||||||
|
if vel_key in leader_action:
|
||||||
|
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
|
||||||
|
|
||||||
|
# Calculate gravity and friction torques
|
||||||
|
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
|
||||||
|
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
|
||||||
|
|
||||||
|
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
|
||||||
|
leader_friction_torques_nm = leader._friction_from_velocity(
|
||||||
|
leader_velocities_rad_per_sec,
|
||||||
|
friction_scale=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine torques
|
||||||
|
leader_total_torques_nm = {}
|
||||||
|
for motor_name in leader_gravity_torques_nm:
|
||||||
|
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
|
||||||
|
friction = leader_friction_torques_nm.get(motor_name, 0.0)
|
||||||
|
leader_total_torques_nm[motor_name] = gravity + friction
|
||||||
|
|
||||||
|
# Apply compensation
|
||||||
|
for motor in leader.bus_right.motors:
|
||||||
|
full_name = f"right_{motor}"
|
||||||
|
position = leader_positions_deg.get(full_name, 0.0)
|
||||||
|
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||||
|
kd = leader.get_damping_kd(motor)
|
||||||
|
|
||||||
|
leader.bus_right._mit_control(
|
||||||
|
motor=motor, kp=0.0, kd=kd,
|
||||||
|
position_degrees=position,
|
||||||
|
velocity_deg_per_sec=0.0,
|
||||||
|
torque=torque,
|
||||||
|
)
|
||||||
|
|
||||||
|
for motor in leader.bus_left.motors:
|
||||||
|
full_name = f"left_{motor}"
|
||||||
|
position = leader_positions_deg.get(full_name, 0.0)
|
||||||
|
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||||
|
kd = leader.get_damping_kd(motor)
|
||||||
|
|
||||||
|
leader.bus_left._mit_control(
|
||||||
|
motor=motor, kp=0.0, kd=kd,
|
||||||
|
position_degrees=position,
|
||||||
|
velocity_deg_per_sec=0.0,
|
||||||
|
torque=torque,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send leader positions to follower
|
||||||
|
follower_action = {}
|
||||||
|
for joint in leader_positions_deg.keys():
|
||||||
|
pos_key = f"{joint}.pos"
|
||||||
|
if pos_key in leader_action:
|
||||||
|
follower_action[pos_key] = leader_action[pos_key]
|
||||||
|
|
||||||
|
if follower_action:
|
||||||
|
follower.send_action(follower_action)
|
||||||
|
|
||||||
|
# Maintain loop rate
|
||||||
|
loop_duration = time.perf_counter() - loop_start
|
||||||
|
sleep_time = dt - loop_duration
|
||||||
|
if sleep_time > 0:
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
|
print("Reset complete")
|
||||||
|
else:
|
||||||
|
log_say("Waiting for manual reset")
|
||||||
|
print(f"Manually reset the environment and press ENTER to continue")
|
||||||
|
input("Press ENTER when ready...")
|
||||||
|
|
||||||
|
print(f"Evaluation complete! {episode_idx} episodes recorded")
|
||||||
|
log_say("Evaluation complete", blocking=True)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nEvaluation interrupted by user")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if leader:
|
||||||
|
leader.bus_right.disable_torque()
|
||||||
|
leader.bus_left.disable_torque()
|
||||||
|
time.sleep(0.1)
|
||||||
|
leader.disconnect()
|
||||||
|
|
||||||
|
follower.disconnect()
|
||||||
|
|
||||||
|
if listener is not None:
|
||||||
|
listener.stop()
|
||||||
|
|
||||||
|
dataset.finalize()
|
||||||
|
print("\nUploading to Hugging Face Hub...")
|
||||||
|
dataset.push_to_hub(private=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
653
examples/openarms/evaluate_with_rtc.py
Normal file
653
examples/openarms/evaluate_with_rtc.py
Normal file
@@ -0,0 +1,653 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
OpenArms Policy Evaluation with Real-Time Chunking (RTC)
|
||||||
|
|
||||||
|
Evaluates a trained policy on the OpenArms robot using RTC for smooth, continuous motion.
|
||||||
|
RTC enables large flow-matching policies (Pi0, Pi0.5, SmolVLA) to produce reactive motion
|
||||||
|
despite high inference latency by asynchronously generating action chunks.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Thread-based asynchronous action generation and execution
|
||||||
|
- RTC for smooth transitions between action chunks
|
||||||
|
- Dataset recording for evaluation episodes
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
python examples/openarms/evaluate_with_rtc.py
|
||||||
|
|
||||||
|
# With custom RTC parameters
|
||||||
|
python examples/openarms/evaluate_with_rtc.py \
|
||||||
|
--rtc.execution_horizon=12 \
|
||||||
|
--rtc.max_guidance_weight=10.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import Event, Lock, Thread
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||||
|
from lerobot.configs import parser
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.types import RTCAttentionSchedule
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||||
|
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
|
||||||
|
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||||
|
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||||
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||||
|
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||||
|
from lerobot.processor import make_default_processors
|
||||||
|
from lerobot.rl.process import ProcessSignalHandler
|
||||||
|
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||||
|
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||||
|
from lerobot.utils.hub import HubMixin
|
||||||
|
from lerobot.utils.utils import init_logging, log_say
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Default Configuration Constants
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
DEFAULT_HF_MODEL_ID = "lerobot-data-collection/three-folds-pi0"
|
||||||
|
DEFAULT_HF_EVAL_DATASET_ID = "lerobot-data-collection/three-folds-pi0_eval_rtc"
|
||||||
|
DEFAULT_TASK_DESCRIPTION = "three-folds-dataset"
|
||||||
|
|
||||||
|
DEFAULT_NUM_EPISODES = 1
|
||||||
|
DEFAULT_FPS = 30
|
||||||
|
DEFAULT_EPISODE_TIME_SEC = 300
|
||||||
|
DEFAULT_RESET_TIME_SEC = 60
|
||||||
|
|
||||||
|
DEFAULT_FOLLOWER_LEFT_PORT = "can0"
|
||||||
|
DEFAULT_FOLLOWER_RIGHT_PORT = "can1"
|
||||||
|
|
||||||
|
DEFAULT_CAMERA_CONFIG = {
|
||||||
|
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=DEFAULT_FPS),
|
||||||
|
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=DEFAULT_FPS),
|
||||||
|
"base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=DEFAULT_FPS),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Thread-Safe Robot Wrapper
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class RobotWrapper:
|
||||||
|
"""Thread-safe wrapper for robot operations."""
|
||||||
|
|
||||||
|
def __init__(self, robot: OpenArmsFollower):
|
||||||
|
self.robot = robot
|
||||||
|
self.lock = Lock()
|
||||||
|
|
||||||
|
def get_observation(self) -> dict[str, Tensor]:
|
||||||
|
with self.lock:
|
||||||
|
return self.robot.get_observation()
|
||||||
|
|
||||||
|
def send_action(self, action: dict) -> None:
|
||||||
|
with self.lock:
|
||||||
|
self.robot.send_action(action)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_features(self) -> dict:
|
||||||
|
with self.lock:
|
||||||
|
return self.robot.observation_features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_features(self) -> dict:
|
||||||
|
with self.lock:
|
||||||
|
return self.robot.action_features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self.robot.name
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Configuration
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OpenArmsRTCEvalConfig(HubMixin):
|
||||||
|
"""Configuration for OpenArms evaluation with RTC."""
|
||||||
|
|
||||||
|
policy: PreTrainedConfig | None = None
|
||||||
|
|
||||||
|
rtc: RTCConfig = field(
|
||||||
|
default_factory=lambda: RTCConfig(
|
||||||
|
enabled=True,
|
||||||
|
execution_horizon=10,
|
||||||
|
max_guidance_weight=10.0,
|
||||||
|
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model_id: str = DEFAULT_HF_MODEL_ID
|
||||||
|
eval_dataset_id: str = DEFAULT_HF_EVAL_DATASET_ID
|
||||||
|
task: str = DEFAULT_TASK_DESCRIPTION
|
||||||
|
|
||||||
|
num_episodes: int = DEFAULT_NUM_EPISODES
|
||||||
|
fps: float = DEFAULT_FPS
|
||||||
|
episode_time_sec: float = DEFAULT_EPISODE_TIME_SEC
|
||||||
|
reset_time_sec: float = DEFAULT_RESET_TIME_SEC
|
||||||
|
|
||||||
|
follower_left_port: str = DEFAULT_FOLLOWER_LEFT_PORT
|
||||||
|
follower_right_port: str = DEFAULT_FOLLOWER_RIGHT_PORT
|
||||||
|
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
|
# Should be higher than inference_delay + execution_horizon
|
||||||
|
action_queue_size_to_get_new_actions: int = 30
|
||||||
|
|
||||||
|
record_dataset: bool = True
|
||||||
|
push_to_hub: bool = True
|
||||||
|
|
||||||
|
use_torch_compile: bool = False
|
||||||
|
torch_compile_backend: str = "inductor"
|
||||||
|
torch_compile_mode: str = "default"
|
||||||
|
torch_compile_disable_cudagraphs: bool = True
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
policy_path = parser.get_path_arg("policy")
|
||||||
|
if policy_path:
|
||||||
|
cli_overrides = parser.get_cli_overrides("policy")
|
||||||
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
|
self.policy.pretrained_path = policy_path
|
||||||
|
self.model_id = policy_path
|
||||||
|
elif self.model_id:
|
||||||
|
self.policy = PreTrainedConfig.from_pretrained(self.model_id)
|
||||||
|
self.policy.pretrained_path = self.model_id
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_path_fields__(cls) -> list[str]:
|
||||||
|
return ["policy"]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Action Generation Thread
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def get_actions_thread(
|
||||||
|
policy,
|
||||||
|
robot: RobotWrapper,
|
||||||
|
robot_observation_processor,
|
||||||
|
action_queue: ActionQueue,
|
||||||
|
shutdown_event: Event,
|
||||||
|
cfg: OpenArmsRTCEvalConfig,
|
||||||
|
episode_active: Event,
|
||||||
|
):
|
||||||
|
"""Thread function to asynchronously generate action chunks from the policy."""
|
||||||
|
try:
|
||||||
|
logger.info("[GET_ACTIONS] Starting action generation thread")
|
||||||
|
|
||||||
|
latency_tracker = LatencyTracker()
|
||||||
|
time_per_chunk = 1.0 / cfg.fps
|
||||||
|
|
||||||
|
hw_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||||
|
policy_device = policy.config.device
|
||||||
|
|
||||||
|
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
||||||
|
|
||||||
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
|
policy_cfg=cfg.policy,
|
||||||
|
pretrained_path=cfg.policy.pretrained_path,
|
||||||
|
dataset_stats=None,
|
||||||
|
preprocessor_overrides={
|
||||||
|
"device_processor": {"device": cfg.device},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully")
|
||||||
|
|
||||||
|
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||||
|
if not cfg.rtc.enabled:
|
||||||
|
get_actions_threshold = 0
|
||||||
|
|
||||||
|
while not shutdown_event.is_set():
|
||||||
|
if not episode_active.is_set():
|
||||||
|
time.sleep(0.01)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if action_queue.qsize() <= get_actions_threshold:
|
||||||
|
current_time = time.perf_counter()
|
||||||
|
action_index_before_inference = action_queue.get_action_index()
|
||||||
|
prev_actions = action_queue.get_left_over()
|
||||||
|
|
||||||
|
inference_latency = latency_tracker.max()
|
||||||
|
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
|
||||||
|
|
||||||
|
obs = robot.get_observation()
|
||||||
|
obs_processed = robot_observation_processor(obs)
|
||||||
|
|
||||||
|
obs_with_policy_features = build_dataset_frame(
|
||||||
|
hw_features, obs_processed, prefix="observation"
|
||||||
|
)
|
||||||
|
|
||||||
|
for name in obs_with_policy_features:
|
||||||
|
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||||
|
if "image" in name:
|
||||||
|
obs_with_policy_features[name] = (
|
||||||
|
obs_with_policy_features[name].type(torch.float32) / 255
|
||||||
|
)
|
||||||
|
obs_with_policy_features[name] = (
|
||||||
|
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||||
|
)
|
||||||
|
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||||
|
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||||
|
|
||||||
|
obs_with_policy_features["task"] = [cfg.task]
|
||||||
|
obs_with_policy_features["robot_type"] = robot.name
|
||||||
|
|
||||||
|
preprocessed_obs = preprocessor(obs_with_policy_features)
|
||||||
|
|
||||||
|
actions = policy.predict_action_chunk(
|
||||||
|
preprocessed_obs,
|
||||||
|
inference_delay=inference_delay,
|
||||||
|
prev_chunk_left_over=prev_actions,
|
||||||
|
)
|
||||||
|
|
||||||
|
original_actions = actions.squeeze(0).clone()
|
||||||
|
postprocessed_actions = postprocessor(actions).squeeze(0)
|
||||||
|
|
||||||
|
new_latency = time.perf_counter() - current_time
|
||||||
|
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||||
|
latency_tracker.add(new_latency)
|
||||||
|
|
||||||
|
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
||||||
|
logger.warning(
|
||||||
|
"[GET_ACTIONS] action_queue_size_to_get_new_actions too small. "
|
||||||
|
"Should be higher than inference delay + execution horizon."
|
||||||
|
)
|
||||||
|
|
||||||
|
action_queue.merge(
|
||||||
|
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[GET_ACTIONS] Generated chunk, latency={new_latency:.3f}s, "
|
||||||
|
f"delay={new_delay}, queue_size={action_queue.qsize()}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
logger.info("[GET_ACTIONS] Action generation thread shutting down")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[GET_ACTIONS] Fatal exception: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
shutdown_event.set()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Action Execution Thread
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def actor_thread(
|
||||||
|
robot: RobotWrapper,
|
||||||
|
robot_action_processor,
|
||||||
|
action_queue: ActionQueue,
|
||||||
|
shutdown_event: Event,
|
||||||
|
cfg: OpenArmsRTCEvalConfig,
|
||||||
|
episode_active: Event,
|
||||||
|
dataset: LeRobotDataset | None,
|
||||||
|
dataset_lock: Lock,
|
||||||
|
teleop_action_processor,
|
||||||
|
robot_observation_processor,
|
||||||
|
):
|
||||||
|
"""Thread function to execute actions on the robot."""
|
||||||
|
try:
|
||||||
|
logger.info("[ACTOR] Starting actor thread")
|
||||||
|
|
||||||
|
action_count = 0
|
||||||
|
action_interval = 1.0 / cfg.fps
|
||||||
|
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
|
||||||
|
|
||||||
|
while not shutdown_event.is_set():
|
||||||
|
if not episode_active.is_set():
|
||||||
|
time.sleep(0.01)
|
||||||
|
continue
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
action = action_queue.get()
|
||||||
|
|
||||||
|
if action is not None:
|
||||||
|
action = action.cpu()
|
||||||
|
|
||||||
|
action_dict = {}
|
||||||
|
for i, key in enumerate(action_keys):
|
||||||
|
if i < len(action):
|
||||||
|
action_dict[key] = action[i].item()
|
||||||
|
|
||||||
|
action_processed = robot_action_processor((action_dict, None))
|
||||||
|
robot.send_action(action_processed)
|
||||||
|
|
||||||
|
if cfg.record_dataset and dataset is not None:
|
||||||
|
with dataset_lock:
|
||||||
|
obs = robot.get_observation()
|
||||||
|
obs_processed = robot_observation_processor(obs)
|
||||||
|
action_for_dataset = teleop_action_processor((action_dict, None))
|
||||||
|
|
||||||
|
frame = {}
|
||||||
|
for key, value in obs_processed.items():
|
||||||
|
frame[f"observation.{key}"] = value
|
||||||
|
for key, value in action_for_dataset.items():
|
||||||
|
frame[f"action.{key}"] = value
|
||||||
|
frame["task"] = cfg.task
|
||||||
|
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
|
action_count += 1
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - start_time
|
||||||
|
sleep_time = max(0, action_interval - dt_s - 0.001)
|
||||||
|
if sleep_time > 0:
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
|
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ACTOR] Fatal exception: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
shutdown_event.set()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Main Evaluation Function
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_torch_compile(policy, cfg: OpenArmsRTCEvalConfig):
|
||||||
|
"""Apply torch.compile to the policy's predict_action_chunk method."""
|
||||||
|
if policy.name in ["pi05", "pi0"]:
|
||||||
|
return policy
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not hasattr(torch, "compile"):
|
||||||
|
logger.warning(
|
||||||
|
f"torch.compile not available. Requires PyTorch 2.0+. "
|
||||||
|
f"Current version: {torch.__version__}. Skipping compilation."
|
||||||
|
)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
logger.info("Applying torch.compile to predict_action_chunk...")
|
||||||
|
|
||||||
|
compile_kwargs = {
|
||||||
|
"backend": cfg.torch_compile_backend,
|
||||||
|
"mode": cfg.torch_compile_mode,
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.torch_compile_disable_cudagraphs:
|
||||||
|
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||||
|
|
||||||
|
original_method = policy.predict_action_chunk
|
||||||
|
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||||
|
policy.predict_action_chunk = compiled_method
|
||||||
|
logger.info("Successfully compiled predict_action_chunk")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to apply torch.compile: {e}")
|
||||||
|
logger.warning("Continuing without torch.compile")
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
|
||||||
|
@parser.wrap()
|
||||||
|
def main(cfg: OpenArmsRTCEvalConfig):
|
||||||
|
"""Main evaluation function with RTC."""
|
||||||
|
init_logging()
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("OpenArms Policy Evaluation with RTC")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"\nModel: {cfg.model_id}")
|
||||||
|
print(f"Evaluation Dataset: {cfg.eval_dataset_id}")
|
||||||
|
print(f"Task: {cfg.task}")
|
||||||
|
print(f"Episodes: {cfg.num_episodes}")
|
||||||
|
print(f"Episode Duration: {cfg.episode_time_sec}s")
|
||||||
|
print(f"RTC Enabled: {cfg.rtc.enabled}")
|
||||||
|
print(f"RTC Execution Horizon: {cfg.rtc.execution_horizon}")
|
||||||
|
print(f"RTC Max Guidance Weight: {cfg.rtc.max_guidance_weight}")
|
||||||
|
print(f"Device: {cfg.device}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||||
|
shutdown_event = signal_handler.shutdown_event
|
||||||
|
episode_active = Event()
|
||||||
|
|
||||||
|
# Initialize Robot
|
||||||
|
follower_config = OpenArmsFollowerConfig(
|
||||||
|
port_left=cfg.follower_left_port,
|
||||||
|
port_right=cfg.follower_right_port,
|
||||||
|
can_interface="socketcan",
|
||||||
|
id="openarms_follower",
|
||||||
|
disable_torque_on_disconnect=True,
|
||||||
|
max_relative_target=10.0,
|
||||||
|
cameras=DEFAULT_CAMERA_CONFIG,
|
||||||
|
)
|
||||||
|
|
||||||
|
follower = OpenArmsFollower(follower_config)
|
||||||
|
follower.connect(calibrate=False)
|
||||||
|
|
||||||
|
if not follower.is_connected:
|
||||||
|
raise RuntimeError("Follower robot failed to connect!")
|
||||||
|
|
||||||
|
robot = RobotWrapper(follower)
|
||||||
|
logger.info("Follower robot connected")
|
||||||
|
|
||||||
|
# Build Processors and Dataset Features
|
||||||
|
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||||
|
|
||||||
|
action_features_hw = {}
|
||||||
|
for key, value in follower.action_features.items():
|
||||||
|
if key.endswith(".pos"):
|
||||||
|
action_features_hw[key] = value
|
||||||
|
|
||||||
|
dataset_features = combine_feature_dicts(
|
||||||
|
aggregate_pipeline_dataset_features(
|
||||||
|
pipeline=teleop_action_processor,
|
||||||
|
initial_features=create_initial_features(action=action_features_hw),
|
||||||
|
use_videos=True,
|
||||||
|
),
|
||||||
|
aggregate_pipeline_dataset_features(
|
||||||
|
pipeline=robot_observation_processor,
|
||||||
|
initial_features=create_initial_features(observation=follower.observation_features),
|
||||||
|
use_videos=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create or Load Dataset
|
||||||
|
dataset = None
|
||||||
|
dataset_lock = Lock()
|
||||||
|
|
||||||
|
if cfg.record_dataset:
|
||||||
|
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / cfg.eval_dataset_id
|
||||||
|
if dataset_path.exists():
|
||||||
|
logger.info(f"Evaluation dataset exists at: {dataset_path}")
|
||||||
|
logger.info("New episodes will be appended.")
|
||||||
|
choice = input("Continue? (y/n): ").strip().lower()
|
||||||
|
if choice != "y":
|
||||||
|
logger.info("Aborting evaluation.")
|
||||||
|
follower.disconnect()
|
||||||
|
return
|
||||||
|
|
||||||
|
dataset = LeRobotDataset.create(
|
||||||
|
repo_id=cfg.eval_dataset_id,
|
||||||
|
fps=int(cfg.fps),
|
||||||
|
features=dataset_features,
|
||||||
|
robot_type=follower.name,
|
||||||
|
use_videos=True,
|
||||||
|
image_writer_processes=0,
|
||||||
|
image_writer_threads=12,
|
||||||
|
)
|
||||||
|
logger.info(f"Dataset created: {cfg.eval_dataset_id}")
|
||||||
|
|
||||||
|
# Load Policy
|
||||||
|
logger.info(f"Loading policy from: {cfg.model_id}")
|
||||||
|
|
||||||
|
policy_class = get_policy_class(cfg.policy.type)
|
||||||
|
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||||
|
|
||||||
|
if cfg.policy.type in ["pi05", "pi0"]:
|
||||||
|
config.compile_model = cfg.use_torch_compile
|
||||||
|
|
||||||
|
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||||
|
|
||||||
|
policy.config.rtc_config = cfg.rtc
|
||||||
|
policy.init_rtc_processor()
|
||||||
|
|
||||||
|
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
||||||
|
|
||||||
|
policy = policy.to(cfg.device)
|
||||||
|
policy.eval()
|
||||||
|
|
||||||
|
if cfg.use_torch_compile:
|
||||||
|
policy = _apply_torch_compile(policy, cfg)
|
||||||
|
|
||||||
|
logger.info(f"Policy loaded: {policy.name}")
|
||||||
|
|
||||||
|
# Create Action Queue and Start Threads
|
||||||
|
action_queue = ActionQueue(cfg.rtc)
|
||||||
|
|
||||||
|
get_actions_t = Thread(
|
||||||
|
target=get_actions_thread,
|
||||||
|
args=(
|
||||||
|
policy,
|
||||||
|
robot,
|
||||||
|
robot_observation_processor,
|
||||||
|
action_queue,
|
||||||
|
shutdown_event,
|
||||||
|
cfg,
|
||||||
|
episode_active,
|
||||||
|
),
|
||||||
|
daemon=True,
|
||||||
|
name="GetActions",
|
||||||
|
)
|
||||||
|
get_actions_t.start()
|
||||||
|
logger.info("Started action generation thread")
|
||||||
|
|
||||||
|
actor_t = Thread(
|
||||||
|
target=actor_thread,
|
||||||
|
args=(
|
||||||
|
robot,
|
||||||
|
robot_action_processor,
|
||||||
|
action_queue,
|
||||||
|
shutdown_event,
|
||||||
|
cfg,
|
||||||
|
episode_active,
|
||||||
|
dataset,
|
||||||
|
dataset_lock,
|
||||||
|
teleop_action_processor,
|
||||||
|
robot_observation_processor,
|
||||||
|
),
|
||||||
|
daemon=True,
|
||||||
|
name="Actor",
|
||||||
|
)
|
||||||
|
actor_t.start()
|
||||||
|
logger.info("Started actor thread")
|
||||||
|
|
||||||
|
# Run Evaluation Episodes
|
||||||
|
episode_idx = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
while episode_idx < cfg.num_episodes and not shutdown_event.is_set():
|
||||||
|
log_say(f"Evaluating episode {episode_idx + 1} of {cfg.num_episodes}")
|
||||||
|
logger.info(f"\n{'='*40}")
|
||||||
|
logger.info(f"Episode {episode_idx + 1} / {cfg.num_episodes}")
|
||||||
|
logger.info(f"{'='*40}")
|
||||||
|
|
||||||
|
action_queue = ActionQueue(cfg.rtc)
|
||||||
|
episode_active.set()
|
||||||
|
episode_start_time = time.time()
|
||||||
|
|
||||||
|
while (time.time() - episode_start_time) < cfg.episode_time_sec:
|
||||||
|
if shutdown_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
elapsed = time.time() - episode_start_time
|
||||||
|
if int(elapsed) % 10 == 0 and int(elapsed) > 0:
|
||||||
|
logger.info(
|
||||||
|
f"[MAIN] Episode progress: {elapsed:.0f}/{cfg.episode_time_sec}s, "
|
||||||
|
f"queue_size={action_queue.qsize()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
episode_active.clear()
|
||||||
|
logger.info(f"Episode {episode_idx + 1} completed")
|
||||||
|
|
||||||
|
if cfg.record_dataset and dataset is not None:
|
||||||
|
with dataset_lock:
|
||||||
|
if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0:
|
||||||
|
logger.info(
|
||||||
|
f"Saving episode {episode_idx + 1} "
|
||||||
|
f"({dataset.episode_buffer['size']} frames)"
|
||||||
|
)
|
||||||
|
dataset.save_episode()
|
||||||
|
|
||||||
|
episode_idx += 1
|
||||||
|
|
||||||
|
# Manual reset between episodes
|
||||||
|
if not shutdown_event.is_set() and episode_idx < cfg.num_episodes:
|
||||||
|
log_say("Waiting for manual reset")
|
||||||
|
logger.info("Manually reset the environment and press ENTER to continue")
|
||||||
|
input("Press ENTER when ready...")
|
||||||
|
|
||||||
|
logger.info(f"Evaluation complete! {episode_idx} episodes recorded")
|
||||||
|
log_say("Evaluation complete", blocking=True)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("\n\nEvaluation interrupted by user")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
shutdown_event.set()
|
||||||
|
episode_active.clear()
|
||||||
|
|
||||||
|
if get_actions_t.is_alive():
|
||||||
|
logger.info("Waiting for action generation thread to finish...")
|
||||||
|
get_actions_t.join(timeout=5.0)
|
||||||
|
|
||||||
|
if actor_t.is_alive():
|
||||||
|
logger.info("Waiting for actor thread to finish...")
|
||||||
|
actor_t.join(timeout=5.0)
|
||||||
|
|
||||||
|
follower.disconnect()
|
||||||
|
logger.info("Follower disconnected")
|
||||||
|
|
||||||
|
if cfg.record_dataset and dataset is not None:
|
||||||
|
dataset.finalize()
|
||||||
|
if cfg.push_to_hub:
|
||||||
|
logger.info("Uploading to Hugging Face Hub...")
|
||||||
|
dataset.push_to_hub(private=True)
|
||||||
|
|
||||||
|
logger.info("Cleanup completed")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
216
examples/openarms/friction_compensation.py
Normal file
216
examples/openarms/friction_compensation.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||||
|
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||||
|
|
||||||
|
|
||||||
|
# Friction model parameters from OpenArms config/follower.yaml
|
||||||
|
# τ_fric(ω) = Fo + Fv·ω + Fc·tanh(k·ω)
|
||||||
|
# For 8 motors: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
|
||||||
|
FRICTION_PARAMS = {
|
||||||
|
"Fc": [0.306, 0.306, 0.40, 0.166, 0.050, 0.093, 0.172, 0.0512], # Coulomb friction [Nm]
|
||||||
|
"k": [28.417, 28.417, 29.065, 130.038, 151.771, 242.287, 7.888, 4.000], # tanh steepness
|
||||||
|
"Fv": [0.063, 0.0630, 0.604, 0.813, 0.029, 0.072, 0.084, 0.084], # Viscous friction [Nm·s/rad]
|
||||||
|
"Fo": [0.088, 0.088, 0.008, -0.058, 0.005, 0.009, -0.059, -0.050], # Offset torque [Nm]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Constants from OpenArms C++ implementation
|
||||||
|
AMP_TMP = 1.0
|
||||||
|
COEF_TMP = 0.1
|
||||||
|
|
||||||
|
FRICTION_SCALE = 1.0 # OpenArms C++ uses 0.3 factor in unilateral mode
|
||||||
|
DAMPING_KD = [0.5, 0.5, 0.5, 0.5, 0.1, 0.1, 0.1, 0.1] # Damping gains for stability
|
||||||
|
|
||||||
|
def compute_friction_torque(velocity_rad_per_sec: float, motor_index: int) -> float:
|
||||||
|
"""
|
||||||
|
Compute friction torque for a single motor using the tanh friction model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
velocity_rad_per_sec: Angular velocity in rad/s
|
||||||
|
motor_index: Index of the motor (0-7)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Friction torque in N·m (scaled for stability)
|
||||||
|
"""
|
||||||
|
|
||||||
|
Fc = FRICTION_PARAMS["Fc"][motor_index]
|
||||||
|
k = FRICTION_PARAMS["k"][motor_index]
|
||||||
|
Fv = FRICTION_PARAMS["Fv"][motor_index]
|
||||||
|
Fo = FRICTION_PARAMS["Fo"][motor_index]
|
||||||
|
|
||||||
|
# Friction model: τ_fric = amp * Fc * tanh(coef * k * ω) + Fv * ω + Fo
|
||||||
|
friction_torque = (
|
||||||
|
AMP_TMP * Fc * np.tanh(COEF_TMP * k * velocity_rad_per_sec) +
|
||||||
|
Fv * velocity_rad_per_sec +
|
||||||
|
Fo
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scale down friction compensation for stability at lower control rates
|
||||||
|
# (OpenArms C++ uses 0.3 factor in unilateral mode)!!
|
||||||
|
friction_torque *= FRICTION_SCALE
|
||||||
|
|
||||||
|
return friction_torque
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
config = OpenArmsFollowerConfig(
|
||||||
|
port_left="can0",
|
||||||
|
port_right="can1",
|
||||||
|
can_interface="socketcan",
|
||||||
|
id="openarms_follower",
|
||||||
|
disable_torque_on_disconnect=True,
|
||||||
|
max_relative_target=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Initializing robot...")
|
||||||
|
follower = OpenArmsFollower(config)
|
||||||
|
follower.connect(calibrate=True)
|
||||||
|
|
||||||
|
print(f"Applying friction compensation")
|
||||||
|
print(" 1. Support the arm before starting")
|
||||||
|
print(" 2. The arm will be held in place by friction compensation")
|
||||||
|
print(" 3. You should be able to move it with gentle force")
|
||||||
|
print("\nPress ENTER when ready to start...")
|
||||||
|
input()
|
||||||
|
|
||||||
|
print(f"✓ Motors enabled")
|
||||||
|
print("\nStarting friction compensation loop...")
|
||||||
|
print("Press Ctrl+C to stop\n")
|
||||||
|
|
||||||
|
loop_times = []
|
||||||
|
last_print_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Motor name to index mapping
|
||||||
|
motor_name_to_index = {
|
||||||
|
"joint_1": 0,
|
||||||
|
"joint_2": 1,
|
||||||
|
"joint_3": 2,
|
||||||
|
"joint_4": 3,
|
||||||
|
"joint_5": 4,
|
||||||
|
"joint_6": 5,
|
||||||
|
"joint_7": 6,
|
||||||
|
"gripper": 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
loop_start = time.perf_counter()
|
||||||
|
|
||||||
|
# Get current joint positions and velocities from robot
|
||||||
|
obs = follower.get_observation()
|
||||||
|
|
||||||
|
# Extract velocities in degrees per second
|
||||||
|
velocities_deg_per_sec = {}
|
||||||
|
positions_deg = {}
|
||||||
|
|
||||||
|
for motor in follower.bus_right.motors:
|
||||||
|
vel_key = f"right_{motor}.vel"
|
||||||
|
pos_key = f"right_{motor}.pos"
|
||||||
|
if vel_key in obs:
|
||||||
|
velocities_deg_per_sec[f"right_{motor}"] = obs[vel_key]
|
||||||
|
if pos_key in obs:
|
||||||
|
positions_deg[f"right_{motor}"] = obs[pos_key]
|
||||||
|
|
||||||
|
for motor in follower.bus_left.motors:
|
||||||
|
vel_key = f"left_{motor}.vel"
|
||||||
|
pos_key = f"left_{motor}.pos"
|
||||||
|
if vel_key in obs:
|
||||||
|
velocities_deg_per_sec[f"left_{motor}"] = obs[vel_key]
|
||||||
|
if pos_key in obs:
|
||||||
|
positions_deg[f"left_{motor}"] = obs[pos_key]
|
||||||
|
|
||||||
|
# Convert velocities to rad/s and compute friction torques
|
||||||
|
friction_torques_nm = {}
|
||||||
|
for motor_full_name, velocity_deg_per_sec in velocities_deg_per_sec.items():
|
||||||
|
# Extract motor name without arm prefix
|
||||||
|
if motor_full_name.startswith("right_"):
|
||||||
|
motor_name = motor_full_name.removeprefix("right_")
|
||||||
|
elif motor_full_name.startswith("left_"):
|
||||||
|
motor_name = motor_full_name.removeprefix("left_")
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get motor index for friction parameters
|
||||||
|
motor_index = motor_name_to_index.get(motor_name, 0)
|
||||||
|
|
||||||
|
# Convert velocity to rad/s
|
||||||
|
velocity_rad_per_sec = np.deg2rad(velocity_deg_per_sec)
|
||||||
|
|
||||||
|
# Compute friction torque
|
||||||
|
friction_torque = compute_friction_torque(velocity_rad_per_sec, motor_index)
|
||||||
|
friction_torques_nm[motor_full_name] = friction_torque
|
||||||
|
|
||||||
|
# Apply friction compensation to right arm (all joints INCLUDING gripper)
|
||||||
|
for motor in follower.bus_right.motors:
|
||||||
|
full_name = f"right_{motor}"
|
||||||
|
position = positions_deg.get(full_name, 0.0)
|
||||||
|
torque = friction_torques_nm.get(full_name, 0.0)
|
||||||
|
|
||||||
|
# Get motor index for damping gain
|
||||||
|
motor_index = motor_name_to_index.get(motor, 0)
|
||||||
|
kd = DAMPING_KD[motor_index]
|
||||||
|
|
||||||
|
# Send MIT control command with friction compensation + damping
|
||||||
|
follower.bus_right._mit_control(
|
||||||
|
motor=motor,
|
||||||
|
kp=0.0, # No position control
|
||||||
|
kd=kd, # Add damping for stability
|
||||||
|
position_degrees=position,
|
||||||
|
velocity_deg_per_sec=0.0,
|
||||||
|
torque=torque
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply friction compensation to left arm (all joints INCLUDING gripper)
|
||||||
|
for motor in follower.bus_left.motors:
|
||||||
|
full_name = f"left_{motor}"
|
||||||
|
position = positions_deg.get(full_name, 0.0)
|
||||||
|
torque = friction_torques_nm.get(full_name, 0.0)
|
||||||
|
|
||||||
|
# Get motor index for damping gain
|
||||||
|
motor_index = motor_name_to_index.get(motor, 0)
|
||||||
|
kd = DAMPING_KD[motor_index]
|
||||||
|
|
||||||
|
# Send MIT control command with friction compensation + damping
|
||||||
|
follower.bus_left._mit_control(
|
||||||
|
motor=motor,
|
||||||
|
kp=0.0, # No position control
|
||||||
|
kd=kd, # Add damping for stability
|
||||||
|
position_degrees=position,
|
||||||
|
velocity_deg_per_sec=0.0,
|
||||||
|
torque=torque
|
||||||
|
)
|
||||||
|
|
||||||
|
# Measure loop time
|
||||||
|
loop_end = time.perf_counter()
|
||||||
|
loop_time = loop_end - loop_start
|
||||||
|
loop_times.append(loop_time)
|
||||||
|
|
||||||
|
# Print status every 2 seconds
|
||||||
|
if loop_end - last_print_time >= 2.0:
|
||||||
|
if loop_times:
|
||||||
|
avg_time = sum(loop_times) / len(loop_times)
|
||||||
|
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||||
|
|
||||||
|
print(f"{current_hz:.1f} Hz")
|
||||||
|
|
||||||
|
loop_times = []
|
||||||
|
last_print_time = loop_end
|
||||||
|
|
||||||
|
time.sleep(0.001)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nStopping friction compensation...")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
print("\nDisabling all motors and disconnecting...")
|
||||||
|
follower.bus_right.disable_torque()
|
||||||
|
follower.bus_left.disable_torque()
|
||||||
|
time.sleep(0.1)
|
||||||
|
follower.disconnect()
|
||||||
|
print("✓ Safe shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
142
examples/openarms/gravity_compensation.py
Executable file
142
examples/openarms/gravity_compensation.py
Executable file
@@ -0,0 +1,142 @@
|
|||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import pinocchio as pin
|
||||||
|
from os.path import join, dirname, exists, expanduser
|
||||||
|
|
||||||
|
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||||
|
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
config = OpenArmsFollowerConfig(
|
||||||
|
port_left="can0",
|
||||||
|
port_right="can1",
|
||||||
|
can_interface="socketcan",
|
||||||
|
id="openarms_follower",
|
||||||
|
disable_torque_on_disconnect=True,
|
||||||
|
max_relative_target=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
print("Initializing robot...")
|
||||||
|
follower = OpenArmsFollower(config)
|
||||||
|
follower.connect(calibrate=True)
|
||||||
|
|
||||||
|
# Load URDF for Pinocchio dynamics
|
||||||
|
urdf_path = "/home/croissant/Documents/openarm_description/openarm_bimanual_pybullet.urdf"
|
||||||
|
|
||||||
|
pin_robot = pin.RobotWrapper.BuildFromURDF(urdf_path, dirname(urdf_path))
|
||||||
|
pin_robot.data = pin_robot.model.createData()
|
||||||
|
print(f"✓ Loaded Pinocchio model with {pin_robot.nq} DoFs")
|
||||||
|
|
||||||
|
follower.pin_robot = pin_robot
|
||||||
|
|
||||||
|
print(f"Applying gravity compensation")
|
||||||
|
print(" 1. Support the arm before starting")
|
||||||
|
print(" 2. The arm will be held in place by gravity compensation")
|
||||||
|
print(" 3. You should be able to move it with gentle force")
|
||||||
|
print("\nPress ENTER when ready to start...")
|
||||||
|
input()
|
||||||
|
|
||||||
|
print(f"✓ Motors enabled")
|
||||||
|
print("\nStarting gravity compensation loop...")
|
||||||
|
print("Press Ctrl+C to stop\n")
|
||||||
|
|
||||||
|
loop_times = []
|
||||||
|
last_print_time = time.perf_counter()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
loop_start = time.perf_counter()
|
||||||
|
|
||||||
|
# Get current joint positions from robot
|
||||||
|
obs = follower.get_observation()
|
||||||
|
|
||||||
|
# Extract positions in degrees
|
||||||
|
positions_deg = {}
|
||||||
|
for motor in follower.bus_right.motors:
|
||||||
|
key = f"right_{motor}.pos"
|
||||||
|
if key in obs:
|
||||||
|
positions_deg[f"right_{motor}"] = obs[key]
|
||||||
|
|
||||||
|
for motor in follower.bus_left.motors:
|
||||||
|
key = f"left_{motor}.pos"
|
||||||
|
if key in obs:
|
||||||
|
positions_deg[f"left_{motor}"] = obs[key]
|
||||||
|
|
||||||
|
# Convert to radians and calculate gravity torques
|
||||||
|
# Use the built-in method from OpenArmsFollower
|
||||||
|
positions_rad = {k: np.deg2rad(v) for k, v in positions_deg.items()}
|
||||||
|
torques_nm = follower._gravity_from_q(positions_rad)
|
||||||
|
|
||||||
|
# Apply gravity compensation to right arm (all joints except gripper)
|
||||||
|
for motor in follower.bus_right.motors:
|
||||||
|
if motor == "gripper":
|
||||||
|
continue # Skip gripper
|
||||||
|
|
||||||
|
full_name = f"right_{motor}"
|
||||||
|
position = positions_deg.get(full_name, 0.0)
|
||||||
|
torque = torques_nm.get(full_name, 0.0)
|
||||||
|
|
||||||
|
# Send MIT control command with gravity compensation torque
|
||||||
|
follower.bus_right._mit_control(
|
||||||
|
motor=motor,
|
||||||
|
kp=0.0, # No position control
|
||||||
|
kd=0.0, # No velocity damping
|
||||||
|
position_degrees=position,
|
||||||
|
velocity_deg_per_sec=0.0,
|
||||||
|
torque=torque
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply gravity compensation to left arm (all joints except gripper)
|
||||||
|
for motor in follower.bus_left.motors:
|
||||||
|
if motor == "gripper":
|
||||||
|
continue # Skip gripper
|
||||||
|
|
||||||
|
full_name = f"left_{motor}"
|
||||||
|
position = positions_deg.get(full_name, 0.0)
|
||||||
|
torque = torques_nm.get(full_name, 0.0)
|
||||||
|
|
||||||
|
# Send MIT control command with gravity compensation torque
|
||||||
|
follower.bus_left._mit_control(
|
||||||
|
motor=motor,
|
||||||
|
kp=0.0, # No position control
|
||||||
|
kd=0.0, # No velocity damping
|
||||||
|
position_degrees=position,
|
||||||
|
velocity_deg_per_sec=0.0,
|
||||||
|
torque=torque
|
||||||
|
)
|
||||||
|
|
||||||
|
# Measure loop time
|
||||||
|
loop_end = time.perf_counter()
|
||||||
|
loop_time = loop_end - loop_start
|
||||||
|
loop_times.append(loop_time)
|
||||||
|
|
||||||
|
# Print status every 2 seconds
|
||||||
|
if loop_end - last_print_time >= 2.0:
|
||||||
|
if loop_times:
|
||||||
|
avg_time = sum(loop_times) / len(loop_times)
|
||||||
|
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||||
|
|
||||||
|
print(f"{current_hz:.1f} Hz ({avg_time*1000:.1f} ms)")
|
||||||
|
|
||||||
|
loop_times = []
|
||||||
|
last_print_time = loop_end
|
||||||
|
|
||||||
|
time.sleep(0.005)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nStopping gravity compensation...")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
print("\nDisabling all motors and disconnecting...")
|
||||||
|
follower.bus_right.disable_torque()
|
||||||
|
follower.bus_left.disable_torque()
|
||||||
|
time.sleep(0.1)
|
||||||
|
follower.disconnect()
|
||||||
|
print("✓ Safe shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
395
examples/openarms/record_with_compensation.py
Normal file
395
examples/openarms/record_with_compensation.py
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
"""
|
||||||
|
OpenArms Dataset Recording with Gravity + Friction Compensation
|
||||||
|
|
||||||
|
Records a dataset using OpenArms follower robot with leader teleoperator.
|
||||||
|
Leader arms have gravity and friction compensation for weightless, easy movement.
|
||||||
|
Includes 3 cameras: left wrist, right wrist, and base camera.
|
||||||
|
|
||||||
|
Uses the same compensation approach as teleop_with_compensation.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||||
|
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||||
|
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||||
|
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||||
|
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||||
|
from lerobot.utils.control_utils import init_keyboard_listener
|
||||||
|
from lerobot.utils.utils import log_say
|
||||||
|
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||||
|
|
||||||
|
# Recording parameters
|
||||||
|
NUM_EPISODES = 1
|
||||||
|
FPS = 30
|
||||||
|
EPISODE_TIME_SEC = 600
|
||||||
|
RESET_TIME_SEC = 120
|
||||||
|
TASK_DESCRIPTION = "OpenArms task description"
|
||||||
|
|
||||||
|
# Friction compensation scale factor (1.0 = full, 0.3 = 30% for stability)
|
||||||
|
FRICTION_SCALE = 1.0
|
||||||
|
|
||||||
|
def record_loop_with_compensation(
|
||||||
|
robot,
|
||||||
|
leader,
|
||||||
|
events,
|
||||||
|
fps,
|
||||||
|
dataset,
|
||||||
|
dataset_features,
|
||||||
|
control_time_s,
|
||||||
|
single_task,
|
||||||
|
display_data=True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Custom record loop that applies gravity + friction compensation to leader.
|
||||||
|
Based on record_loop but with integrated compensation.
|
||||||
|
"""
|
||||||
|
dt = 1 / fps
|
||||||
|
episode_start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# All joints (both arms)
|
||||||
|
all_joints = []
|
||||||
|
for motor in leader.bus_right.motors:
|
||||||
|
all_joints.append(f"right_{motor}")
|
||||||
|
for motor in leader.bus_left.motors:
|
||||||
|
all_joints.append(f"left_{motor}")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
loop_start = time.perf_counter()
|
||||||
|
elapsed = loop_start - episode_start_time
|
||||||
|
|
||||||
|
# Check if we should exit
|
||||||
|
if elapsed >= control_time_s or events["exit_early"] or events["stop_recording"]:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get leader state
|
||||||
|
leader_action = leader.get_action()
|
||||||
|
|
||||||
|
# Extract positions and velocities in degrees
|
||||||
|
leader_positions_deg = {}
|
||||||
|
leader_velocities_deg_per_sec = {}
|
||||||
|
|
||||||
|
for motor in leader.bus_right.motors:
|
||||||
|
pos_key = f"right_{motor}.pos"
|
||||||
|
vel_key = f"right_{motor}.vel"
|
||||||
|
if pos_key in leader_action:
|
||||||
|
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
|
||||||
|
if vel_key in leader_action:
|
||||||
|
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
|
||||||
|
|
||||||
|
for motor in leader.bus_left.motors:
|
||||||
|
pos_key = f"left_{motor}.pos"
|
||||||
|
vel_key = f"left_{motor}.vel"
|
||||||
|
if pos_key in leader_action:
|
||||||
|
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
|
||||||
|
if vel_key in leader_action:
|
||||||
|
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
|
||||||
|
|
||||||
|
# Calculate gravity torques for leader using built-in method
|
||||||
|
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
|
||||||
|
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
|
||||||
|
|
||||||
|
# Calculate friction torques for leader using built-in method
|
||||||
|
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
|
||||||
|
leader_friction_torques_nm = leader._friction_from_velocity(
|
||||||
|
leader_velocities_rad_per_sec,
|
||||||
|
friction_scale=FRICTION_SCALE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine gravity + friction torques
|
||||||
|
leader_total_torques_nm = {}
|
||||||
|
for motor_name in leader_gravity_torques_nm:
|
||||||
|
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
|
||||||
|
friction = leader_friction_torques_nm.get(motor_name, 0.0)
|
||||||
|
leader_total_torques_nm[motor_name] = gravity + friction
|
||||||
|
|
||||||
|
# Apply gravity + friction compensation to leader RIGHT arm (all joints including gripper)
|
||||||
|
for motor in leader.bus_right.motors:
|
||||||
|
full_name = f"right_{motor}"
|
||||||
|
position = leader_positions_deg.get(full_name, 0.0)
|
||||||
|
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||||
|
|
||||||
|
# Get damping gain for stability
|
||||||
|
kd = leader.get_damping_kd(motor)
|
||||||
|
|
||||||
|
leader.bus_right._mit_control(
|
||||||
|
motor=motor,
|
||||||
|
kp=0.0,
|
||||||
|
kd=kd, # Add damping for stability
|
||||||
|
position_degrees=position,
|
||||||
|
velocity_deg_per_sec=0.0,
|
||||||
|
torque=torque,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply gravity + friction compensation to leader LEFT arm (all joints including gripper)
|
||||||
|
for motor in leader.bus_left.motors:
|
||||||
|
full_name = f"left_{motor}"
|
||||||
|
position = leader_positions_deg.get(full_name, 0.0)
|
||||||
|
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||||
|
|
||||||
|
# Get damping gain for stability
|
||||||
|
kd = leader.get_damping_kd(motor)
|
||||||
|
|
||||||
|
leader.bus_left._mit_control(
|
||||||
|
motor=motor,
|
||||||
|
kp=0.0,
|
||||||
|
kd=kd, # Add damping for stability
|
||||||
|
position_degrees=position,
|
||||||
|
velocity_deg_per_sec=0.0,
|
||||||
|
torque=torque,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send leader positions to follower (both arms)
|
||||||
|
follower_action = {}
|
||||||
|
for joint in all_joints:
|
||||||
|
pos_key = f"{joint}.pos"
|
||||||
|
if pos_key in leader_action:
|
||||||
|
follower_action[pos_key] = leader_action[pos_key]
|
||||||
|
|
||||||
|
# Send action to robot
|
||||||
|
if follower_action:
|
||||||
|
robot.send_action(follower_action)
|
||||||
|
|
||||||
|
# Get observation from robot (includes camera images)
|
||||||
|
observation = robot.get_observation()
|
||||||
|
|
||||||
|
# Add to dataset if we have a dataset
|
||||||
|
if dataset is not None:
|
||||||
|
# Build properly formatted observation frame
|
||||||
|
obs_frame = build_dataset_frame(dataset_features, observation, prefix="observation")
|
||||||
|
|
||||||
|
# Build properly formatted action frame (keep .pos suffix - it matches the feature names)
|
||||||
|
action_frame = build_dataset_frame(dataset_features, follower_action, prefix="action")
|
||||||
|
|
||||||
|
# Combine into single frame
|
||||||
|
frame = {**obs_frame, **action_frame}
|
||||||
|
|
||||||
|
# Add metadata (task is required, timestamp will be auto-calculated by add_frame)
|
||||||
|
frame["task"] = single_task
|
||||||
|
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
|
# Display data if requested
|
||||||
|
if display_data:
|
||||||
|
log_rerun_data(observation=observation, action=follower_action)
|
||||||
|
|
||||||
|
# Maintain loop rate
|
||||||
|
loop_duration = time.perf_counter() - loop_start
|
||||||
|
sleep_time = dt - loop_duration
|
||||||
|
if sleep_time > 0:
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main recording loop with gravity compensation."""
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print("OpenArms Dataset Recording with Compensation")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
# Create camera configurations (3 cameras: left wrist, right wrist, base)
|
||||||
|
# Using actual device paths found by lerobot-find-cameras opencv
|
||||||
|
camera_config = {
|
||||||
|
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video0", width=640, height=480, fps=FPS),
|
||||||
|
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=FPS),
|
||||||
|
"base": OpenCVCameraConfig(index_or_path="/dev/video7", width=640, height=480, fps=FPS),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Configure follower robot with cameras
|
||||||
|
follower_config = OpenArmsFollowerConfig(
|
||||||
|
port_left="can2",
|
||||||
|
port_right="can3",
|
||||||
|
can_interface="socketcan",
|
||||||
|
id="openarms_follower",
|
||||||
|
disable_torque_on_disconnect=True,
|
||||||
|
max_relative_target=10.0,
|
||||||
|
cameras=camera_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure leader teleoperator (no cameras needed)
|
||||||
|
leader_config = OpenArmsLeaderConfig(
|
||||||
|
port_left="can0",
|
||||||
|
port_right="can1",
|
||||||
|
can_interface="socketcan",
|
||||||
|
id="openarms_leader",
|
||||||
|
manual_control=False, # Enable torque control for gravity compensation
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize robot and teleoperator
|
||||||
|
print("\nInitializing devices...")
|
||||||
|
follower = OpenArmsFollower(follower_config)
|
||||||
|
leader = OpenArmsLeader(leader_config)
|
||||||
|
|
||||||
|
# Connect devices
|
||||||
|
print("Connecting and calibrating...")
|
||||||
|
follower.connect(calibrate=True)
|
||||||
|
leader.connect(calibrate=True)
|
||||||
|
|
||||||
|
# Verify URDF is loaded for gravity compensation
|
||||||
|
if leader.pin_robot is None:
|
||||||
|
raise RuntimeError("URDF model not loaded on leader. Gravity compensation not available.")
|
||||||
|
|
||||||
|
# Configure the dataset features
|
||||||
|
# For actions, we only want to record positions (not velocity or torque)
|
||||||
|
action_features_hw = {}
|
||||||
|
for key, value in follower.action_features.items():
|
||||||
|
if key.endswith(".pos"):
|
||||||
|
action_features_hw[key] = value
|
||||||
|
|
||||||
|
action_features = hw_to_dataset_features(action_features_hw, "action")
|
||||||
|
obs_features = hw_to_dataset_features(follower.observation_features, "observation")
|
||||||
|
dataset_features = {**action_features, **obs_features}
|
||||||
|
|
||||||
|
# Create the dataset
|
||||||
|
print("\nCreating dataset...")
|
||||||
|
repo_id = "<hf_username>/<dataset_repo_id>" # TODO: Replace with your Hugging Face repo
|
||||||
|
|
||||||
|
# Check if dataset already exists and prompt user
|
||||||
|
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id
|
||||||
|
while dataset_path.exists():
|
||||||
|
print(f"\nDataset already exists at: {dataset_path}")
|
||||||
|
print("\nOptions:")
|
||||||
|
print(" 1. Overwrite existing dataset")
|
||||||
|
print(" 2. Use a different name")
|
||||||
|
print(" 3. Abort")
|
||||||
|
|
||||||
|
choice = input("\nEnter your choice (1/2/3): ").strip()
|
||||||
|
|
||||||
|
if choice == '1':
|
||||||
|
print(f"Removing existing dataset...")
|
||||||
|
shutil.rmtree(dataset_path)
|
||||||
|
print("✓ Existing dataset removed")
|
||||||
|
break
|
||||||
|
elif choice == '2':
|
||||||
|
print("\nCurrent repo_id:", repo_id)
|
||||||
|
new_repo_id = input("Enter new repo_id (format: <username>/<dataset_name>): ").strip()
|
||||||
|
if new_repo_id and '/' in new_repo_id:
|
||||||
|
repo_id = new_repo_id
|
||||||
|
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id
|
||||||
|
print(f"✓ Using new repo_id: {repo_id}")
|
||||||
|
# Loop will continue if this new path also exists
|
||||||
|
else:
|
||||||
|
print("Invalid repo_id format. Please use format: <username>/<dataset_name>")
|
||||||
|
elif choice == '3':
|
||||||
|
print("Aborting. Please remove the existing dataset manually or restart with a different repo_id.")
|
||||||
|
follower.disconnect()
|
||||||
|
leader.disconnect()
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
print("Invalid choice. Please enter 1, 2, or 3.")
|
||||||
|
|
||||||
|
dataset = LeRobotDataset.create(
|
||||||
|
repo_id=repo_id,
|
||||||
|
fps=FPS,
|
||||||
|
features=dataset_features,
|
||||||
|
robot_type=follower.name,
|
||||||
|
use_videos=True,
|
||||||
|
image_writer_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize keyboard listener and visualization
|
||||||
|
_, events = init_keyboard_listener()
|
||||||
|
init_rerun(session_name="openarms_recording")
|
||||||
|
|
||||||
|
# Enable motors on both leader arms for gravity compensation
|
||||||
|
leader.bus_right.enable_torque()
|
||||||
|
leader.bus_left.enable_torque()
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print(f"Recording {NUM_EPISODES} episodes")
|
||||||
|
print(f"Task: {TASK_DESCRIPTION}")
|
||||||
|
print("=" * 70)
|
||||||
|
print("\nLeader BOTH arms: Gravity + Friction comp | Follower BOTH arms: Teleop")
|
||||||
|
print("\nKeyboard controls:")
|
||||||
|
print(" - Press 'q' to stop recording")
|
||||||
|
print(" - Press 'r' to re-record current episode")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
episode_idx = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||||
|
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||||
|
|
||||||
|
# Record episode with compensation active
|
||||||
|
record_loop_with_compensation(
|
||||||
|
robot=follower,
|
||||||
|
leader=leader,
|
||||||
|
events=events,
|
||||||
|
fps=FPS,
|
||||||
|
dataset=dataset,
|
||||||
|
dataset_features=dataset_features,
|
||||||
|
control_time_s=EPISODE_TIME_SEC,
|
||||||
|
single_task=TASK_DESCRIPTION,
|
||||||
|
display_data=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset the environment if not stopping or re-recording
|
||||||
|
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||||
|
log_say("Reset the environment")
|
||||||
|
record_loop_with_compensation(
|
||||||
|
robot=follower,
|
||||||
|
leader=leader,
|
||||||
|
events=events,
|
||||||
|
fps=FPS,
|
||||||
|
dataset=None, # Don't save reset period
|
||||||
|
dataset_features=dataset_features,
|
||||||
|
control_time_s=RESET_TIME_SEC,
|
||||||
|
single_task=TASK_DESCRIPTION,
|
||||||
|
display_data=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle re-recording
|
||||||
|
if events["rerecord_episode"]:
|
||||||
|
log_say("Re-recording episode")
|
||||||
|
events["rerecord_episode"] = False
|
||||||
|
events["exit_early"] = False
|
||||||
|
dataset.clear_episode_buffer()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Only save episode if frames were recorded
|
||||||
|
if dataset.episode_buffer is not None and dataset.episode_buffer["size"] > 0:
|
||||||
|
dataset.save_episode()
|
||||||
|
episode_idx += 1
|
||||||
|
else:
|
||||||
|
log_say("No frames recorded, skipping episode save")
|
||||||
|
# Clear the empty buffer
|
||||||
|
dataset.episode_buffer = None
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nStopping recording...")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
log_say("Stop recording")
|
||||||
|
try:
|
||||||
|
leader.bus_right.disable_torque()
|
||||||
|
leader.bus_left.disable_torque()
|
||||||
|
time.sleep(0.1)
|
||||||
|
leader.disconnect()
|
||||||
|
follower.disconnect()
|
||||||
|
print("✓ Shutdown complete")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Shutdown error: {e}")
|
||||||
|
|
||||||
|
# Upload dataset
|
||||||
|
print("\nUploading dataset to Hugging Face Hub...")
|
||||||
|
try:
|
||||||
|
dataset.push_to_hub()
|
||||||
|
print("✓ Dataset uploaded successfully")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to upload dataset: {e}")
|
||||||
|
print("You can manually upload later using: dataset.push_to_hub()")
|
||||||
|
|
||||||
|
print("✓ Recording complete!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
166
examples/openarms/replay.py
Normal file
166
examples/openarms/replay.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
OpenArms Dataset Replay Example
|
||||||
|
|
||||||
|
Replays position actions from a recorded dataset on an OpenArms follower robot.
|
||||||
|
Only position commands (ending with .pos) are replayed, not velocity or torque.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
python examples/openarms/replay.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||||
|
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||||
|
from lerobot.utils.constants import ACTION
|
||||||
|
from lerobot.utils.robot_utils import busy_wait
|
||||||
|
from lerobot.utils.utils import log_say
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
EPISODE_IDX = 0
|
||||||
|
DATASET_REPO_ID = "lerobot-data-collection/replay-this-2025-11-02-17-58" # TODO: Replace with your dataset
|
||||||
|
DATASET_ROOT = None # Use default cache location, or specify custom path
|
||||||
|
|
||||||
|
# Robot configuration - adjust these to match your setup
|
||||||
|
ROBOT_CONFIG = OpenArmsFollowerConfig(
|
||||||
|
port_left="can2", # CAN interface for left arm
|
||||||
|
port_right="can3", # CAN interface for right arm
|
||||||
|
can_interface="socketcan",
|
||||||
|
id="openarms_follower",
|
||||||
|
disable_torque_on_disconnect=True,
|
||||||
|
max_relative_target=10.0, # Safety limit: max degrees to move per step
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main replay function."""
|
||||||
|
print("=" * 70)
|
||||||
|
print("OpenArms Dataset Replay")
|
||||||
|
print("=" * 70)
|
||||||
|
print(f"\nDataset: {DATASET_REPO_ID}")
|
||||||
|
print(f"Episode: {EPISODE_IDX}")
|
||||||
|
print(f"Robot: {ROBOT_CONFIG.id}")
|
||||||
|
print(f" Left arm: {ROBOT_CONFIG.port_left}")
|
||||||
|
print(f" Right arm: {ROBOT_CONFIG.port_right}")
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
|
||||||
|
# Initialize the robot
|
||||||
|
print("\n[1/3] Initializing robot...")
|
||||||
|
robot = OpenArmsFollower(ROBOT_CONFIG)
|
||||||
|
|
||||||
|
# Load the dataset
|
||||||
|
print(f"\n[2/3] Loading dataset '{DATASET_REPO_ID}'...")
|
||||||
|
dataset = LeRobotDataset(
|
||||||
|
DATASET_REPO_ID,
|
||||||
|
root=DATASET_ROOT,
|
||||||
|
episodes=[EPISODE_IDX]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter dataset to only include frames from the specified episode
|
||||||
|
# (required for dataset V3.0 where episodes are chunked)
|
||||||
|
episode_frames = dataset.hf_dataset.filter(
|
||||||
|
lambda x: x["episode_index"] == EPISODE_IDX
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(episode_frames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No frames found for episode {EPISODE_IDX} in dataset {DATASET_REPO_ID}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" Found {len(episode_frames)} frames in episode {EPISODE_IDX}")
|
||||||
|
|
||||||
|
# Extract action features from dataset
|
||||||
|
action_features = dataset.features.get(ACTION, {})
|
||||||
|
action_names = action_features.get("names", [])
|
||||||
|
|
||||||
|
# Filter to only position actions (ending with .pos)
|
||||||
|
position_action_names = [name for name in action_names if name.endswith(".pos")]
|
||||||
|
|
||||||
|
if not position_action_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"No position actions found in dataset. Action names: {action_names}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" Found {len(position_action_names)} position actions to replay")
|
||||||
|
print(f" Actions: {', '.join(position_action_names[:5])}{'...' if len(position_action_names) > 5 else ''}")
|
||||||
|
|
||||||
|
# Select only action columns from dataset
|
||||||
|
actions = episode_frames.select_columns(ACTION)
|
||||||
|
|
||||||
|
# Connect to the robot
|
||||||
|
print(f"\n[3/3] Connecting to robot...")
|
||||||
|
robot.connect(calibrate=False) # Skip calibration for replay
|
||||||
|
|
||||||
|
if not robot.is_connected:
|
||||||
|
raise RuntimeError("Robot failed to connect!")
|
||||||
|
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("Ready to replay!")
|
||||||
|
print("=" * 70)
|
||||||
|
print("\nThe robot will replay the recorded positions.")
|
||||||
|
print("Press Ctrl+C to stop at any time.\n")
|
||||||
|
|
||||||
|
input("Press ENTER to start replaying...")
|
||||||
|
|
||||||
|
# Replay loop
|
||||||
|
log_say(f"Replaying episode {EPISODE_IDX}", blocking=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for idx in range(len(episode_frames)):
|
||||||
|
loop_start = time.perf_counter()
|
||||||
|
|
||||||
|
# Extract action array from dataset
|
||||||
|
action_array = actions[idx][ACTION]
|
||||||
|
|
||||||
|
# Build action dictionary, but only include position actions
|
||||||
|
action = {}
|
||||||
|
for i, name in enumerate(action_names):
|
||||||
|
# Only include position actions (ending with .pos)
|
||||||
|
if name.endswith(".pos"):
|
||||||
|
action[name] = float(action_array[i])
|
||||||
|
|
||||||
|
# Send action to robot
|
||||||
|
robot.send_action(action)
|
||||||
|
|
||||||
|
# Maintain replay rate (use dataset fps)
|
||||||
|
loop_duration = time.perf_counter() - loop_start
|
||||||
|
dt_s = 1.0 / dataset.fps - loop_duration
|
||||||
|
busy_wait(dt_s)
|
||||||
|
|
||||||
|
# Progress indicator every 100 frames
|
||||||
|
if (idx + 1) % 100 == 0:
|
||||||
|
progress = (idx + 1) / len(episode_frames) * 100
|
||||||
|
print(f"Progress: {idx + 1}/{len(episode_frames)} frames ({progress:.1f}%)")
|
||||||
|
|
||||||
|
print(f"\n✓ Successfully replayed {len(episode_frames)} frames")
|
||||||
|
log_say("Replay complete", blocking=True)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nReplay interrupted by user")
|
||||||
|
finally:
|
||||||
|
# Disconnect robot
|
||||||
|
print("\nDisconnecting robot...")
|
||||||
|
robot.disconnect()
|
||||||
|
print("✓ Replay complete!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
73
examples/openarms/setup_can.sh
Executable file
73
examples/openarms/setup_can.sh
Executable file
@@ -0,0 +1,73 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Setup all OpenArms CAN interfaces with CAN FD
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo "=========================================="
|
||||||
|
echo "OpenArms CAN FD Interface Setup"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
echo "Mode: CAN FD"
|
||||||
|
echo " - Nominal bitrate: 1 Mbps"
|
||||||
|
echo " - Data bitrate: 5 Mbps"
|
||||||
|
echo ""
|
||||||
|
echo "Configuring interfaces can0, can1, can2, can3..."
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Configure each CAN interface with CAN FD
|
||||||
|
for i in 0 1 2 3; do
|
||||||
|
interface="can$i"
|
||||||
|
|
||||||
|
# Check if interface exists
|
||||||
|
if ! ip link show "$interface" &> /dev/null; then
|
||||||
|
echo "⚠ $interface: Not found, skipping"
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Bring down interface
|
||||||
|
sudo ip link set "$interface" down 2>/dev/null
|
||||||
|
|
||||||
|
# Configure CAN FD mode
|
||||||
|
sudo ip link set "$interface" type can \
|
||||||
|
bitrate 1000000 \
|
||||||
|
dbitrate 5000000 \
|
||||||
|
fd on
|
||||||
|
|
||||||
|
# Bring up interface
|
||||||
|
sudo ip link set "$interface" up
|
||||||
|
|
||||||
|
# Verify configuration
|
||||||
|
if ip link show "$interface" | grep -q "UP"; then
|
||||||
|
echo "✓ $interface: Configured and UP"
|
||||||
|
else
|
||||||
|
echo "✗ $interface: Failed to bring UP"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo "Verification"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Show detailed status for each interface
|
||||||
|
for i in 0 1 2 3; do
|
||||||
|
interface="can$i"
|
||||||
|
if ip link show "$interface" &> /dev/null; then
|
||||||
|
echo "$interface:"
|
||||||
|
# Show key parameters
|
||||||
|
ip -d link show "$interface" | grep -E "can|state|bitrate|dbitrate" | head -3
|
||||||
|
echo ""
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "=========================================="
|
||||||
|
echo "Setup Complete!"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
echo "All interfaces configured for CAN FD mode"
|
||||||
|
echo ""
|
||||||
|
echo "Next steps:"
|
||||||
|
echo " 1. Test motors: python debug_can_communication.py"
|
||||||
|
echo " 2. Run teleoperation: python examples/openarms/teleop.py"
|
||||||
|
echo ""
|
||||||
148
examples/openarms/teleop.py
Normal file
148
examples/openarms/teleop.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""
|
||||||
|
OpenArms Teleoperation Example - Full Dual Arms
|
||||||
|
|
||||||
|
This script demonstrates teleoperation of OpenArms follower robot using an OpenArms leader arm.
|
||||||
|
It first calibrates both devices, then enters a teleoperation loop for both arms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||||
|
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||||
|
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||||
|
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||||
|
|
||||||
|
|
||||||
|
follower_config = OpenArmsFollowerConfig(
|
||||||
|
port_left="can2", # CAN interface for follower left arm
|
||||||
|
port_right="can3", # CAN interface for follower right arm
|
||||||
|
can_interface="socketcan", # Linux SocketCAN
|
||||||
|
id="openarms_follower",
|
||||||
|
disable_torque_on_disconnect=True,
|
||||||
|
max_relative_target=5.0, # Safety limit
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
leader_config = OpenArmsLeaderConfig(
|
||||||
|
port_left="can0", # CAN interface for leader left arm
|
||||||
|
port_right="can1", # CAN interface for leader right arm
|
||||||
|
can_interface="socketcan", # Linux SocketCAN
|
||||||
|
id="openarms_leader",
|
||||||
|
manual_control=True, # Enable manual control (torque disabled)
|
||||||
|
)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("OpenArms Teleoperation - Full Dual Arms")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Initialize devices
|
||||||
|
print("\n[1/4] Initializing devices...")
|
||||||
|
follower = OpenArmsFollower(follower_config)
|
||||||
|
leader = OpenArmsLeader(leader_config)
|
||||||
|
|
||||||
|
# Connect and calibrate follower
|
||||||
|
print("\n[2/4] Connecting and calibrating follower robot...")
|
||||||
|
print("Note: If you have existing calibration, just press ENTER to use it.")
|
||||||
|
follower.connect(calibrate=True)
|
||||||
|
|
||||||
|
# Connect and calibrate leader
|
||||||
|
print("\n[3/4] Connecting and calibrating leader arm...")
|
||||||
|
print("Note: The leader arm will have torque disabled for manual control.")
|
||||||
|
leader.connect(calibrate=True)
|
||||||
|
|
||||||
|
# Wait for user to be ready
|
||||||
|
print("\n[4/4] Ready for teleoperation!")
|
||||||
|
print("\nBoth arms will be controlled (16 motors total):")
|
||||||
|
print(" RIGHT ARM: joints 1-7 + gripper")
|
||||||
|
print(" LEFT ARM: joints 1-7 + gripper")
|
||||||
|
|
||||||
|
print("\nPress ENTER to start teleoperation...")
|
||||||
|
input()
|
||||||
|
|
||||||
|
print("\nTeleoperation started! Move both leader arms.")
|
||||||
|
print("Press Ctrl+C to stop.\n")
|
||||||
|
|
||||||
|
# All joints for both arms (16 motors total)
|
||||||
|
all_joints = [
|
||||||
|
# Right arm
|
||||||
|
"right_joint_1",
|
||||||
|
"right_joint_2",
|
||||||
|
"right_joint_3",
|
||||||
|
"right_joint_4",
|
||||||
|
"right_joint_5",
|
||||||
|
"right_joint_6",
|
||||||
|
"right_joint_7",
|
||||||
|
"right_gripper",
|
||||||
|
# Left arm
|
||||||
|
"left_joint_1",
|
||||||
|
"left_joint_2",
|
||||||
|
"left_joint_3",
|
||||||
|
"left_joint_4",
|
||||||
|
"left_joint_5",
|
||||||
|
"left_joint_6",
|
||||||
|
"left_joint_7",
|
||||||
|
"left_gripper",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Performance monitoring
|
||||||
|
loop_times = []
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
last_print_time = start_time
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
loop_start = time.perf_counter()
|
||||||
|
|
||||||
|
# Get action from leader
|
||||||
|
leader_action = leader.get_action()
|
||||||
|
|
||||||
|
# Filter to only position data for all joints (both arms)
|
||||||
|
joint_action = {}
|
||||||
|
for joint in all_joints:
|
||||||
|
pos_key = f"{joint}.pos"
|
||||||
|
if pos_key in leader_action:
|
||||||
|
joint_action[pos_key] = leader_action[pos_key]
|
||||||
|
|
||||||
|
# Send action to follower (both arms)
|
||||||
|
if joint_action:
|
||||||
|
follower.send_action(joint_action)
|
||||||
|
|
||||||
|
# Measure loop time
|
||||||
|
loop_end = time.perf_counter()
|
||||||
|
loop_time = loop_end - loop_start
|
||||||
|
loop_times.append(loop_time)
|
||||||
|
|
||||||
|
# Print stats every 2 seconds
|
||||||
|
if loop_end - last_print_time >= 2.0:
|
||||||
|
if loop_times:
|
||||||
|
avg_time = sum(loop_times) / len(loop_times)
|
||||||
|
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||||
|
min_time = min(loop_times)
|
||||||
|
max_time = max(loop_times)
|
||||||
|
max_hz = 1.0 / min_time if min_time > 0 else 0
|
||||||
|
min_hz = 1.0 / max_time if max_time > 0 else 0
|
||||||
|
|
||||||
|
print(f"[Hz Stats] Avg: {current_hz:.1f} Hz | "
|
||||||
|
f"Range: {min_hz:.1f}-{max_hz:.1f} Hz | "
|
||||||
|
f"Avg loop time: {avg_time*1000:.1f} ms")
|
||||||
|
|
||||||
|
# Reset for next measurement window
|
||||||
|
loop_times = []
|
||||||
|
last_print_time = loop_end
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nStopping teleoperation...")
|
||||||
|
finally:
|
||||||
|
# Disconnect devices
|
||||||
|
print("Disconnecting devices...")
|
||||||
|
try:
|
||||||
|
follower.disconnect()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error disconnecting follower: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
leader.disconnect()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error disconnecting leader: {e}")
|
||||||
|
|
||||||
|
print("Done!")
|
||||||
197
examples/openarms/teleop_openarms_mini.py
Normal file
197
examples/openarms/teleop_openarms_mini.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
"""
|
||||||
|
OpenArms Mini Teleoperation Example
|
||||||
|
|
||||||
|
This script demonstrates teleoperation of an OpenArms follower robot using
|
||||||
|
an OpenArms Mini leader (Feetech-based) with dual arms (16 motors total).
|
||||||
|
|
||||||
|
The OpenArms Mini has:
|
||||||
|
- Right arm: 8 motors (joint_1 to joint_7 + gripper)
|
||||||
|
- Left arm: 8 motors (joint_1 to joint_7 + gripper)
|
||||||
|
|
||||||
|
Note on gripper normalization:
|
||||||
|
- OpenArms Mini gripper: 0-100 scale (0=closed, 100=open)
|
||||||
|
- OpenArms follower gripper: degrees (0=closed, -65=open)
|
||||||
|
- This script automatically converts between the two ranges
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||||
|
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||||
|
from lerobot.teleoperators.openarms_mini.openarms_mini import OpenArmsMini
|
||||||
|
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig
|
||||||
|
from lerobot.utils.robot_utils import busy_wait
|
||||||
|
|
||||||
|
# Target control frequency
|
||||||
|
TARGET_FPS = 30
|
||||||
|
|
||||||
|
# Configure the OpenArms follower (Damiao motors on CAN bus)
|
||||||
|
follower_config = OpenArmsFollowerConfig(
|
||||||
|
port_left="can0", # CAN interface for follower left arm
|
||||||
|
port_right="can1", # CAN interface for follower right arm
|
||||||
|
can_interface="socketcan", # Linux SocketCAN
|
||||||
|
id="openarms_follower",
|
||||||
|
disable_torque_on_disconnect=True,
|
||||||
|
max_relative_target=10.0, # Safety limit (degrees per step)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure the OpenArms Mini leader (Feetech motors on serial)
|
||||||
|
leader_config = OpenArmsMiniConfig(
|
||||||
|
port_right="/dev/ttyACM0", # Serial port for right arm
|
||||||
|
port_left="/dev/ttyACM1", # Serial port for left arm
|
||||||
|
id="openarms_mini",
|
||||||
|
use_degrees=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("OpenArms Mini → OpenArms Follower Teleoperation")
|
||||||
|
|
||||||
|
# Initialize devices
|
||||||
|
follower = OpenArmsFollower(follower_config)
|
||||||
|
leader = OpenArmsMini(leader_config)
|
||||||
|
|
||||||
|
# Connect and calibrate follower
|
||||||
|
print("Note: If you have existing calibration, just press ENTER to use it.")
|
||||||
|
follower.connect(calibrate=True)
|
||||||
|
|
||||||
|
# Connect and calibrate leader
|
||||||
|
print("Note: The leader arms will have torque disabled for manual control.")
|
||||||
|
leader.connect(calibrate=True)
|
||||||
|
|
||||||
|
print("\nPress ENTER to start teleoperation...")
|
||||||
|
input()
|
||||||
|
|
||||||
|
print("Press Ctrl+C to stop.\n")
|
||||||
|
|
||||||
|
# All joints for both arms (16 motors total)
|
||||||
|
all_joints = [
|
||||||
|
# Right arm
|
||||||
|
"right_joint_1",
|
||||||
|
"right_joint_2",
|
||||||
|
"right_joint_3",
|
||||||
|
"right_joint_4",
|
||||||
|
"right_joint_5",
|
||||||
|
"right_joint_6",
|
||||||
|
"right_joint_7",
|
||||||
|
"right_gripper",
|
||||||
|
# Left arm
|
||||||
|
"left_joint_1",
|
||||||
|
"left_joint_2",
|
||||||
|
"left_joint_3",
|
||||||
|
"left_joint_4",
|
||||||
|
"left_joint_5",
|
||||||
|
"left_joint_6",
|
||||||
|
"left_joint_7",
|
||||||
|
"left_gripper",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Performance monitoring
|
||||||
|
loop_times = []
|
||||||
|
avg_loop_time = 0.0
|
||||||
|
min_loop_time = float('inf')
|
||||||
|
max_loop_time = 0.0
|
||||||
|
stats_update_interval = 1.0 # Update stats every 1 second
|
||||||
|
last_stats_update = time.perf_counter()
|
||||||
|
|
||||||
|
|
||||||
|
SWAPPED_JOINTS = {
|
||||||
|
"right_joint_6": "right_joint_7",
|
||||||
|
"right_joint_7": "right_joint_6",
|
||||||
|
"left_joint_6": "left_joint_7",
|
||||||
|
"left_joint_7": "left_joint_6",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
loop_start = time.perf_counter()
|
||||||
|
|
||||||
|
# Get actions and observations
|
||||||
|
leader_action = leader.get_action()
|
||||||
|
follower_obs = follower.get_observation()
|
||||||
|
|
||||||
|
joint_action = {}
|
||||||
|
for joint in all_joints:
|
||||||
|
leader_key = f"{joint}.pos"
|
||||||
|
|
||||||
|
# Determine which follower joint this leader joint controls
|
||||||
|
follower_joint = SWAPPED_JOINTS.get(joint, joint)
|
||||||
|
follower_key = f"{follower_joint}.pos"
|
||||||
|
|
||||||
|
# Get leader position (default 0 if missing)
|
||||||
|
pos = leader_action.get(leader_key, 0.0)
|
||||||
|
|
||||||
|
# Convert gripper values: Mini uses 0-100, OpenArms uses 0 to -65 degrees
|
||||||
|
if "gripper" in joint:
|
||||||
|
# Map 0-100 (Mini) to 0 to -65 (OpenArms)
|
||||||
|
# 0 (closed) -> 0°, 100 (open) -> -65°
|
||||||
|
pos = (pos / 100.0) * -65.0
|
||||||
|
|
||||||
|
# Store in action dict for follower
|
||||||
|
joint_action[follower_key] = pos
|
||||||
|
|
||||||
|
follower.send_action(joint_action)
|
||||||
|
|
||||||
|
# Loop timing
|
||||||
|
loop_end = time.perf_counter()
|
||||||
|
loop_time = loop_end - loop_start
|
||||||
|
loop_times.append(loop_time)
|
||||||
|
|
||||||
|
# Update stats periodically
|
||||||
|
current_time = time.perf_counter()
|
||||||
|
if current_time - last_stats_update >= stats_update_interval:
|
||||||
|
if loop_times:
|
||||||
|
avg_loop_time = sum(loop_times) / len(loop_times)
|
||||||
|
min_loop_time = min(loop_times)
|
||||||
|
max_loop_time = max(loop_times)
|
||||||
|
loop_times = []
|
||||||
|
last_stats_update = current_time
|
||||||
|
|
||||||
|
# Display everything
|
||||||
|
sys.stdout.write("\033[H\033[J") # Clear screen
|
||||||
|
|
||||||
|
# Show timing stats at the top
|
||||||
|
if avg_loop_time > 0:
|
||||||
|
avg_hz = 1.0 / avg_loop_time
|
||||||
|
min_hz = 1.0 / max_loop_time if max_loop_time > 0 else 0
|
||||||
|
max_hz = 1.0 / min_loop_time if min_loop_time > 0 and min_loop_time < float('inf') else 0
|
||||||
|
print(f"[Performance] Target: {TARGET_FPS} Hz | Avg: {avg_hz:.1f} Hz | Range: {min_hz:.1f}-{max_hz:.1f} Hz | Loop: {avg_loop_time*1000:.1f} ms\n")
|
||||||
|
else:
|
||||||
|
print(f"[Performance] Target: {TARGET_FPS} Hz | Measuring...\n")
|
||||||
|
|
||||||
|
# Show joint positions
|
||||||
|
print(f"{'Joint':<20} {'Leader':>15} {'Follower':>15}")
|
||||||
|
print(f"{'':20} {'(0-100/deg)':>15} {'(deg)':>15}")
|
||||||
|
print("-" * 52)
|
||||||
|
|
||||||
|
for joint in all_joints:
|
||||||
|
leader_key = f"{joint}.pos"
|
||||||
|
follower_joint = SWAPPED_JOINTS.get(joint, joint)
|
||||||
|
follower_key = f"{follower_joint}.pos"
|
||||||
|
|
||||||
|
leader_pos = leader_action.get(leader_key, 0.0)
|
||||||
|
follower_pos = follower_obs.get(follower_key, 0.0)
|
||||||
|
|
||||||
|
print(f"{joint:<20} {leader_pos:>15.2f} {follower_pos:>15.2f}")
|
||||||
|
|
||||||
|
# Smart sleep to maintain target FPS
|
||||||
|
dt_s = time.perf_counter() - loop_start
|
||||||
|
busy_wait(max(0, 1.0 / TARGET_FPS - dt_s))
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nStopping teleoperation...")
|
||||||
|
finally:
|
||||||
|
# Disconnect devices
|
||||||
|
print("Disconnecting devices...")
|
||||||
|
try:
|
||||||
|
follower.disconnect()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error disconnecting follower: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
leader.disconnect()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error disconnecting leader: {e}")
|
||||||
|
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
202
examples/openarms/teleop_with_compensation.py
Executable file
202
examples/openarms/teleop_with_compensation.py
Executable file
@@ -0,0 +1,202 @@
|
|||||||
|
"""
|
||||||
|
OpenArms Teleoperation with Gravity + Friction Compensation
|
||||||
|
|
||||||
|
Leader arms (both LEFT and RIGHT): Gravity + Friction compensation (weightless, easy to move)
|
||||||
|
Follower arms (both LEFT and RIGHT): Mirror leader movements
|
||||||
|
|
||||||
|
Uses the URDF file from the lerobot repository.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||||
|
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||||
|
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||||
|
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||||
|
|
||||||
|
# Friction compensation scale factor (1.0 = full, 0.3 = 30% for stability)
|
||||||
|
FRICTION_SCALE = 1.0
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main teleoperation loop with gravity compensation"""
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print("OpenArms Teleoperation with Gravity Compensation")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
follower_config = OpenArmsFollowerConfig(
|
||||||
|
port_left="can2",
|
||||||
|
port_right="can3",
|
||||||
|
can_interface="socketcan",
|
||||||
|
id="openarms_follower",
|
||||||
|
disable_torque_on_disconnect=True,
|
||||||
|
max_relative_target=10.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
leader_config = OpenArmsLeaderConfig(
|
||||||
|
port_left="can0",
|
||||||
|
port_right="can1",
|
||||||
|
can_interface="socketcan",
|
||||||
|
id="openarms_leader",
|
||||||
|
manual_control=False, # Enable torque control for gravity compensation
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize and connect
|
||||||
|
print("\nInitializing devices...")
|
||||||
|
follower = OpenArmsFollower(follower_config)
|
||||||
|
leader = OpenArmsLeader(leader_config)
|
||||||
|
|
||||||
|
follower.connect()
|
||||||
|
leader.connect()
|
||||||
|
|
||||||
|
# URDF is automatically loaded in the leader constructor
|
||||||
|
if leader.pin_robot is None:
|
||||||
|
raise RuntimeError("URDF model not loaded on leader. Gravity compensation not available.")
|
||||||
|
|
||||||
|
print("\nLeader BOTH arms: Gravity + Friction comp | Follower BOTH arms: Teleop")
|
||||||
|
print("Press ENTER to start...")
|
||||||
|
input()
|
||||||
|
|
||||||
|
# Enable motors on both leader arms for gravity compensation
|
||||||
|
leader.bus_right.enable_torque()
|
||||||
|
leader.bus_left.enable_torque()
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
print("Press Ctrl+C to stop\n")
|
||||||
|
|
||||||
|
# Main control loop
|
||||||
|
loop_times = []
|
||||||
|
last_print_time = time.perf_counter()
|
||||||
|
|
||||||
|
# All joints (both arms)
|
||||||
|
all_joints = []
|
||||||
|
for motor in leader.bus_right.motors:
|
||||||
|
all_joints.append(f"right_{motor}")
|
||||||
|
for motor in leader.bus_left.motors:
|
||||||
|
all_joints.append(f"left_{motor}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
loop_start = time.perf_counter()
|
||||||
|
|
||||||
|
# Get leader state
|
||||||
|
leader_action = leader.get_action()
|
||||||
|
|
||||||
|
# Extract positions and velocities in degrees
|
||||||
|
leader_positions_deg = {}
|
||||||
|
leader_velocities_deg_per_sec = {}
|
||||||
|
|
||||||
|
for motor in leader.bus_right.motors:
|
||||||
|
pos_key = f"right_{motor}.pos"
|
||||||
|
vel_key = f"right_{motor}.vel"
|
||||||
|
if pos_key in leader_action:
|
||||||
|
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
|
||||||
|
if vel_key in leader_action:
|
||||||
|
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
|
||||||
|
|
||||||
|
for motor in leader.bus_left.motors:
|
||||||
|
pos_key = f"left_{motor}.pos"
|
||||||
|
vel_key = f"left_{motor}.vel"
|
||||||
|
if pos_key in leader_action:
|
||||||
|
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
|
||||||
|
if vel_key in leader_action:
|
||||||
|
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
|
||||||
|
|
||||||
|
# Calculate gravity torques for leader using built-in method
|
||||||
|
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
|
||||||
|
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
|
||||||
|
|
||||||
|
# Calculate friction torques for leader using built-in method
|
||||||
|
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
|
||||||
|
leader_friction_torques_nm = leader._friction_from_velocity(
|
||||||
|
leader_velocities_rad_per_sec,
|
||||||
|
friction_scale=FRICTION_SCALE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine gravity + friction torques
|
||||||
|
leader_total_torques_nm = {}
|
||||||
|
for motor_name in leader_gravity_torques_nm:
|
||||||
|
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
|
||||||
|
friction = leader_friction_torques_nm.get(motor_name, 0.0)
|
||||||
|
leader_total_torques_nm[motor_name] = gravity + friction
|
||||||
|
|
||||||
|
# Apply gravity + friction compensation to leader RIGHT arm (all joints including gripper)
|
||||||
|
for motor in leader.bus_right.motors:
|
||||||
|
full_name = f"right_{motor}"
|
||||||
|
position = leader_positions_deg.get(full_name, 0.0)
|
||||||
|
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||||
|
|
||||||
|
# Get damping gain for stability
|
||||||
|
kd = leader.get_damping_kd(motor)
|
||||||
|
|
||||||
|
leader.bus_right._mit_control(
|
||||||
|
motor=motor,
|
||||||
|
kp=0.0,
|
||||||
|
kd=kd, # Add damping for stability
|
||||||
|
position_degrees=position,
|
||||||
|
velocity_deg_per_sec=0.0,
|
||||||
|
torque=torque,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply gravity + friction compensation to leader LEFT arm (all joints including gripper)
|
||||||
|
for motor in leader.bus_left.motors:
|
||||||
|
full_name = f"left_{motor}"
|
||||||
|
position = leader_positions_deg.get(full_name, 0.0)
|
||||||
|
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||||
|
|
||||||
|
# Get damping gain for stability
|
||||||
|
kd = leader.get_damping_kd(motor)
|
||||||
|
|
||||||
|
leader.bus_left._mit_control(
|
||||||
|
motor=motor,
|
||||||
|
kp=0.0,
|
||||||
|
kd=kd, # Add damping for stability
|
||||||
|
position_degrees=position,
|
||||||
|
velocity_deg_per_sec=0.0,
|
||||||
|
torque=torque,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send leader positions to follower (both arms)
|
||||||
|
follower_action = {}
|
||||||
|
for joint in all_joints:
|
||||||
|
pos_key = f"{joint}.pos"
|
||||||
|
if pos_key in leader_action:
|
||||||
|
follower_action[pos_key] = leader_action[pos_key]
|
||||||
|
|
||||||
|
if follower_action:
|
||||||
|
follower.send_action(follower_action)
|
||||||
|
|
||||||
|
# Performance monitoring
|
||||||
|
loop_end = time.perf_counter()
|
||||||
|
loop_time = loop_end - loop_start
|
||||||
|
loop_times.append(loop_time)
|
||||||
|
|
||||||
|
if loop_end - last_print_time >= 2.0:
|
||||||
|
if loop_times:
|
||||||
|
avg_time = sum(loop_times) / len(loop_times)
|
||||||
|
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||||
|
|
||||||
|
print(f"{current_hz:.1f} Hz ({avg_time*1000:.1f} ms)")
|
||||||
|
|
||||||
|
loop_times = []
|
||||||
|
last_print_time = loop_end
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nStopping...")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
leader.bus_right.disable_torque()
|
||||||
|
leader.bus_left.disable_torque()
|
||||||
|
time.sleep(0.1)
|
||||||
|
leader.disconnect()
|
||||||
|
follower.disconnect()
|
||||||
|
print("✓ Shutdown complete")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Shutdown error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
745
examples/openarms_web_interface/App.css
Normal file
745
examples/openarms_web_interface/App.css
Normal file
@@ -0,0 +1,745 @@
|
|||||||
|
body {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||||
|
background: #f5f5f5;
|
||||||
|
}
|
||||||
|
|
||||||
|
main {
|
||||||
|
min-height: 100vh;
|
||||||
|
padding: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
header {
|
||||||
|
text-align: center;
|
||||||
|
margin-bottom: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
h1 {
|
||||||
|
font-size: 2rem;
|
||||||
|
font-weight: 600;
|
||||||
|
color: #333;
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
h2 {
|
||||||
|
font-size: 1.25rem;
|
||||||
|
font-weight: 600;
|
||||||
|
color: #333;
|
||||||
|
margin: 0 0 1rem 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
h3 {
|
||||||
|
font-size: 0.875rem;
|
||||||
|
font-weight: 600;
|
||||||
|
color: #666;
|
||||||
|
margin: 0 0 0.5rem 0;
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 0.5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.container {
|
||||||
|
max-width: 1920px;
|
||||||
|
margin: 0 auto;
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: minmax(500px, 600px) 1fr;
|
||||||
|
gap: 2rem;
|
||||||
|
align-items: start;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Left column container */
|
||||||
|
.left-column {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Right column container */
|
||||||
|
.right-column {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Responsive: Stack on smaller screens */
|
||||||
|
@media (max-width: 1200px) {
|
||||||
|
.container {
|
||||||
|
grid-template-columns: 1fr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.panel {
|
||||||
|
background: white;
|
||||||
|
border-radius: 8px;
|
||||||
|
padding: 1.5rem;
|
||||||
|
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.config-panel {
|
||||||
|
border: 2px solid #e5e7eb;
|
||||||
|
}
|
||||||
|
|
||||||
|
.config-header {
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
align-items: center;
|
||||||
|
cursor: pointer;
|
||||||
|
user-select: none;
|
||||||
|
padding: 0.5rem 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.config-header:hover {
|
||||||
|
opacity: 0.7;
|
||||||
|
}
|
||||||
|
|
||||||
|
.toggle-icon {
|
||||||
|
font-size: 1rem;
|
||||||
|
color: #6b7280;
|
||||||
|
transition: transform 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.config-content {
|
||||||
|
margin-top: 1rem;
|
||||||
|
padding-top: 1rem;
|
||||||
|
border-top: 1px solid #e5e7eb;
|
||||||
|
}
|
||||||
|
|
||||||
|
.robot-setup {
|
||||||
|
margin-bottom: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.robot-status {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: space-between;
|
||||||
|
padding: 1rem;
|
||||||
|
border-radius: 6px;
|
||||||
|
font-weight: 500;
|
||||||
|
gap: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.robot-status.ready {
|
||||||
|
background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%);
|
||||||
|
color: #065f46;
|
||||||
|
border: 1px solid #10b981;
|
||||||
|
}
|
||||||
|
|
||||||
|
.robot-status.not-ready {
|
||||||
|
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
|
||||||
|
color: #92400e;
|
||||||
|
border: 1px solid #f59e0b;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-setup {
|
||||||
|
background: #10b981;
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
padding: 0.5rem 1rem;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
font-weight: 500;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: background 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-setup:hover:not(:disabled) {
|
||||||
|
background: #059669;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-setup:disabled {
|
||||||
|
background: #d1d5db;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-zero {
|
||||||
|
background: #8b5cf6;
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
padding: 0.5rem 1rem;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
font-weight: 500;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: background 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-zero:hover:not(:disabled) {
|
||||||
|
background: #7c3aed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-zero:disabled {
|
||||||
|
background: #d1d5db;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.zero-position-section {
|
||||||
|
margin-top: 1rem;
|
||||||
|
padding-top: 1rem;
|
||||||
|
border-top: 1px solid #e5e7eb;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-zero-large {
|
||||||
|
width: 100%;
|
||||||
|
background: #8b5cf6;
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
padding: 0.875rem 1.5rem;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-size: 1rem;
|
||||||
|
font-weight: 600;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.2s;
|
||||||
|
box-shadow: 0 2px 4px rgba(139, 92, 246, 0.2);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-zero-large:hover:not(:disabled) {
|
||||||
|
background: #7c3aed;
|
||||||
|
box-shadow: 0 4px 8px rgba(139, 92, 246, 0.3);
|
||||||
|
transform: translateY(-1px);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-zero-large:disabled {
|
||||||
|
background: #d1d5db;
|
||||||
|
cursor: not-allowed;
|
||||||
|
box-shadow: none;
|
||||||
|
transform: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.delete-episode-section {
|
||||||
|
margin-top: 1rem;
|
||||||
|
padding-top: 1rem;
|
||||||
|
border-top: 1px solid #e5e7eb;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-delete {
|
||||||
|
width: 100%;
|
||||||
|
background: #ef4444;
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
padding: 0.875rem 1.5rem;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-size: 1rem;
|
||||||
|
font-weight: 600;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.2s;
|
||||||
|
box-shadow: 0 2px 4px rgba(239, 68, 68, 0.2);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-delete:hover:not(:disabled) {
|
||||||
|
background: #dc2626;
|
||||||
|
box-shadow: 0 4px 8px rgba(239, 68, 68, 0.3);
|
||||||
|
transform: translateY(-1px);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-delete:disabled {
|
||||||
|
background: #d1d5db;
|
||||||
|
cursor: not-allowed;
|
||||||
|
box-shadow: none;
|
||||||
|
transform: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.delete-info {
|
||||||
|
margin-top: 0.5rem;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
color: #666;
|
||||||
|
text-align: center;
|
||||||
|
font-style: italic;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-disconnect {
|
||||||
|
background: #ef4444;
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
padding: 0.5rem 1rem;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
font-weight: 500;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: background 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-disconnect:hover {
|
||||||
|
background: #dc2626;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-refresh {
|
||||||
|
background: #3b82f6;
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
padding: 0.4rem 0.8rem;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
font-weight: 500;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: background 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-refresh:hover:not(:disabled) {
|
||||||
|
background: #2563eb;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-refresh:disabled {
|
||||||
|
background: #d1d5db;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-panel {
|
||||||
|
border: 2px solid #10b981;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-banner {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 1rem;
|
||||||
|
padding: 1rem 1.5rem;
|
||||||
|
border-radius: 6px;
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
font-weight: 500;
|
||||||
|
font-size: 0.95rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-banner.initializing {
|
||||||
|
background: linear-gradient(135deg, #dbeafe 0%, #bfdbfe 100%);
|
||||||
|
color: #1e40af;
|
||||||
|
border-left: 4px solid #3b82f6;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-banner.encoding {
|
||||||
|
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
|
||||||
|
color: #92400e;
|
||||||
|
border-left: 4px solid #f59e0b;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-banner.uploading {
|
||||||
|
background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%);
|
||||||
|
color: #3730a3;
|
||||||
|
border-left: 4px solid #6366f1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-banner.success {
|
||||||
|
background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%);
|
||||||
|
color: #065f46;
|
||||||
|
border-left: 4px solid #10b981;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-banner.warning {
|
||||||
|
background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%);
|
||||||
|
color: #991b1b;
|
||||||
|
border-left: 4px solid #ef4444;
|
||||||
|
}
|
||||||
|
|
||||||
|
.spinner {
|
||||||
|
width: 20px;
|
||||||
|
height: 20px;
|
||||||
|
border: 3px solid rgba(0, 0, 0, 0.1);
|
||||||
|
border-top-color: currentColor;
|
||||||
|
border-radius: 50%;
|
||||||
|
animation: spin 0.8s linear infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes spin {
|
||||||
|
to { transform: rotate(360deg); }
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-horizontal {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-left {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-right {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.input-group {
|
||||||
|
display: flex;
|
||||||
|
gap: 0.5rem;
|
||||||
|
margin-bottom: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
input[type="text"] {
|
||||||
|
flex: 1;
|
||||||
|
padding: 0.75rem;
|
||||||
|
border: 1px solid #ddd;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
input[type="text"]:disabled {
|
||||||
|
background: #f5f5f5;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
input[type="text"]:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #10b981;
|
||||||
|
}
|
||||||
|
|
||||||
|
button {
|
||||||
|
padding: 0.75rem 1.5rem;
|
||||||
|
border: none;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 1rem;
|
||||||
|
font-weight: 500;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-set-task {
|
||||||
|
background: #3b82f6;
|
||||||
|
color: white;
|
||||||
|
min-width: 120px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-set-task:hover:not(:disabled) {
|
||||||
|
background: #2563eb;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-set-task:disabled {
|
||||||
|
background: #d1d5db;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-start {
|
||||||
|
background: #10b981;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-start:hover:not(:disabled) {
|
||||||
|
background: #059669;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-start:disabled {
|
||||||
|
background: #d1d5db;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-stop {
|
||||||
|
background: #ef4444;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-stop:hover {
|
||||||
|
background: #dc2626;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-reset {
|
||||||
|
padding: 0.5rem 1rem;
|
||||||
|
background: #6b7280;
|
||||||
|
color: white;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-reset:hover {
|
||||||
|
background: #4b5563;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.75rem;
|
||||||
|
padding: 1rem;
|
||||||
|
border-radius: 4px;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status.recording {
|
||||||
|
background: #fee2e2;
|
||||||
|
color: #991b1b;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status.recording.recording-active {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 1rem;
|
||||||
|
background: #dc2626;
|
||||||
|
color: white;
|
||||||
|
padding: 1.5rem;
|
||||||
|
border: 4px solid #991b1b;
|
||||||
|
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.4);
|
||||||
|
font-weight: 700;
|
||||||
|
font-size: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status.recording.recording-active .indicator {
|
||||||
|
width: 20px;
|
||||||
|
height: 20px;
|
||||||
|
background: #fef2f2;
|
||||||
|
animation: pulse-strong 1s ease-in-out infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulse-strong {
|
||||||
|
0%, 100% {
|
||||||
|
opacity: 1;
|
||||||
|
transform: scale(1);
|
||||||
|
}
|
||||||
|
50% {
|
||||||
|
opacity: 0.7;
|
||||||
|
transform: scale(1.1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.status.recording.recording-active .time-display {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.5rem;
|
||||||
|
font-size: 1.5rem;
|
||||||
|
font-weight: 700;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.fps-display {
|
||||||
|
font-size: 1rem;
|
||||||
|
font-weight: 500;
|
||||||
|
opacity: 0.95;
|
||||||
|
}
|
||||||
|
|
||||||
|
.fps-warning {
|
||||||
|
color: #fef2f2;
|
||||||
|
animation: pulse-warning 1s ease-in-out infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulse-warning {
|
||||||
|
0%, 100% { opacity: 1; }
|
||||||
|
50% { opacity: 0.5; }
|
||||||
|
}
|
||||||
|
|
||||||
|
.status.recording.recording-active .btn-stop {
|
||||||
|
align-self: stretch;
|
||||||
|
}
|
||||||
|
|
||||||
|
.ramp-up-countdown {
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.countdown-box {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
padding: 2rem 3rem;
|
||||||
|
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
|
||||||
|
border: 4px solid #f59e0b;
|
||||||
|
border-radius: 16px;
|
||||||
|
box-shadow: 0 6px 20px rgba(245, 158, 11, 0.4);
|
||||||
|
min-width: 280px;
|
||||||
|
animation: pulse-warm 1.5s ease-in-out infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulse-warm {
|
||||||
|
0%, 100% {
|
||||||
|
box-shadow: 0 6px 20px rgba(245, 158, 11, 0.4);
|
||||||
|
}
|
||||||
|
50% {
|
||||||
|
box-shadow: 0 6px 25px rgba(245, 158, 11, 0.6);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.countdown-label {
|
||||||
|
font-size: 1rem;
|
||||||
|
color: #92400e;
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 1.5px;
|
||||||
|
font-weight: 800;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.countdown-value {
|
||||||
|
font-size: 4.5rem;
|
||||||
|
font-weight: 900;
|
||||||
|
color: #d97706;
|
||||||
|
font-family: 'Courier New', monospace;
|
||||||
|
line-height: 1;
|
||||||
|
text-shadow: 2px 2px 6px rgba(0, 0, 0, 0.15);
|
||||||
|
margin-bottom: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.countdown-subtitle {
|
||||||
|
font-size: 0.875rem;
|
||||||
|
color: #78350f;
|
||||||
|
font-weight: 600;
|
||||||
|
font-style: italic;
|
||||||
|
text-align: center;
|
||||||
|
margin-top: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status.idle {
|
||||||
|
background: #f3f4f6;
|
||||||
|
color: #374151;
|
||||||
|
}
|
||||||
|
|
||||||
|
.indicator {
|
||||||
|
width: 12px;
|
||||||
|
height: 12px;
|
||||||
|
border-radius: 50%;
|
||||||
|
background: #ef4444;
|
||||||
|
animation: pulse 1.5s ease-in-out infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulse {
|
||||||
|
0%, 100% { opacity: 1; }
|
||||||
|
50% { opacity: 0.5; }
|
||||||
|
}
|
||||||
|
|
||||||
|
.counter {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.75rem;
|
||||||
|
padding: 1.5rem;
|
||||||
|
background: linear-gradient(135deg, #f9fafb 0%, #f3f4f6 100%);
|
||||||
|
border-radius: 8px;
|
||||||
|
border: 2px solid #e5e7eb;
|
||||||
|
min-width: 200px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.counter-label {
|
||||||
|
font-size: 0.75rem;
|
||||||
|
color: #6b7280;
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 0.5px;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
|
||||||
|
.counter-value {
|
||||||
|
font-size: 3rem;
|
||||||
|
font-weight: 700;
|
||||||
|
color: #10b981;
|
||||||
|
line-height: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.time-display {
|
||||||
|
font-size: 1.5rem;
|
||||||
|
font-weight: 600;
|
||||||
|
font-family: 'Courier New', monospace;
|
||||||
|
}
|
||||||
|
|
||||||
|
.error-box {
|
||||||
|
padding: 1rem;
|
||||||
|
background: #fee2e2;
|
||||||
|
color: #991b1b;
|
||||||
|
border-radius: 4px;
|
||||||
|
border-left: 4px solid #ef4444;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.config-section {
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.config-section:last-child {
|
||||||
|
margin-bottom: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.config-grid {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||||
|
gap: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
label {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.5rem;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
color: #374151;
|
||||||
|
font-weight: 500;
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
padding: 0.5rem;
|
||||||
|
border: 1px solid #ddd;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
background: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
select:disabled {
|
||||||
|
background: #f5f5f5;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Camera Layout */
|
||||||
|
.camera-layout {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.camera-base {
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.camera-wrist-container {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: repeat(2, 1fr);
|
||||||
|
gap: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.camera-wrist {
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.camera {
|
||||||
|
border: 1px solid #e5e7eb;
|
||||||
|
border-radius: 4px;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.camera h3 {
|
||||||
|
padding: 0.75rem;
|
||||||
|
background: #f9fafb;
|
||||||
|
border-bottom: 1px solid #e5e7eb;
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.camera img {
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
display: block;
|
||||||
|
background: #000;
|
||||||
|
min-height: 300px;
|
||||||
|
object-fit: cover;
|
||||||
|
}
|
||||||
|
|
||||||
|
.camera-placeholder {
|
||||||
|
text-align: center;
|
||||||
|
padding: 4rem 2rem;
|
||||||
|
background: #f9fafb;
|
||||||
|
border-radius: 4px;
|
||||||
|
border: 2px dashed #d1d5db;
|
||||||
|
}
|
||||||
|
|
||||||
|
.camera-placeholder p {
|
||||||
|
margin: 0.5rem 0;
|
||||||
|
font-size: 1rem;
|
||||||
|
color: #6b7280;
|
||||||
|
}
|
||||||
|
|
||||||
|
.camera-placeholder p:first-child {
|
||||||
|
font-size: 1.25rem;
|
||||||
|
font-weight: 500;
|
||||||
|
color: #374151;
|
||||||
|
}
|
||||||
|
|
||||||
|
.hint {
|
||||||
|
margin-top: 0.5rem;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
color: #6b7280;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.5rem;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
}
|
||||||
|
|
||||||
857
examples/openarms_web_interface/App.jsx
Normal file
857
examples/openarms_web_interface/App.jsx
Normal file
@@ -0,0 +1,857 @@
|
|||||||
|
import { useState, useEffect, useCallback, useRef } from 'react';
|
||||||
|
import './App.css';
|
||||||
|
|
||||||
|
const API_BASE = 'http://localhost:8000/api';
|
||||||
|
|
||||||
|
function App() {
|
||||||
|
// State
|
||||||
|
const [task, setTask] = useState('');
|
||||||
|
const [isRecording, setIsRecording] = useState(false);
|
||||||
|
const [isInitializing, setIsInitializing] = useState(false);
|
||||||
|
const [isEncoding, setIsEncoding] = useState(false);
|
||||||
|
const [isUploading, setIsUploading] = useState(false);
|
||||||
|
const [robotsReady, setRobotsReady] = useState(false);
|
||||||
|
const [elapsedTime, setElapsedTime] = useState(0);
|
||||||
|
const [currentFps, setCurrentFps] = useState(0);
|
||||||
|
const [loopFps, setLoopFps] = useState(0);
|
||||||
|
const [episodeCount, setEpisodeCount] = useState(0);
|
||||||
|
const [error, setError] = useState(null);
|
||||||
|
const [statusMessage, setStatusMessage] = useState('Ready');
|
||||||
|
const [uploadStatus, setUploadStatus] = useState(null);
|
||||||
|
const [rampUpRemaining, setRampUpRemaining] = useState(0);
|
||||||
|
const [movingToZero, setMovingToZero] = useState(false);
|
||||||
|
const [configExpanded, setConfigExpanded] = useState(false);
|
||||||
|
const [latestRepoId, setLatestRepoId] = useState(null);
|
||||||
|
|
||||||
|
// Configuration
|
||||||
|
const [config, setConfig] = useState({
|
||||||
|
leader_type: 'openarms', // 'openarms' or 'openarms_mini'
|
||||||
|
leader_left: 'can0',
|
||||||
|
leader_right: 'can1',
|
||||||
|
follower_left: 'can2',
|
||||||
|
follower_right: 'can3',
|
||||||
|
left_wrist: '/dev/video0',
|
||||||
|
right_wrist: '/dev/video1',
|
||||||
|
base: '/dev/video4'
|
||||||
|
});
|
||||||
|
|
||||||
|
// Available options
|
||||||
|
const [availableCameras, setAvailableCameras] = useState([]);
|
||||||
|
const [availableUsbPorts, setAvailableUsbPorts] = useState([]);
|
||||||
|
const canInterfaces = ['can0', 'can1', 'can2', 'can3'];
|
||||||
|
|
||||||
|
const statusIntervalRef = useRef(null);
|
||||||
|
const hasInitializedRef = useRef(false);
|
||||||
|
|
||||||
|
const loadConfig = () => {
|
||||||
|
try {
|
||||||
|
const saved = localStorage.getItem('openarms_config');
|
||||||
|
if (saved) {
|
||||||
|
const loadedConfig = JSON.parse(saved);
|
||||||
|
setConfig(prev => ({ ...prev, ...loadedConfig }));
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Load config error:', e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const saveConfig = (newConfig) => {
|
||||||
|
try {
|
||||||
|
localStorage.setItem('openarms_config', JSON.stringify(newConfig || config));
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Save config error:', e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Fetch status periodically
|
||||||
|
const fetchStatus = async () => {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${API_BASE}/status`);
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
setIsRecording(data.is_recording);
|
||||||
|
setIsInitializing(data.is_initializing);
|
||||||
|
setIsEncoding(data.is_encoding);
|
||||||
|
setIsUploading(data.is_uploading);
|
||||||
|
setRobotsReady(data.robots_ready);
|
||||||
|
setElapsedTime(data.elapsed_time);
|
||||||
|
setCurrentFps(data.current_fps || 0);
|
||||||
|
setLoopFps(data.loop_fps || 0);
|
||||||
|
setEpisodeCount(data.episode_count);
|
||||||
|
setError(data.error);
|
||||||
|
setStatusMessage(data.status_message || 'Ready');
|
||||||
|
setUploadStatus(data.upload_status);
|
||||||
|
setRampUpRemaining(data.ramp_up_remaining || 0);
|
||||||
|
setMovingToZero(data.moving_to_zero || false);
|
||||||
|
|
||||||
|
// Track the latest repo_id from the backend
|
||||||
|
if (data.latest_repo_id) {
|
||||||
|
setLatestRepoId(data.latest_repo_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data.config) {
|
||||||
|
// Only merge server config if we don't have a saved config (first load)
|
||||||
|
if (!localStorage.getItem('openarms_config')) {
|
||||||
|
setConfig(prev => {
|
||||||
|
const merged = { ...data.config, ...prev };
|
||||||
|
localStorage.setItem('openarms_config', JSON.stringify(merged));
|
||||||
|
return merged;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Failed to fetch status:', e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const setupRobots = async () => {
|
||||||
|
// Show warning to verify camera positions
|
||||||
|
const confirmed = window.confirm(
|
||||||
|
'⚠️ IMPORTANT: Before connecting robots, please verify:\n\n' +
|
||||||
|
'📹 Check that cameras are correctly positioned:\n' +
|
||||||
|
' • LEFT wrist camera is actually on the LEFT arm\n' +
|
||||||
|
' • RIGHT wrist camera is actually on the RIGHT arm\n' +
|
||||||
|
' • BASE camera is actually the BASE/overhead camera\n\n' +
|
||||||
|
'Incorrect camera positioning will result in invalid training data!\n\n' +
|
||||||
|
'Click OK to continue with robot setup, or Cancel to review configuration.'
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!confirmed) {
|
||||||
|
return; // User cancelled, don't proceed
|
||||||
|
}
|
||||||
|
|
||||||
|
setError(null);
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${API_BASE}/robots/setup`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify(config)
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const data = await response.json();
|
||||||
|
throw new Error(data.detail || 'Failed to setup robots');
|
||||||
|
}
|
||||||
|
|
||||||
|
await response.json();
|
||||||
|
saveConfig(config);
|
||||||
|
} catch (e) {
|
||||||
|
setError(`Robot setup failed: ${e.message}`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Disconnect robots
|
||||||
|
const disconnectRobots = async () => {
|
||||||
|
try {
|
||||||
|
await fetch(`${API_BASE}/robots/disconnect`, { method: 'POST' });
|
||||||
|
setRobotsReady(false);
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Failed to disconnect robots:', e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Discover cameras
|
||||||
|
const discoverCameras = async () => {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${API_BASE}/cameras/discover`);
|
||||||
|
const data = await response.json();
|
||||||
|
const cameras = data.cameras || [];
|
||||||
|
setAvailableCameras(cameras);
|
||||||
|
|
||||||
|
// Get list of valid camera IDs
|
||||||
|
const validCameraIds = cameras.map(cam => String(cam.id));
|
||||||
|
|
||||||
|
// Auto-fix config if current values are invalid or not set
|
||||||
|
const updated = { ...config };
|
||||||
|
let changed = false;
|
||||||
|
|
||||||
|
// Auto-fix invalid camera config
|
||||||
|
if (!config.left_wrist || !validCameraIds.includes(config.left_wrist)) {
|
||||||
|
if (cameras.length >= 1) {
|
||||||
|
updated.left_wrist = String(cameras[0].id);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.right_wrist || !validCameraIds.includes(config.right_wrist)) {
|
||||||
|
if (cameras.length >= 2) {
|
||||||
|
updated.right_wrist = String(cameras[1].id);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.base || !validCameraIds.includes(config.base)) {
|
||||||
|
if (cameras.length >= 3) {
|
||||||
|
updated.base = String(cameras[2].id);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (changed) {
|
||||||
|
setConfig(updated);
|
||||||
|
saveConfig(updated);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cameras.length === 0) {
|
||||||
|
setError('No cameras detected! Please connect cameras and refresh.');
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Failed to discover cameras:', e);
|
||||||
|
setError(`Camera discovery failed: ${e.message}`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Discover USB ports
|
||||||
|
const discoverUsbPorts = async () => {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${API_BASE}/usb/discover`);
|
||||||
|
const data = await response.json();
|
||||||
|
const ports = data.ports || [];
|
||||||
|
setAvailableUsbPorts(ports);
|
||||||
|
|
||||||
|
// Auto-fix config if OpenArms Mini is selected and ports are invalid
|
||||||
|
if (config.leader_type === 'openarms_mini') {
|
||||||
|
const updated = { ...config };
|
||||||
|
let changed = false;
|
||||||
|
|
||||||
|
if (ports.length >= 1 && !ports.includes(config.leader_left)) {
|
||||||
|
updated.leader_left = ports[0];
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ports.length >= 2 && !ports.includes(config.leader_right)) {
|
||||||
|
updated.leader_right = ports[1];
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (changed) {
|
||||||
|
setConfig(updated);
|
||||||
|
saveConfig(updated);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ports.length === 0) {
|
||||||
|
console.warn('No USB ports detected for OpenArms Mini');
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Failed to discover USB ports:', e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Set task only (for pedal use)
|
||||||
|
const setTaskOnly = async () => {
|
||||||
|
if (!task.trim()) {
|
||||||
|
setError('Please enter a task description');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setError(null);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${API_BASE}/recording/set-task`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({ task, ...config })
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const data = await response.json();
|
||||||
|
throw new Error(data.detail || 'Failed to set task');
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = await response.json();
|
||||||
|
setStatusMessage(result.message || `Task set: ${task}`);
|
||||||
|
saveConfig(config);
|
||||||
|
|
||||||
|
// Clear success message after 3 seconds
|
||||||
|
setTimeout(() => {
|
||||||
|
if (!isRecording && !isInitializing) {
|
||||||
|
setStatusMessage('Ready');
|
||||||
|
}
|
||||||
|
}, 3000);
|
||||||
|
} catch (e) {
|
||||||
|
setError(e.message);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Start recording
|
||||||
|
const startRecording = async () => {
|
||||||
|
if (!task.trim()) {
|
||||||
|
setError('Please enter a task description');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setError(null);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${API_BASE}/recording/start`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({ task, ...config })
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const data = await response.json();
|
||||||
|
throw new Error(data.detail || 'Failed to start recording');
|
||||||
|
}
|
||||||
|
|
||||||
|
await response.json();
|
||||||
|
saveConfig(config);
|
||||||
|
} catch (e) {
|
||||||
|
setError(e.message);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Stop recording
|
||||||
|
const stopRecording = async () => {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${API_BASE}/recording/stop`, {
|
||||||
|
method: 'POST'
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const data = await response.json();
|
||||||
|
throw new Error(data.detail || 'Failed to stop recording');
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
setError(null);
|
||||||
|
// Update latest repo_id after recording
|
||||||
|
if (data.dataset_name) {
|
||||||
|
setLatestRepoId(`lerobot-data-collection/${data.dataset_name}`);
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
setError(e.message);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const deleteLatestEpisode = async () => {
|
||||||
|
if (!latestRepoId) {
|
||||||
|
setError('No episode to delete');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const confirmed = window.confirm(
|
||||||
|
`WARNING: This will permanently delete the repository:\n\n${latestRepoId}\n\nThis action cannot be undone. Continue?`
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!confirmed) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${API_BASE}/recording/delete-latest`, { method: 'POST' });
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const data = await response.json();
|
||||||
|
throw new Error(data.detail || 'Failed to delete episode');
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
setLatestRepoId(null);
|
||||||
|
setEpisodeCount(Math.max(0, episodeCount - 1));
|
||||||
|
setStatusMessage(`Deleted: ${data.deleted_repo}`);
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
if (!isRecording && !isInitializing) {
|
||||||
|
setStatusMessage('Ready');
|
||||||
|
}
|
||||||
|
}, 3000);
|
||||||
|
} catch (e) {
|
||||||
|
setError(`Delete failed: ${e.message}`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Reset counter
|
||||||
|
const resetCounter = async () => {
|
||||||
|
try {
|
||||||
|
await fetch(`${API_BASE}/counter/reset`, { method: 'POST' });
|
||||||
|
setEpisodeCount(0);
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Failed to reset counter:', e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Move robot to zero position
|
||||||
|
const moveToZero = async () => {
|
||||||
|
setError(null);
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${API_BASE}/robots/move-to-zero`, { method: 'POST' });
|
||||||
|
if (!response.ok) {
|
||||||
|
const data = await response.json();
|
||||||
|
throw new Error(data.detail || 'Failed to move to zero position');
|
||||||
|
}
|
||||||
|
await response.json();
|
||||||
|
} catch (e) {
|
||||||
|
setError(`Move to zero failed: ${e.message}`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Format time as MM:SS
|
||||||
|
const formatTime = (seconds) => {
|
||||||
|
const mins = Math.floor(seconds / 60);
|
||||||
|
const secs = Math.floor(seconds % 60);
|
||||||
|
return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}`;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Update config and save
|
||||||
|
const updateConfig = (key, value) => {
|
||||||
|
const updated = { ...config, [key]: value };
|
||||||
|
setConfig(updated);
|
||||||
|
saveConfig(updated);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Initialize on mount only
|
||||||
|
useEffect(() => {
|
||||||
|
// Prevent double-initialization in development
|
||||||
|
if (hasInitializedRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
hasInitializedRef.current = true;
|
||||||
|
|
||||||
|
loadConfig();
|
||||||
|
discoverCameras();
|
||||||
|
discoverUsbPorts();
|
||||||
|
fetchStatus();
|
||||||
|
statusIntervalRef.current = setInterval(fetchStatus, 1000);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
if (statusIntervalRef.current) {
|
||||||
|
clearInterval(statusIntervalRef.current);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, []); // Run only once on mount
|
||||||
|
|
||||||
|
// Discover USB ports when leader type changes to Mini
|
||||||
|
useEffect(() => {
|
||||||
|
if (config.leader_type === 'openarms_mini') {
|
||||||
|
discoverUsbPorts();
|
||||||
|
}
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [config.leader_type]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<main>
|
||||||
|
<header>
|
||||||
|
<h1>OpenArms Recording</h1>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
<div className="container">
|
||||||
|
{/* Left Column: Configuration and Recording Control */}
|
||||||
|
<div className="left-column">
|
||||||
|
{/* Configuration Panel */}
|
||||||
|
<section className="panel config-panel">
|
||||||
|
<div
|
||||||
|
className="config-header"
|
||||||
|
onClick={() => setConfigExpanded(!configExpanded)}
|
||||||
|
role="button"
|
||||||
|
tabIndex={0}
|
||||||
|
onKeyDown={(e) => e.key === 'Enter' && setConfigExpanded(!configExpanded)}
|
||||||
|
>
|
||||||
|
<h2>⚙️ Configuration</h2>
|
||||||
|
<span className="toggle-icon">{configExpanded ? '▼' : '▶'}</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{configExpanded && (
|
||||||
|
<div className="config-content">
|
||||||
|
{/* Robot Setup */}
|
||||||
|
<div className="config-section">
|
||||||
|
<h3>🤖 Robot Setup</h3>
|
||||||
|
<div className="robot-setup">
|
||||||
|
{robotsReady ? (
|
||||||
|
<div className="robot-status ready">
|
||||||
|
<span>✅ Robots Ready - Recording will start instantly</span>
|
||||||
|
<button onClick={disconnectRobots} className="btn-disconnect">
|
||||||
|
Disconnect Robots
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="robot-status not-ready">
|
||||||
|
<span>⚠️ Robots not initialized - Recording will take ~10 seconds</span>
|
||||||
|
<button
|
||||||
|
onClick={setupRobots}
|
||||||
|
disabled={isRecording || isInitializing}
|
||||||
|
className="btn-setup"
|
||||||
|
>
|
||||||
|
🚀 Setup Robots
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Leader Type Selection */}
|
||||||
|
<div className="config-section">
|
||||||
|
<h3>🎮 Leader Type</h3>
|
||||||
|
<div className="config-grid">
|
||||||
|
<label style={{gridColumn: '1 / -1'}}>
|
||||||
|
Leader Arm Type
|
||||||
|
<select
|
||||||
|
value={config.leader_type}
|
||||||
|
onChange={(e) => updateConfig('leader_type', e.target.value)}
|
||||||
|
disabled={isRecording || robotsReady}
|
||||||
|
>
|
||||||
|
<option value="openarms">OpenArms (CAN Bus - Damiao Motors)</option>
|
||||||
|
<option value="openarms_mini">OpenArms Mini (USB - Feetech Motors)</option>
|
||||||
|
</select>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Leader Interfaces (CAN or USB based on type) */}
|
||||||
|
<div className="config-section">
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
|
||||||
|
<h3>
|
||||||
|
{config.leader_type === 'openarms_mini'
|
||||||
|
? `Leader Ports (USB/Serial) ${availableUsbPorts.length > 0 ? `(${availableUsbPorts.length} detected)` : ''}`
|
||||||
|
: 'Leader Interfaces (CAN)'}
|
||||||
|
</h3>
|
||||||
|
{config.leader_type === 'openarms_mini' && (
|
||||||
|
<button
|
||||||
|
onClick={discoverUsbPorts}
|
||||||
|
className="btn-refresh"
|
||||||
|
disabled={isRecording || robotsReady}
|
||||||
|
>
|
||||||
|
🔄 Refresh
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="config-grid">
|
||||||
|
<label>
|
||||||
|
Leader Left
|
||||||
|
<select
|
||||||
|
value={config.leader_left}
|
||||||
|
onChange={(e) => updateConfig('leader_left', e.target.value)}
|
||||||
|
disabled={isRecording || robotsReady}
|
||||||
|
>
|
||||||
|
{config.leader_type === 'openarms_mini' ? (
|
||||||
|
availableUsbPorts.length > 0 ? (
|
||||||
|
availableUsbPorts.map((port) => (
|
||||||
|
<option key={port} value={port}>{port}</option>
|
||||||
|
))
|
||||||
|
) : (
|
||||||
|
<option value="">No USB ports detected</option>
|
||||||
|
)
|
||||||
|
) : (
|
||||||
|
canInterfaces.map((iface) => (
|
||||||
|
<option key={iface} value={iface}>{iface}</option>
|
||||||
|
))
|
||||||
|
)}
|
||||||
|
</select>
|
||||||
|
</label>
|
||||||
|
|
||||||
|
<label>
|
||||||
|
Leader Right
|
||||||
|
<select
|
||||||
|
value={config.leader_right}
|
||||||
|
onChange={(e) => updateConfig('leader_right', e.target.value)}
|
||||||
|
disabled={isRecording || robotsReady}
|
||||||
|
>
|
||||||
|
{config.leader_type === 'openarms_mini' ? (
|
||||||
|
availableUsbPorts.length > 0 ? (
|
||||||
|
availableUsbPorts.map((port) => (
|
||||||
|
<option key={port} value={port}>{port}</option>
|
||||||
|
))
|
||||||
|
) : (
|
||||||
|
<option value="">No USB ports detected</option>
|
||||||
|
)
|
||||||
|
) : (
|
||||||
|
canInterfaces.map((iface) => (
|
||||||
|
<option key={iface} value={iface}>{iface}</option>
|
||||||
|
))
|
||||||
|
)}
|
||||||
|
</select>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Follower CAN Interfaces */}
|
||||||
|
<div className="config-section">
|
||||||
|
<h3>Follower Interfaces (CAN)</h3>
|
||||||
|
|
||||||
|
<div className="config-grid">
|
||||||
|
<label>
|
||||||
|
Follower Left
|
||||||
|
<select
|
||||||
|
value={config.follower_left}
|
||||||
|
onChange={(e) => updateConfig('follower_left', e.target.value)}
|
||||||
|
disabled={isRecording || robotsReady}
|
||||||
|
>
|
||||||
|
{canInterfaces.map((iface) => (
|
||||||
|
<option key={iface} value={iface}>{iface}</option>
|
||||||
|
))}
|
||||||
|
</select>
|
||||||
|
</label>
|
||||||
|
|
||||||
|
<label>
|
||||||
|
Follower Right
|
||||||
|
<select
|
||||||
|
value={config.follower_right}
|
||||||
|
onChange={(e) => updateConfig('follower_right', e.target.value)}
|
||||||
|
disabled={isRecording || robotsReady}
|
||||||
|
>
|
||||||
|
{canInterfaces.map((iface) => (
|
||||||
|
<option key={iface} value={iface}>{iface}</option>
|
||||||
|
))}
|
||||||
|
</select>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Camera Configuration */}
|
||||||
|
<div className="config-section">
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
|
||||||
|
<h3>Cameras {availableCameras.length > 0 && `(${availableCameras.length} detected)`}</h3>
|
||||||
|
<button
|
||||||
|
onClick={discoverCameras}
|
||||||
|
className="btn-refresh"
|
||||||
|
disabled={isRecording || robotsReady}
|
||||||
|
>
|
||||||
|
🔄 Refresh
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div className="config-grid">
|
||||||
|
<label>
|
||||||
|
Left Wrist
|
||||||
|
<select
|
||||||
|
value={config.left_wrist}
|
||||||
|
onChange={(e) => updateConfig('left_wrist', e.target.value)}
|
||||||
|
disabled={isRecording || robotsReady}
|
||||||
|
>
|
||||||
|
{availableCameras.map((cam) => (
|
||||||
|
<option key={cam.id} value={String(cam.id)}>
|
||||||
|
{cam.name || `Camera @ ${cam.id}`}
|
||||||
|
</option>
|
||||||
|
))}
|
||||||
|
</select>
|
||||||
|
</label>
|
||||||
|
|
||||||
|
<label>
|
||||||
|
Right Wrist
|
||||||
|
<select
|
||||||
|
value={config.right_wrist}
|
||||||
|
onChange={(e) => updateConfig('right_wrist', e.target.value)}
|
||||||
|
disabled={isRecording || robotsReady}
|
||||||
|
>
|
||||||
|
{availableCameras.map((cam) => (
|
||||||
|
<option key={cam.id} value={String(cam.id)}>
|
||||||
|
{cam.name || `Camera @ ${cam.id}`}
|
||||||
|
</option>
|
||||||
|
))}
|
||||||
|
</select>
|
||||||
|
</label>
|
||||||
|
|
||||||
|
<label>
|
||||||
|
Base Camera
|
||||||
|
<select
|
||||||
|
value={config.base}
|
||||||
|
onChange={(e) => updateConfig('base', e.target.value)}
|
||||||
|
disabled={isRecording || robotsReady}
|
||||||
|
>
|
||||||
|
{availableCameras.map((cam) => (
|
||||||
|
<option key={cam.id} value={String(cam.id)}>
|
||||||
|
{cam.name || `Camera @ ${cam.id}`}
|
||||||
|
</option>
|
||||||
|
))}
|
||||||
|
</select>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</section>
|
||||||
|
|
||||||
|
{/* Control Panel */}
|
||||||
|
<section className="panel control-panel">
|
||||||
|
<h2>🎬 Recording Control</h2>
|
||||||
|
|
||||||
|
{/* Status Banner - Always show important statuses */}
|
||||||
|
{isInitializing && (
|
||||||
|
<div className="status-banner initializing">
|
||||||
|
<div className="spinner"></div>
|
||||||
|
<span>{statusMessage}</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{isEncoding && (
|
||||||
|
<div className="status-banner encoding">
|
||||||
|
<div className="spinner"></div>
|
||||||
|
<span>📹 {statusMessage}</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{isUploading && (
|
||||||
|
<div className="status-banner uploading">
|
||||||
|
<div className="spinner"></div>
|
||||||
|
<span>☁️ {statusMessage}</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{uploadStatus && !isRecording && !isEncoding && !isUploading && (
|
||||||
|
<div className={`status-banner ${uploadStatus.startsWith('✓') ? 'success' : 'warning'}`}>
|
||||||
|
<span>{uploadStatus}</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="control-horizontal">
|
||||||
|
{/* Task Input and Status */}
|
||||||
|
<div className="control-left">
|
||||||
|
<div className="input-group">
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={task}
|
||||||
|
onChange={(e) => setTask(e.target.value)}
|
||||||
|
placeholder="Task description (e.g., 'pick and place')"
|
||||||
|
disabled={isRecording || isInitializing || isEncoding || isUploading}
|
||||||
|
onKeyPress={(e) => {
|
||||||
|
if (e.key === 'Enter' && robotsReady) {
|
||||||
|
setTaskOnly();
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
onClick={setTaskOnly}
|
||||||
|
disabled={isRecording || isInitializing || isEncoding || isUploading || !robotsReady}
|
||||||
|
className="btn-set-task"
|
||||||
|
title={!robotsReady ? 'Please setup robots first' : 'Store task for pedal use (Enter key)'}
|
||||||
|
>
|
||||||
|
💾 Set Task
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={startRecording}
|
||||||
|
disabled={isRecording || isInitializing || isEncoding || isUploading || !robotsReady}
|
||||||
|
className="btn-start"
|
||||||
|
title={!robotsReady ? 'Please setup robots first' : ''}
|
||||||
|
>
|
||||||
|
{isInitializing
|
||||||
|
? '⏳ Initializing...'
|
||||||
|
: isRecording
|
||||||
|
? '⏺ Recording...'
|
||||||
|
: robotsReady
|
||||||
|
? '⏺ Start Recording'
|
||||||
|
: '⏺ Setup Robots First'}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Ramp-up Countdown */}
|
||||||
|
{isRecording && rampUpRemaining > 0 && (
|
||||||
|
<div className="ramp-up-countdown">
|
||||||
|
<div className="countdown-box">
|
||||||
|
<div className="countdown-label">⚡ WARMING UP - PID RAMP-UP</div>
|
||||||
|
<div className="countdown-value">{rampUpRemaining.toFixed(1)}s</div>
|
||||||
|
<div className="countdown-subtitle">Recording will start automatically...</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Recording Status - Only show after ramp-up */}
|
||||||
|
{isRecording && rampUpRemaining <= 0 && (
|
||||||
|
<div className="status recording recording-active">
|
||||||
|
<div className="indicator"></div>
|
||||||
|
<div className="time-display">
|
||||||
|
<span>{formatTime(elapsedTime)}</span>
|
||||||
|
<span className="fps-display">
|
||||||
|
Loop: {loopFps.toFixed(1)} Hz
|
||||||
|
{loopFps > 0 && loopFps < 29 && <span className="fps-warning"> ⚠️</span>}
|
||||||
|
</span>
|
||||||
|
<span className="fps-display">Recording: {currentFps.toFixed(1)} FPS</span>
|
||||||
|
</div>
|
||||||
|
<button onClick={stopRecording} className="btn-stop">
|
||||||
|
⏹ Stop
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Episode Counter */}
|
||||||
|
<div className="control-right">
|
||||||
|
<div className="counter">
|
||||||
|
<div className="counter-label">Episodes Recorded</div>
|
||||||
|
<div className="counter-value">{episodeCount}</div>
|
||||||
|
<button onClick={resetCounter} className="btn-reset">
|
||||||
|
Reset
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Delete Latest Episode Button */}
|
||||||
|
{!isRecording && !isInitializing && latestRepoId && (
|
||||||
|
<div className="delete-episode-section">
|
||||||
|
<button
|
||||||
|
onClick={deleteLatestEpisode}
|
||||||
|
className="btn-delete"
|
||||||
|
title="Delete the latest recorded episode from HuggingFace Hub"
|
||||||
|
>
|
||||||
|
Delete Latest Episode
|
||||||
|
</button>
|
||||||
|
<div className="delete-info">Will delete: {latestRepoId}</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Move to Zero Button */}
|
||||||
|
{robotsReady && !isRecording && !isInitializing && (
|
||||||
|
<div className="zero-position-section">
|
||||||
|
<button
|
||||||
|
onClick={moveToZero}
|
||||||
|
disabled={movingToZero}
|
||||||
|
className="btn-zero-large"
|
||||||
|
title="Move both leader and follower robots to zero position (2s)"
|
||||||
|
>
|
||||||
|
{movingToZero ? '⏳ Moving to Zero Position...' : '🎯 Move to Zero Position (Leader + Follower)'}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Error Display */}
|
||||||
|
{error && (
|
||||||
|
<div className="error-box">
|
||||||
|
⚠️ {error}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</section>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Right Column: Camera Feeds */}
|
||||||
|
<div className="right-column">
|
||||||
|
<section className="panel cameras">
|
||||||
|
<h2>📹 Camera Views</h2>
|
||||||
|
{robotsReady || isRecording || isInitializing ? (
|
||||||
|
<div className="camera-layout">
|
||||||
|
{/* Base camera - full width */}
|
||||||
|
<div className="camera camera-base">
|
||||||
|
<h3>Base Camera</h3>
|
||||||
|
<img src={`${API_BASE}/camera/stream/base`} alt="Base Camera" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Wrist cameras - side by side */}
|
||||||
|
<div className="camera-wrist-container">
|
||||||
|
<div className="camera camera-wrist">
|
||||||
|
<h3>Left Wrist</h3>
|
||||||
|
<img src={`${API_BASE}/camera/stream/left_wrist`} alt="Left Wrist Camera" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="camera camera-wrist">
|
||||||
|
<h3>Right Wrist</h3>
|
||||||
|
<img src={`${API_BASE}/camera/stream/right_wrist`} alt="Right Wrist Camera" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="camera-placeholder">
|
||||||
|
<p>📷 Camera feeds will appear when robots are set up</p>
|
||||||
|
<p className="hint">Click "Setup Robots" above to preview camera feeds</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</section>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
</main>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default App;
|
||||||
|
|
||||||
41
examples/openarms_web_interface/README.md
Normal file
41
examples/openarms_web_interface/README.md
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# OpenArms Web Recording Interface
|
||||||
|
|
||||||
|
A web interface for recording OpenArms datasets.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd examples/openarms_web_interface
|
||||||
|
npm install
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
**Start everything with one command:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./launch.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
This will:
|
||||||
|
- Start the FastAPI backend on port 8000
|
||||||
|
- Start the React frontend on port 5173
|
||||||
|
- Show live logs from both services
|
||||||
|
|
||||||
|
Then open your browser to: **http://localhost:5173**
|
||||||
|
|
||||||
|
**Stop with:** `Ctrl+C`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Workflow
|
||||||
|
|
||||||
|
1. **Configure CAN interfaces** and **camera paths** in the dropdowns
|
||||||
|
2. Click **"Setup Robots"** to initialize (once at start)
|
||||||
|
3. Enter a **task description**
|
||||||
|
4. Click **"Start Recording"** to begin an episode
|
||||||
|
5. Click **"Stop Recording"** when done
|
||||||
|
6. Dataset is automatically encoded and uploaded to HuggingFace Hub as **private**
|
||||||
|
7. Repeat steps 3-6 for more episodes (no need to re-setup robots!)
|
||||||
|
|
||||||
|
---
|
||||||
12
examples/openarms_web_interface/index.html
Normal file
12
examples/openarms_web_interface/index.html
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
<!doctype html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<title>OpenArms Recording Interface</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="root"></div>
|
||||||
|
<script type="module" src="/main.jsx"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
142
examples/openarms_web_interface/launch.sh
Executable file
142
examples/openarms_web_interface/launch.sh
Executable file
@@ -0,0 +1,142 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# OpenArms Web Interface Launcher
|
||||||
|
# Starts Rerun viewer, FastAPI backend, and React frontend
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Colors for output
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
BLUE='\033[0;34m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
RED='\033[0;31m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
# Get script directory
|
||||||
|
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||||
|
cd "$SCRIPT_DIR"
|
||||||
|
|
||||||
|
echo -e "${BLUE}╔════════════════════════════════════════╗${NC}"
|
||||||
|
echo -e "${BLUE}║ OpenArms Web Recording Interface ║${NC}"
|
||||||
|
echo -e "${BLUE}╚════════════════════════════════════════╝${NC}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Function to cleanup on exit
|
||||||
|
cleanup() {
|
||||||
|
echo ""
|
||||||
|
echo -e "${YELLOW}Shutting down services...${NC}"
|
||||||
|
|
||||||
|
# Kill all child processes
|
||||||
|
pkill -P $$ 2>/dev/null || true
|
||||||
|
|
||||||
|
# Kill specific services by port
|
||||||
|
lsof -ti:8000 | xargs kill -9 2>/dev/null || true # Backend
|
||||||
|
lsof -ti:5173 | xargs kill -9 2>/dev/null || true # Frontend
|
||||||
|
lsof -ti:9876 | xargs kill -9 2>/dev/null || true # Rerun (if spawned)
|
||||||
|
|
||||||
|
echo -e "${GREEN}✓ Services stopped${NC}"
|
||||||
|
exit 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Register cleanup on script exit
|
||||||
|
trap cleanup EXIT INT TERM
|
||||||
|
|
||||||
|
# Check if required commands exist
|
||||||
|
command -v rerun >/dev/null 2>&1 || {
|
||||||
|
echo -e "${RED}✗ Error: 'rerun' not found. Please install: pip install rerun-sdk${NC}"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
command -v python >/dev/null 2>&1 || {
|
||||||
|
echo -e "${RED}✗ Error: 'python' not found${NC}"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
command -v npm >/dev/null 2>&1 || {
|
||||||
|
echo -e "${RED}✗ Error: 'npm' not found${NC}"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if node_modules exists
|
||||||
|
if [ ! -d "node_modules" ]; then
|
||||||
|
echo -e "${YELLOW}⚠ node_modules not found. Running npm install...${NC}"
|
||||||
|
npm install
|
||||||
|
echo -e "${GREEN}✓ Dependencies installed${NC}"
|
||||||
|
echo ""
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo -e "${GREEN}Starting services...${NC}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# 1. Start FastAPI backend (Rerun will start when recording begins)
|
||||||
|
echo -e "${BLUE}[1/2]${NC} Starting FastAPI backend on port 8000..."
|
||||||
|
cd "$SCRIPT_DIR"
|
||||||
|
|
||||||
|
# Use Python from current environment (if lerobot env is active, it will use that)
|
||||||
|
# Otherwise, check if we need to use conda run
|
||||||
|
if [[ "$CONDA_DEFAULT_ENV" == "lerobot" ]]; then
|
||||||
|
# Already in lerobot environment
|
||||||
|
echo -e "${GREEN}✓ Using active lerobot environment${NC}"
|
||||||
|
PYTHON_CMD="python"
|
||||||
|
elif command -v conda >/dev/null 2>&1 && conda env list | grep -q "^lerobot "; then
|
||||||
|
# lerobot env exists but not active - use conda run
|
||||||
|
echo -e "${YELLOW}Using conda run with lerobot environment...${NC}"
|
||||||
|
PYTHON_CMD="conda run -n lerobot --no-capture-output python"
|
||||||
|
else
|
||||||
|
# Fall back to system python
|
||||||
|
echo -e "${YELLOW}⚠ Warning: lerobot environment not found, using system python${NC}"
|
||||||
|
PYTHON_CMD="python"
|
||||||
|
fi
|
||||||
|
|
||||||
|
$PYTHON_CMD web_record_server.py > /tmp/openarms_backend.log 2>&1 &
|
||||||
|
BACKEND_PID=$!
|
||||||
|
sleep 3
|
||||||
|
|
||||||
|
if ps -p $BACKEND_PID > /dev/null; then
|
||||||
|
echo -e "${GREEN}✓ Backend started${NC} (PID: $BACKEND_PID)"
|
||||||
|
echo -e " URL: ${BLUE}http://localhost:8000${NC}"
|
||||||
|
else
|
||||||
|
echo -e "${RED}✗ Failed to start backend${NC}"
|
||||||
|
echo -e "${YELLOW}Check logs: tail -f /tmp/openarms_backend.log${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# 2. Start React frontend
|
||||||
|
echo -e "${BLUE}[2/2]${NC} Starting React frontend on port 5173..."
|
||||||
|
cd "$SCRIPT_DIR"
|
||||||
|
npm run dev > /tmp/openarms_frontend.log 2>&1 &
|
||||||
|
FRONTEND_PID=$!
|
||||||
|
sleep 3
|
||||||
|
|
||||||
|
if ps -p $FRONTEND_PID > /dev/null; then
|
||||||
|
echo -e "${GREEN}✓ Frontend started${NC} (PID: $FRONTEND_PID)"
|
||||||
|
echo -e " URL: ${BLUE}http://localhost:5173${NC}"
|
||||||
|
else
|
||||||
|
echo -e "${RED}✗ Failed to start frontend${NC}"
|
||||||
|
echo -e "${YELLOW}Check logs: tail -f /tmp/openarms_frontend.log${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Display status
|
||||||
|
echo -e "${GREEN}╔════════════════════════════════════════╗${NC}"
|
||||||
|
echo -e "${GREEN}║ All services running! 🚀 ║${NC}"
|
||||||
|
echo -e "${GREEN}╚════════════════════════════════════════╝${NC}"
|
||||||
|
echo ""
|
||||||
|
echo -e "🔧 ${BLUE}Backend:${NC} http://localhost:8000"
|
||||||
|
echo -e "🌐 ${BLUE}Frontend:${NC} http://localhost:5173"
|
||||||
|
echo -e "📊 ${BLUE}Rerun:${NC} Will spawn automatically when recording starts"
|
||||||
|
echo ""
|
||||||
|
echo -e "${YELLOW}Open your browser to:${NC} ${BLUE}http://localhost:5173${NC}"
|
||||||
|
echo ""
|
||||||
|
echo -e "${YELLOW}Logs:${NC}"
|
||||||
|
echo -e " • Backend: tail -f /tmp/openarms_backend.log"
|
||||||
|
echo -e " • Frontend: tail -f /tmp/openarms_frontend.log"
|
||||||
|
echo ""
|
||||||
|
echo -e "${RED}Press Ctrl+C to stop all services${NC}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Keep script running and wait for any service to exit
|
||||||
|
wait
|
||||||
|
|
||||||
7
examples/openarms_web_interface/main.jsx
Normal file
7
examples/openarms_web_interface/main.jsx
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
import { createRoot } from 'react-dom/client'
|
||||||
|
import App from './App.jsx'
|
||||||
|
|
||||||
|
createRoot(document.getElementById('root')).render(
|
||||||
|
<App />
|
||||||
|
)
|
||||||
|
|
||||||
1955
examples/openarms_web_interface/package-lock.json
generated
Normal file
1955
examples/openarms_web_interface/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
21
examples/openarms_web_interface/package.json
Normal file
21
examples/openarms_web_interface/package.json
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"name": "openarms-web-interface",
|
||||||
|
"private": true,
|
||||||
|
"version": "0.0.0",
|
||||||
|
"type": "module",
|
||||||
|
"scripts": {
|
||||||
|
"dev": "vite",
|
||||||
|
"build": "vite build",
|
||||||
|
"preview": "vite preview"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"react": "^18.3.1",
|
||||||
|
"react-dom": "^18.3.1"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@types/react": "^18.3.12",
|
||||||
|
"@types/react-dom": "^18.3.1",
|
||||||
|
"@vitejs/plugin-react": "^4.3.4",
|
||||||
|
"vite": "^6.0.1"
|
||||||
|
}
|
||||||
|
}
|
||||||
17
examples/openarms_web_interface/vite.config.js
Normal file
17
examples/openarms_web_interface/vite.config.js
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import { defineConfig } from 'vite'
|
||||||
|
import react from '@vitejs/plugin-react'
|
||||||
|
|
||||||
|
// https://vite.dev/config/
|
||||||
|
export default defineConfig({
|
||||||
|
plugins: [react()],
|
||||||
|
server: {
|
||||||
|
port: 5173,
|
||||||
|
strictPort: false,
|
||||||
|
host: true,
|
||||||
|
open: false
|
||||||
|
},
|
||||||
|
build: {
|
||||||
|
outDir: 'dist',
|
||||||
|
sourcemap: true
|
||||||
|
}
|
||||||
|
})
|
||||||
1533
examples/openarms_web_interface/web_record_server.py
Normal file
1533
examples/openarms_web_interface/web_record_server.py
Normal file
File diff suppressed because it is too large
Load Diff
347
examples/unitree_g1/gr00t_locomotion.py
Normal file
347
examples/unitree_g1/gr00t_locomotion.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Example: GR00T Locomotion with Pre-loaded Policies
|
||||||
|
|
||||||
|
This example demonstrates the NEW pattern for loading GR00T policies externally
|
||||||
|
and passing them to the robot class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||||
|
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
GROOT_DEFAULT_ANGLES = np.zeros(29, dtype=np.float32)
|
||||||
|
GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # hip pitch
|
||||||
|
GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # knee
|
||||||
|
GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # ankle pitch
|
||||||
|
|
||||||
|
MISSING_JOINTS = []
|
||||||
|
G1_MODEL = "g1_23" # or "g1_29"
|
||||||
|
if G1_MODEL == "g1_23":
|
||||||
|
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # waist yaw/pitch, wrist pitch/yaw
|
||||||
|
|
||||||
|
LOCOMOTION_ACTION_SCALE = 0.25
|
||||||
|
|
||||||
|
LOCOMOTION_CONTROL_DT = 0.02
|
||||||
|
|
||||||
|
ANG_VEL_SCALE: float = 0.25
|
||||||
|
DOF_POS_SCALE: float = 1.0
|
||||||
|
DOF_VEL_SCALE: float = 0.05
|
||||||
|
CMD_SCALE: list = [2.0, 2.0, 0.25]
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
|
||||||
|
|
||||||
|
|
||||||
|
def load_groot_policies(
|
||||||
|
repo_id: str = DEFAULT_GROOT_REPO_ID,
|
||||||
|
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
|
||||||
|
"""Load GR00T dual-policy system (Balance + Walk) from Hugging Face Hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id: Hugging Face Hub repository ID containing the ONNX policies.
|
||||||
|
"""
|
||||||
|
logger.info(f"Loading GR00T dual-policy system from Hugging Face Hub ({repo_id})...")
|
||||||
|
|
||||||
|
# Download ONNX policies from Hugging Face Hub
|
||||||
|
balance_path = hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename="GR00T-WholeBodyControl-Balance.onnx",
|
||||||
|
)
|
||||||
|
walk_path = hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename="GR00T-WholeBodyControl-Walk.onnx",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load ONNX policies
|
||||||
|
policy_balance = ort.InferenceSession(balance_path)
|
||||||
|
policy_walk = ort.InferenceSession(walk_path)
|
||||||
|
|
||||||
|
logger.info("GR00T policies loaded successfully")
|
||||||
|
|
||||||
|
return policy_balance, policy_walk
|
||||||
|
|
||||||
|
|
||||||
|
class GrootLocomotionController:
|
||||||
|
"""
|
||||||
|
Handles GR00T-style locomotion control for the Unitree G1 robot.
|
||||||
|
|
||||||
|
This controller manages:
|
||||||
|
- Dual-policy system (Balance + Walk)
|
||||||
|
- 29-joint observation processing
|
||||||
|
- 15D action output (legs + waist)
|
||||||
|
- Policy inference and motor command generation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, policy_balance, policy_walk, robot, config):
|
||||||
|
self.policy_balance = policy_balance
|
||||||
|
self.policy_walk = policy_walk
|
||||||
|
self.robot = robot
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot
|
||||||
|
|
||||||
|
# GR00T-specific state
|
||||||
|
self.groot_qj_all = np.zeros(29, dtype=np.float32)
|
||||||
|
self.groot_dqj_all = np.zeros(29, dtype=np.float32)
|
||||||
|
self.groot_action = np.zeros(15, dtype=np.float32)
|
||||||
|
self.groot_obs_single = np.zeros(86, dtype=np.float32)
|
||||||
|
self.groot_obs_history = deque(maxlen=6)
|
||||||
|
self.groot_obs_stacked = np.zeros(516, dtype=np.float32)
|
||||||
|
self.groot_height_cmd = 0.74 # Default base height
|
||||||
|
self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||||
|
|
||||||
|
# input to gr00t is 6 frames (6*86D=516)
|
||||||
|
for _ in range(6):
|
||||||
|
self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
|
||||||
|
|
||||||
|
# Thread management
|
||||||
|
self.locomotion_running = False
|
||||||
|
self.locomotion_thread = None
|
||||||
|
|
||||||
|
logger.info("GrootLocomotionController initialized")
|
||||||
|
|
||||||
|
def groot_locomotion_run(self):
|
||||||
|
# get current observation
|
||||||
|
robot_state = self.robot.get_observation()
|
||||||
|
|
||||||
|
if robot_state is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# get command from remote controller
|
||||||
|
if robot_state.wireless_remote is not None:
|
||||||
|
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||||
|
if self.robot.remote_controller.button[0]: # R1 - raise waist
|
||||||
|
self.groot_height_cmd += 0.001
|
||||||
|
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||||
|
if self.robot.remote_controller.button[4]: # R2 - lower waist
|
||||||
|
self.groot_height_cmd -= 0.001
|
||||||
|
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||||
|
else:
|
||||||
|
self.robot.remote_controller.lx = 0.0
|
||||||
|
self.robot.remote_controller.ly = 0.0
|
||||||
|
self.robot.remote_controller.rx = 0.0
|
||||||
|
self.robot.remote_controller.ry = 0.0
|
||||||
|
|
||||||
|
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
|
||||||
|
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right
|
||||||
|
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # rotation rate
|
||||||
|
|
||||||
|
for i in range(29):
|
||||||
|
self.groot_qj_all[i] = robot_state.motor_state[i].q
|
||||||
|
self.groot_dqj_all[i] = robot_state.motor_state[i].dq
|
||||||
|
|
||||||
|
# adapt observation for g1_23dof
|
||||||
|
for idx in MISSING_JOINTS:
|
||||||
|
self.groot_qj_all[idx] = 0.0
|
||||||
|
self.groot_dqj_all[idx] = 0.0
|
||||||
|
|
||||||
|
# Scale joint positions and velocities
|
||||||
|
qj_obs = self.groot_qj_all.copy()
|
||||||
|
dqj_obs = self.groot_dqj_all.copy()
|
||||||
|
|
||||||
|
# express imu data in gravity frame of reference
|
||||||
|
quat = robot_state.imu_state.quaternion
|
||||||
|
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||||
|
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||||
|
|
||||||
|
# scale joint positions and velocities before policy inference
|
||||||
|
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
|
||||||
|
dqj_obs = dqj_obs * DOF_VEL_SCALE
|
||||||
|
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
||||||
|
|
||||||
|
# build single frame observation
|
||||||
|
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE)
|
||||||
|
self.groot_obs_single[3] = self.groot_height_cmd
|
||||||
|
self.groot_obs_single[4:7] = self.groot_orientation_cmd
|
||||||
|
self.groot_obs_single[7:10] = ang_vel_scaled
|
||||||
|
self.groot_obs_single[10:13] = gravity_orientation
|
||||||
|
self.groot_obs_single[13:42] = qj_obs
|
||||||
|
self.groot_obs_single[42:71] = dqj_obs
|
||||||
|
self.groot_obs_single[71:86] = self.groot_action # 15D previous actions
|
||||||
|
|
||||||
|
# Add to history and stack observations (6 frames × 86D = 516D)
|
||||||
|
self.groot_obs_history.append(self.groot_obs_single.copy())
|
||||||
|
|
||||||
|
# Stack all 6 frames into 516D vector
|
||||||
|
for i, obs_frame in enumerate(self.groot_obs_history):
|
||||||
|
start_idx = i * 86
|
||||||
|
end_idx = start_idx + 86
|
||||||
|
self.groot_obs_stacked[start_idx:end_idx] = obs_frame
|
||||||
|
|
||||||
|
# Run policy inference (ONNX) with 516D stacked observation
|
||||||
|
|
||||||
|
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
|
||||||
|
|
||||||
|
selected_policy = (
|
||||||
|
self.policy_balance if cmd_magnitude < 0.05 else self.policy_walk
|
||||||
|
) # balance/standing policy for small commands, walking policy for movement commands
|
||||||
|
|
||||||
|
# run policy inference
|
||||||
|
ort_inputs = {selected_policy.get_inputs()[0].name: np.expand_dims(self.groot_obs_stacked, axis=0)}
|
||||||
|
ort_outs = selected_policy.run(None, ort_inputs)
|
||||||
|
self.groot_action = ort_outs[0].squeeze()
|
||||||
|
|
||||||
|
# transform action back to target joint positions
|
||||||
|
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * LOCOMOTION_ACTION_SCALE
|
||||||
|
|
||||||
|
# command motors
|
||||||
|
for i in range(15):
|
||||||
|
motor_idx = i
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
# adapt action for g1_23dof
|
||||||
|
for joint_idx in MISSING_JOINTS:
|
||||||
|
self.robot.msg.motor_cmd[joint_idx].q = 0.0
|
||||||
|
self.robot.msg.motor_cmd[joint_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[joint_idx].kp = self.robot.kp[joint_idx]
|
||||||
|
self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd[joint_idx]
|
||||||
|
self.robot.msg.motor_cmd[joint_idx].tau = 0
|
||||||
|
|
||||||
|
# send action to robot
|
||||||
|
self.robot.send_action(self.robot.msg)
|
||||||
|
|
||||||
|
def _locomotion_thread_loop(self):
|
||||||
|
"""Background thread that runs the locomotion policy at specified rate."""
|
||||||
|
logger.info("Locomotion thread started")
|
||||||
|
while self.locomotion_running:
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
self.groot_locomotion_run()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in locomotion loop: {e}")
|
||||||
|
|
||||||
|
# Sleep to maintain control rate
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
logger.info("Locomotion thread stopped")
|
||||||
|
|
||||||
|
def start_locomotion_thread(self):
|
||||||
|
if self.locomotion_running:
|
||||||
|
logger.warning("Locomotion thread already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Starting locomotion control thread...")
|
||||||
|
self.locomotion_running = True
|
||||||
|
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
|
||||||
|
self.locomotion_thread.start()
|
||||||
|
|
||||||
|
logger.info("Locomotion control thread started!")
|
||||||
|
|
||||||
|
def stop_locomotion_thread(self):
|
||||||
|
if not self.locomotion_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Stopping locomotion control thread...")
|
||||||
|
self.locomotion_running = False
|
||||||
|
if self.locomotion_thread:
|
||||||
|
self.locomotion_thread.join(timeout=2.0)
|
||||||
|
logger.info("Locomotion control thread stopped")
|
||||||
|
|
||||||
|
def reset_robot(self):
|
||||||
|
"""Move robot legs to default standing position over 2 seconds (arms are not moved)."""
|
||||||
|
total_time = 3.0
|
||||||
|
num_step = int(total_time / self.robot.control_dt)
|
||||||
|
|
||||||
|
# Only control legs, not arms (first 12 joints)
|
||||||
|
default_pos = GROOT_DEFAULT_ANGLES # First 12 values are leg angles
|
||||||
|
dof_size = len(default_pos)
|
||||||
|
|
||||||
|
# Get current lowstate
|
||||||
|
robot_state = self.robot.get_observation()
|
||||||
|
|
||||||
|
# Record the current leg positions
|
||||||
|
init_dof_pos = np.zeros(dof_size, dtype=np.float32)
|
||||||
|
for i in range(dof_size):
|
||||||
|
init_dof_pos[i] = robot_state.motor_state[i].q
|
||||||
|
|
||||||
|
# Move legs to default pos
|
||||||
|
for i in range(num_step):
|
||||||
|
alpha = i / num_step
|
||||||
|
for motor_idx in range(dof_size):
|
||||||
|
target_pos = default_pos[motor_idx]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = (
|
||||||
|
init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
|
||||||
|
)
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||||
|
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||||
|
time.sleep(self.robot.control_dt)
|
||||||
|
logger.info("Reached default position (legs only)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1")
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_GROOT_REPO_ID,
|
||||||
|
help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# load policies
|
||||||
|
policy_balance, policy_walk = load_groot_policies(repo_id=args.repo_id)
|
||||||
|
|
||||||
|
# initialize robot
|
||||||
|
config = UnitreeG1Config()
|
||||||
|
robot = UnitreeG1(config)
|
||||||
|
|
||||||
|
# initialize gr00t locomotion controller
|
||||||
|
groot_controller = GrootLocomotionController(
|
||||||
|
policy_balance=policy_balance,
|
||||||
|
policy_walk=policy_walk,
|
||||||
|
robot=robot,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# reset legs and start locomotion thread
|
||||||
|
try:
|
||||||
|
groot_controller.reset_robot()
|
||||||
|
groot_controller.start_locomotion_thread()
|
||||||
|
|
||||||
|
# log status
|
||||||
|
logger.info("Robot initialized with GR00T locomotion policies")
|
||||||
|
logger.info("Locomotion controller running in background thread")
|
||||||
|
logger.info("Press Ctrl+C to stop")
|
||||||
|
|
||||||
|
# keep robot alive
|
||||||
|
while True:
|
||||||
|
time.sleep(1.0)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nStopping locomotion...")
|
||||||
|
groot_controller.stop_locomotion_thread()
|
||||||
|
print("Done!")
|
||||||
10
loop_datasets.py
Normal file
10
loop_datasets.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from huggingface_hub import HfApi, list_datasets
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
datasets = list_datasets(author="lerobot-data-collection")
|
||||||
|
print('"[', end="")
|
||||||
|
i=0
|
||||||
|
for dataset in datasets:
|
||||||
|
if "three-folds-dataset" in dataset.id:
|
||||||
|
print("'" + dataset.id + "',", end="")
|
||||||
|
print(']"',)
|
||||||
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "lerobot"
|
name = "lerobot"
|
||||||
version = "0.4.2"
|
version = "0.4.3"
|
||||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { text = "Apache-2.0" }
|
license = { text = "Apache-2.0" }
|
||||||
@@ -102,11 +102,17 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (com
|
|||||||
# Motors
|
# Motors
|
||||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||||
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||||
|
damiao = ["python-can>=4.2.0,<5.0.0"]
|
||||||
|
|
||||||
# Robots
|
# Robots
|
||||||
|
openarms = ["lerobot[damiao]"]
|
||||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||||
|
unitree_g1 = [
|
||||||
|
"pyzmq>=26.2.1,<28.0.0",
|
||||||
|
"onnxruntime>=1.16.0"
|
||||||
|
]
|
||||||
reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"]
|
reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"]
|
||||||
kinematics = ["lerobot[placo-dep]"]
|
kinematics = ["lerobot[placo-dep]"]
|
||||||
intelrealsense = [
|
intelrealsense = [
|
||||||
@@ -129,6 +135,7 @@ groot = [
|
|||||||
"ninja>=1.11.1,<2.0.0",
|
"ninja>=1.11.1,<2.0.0",
|
||||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||||
]
|
]
|
||||||
|
xvla = ["lerobot[transformers-dep]"]
|
||||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
|
|
||||||
# Features
|
# Features
|
||||||
@@ -148,6 +155,7 @@ metaworld = ["metaworld==3.0.0"]
|
|||||||
# All
|
# All
|
||||||
all = [
|
all = [
|
||||||
"lerobot[dynamixel]",
|
"lerobot[dynamixel]",
|
||||||
|
"lerobot[openarms]",
|
||||||
"lerobot[gamepad]",
|
"lerobot[gamepad]",
|
||||||
"lerobot[hopejr]",
|
"lerobot[hopejr]",
|
||||||
"lerobot[lekiwi]",
|
"lerobot[lekiwi]",
|
||||||
@@ -157,6 +165,7 @@ all = [
|
|||||||
"lerobot[pi]",
|
"lerobot[pi]",
|
||||||
"lerobot[smolvla]",
|
"lerobot[smolvla]",
|
||||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||||
|
"lerobot[xvla]",
|
||||||
"lerobot[hilserl]",
|
"lerobot[hilserl]",
|
||||||
"lerobot[async]",
|
"lerobot[async]",
|
||||||
"lerobot[dev]",
|
"lerobot[dev]",
|
||||||
@@ -257,6 +266,7 @@ default.extend-ignore-identifiers-re = [
|
|||||||
"ein",
|
"ein",
|
||||||
"thw",
|
"thw",
|
||||||
"inpt",
|
"inpt",
|
||||||
|
"ROBOTIS",
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO: Uncomment when ready to use
|
# TODO: Uncomment when ready to use
|
||||||
@@ -356,9 +366,9 @@ ignore_errors = false
|
|||||||
# module = "lerobot.async_inference.*"
|
# module = "lerobot.async_inference.*"
|
||||||
# ignore_errors = false
|
# ignore_errors = false
|
||||||
|
|
||||||
# [[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
# module = "lerobot.transport.*"
|
module = "lerobot.transport.*"
|
||||||
# ignore_errors = false
|
ignore_errors = false
|
||||||
|
|
||||||
# [[tool.mypy.overrides]]
|
# [[tool.mypy.overrides]]
|
||||||
# module = "lerobot.scripts.*"
|
# module = "lerobot.scripts.*"
|
||||||
|
|||||||
@@ -26,4 +26,4 @@ DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
|||||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
|
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
|
||||||
|
|
||||||
# TODO: Add all other robots
|
# TODO: Add all other robots
|
||||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower"]
|
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower", "omx_follower"]
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ from lerobot.robots import ( # noqa: F401
|
|||||||
bi_so100_follower,
|
bi_so100_follower,
|
||||||
koch_follower,
|
koch_follower,
|
||||||
make_robot_from_config,
|
make_robot_from_config,
|
||||||
|
omx_follower,
|
||||||
so100_follower,
|
so100_follower,
|
||||||
so101_follower,
|
so101_follower,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
steps: int = 100_000
|
steps: int = 100_000
|
||||||
eval_freq: int = 20_000
|
eval_freq: int = 20_000
|
||||||
log_freq: int = 200
|
log_freq: int = 200
|
||||||
|
tolerance_s: float = 1e-4
|
||||||
save_checkpoint: bool = True
|
save_checkpoint: bool = True
|
||||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||||
save_freq: int = 20_000
|
save_freq: int = 20_000
|
||||||
|
|||||||
@@ -136,21 +136,40 @@ def update_meta_data(
|
|||||||
df["_orig_chunk"] = df[orig_chunk_col].copy()
|
df["_orig_chunk"] = df[orig_chunk_col].copy()
|
||||||
df["_orig_file"] = df[orig_file_col].copy()
|
df["_orig_file"] = df[orig_file_col].copy()
|
||||||
|
|
||||||
# Update chunk and file indices to point to destination
|
# Get mappings for this video key
|
||||||
df[orig_chunk_col] = video_idx["chunk"]
|
|
||||||
df[orig_file_col] = video_idx["file"]
|
|
||||||
|
|
||||||
# Apply per-source-file timestamp offsets
|
|
||||||
src_to_offset = video_idx.get("src_to_offset", {})
|
src_to_offset = video_idx.get("src_to_offset", {})
|
||||||
if src_to_offset:
|
src_to_dst = video_idx.get("src_to_dst", {})
|
||||||
# Apply offset based on original source file
|
|
||||||
|
# Apply per-source-file mappings
|
||||||
|
if src_to_dst:
|
||||||
|
# Map each episode to its correct destination file and apply offset
|
||||||
for idx in df.index:
|
for idx in df.index:
|
||||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
||||||
|
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
||||||
|
|
||||||
|
# Get destination chunk/file for this source file
|
||||||
|
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
|
||||||
|
df.at[idx, orig_chunk_col] = dst_chunk
|
||||||
|
df.at[idx, orig_file_col] = dst_file
|
||||||
|
|
||||||
|
# Apply timestamp offset
|
||||||
|
offset = src_to_offset.get(src_key, 0)
|
||||||
|
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||||
|
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||||
|
elif src_to_offset:
|
||||||
|
# Fallback: use same destination for all, but apply per-file offsets
|
||||||
|
df[orig_chunk_col] = video_idx["chunk"]
|
||||||
|
df[orig_file_col] = video_idx["file"]
|
||||||
|
for idx in df.index:
|
||||||
|
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
||||||
|
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
||||||
offset = src_to_offset.get(src_key, 0)
|
offset = src_to_offset.get(src_key, 0)
|
||||||
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||||
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||||
else:
|
else:
|
||||||
# Fallback to simple offset (for backward compatibility)
|
# Fallback to simple offset (for backward compatibility)
|
||||||
|
df[orig_chunk_col] = video_idx["chunk"]
|
||||||
|
df[orig_file_col] = video_idx["file"]
|
||||||
df[f"videos/{key}/from_timestamp"] = (
|
df[f"videos/{key}/from_timestamp"] = (
|
||||||
df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||||
)
|
)
|
||||||
@@ -268,6 +287,12 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
videos_idx[key]["episode_duration"] = 0
|
videos_idx[key]["episode_duration"] = 0
|
||||||
# Track offset for each source (chunk, file) pair
|
# Track offset for each source (chunk, file) pair
|
||||||
videos_idx[key]["src_to_offset"] = {}
|
videos_idx[key]["src_to_offset"] = {}
|
||||||
|
# Track destination (chunk, file) for each source (chunk, file) pair
|
||||||
|
videos_idx[key]["src_to_dst"] = {}
|
||||||
|
# Initialize dst_file_durations if not present
|
||||||
|
# dst_file_durations tracks duration of each destination file
|
||||||
|
if "dst_file_durations" not in videos_idx[key]:
|
||||||
|
videos_idx[key]["dst_file_durations"] = {}
|
||||||
|
|
||||||
for key, video_idx in videos_idx.items():
|
for key, video_idx in videos_idx.items():
|
||||||
unique_chunk_file_pairs = {
|
unique_chunk_file_pairs = {
|
||||||
@@ -282,9 +307,13 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
|
|
||||||
chunk_idx = video_idx["chunk"]
|
chunk_idx = video_idx["chunk"]
|
||||||
file_idx = video_idx["file"]
|
file_idx = video_idx["file"]
|
||||||
current_offset = video_idx["latest_duration"]
|
dst_file_durations = video_idx["dst_file_durations"]
|
||||||
|
|
||||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||||
|
# Convert to Python int to ensure consistent dict keys
|
||||||
|
src_chunk_idx = int(src_chunk_idx)
|
||||||
|
src_file_idx = int(src_file_idx)
|
||||||
|
|
||||||
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||||
video_key=key,
|
video_key=key,
|
||||||
chunk_index=src_chunk_idx,
|
chunk_index=src_chunk_idx,
|
||||||
@@ -298,14 +327,17 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
)
|
)
|
||||||
|
|
||||||
src_duration = get_video_duration_in_s(src_path)
|
src_duration = get_video_duration_in_s(src_path)
|
||||||
|
dst_key = (chunk_idx, file_idx)
|
||||||
|
|
||||||
if not dst_path.exists():
|
if not dst_path.exists():
|
||||||
# Store offset before incrementing
|
# New destination file: offset is 0
|
||||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||||
|
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
shutil.copy(str(src_path), str(dst_path))
|
shutil.copy(str(src_path), str(dst_path))
|
||||||
|
# Track duration of this destination file
|
||||||
|
dst_file_durations[dst_key] = src_duration
|
||||||
videos_idx[key]["episode_duration"] += src_duration
|
videos_idx[key]["episode_duration"] += src_duration
|
||||||
current_offset += src_duration
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check file sizes before appending
|
# Check file sizes before appending
|
||||||
@@ -313,10 +345,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
dst_size = get_file_size_in_mb(dst_path)
|
dst_size = get_file_size_in_mb(dst_path)
|
||||||
|
|
||||||
if dst_size + src_size >= video_files_size_in_mb:
|
if dst_size + src_size >= video_files_size_in_mb:
|
||||||
# Rotate to a new file, this source becomes start of new destination
|
# Rotate to a new file - offset is 0
|
||||||
# So its offset should be 0
|
|
||||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
|
||||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||||
|
dst_key = (chunk_idx, file_idx)
|
||||||
|
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||||
|
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||||
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||||
video_key=key,
|
video_key=key,
|
||||||
chunk_index=chunk_idx,
|
chunk_index=chunk_idx,
|
||||||
@@ -324,16 +357,20 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
)
|
)
|
||||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
shutil.copy(str(src_path), str(dst_path))
|
shutil.copy(str(src_path), str(dst_path))
|
||||||
# Reset offset for next file
|
# Track duration of this new destination file
|
||||||
current_offset = src_duration
|
dst_file_durations[dst_key] = src_duration
|
||||||
else:
|
else:
|
||||||
# Append to existing video file - use current accumulated offset
|
# Append to existing destination file
|
||||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
# Offset is the current duration of this destination file
|
||||||
|
current_dst_duration = dst_file_durations.get(dst_key, 0)
|
||||||
|
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
|
||||||
|
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||||
concatenate_video_files(
|
concatenate_video_files(
|
||||||
[dst_path, src_path],
|
[dst_path, src_path],
|
||||||
dst_path,
|
dst_path,
|
||||||
)
|
)
|
||||||
current_offset += src_duration
|
# Update duration of this destination file
|
||||||
|
dst_file_durations[dst_key] = current_dst_duration + src_duration
|
||||||
|
|
||||||
videos_idx[key]["episode_duration"] += src_duration
|
videos_idx[key]["episode_duration"] += src_duration
|
||||||
|
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||||||
image_transforms=image_transforms,
|
image_transforms=image_transforms,
|
||||||
revision=cfg.dataset.revision,
|
revision=cfg.dataset.revision,
|
||||||
video_backend=cfg.dataset.video_backend,
|
video_backend=cfg.dataset.video_backend,
|
||||||
|
tolerance_s=cfg.tolerance_s,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dataset = StreamingLeRobotDataset(
|
dataset = StreamingLeRobotDataset(
|
||||||
@@ -108,6 +109,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||||||
image_transforms=image_transforms,
|
image_transforms=image_transforms,
|
||||||
revision=cfg.dataset.revision,
|
revision=cfg.dataset.revision,
|
||||||
max_num_shards=cfg.num_workers,
|
max_num_shards=cfg.num_workers,
|
||||||
|
tolerance_s=cfg.tolerance_s,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||||
|
|||||||
@@ -23,11 +23,13 @@ from pathlib import Path
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
import packaging.version
|
import packaging.version
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
import torch
|
import torch
|
||||||
import torch.utils
|
import torch.utils
|
||||||
from huggingface_hub import HfApi, snapshot_download
|
from huggingface_hub import HfApi, snapshot_download
|
||||||
@@ -1199,6 +1201,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
use_batched_encoding = self.batch_encoding_size > 1
|
use_batched_encoding = self.batch_encoding_size > 1
|
||||||
|
|
||||||
if has_video_keys and not use_batched_encoding:
|
if has_video_keys and not use_batched_encoding:
|
||||||
|
video_paths = self._encode_multiple_temporary_episode_videos(self.meta.video_keys, episode_index)
|
||||||
|
for (video_key, video_path) in zip(self.meta.video_keys, video_paths):
|
||||||
|
ep_metadata.update(self._save_episode_video(video_key, episode_index, video_path))
|
||||||
num_cameras = len(self.meta.video_keys)
|
num_cameras = len(self.meta.video_keys)
|
||||||
if parallel_encoding and num_cameras > 1:
|
if parallel_encoding and num_cameras > 1:
|
||||||
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
|
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
|
||||||
@@ -1528,6 +1533,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
"""
|
"""
|
||||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps)
|
return _encode_video_worker(video_key, episode_index, self.root, self.fps)
|
||||||
|
|
||||||
|
def _encode_multiple_temporary_episode_videos(self, video_keys, episode_index):
|
||||||
|
temp_paths = []
|
||||||
|
img_dirs = []
|
||||||
|
for video_key in video_keys:
|
||||||
|
temp_paths.append(Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4")
|
||||||
|
img_dirs.append(self._get_image_file_dir(episode_index, video_key))
|
||||||
|
fps = [self.fps]*len(video_keys)
|
||||||
|
|
||||||
|
with ProcessPoolExecutor(max_workers=len(video_keys)) as executor:
|
||||||
|
executor.map(encode_video_frames,img_dirs,temp_paths,fps)
|
||||||
|
|
||||||
|
for img_dir in img_dirs:
|
||||||
|
shutil.rmtree(img_dir)
|
||||||
|
|
||||||
|
return temp_paths
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@@ -310,7 +310,7 @@ def encode_video_frames(
|
|||||||
crf: int | None = 30,
|
crf: int | None = 30,
|
||||||
fast_decode: int = 0,
|
fast_decode: int = 0,
|
||||||
log_level: int | None = av.logging.ERROR,
|
log_level: int | None = av.logging.ERROR,
|
||||||
overwrite: bool = False,
|
overwrite: bool = True,
|
||||||
preset: int | None = None,
|
preset: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||||
@@ -355,6 +355,9 @@ def encode_video_frames(
|
|||||||
if crf is not None:
|
if crf is not None:
|
||||||
video_options["crf"] = str(crf)
|
video_options["crf"] = str(crf)
|
||||||
|
|
||||||
|
#TEMPORARY FIX
|
||||||
|
video_options["preset"] = "12"
|
||||||
|
|
||||||
if fast_decode:
|
if fast_decode:
|
||||||
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
||||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||||
|
|||||||
@@ -245,7 +245,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
|||||||
class LiberoEnv(EnvConfig):
|
class LiberoEnv(EnvConfig):
|
||||||
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
|
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
|
||||||
fps: int = 30
|
fps: int = 30
|
||||||
episode_length: int = 520
|
episode_length: int | None = None
|
||||||
obs_type: str = "pixels_agent_pos"
|
obs_type: str = "pixels_agent_pos"
|
||||||
render_mode: str = "rgb_array"
|
render_mode: str = "rgb_array"
|
||||||
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
|
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
|
||||||
@@ -272,6 +272,7 @@ class LiberoEnv(EnvConfig):
|
|||||||
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
|
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
control_mode: str = "relative" # or "absolute"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.obs_type == "pixels":
|
if self.obs_type == "pixels":
|
||||||
|
|||||||
@@ -19,8 +19,10 @@ from typing import Any
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.envs.registration import registry as gym_registry
|
from gymnasium.envs.registration import registry as gym_registry
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
|
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
|
||||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||||
|
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
from lerobot.processor import ProcessorStep
|
from lerobot.processor import ProcessorStep
|
||||||
from lerobot.processor.env_processor import LiberoProcessorStep
|
from lerobot.processor.env_processor import LiberoProcessorStep
|
||||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
@@ -39,6 +41,7 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
|||||||
|
|
||||||
def make_env_pre_post_processors(
|
def make_env_pre_post_processors(
|
||||||
env_cfg: EnvConfig,
|
env_cfg: EnvConfig,
|
||||||
|
policy_cfg: PreTrainedConfig,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
@@ -61,6 +64,10 @@ def make_env_pre_post_processors(
|
|||||||
# Preprocessor and Postprocessor steps are Identity for most environments
|
# Preprocessor and Postprocessor steps are Identity for most environments
|
||||||
preprocessor_steps: list[ProcessorStep] = []
|
preprocessor_steps: list[ProcessorStep] = []
|
||||||
postprocessor_steps: list[ProcessorStep] = []
|
postprocessor_steps: list[ProcessorStep] = []
|
||||||
|
if isinstance(policy_cfg, XVLAConfig):
|
||||||
|
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
|
||||||
|
|
||||||
|
return make_xvla_libero_pre_post_processors()
|
||||||
|
|
||||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||||
@@ -136,6 +143,8 @@ def make_env(
|
|||||||
init_states=cfg.init_states,
|
init_states=cfg.init_states,
|
||||||
gym_kwargs=cfg.gym_kwargs,
|
gym_kwargs=cfg.gym_kwargs,
|
||||||
env_cls=env_cls,
|
env_cls=env_cls,
|
||||||
|
control_mode=cfg.control_mode,
|
||||||
|
episode_length=cfg.episode_length,
|
||||||
)
|
)
|
||||||
elif "metaworld" in cfg.type:
|
elif "metaworld" in cfg.type:
|
||||||
from lerobot.envs.metaworld import create_metaworld_envs
|
from lerobot.envs.metaworld import create_metaworld_envs
|
||||||
|
|||||||
@@ -80,10 +80,7 @@ def get_libero_dummy_action():
|
|||||||
return [0, 0, 0, 0, 0, 0, -1]
|
return [0, 0, 0, 0, 0, 0, -1]
|
||||||
|
|
||||||
|
|
||||||
OBS_STATE_DIM = 8
|
|
||||||
ACTION_DIM = 7
|
ACTION_DIM = 7
|
||||||
AGENT_POS_LOW = -1000.0
|
|
||||||
AGENT_POS_HIGH = 1000.0
|
|
||||||
ACTION_LOW = -1.0
|
ACTION_LOW = -1.0
|
||||||
ACTION_HIGH = 1.0
|
ACTION_HIGH = 1.0
|
||||||
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
||||||
@@ -103,6 +100,7 @@ class LiberoEnv(gym.Env):
|
|||||||
task_suite: Any,
|
task_suite: Any,
|
||||||
task_id: int,
|
task_id: int,
|
||||||
task_suite_name: str,
|
task_suite_name: str,
|
||||||
|
episode_length: int | None = None,
|
||||||
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||||
obs_type: str = "pixels",
|
obs_type: str = "pixels",
|
||||||
render_mode: str = "rgb_array",
|
render_mode: str = "rgb_array",
|
||||||
@@ -114,6 +112,7 @@ class LiberoEnv(gym.Env):
|
|||||||
episode_index: int = 0,
|
episode_index: int = 0,
|
||||||
camera_name_mapping: dict[str, str] | None = None,
|
camera_name_mapping: dict[str, str] | None = None,
|
||||||
num_steps_wait: int = 10,
|
num_steps_wait: int = 10,
|
||||||
|
control_mode: str = "relative",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.task_id = task_id
|
self.task_id = task_id
|
||||||
@@ -141,14 +140,19 @@ class LiberoEnv(gym.Env):
|
|||||||
self.camera_name_mapping = camera_name_mapping
|
self.camera_name_mapping = camera_name_mapping
|
||||||
self.num_steps_wait = num_steps_wait
|
self.num_steps_wait = num_steps_wait
|
||||||
self.episode_index = episode_index
|
self.episode_index = episode_index
|
||||||
|
self.episode_length = episode_length
|
||||||
# Load once and keep
|
# Load once and keep
|
||||||
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
|
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
|
||||||
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||||
|
|
||||||
self._env = self._make_envs_task(task_suite, self.task_id)
|
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||||
default_steps = 500
|
default_steps = 500
|
||||||
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
self._max_episode_steps = (
|
||||||
|
TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||||
|
if self.episode_length is None
|
||||||
|
else self.episode_length
|
||||||
|
)
|
||||||
|
self.control_mode = control_mode
|
||||||
images = {}
|
images = {}
|
||||||
for cam in self.camera_name:
|
for cam in self.camera_name:
|
||||||
images[self.camera_name_mapping[cam]] = spaces.Box(
|
images[self.camera_name_mapping[cam]] = spaces.Box(
|
||||||
@@ -296,6 +300,15 @@ class LiberoEnv(gym.Env):
|
|||||||
# Increasing this value can improve determinism and reproducibility across resets.
|
# Increasing this value can improve determinism and reproducibility across resets.
|
||||||
for _ in range(self.num_steps_wait):
|
for _ in range(self.num_steps_wait):
|
||||||
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
||||||
|
|
||||||
|
if self.control_mode == "absolute":
|
||||||
|
for robot in self._env.robots:
|
||||||
|
robot.controller.use_delta = False
|
||||||
|
elif self.control_mode == "relative":
|
||||||
|
for robot in self._env.robots:
|
||||||
|
robot.controller.use_delta = True
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid control mode: {self.control_mode}")
|
||||||
observation = self._format_raw_obs(raw_obs)
|
observation = self._format_raw_obs(raw_obs)
|
||||||
info = {"is_success": False}
|
info = {"is_success": False}
|
||||||
return observation, info
|
return observation, info
|
||||||
@@ -341,8 +354,10 @@ def _make_env_fns(
|
|||||||
task_id: int,
|
task_id: int,
|
||||||
n_envs: int,
|
n_envs: int,
|
||||||
camera_names: list[str],
|
camera_names: list[str],
|
||||||
|
episode_length: int | None,
|
||||||
init_states: bool,
|
init_states: bool,
|
||||||
gym_kwargs: Mapping[str, Any],
|
gym_kwargs: Mapping[str, Any],
|
||||||
|
control_mode: str,
|
||||||
) -> list[Callable[[], LiberoEnv]]:
|
) -> list[Callable[[], LiberoEnv]]:
|
||||||
"""Build n_envs factory callables for a single (suite, task_id)."""
|
"""Build n_envs factory callables for a single (suite, task_id)."""
|
||||||
|
|
||||||
@@ -354,7 +369,9 @@ def _make_env_fns(
|
|||||||
task_suite_name=suite_name,
|
task_suite_name=suite_name,
|
||||||
camera_name=camera_names,
|
camera_name=camera_names,
|
||||||
init_states=init_states,
|
init_states=init_states,
|
||||||
|
episode_length=episode_length,
|
||||||
episode_index=episode_index,
|
episode_index=episode_index,
|
||||||
|
control_mode=control_mode,
|
||||||
**local_kwargs,
|
**local_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -374,6 +391,8 @@ def create_libero_envs(
|
|||||||
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||||
init_states: bool = True,
|
init_states: bool = True,
|
||||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||||
|
control_mode: str = "relative",
|
||||||
|
episode_length: int | None = None,
|
||||||
) -> dict[str, dict[int, Any]]:
|
) -> dict[str, dict[int, Any]]:
|
||||||
"""
|
"""
|
||||||
Create vectorized LIBERO environments with a consistent return shape.
|
Create vectorized LIBERO environments with a consistent return shape.
|
||||||
@@ -415,12 +434,14 @@ def create_libero_envs(
|
|||||||
for tid in selected:
|
for tid in selected:
|
||||||
fns = _make_env_fns(
|
fns = _make_env_fns(
|
||||||
suite=suite,
|
suite=suite,
|
||||||
|
episode_length=episode_length,
|
||||||
suite_name=suite_name,
|
suite_name=suite_name,
|
||||||
task_id=tid,
|
task_id=tid,
|
||||||
n_envs=n_envs,
|
n_envs=n_envs,
|
||||||
camera_names=camera_names,
|
camera_names=camera_names,
|
||||||
init_states=init_states,
|
init_states=init_states,
|
||||||
gym_kwargs=gym_kwargs,
|
gym_kwargs=gym_kwargs,
|
||||||
|
control_mode=control_mode,
|
||||||
)
|
)
|
||||||
out[suite_name][tid] = env_cls(fns)
|
out[suite_name][tid] = env_cls(fns)
|
||||||
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||||
|
|||||||
@@ -14,4 +14,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus
|
from .motors_bus import (
|
||||||
|
Motor,
|
||||||
|
MotorCalibration,
|
||||||
|
MotorNormMode,
|
||||||
|
MotorsBus, # Backward compatibility (alias for SerialMotorsBus)
|
||||||
|
MotorsBusBase,
|
||||||
|
SerialMotorsBus,
|
||||||
|
)
|
||||||
|
|||||||
18
src/lerobot/motors/damiao/__init__.py
Normal file
18
src/lerobot/motors/damiao/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .damiao import DamiaoMotorsBus
|
||||||
|
from .tables import *
|
||||||
905
src/lerobot/motors/damiao/damiao.py
Normal file
905
src/lerobot/motors/damiao/damiao.py
Normal file
@@ -0,0 +1,905 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# TODO(pepijn): add license of: https://github.com/cmjang/DM_Control_Python?tab=MIT-1-ov-file#readme
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from copy import deepcopy
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import can
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.motors import Motor, MotorCalibration, MotorNormMode, MotorsBusBase
|
||||||
|
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||||
|
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||||
|
|
||||||
|
from .tables import (
|
||||||
|
AVAILABLE_BAUDRATES,
|
||||||
|
CAN_CMD_DISABLE,
|
||||||
|
CAN_CMD_ENABLE,
|
||||||
|
CAN_CMD_REFRESH,
|
||||||
|
CAN_CMD_SET_ZERO,
|
||||||
|
CAN_PARAM_ID,
|
||||||
|
DEFAULT_BAUDRATE,
|
||||||
|
DEFAULT_TIMEOUT_MS,
|
||||||
|
MODEL_RESOLUTION,
|
||||||
|
MOTOR_LIMIT_PARAMS,
|
||||||
|
NORMALIZED_DATA,
|
||||||
|
MotorType,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
NameOrID = Union[str, int]
|
||||||
|
Value = Union[int, float]
|
||||||
|
|
||||||
|
|
||||||
|
class DamiaoMotorsBus(MotorsBusBase):
|
||||||
|
"""
|
||||||
|
The Damiao implementation for a MotorsBus using CAN bus communication.
|
||||||
|
|
||||||
|
This class uses python-can for CAN bus communication with Damiao motors.
|
||||||
|
For more info, see:
|
||||||
|
- python-can documentation: https://python-can.readthedocs.io/en/stable/
|
||||||
|
- Seedstudio documentation: https://wiki.seeedstudio.com/damiao_series/
|
||||||
|
- DM_Control_Python repo: https://github.com/cmjang/DM_Control_Python
|
||||||
|
"""
|
||||||
|
|
||||||
|
# CAN-specific settings
|
||||||
|
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
|
||||||
|
default_baudrate = DEFAULT_BAUDRATE
|
||||||
|
default_timeout = DEFAULT_TIMEOUT_MS
|
||||||
|
|
||||||
|
# Motor configuration
|
||||||
|
model_resolution_table = deepcopy(MODEL_RESOLUTION)
|
||||||
|
normalized_data = deepcopy(NORMALIZED_DATA)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
port: str,
|
||||||
|
motors: dict[str, Motor],
|
||||||
|
calibration: dict[str, MotorCalibration] | None = None,
|
||||||
|
can_interface: str = "auto",
|
||||||
|
use_can_fd: bool = True,
|
||||||
|
bitrate: int = 1000000,
|
||||||
|
data_bitrate: int | None = 5000000,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the Damiao motors bus.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS)
|
||||||
|
motors: Dictionary mapping motor names to Motor objects
|
||||||
|
calibration: Optional calibration data
|
||||||
|
can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial)
|
||||||
|
use_can_fd: Whether to use CAN FD mode (default: True for OpenArms)
|
||||||
|
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
|
||||||
|
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
|
||||||
|
"""
|
||||||
|
super().__init__(port, motors, calibration)
|
||||||
|
self.port = port
|
||||||
|
self.can_interface = can_interface
|
||||||
|
self.use_can_fd = use_can_fd
|
||||||
|
self.bitrate = bitrate
|
||||||
|
self.data_bitrate = data_bitrate
|
||||||
|
self.canbus = None
|
||||||
|
self._is_connected = False
|
||||||
|
|
||||||
|
# Map motor names to CAN IDs
|
||||||
|
self._motor_can_ids = {}
|
||||||
|
self._recv_id_to_motor = {}
|
||||||
|
|
||||||
|
# Store motor types and recv IDs
|
||||||
|
self._motor_types = {}
|
||||||
|
for name, motor in self.motors.items():
|
||||||
|
if hasattr(motor, "motor_type"):
|
||||||
|
self._motor_types[name] = motor.motor_type
|
||||||
|
else:
|
||||||
|
# Default to DM4310 if not specified
|
||||||
|
self._motor_types[name] = MotorType.DM4310
|
||||||
|
|
||||||
|
# Map recv_id to motor name for filtering responses
|
||||||
|
if hasattr(motor, "recv_id"):
|
||||||
|
self._recv_id_to_motor[motor.recv_id] = name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
"""Check if the CAN bus is connected."""
|
||||||
|
return self._is_connected and self.canbus is not None
|
||||||
|
|
||||||
|
def connect(self, handshake: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Open the CAN bus and initialize communication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
handshake: If True, ping all motors to verify they're present
|
||||||
|
"""
|
||||||
|
if self.is_connected:
|
||||||
|
raise DeviceAlreadyConnectedError(
|
||||||
|
f"{self.__class__.__name__}('{self.port}') is already connected."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Auto-detect interface type based on port name
|
||||||
|
if self.can_interface == "auto":
|
||||||
|
if self.port.startswith("/dev/"):
|
||||||
|
# Serial device (macOS/Windows)
|
||||||
|
self.can_interface = "slcan"
|
||||||
|
logger.info(f"Auto-detected slcan interface for port {self.port}")
|
||||||
|
else:
|
||||||
|
# Network interface (Linux)
|
||||||
|
self.can_interface = "socketcan"
|
||||||
|
logger.info(f"Auto-detected socketcan interface for port {self.port}")
|
||||||
|
|
||||||
|
# Connect to CAN bus
|
||||||
|
if self.can_interface == "socketcan":
|
||||||
|
# Linux SocketCAN with CAN FD support
|
||||||
|
if self.use_can_fd and self.data_bitrate is not None:
|
||||||
|
self.canbus = can.interface.Bus(
|
||||||
|
channel=self.port,
|
||||||
|
interface="socketcan",
|
||||||
|
bitrate=self.bitrate,
|
||||||
|
data_bitrate=self.data_bitrate,
|
||||||
|
fd=True
|
||||||
|
)
|
||||||
|
logger.info(f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})")
|
||||||
|
else:
|
||||||
|
self.canbus = can.interface.Bus(
|
||||||
|
channel=self.port,
|
||||||
|
interface="socketcan",
|
||||||
|
bitrate=self.bitrate
|
||||||
|
)
|
||||||
|
logger.info(f"Connected to {self.port} with CAN 2.0 (bitrate={self.bitrate})")
|
||||||
|
elif self.can_interface == "slcan":
|
||||||
|
# Serial Line CAN (macOS, Windows, or USB adapters)
|
||||||
|
# Note: SLCAN typically doesn't support CAN FD
|
||||||
|
self.canbus = can.interface.Bus(
|
||||||
|
channel=self.port,
|
||||||
|
interface="slcan",
|
||||||
|
bitrate=self.bitrate
|
||||||
|
)
|
||||||
|
logger.info(f"Connected to {self.port} with SLCAN (bitrate={self.bitrate})")
|
||||||
|
else:
|
||||||
|
# Generic interface (vector, pcan, etc.)
|
||||||
|
if self.use_can_fd and self.data_bitrate is not None:
|
||||||
|
self.canbus = can.interface.Bus(
|
||||||
|
channel=self.port,
|
||||||
|
interface=self.can_interface,
|
||||||
|
bitrate=self.bitrate,
|
||||||
|
data_bitrate=self.data_bitrate,
|
||||||
|
fd=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.canbus = can.interface.Bus(
|
||||||
|
channel=self.port,
|
||||||
|
interface=self.can_interface,
|
||||||
|
bitrate=self.bitrate
|
||||||
|
)
|
||||||
|
|
||||||
|
self._is_connected = True
|
||||||
|
|
||||||
|
if handshake:
|
||||||
|
self._handshake()
|
||||||
|
|
||||||
|
logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.")
|
||||||
|
except Exception as e:
|
||||||
|
self._is_connected = False
|
||||||
|
raise ConnectionError(f"Failed to connect to CAN bus: {e}")
|
||||||
|
|
||||||
|
def _handshake(self) -> None:
|
||||||
|
"""Verify all motors are present by refreshing their status."""
|
||||||
|
for motor_name in self.motors:
|
||||||
|
self._refresh_motor(motor_name)
|
||||||
|
time.sleep(0.01) # Small delay between motors
|
||||||
|
|
||||||
|
def disconnect(self, disable_torque: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Close the CAN bus connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
disable_torque: If True, disable torque on all motors before disconnecting
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(
|
||||||
|
f"{self.__class__.__name__}('{self.port}') is not connected."
|
||||||
|
)
|
||||||
|
|
||||||
|
if disable_torque:
|
||||||
|
try:
|
||||||
|
self.disable_torque()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to disable torque during disconnect: {e}")
|
||||||
|
|
||||||
|
if self.canbus:
|
||||||
|
self.canbus.shutdown()
|
||||||
|
self.canbus = None
|
||||||
|
self._is_connected = False
|
||||||
|
logger.debug(f"{self.__class__.__name__} disconnected.")
|
||||||
|
|
||||||
|
def configure_motors(self) -> None:
|
||||||
|
"""Configure all motors with default settings."""
|
||||||
|
# Damiao motors don't require much configuration in MIT mode
|
||||||
|
# Just ensure they're enabled
|
||||||
|
for motor in self.motors:
|
||||||
|
self._enable_motor(motor)
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
def _enable_motor(self, motor: NameOrID) -> None:
|
||||||
|
"""Enable a single motor."""
|
||||||
|
motor_id = self._get_motor_id(motor)
|
||||||
|
recv_id = self._get_motor_recv_id(motor)
|
||||||
|
data = [0xFF] * 7 + [CAN_CMD_ENABLE]
|
||||||
|
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||||
|
self.canbus.send(msg)
|
||||||
|
self._recv_motor_response(expected_recv_id=recv_id)
|
||||||
|
|
||||||
|
def _disable_motor(self, motor: NameOrID) -> None:
|
||||||
|
"""Disable a single motor."""
|
||||||
|
motor_id = self._get_motor_id(motor)
|
||||||
|
recv_id = self._get_motor_recv_id(motor)
|
||||||
|
data = [0xFF] * 7 + [CAN_CMD_DISABLE]
|
||||||
|
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||||
|
self.canbus.send(msg)
|
||||||
|
self._recv_motor_response(expected_recv_id=recv_id)
|
||||||
|
|
||||||
|
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||||
|
"""Enable torque on selected motors."""
|
||||||
|
motors = self._get_motors_list(motors)
|
||||||
|
for motor in motors:
|
||||||
|
for _ in range(num_retry + 1):
|
||||||
|
try:
|
||||||
|
self._enable_motor(motor)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
if _ == num_retry:
|
||||||
|
raise e
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||||
|
"""Disable torque on selected motors."""
|
||||||
|
motors = self._get_motors_list(motors)
|
||||||
|
for motor in motors:
|
||||||
|
for _ in range(num_retry + 1):
|
||||||
|
try:
|
||||||
|
self._disable_motor(motor)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
if _ == num_retry:
|
||||||
|
raise e
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def torque_disabled(self, motors: str | list[str] | None = None):
|
||||||
|
"""
|
||||||
|
Context manager that guarantees torque is re-enabled.
|
||||||
|
|
||||||
|
This helper is useful to temporarily disable torque when configuring motors.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> with bus.torque_disabled():
|
||||||
|
... # Safe operations here with torque disabled
|
||||||
|
... pass
|
||||||
|
"""
|
||||||
|
self.disable_torque(motors)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self.enable_torque(motors)
|
||||||
|
|
||||||
|
def set_zero_position(self, motors: str | list[str] | None = None) -> None:
|
||||||
|
"""Set current position as zero for selected motors."""
|
||||||
|
motors = self._get_motors_list(motors)
|
||||||
|
for motor in motors:
|
||||||
|
motor_id = self._get_motor_id(motor)
|
||||||
|
recv_id = self._get_motor_recv_id(motor)
|
||||||
|
data = [0xFF] * 7 + [CAN_CMD_SET_ZERO]
|
||||||
|
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||||
|
self.canbus.send(msg)
|
||||||
|
self._recv_motor_response(expected_recv_id=recv_id)
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
def _refresh_motor(self, motor: NameOrID) -> Optional[can.Message]:
|
||||||
|
"""Refresh motor status and return the response."""
|
||||||
|
motor_id = self._get_motor_id(motor)
|
||||||
|
recv_id = self._get_motor_recv_id(motor)
|
||||||
|
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||||
|
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
|
||||||
|
self.canbus.send(msg)
|
||||||
|
return self._recv_motor_response(expected_recv_id=recv_id)
|
||||||
|
|
||||||
|
def _recv_motor_response(self, expected_recv_id: Optional[int] = None, timeout: float = 0.001) -> Optional[can.Message]:
|
||||||
|
"""
|
||||||
|
Receive a response from a motor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expected_recv_id: If provided, only return messages from this CAN ID
|
||||||
|
timeout: Timeout in seconds (default: 1ms for high-speed operation)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CAN message if received, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
messages_seen = []
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
msg = self.canbus.recv(timeout=0.0001) # 100us timeout for fast polling
|
||||||
|
if msg:
|
||||||
|
messages_seen.append(f"0x{msg.arbitration_id:02X}")
|
||||||
|
# If no filter specified, return any message
|
||||||
|
if expected_recv_id is None:
|
||||||
|
return msg
|
||||||
|
# Otherwise, only return if it matches the expected recv_id
|
||||||
|
if msg.arbitration_id == expected_recv_id:
|
||||||
|
return msg
|
||||||
|
else:
|
||||||
|
logger.debug(f"Ignoring message from CAN ID 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}")
|
||||||
|
|
||||||
|
# Only log warnings if we're in debug mode to reduce overhead
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
if messages_seen:
|
||||||
|
logger.debug(f"Received {len(messages_seen)} message(s) from IDs {set(messages_seen)}, but expected 0x{expected_recv_id:02X}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"No CAN messages received (expected from 0x{expected_recv_id:02X})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to receive CAN message: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _recv_all_responses(self, expected_recv_ids: list[int], timeout: float = 0.002) -> dict[int, can.Message]:
|
||||||
|
"""
|
||||||
|
Efficiently receive responses from multiple motors at once.
|
||||||
|
Uses the OpenArms pattern: collect all available messages within timeout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expected_recv_ids: List of CAN IDs we expect responses from
|
||||||
|
timeout: Total timeout in seconds (default: 2ms)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping recv_id to CAN message
|
||||||
|
"""
|
||||||
|
responses = {}
|
||||||
|
expected_set = set(expected_recv_ids)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
|
||||||
|
msg = self.canbus.recv(timeout=0.0002) # 200us poll timeout (increased from 100us for better reliability)
|
||||||
|
if msg and msg.arbitration_id in expected_set:
|
||||||
|
responses[msg.arbitration_id] = msg
|
||||||
|
if len(responses) == len(expected_recv_ids):
|
||||||
|
break # Got all responses, exit early
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error receiving responses: {e}")
|
||||||
|
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def _mit_control(
|
||||||
|
self,
|
||||||
|
motor: NameOrID,
|
||||||
|
kp: float,
|
||||||
|
kd: float,
|
||||||
|
position_degrees: float,
|
||||||
|
velocity_deg_per_sec: float,
|
||||||
|
torque: float,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Send MIT control command to a motor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
motor: Motor name or ID
|
||||||
|
kp: Position gain
|
||||||
|
kd: Velocity gain
|
||||||
|
position_degrees: Target position (degrees)
|
||||||
|
velocity_deg_per_sec: Target velocity (degrees/s)
|
||||||
|
torque: Target torque (N·m)
|
||||||
|
"""
|
||||||
|
motor_id = self._get_motor_id(motor)
|
||||||
|
motor_name = self._get_motor_name(motor)
|
||||||
|
motor_type = self._motor_types.get(motor_name, MotorType.DM4310)
|
||||||
|
|
||||||
|
# Convert degrees to radians for motor control
|
||||||
|
position_rad = np.radians(position_degrees)
|
||||||
|
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
|
||||||
|
|
||||||
|
# Get motor limits
|
||||||
|
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||||
|
|
||||||
|
# Encode parameters
|
||||||
|
kp_uint = self._float_to_uint(kp, 0, 500, 12)
|
||||||
|
kd_uint = self._float_to_uint(kd, 0, 5, 12)
|
||||||
|
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
|
||||||
|
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
|
||||||
|
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
|
||||||
|
|
||||||
|
# Pack data
|
||||||
|
data = [0] * 8
|
||||||
|
data[0] = (q_uint >> 8) & 0xFF
|
||||||
|
data[1] = q_uint & 0xFF
|
||||||
|
data[2] = dq_uint >> 4
|
||||||
|
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
|
||||||
|
data[4] = kp_uint & 0xFF
|
||||||
|
data[5] = kd_uint >> 4
|
||||||
|
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
|
||||||
|
data[7] = tau_uint & 0xFF
|
||||||
|
|
||||||
|
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||||
|
self.canbus.send(msg)
|
||||||
|
recv_id = self._get_motor_recv_id(motor)
|
||||||
|
self._recv_motor_response(expected_recv_id=recv_id)
|
||||||
|
|
||||||
|
def _mit_control_batch(
|
||||||
|
self,
|
||||||
|
commands: Dict[NameOrID, Tuple[float, float, float, float, float]],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Send MIT control commands to multiple motors in batch (optimized).
|
||||||
|
Sends all commands first, then collects responses. Much faster than sequential.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
commands: Dict mapping motor name/ID to (kp, kd, position_deg, velocity_deg/s, torque)
|
||||||
|
Example: {'joint_1': (10.0, 0.5, 45.0, 0.0, 0.0), ...}
|
||||||
|
"""
|
||||||
|
if not commands:
|
||||||
|
return
|
||||||
|
|
||||||
|
expected_recv_ids = []
|
||||||
|
|
||||||
|
# Step 1: Send all MIT control commands (no waiting)
|
||||||
|
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
|
||||||
|
motor_id = self._get_motor_id(motor)
|
||||||
|
motor_name = self._get_motor_name(motor)
|
||||||
|
motor_type = self._motor_types.get(motor_name, MotorType.DM4310)
|
||||||
|
|
||||||
|
# Convert degrees to radians
|
||||||
|
position_rad = np.radians(position_degrees)
|
||||||
|
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
|
||||||
|
|
||||||
|
# Get motor limits
|
||||||
|
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||||
|
|
||||||
|
# Encode parameters
|
||||||
|
kp_uint = self._float_to_uint(kp, 0, 500, 12)
|
||||||
|
kd_uint = self._float_to_uint(kd, 0, 5, 12)
|
||||||
|
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
|
||||||
|
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
|
||||||
|
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
|
||||||
|
|
||||||
|
# Pack data
|
||||||
|
data = [0] * 8
|
||||||
|
data[0] = (q_uint >> 8) & 0xFF
|
||||||
|
data[1] = q_uint & 0xFF
|
||||||
|
data[2] = dq_uint >> 4
|
||||||
|
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
|
||||||
|
data[4] = kp_uint & 0xFF
|
||||||
|
data[5] = kd_uint >> 4
|
||||||
|
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
|
||||||
|
data[7] = tau_uint & 0xFF
|
||||||
|
|
||||||
|
# Send command
|
||||||
|
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||||
|
self.canbus.send(msg)
|
||||||
|
|
||||||
|
# Track expected response
|
||||||
|
recv_id = self._get_motor_recv_id(motor)
|
||||||
|
expected_recv_ids.append(recv_id)
|
||||||
|
|
||||||
|
# Step 2: Collect all responses at once
|
||||||
|
self._recv_all_responses(expected_recv_ids, timeout=0.002)
|
||||||
|
|
||||||
|
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
|
||||||
|
"""Convert float to unsigned integer for CAN transmission."""
|
||||||
|
x = max(x_min, min(x_max, x)) # Clamp to range
|
||||||
|
span = x_max - x_min
|
||||||
|
data_norm = (x - x_min) / span
|
||||||
|
return int(data_norm * ((1 << bits) - 1))
|
||||||
|
|
||||||
|
def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float:
|
||||||
|
"""Convert unsigned integer from CAN to float."""
|
||||||
|
span = x_max - x_min
|
||||||
|
data_norm = float(x) / ((1 << bits) - 1)
|
||||||
|
return data_norm * span + x_min
|
||||||
|
|
||||||
|
def _decode_motor_state(self, data: bytes, motor_type: MotorType) -> Tuple[float, float, float, int, int]:
|
||||||
|
"""
|
||||||
|
Decode motor state from CAN data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (position_degrees, velocity_deg_per_sec, torque, temp_mos, temp_rotor)
|
||||||
|
"""
|
||||||
|
if len(data) < 8:
|
||||||
|
raise ValueError("Invalid motor state data")
|
||||||
|
|
||||||
|
# Extract encoded values
|
||||||
|
q_uint = (data[1] << 8) | data[2]
|
||||||
|
dq_uint = (data[3] << 4) | (data[4] >> 4)
|
||||||
|
tau_uint = ((data[4] & 0x0F) << 8) | data[5]
|
||||||
|
t_mos = data[6]
|
||||||
|
t_rotor = data[7]
|
||||||
|
|
||||||
|
# Get motor limits
|
||||||
|
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||||
|
|
||||||
|
# Decode to physical values (radians)
|
||||||
|
position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16)
|
||||||
|
velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12)
|
||||||
|
torque = self._uint_to_float(tau_uint, -tmax, tmax, 12)
|
||||||
|
|
||||||
|
# Convert to degrees
|
||||||
|
position_degrees = np.degrees(position_rad)
|
||||||
|
velocity_deg_per_sec = np.degrees(velocity_rad_per_sec)
|
||||||
|
|
||||||
|
return position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor
|
||||||
|
|
||||||
|
def read(
|
||||||
|
self,
|
||||||
|
data_name: str,
|
||||||
|
motor: str,
|
||||||
|
*,
|
||||||
|
normalize: bool = True,
|
||||||
|
num_retry: int = 0,
|
||||||
|
) -> Value:
|
||||||
|
"""Read a value from a single motor. Positions are always in degrees."""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||||
|
|
||||||
|
# Refresh motor to get latest state
|
||||||
|
msg = self._refresh_motor(motor)
|
||||||
|
if msg is None:
|
||||||
|
motor_id = self._get_motor_id(motor)
|
||||||
|
recv_id = self._get_motor_recv_id(motor)
|
||||||
|
raise ConnectionError(
|
||||||
|
f"No response from motor '{motor}' (send ID: 0x{motor_id:02X}, recv ID: 0x{recv_id:02X}). "
|
||||||
|
f"Check that: 1) Motor is powered (24V), 2) CAN wiring is correct, "
|
||||||
|
f"3) Motor IDs are configured correctly using Damiao Debugging Tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
motor_type = self._motor_types.get(motor, MotorType.DM4310)
|
||||||
|
position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
|
||||||
|
|
||||||
|
# Return requested data (already in degrees for position/velocity)
|
||||||
|
if data_name == "Present_Position":
|
||||||
|
value = position_degrees
|
||||||
|
elif data_name == "Present_Velocity":
|
||||||
|
value = velocity_deg_per_sec
|
||||||
|
elif data_name == "Present_Torque":
|
||||||
|
value = torque
|
||||||
|
elif data_name == "Temperature_MOS":
|
||||||
|
value = t_mos
|
||||||
|
elif data_name == "Temperature_Rotor":
|
||||||
|
value = t_rotor
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown data_name: {data_name}")
|
||||||
|
|
||||||
|
# For Damiao, positions are always in degrees, no normalization needed
|
||||||
|
# We keep the normalize parameter for compatibility but don't use it
|
||||||
|
return value
|
||||||
|
|
||||||
|
def write(
|
||||||
|
self,
|
||||||
|
data_name: str,
|
||||||
|
motor: str,
|
||||||
|
value: Value,
|
||||||
|
*,
|
||||||
|
normalize: bool = True,
|
||||||
|
num_retry: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Write a value to a single motor. Positions are always in degrees."""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||||
|
|
||||||
|
# Value is expected to be in degrees for positions
|
||||||
|
if data_name == "Goal_Position":
|
||||||
|
# Use MIT control with position in degrees
|
||||||
|
self._mit_control(motor, 10.0, 0.5, value, 0, 0)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Writing {data_name} not supported in MIT mode")
|
||||||
|
|
||||||
|
def sync_read(
|
||||||
|
self,
|
||||||
|
data_name: str,
|
||||||
|
motors: str | list[str] | None = None,
|
||||||
|
*,
|
||||||
|
normalize: bool = True,
|
||||||
|
num_retry: int = 0,
|
||||||
|
) -> Dict[str, Value]:
|
||||||
|
"""
|
||||||
|
Read the same value from multiple motors simultaneously.
|
||||||
|
Uses batched operations: sends all refresh commands, then collects all responses.
|
||||||
|
This is MUCH faster than sequential reads (OpenArms pattern).
|
||||||
|
"""
|
||||||
|
motors = self._get_motors_list(motors)
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# Step 1: Send refresh commands to ALL motors first (no waiting)
|
||||||
|
for motor in motors:
|
||||||
|
motor_id = self._get_motor_id(motor)
|
||||||
|
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||||
|
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
|
||||||
|
self.canbus.send(msg)
|
||||||
|
|
||||||
|
# Step 2: Collect all responses at once (batch receive)
|
||||||
|
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in motors]
|
||||||
|
responses = self._recv_all_responses(expected_recv_ids, timeout=0.01) # 10ms total timeout
|
||||||
|
|
||||||
|
# Step 3: Parse responses
|
||||||
|
for motor in motors:
|
||||||
|
try:
|
||||||
|
recv_id = self._get_motor_recv_id(motor)
|
||||||
|
msg = responses.get(recv_id)
|
||||||
|
|
||||||
|
if msg is None:
|
||||||
|
logger.warning(f"No response from motor '{motor}' (recv ID: 0x{recv_id:02X})")
|
||||||
|
result[motor] = 0.0
|
||||||
|
continue
|
||||||
|
|
||||||
|
motor_type = self._motor_types.get(motor, MotorType.DM4310)
|
||||||
|
position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
|
||||||
|
|
||||||
|
# Return requested data
|
||||||
|
if data_name == "Present_Position":
|
||||||
|
value = position_degrees
|
||||||
|
elif data_name == "Present_Velocity":
|
||||||
|
value = velocity_deg_per_sec
|
||||||
|
elif data_name == "Present_Torque":
|
||||||
|
value = torque
|
||||||
|
elif data_name == "Temperature_MOS":
|
||||||
|
value = t_mos
|
||||||
|
elif data_name == "Temperature_Rotor":
|
||||||
|
value = t_rotor
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown data_name: {data_name}")
|
||||||
|
|
||||||
|
result[motor] = value
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to read {data_name} from {motor}: {e}")
|
||||||
|
result[motor] = 0.0
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def sync_read_all_states(
|
||||||
|
self,
|
||||||
|
motors: str | list[str] | None = None,
|
||||||
|
*,
|
||||||
|
num_retry: int = 0,
|
||||||
|
) -> Dict[str, Dict[str, Value]]:
|
||||||
|
"""
|
||||||
|
Read ALL motor states (position, velocity, torque) from multiple motors in ONE refresh cycle.
|
||||||
|
This is 3x faster than calling sync_read() three times separately.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping motor names to state dicts with keys: 'position', 'velocity', 'torque'
|
||||||
|
Example: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
|
||||||
|
"""
|
||||||
|
motors = self._get_motors_list(motors)
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# Step 1: Send refresh commands to ALL motors first (with small delays to reduce bus congestion)
|
||||||
|
for motor in motors:
|
||||||
|
motor_id = self._get_motor_id(motor)
|
||||||
|
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||||
|
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
|
||||||
|
self.canbus.send(msg)
|
||||||
|
time.sleep(0.0001) # 100us delay between commands to reduce bus congestion
|
||||||
|
|
||||||
|
# Step 2: Collect all responses at once (batch receive)
|
||||||
|
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in motors]
|
||||||
|
responses = self._recv_all_responses(expected_recv_ids, timeout=0.015) # 15ms timeout (increased for reliability)
|
||||||
|
|
||||||
|
# Step 3: Parse responses and extract ALL state values
|
||||||
|
for motor in motors:
|
||||||
|
try:
|
||||||
|
recv_id = self._get_motor_recv_id(motor)
|
||||||
|
msg = responses.get(recv_id)
|
||||||
|
|
||||||
|
if msg is None:
|
||||||
|
logger.warning(f"No response from motor '{motor}' (recv ID: 0x{recv_id:02X})")
|
||||||
|
result[motor] = {"position": 0.0, "velocity": 0.0, "torque": 0.0}
|
||||||
|
continue
|
||||||
|
|
||||||
|
motor_type = self._motor_types.get(motor, MotorType.DM4310)
|
||||||
|
position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
|
||||||
|
|
||||||
|
# Return all state values in one dict
|
||||||
|
result[motor] = {
|
||||||
|
"position": position_degrees,
|
||||||
|
"velocity": velocity_deg_per_sec,
|
||||||
|
"torque": torque,
|
||||||
|
"temp_mos": t_mos,
|
||||||
|
"temp_rotor": t_rotor,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to read state from {motor}: {e}")
|
||||||
|
result[motor] = {"position": 0.0, "velocity": 0.0, "torque": 0.0}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def sync_write(
|
||||||
|
self,
|
||||||
|
data_name: str,
|
||||||
|
values: Dict[str, Value],
|
||||||
|
*,
|
||||||
|
normalize: bool = True,
|
||||||
|
num_retry: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Write different values to multiple motors simultaneously. Positions are always in degrees.
|
||||||
|
Uses batched operations: sends all commands first, then collects responses (OpenArms pattern).
|
||||||
|
"""
|
||||||
|
if data_name == "Goal_Position":
|
||||||
|
# Step 1: Send all MIT control commands first (no waiting)
|
||||||
|
for motor, value_degrees in values.items():
|
||||||
|
motor_id = self._get_motor_id(motor)
|
||||||
|
motor_name = self._get_motor_name(motor)
|
||||||
|
motor_type = self._motor_types.get(motor_name, MotorType.DM4310)
|
||||||
|
|
||||||
|
# Convert degrees to radians
|
||||||
|
position_rad = np.radians(value_degrees)
|
||||||
|
|
||||||
|
# Default gains for position control
|
||||||
|
kp, kd = 10.0, 0.5
|
||||||
|
|
||||||
|
# Get motor limits and encode parameters
|
||||||
|
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||||
|
kp_uint = self._float_to_uint(kp, 0, 500, 12)
|
||||||
|
kd_uint = self._float_to_uint(kd, 0, 5, 12)
|
||||||
|
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
|
||||||
|
dq_uint = self._float_to_uint(0, -vmax, vmax, 12)
|
||||||
|
tau_uint = self._float_to_uint(0, -tmax, tmax, 12)
|
||||||
|
|
||||||
|
# Pack data
|
||||||
|
data = [0] * 8
|
||||||
|
data[0] = (q_uint >> 8) & 0xFF
|
||||||
|
data[1] = q_uint & 0xFF
|
||||||
|
data[2] = dq_uint >> 4
|
||||||
|
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
|
||||||
|
data[4] = kp_uint & 0xFF
|
||||||
|
data[5] = kd_uint >> 4
|
||||||
|
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
|
||||||
|
data[7] = tau_uint & 0xFF
|
||||||
|
|
||||||
|
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||||
|
self.canbus.send(msg)
|
||||||
|
time.sleep(0.0001) # 100us delay between commands to reduce bus congestion
|
||||||
|
|
||||||
|
# Step 2: Collect all responses at once
|
||||||
|
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in values.keys()]
|
||||||
|
self._recv_all_responses(expected_recv_ids, timeout=0.015) # 15ms timeout (increased for reliability)
|
||||||
|
else:
|
||||||
|
# Fall back to individual writes for other data types
|
||||||
|
for motor, value in values.items():
|
||||||
|
self.write(data_name, motor, value, normalize=normalize, num_retry=num_retry)
|
||||||
|
|
||||||
|
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||||
|
"""Read calibration data from motors."""
|
||||||
|
# Damiao motors don't store calibration internally
|
||||||
|
# Return existing calibration or empty dict
|
||||||
|
return self.calibration if self.calibration else {}
|
||||||
|
|
||||||
|
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
|
||||||
|
"""Write calibration data to motors."""
|
||||||
|
# Damiao motors don't store calibration internally
|
||||||
|
# Just cache it in memory
|
||||||
|
if cache:
|
||||||
|
self.calibration = calibration_dict
|
||||||
|
|
||||||
|
def record_ranges_of_motion(
|
||||||
|
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
|
||||||
|
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
|
||||||
|
"""
|
||||||
|
Interactively record the min/max values of each motor in degrees.
|
||||||
|
|
||||||
|
Move the joints by hand (with torque disabled) while the method streams live positions.
|
||||||
|
Press Enter to finish.
|
||||||
|
"""
|
||||||
|
if motors is None:
|
||||||
|
motors = list(self.motors.keys())
|
||||||
|
elif isinstance(motors, (str, int)):
|
||||||
|
motors = [motors]
|
||||||
|
|
||||||
|
# Disable torque for manual movement
|
||||||
|
self.disable_torque(motors)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Get initial positions (already in degrees)
|
||||||
|
start_positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||||
|
mins = start_positions.copy()
|
||||||
|
maxes = start_positions.copy()
|
||||||
|
|
||||||
|
print("\nMove joints through their full range of motion. Press ENTER when done.")
|
||||||
|
user_pressed_enter = False
|
||||||
|
|
||||||
|
while not user_pressed_enter:
|
||||||
|
positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||||
|
|
||||||
|
for motor in motors:
|
||||||
|
if motor in positions:
|
||||||
|
mins[motor] = min(positions[motor], mins.get(motor, positions[motor]))
|
||||||
|
maxes[motor] = max(positions[motor], maxes.get(motor, positions[motor]))
|
||||||
|
|
||||||
|
if display_values:
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}")
|
||||||
|
print("-" * 50)
|
||||||
|
for motor in motors:
|
||||||
|
if motor in positions:
|
||||||
|
print(f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}")
|
||||||
|
|
||||||
|
if enter_pressed():
|
||||||
|
user_pressed_enter = True
|
||||||
|
|
||||||
|
if display_values and not user_pressed_enter:
|
||||||
|
# Move cursor up to overwrite the previous output
|
||||||
|
move_cursor_up(len(motors) + 4)
|
||||||
|
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
# Re-enable torque
|
||||||
|
self.enable_torque(motors)
|
||||||
|
|
||||||
|
# Validate ranges
|
||||||
|
for motor in motors:
|
||||||
|
if motor in mins and motor in maxes:
|
||||||
|
if abs(maxes[motor] - mins[motor]) < 5.0: # At least 5 degrees of range
|
||||||
|
raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)")
|
||||||
|
|
||||||
|
return mins, maxes
|
||||||
|
|
||||||
|
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
|
||||||
|
"""Convert motor specification to list of motor names."""
|
||||||
|
if motors is None:
|
||||||
|
return list(self.motors.keys())
|
||||||
|
elif isinstance(motors, str):
|
||||||
|
return [motors]
|
||||||
|
elif isinstance(motors, list):
|
||||||
|
return motors
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Invalid motors type: {type(motors)}")
|
||||||
|
|
||||||
|
def _get_motor_id(self, motor: NameOrID) -> int:
|
||||||
|
"""Get CAN ID for a motor."""
|
||||||
|
if isinstance(motor, str):
|
||||||
|
if motor in self.motors:
|
||||||
|
return self.motors[motor].id
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown motor: {motor}")
|
||||||
|
else:
|
||||||
|
return motor
|
||||||
|
|
||||||
|
def _get_motor_name(self, motor: NameOrID) -> str:
|
||||||
|
"""Get motor name from name or ID."""
|
||||||
|
if isinstance(motor, str):
|
||||||
|
return motor
|
||||||
|
else:
|
||||||
|
for name, m in self.motors.items():
|
||||||
|
if m.id == motor:
|
||||||
|
return name
|
||||||
|
raise ValueError(f"Unknown motor ID: {motor}")
|
||||||
|
|
||||||
|
def _get_motor_recv_id(self, motor: NameOrID) -> Optional[int]:
|
||||||
|
"""Get motor recv_id from name or ID."""
|
||||||
|
motor_name = self._get_motor_name(motor)
|
||||||
|
motor_obj = self.motors.get(motor_name)
|
||||||
|
if motor_obj and hasattr(motor_obj, "recv_id"):
|
||||||
|
return motor_obj.recv_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
"""Check if motors are calibrated."""
|
||||||
|
return bool(self.calibration)
|
||||||
209
src/lerobot/motors/damiao/tables.py
Normal file
209
src/lerobot/motors/damiao/tables.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Configuration tables for Damiao motors."""
|
||||||
|
|
||||||
|
from enum import IntEnum
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
# Motor type definitions
|
||||||
|
class MotorType(IntEnum):
|
||||||
|
DM3507 = 0
|
||||||
|
DM4310 = 1
|
||||||
|
DM4310_48V = 2
|
||||||
|
DM4340 = 3
|
||||||
|
DM4340_48V = 4
|
||||||
|
DM6006 = 5
|
||||||
|
DM8006 = 6
|
||||||
|
DM8009 = 7
|
||||||
|
DM10010L = 8
|
||||||
|
DM10010 = 9
|
||||||
|
DMH3510 = 10
|
||||||
|
DMH6215 = 11
|
||||||
|
DMG6220 = 12
|
||||||
|
|
||||||
|
# Control modes
|
||||||
|
class ControlMode(IntEnum):
|
||||||
|
MIT = 1
|
||||||
|
POS_VEL = 2
|
||||||
|
VEL = 3
|
||||||
|
TORQUE_POS = 4
|
||||||
|
|
||||||
|
# Motor variable IDs (RID)
|
||||||
|
class MotorVariable(IntEnum):
|
||||||
|
UV_VALUE = 0
|
||||||
|
KT_VALUE = 1
|
||||||
|
OT_VALUE = 2
|
||||||
|
OC_VALUE = 3
|
||||||
|
ACC = 4
|
||||||
|
DEC = 5
|
||||||
|
MAX_SPD = 6
|
||||||
|
MST_ID = 7
|
||||||
|
ESC_ID = 8
|
||||||
|
TIMEOUT = 9
|
||||||
|
CTRL_MODE = 10
|
||||||
|
DAMP = 11
|
||||||
|
INERTIA = 12
|
||||||
|
HW_VER = 13
|
||||||
|
SW_VER = 14
|
||||||
|
SN = 15
|
||||||
|
NPP = 16
|
||||||
|
RS = 17
|
||||||
|
LS = 18
|
||||||
|
FLUX = 19
|
||||||
|
GR = 20
|
||||||
|
PMAX = 21
|
||||||
|
VMAX = 22
|
||||||
|
TMAX = 23
|
||||||
|
I_BW = 24
|
||||||
|
KP_ASR = 25
|
||||||
|
KI_ASR = 26
|
||||||
|
KP_APR = 27
|
||||||
|
KI_APR = 28
|
||||||
|
OV_VALUE = 29
|
||||||
|
GREF = 30
|
||||||
|
DETA = 31
|
||||||
|
V_BW = 32
|
||||||
|
IQ_C1 = 33
|
||||||
|
VL_C1 = 34
|
||||||
|
CAN_BR = 35
|
||||||
|
SUB_VER = 36
|
||||||
|
U_OFF = 50
|
||||||
|
V_OFF = 51
|
||||||
|
K1 = 52
|
||||||
|
K2 = 53
|
||||||
|
M_OFF = 54
|
||||||
|
DIR = 55
|
||||||
|
P_M = 80
|
||||||
|
XOUT = 81
|
||||||
|
|
||||||
|
# Motor limit parameters [PMAX, VMAX, TMAX]
|
||||||
|
# PMAX: Maximum position (rad)
|
||||||
|
# VMAX: Maximum velocity (rad/s)
|
||||||
|
# TMAX: Maximum torque (N·m)
|
||||||
|
MOTOR_LIMIT_PARAMS = {
|
||||||
|
MotorType.DM3507: (12.5, 30, 10),
|
||||||
|
MotorType.DM4310: (12.5, 30, 10),
|
||||||
|
MotorType.DM4310_48V: (12.5, 50, 10),
|
||||||
|
MotorType.DM4340: (12.5, 8, 28),
|
||||||
|
MotorType.DM4340_48V: (12.5, 10, 28),
|
||||||
|
MotorType.DM6006: (12.5, 45, 20),
|
||||||
|
MotorType.DM8006: (12.5, 45, 40),
|
||||||
|
MotorType.DM8009: (12.5, 45, 54),
|
||||||
|
MotorType.DM10010L: (12.5, 25, 200),
|
||||||
|
MotorType.DM10010: (12.5, 20, 200),
|
||||||
|
MotorType.DMH3510: (12.5, 280, 1),
|
||||||
|
MotorType.DMH6215: (12.5, 45, 10),
|
||||||
|
MotorType.DMG6220: (12.5, 45, 10),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Motor model names
|
||||||
|
MODEL_NAMES = {
|
||||||
|
MotorType.DM3507: "dm3507",
|
||||||
|
MotorType.DM4310: "dm4310",
|
||||||
|
MotorType.DM4310_48V: "dm4310_48v",
|
||||||
|
MotorType.DM4340: "dm4340",
|
||||||
|
MotorType.DM4340_48V: "dm4340_48v",
|
||||||
|
MotorType.DM6006: "dm6006",
|
||||||
|
MotorType.DM8006: "dm8006",
|
||||||
|
MotorType.DM8009: "dm8009",
|
||||||
|
MotorType.DM10010L: "dm10010l",
|
||||||
|
MotorType.DM10010: "dm10010",
|
||||||
|
MotorType.DMH3510: "dmh3510",
|
||||||
|
MotorType.DMH6215: "dmh6215",
|
||||||
|
MotorType.DMG6220: "dmg6220",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Motor resolution table (encoder counts per revolution)
|
||||||
|
MODEL_RESOLUTION = {
|
||||||
|
"dm3507": 65536,
|
||||||
|
"dm4310": 65536,
|
||||||
|
"dm4310_48v": 65536,
|
||||||
|
"dm4340": 65536,
|
||||||
|
"dm4340_48v": 65536,
|
||||||
|
"dm6006": 65536,
|
||||||
|
"dm8006": 65536,
|
||||||
|
"dm8009": 65536,
|
||||||
|
"dm10010l": 65536,
|
||||||
|
"dm10010": 65536,
|
||||||
|
"dmh3510": 65536,
|
||||||
|
"dmh6215": 65536,
|
||||||
|
"dmg6220": 65536,
|
||||||
|
}
|
||||||
|
|
||||||
|
# CAN baudrates supported by Damiao motors
|
||||||
|
AVAILABLE_BAUDRATES = [
|
||||||
|
125000, # 0: 125 kbps
|
||||||
|
200000, # 1: 200 kbps
|
||||||
|
250000, # 2: 250 kbps
|
||||||
|
500000, # 3: 500 kbps
|
||||||
|
1000000, # 4: 1 mbps (default for OpenArms)
|
||||||
|
2000000, # 5: 2 mbps
|
||||||
|
2500000, # 6: 2.5 mbps
|
||||||
|
3200000, # 7: 3.2 mbps
|
||||||
|
4000000, # 8: 4 mbps
|
||||||
|
5000000, # 9: 5 mbps
|
||||||
|
]
|
||||||
|
DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms
|
||||||
|
|
||||||
|
# Default timeout in milliseconds
|
||||||
|
DEFAULT_TIMEOUT_MS = 1000
|
||||||
|
|
||||||
|
# Data that should be normalized
|
||||||
|
NORMALIZED_DATA = ["Present_Position", "Goal_Position"]
|
||||||
|
|
||||||
|
# OpenArms specific configurations
|
||||||
|
# Based on: https://docs.openarm.dev/software/setup/configure-test
|
||||||
|
# OpenArms has 7 DOF per arm (14 total for dual arm)
|
||||||
|
OPENARMS_ARM_MOTOR_IDS = {
|
||||||
|
"joint_1": {"send": 0x01, "recv": 0x11}, # J1 - Shoulder pan
|
||||||
|
"joint_2": {"send": 0x02, "recv": 0x12}, # J2 - Shoulder lift
|
||||||
|
"joint_3": {"send": 0x03, "recv": 0x13}, # J3 - Elbow flex
|
||||||
|
"joint_4": {"send": 0x04, "recv": 0x14}, # J4 - Wrist flex
|
||||||
|
"joint_5": {"send": 0x05, "recv": 0x15}, # J5 - Wrist roll
|
||||||
|
"joint_6": {"send": 0x06, "recv": 0x16}, # J6 - Wrist pitch
|
||||||
|
"joint_7": {"send": 0x07, "recv": 0x17}, # J7 - Wrist rotation
|
||||||
|
}
|
||||||
|
|
||||||
|
OPENARMS_GRIPPER_MOTOR_IDS = {
|
||||||
|
"gripper": {"send": 0x08, "recv": 0x18}, # J8 - Gripper
|
||||||
|
}
|
||||||
|
|
||||||
|
# Default motor types for OpenArms
|
||||||
|
OPENARMS_DEFAULT_MOTOR_TYPES = {
|
||||||
|
"joint_1": MotorType.DM8009, # Shoulder pan - high torque
|
||||||
|
"joint_2": MotorType.DM8009, # Shoulder lift - high torque
|
||||||
|
"joint_3": MotorType.DM4340, # Shoulder rotation
|
||||||
|
"joint_4": MotorType.DM4340, # Elbow flex
|
||||||
|
"joint_5": MotorType.DM4310, # Wrist roll
|
||||||
|
"joint_6": MotorType.DM4310, # Wrist pitch
|
||||||
|
"joint_7": MotorType.DM4310, # Wrist rotation
|
||||||
|
"gripper": MotorType.DM4310, # Gripper
|
||||||
|
}
|
||||||
|
|
||||||
|
# MIT control parameter ranges
|
||||||
|
MIT_KP_RANGE = (0.0, 500.0)
|
||||||
|
MIT_KD_RANGE = (0.0, 5.0)
|
||||||
|
|
||||||
|
# CAN frame command IDs
|
||||||
|
CAN_CMD_ENABLE = 0xFC
|
||||||
|
CAN_CMD_DISABLE = 0xFD
|
||||||
|
CAN_CMD_SET_ZERO = 0xFE
|
||||||
|
CAN_CMD_REFRESH = 0xCC
|
||||||
|
CAN_CMD_QUERY_PARAM = 0x33
|
||||||
|
CAN_CMD_WRITE_PARAM = 0x55
|
||||||
|
CAN_CMD_SAVE_PARAM = 0xAA
|
||||||
|
|
||||||
|
# CAN ID for parameter operations
|
||||||
|
CAN_PARAM_ID = 0x7FF
|
||||||
@@ -24,7 +24,7 @@ from enum import Enum
|
|||||||
|
|
||||||
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
|
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
|
||||||
|
|
||||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
|
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
||||||
from .tables import (
|
from .tables import (
|
||||||
AVAILABLE_BAUDRATES,
|
AVAILABLE_BAUDRATES,
|
||||||
MODEL_BAUDRATE_TABLE,
|
MODEL_BAUDRATE_TABLE,
|
||||||
@@ -100,7 +100,7 @@ def _split_into_byte_chunks(value: int, length: int) -> list[int]:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class DynamixelMotorsBus(MotorsBus):
|
class DynamixelMotorsBus(SerialMotorsBus):
|
||||||
"""
|
"""
|
||||||
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
|
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
|
||||||
the motors. For more info, see the Dynamixel SDK Documentation:
|
the motors. For more info, see the Dynamixel SDK Documentation:
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from pprint import pformat
|
|||||||
|
|
||||||
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||||
|
|
||||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
|
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
||||||
from .tables import (
|
from .tables import (
|
||||||
FIRMWARE_MAJOR_VERSION,
|
FIRMWARE_MAJOR_VERSION,
|
||||||
FIRMWARE_MINOR_VERSION,
|
FIRMWARE_MINOR_VERSION,
|
||||||
@@ -96,7 +96,7 @@ def patch_setPacketTimeout(self, packet_length): # noqa: N802
|
|||||||
self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50
|
self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50
|
||||||
|
|
||||||
|
|
||||||
class FeetechMotorsBus(MotorsBus):
|
class FeetechMotorsBus(SerialMotorsBus):
|
||||||
"""
|
"""
|
||||||
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the
|
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the
|
||||||
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
|
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
|
||||||
@@ -165,7 +165,7 @@ class FeetechMotorsBus(MotorsBus):
|
|||||||
|
|
||||||
def _handshake(self) -> None:
|
def _handshake(self) -> None:
|
||||||
self._assert_motors_exist()
|
self._assert_motors_exist()
|
||||||
self._assert_same_firmware()
|
#self._assert_same_firmware()
|
||||||
|
|
||||||
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
|
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
|
||||||
if self.protocol_version == 0:
|
if self.protocol_version == 0:
|
||||||
|
|||||||
@@ -19,6 +19,8 @@
|
|||||||
# TODO(aliberts): Add block noqa when feature below is available
|
# TODO(aliberts): Add block noqa when feature below is available
|
||||||
# https://github.com/astral-sh/ruff/issues/3711
|
# https://github.com/astral-sh/ruff/issues/3711
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@@ -41,6 +43,92 @@ Value: TypeAlias = int | float
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MotorsBusBase(abc.ABC):
|
||||||
|
"""
|
||||||
|
Base class for all motor bus implementations.
|
||||||
|
|
||||||
|
This is a minimal interface that all motor buses must implement, regardless of their
|
||||||
|
communication protocol (serial, CAN, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
port: str,
|
||||||
|
motors: dict[str, Motor],
|
||||||
|
calibration: dict[str, MotorCalibration] | None = None,
|
||||||
|
):
|
||||||
|
self.port = port
|
||||||
|
self.motors = motors
|
||||||
|
self.calibration = calibration if calibration else {}
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def connect(self, handshake: bool = True) -> None:
|
||||||
|
"""Establish connection to the motors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def disconnect(self, disable_torque: bool = True) -> None:
|
||||||
|
"""Disconnect from the motors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
"""Check if connected to the motors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def read(self, data_name: str, motor: str, *, normalize: bool = True, num_retry: int = 0) -> Value:
|
||||||
|
"""Read a value from a single motor."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def write(
|
||||||
|
self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0
|
||||||
|
) -> None:
|
||||||
|
"""Write a value to a single motor."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def sync_read(
|
||||||
|
self, data_name: str, motors: str | list[str] | None = None, *, normalize: bool = True
|
||||||
|
) -> dict[str, Value]:
|
||||||
|
"""Read a value from multiple motors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def sync_write(
|
||||||
|
self,
|
||||||
|
data_name: str,
|
||||||
|
values: Value | dict[str, Value],
|
||||||
|
motors: str | list[str] | None = None,
|
||||||
|
*,
|
||||||
|
normalize: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Write values to multiple motors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||||
|
"""Enable torque on selected motors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||||
|
"""Disable torque on selected motors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||||
|
"""Read calibration parameters from the motors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
|
||||||
|
"""Write calibration parameters to the motors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
|
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
|
||||||
ctrl_table = model_ctrl_table.get(model)
|
ctrl_table = model_ctrl_table.get(model)
|
||||||
if ctrl_table is None:
|
if ctrl_table is None:
|
||||||
@@ -203,15 +291,15 @@ class GroupSyncWrite(Protocol):
|
|||||||
def txPacket(self): ...
|
def txPacket(self): ...
|
||||||
|
|
||||||
|
|
||||||
class MotorsBus(abc.ABC):
|
class SerialMotorsBus(MotorsBusBase):
|
||||||
"""
|
"""
|
||||||
A MotorsBus allows to efficiently read and write to the attached motors.
|
A SerialMotorsBus allows to efficiently read and write to motors connected via serial communication.
|
||||||
It represents several motors daisy-chained together and connected through a serial port.
|
It represents several motors daisy-chained together and connected through a serial port.
|
||||||
There are currently two implementations of this abstract class:
|
There are currently two implementations of this class:
|
||||||
- DynamixelMotorsBus
|
- DynamixelMotorsBus
|
||||||
- FeetechMotorsBus
|
- FeetechMotorsBus
|
||||||
|
|
||||||
Note: This class may evolve in the future should we add support for other types of bus.
|
This class is specifically for serial-based motor protocols (Dynamixel, Feetech, etc.).
|
||||||
|
|
||||||
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
||||||
To find the port, you can run our utility script:
|
To find the port, you can run our utility script:
|
||||||
@@ -1212,3 +1300,7 @@ class MotorsBus(abc.ABC):
|
|||||||
for id_, value in ids_values.items():
|
for id_, value in ids_values.items():
|
||||||
data = self._serialize_data(value, length)
|
data = self._serialize_data(value, length)
|
||||||
self.sync_writer.addParam(id_, data)
|
self.sync_writer.addParam(id_, data)
|
||||||
|
|
||||||
|
|
||||||
|
# Backward compatibility alias
|
||||||
|
MotorsBus = SerialMotorsBus
|
||||||
|
|||||||
@@ -104,6 +104,107 @@ class SGDConfig(OptimizerConfig):
|
|||||||
return torch.optim.SGD(params, **kwargs)
|
return torch.optim.SGD(params, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@OptimizerConfig.register_subclass("xvla-adamw")
|
||||||
|
@dataclass
|
||||||
|
class XVLAAdamWConfig(OptimizerConfig):
|
||||||
|
"""Custom AdamW optimizer for XVLA with differential learning rates.
|
||||||
|
|
||||||
|
The Vision-Language Model (VLM) is trained with 1/10 of the base learning rate
|
||||||
|
for stable optimization, while all other components use the full LR.
|
||||||
|
|
||||||
|
This LR ratio is crucial for achieving strong and stable finetuning performance.
|
||||||
|
|
||||||
|
Soft-prompts can optionally use a separate learning rate with warm-up support.
|
||||||
|
Set `soft_prompt_lr_scale` to a value < 1.0 (e.g., 0.1) to start soft-prompts
|
||||||
|
at a lower LR. Combine with a warmup scheduler for optimal results.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Completely matching official reported performance may require an additional
|
||||||
|
warm-up LR schedule for soft-prompts, which can bring minor improvements.
|
||||||
|
When `soft_prompt_warmup_lr_scale` is set, soft-prompts start at
|
||||||
|
`lr * soft_prompt_warmup_lr_scale` and should be warmed up via the scheduler.
|
||||||
|
|
||||||
|
Parameter Groups:
|
||||||
|
- Group 0 (vlm): VLM parameters at lr * 0.1, weight_decay * 0.1
|
||||||
|
- Group 1 (soft_prompts): Soft-prompt parameters at lr * soft_prompt_lr_scale
|
||||||
|
- Group 2 (other): All other parameters at full lr
|
||||||
|
"""
|
||||||
|
|
||||||
|
lr: float = 1e-4
|
||||||
|
betas: tuple[float, float] = (0.9, 0.99)
|
||||||
|
eps: float = 1e-8
|
||||||
|
weight_decay: float = 0.0
|
||||||
|
grad_clip_norm: float = 10.0
|
||||||
|
# Soft-prompt specific settings
|
||||||
|
soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR (1.0 = same as base LR)
|
||||||
|
soft_prompt_warmup_lr_scale: float | None = None # If set, start soft-prompts at this scale (e.g., 0.01)
|
||||||
|
|
||||||
|
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||||
|
"""
|
||||||
|
Build AdamW optimizer with differential learning rates.
|
||||||
|
|
||||||
|
Expects `named_parameters()` as input (dict of name -> param).
|
||||||
|
Applies:
|
||||||
|
- lr * 0.1 for all VLM-related parameters
|
||||||
|
- lr * soft_prompt_lr_scale for soft-prompt parameters (with optional warmup)
|
||||||
|
- full lr for all other parameters
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: Dictionary of parameter names to parameters (from named_parameters())
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AdamW optimizer with parameter groups for VLM, soft-prompts, and other components
|
||||||
|
"""
|
||||||
|
assert isinstance(params, dict), "Custom LR optimizer requires `named_parameters()` as inputs."
|
||||||
|
|
||||||
|
vlm_group, soft_prompt_group, other_group = [], [], []
|
||||||
|
for name, p in params.items():
|
||||||
|
if not p.requires_grad:
|
||||||
|
continue
|
||||||
|
if "vlm" in name.lower():
|
||||||
|
vlm_group.append(p)
|
||||||
|
elif "soft_prompt" in name.lower():
|
||||||
|
soft_prompt_group.append(p)
|
||||||
|
else:
|
||||||
|
other_group.append(p)
|
||||||
|
|
||||||
|
# Determine soft-prompt LR
|
||||||
|
soft_prompt_lr = self.lr * self.soft_prompt_lr_scale
|
||||||
|
if self.soft_prompt_warmup_lr_scale is not None:
|
||||||
|
# Start at warmup scale, scheduler will warm up to soft_prompt_lr
|
||||||
|
soft_prompt_lr = self.lr * self.soft_prompt_warmup_lr_scale
|
||||||
|
|
||||||
|
param_groups = [
|
||||||
|
{
|
||||||
|
"params": vlm_group,
|
||||||
|
"lr": self.lr * 0.1,
|
||||||
|
"weight_decay": self.weight_decay * 0.1,
|
||||||
|
"name": "vlm",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": soft_prompt_group,
|
||||||
|
"lr": soft_prompt_lr,
|
||||||
|
"weight_decay": self.weight_decay,
|
||||||
|
"name": "soft_prompts",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": other_group,
|
||||||
|
"lr": self.lr,
|
||||||
|
"weight_decay": self.weight_decay,
|
||||||
|
"name": "other",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Filter out empty groups
|
||||||
|
param_groups = [g for g in param_groups if len(g["params"]) > 0]
|
||||||
|
|
||||||
|
return torch.optim.AdamW(
|
||||||
|
param_groups,
|
||||||
|
betas=self.betas,
|
||||||
|
eps=self.eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@OptimizerConfig.register_subclass("multi_adam")
|
@OptimizerConfig.register_subclass("multi_adam")
|
||||||
@dataclass
|
@dataclass
|
||||||
class MultiAdamConfig(OptimizerConfig):
|
class MultiAdamConfig(OptimizerConfig):
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
|||||||
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||||
|
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ACTConfig",
|
"ACTConfig",
|
||||||
@@ -31,4 +32,5 @@ __all__ = [
|
|||||||
"TDMPCConfig",
|
"TDMPCConfig",
|
||||||
"VQBeTConfig",
|
"VQBeTConfig",
|
||||||
"GrootConfig",
|
"GrootConfig",
|
||||||
|
"XVLAConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
@@ -40,6 +41,7 @@ from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
|||||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.policies.utils import validate_visual_features_consistency
|
from lerobot.policies.utils import validate_visual_features_consistency
|
||||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
|
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||||
from lerobot.processor.converters import (
|
from lerobot.processor.converters import (
|
||||||
batch_to_transition,
|
batch_to_transition,
|
||||||
@@ -107,8 +109,15 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||||
|
|
||||||
return GrootPolicy
|
return GrootPolicy
|
||||||
|
elif name == "xvla":
|
||||||
|
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
|
||||||
|
|
||||||
|
return XVLAPolicy
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
try:
|
||||||
|
return _get_policy_cls_from_policy_name(name=name)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Policy type '{name}' is not available.") from e
|
||||||
|
|
||||||
|
|
||||||
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||||
@@ -150,8 +159,14 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
return RewardClassifierConfig(**kwargs)
|
return RewardClassifierConfig(**kwargs)
|
||||||
elif policy_type == "groot":
|
elif policy_type == "groot":
|
||||||
return GrootConfig(**kwargs)
|
return GrootConfig(**kwargs)
|
||||||
|
elif policy_type == "xvla":
|
||||||
|
return XVLAConfig(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
try:
|
||||||
|
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||||
|
return config_cls(**kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Policy type '{policy_type}' is not available.") from e
|
||||||
|
|
||||||
|
|
||||||
class ProcessorConfigKwargs(TypedDict, total=False):
|
class ProcessorConfigKwargs(TypedDict, total=False):
|
||||||
@@ -329,9 +344,24 @@ def make_pre_post_processors(
|
|||||||
config=policy_cfg,
|
config=policy_cfg,
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
elif isinstance(policy_cfg, XVLAConfig):
|
||||||
|
from lerobot.policies.xvla.processor_xvla import (
|
||||||
|
make_xvla_pre_post_processors,
|
||||||
|
)
|
||||||
|
|
||||||
|
processors = make_xvla_pre_post_processors(
|
||||||
|
config=policy_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
try:
|
||||||
|
processors = _make_processors_from_policy_config(
|
||||||
|
config=policy_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") from e
|
||||||
|
|
||||||
return processors
|
return processors
|
||||||
|
|
||||||
@@ -400,8 +430,7 @@ def make_policy(
|
|||||||
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
|
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
|
||||||
features = env_to_policy_features(env_cfg)
|
features = env_to_policy_features(env_cfg)
|
||||||
|
|
||||||
if not cfg.output_features:
|
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
|
||||||
if not cfg.input_features:
|
if not cfg.input_features:
|
||||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||||
kwargs["config"] = cfg
|
kwargs["config"] = cfg
|
||||||
@@ -425,3 +454,65 @@ def make_policy(
|
|||||||
# TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
|
# TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|
||||||
|
def _get_policy_cls_from_policy_name(name: str) -> type[PreTrainedConfig]:
|
||||||
|
"""Get policy class from its registered name using dynamic imports.
|
||||||
|
|
||||||
|
This is used as a helper function to import policies from 3rd party lerobot plugins.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the policy.
|
||||||
|
Returns:
|
||||||
|
The policy class corresponding to the given name.
|
||||||
|
"""
|
||||||
|
if name not in PreTrainedConfig.get_known_choices():
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown policy name '{name}'. Available policies: {PreTrainedConfig.get_known_choices()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
config_cls = PreTrainedConfig.get_choice_class(name)
|
||||||
|
config_cls_name = config_cls.__name__
|
||||||
|
|
||||||
|
model_name = config_cls_name.removesuffix("Config") # e.g., DiffusionConfig -> Diffusion
|
||||||
|
if model_name == config_cls_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"The config class name '{config_cls_name}' does not follow the expected naming convention."
|
||||||
|
f"Make sure it ends with 'Config'!"
|
||||||
|
)
|
||||||
|
cls_name = model_name + "Policy" # e.g., DiffusionConfig -> DiffusionPolicy
|
||||||
|
module_path = config_cls.__module__.replace(
|
||||||
|
"configuration_", "modeling_"
|
||||||
|
) # e.g., configuration_diffusion -> modeling_diffusion
|
||||||
|
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
policy_cls = getattr(module, cls_name)
|
||||||
|
return policy_cls
|
||||||
|
|
||||||
|
|
||||||
|
def _make_processors_from_policy_config(
|
||||||
|
config: PreTrainedConfig,
|
||||||
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
|
) -> tuple[Any, Any]:
|
||||||
|
"""Create pre- and post-processors from a policy configuration using dynamic imports.
|
||||||
|
|
||||||
|
This is used as a helper function to import processor factories from 3rd party lerobot plugins.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The policy configuration object.
|
||||||
|
dataset_stats: Dataset statistics for normalization.
|
||||||
|
Returns:
|
||||||
|
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
|
||||||
|
"""
|
||||||
|
|
||||||
|
policy_type = config.type
|
||||||
|
function_name = f"make_{policy_type}_pre_post_processors"
|
||||||
|
module_path = config.__class__.__module__.replace(
|
||||||
|
"configuration_", "processor_"
|
||||||
|
) # e.g., configuration_diffusion -> processor_diffusion
|
||||||
|
logging.debug(
|
||||||
|
f"Instantiating pre/post processors using function '{function_name}' from module '{module_path}'"
|
||||||
|
)
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
function = getattr(module, function_name)
|
||||||
|
return function(config, dataset_stats=dataset_stats)
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
|||||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||||
from lerobot.utils.constants import OBS_IMAGES
|
from lerobot.utils.constants import OBS_IMAGES
|
||||||
|
|
||||||
|
DEFAULT_IMAGE_SIZE = 224
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("pi0")
|
@PreTrainedConfig.register_subclass("pi0")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -51,7 +53,10 @@ class PI0Config(PreTrainedConfig):
|
|||||||
# Real-Time Chunking (RTC) configuration
|
# Real-Time Chunking (RTC) configuration
|
||||||
rtc_config: RTCConfig | None = None
|
rtc_config: RTCConfig | None = None
|
||||||
|
|
||||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
image_resolution: tuple[int, int] = (
|
||||||
|
DEFAULT_IMAGE_SIZE,
|
||||||
|
DEFAULT_IMAGE_SIZE,
|
||||||
|
) # see openpi `preprocessing_pytorch.py`
|
||||||
|
|
||||||
# Add empty images. Used to add empty cameras when no image features are present.
|
# Add empty images. Used to add empty cameras when no image features are present.
|
||||||
empty_cameras: int = 0
|
empty_cameras: int = 0
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ else:
|
|||||||
PaliGemmaForConditionalGeneration = None
|
PaliGemmaForConditionalGeneration = None
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
@@ -337,6 +337,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
action_expert_config,
|
action_expert_config,
|
||||||
use_adarms=None,
|
use_adarms=None,
|
||||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||||
|
image_size: int = DEFAULT_IMAGE_SIZE,
|
||||||
):
|
):
|
||||||
if use_adarms is None:
|
if use_adarms is None:
|
||||||
use_adarms = [False, False]
|
use_adarms = [False, False]
|
||||||
@@ -356,6 +357,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
vlm_config_hf.text_config.vocab_size = 257152
|
vlm_config_hf.text_config.vocab_size = 257152
|
||||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||||
|
vlm_config_hf.vision_config.image_size = image_size
|
||||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||||
vlm_config_hf.vision_config.projection_dim = 2048
|
vlm_config_hf.vision_config.projection_dim = 2048
|
||||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||||
@@ -519,11 +521,17 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||||
|
|
||||||
|
if config.image_resolution[0] != config.image_resolution[1]:
|
||||||
|
raise ValueError(
|
||||||
|
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
|
||||||
|
)
|
||||||
|
|
||||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||||
paligemma_config,
|
paligemma_config,
|
||||||
action_expert_config,
|
action_expert_config,
|
||||||
use_adarms=[False, False],
|
use_adarms=[False, False],
|
||||||
precision=config.dtype,
|
precision=config.dtype,
|
||||||
|
image_size=config.image_resolution[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
||||||
@@ -812,16 +820,13 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
)
|
)
|
||||||
|
|
||||||
dt = -1.0 / num_steps
|
dt = -1.0 / num_steps
|
||||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
|
||||||
|
|
||||||
x_t = noise
|
x_t = noise
|
||||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
for step in range(num_steps):
|
||||||
while time >= -dt / 2:
|
time = 1.0 + step * dt
|
||||||
expanded_time = time.expand(bsize)
|
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||||
|
|
||||||
# Define a closure function to properly capture expanded_time
|
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
|
||||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
|
||||||
return self.denoise_step(
|
return self.denoise_step(
|
||||||
state=state,
|
state=state,
|
||||||
prefix_pad_masks=prefix_pad_masks,
|
prefix_pad_masks=prefix_pad_masks,
|
||||||
@@ -846,15 +851,11 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
else:
|
else:
|
||||||
v_t = denoise_step_partial_call(x_t)
|
v_t = denoise_step_partial_call(x_t)
|
||||||
|
|
||||||
# Euler step
|
x_t = x_t + dt * v_t
|
||||||
x_t += dt * v_t
|
|
||||||
|
|
||||||
# Record x_t and v_t after Euler step
|
|
||||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||||
|
|
||||||
time += dt
|
|
||||||
|
|
||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
def denoise_step(
|
def denoise_step(
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ from lerobot.optim.optimizers import AdamWConfig
|
|||||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||||
|
|
||||||
|
DEFAULT_IMAGE_SIZE = 224
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("pi05")
|
@PreTrainedConfig.register_subclass("pi05")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -50,7 +52,10 @@ class PI05Config(PreTrainedConfig):
|
|||||||
# Real-Time Chunking (RTC) configuration
|
# Real-Time Chunking (RTC) configuration
|
||||||
rtc_config: RTCConfig | None = None
|
rtc_config: RTCConfig | None = None
|
||||||
|
|
||||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
image_resolution: tuple[int, int] = (
|
||||||
|
DEFAULT_IMAGE_SIZE,
|
||||||
|
DEFAULT_IMAGE_SIZE,
|
||||||
|
) # see openpi `preprocessing_pytorch.py`
|
||||||
|
|
||||||
# Add empty images. Used to add empty cameras when no image features are present.
|
# Add empty images. Used to add empty cameras when no image features are present.
|
||||||
empty_cameras: int = 0
|
empty_cameras: int = 0
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ else:
|
|||||||
PaliGemmaForConditionalGeneration = None
|
PaliGemmaForConditionalGeneration = None
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
@@ -336,6 +336,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
action_expert_config,
|
action_expert_config,
|
||||||
use_adarms=None,
|
use_adarms=None,
|
||||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||||
|
image_size: int = DEFAULT_IMAGE_SIZE,
|
||||||
):
|
):
|
||||||
if use_adarms is None:
|
if use_adarms is None:
|
||||||
use_adarms = [False, False]
|
use_adarms = [False, False]
|
||||||
@@ -355,6 +356,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
vlm_config_hf.text_config.vocab_size = 257152
|
vlm_config_hf.text_config.vocab_size = 257152
|
||||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||||
|
vlm_config_hf.vision_config.image_size = image_size
|
||||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||||
vlm_config_hf.vision_config.projection_dim = 2048
|
vlm_config_hf.vision_config.projection_dim = 2048
|
||||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||||
@@ -518,11 +520,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||||
|
|
||||||
|
if config.image_resolution[0] != config.image_resolution[1]:
|
||||||
|
raise ValueError(
|
||||||
|
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
|
||||||
|
)
|
||||||
|
|
||||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||||
paligemma_config,
|
paligemma_config,
|
||||||
action_expert_config,
|
action_expert_config,
|
||||||
use_adarms=[False, True],
|
use_adarms=[False, True],
|
||||||
precision=config.dtype,
|
precision=config.dtype,
|
||||||
|
image_size=config.image_resolution[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
||||||
@@ -538,6 +546,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
if config.compile_model:
|
if config.compile_model:
|
||||||
torch.set_float32_matmul_precision("high")
|
torch.set_float32_matmul_precision("high")
|
||||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
||||||
|
# Also compile the main forward pass used during training
|
||||||
|
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||||
|
|
||||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||||
|
|
||||||
@@ -785,16 +795,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
)
|
)
|
||||||
|
|
||||||
dt = -1.0 / num_steps
|
dt = -1.0 / num_steps
|
||||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
|
||||||
|
|
||||||
x_t = noise
|
x_t = noise
|
||||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
for step in range(num_steps):
|
||||||
while time >= -dt / 2:
|
time = 1.0 + step * dt
|
||||||
expanded_time = time.expand(bsize)
|
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||||
|
|
||||||
# Define a closure function to properly capture expanded_time
|
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
|
||||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
|
||||||
return self.denoise_step(
|
return self.denoise_step(
|
||||||
prefix_pad_masks=prefix_pad_masks,
|
prefix_pad_masks=prefix_pad_masks,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
@@ -818,15 +825,11 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
else:
|
else:
|
||||||
v_t = denoise_step_partial_call(x_t)
|
v_t = denoise_step_partial_call(x_t)
|
||||||
|
|
||||||
# Euler step
|
x_t = x_t + dt * v_t
|
||||||
x_t += dt * v_t
|
|
||||||
|
|
||||||
# Record x_t and v_t after Euler step
|
|
||||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||||
|
|
||||||
time += dt
|
|
||||||
|
|
||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
def denoise_step(
|
def denoise_step(
|
||||||
|
|||||||
@@ -527,6 +527,7 @@ class VLAFlowMatching(nn.Module):
|
|||||||
num_vlm_layers=self.config.num_vlm_layers,
|
num_vlm_layers=self.config.num_vlm_layers,
|
||||||
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
|
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
|
||||||
expert_width_multiplier=self.config.expert_width_multiplier,
|
expert_width_multiplier=self.config.expert_width_multiplier,
|
||||||
|
device=self.config.device if self.config.device is not None else "auto",
|
||||||
)
|
)
|
||||||
self.state_proj = nn.Linear(
|
self.state_proj = nn.Linear(
|
||||||
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
|
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
|
||||||
@@ -783,18 +784,15 @@ class VLAFlowMatching(nn.Module):
|
|||||||
use_cache=self.config.use_cache,
|
use_cache=self.config.use_cache,
|
||||||
fill_kv_cache=True,
|
fill_kv_cache=True,
|
||||||
)
|
)
|
||||||
dt = -1.0 / self.config.num_steps
|
num_steps = self.config.num_steps
|
||||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
dt = -1.0 / num_steps
|
||||||
|
|
||||||
x_t = noise
|
x_t = noise
|
||||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
for step in range(num_steps):
|
||||||
|
time = 1.0 + step * dt
|
||||||
|
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||||
|
|
||||||
while time >= -dt / 2:
|
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||||
expanded_time = time.expand(bsize)
|
|
||||||
|
|
||||||
# Define a closure function to properly capture expanded_time
|
|
||||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
|
||||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
|
||||||
return self.denoise_step(
|
return self.denoise_step(
|
||||||
x_t=input_x_t,
|
x_t=input_x_t,
|
||||||
prefix_pad_masks=prefix_pad_masks,
|
prefix_pad_masks=prefix_pad_masks,
|
||||||
@@ -818,15 +816,11 @@ class VLAFlowMatching(nn.Module):
|
|||||||
else:
|
else:
|
||||||
v_t = denoise_step_partial_call(x_t)
|
v_t = denoise_step_partial_call(x_t)
|
||||||
|
|
||||||
# Euler step
|
x_t = x_t + dt * v_t
|
||||||
x_t += dt * v_t
|
|
||||||
|
|
||||||
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
|
|
||||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||||
|
|
||||||
time += dt
|
|
||||||
|
|
||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
def denoise_step(
|
def denoise_step(
|
||||||
|
|||||||
6
src/lerobot/policies/xvla/__init__.py
Normal file
6
src/lerobot/policies/xvla/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# register the processor steps
|
||||||
|
from lerobot.policies.xvla.processor_xvla import (
|
||||||
|
XVLAAddDomainIdProcessorStep,
|
||||||
|
XVLAImageNetNormalizeProcessorStep,
|
||||||
|
XVLAImageToFloatProcessorStep,
|
||||||
|
)
|
||||||
588
src/lerobot/policies/xvla/action_hub.py
Normal file
588
src/lerobot/policies/xvla/action_hub.py
Normal file
@@ -0,0 +1,588 @@
|
|||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright 2025 2toINF and HuggingFace Inc. (https://github.com/2toINF)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Registry
|
||||||
|
# =============================================================================
|
||||||
|
ACTION_REGISTRY: dict[str, type[BaseActionSpace]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_action(name: str):
|
||||||
|
"""Decorator for registering a new action space."""
|
||||||
|
|
||||||
|
def _wrap(cls):
|
||||||
|
key = name.lower()
|
||||||
|
if key in ACTION_REGISTRY:
|
||||||
|
raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}")
|
||||||
|
ACTION_REGISTRY[key] = cls
|
||||||
|
cls.name = key
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return _wrap
|
||||||
|
|
||||||
|
|
||||||
|
def build_action_space(name: str, **kwargs) -> BaseActionSpace:
|
||||||
|
"""Instantiate a registered action space by name."""
|
||||||
|
key = name.lower()
|
||||||
|
if key not in ACTION_REGISTRY:
|
||||||
|
raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}")
|
||||||
|
return ACTION_REGISTRY[key](**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Base class
|
||||||
|
# =============================================================================
|
||||||
|
class BaseActionSpace(nn.Module):
|
||||||
|
"""
|
||||||
|
Abstract base class for all action-space definitions.
|
||||||
|
|
||||||
|
Each subclass defines:
|
||||||
|
- `dim_action`: dimension of the action vector.
|
||||||
|
- `gripper_idx`: indices of gripper channels.
|
||||||
|
- `compute_loss(pred, target)`: supervised loss for this space.
|
||||||
|
- `preprocess(proprio, action, mode)`: pre-step modifications.
|
||||||
|
- `postprocess(action)`: post-step corrections (e.g. apply sigmoid).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "base"
|
||||||
|
dim_action: int = 0
|
||||||
|
gripper_idx: tuple[int, ...] = ()
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------
|
||||||
|
# Core supervised loss
|
||||||
|
# ---------------------------------------------------------------------
|
||||||
|
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||||
|
"""Alias for compute_loss."""
|
||||||
|
return self.compute_loss(pred, target)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------
|
||||||
|
# Space-level hooks
|
||||||
|
# ---------------------------------------------------------------------
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
proprio: torch.Tensor,
|
||||||
|
action: torch.Tensor,
|
||||||
|
mode: str = "train",
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Default: return unchanged."""
|
||||||
|
return proprio, action
|
||||||
|
|
||||||
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Default: return unchanged."""
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Utilities
|
||||||
|
# =============================================================================
|
||||||
|
def _ensure_indices_valid(dim_action: int, idx: Iterable[int], name: str) -> None:
|
||||||
|
bad = [i for i in idx if i < 0 or i >= dim_action]
|
||||||
|
if bad:
|
||||||
|
raise IndexError(f"{name} contains out-of-range indices {bad} for action dim dim_action={dim_action}")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Implementations
|
||||||
|
# =============================================================================
|
||||||
|
@register_action("ee6d")
|
||||||
|
class EE6DActionSpace(BaseActionSpace):
|
||||||
|
"""End-effector layout with xyz, 6D rotation, and gripper channels."""
|
||||||
|
|
||||||
|
dim_action = 20
|
||||||
|
gripper_idx = (9, 19)
|
||||||
|
GRIPPER_SCALE = 1.0
|
||||||
|
XYZ_SCALE = 500.0
|
||||||
|
ROT_SCALE = 10.0
|
||||||
|
|
||||||
|
POS_IDX_1 = (0, 1, 2)
|
||||||
|
POS_IDX_2 = (10, 11, 12)
|
||||||
|
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
|
||||||
|
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
self.bce = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
|
def compute_loss(self, pred, target):
|
||||||
|
assert pred.shape == target.shape, "pred/target shapes must match"
|
||||||
|
batch_size, seq_len, action_dim = pred.shape
|
||||||
|
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||||
|
|
||||||
|
# Gripper BCE
|
||||||
|
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
||||||
|
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
||||||
|
|
||||||
|
# XYZ position
|
||||||
|
pos_loss = (
|
||||||
|
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
|
||||||
|
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
||||||
|
) * self.XYZ_SCALE
|
||||||
|
|
||||||
|
# Rotation 6D
|
||||||
|
rot_loss = (
|
||||||
|
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
|
||||||
|
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
||||||
|
) * self.ROT_SCALE
|
||||||
|
|
||||||
|
return {
|
||||||
|
"position_loss": pos_loss,
|
||||||
|
"rotate6D_loss": rot_loss,
|
||||||
|
"gripper_loss": gripper_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
def preprocess(self, proprio, action, mode="train"):
|
||||||
|
"""Zero-out gripper channels in proprio/action."""
|
||||||
|
proprio_m = proprio.clone()
|
||||||
|
action_m = action.clone()
|
||||||
|
proprio_m[..., self.gripper_idx] = 0.0
|
||||||
|
action_m[..., self.gripper_idx] = 0.0
|
||||||
|
return proprio_m, action_m
|
||||||
|
|
||||||
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply sigmoid to gripper logits."""
|
||||||
|
if action.size(-1) > max(self.gripper_idx):
|
||||||
|
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
@register_action("joint")
|
||||||
|
class JointActionSpace(BaseActionSpace):
|
||||||
|
"""Joint-space layout with joints + gripper only."""
|
||||||
|
|
||||||
|
dim_action = 14
|
||||||
|
gripper_idx = (6, 13)
|
||||||
|
GRIPPER_SCALE = 0.1
|
||||||
|
JOINTS_SCALE = 1.0
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
self.bce = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
|
def compute_loss(self, pred, target):
|
||||||
|
assert pred.shape == target.shape
|
||||||
|
batch_size, seq_len, action_dim = pred.shape
|
||||||
|
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||||
|
|
||||||
|
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
||||||
|
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
||||||
|
|
||||||
|
joints_idx = tuple(i for i in range(action_dim) if i not in set(self.gripper_idx))
|
||||||
|
joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
|
||||||
|
|
||||||
|
return {
|
||||||
|
"joints_loss": joints_loss,
|
||||||
|
"gripper_loss": gripper_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
def preprocess(self, proprio, action, mode="train"):
|
||||||
|
"""Zero-out gripper channels in proprio/action."""
|
||||||
|
proprio_m = proprio.clone()
|
||||||
|
action_m = action.clone()
|
||||||
|
proprio_m[..., self.gripper_idx] = 0.0
|
||||||
|
action_m[..., self.gripper_idx] = 0.0
|
||||||
|
return proprio_m, action_m
|
||||||
|
|
||||||
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply sigmoid to gripper logits."""
|
||||||
|
if action.size(-1) > max(self.gripper_idx):
|
||||||
|
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
@register_action("agibot_ee6d")
|
||||||
|
class AGIBOTEE6DActionSpace(BaseActionSpace):
|
||||||
|
"""AGI-bot variant of EE6DActionSpace using MSE for all components."""
|
||||||
|
|
||||||
|
dim_action = 20
|
||||||
|
gripper_idx = (9, 19)
|
||||||
|
GRIPPER_SCALE = 10.0
|
||||||
|
XYZ_SCALE = 500.0
|
||||||
|
ROT_SCALE = 10.0
|
||||||
|
POS_IDX_1 = (0, 1, 2)
|
||||||
|
POS_IDX_2 = (10, 11, 12)
|
||||||
|
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
|
||||||
|
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
|
||||||
|
def compute_loss(self, pred, target):
|
||||||
|
assert pred.shape == target.shape
|
||||||
|
batch_size, seq_len, action_dim = pred.shape
|
||||||
|
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||||
|
|
||||||
|
gripper_loss = (
|
||||||
|
self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
|
||||||
|
)
|
||||||
|
pos_loss = (
|
||||||
|
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
|
||||||
|
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
||||||
|
) * self.XYZ_SCALE
|
||||||
|
rot_loss = (
|
||||||
|
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
|
||||||
|
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
||||||
|
) * self.ROT_SCALE
|
||||||
|
|
||||||
|
return {
|
||||||
|
"position_loss": pos_loss,
|
||||||
|
"rotate6D_loss": rot_loss,
|
||||||
|
"gripper_loss": gripper_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
def preprocess(self, proprio, action, mode="train"):
|
||||||
|
"""No preprocessing applied in AGIBOT variant."""
|
||||||
|
return proprio, action
|
||||||
|
|
||||||
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""AGIBOT does not postprocess."""
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
@register_action("franka_joint7")
|
||||||
|
class FrankaJoint7ActionSpace(BaseActionSpace):
|
||||||
|
"""
|
||||||
|
Franka Panda joint-space: 7 joints, with gripper.
|
||||||
|
|
||||||
|
- Real robot action dim: 7
|
||||||
|
- Model-facing dim: 20 (padded with zeros)
|
||||||
|
compatible with pretrained VLA models expecting 20D.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dim_action = 20 # model dimension
|
||||||
|
REAL_DIM = 7 # actual Franka joints
|
||||||
|
|
||||||
|
JOINTS_SCALE = 1.0
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
|
||||||
|
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Pad 7 → 20 dims (zeros for the dummy channels)."""
|
||||||
|
if x is None:
|
||||||
|
return None
|
||||||
|
if x.size(-1) == self.dim_action:
|
||||||
|
return x
|
||||||
|
if x.size(-1) != self.REAL_DIM:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected last dim to be {self.REAL_DIM} or {self.dim_action}, got {x.size(-1)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.REAL_DIM] # 13 zeros
|
||||||
|
pad = x.new_zeros(pad_shape)
|
||||||
|
return torch.cat([x, pad], dim=-1)
|
||||||
|
|
||||||
|
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Trim model output 20 → 7 dims."""
|
||||||
|
return x[..., : self.REAL_DIM]
|
||||||
|
|
||||||
|
def compute_loss(self, pred, target):
|
||||||
|
"""
|
||||||
|
pred : [B, T, 20]
|
||||||
|
target : [B, T, 7] or [B, T, 20]
|
||||||
|
|
||||||
|
Only compute MSE on the first 7 dims.
|
||||||
|
"""
|
||||||
|
pred = self._pad_to_model_dim(pred)
|
||||||
|
target = self._pad_to_model_dim(target)
|
||||||
|
|
||||||
|
assert pred.shape == target.shape
|
||||||
|
|
||||||
|
joints_loss = (
|
||||||
|
self.mse(
|
||||||
|
pred[:, :, : self.REAL_DIM], # use only the first 7 joints
|
||||||
|
target[:, :, : self.REAL_DIM],
|
||||||
|
)
|
||||||
|
* self.JOINTS_SCALE
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"joints_loss": joints_loss}
|
||||||
|
|
||||||
|
def preprocess(self, proprio, action, mode="train"):
|
||||||
|
"""
|
||||||
|
During training:
|
||||||
|
- Pad [7] → [20]
|
||||||
|
"""
|
||||||
|
return proprio, self._pad_to_model_dim(action)
|
||||||
|
|
||||||
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
After model prediction:
|
||||||
|
- Trim [20] → [7] for real robot control.
|
||||||
|
"""
|
||||||
|
return self._trim_to_real_dim(action)
|
||||||
|
|
||||||
|
|
||||||
|
@register_action("auto")
|
||||||
|
class AutoActionSpace(BaseActionSpace):
|
||||||
|
"""
|
||||||
|
Auto-detecting action space that adapts to any action dimension.
|
||||||
|
|
||||||
|
- Auto-detects the real action dimension from the policy feature
|
||||||
|
- Model outputs max_dim for compatibility with pretrained models
|
||||||
|
- Loss is computed only on the first real_dim dimensions
|
||||||
|
- Postprocess trims output back to real_dim
|
||||||
|
|
||||||
|
Args:
|
||||||
|
real_dim: The actual action dimension from the dataset/policy feature
|
||||||
|
max_dim: The model's output dimension for pretrained VLA compatibility
|
||||||
|
"""
|
||||||
|
|
||||||
|
JOINTS_SCALE = 1.0
|
||||||
|
|
||||||
|
def __init__(self, real_dim: int, max_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.real_dim = real_dim
|
||||||
|
self.dim_action = max_dim # Model-facing dimension
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
|
||||||
|
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Pad real_dim → max_dim (zeros for the dummy channels)."""
|
||||||
|
if x is None:
|
||||||
|
return None
|
||||||
|
if x.size(-1) == self.dim_action:
|
||||||
|
return x
|
||||||
|
if x.size(-1) != self.real_dim:
|
||||||
|
# If dimension doesn't match either, pad/trim to real_dim first
|
||||||
|
if x.size(-1) < self.real_dim:
|
||||||
|
pad_shape = list(x.shape[:-1]) + [self.real_dim - x.size(-1)]
|
||||||
|
pad = x.new_zeros(pad_shape)
|
||||||
|
x = torch.cat([x, pad], dim=-1)
|
||||||
|
else:
|
||||||
|
x = x[..., : self.real_dim]
|
||||||
|
|
||||||
|
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.real_dim]
|
||||||
|
pad = x.new_zeros(pad_shape)
|
||||||
|
return torch.cat([x, pad], dim=-1)
|
||||||
|
|
||||||
|
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Trim model output max_dim → real_dim."""
|
||||||
|
return x[..., : self.real_dim]
|
||||||
|
|
||||||
|
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Compute loss only on the first real_dim dimensions.
|
||||||
|
|
||||||
|
pred: [B, T, max_dim] from the model
|
||||||
|
target: [B, T, real_dim] or [B, T, max_dim]
|
||||||
|
|
||||||
|
Loss = MSE(pred[:,:,:real_dim], target[:,:,:real_dim])
|
||||||
|
"""
|
||||||
|
pred = self._pad_to_model_dim(pred)
|
||||||
|
target = self._pad_to_model_dim(target)
|
||||||
|
assert pred.shape == target.shape, f"Shape mismatch: pred {pred.shape} vs target {target.shape}"
|
||||||
|
|
||||||
|
# only compute loss on the real dimensions
|
||||||
|
joints_loss = (
|
||||||
|
self.mse(
|
||||||
|
pred[:, :, : self.real_dim],
|
||||||
|
target[:, :, : self.real_dim],
|
||||||
|
)
|
||||||
|
* self.JOINTS_SCALE
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"joints_loss": joints_loss}
|
||||||
|
|
||||||
|
def preprocess(self, proprio: torch.Tensor, action: torch.Tensor, mode: str = "train"):
|
||||||
|
"""
|
||||||
|
Pad action from real_dim to max_dim for the model.
|
||||||
|
"""
|
||||||
|
return proprio, self._pad_to_model_dim(action)
|
||||||
|
|
||||||
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Trim model output from max_dim to real_dim for real robot control.
|
||||||
|
"""
|
||||||
|
return self._trim_to_real_dim(action)
|
||||||
|
|
||||||
|
|
||||||
|
@register_action("so101_bimanual")
|
||||||
|
class BimanualSO101ActionSpace(BaseActionSpace):
|
||||||
|
"""
|
||||||
|
Bimanual SO101 robot: 2 arms with 5 joints each + gripper.
|
||||||
|
|
||||||
|
Layout (real robot):
|
||||||
|
[left_arm (5 joints + gripper), right_arm (5 joints + gripper)]
|
||||||
|
- Left arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
|
||||||
|
- Right arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
|
||||||
|
|
||||||
|
Real action dim: 12
|
||||||
|
Model-facing dim: 20 (extra 8 dummy dims at the end)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Model output / training dimension (to match pretrained policy)
|
||||||
|
dim_action = 20
|
||||||
|
|
||||||
|
# Real robot action dimension
|
||||||
|
REAL_DIM = 12
|
||||||
|
|
||||||
|
# Indices of real vs dummy channels
|
||||||
|
REAL_IDXS = tuple(range(REAL_DIM)) # 0..11
|
||||||
|
DUMMY_IDXS = tuple(range(REAL_DIM, dim_action)) # 12..19
|
||||||
|
|
||||||
|
# Grippers live in the real part
|
||||||
|
gripper_idx = (5, 11) # left_gripper at idx 5, right_gripper at idx 11
|
||||||
|
GRIPPER_SCALE = 1.0
|
||||||
|
JOINTS_SCALE = 1.0
|
||||||
|
|
||||||
|
# Indices for left and right arm joints (excluding grippers)
|
||||||
|
LEFT_ARM_JOINTS = (0, 1, 2, 3, 4)
|
||||||
|
RIGHT_ARM_JOINTS = (6, 7, 8, 9, 10)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
self.bce = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
|
# ---------- helpers ----------
|
||||||
|
|
||||||
|
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""If last dim is REAL_DIM (12), pad zeros to reach dim_action (20)."""
|
||||||
|
if x is None:
|
||||||
|
return None
|
||||||
|
if x.size(-1) == self.dim_action:
|
||||||
|
return x
|
||||||
|
if x.size(-1) != self.REAL_DIM:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected last dim to be {self.REAL_DIM} or {self.dim_action}, got {x.size(-1)}"
|
||||||
|
)
|
||||||
|
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.REAL_DIM]
|
||||||
|
pad = x.new_zeros(pad_shape)
|
||||||
|
return torch.cat([x, pad], dim=-1)
|
||||||
|
|
||||||
|
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Keep only the first REAL_DIM (12) dims for the real robot."""
|
||||||
|
return x[..., : self.REAL_DIM]
|
||||||
|
|
||||||
|
# ---------- loss ----------
|
||||||
|
|
||||||
|
def compute_loss(self, pred, target):
|
||||||
|
"""
|
||||||
|
pred: [B, T, 20] from the model
|
||||||
|
target: [B, T, 12] or [B, T, 20]
|
||||||
|
We pad target → 20 and compute loss only on the real dims.
|
||||||
|
"""
|
||||||
|
# Ensure both are [B, T, 20]
|
||||||
|
pred = self._pad_to_model_dim(pred)
|
||||||
|
target = self._pad_to_model_dim(target)
|
||||||
|
assert pred.shape == target.shape
|
||||||
|
|
||||||
|
# ---- MSE for all real dims (0–11) ----
|
||||||
|
real_dims = 12
|
||||||
|
|
||||||
|
joints_loss = (
|
||||||
|
self.mse(
|
||||||
|
pred[:, :, :real_dims],
|
||||||
|
target[:, :, :real_dims],
|
||||||
|
)
|
||||||
|
* self.JOINTS_SCALE
|
||||||
|
)
|
||||||
|
|
||||||
|
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
|
||||||
|
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
|
||||||
|
|
||||||
|
gripper_loss = (
|
||||||
|
self.mse(
|
||||||
|
pred[:, :, [5, 11]],
|
||||||
|
target[:, :, [5, 11]],
|
||||||
|
)
|
||||||
|
* self.GRIPPER_SCALE
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"joints_loss": joints_loss,
|
||||||
|
"gripper_loss": gripper_loss,
|
||||||
|
"left_arm_loss": left_arm_loss,
|
||||||
|
"right_arm_loss": right_arm_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------- preprocess / postprocess ----------
|
||||||
|
|
||||||
|
def preprocess(self, proprio, action, mode="train"):
|
||||||
|
"""
|
||||||
|
- If proprio/action are 12-dim, pad them to 20 for the model.
|
||||||
|
- Zero-out gripper channels in proprio/action to focus learning on joints.
|
||||||
|
"""
|
||||||
|
proprio_m = self._pad_to_model_dim(proprio.clone())
|
||||||
|
action_m = self._pad_to_model_dim(action.clone()) if action is not None else None
|
||||||
|
|
||||||
|
proprio_m[..., self.gripper_idx] = 0.0
|
||||||
|
if action_m is not None:
|
||||||
|
action_m[..., self.gripper_idx] = 0.0
|
||||||
|
|
||||||
|
return proprio_m, action_m
|
||||||
|
|
||||||
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
- Model outputs [*, 20]
|
||||||
|
- Apply sigmoid to gripper logits
|
||||||
|
- Return only the first 12 dims for the real robot:
|
||||||
|
["left_shoulder_pan.pos",
|
||||||
|
"left_shoulder_lift.pos",
|
||||||
|
"left_elbow_flex.pos",
|
||||||
|
"left_wrist_flex.pos",
|
||||||
|
"left_wrist_roll.pos",
|
||||||
|
"left_gripper.pos",
|
||||||
|
"right_shoulder_pan.pos",
|
||||||
|
"right_shoulder_lift.pos",
|
||||||
|
"right_elbow_flex.pos",
|
||||||
|
"right_wrist_flex.pos",
|
||||||
|
"right_wrist_roll.pos",
|
||||||
|
"right_gripper.pos"]
|
||||||
|
"""
|
||||||
|
# Ensure we at least have the real dims + grippers
|
||||||
|
if action.size(-1) < self.REAL_DIM:
|
||||||
|
raise ValueError(f"Expected at least {self.REAL_DIM} dims in action, got {action.size(-1)}")
|
||||||
|
|
||||||
|
# Apply sigmoid on gripper channels in model space (indices 5 and 11)
|
||||||
|
if action.size(-1) > max(self.gripper_idx):
|
||||||
|
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||||
|
|
||||||
|
# Return only the real 12-dim control vector for the env
|
||||||
|
return self._trim_to_real_dim(action)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Exports
|
||||||
|
# =============================================================================
|
||||||
|
__all__ = [
|
||||||
|
"BaseActionSpace",
|
||||||
|
"build_action_space",
|
||||||
|
"register_action",
|
||||||
|
"EE6DActionSpace",
|
||||||
|
"JointActionSpace",
|
||||||
|
"AGIBOTEE6DActionSpace",
|
||||||
|
"FrankaJoint7ActionSpace",
|
||||||
|
"AutoActionSpace",
|
||||||
|
"BimanualSO101ActionSpace",
|
||||||
|
"ACTION_REGISTRY",
|
||||||
|
]
|
||||||
353
src/lerobot/policies/xvla/configuration_florence2.py
Normal file
353
src/lerobot/policies/xvla/configuration_florence2.py
Normal file
@@ -0,0 +1,353 @@
|
|||||||
|
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
""" Florence-2 configuration"""
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Florence2VisionConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
|
||||||
|
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
drop_path_rate (`float`, *optional*, defaults to 0.1):
|
||||||
|
The dropout rate of the drop path layer.
|
||||||
|
patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
|
||||||
|
The patch size of the image.
|
||||||
|
patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
|
||||||
|
The patch stride of the image.
|
||||||
|
patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
|
||||||
|
The patch padding of the image.
|
||||||
|
patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
|
||||||
|
Whether to apply layer normalization before the patch embedding layer.
|
||||||
|
enable_checkpoint (`bool`, *optional*, defaults to False):
|
||||||
|
Whether to enable checkpointing.
|
||||||
|
dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
|
||||||
|
The dimension of the embedding layer.
|
||||||
|
num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
|
||||||
|
The number of attention heads.
|
||||||
|
num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
|
||||||
|
The number of groups.
|
||||||
|
depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
|
||||||
|
The depth of the model.
|
||||||
|
window_size (`int`, *optional*, defaults to 12):
|
||||||
|
The window size of the model.
|
||||||
|
projection_dim (`int`, *optional*, defaults to 1024):
|
||||||
|
The dimension of the projection layer.
|
||||||
|
visual_temporal_embedding (`dict`, *optional*):
|
||||||
|
The configuration of the visual temporal embedding.
|
||||||
|
image_pos_embed (`dict`, *optional*):
|
||||||
|
The configuration of the image position embedding.
|
||||||
|
image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
|
||||||
|
The source of the image feature.
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import Florence2VisionConfig, Florence2VisionModel
|
||||||
|
|
||||||
|
>>> # Initializing a Florence2 Vision style configuration
|
||||||
|
>>> configuration = Florence2VisionConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model (with random weights)
|
||||||
|
>>> model = Florence2VisionModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "davit"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
drop_path_rate=0.1,
|
||||||
|
patch_size=None,
|
||||||
|
patch_stride=None,
|
||||||
|
patch_padding=None,
|
||||||
|
patch_prenorm=None,
|
||||||
|
enable_checkpoint=False,
|
||||||
|
dim_embed=None,
|
||||||
|
num_heads=None,
|
||||||
|
num_groups=None,
|
||||||
|
depths=None,
|
||||||
|
window_size=12,
|
||||||
|
projection_dim=1024,
|
||||||
|
visual_temporal_embedding=None,
|
||||||
|
image_pos_embed=None,
|
||||||
|
image_feature_source=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.drop_path_rate = drop_path_rate
|
||||||
|
self.patch_size = patch_size if patch_size is not None else [7, 3, 3, 3]
|
||||||
|
self.patch_stride = patch_stride if patch_stride is not None else [4, 2, 2, 2]
|
||||||
|
self.patch_padding = patch_padding if patch_padding is not None else [3, 1, 1, 1]
|
||||||
|
self.patch_prenorm = patch_prenorm if patch_prenorm is not None else [False, True, True, True]
|
||||||
|
self.enable_checkpoint = enable_checkpoint
|
||||||
|
self.dim_embed = dim_embed if dim_embed is not None else [256, 512, 1024, 2048]
|
||||||
|
self.num_heads = num_heads if num_heads is not None else [8, 16, 32, 64]
|
||||||
|
self.num_groups = num_groups if num_groups is not None else [8, 16, 32, 64]
|
||||||
|
self.depths = depths if depths is not None else [1, 1, 9, 1]
|
||||||
|
self.window_size = window_size
|
||||||
|
self.projection_dim = projection_dim
|
||||||
|
|
||||||
|
if visual_temporal_embedding is None:
|
||||||
|
visual_temporal_embedding = {
|
||||||
|
"type": "COSINE",
|
||||||
|
"max_temporal_embeddings": 100,
|
||||||
|
}
|
||||||
|
self.visual_temporal_embedding = visual_temporal_embedding
|
||||||
|
|
||||||
|
if image_pos_embed is None:
|
||||||
|
image_pos_embed = {
|
||||||
|
"type": "learned_abs_2d",
|
||||||
|
"max_pos_embeddings": 1000,
|
||||||
|
}
|
||||||
|
self.image_pos_embed = image_pos_embed
|
||||||
|
|
||||||
|
self.image_feature_source = (
|
||||||
|
image_feature_source
|
||||||
|
if image_feature_source is not None
|
||||||
|
else ["spatial_avg_pool", "temporal_avg_pool"]
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Florence2LanguageConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
|
||||||
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the BART
|
||||||
|
[facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 51289):
|
||||||
|
Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`Florence2LanguageModel`].
|
||||||
|
d_model (`int`, *optional*, defaults to 1024):
|
||||||
|
Dimensionality of the layers and the pooler layer.
|
||||||
|
encoder_layers (`int`, *optional*, defaults to 12):
|
||||||
|
Number of encoder layers.
|
||||||
|
decoder_layers (`int`, *optional*, defaults to 12):
|
||||||
|
Number of decoder layers.
|
||||||
|
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
||||||
|
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||||
|
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
||||||
|
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||||
|
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||||
|
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||||
|
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||||
|
dropout (`float`, *optional*, defaults to 0.1):
|
||||||
|
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
activation_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for activations inside the fully connected layer.
|
||||||
|
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for classifier.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
||||||
|
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||||
|
just in case (e.g., 512 or 1024 or 2048).
|
||||||
|
init_std (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
||||||
|
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
||||||
|
for more details.
|
||||||
|
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
||||||
|
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
||||||
|
for more details.
|
||||||
|
scale_embedding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Scale embeddings by diving by sqrt(d_model).
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||||
|
num_labels (`int`, *optional*, defaults to 3):
|
||||||
|
The number of labels to use in [`Florence2LanguageForSequenceClassification`].
|
||||||
|
forced_eos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
|
||||||
|
`eos_token_id`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
|
||||||
|
|
||||||
|
>>> # Initializing a Florence2 Language style configuration
|
||||||
|
>>> configuration = Florence2LanguageConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model (with random weights)
|
||||||
|
>>> model = Florence2LanguageModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "florence2_language"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=51289,
|
||||||
|
max_position_embeddings=1024,
|
||||||
|
encoder_layers=12,
|
||||||
|
encoder_ffn_dim=4096,
|
||||||
|
encoder_attention_heads=16,
|
||||||
|
decoder_layers=12,
|
||||||
|
decoder_ffn_dim=4096,
|
||||||
|
decoder_attention_heads=16,
|
||||||
|
encoder_layerdrop=0.0,
|
||||||
|
decoder_layerdrop=0.0,
|
||||||
|
activation_function="gelu",
|
||||||
|
d_model=1024,
|
||||||
|
dropout=0.1,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
activation_dropout=0.0,
|
||||||
|
init_std=0.02,
|
||||||
|
classifier_dropout=0.0,
|
||||||
|
scale_embedding=False,
|
||||||
|
use_cache=True,
|
||||||
|
num_labels=3,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=2,
|
||||||
|
is_encoder_decoder=True,
|
||||||
|
decoder_start_token_id=2,
|
||||||
|
forced_eos_token_id=2,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.d_model = d_model
|
||||||
|
self.encoder_ffn_dim = encoder_ffn_dim
|
||||||
|
self.encoder_layers = encoder_layers
|
||||||
|
self.encoder_attention_heads = encoder_attention_heads
|
||||||
|
self.decoder_ffn_dim = decoder_ffn_dim
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.decoder_attention_heads = decoder_attention_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.activation_dropout = activation_dropout
|
||||||
|
self.activation_function = activation_function
|
||||||
|
self.init_std = init_std
|
||||||
|
self.encoder_layerdrop = encoder_layerdrop
|
||||||
|
self.decoder_layerdrop = decoder_layerdrop
|
||||||
|
self.classifier_dropout = classifier_dropout
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.num_hidden_layers = encoder_layers
|
||||||
|
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
num_labels=num_labels,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ensure backward compatibility for BART CNN models
|
||||||
|
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
||||||
|
self.forced_bos_token_id = self.bos_token_id
|
||||||
|
warnings.warn(
|
||||||
|
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
|
||||||
|
"The config can simply be saved and uploaded again to be fixed.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Florence2Config(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
|
||||||
|
Florence-2 model according to the specified arguments, defining the model architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vision_config (`Florence2VisionConfig`, *optional*):
|
||||||
|
Custom vision config or dict
|
||||||
|
text_config (`Union[AutoConfig, dict]`, *optional*):
|
||||||
|
The config object of the text backbone.
|
||||||
|
ignore_index (`int`, *optional*, defaults to -100):
|
||||||
|
The ignore index for the loss function.
|
||||||
|
vocab_size (`int`, *optional*, defaults to 51289):
|
||||||
|
Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
|
||||||
|
projection_dim (`int`, *optional*, defaults to 1024):
|
||||||
|
Dimension of the multimodal projection space.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
|
||||||
|
|
||||||
|
>>> # Initializing a clip-like vision config
|
||||||
|
>>> vision_config = CLIPVisionConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a Bart config
|
||||||
|
>>> text_config = BartConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a Florence-2 configuration
|
||||||
|
>>> configuration = Florence2Config(vision_config, text_config)
|
||||||
|
|
||||||
|
>>> # Initializing a model from the florence-2 configuration
|
||||||
|
>>> model = Florence2ForConditionalGeneration(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "florence2"
|
||||||
|
is_composition = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vision_config=None,
|
||||||
|
text_config=None,
|
||||||
|
ignore_index=-100,
|
||||||
|
vocab_size=51289,
|
||||||
|
projection_dim=1024,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.projection_dim = projection_dim
|
||||||
|
if vision_config is not None:
|
||||||
|
vision_config = Florence2VisionConfig(**vision_config)
|
||||||
|
self.vision_config = vision_config
|
||||||
|
|
||||||
|
self.text_config = text_config
|
||||||
|
if text_config is not None:
|
||||||
|
self.text_config = Florence2LanguageConfig(**text_config)
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
203
src/lerobot/policies/xvla/configuration_xvla.py
Normal file
203
src/lerobot/policies/xvla/configuration_xvla.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
from lerobot.optim.optimizers import XVLAAdamWConfig
|
||||||
|
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||||
|
from lerobot.utils.constants import OBS_IMAGES
|
||||||
|
|
||||||
|
# Conditional import for type checking and lazy loading
|
||||||
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from .configuration_florence2 import Florence2Config
|
||||||
|
else:
|
||||||
|
Florence2Config = None
|
||||||
|
|
||||||
|
|
||||||
|
@PreTrainedConfig.register_subclass("xvla")
|
||||||
|
@dataclass
|
||||||
|
class XVLAConfig(PreTrainedConfig):
|
||||||
|
"""
|
||||||
|
Configuration class for the XVLA (Extended Vision-Language-Action) policy so it can
|
||||||
|
plug into the LeRobot training stack.
|
||||||
|
|
||||||
|
The config mirrors the knobs exposed in the original XVLA repository but also
|
||||||
|
declares the input/output feature contract required by LeRobot.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Input / output structure
|
||||||
|
n_obs_steps: int = 1
|
||||||
|
chunk_size: int = 32
|
||||||
|
n_action_steps: int = 32
|
||||||
|
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||||
|
|
||||||
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.IDENTITY,
|
||||||
|
"ACTION": NormalizationMode.IDENTITY,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Florence2 backbone and tokenizer configuration
|
||||||
|
florence_config: dict[str, Any] = field(default_factory=dict)
|
||||||
|
tokenizer_name: str = "facebook/bart-large"
|
||||||
|
tokenizer_max_length: int = 64
|
||||||
|
tokenizer_padding_side: str = "right"
|
||||||
|
pad_language_to: str = "max_length"
|
||||||
|
|
||||||
|
# Transformer head
|
||||||
|
hidden_size: int = 1024
|
||||||
|
depth: int = 24
|
||||||
|
num_heads: int = 16
|
||||||
|
mlp_ratio: float = 4.0
|
||||||
|
num_domains: int = 30
|
||||||
|
len_soft_prompts: int = 32
|
||||||
|
dim_time: int = 32
|
||||||
|
max_len_seq: int = 512
|
||||||
|
use_hetero_proj: bool = False
|
||||||
|
|
||||||
|
# Action & proprioception
|
||||||
|
action_mode: str = "ee6d"
|
||||||
|
num_denoising_steps: int = 10
|
||||||
|
use_proprio: bool = True
|
||||||
|
max_state_dim: int = 32
|
||||||
|
max_action_dim: int = 20 # Maximum action dimension for padding (used by "auto" action mode)
|
||||||
|
domain_feature_key: str | None = None
|
||||||
|
|
||||||
|
# Vision preprocessing
|
||||||
|
resize_imgs_with_padding: tuple[int, int] | None = None
|
||||||
|
num_image_views: int | None = None
|
||||||
|
empty_cameras: int = 0
|
||||||
|
|
||||||
|
# Freezing options for VLM components
|
||||||
|
# By default, VLM encoders are frozen and only policy transformer + soft prompts train
|
||||||
|
freeze_vision_encoder: bool = False # Freeze VLM vision encoder weights
|
||||||
|
freeze_language_encoder: bool = False # Freeze VLM language encoder weights
|
||||||
|
train_policy_transformer: bool = True # Allow policy transformer to train
|
||||||
|
train_soft_prompts: bool = True # Allow soft prompts to train
|
||||||
|
|
||||||
|
# Training presets
|
||||||
|
optimizer_lr: float = 1e-4
|
||||||
|
optimizer_betas: tuple[float, float] = (0.9, 0.99)
|
||||||
|
optimizer_eps: float = 1e-8
|
||||||
|
optimizer_weight_decay: float = 0.0
|
||||||
|
optimizer_grad_clip_norm: float = 10.0
|
||||||
|
# Soft-prompt LR settings (for optional warm-up)
|
||||||
|
optimizer_soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR
|
||||||
|
optimizer_soft_prompt_warmup_lr_scale: float | None = None # Start scale for warmup (e.g., 0.01)
|
||||||
|
|
||||||
|
scheduler_warmup_steps: int = 1_000
|
||||||
|
scheduler_decay_steps: int = 30_000
|
||||||
|
scheduler_decay_lr: float = 2.5e-6
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
super().__post_init__()
|
||||||
|
|
||||||
|
if self.chunk_size <= 0:
|
||||||
|
raise ValueError("`chunk_size` must be strictly positive.")
|
||||||
|
if self.n_action_steps > self.chunk_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"`n_action_steps` ({self.n_action_steps}) must be <= `chunk_size` ({self.chunk_size})."
|
||||||
|
)
|
||||||
|
if self.num_image_views is not None and self.num_image_views <= 0:
|
||||||
|
raise ValueError("`num_image_views` must be > 0 when specified.")
|
||||||
|
if self.dtype not in ["bfloat16", "float32"]:
|
||||||
|
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||||
|
self._florence_config_obj: Florence2Config | None = None
|
||||||
|
|
||||||
|
def get_florence_config(self) -> Florence2Config:
|
||||||
|
"""
|
||||||
|
Build (and cache) the Florence2 transformer config that should back the VLM.
|
||||||
|
"""
|
||||||
|
if self._florence_config_obj is None:
|
||||||
|
config_dict = dict(self.florence_config)
|
||||||
|
if "vision_config" not in config_dict or config_dict["vision_config"] is None:
|
||||||
|
raise ValueError("vision_config is required")
|
||||||
|
|
||||||
|
if "text_config" not in config_dict or config_dict["text_config"] is None:
|
||||||
|
raise ValueError("text_config is required")
|
||||||
|
self._florence_config_obj = Florence2Config(**config_dict)
|
||||||
|
return self._florence_config_obj
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
if not self.image_features:
|
||||||
|
raise ValueError("XVLA requires at least one visual feature in the inputs.")
|
||||||
|
if self.use_proprio and self.robot_state_feature is None:
|
||||||
|
raise ValueError("`use_proprio=True` requires a proprioceptive state feature.")
|
||||||
|
if self.num_image_views is None:
|
||||||
|
self.num_image_views = len(self.image_features) + self.empty_cameras
|
||||||
|
else:
|
||||||
|
self.num_image_views = max(self.num_image_views, len(self.image_features) + self.empty_cameras)
|
||||||
|
|
||||||
|
if self.empty_cameras > 0:
|
||||||
|
height, width = (480, 640)
|
||||||
|
if self.resize_imgs_with_padding is not None:
|
||||||
|
height, width = self.resize_imgs_with_padding
|
||||||
|
for idx in range(self.empty_cameras):
|
||||||
|
key = f"{OBS_IMAGES}.empty_camera_{idx}"
|
||||||
|
if key not in self.input_features:
|
||||||
|
self.input_features[key] = PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL,
|
||||||
|
shape=(3, height, width),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> XVLAAdamWConfig:
|
||||||
|
"""Return the XVLA-specific optimizer with differential learning rates.
|
||||||
|
|
||||||
|
This optimizer applies:
|
||||||
|
- 1/10 LR for VLM parameters (stable optimization)
|
||||||
|
- Full LR for transformer/action head
|
||||||
|
- Configurable LR for soft-prompts (with optional warm-up)
|
||||||
|
"""
|
||||||
|
return XVLAAdamWConfig(
|
||||||
|
lr=self.optimizer_lr,
|
||||||
|
betas=self.optimizer_betas,
|
||||||
|
eps=self.optimizer_eps,
|
||||||
|
weight_decay=self.optimizer_weight_decay,
|
||||||
|
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||||
|
soft_prompt_lr_scale=self.optimizer_soft_prompt_lr_scale,
|
||||||
|
soft_prompt_warmup_lr_scale=self.optimizer_soft_prompt_warmup_lr_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||||
|
return CosineDecayWithWarmupSchedulerConfig(
|
||||||
|
peak_lr=self.optimizer_lr,
|
||||||
|
decay_lr=self.scheduler_decay_lr,
|
||||||
|
num_warmup_steps=self.scheduler_warmup_steps,
|
||||||
|
num_decay_steps=self.scheduler_decay_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> list[int] | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> list[int]:
|
||||||
|
return list(range(self.chunk_size))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> list[int] | None:
|
||||||
|
return None
|
||||||
2757
src/lerobot/policies/xvla/modeling_florence2.py
Normal file
2757
src/lerobot/policies/xvla/modeling_florence2.py
Normal file
File diff suppressed because it is too large
Load Diff
548
src/lerobot/policies/xvla/modeling_xvla.py
Normal file
548
src/lerobot/policies/xvla/modeling_xvla.py
Normal file
@@ -0,0 +1,548 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import builtins
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from collections import deque
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
|
from lerobot.policies.utils import populate_queues
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||||
|
|
||||||
|
from .action_hub import build_action_space
|
||||||
|
from .configuration_florence2 import Florence2Config
|
||||||
|
from .configuration_xvla import XVLAConfig
|
||||||
|
from .modeling_florence2 import Florence2ForConditionalGeneration
|
||||||
|
from .soft_transformer import SoftPromptedTransformer
|
||||||
|
|
||||||
|
|
||||||
|
class XVLAModel(nn.Module):
|
||||||
|
"""
|
||||||
|
XVLA backbone that stitches Florence-2 embeddings with the temporal/action transformer head.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: XVLAConfig,
|
||||||
|
florence_config: Florence2Config,
|
||||||
|
proprio_dim: int,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.chunk_size: int = config.chunk_size
|
||||||
|
self.use_proprio: bool = config.use_proprio
|
||||||
|
|
||||||
|
# Build action space with auto-detection for "auto" mode
|
||||||
|
if config.action_mode.lower() == "auto":
|
||||||
|
# Auto-detect real action dim from config.action_feature
|
||||||
|
real_dim = (
|
||||||
|
config.action_feature.shape[-1]
|
||||||
|
if config.action_feature is not None
|
||||||
|
else config.max_action_dim
|
||||||
|
)
|
||||||
|
self.action_space = build_action_space(
|
||||||
|
config.action_mode.lower(),
|
||||||
|
real_dim=real_dim,
|
||||||
|
max_dim=config.max_action_dim,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.action_space = build_action_space(config.action_mode.lower())
|
||||||
|
|
||||||
|
self.dim_action = self.action_space.dim_action
|
||||||
|
self.dim_proprio = proprio_dim
|
||||||
|
|
||||||
|
self.vlm = Florence2ForConditionalGeneration(florence_config)
|
||||||
|
if hasattr(self.vlm, "language_model"):
|
||||||
|
lm = self.vlm.language_model
|
||||||
|
if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
|
||||||
|
del lm.model.decoder
|
||||||
|
if hasattr(lm, "lm_head"):
|
||||||
|
del lm.lm_head
|
||||||
|
|
||||||
|
projection_dim = getattr(self.vlm.config, "projection_dim", None)
|
||||||
|
if projection_dim is None:
|
||||||
|
raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
|
||||||
|
|
||||||
|
self.transformer = SoftPromptedTransformer(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
multi_modal_input_size=projection_dim,
|
||||||
|
depth=config.depth,
|
||||||
|
num_heads=config.num_heads,
|
||||||
|
mlp_ratio=config.mlp_ratio,
|
||||||
|
num_domains=config.num_domains,
|
||||||
|
dim_action=self.dim_action,
|
||||||
|
dim_propio=self.dim_proprio,
|
||||||
|
len_soft_prompts=config.len_soft_prompts,
|
||||||
|
dim_time=config.dim_time,
|
||||||
|
max_len_seq=config.max_len_seq,
|
||||||
|
use_hetero_proj=config.use_hetero_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply freezing based on config
|
||||||
|
self._apply_freezing()
|
||||||
|
|
||||||
|
# Apply dtype casting based on config
|
||||||
|
self._apply_dtype()
|
||||||
|
|
||||||
|
def _get_target_dtype(self) -> torch.dtype:
|
||||||
|
"""Get the target dtype based on config."""
|
||||||
|
if self.config.dtype == "bfloat16":
|
||||||
|
return torch.bfloat16
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
|
def _apply_dtype(self) -> None:
|
||||||
|
"""
|
||||||
|
Apply dtype casting to model components based on config.
|
||||||
|
"""
|
||||||
|
target_dtype = self._get_target_dtype()
|
||||||
|
self.to(dtype=target_dtype)
|
||||||
|
|
||||||
|
def _apply_freezing(self) -> None:
|
||||||
|
"""
|
||||||
|
Freeze VLM vision and language encoders based on config options.
|
||||||
|
Keep only policy transformer and soft prompts trainable.
|
||||||
|
"""
|
||||||
|
# Freeze vision encoder
|
||||||
|
if self.config.freeze_vision_encoder and hasattr(self.vlm, "vision_tower"):
|
||||||
|
for param in self.vlm.vision_tower.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Freeze language encoder
|
||||||
|
if self.config.freeze_language_encoder and hasattr(self.vlm, "language_model"):
|
||||||
|
lm = self.vlm.language_model
|
||||||
|
# Freeze encoder
|
||||||
|
if hasattr(lm, "model") and hasattr(lm.model, "encoder"):
|
||||||
|
for param in lm.model.encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
# Freeze shared embeddings
|
||||||
|
if hasattr(lm, "model") and hasattr(lm.model, "shared"):
|
||||||
|
for param in lm.model.shared.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Freeze or unfreeze policy transformer
|
||||||
|
if not self.config.train_policy_transformer:
|
||||||
|
for name, param in self.transformer.named_parameters():
|
||||||
|
if "soft_prompts" not in name:
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Freeze or unfreeze soft prompts
|
||||||
|
if not self.config.train_soft_prompts and hasattr(self.transformer, "soft_prompt_hub"):
|
||||||
|
for param in self.transformer.soft_prompt_hub.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward_vlm(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
image_mask: torch.Tensor,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Encode text and multi-view images via Florence2 encoder.
|
||||||
|
"""
|
||||||
|
batch_size, num_views = pixel_values.shape[:2]
|
||||||
|
flat_mask = image_mask.view(-1).to(dtype=torch.bool)
|
||||||
|
flat_images = pixel_values.flatten(0, 1)
|
||||||
|
num_valid = int(flat_mask.sum().item())
|
||||||
|
if num_valid == 0:
|
||||||
|
raise ValueError("At least one image view must be valid per batch.")
|
||||||
|
|
||||||
|
valid_images = flat_images[flat_mask]
|
||||||
|
valid_feats = self.vlm._encode_image(valid_images)
|
||||||
|
tokens_per_view, hidden_dim = valid_feats.shape[1:]
|
||||||
|
|
||||||
|
image_features = valid_feats.new_zeros((batch_size * num_views, tokens_per_view, hidden_dim))
|
||||||
|
image_features[flat_mask] = valid_feats
|
||||||
|
image_features = image_features.view(batch_size, num_views, tokens_per_view, hidden_dim)
|
||||||
|
inputs_embeds = self.vlm.get_input_embeddings()(input_ids)
|
||||||
|
merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features(
|
||||||
|
image_features[:, 0],
|
||||||
|
inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
enc_out = self.vlm.language_model.model.encoder(
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
inputs_embeds=merged_embeds,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
aux_visual_inputs = image_features[:, 1:].reshape(batch_size, -1, hidden_dim)
|
||||||
|
return {"vlm_features": enc_out, "aux_visual_inputs": aux_visual_inputs}
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
image_input: torch.FloatTensor,
|
||||||
|
image_mask: torch.Tensor,
|
||||||
|
domain_id: torch.LongTensor,
|
||||||
|
proprio: torch.Tensor,
|
||||||
|
action: torch.Tensor,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Forward pass for the XVLA model.
|
||||||
|
"""
|
||||||
|
target_dtype = self._get_target_dtype()
|
||||||
|
image_input = image_input.to(dtype=target_dtype)
|
||||||
|
proprio = proprio.to(dtype=target_dtype)
|
||||||
|
action = action.to(dtype=target_dtype)
|
||||||
|
|
||||||
|
enc = self.forward_vlm(input_ids, image_input, image_mask)
|
||||||
|
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
t = (
|
||||||
|
torch.rand(1, device=input_ids.device, dtype=target_dtype)
|
||||||
|
+ torch.arange(batch_size, device=input_ids.device, dtype=target_dtype) / batch_size
|
||||||
|
) % (1 - 1e-5)
|
||||||
|
|
||||||
|
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
||||||
|
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
|
||||||
|
|
||||||
|
pred_action = self.transformer(
|
||||||
|
domain_id=domain_id,
|
||||||
|
action_with_noise=action_noisy_m,
|
||||||
|
t=t,
|
||||||
|
proprio=proprio_m,
|
||||||
|
**enc,
|
||||||
|
)
|
||||||
|
return self.action_space.compute_loss(pred_action, action)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate_actions(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
image_input: torch.FloatTensor,
|
||||||
|
image_mask: torch.Tensor,
|
||||||
|
domain_id: torch.LongTensor,
|
||||||
|
proprio: torch.Tensor,
|
||||||
|
steps: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
target_dtype = self._get_target_dtype()
|
||||||
|
image_input = image_input.to(dtype=target_dtype)
|
||||||
|
proprio = proprio.to(dtype=target_dtype)
|
||||||
|
|
||||||
|
enc = self.forward_vlm(input_ids, image_input, image_mask)
|
||||||
|
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
action_dim = self.dim_action
|
||||||
|
|
||||||
|
x1 = torch.randn(batch_size, self.chunk_size, action_dim, device=proprio.device, dtype=target_dtype)
|
||||||
|
action = torch.zeros_like(x1)
|
||||||
|
|
||||||
|
steps = max(1, int(steps))
|
||||||
|
for i in range(steps, 0, -1):
|
||||||
|
t = torch.full((batch_size,), i / steps, device=proprio.device, dtype=target_dtype)
|
||||||
|
x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
||||||
|
proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
|
||||||
|
action = self.transformer(
|
||||||
|
domain_id=domain_id,
|
||||||
|
action_with_noise=x_t_m,
|
||||||
|
proprio=proprio_m,
|
||||||
|
t=t,
|
||||||
|
**enc,
|
||||||
|
)
|
||||||
|
return self.action_space.postprocess(action)
|
||||||
|
|
||||||
|
|
||||||
|
class XVLAPolicy(PreTrainedPolicy):
|
||||||
|
"""LeRobot-compliant wrapper built around the XVLA model."""
|
||||||
|
|
||||||
|
config_class = XVLAConfig
|
||||||
|
name = "xvla"
|
||||||
|
|
||||||
|
def __init__(self, config: XVLAConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
config.validate_features()
|
||||||
|
florence_config = config.get_florence_config()
|
||||||
|
proprio_dim = config.max_state_dim if config.use_proprio else 0
|
||||||
|
self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim)
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self._queues = {
|
||||||
|
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_optim_params(self) -> dict:
|
||||||
|
"""Return trainable named parameters for optimization.
|
||||||
|
|
||||||
|
Returns a dict of name -> param for all trainable parameters.
|
||||||
|
This enables the xvla-adamw optimizer to apply differential learning rates
|
||||||
|
based on parameter names (e.g., 1/10 LR for VLM components).
|
||||||
|
"""
|
||||||
|
return dict(filter(lambda kv: kv[1].requires_grad, self.named_parameters()))
|
||||||
|
|
||||||
|
def _prepare_state(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
|
||||||
|
if not self.config.use_proprio or OBS_STATE not in batch:
|
||||||
|
return torch.zeros(batch_size, 0, device=device)
|
||||||
|
state = batch[OBS_STATE]
|
||||||
|
if state.ndim > 2:
|
||||||
|
state = state[:, -1, :]
|
||||||
|
return pad_vector(state, self.model.dim_proprio)
|
||||||
|
|
||||||
|
def _prepare_images(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||||
|
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||||
|
if len(present_img_keys) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"All image features are missing from the batch. "
|
||||||
|
f"Batch keys: {list(batch.keys())}, expected at least one of {list(self.config.image_features)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
images = []
|
||||||
|
masks = []
|
||||||
|
for key in present_img_keys:
|
||||||
|
img = batch[key][:, -1] if batch[key].ndim == 5 else batch[key]
|
||||||
|
if self.config.resize_imgs_with_padding is not None:
|
||||||
|
img = resize_with_pad(img, *self.config.resize_imgs_with_padding)
|
||||||
|
images.append(img)
|
||||||
|
masks.append(torch.ones(img.size(0), dtype=torch.bool, device=img.device))
|
||||||
|
|
||||||
|
stacked_imgs = torch.stack(images, dim=1)
|
||||||
|
stacked_masks = torch.stack(masks, dim=1)
|
||||||
|
|
||||||
|
total_views = self.config.num_image_views or stacked_imgs.size(1)
|
||||||
|
total_views = max(total_views, stacked_imgs.size(1))
|
||||||
|
num_pad = total_views - stacked_imgs.size(1)
|
||||||
|
if num_pad > 0:
|
||||||
|
pad_shape = (stacked_imgs.size(0), num_pad, *stacked_imgs.shape[2:])
|
||||||
|
pad_imgs = stacked_imgs.new_zeros(pad_shape)
|
||||||
|
pad_masks = stacked_masks.new_zeros((stacked_masks.size(0), num_pad))
|
||||||
|
stacked_imgs = torch.cat([stacked_imgs, pad_imgs], dim=1)
|
||||||
|
stacked_masks = torch.cat([stacked_masks, pad_masks], dim=1)
|
||||||
|
|
||||||
|
return stacked_imgs, stacked_masks
|
||||||
|
|
||||||
|
def _get_domain_id(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
|
||||||
|
candidate = None
|
||||||
|
if self.config.domain_feature_key and self.config.domain_feature_key in batch:
|
||||||
|
candidate = batch[self.config.domain_feature_key]
|
||||||
|
elif "domain_id" in batch:
|
||||||
|
candidate = batch["domain_id"]
|
||||||
|
|
||||||
|
if candidate is None:
|
||||||
|
return torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
if not isinstance(candidate, torch.Tensor):
|
||||||
|
candidate = torch.as_tensor(candidate, device=device)
|
||||||
|
else:
|
||||||
|
candidate = candidate.to(device=device)
|
||||||
|
|
||||||
|
if candidate.ndim == 0:
|
||||||
|
candidate = candidate.expand(batch_size)
|
||||||
|
if candidate.ndim > 1:
|
||||||
|
candidate = candidate.view(candidate.shape[0], -1)[:, 0]
|
||||||
|
if candidate.shape[0] != batch_size:
|
||||||
|
candidate = candidate.expand(batch_size)
|
||||||
|
return candidate.to(dtype=torch.long)
|
||||||
|
|
||||||
|
def _prepare_action_targets(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
if ACTION not in batch:
|
||||||
|
raise ValueError("Batch is missing action targets required for training.")
|
||||||
|
actions = batch[ACTION]
|
||||||
|
if actions.ndim == 2:
|
||||||
|
actions = actions.unsqueeze(1)
|
||||||
|
actions = pad_tensor_along_dim(actions, self.config.chunk_size, dim=1)
|
||||||
|
if actions.shape[-1] != self.model.dim_action:
|
||||||
|
actions = pad_vector(actions, self.model.dim_action)
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def _build_model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
input_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
images, image_mask = self._prepare_images(batch)
|
||||||
|
domain_id = self._get_domain_id(batch, batch_size, images.device)
|
||||||
|
proprio = self._prepare_state(batch, batch_size, images.device)
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"image_input": images,
|
||||||
|
"image_mask": image_mask,
|
||||||
|
"domain_id": domain_id,
|
||||||
|
"proprio": proprio,
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||||
|
inputs = self._build_model_inputs(batch)
|
||||||
|
targets = self._prepare_action_targets(batch)
|
||||||
|
losses = self.model(action=targets, **inputs)
|
||||||
|
total_loss = sum(losses.values())
|
||||||
|
|
||||||
|
log_dict = {k: v.detach().item() for k, v in losses.items()}
|
||||||
|
log_dict["loss"] = total_loss.detach().item()
|
||||||
|
return total_loss, log_dict
|
||||||
|
|
||||||
|
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
inputs = self._build_model_inputs(batch)
|
||||||
|
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
|
||||||
|
return actions
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
|
||||||
|
self.eval()
|
||||||
|
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||||
|
return self._get_action_chunk(batch)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
|
||||||
|
self.eval()
|
||||||
|
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||||
|
|
||||||
|
if len(self._queues[ACTION]) == 0:
|
||||||
|
actions = self._get_action_chunk(batch)
|
||||||
|
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||||
|
|
||||||
|
return self._queues[ACTION].popleft()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls: builtins.type[T],
|
||||||
|
pretrained_name_or_path: str | Path,
|
||||||
|
*,
|
||||||
|
config: PreTrainedConfig | None = None,
|
||||||
|
force_download: bool = False,
|
||||||
|
resume_download: bool | None = None,
|
||||||
|
proxies: dict | None = None,
|
||||||
|
token: str | bool | None = None,
|
||||||
|
cache_dir: str | Path | None = None,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
revision: str | None = None,
|
||||||
|
strict: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Loads XVLA model weights with:
|
||||||
|
- automatic prefix 'model.' added to all keys
|
||||||
|
- skip list for layers that should remain randomly initialized
|
||||||
|
"""
|
||||||
|
import safetensors.torch
|
||||||
|
|
||||||
|
# step 1: load config
|
||||||
|
# TODO: jadechoghari, fix this
|
||||||
|
if config is None:
|
||||||
|
config = PreTrainedConfig.from_pretrained(
|
||||||
|
pretrained_name_or_path=pretrained_name_or_path,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
revision=revision,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_id = str(pretrained_name_or_path)
|
||||||
|
instance = cls(config, **kwargs)
|
||||||
|
# step 2: locate model.safetensors
|
||||||
|
if os.path.isdir(model_id):
|
||||||
|
logging.info("Loading weights from local directory")
|
||||||
|
model_file = os.path.join(model_id, "model.safetensors")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from huggingface_hub.utils import HfHubHTTPError
|
||||||
|
|
||||||
|
model_file = hf_hub_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename="model.safetensors",
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
token=token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
except HfHubHTTPError as e:
|
||||||
|
raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
|
||||||
|
|
||||||
|
logging.info(f"Loading checkpoint from {model_file}")
|
||||||
|
# step 3: load state dict
|
||||||
|
state_dict = safetensors.torch.load_file(model_file)
|
||||||
|
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
|
||||||
|
shared_key = "model.vlm.language_model.model.shared.weight"
|
||||||
|
if encoder_key in state_dict:
|
||||||
|
state_dict[shared_key] = state_dict[encoder_key]
|
||||||
|
# or deepcopy
|
||||||
|
# step 4: load into instance
|
||||||
|
instance.load_state_dict(state_dict, strict=True)
|
||||||
|
logging.info("Loaded XVLA checkpoint")
|
||||||
|
# step 5: finalize
|
||||||
|
# Reapply dtype after loading state dict
|
||||||
|
instance.model._apply_dtype()
|
||||||
|
instance.to(config.device)
|
||||||
|
instance.eval()
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float = 0.0) -> torch.Tensor:
|
||||||
|
if img.ndim != 4:
|
||||||
|
raise ValueError(f"(b,c,h,w) expected, but got {img.shape}")
|
||||||
|
|
||||||
|
current_height, current_width = img.shape[2:]
|
||||||
|
if current_height == height and current_width == width:
|
||||||
|
return img
|
||||||
|
|
||||||
|
ratio = max(current_width / width, current_height / height)
|
||||||
|
resized_height = int(current_height / ratio)
|
||||||
|
resized_width = int(current_width / ratio)
|
||||||
|
resized_img = F.interpolate(
|
||||||
|
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
pad_height = max(0, height - resized_height)
|
||||||
|
pad_width = max(0, width - resized_width)
|
||||||
|
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||||
|
return padded_img
|
||||||
|
|
||||||
|
|
||||||
|
def pad_vector(vector: Tensor, new_dim: int) -> Tensor:
|
||||||
|
if vector.shape[-1] == new_dim:
|
||||||
|
return vector
|
||||||
|
if new_dim == 0:
|
||||||
|
shape = list(vector.shape)
|
||||||
|
shape[-1] = 0
|
||||||
|
return vector.new_zeros(*shape)
|
||||||
|
shape = list(vector.shape)
|
||||||
|
current_dim = shape[-1]
|
||||||
|
shape[-1] = new_dim
|
||||||
|
new_vector = vector.new_zeros(*shape)
|
||||||
|
length = min(current_dim, new_dim)
|
||||||
|
new_vector[..., :length] = vector[..., :length]
|
||||||
|
return new_vector
|
||||||
|
|
||||||
|
|
||||||
|
def pad_tensor_along_dim(tensor: Tensor, target_len: int, dim: int = 1) -> Tensor:
|
||||||
|
current_len = tensor.size(dim)
|
||||||
|
if current_len == target_len:
|
||||||
|
return tensor
|
||||||
|
if current_len > target_len:
|
||||||
|
slices = [slice(None)] * tensor.dim()
|
||||||
|
slices[dim] = slice(0, target_len)
|
||||||
|
return tensor[tuple(slices)]
|
||||||
|
pad_shape = list(tensor.shape)
|
||||||
|
pad_shape[dim] = target_len - current_len
|
||||||
|
pad_tensor = tensor.new_zeros(pad_shape)
|
||||||
|
return torch.cat([tensor, pad_tensor], dim=dim)
|
||||||
554
src/lerobot/policies/xvla/processor_xvla.py
Normal file
554
src/lerobot/policies/xvla/processor_xvla.py
Normal file
@@ -0,0 +1,554 @@
|
|||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||||
|
from lerobot.datasets.factory import IMAGENET_STATS
|
||||||
|
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
|
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
|
||||||
|
from lerobot.processor import (
|
||||||
|
AddBatchDimensionProcessorStep,
|
||||||
|
DeviceProcessorStep,
|
||||||
|
NormalizerProcessorStep,
|
||||||
|
ObservationProcessorStep,
|
||||||
|
PolicyAction,
|
||||||
|
PolicyProcessorPipeline,
|
||||||
|
ProcessorStep,
|
||||||
|
ProcessorStepRegistry,
|
||||||
|
RenameObservationsProcessorStep,
|
||||||
|
TokenizerProcessorStep,
|
||||||
|
UnnormalizerProcessorStep,
|
||||||
|
)
|
||||||
|
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||||
|
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||||
|
from lerobot.utils.constants import (
|
||||||
|
OBS_IMAGES,
|
||||||
|
OBS_STATE,
|
||||||
|
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_xvla_pre_post_processors(
|
||||||
|
config: XVLAConfig,
|
||||||
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
|
) -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Build the LeRobot processor pipelines for XVLA.
|
||||||
|
"""
|
||||||
|
|
||||||
|
features = {**config.input_features, **config.output_features}
|
||||||
|
input_steps = [
|
||||||
|
RenameObservationsProcessorStep(rename_map={}),
|
||||||
|
AddBatchDimensionProcessorStep(),
|
||||||
|
TokenizerProcessorStep(
|
||||||
|
tokenizer_name=config.tokenizer_name,
|
||||||
|
max_length=config.tokenizer_max_length,
|
||||||
|
padding=config.pad_language_to,
|
||||||
|
padding_side=config.tokenizer_padding_side,
|
||||||
|
),
|
||||||
|
XVLAImageToFloatProcessorStep(),
|
||||||
|
XVLAImageNetNormalizeProcessorStep(),
|
||||||
|
XVLAAddDomainIdProcessorStep(),
|
||||||
|
DeviceProcessorStep(device=config.device),
|
||||||
|
NormalizerProcessorStep(
|
||||||
|
features=features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||||
|
),
|
||||||
|
]
|
||||||
|
output_steps = [
|
||||||
|
UnnormalizerProcessorStep(
|
||||||
|
features=config.output_features,
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
),
|
||||||
|
DeviceProcessorStep(device="cpu"),
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||||
|
steps=input_steps,
|
||||||
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
),
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||||
|
steps=output_steps,
|
||||||
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
to_transition=policy_action_to_transition,
|
||||||
|
to_output=transition_to_policy_action,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Custom XVLA processor steps
|
||||||
|
@dataclass
|
||||||
|
class LiberoProcessorStep(ObservationProcessorStep):
|
||||||
|
"""
|
||||||
|
Processes LIBERO observations into the LeRobot format.
|
||||||
|
|
||||||
|
This step handles the specific observation structure from LIBERO environments,
|
||||||
|
which includes nested robot_state dictionaries and image observations.
|
||||||
|
|
||||||
|
**State Processing:**
|
||||||
|
- Processes the `robot_state` dictionary which contains nested end-effector,
|
||||||
|
gripper, and joint information.
|
||||||
|
- Extracts and concatenates:
|
||||||
|
- End-effector position (3D)
|
||||||
|
- End-effector quaternion converted to axis-angle (3D)
|
||||||
|
- Gripper joint positions (2D)
|
||||||
|
- Maps the concatenated state to `"observation.state"`.
|
||||||
|
|
||||||
|
**Image Processing:**
|
||||||
|
- Rotates images by 180 degrees by flipping both height and width dimensions.
|
||||||
|
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _process_observation(self, observation):
|
||||||
|
"""
|
||||||
|
Processes both image and robot_state observations from LIBERO.
|
||||||
|
"""
|
||||||
|
processed_obs = observation.copy()
|
||||||
|
for key in list(processed_obs.keys()):
|
||||||
|
if key.startswith(f"{OBS_IMAGES}."):
|
||||||
|
img = processed_obs[key]
|
||||||
|
|
||||||
|
if key == f"{OBS_IMAGES}.image":
|
||||||
|
# Flip both H and W
|
||||||
|
img = torch.flip(img, dims=[2, 3])
|
||||||
|
|
||||||
|
processed_obs[key] = img
|
||||||
|
# Process robot_state into a flat state vector
|
||||||
|
if "observation.robot_state" in processed_obs:
|
||||||
|
robot_state = processed_obs.pop("observation.robot_state")
|
||||||
|
|
||||||
|
# Extract components
|
||||||
|
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
|
||||||
|
eef_mat = robot_state["eef"]["mat"] # (B, 3, 3)
|
||||||
|
eef_rot6d = self._mat_to_rotate6d(eef_mat) # (B, 6)
|
||||||
|
|
||||||
|
extra = torch.zeros((eef_pos.shape[0], 1), dtype=torch.float32, device=eef_pos.device)
|
||||||
|
|
||||||
|
proprio_state = torch.cat((eef_pos, eef_rot6d, extra), dim=-1) # (B, 10)
|
||||||
|
state = torch.cat((proprio_state, torch.zeros_like(proprio_state)), dim=-1) # (B, 20)
|
||||||
|
# ensure float32
|
||||||
|
state = state.float()
|
||||||
|
if state.dim() == 1:
|
||||||
|
state = state.unsqueeze(0)
|
||||||
|
|
||||||
|
processed_obs[OBS_STATE] = state
|
||||||
|
return processed_obs
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
"""
|
||||||
|
Transforms feature keys from the LIBERO format to the LeRobot standard.
|
||||||
|
"""
|
||||||
|
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
|
||||||
|
|
||||||
|
# copy over non-STATE features
|
||||||
|
for ft, feats in features.items():
|
||||||
|
if ft != PipelineFeatureType.STATE:
|
||||||
|
new_features[ft] = feats.copy()
|
||||||
|
|
||||||
|
# rebuild STATE features
|
||||||
|
state_feats = {}
|
||||||
|
|
||||||
|
# add our new flattened state
|
||||||
|
state_feats["observation.state"] = PolicyFeature(
|
||||||
|
key="observation.state",
|
||||||
|
shape=(20,),
|
||||||
|
dtype="float32",
|
||||||
|
)
|
||||||
|
|
||||||
|
new_features[PipelineFeatureType.STATE] = state_feats
|
||||||
|
|
||||||
|
return new_features
|
||||||
|
|
||||||
|
def _mat_to_rotate6d(self, rot_mats: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert batched rotation matrices (B, 3, 3) into 6D rotation representation (B, 6).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rot_mats (Tensor): Rotation matrices of shape (B, 3, 3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: 6D rotation representation, shape (B, 6)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if input is not a torch tensor
|
||||||
|
ValueError: if shape is not (B, 3, 3)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(rot_mats, torch.Tensor):
|
||||||
|
raise TypeError(f"mat_to_rot6d expects a torch.Tensor, got {type(rot_mats)}")
|
||||||
|
|
||||||
|
if rot_mats.ndim != 3 or rot_mats.shape[1:] != (3, 3):
|
||||||
|
raise ValueError(f"mat_to_rot6d expects shape (B, 3, 3), got {tuple(rot_mats.shape)}")
|
||||||
|
|
||||||
|
rot_mats = rot_mats.to(torch.float32)
|
||||||
|
|
||||||
|
col1 = rot_mats[:, :3, 0] # (B, 3)
|
||||||
|
col2 = rot_mats[:, :3, 1] # (B, 3)
|
||||||
|
|
||||||
|
rot6d = torch.cat([col1, col2], dim=-1) # (B, 6)
|
||||||
|
|
||||||
|
return rot6d
|
||||||
|
|
||||||
|
def observation(self, observation):
|
||||||
|
return self._process_observation(observation)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="xvla_image_scale")
|
||||||
|
class XVLAImageScaleProcessorStep(ProcessorStep):
|
||||||
|
"""Scale image observations by 255 to convert from [0, 1] to [0, 255] range.
|
||||||
|
|
||||||
|
This processor step multiplies all image observations by 255, which is required
|
||||||
|
for XVLA models that expect images in uint8-like range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_keys: List of observation keys that contain images to scale.
|
||||||
|
If None, will automatically detect keys starting with "observation.images."
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_keys: list[str] | None = None
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
"""Scale image observations by 255."""
|
||||||
|
new_transition = transition.copy()
|
||||||
|
obs = new_transition.get(TransitionKey.OBSERVATION, {})
|
||||||
|
if obs is None:
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
# Make a copy of observations to avoid modifying the original
|
||||||
|
obs = obs.copy()
|
||||||
|
|
||||||
|
# Determine which keys to scale
|
||||||
|
keys_to_scale = self.image_keys
|
||||||
|
if keys_to_scale is None:
|
||||||
|
# Auto-detect image keys
|
||||||
|
keys_to_scale = [k for k in obs if k.startswith("observation.images.")]
|
||||||
|
|
||||||
|
# Scale each image
|
||||||
|
for key in keys_to_scale:
|
||||||
|
if key in obs and isinstance(obs[key], torch.Tensor):
|
||||||
|
obs[key] = obs[key] * 255
|
||||||
|
|
||||||
|
new_transition[TransitionKey.OBSERVATION] = obs
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
"""Image scaling doesn't change feature structure."""
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
"""Return serializable configuration."""
|
||||||
|
return {
|
||||||
|
"image_keys": self.image_keys,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="xvla_image_to_float")
|
||||||
|
class XVLAImageToFloatProcessorStep(ProcessorStep):
|
||||||
|
"""Convert image observations from [0, 255] to [0, 1] range.
|
||||||
|
|
||||||
|
This processor step divides image observations by 255 to convert from uint8-like
|
||||||
|
range [0, 255] to float range [0, 1]. This is typically used when loading images
|
||||||
|
that are stored as uint8 values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_keys: List of observation keys that contain images to convert.
|
||||||
|
If None, will automatically detect keys starting with "observation.images."
|
||||||
|
validate_range: If True, validates that input values are in [0, 255] range (default: True)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If validate_range is True and image values are not in [0, 255] range.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_keys: list[str] | None = None
|
||||||
|
validate_range: bool = True
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
"""Convert image observations from [0, 255] to [0, 1]."""
|
||||||
|
new_transition = transition.copy()
|
||||||
|
obs = new_transition.get(TransitionKey.OBSERVATION, {})
|
||||||
|
if obs is None:
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
# Make a copy of observations to avoid modifying the original
|
||||||
|
obs = obs.copy()
|
||||||
|
|
||||||
|
# Determine which keys to convert
|
||||||
|
keys_to_convert = self.image_keys
|
||||||
|
if keys_to_convert is None:
|
||||||
|
# Auto-detect image keys
|
||||||
|
keys_to_convert = [k for k in obs if k.startswith("observation.images.")]
|
||||||
|
|
||||||
|
# Convert each image
|
||||||
|
for key in keys_to_convert:
|
||||||
|
if key in obs and isinstance(obs[key], torch.Tensor):
|
||||||
|
tensor = obs[key]
|
||||||
|
|
||||||
|
min_val = tensor.min().item()
|
||||||
|
max_val = tensor.max().item()
|
||||||
|
|
||||||
|
if max_val <= 1.0:
|
||||||
|
obs[key] = tensor.float() # ensure float dtype, but no division
|
||||||
|
continue
|
||||||
|
# Validate that values are in [0, 255] range if requested
|
||||||
|
if self.validate_range and (min_val < 0.0 or max_val > 255.0):
|
||||||
|
raise ValueError(
|
||||||
|
f"Image '{key}' has values outside [0, 255] range: "
|
||||||
|
f"min={min_val:.4f}, max={max_val:.4f}. "
|
||||||
|
f"Cannot convert to [0, 1] range."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to float and divide by 255
|
||||||
|
obs[key] = tensor.float() / 255.0
|
||||||
|
|
||||||
|
new_transition[TransitionKey.OBSERVATION] = obs
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
"""Image conversion doesn't change feature structure."""
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
"""Return serializable configuration."""
|
||||||
|
return {
|
||||||
|
"image_keys": self.image_keys,
|
||||||
|
"validate_range": self.validate_range,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="xvla_imagenet_normalize")
|
||||||
|
class XVLAImageNetNormalizeProcessorStep(ProcessorStep):
|
||||||
|
"""Normalize image observations using ImageNet statistics.
|
||||||
|
|
||||||
|
This processor step applies ImageNet normalization (mean and std) to image observations.
|
||||||
|
It validates that input values are in the [0, 1] range before normalizing.
|
||||||
|
|
||||||
|
The normalization formula is: (image - mean) / std
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_keys: List of observation keys that contain images to normalize.
|
||||||
|
If None, will automatically detect keys starting with "observation.images."
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If image values are not in the [0, 1] range.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_keys: list[str] | None = None
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
"""Normalize image observations using ImageNet statistics."""
|
||||||
|
new_transition = transition.copy()
|
||||||
|
obs = new_transition.get(TransitionKey.OBSERVATION, {})
|
||||||
|
if obs is None:
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
# Make a copy of observations to avoid modifying the original
|
||||||
|
obs = obs.copy()
|
||||||
|
|
||||||
|
# Determine which keys to normalize
|
||||||
|
keys_to_normalize = self.image_keys
|
||||||
|
if keys_to_normalize is None:
|
||||||
|
# Auto-detect image keys
|
||||||
|
keys_to_normalize = [k for k in obs if k.startswith("observation.images.")]
|
||||||
|
|
||||||
|
# Normalize each image
|
||||||
|
for key in keys_to_normalize:
|
||||||
|
if key in obs and isinstance(obs[key], torch.Tensor):
|
||||||
|
tensor = obs[key]
|
||||||
|
|
||||||
|
# Validate that values are in [0, 1] range
|
||||||
|
min_val = tensor.min().item()
|
||||||
|
max_val = tensor.max().item()
|
||||||
|
if min_val < 0.0 or max_val > 1.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image '{key}' has values outside [0, 1] range: "
|
||||||
|
f"min={min_val:.4f}, max={max_val:.4f}. "
|
||||||
|
f"ImageNet normalization requires input values in [0, 1]."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply ImageNet normalization
|
||||||
|
mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
|
||||||
|
std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
|
||||||
|
|
||||||
|
# Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
|
||||||
|
while mean.dim() < tensor.dim():
|
||||||
|
mean = mean.unsqueeze(0)
|
||||||
|
std = std.unsqueeze(0)
|
||||||
|
|
||||||
|
# Normalize: (image - mean) / std
|
||||||
|
obs[key] = (tensor - mean) / std
|
||||||
|
|
||||||
|
new_transition[TransitionKey.OBSERVATION] = obs
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
"""ImageNet normalization doesn't change feature structure."""
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
"""Return serializable configuration."""
|
||||||
|
return {
|
||||||
|
"image_keys": self.image_keys,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="xvla_add_domain_id")
|
||||||
|
class XVLAAddDomainIdProcessorStep(ProcessorStep):
|
||||||
|
"""Add domain_id to complementary data.
|
||||||
|
|
||||||
|
This processor step adds a domain_id tensor to the complementary data,
|
||||||
|
which is used by XVLA to identify different robot embodiments or task domains.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain_id: The domain ID to add (default: 3)
|
||||||
|
"""
|
||||||
|
|
||||||
|
domain_id: int = 0
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
"""Add domain_id to complementary data."""
|
||||||
|
new_transition = transition.copy()
|
||||||
|
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||||
|
comp = {} if comp is None else comp.copy()
|
||||||
|
|
||||||
|
# Infer batch size from observation tensors
|
||||||
|
obs = new_transition.get(TransitionKey.OBSERVATION, {})
|
||||||
|
batch_size = 1
|
||||||
|
if obs:
|
||||||
|
for v in obs.values():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
batch_size = v.shape[0]
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add domain_id tensor
|
||||||
|
comp["domain_id"] = torch.tensor([int(self.domain_id)] * batch_size, dtype=torch.long)
|
||||||
|
|
||||||
|
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
"""Domain ID addition doesn't change feature structure."""
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
"""Return serializable configuration."""
|
||||||
|
return {
|
||||||
|
"domain_id": self.domain_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="xvla_rotation_6d_to_axis_angle")
|
||||||
|
class XVLARotation6DToAxisAngleProcessorStep(ProcessorStep):
|
||||||
|
"""Convert 6D rotation representation to axis-angle and reorganize action dimensions.
|
||||||
|
|
||||||
|
This processor step takes actions with 6D rotation representation and converts them to
|
||||||
|
axis-angle representation, reorganizing the action dimensions as:
|
||||||
|
- action[:, :3] -> target_eef (end-effector position)
|
||||||
|
- action[:, 3:9] -> 6D rotation (converted to axis-angle, 3D)
|
||||||
|
- action[:, 9:10] -> gripper action
|
||||||
|
|
||||||
|
Final output: [target_eef (3), axis_angle (3), gripper (1)] = 7D action
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expected_action_dim: Expected input action dimension (default: 10, supports 6D rotation + extras)
|
||||||
|
"""
|
||||||
|
|
||||||
|
expected_action_dim: int = 10
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
"""Convert 6D rotation to axis-angle in action."""
|
||||||
|
new_transition = transition.copy()
|
||||||
|
action = new_transition.get(TransitionKey.ACTION)
|
||||||
|
|
||||||
|
if action is None or not isinstance(action, torch.Tensor):
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
# Convert to numpy for processing
|
||||||
|
device = action.device
|
||||||
|
dtype = action.dtype
|
||||||
|
action_np = action.cpu().numpy()
|
||||||
|
|
||||||
|
# Extract components
|
||||||
|
# action shape: (B, D) where D >= 10
|
||||||
|
target_eef = action_np[:, :3] # (B, 3)
|
||||||
|
rotation_6d = action_np[:, 3:9] # (B, 6)
|
||||||
|
target_act = action_np[:, 9:10] # (B, 1)
|
||||||
|
|
||||||
|
# Convert 6D rotation to axis-angle
|
||||||
|
target_axis = rotate6d_to_axis_angle(rotation_6d) # (B, 3)
|
||||||
|
|
||||||
|
# Concatenate: [eef (3), axis_angle (3), gripper (1)] = 7D
|
||||||
|
action_np = np.concatenate([target_eef, target_axis, target_act], axis=-1)
|
||||||
|
|
||||||
|
# Convert gripper action to -1 or 1
|
||||||
|
action_np[:, -1] = np.where(action_np[:, -1] > 0.5, 1.0, -1.0)
|
||||||
|
|
||||||
|
# Convert back to tensor
|
||||||
|
action = torch.from_numpy(action_np).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
new_transition[TransitionKey.ACTION] = action
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
"""Rotation conversion changes action dimension from 10 to 7."""
|
||||||
|
# Note: This is a simplified version. In practice, you might want to
|
||||||
|
# update the action feature shape in the features dict.
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
"""Return serializable configuration."""
|
||||||
|
return {
|
||||||
|
"expected_action_dim": self.expected_action_dim,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def make_xvla_libero_pre_post_processors() -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Build the LeRobot processor pipelines for XVLA with LIBERO environment.
|
||||||
|
"""
|
||||||
|
pre_processor_steps: list[ProcessorStep] = []
|
||||||
|
post_processor_steps: list[ProcessorStep] = []
|
||||||
|
pre_processor_steps.extend(
|
||||||
|
[LiberoProcessorStep(), XVLAImageNetNormalizeProcessorStep(), XVLAAddDomainIdProcessorStep()]
|
||||||
|
)
|
||||||
|
post_processor_steps.extend([XVLARotation6DToAxisAngleProcessorStep()])
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||||
|
steps=pre_processor_steps,
|
||||||
|
),
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||||
|
steps=post_processor_steps,
|
||||||
|
),
|
||||||
|
)
|
||||||
415
src/lerobot/policies/xvla/soft_transformer.py
Normal file
415
src/lerobot/policies/xvla/soft_transformer.py
Normal file
@@ -0,0 +1,415 @@
|
|||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright 2025 2toINF (https://github.com/2toINF)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from functools import partial
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as functional
|
||||||
|
|
||||||
|
# ------------------------------- Small utils ----------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _to_2tuple(x) -> tuple:
|
||||||
|
"""Minimal replacement for timm.layers.to_2tuple."""
|
||||||
|
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
|
||||||
|
t = tuple(x)
|
||||||
|
return (t[0], t[1]) if len(t) >= 2 else (t[0], t[0])
|
||||||
|
return (x, x)
|
||||||
|
|
||||||
|
|
||||||
|
def _has_sdp_attention() -> bool:
|
||||||
|
"""Check if we can use PyTorch fused scaled_dot_product_attention."""
|
||||||
|
return hasattr(functional, "scaled_dot_product_attention")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------- MLP --------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
"""
|
||||||
|
MLP used in ViT-style blocks.
|
||||||
|
|
||||||
|
Supports Linear or 1x1 Conv 'linear_layer' for token/channel mixing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_features: int | None = None,
|
||||||
|
out_features: int | None = None,
|
||||||
|
norm_layer: type[nn.Module] | None = None,
|
||||||
|
bias: bool | tuple[bool, bool] = True,
|
||||||
|
drop: float | tuple[float, float] = 0.0,
|
||||||
|
use_conv: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
bias = _to_2tuple(bias)
|
||||||
|
drop_probs = _to_2tuple(drop)
|
||||||
|
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
||||||
|
|
||||||
|
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
||||||
|
self.act = nn.GELU(approximate="tanh")
|
||||||
|
self.drop1 = nn.Dropout(drop_probs[0])
|
||||||
|
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
||||||
|
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
||||||
|
self.drop2 = nn.Dropout(drop_probs[1])
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Expect [B, T, C] for Linear variant; caller is responsible for shapes.
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop1(x)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------- Attention ----------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""
|
||||||
|
Multi-Head Self-Attention with optional fused SDPA fallback.
|
||||||
|
|
||||||
|
If PyTorch provides `scaled_dot_product_attention`, it will be used
|
||||||
|
(usually faster and more stable); otherwise we use a manual implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
fused_attn: Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
attn_drop: float = 0.0,
|
||||||
|
proj_drop: float = 0.0,
|
||||||
|
norm_layer: type[nn.Module] = nn.LayerNorm,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.fused_attn = _has_sdp_attention()
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor, shape [batch_size, seq_len, channels]
|
||||||
|
Input sequence.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor, shape [batch_size, seq_len, channels]
|
||||||
|
Output sequence after MHSA + projection.
|
||||||
|
"""
|
||||||
|
batch_size, seq_len, channels = x.shape
|
||||||
|
qkv = (
|
||||||
|
self.qkv(x)
|
||||||
|
.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||||
|
.permute(2, 0, 3, 1, 4) # 3 x [batch_size, num_heads, seq_len, head_dim]
|
||||||
|
)
|
||||||
|
q, k, v = qkv.unbind(0) # each: [batch_size, num_heads, seq_len, head_dim]
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
|
if self.fused_attn:
|
||||||
|
x = functional.scaled_dot_product_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||||
|
) # [batch_size, num_heads, seq_len, head_dim]
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = q @ k.transpose(-2, -1) # [batch_size, num_heads, seq_len, seq_len]
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = attn @ v # [batch_size, num_heads, seq_len, head_dim]
|
||||||
|
|
||||||
|
x = x.transpose(1, 2).reshape(batch_size, seq_len, channels) # [batch_size, seq_len, channels]
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------- Utilities -----------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def basic_init(module: nn.Module) -> None:
|
||||||
|
"""
|
||||||
|
Apply a basic initialization scheme to Linear layers.
|
||||||
|
|
||||||
|
- Weight: Xavier uniform initialization.
|
||||||
|
- Bias: Set to zero.
|
||||||
|
"""
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
nn.init.xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.constant_(module.bias, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
t : torch.Tensor
|
||||||
|
Shape [B]. Each element is a timestep index, may be fractional.
|
||||||
|
dim : int
|
||||||
|
Dimensionality of the output embedding.
|
||||||
|
max_period : int, default=100
|
||||||
|
Controls the minimum frequency of the sinusoids.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Shape [B, dim]. Sinusoidal embeddings.
|
||||||
|
"""
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=t.dtype, device=t.device) / half
|
||||||
|
)
|
||||||
|
args = t[:, None] * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2 == 1:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------- Core Layers ----------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class DomainAwareLinear(nn.Module):
|
||||||
|
"""
|
||||||
|
Linear layer with domain-conditioned parameters (per-sample).
|
||||||
|
|
||||||
|
Each domain has its own weight and bias vectors, stored in embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_size: int, output_size: int, num_domains: int = 20) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.output_size = output_size
|
||||||
|
self.fc = nn.Embedding(num_domains, output_size * input_size)
|
||||||
|
self.bias = nn.Embedding(num_domains, output_size)
|
||||||
|
nn.init.xavier_uniform_(self.fc.weight)
|
||||||
|
nn.init.zeros_(self.bias.weight)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, domain_id: torch.LongTensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor
|
||||||
|
[B, I] or [B, T, I]
|
||||||
|
domain_id : LongTensor
|
||||||
|
[B], domain indices.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
[batch_size, output_size] or [batch_size, seq_len, output_size]
|
||||||
|
"""
|
||||||
|
batch_size = domain_id.shape[0]
|
||||||
|
squeeze_seq = False
|
||||||
|
if x.dim() == 2:
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
squeeze_seq = True
|
||||||
|
weight = self.fc(domain_id).view(batch_size, self.input_size, self.output_size)
|
||||||
|
bias = self.bias(domain_id).view(batch_size, self.output_size)
|
||||||
|
y = torch.matmul(x, weight) + bias.view(batch_size, 1, self.output_size)
|
||||||
|
if squeeze_seq:
|
||||||
|
y = y.squeeze(1)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Standard Transformer block (pre-LN): LN → MHSA → residual, LN → MLP → residual.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = nn.LayerNorm(hidden_size)
|
||||||
|
self.norm2 = nn.LayerNorm(hidden_size)
|
||||||
|
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, attn_drop=0.1)
|
||||||
|
self.mlp = Mlp(
|
||||||
|
in_features=hidden_size,
|
||||||
|
hidden_features=int(hidden_size * mlp_ratio),
|
||||||
|
drop=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor, [B, T, H]
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor, [B, T, H]
|
||||||
|
"""
|
||||||
|
x = x + self.attn(self.norm1(x))
|
||||||
|
x = x + self.mlp(self.norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------- Main Model ---------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class SoftPromptedTransformer(nn.Module):
|
||||||
|
"""
|
||||||
|
Multi-modal, domain-aware Transformer with optional soft prompts.
|
||||||
|
|
||||||
|
See parameter and forward I/O descriptions inside the docstrings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 768,
|
||||||
|
multi_modal_input_size: int = 768,
|
||||||
|
depth: int = 24,
|
||||||
|
num_heads: int = 16,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
num_domains: int = 20,
|
||||||
|
dim_action: int = 20,
|
||||||
|
dim_propio: int = 20,
|
||||||
|
dim_time: int = 32,
|
||||||
|
len_soft_prompts: int = 32,
|
||||||
|
max_len_seq: int = 512,
|
||||||
|
use_hetero_proj: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.dim_action = dim_action
|
||||||
|
self.dim_time = dim_time
|
||||||
|
self.len_soft_prompts = len_soft_prompts
|
||||||
|
self.use_hetero_proj = use_hetero_proj
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)]
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_hetero_proj:
|
||||||
|
self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
|
||||||
|
self.aux_visual_proj = DomainAwareLinear(
|
||||||
|
multi_modal_input_size, hidden_size, num_domains=num_domains
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
|
||||||
|
self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
|
||||||
|
|
||||||
|
self.pos_emb = nn.Parameter(torch.zeros(1, max_len_seq, hidden_size), requires_grad=True)
|
||||||
|
nn.init.normal_(self.pos_emb, std=0.02)
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm(hidden_size)
|
||||||
|
self.action_encoder = DomainAwareLinear(
|
||||||
|
dim_action + dim_time + dim_propio, hidden_size, num_domains=num_domains
|
||||||
|
)
|
||||||
|
self.action_decoder = DomainAwareLinear(hidden_size, dim_action, num_domains=num_domains)
|
||||||
|
|
||||||
|
if len_soft_prompts > 0:
|
||||||
|
self.soft_prompt_hub = nn.Embedding(num_domains, len_soft_prompts * hidden_size)
|
||||||
|
nn.init.normal_(self.soft_prompt_hub.weight, std=0.02)
|
||||||
|
|
||||||
|
self.apply(basic_init)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
domain_id: torch.LongTensor,
|
||||||
|
vlm_features: torch.Tensor,
|
||||||
|
aux_visual_inputs: torch.Tensor,
|
||||||
|
action_with_noise: torch.Tensor,
|
||||||
|
proprio: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass.
|
||||||
|
|
||||||
|
Inputs
|
||||||
|
------
|
||||||
|
domain_id : [B]
|
||||||
|
vlm_features : [B, T_vlm, D]
|
||||||
|
aux_visual_inputs : [B, T_aux, D]
|
||||||
|
action_with_noise : [B, T_action, dim_action]
|
||||||
|
proprio : [B, dim_propio]
|
||||||
|
t : [B]
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
Predicted actions, [batch_size, num_actions, dim_action]
|
||||||
|
"""
|
||||||
|
batch_size, num_actions = action_with_noise.shape[:2]
|
||||||
|
|
||||||
|
# Encode (action + proprio + time) → tokens
|
||||||
|
time_emb = timestep_embedding(t, self.dim_time) # [batch_size, dim_time]
|
||||||
|
time_tokens = time_emb.unsqueeze(1).expand(batch_size, num_actions, self.dim_time)
|
||||||
|
proprio_tokens = proprio.unsqueeze(1).expand(batch_size, num_actions, proprio.shape[-1])
|
||||||
|
action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
|
||||||
|
x = self.action_encoder(action_tokens, domain_id) # [batch_size, num_actions, hidden_size]
|
||||||
|
|
||||||
|
# Project visual streams and concatenate
|
||||||
|
if self.use_hetero_proj:
|
||||||
|
x = torch.cat(
|
||||||
|
[
|
||||||
|
x,
|
||||||
|
self.vlm_proj(vlm_features, domain_id),
|
||||||
|
self.aux_visual_proj(aux_visual_inputs, domain_id),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = torch.cat([x, self.vlm_proj(vlm_features), self.aux_visual_proj(aux_visual_inputs)], dim=1)
|
||||||
|
|
||||||
|
# Add positional embeddings (truncate if needed)
|
||||||
|
seq_len = x.shape[1]
|
||||||
|
if seq_len > self.pos_emb.shape[1]:
|
||||||
|
raise ValueError(f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}.")
|
||||||
|
x = x + self.pos_emb[:, :seq_len, :]
|
||||||
|
|
||||||
|
# Append soft prompts
|
||||||
|
if self.len_soft_prompts > 0:
|
||||||
|
soft_prompts = self.soft_prompt_hub(domain_id).view(
|
||||||
|
batch_size, self.len_soft_prompts, self.hidden_size
|
||||||
|
)
|
||||||
|
x = torch.cat([x, soft_prompts], dim=1)
|
||||||
|
|
||||||
|
# Transformer backbone
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
# Decode only the action segment
|
||||||
|
return self.action_decoder(self.norm(x[:, :num_actions]), domain_id)
|
||||||
138
src/lerobot/policies/xvla/utils.py
Normal file
138
src/lerobot/policies/xvla/utils.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def mat2quat(rmat):
|
||||||
|
"""
|
||||||
|
Converts given rotation matrix to quaternion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rmat (np.array): 3x3 rotation matrix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: (x,y,z,w) float quaternion angles
|
||||||
|
"""
|
||||||
|
mat = np.asarray(rmat).astype(np.float32)[:3, :3]
|
||||||
|
|
||||||
|
m00 = mat[0, 0]
|
||||||
|
m01 = mat[0, 1]
|
||||||
|
m02 = mat[0, 2]
|
||||||
|
m10 = mat[1, 0]
|
||||||
|
m11 = mat[1, 1]
|
||||||
|
m12 = mat[1, 2]
|
||||||
|
m20 = mat[2, 0]
|
||||||
|
m21 = mat[2, 1]
|
||||||
|
m22 = mat[2, 2]
|
||||||
|
# symmetric matrix k
|
||||||
|
k = np.array(
|
||||||
|
[
|
||||||
|
[m00 - m11 - m22, np.float32(0.0), np.float32(0.0), np.float32(0.0)],
|
||||||
|
[m01 + m10, m11 - m00 - m22, np.float32(0.0), np.float32(0.0)],
|
||||||
|
[m02 + m20, m12 + m21, m22 - m00 - m11, np.float32(0.0)],
|
||||||
|
[m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
k /= 3.0
|
||||||
|
# quaternion is Eigen vector of k that corresponds to largest eigenvalue
|
||||||
|
w, v = np.linalg.eigh(k)
|
||||||
|
inds = np.array([3, 0, 1, 2])
|
||||||
|
q1 = v[inds, np.argmax(w)]
|
||||||
|
if q1[0] < 0.0:
|
||||||
|
np.negative(q1, q1)
|
||||||
|
inds = np.array([1, 2, 3, 0])
|
||||||
|
return q1[inds]
|
||||||
|
|
||||||
|
|
||||||
|
def quat2axisangle(quat):
|
||||||
|
"""
|
||||||
|
Converts quaternion to axis-angle format.
|
||||||
|
Returns a unit vector direction scaled by its angle in radians.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quat (np.array): (x,y,z,w) vec4 float angles
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: (ax,ay,az) axis-angle exponential coordinates
|
||||||
|
"""
|
||||||
|
# clip quaternion
|
||||||
|
if quat[3] > 1.0:
|
||||||
|
quat[3] = 1.0
|
||||||
|
elif quat[3] < -1.0:
|
||||||
|
quat[3] = -1.0
|
||||||
|
|
||||||
|
den = np.sqrt(1.0 - quat[3] * quat[3])
|
||||||
|
if math.isclose(den, 0.0):
|
||||||
|
# This is (close to) a zero degree rotation, immediately return
|
||||||
|
return np.zeros(3)
|
||||||
|
|
||||||
|
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
||||||
|
|
||||||
|
|
||||||
|
def rotate6d_to_axis_angle(r6d):
|
||||||
|
"""
|
||||||
|
r6d: np.ndarray, shape (N, 6)
|
||||||
|
return: np.ndarray, shape (N, 3), axis-angle vectors
|
||||||
|
"""
|
||||||
|
flag = 0
|
||||||
|
if len(r6d.shape) == 1:
|
||||||
|
r6d = r6d[None, ...]
|
||||||
|
flag = 1
|
||||||
|
|
||||||
|
a1 = r6d[:, 0:3]
|
||||||
|
a2 = r6d[:, 3:6]
|
||||||
|
|
||||||
|
# b1
|
||||||
|
b1 = a1 / (np.linalg.norm(a1, axis=-1, keepdims=True) + 1e-6)
|
||||||
|
|
||||||
|
# b2
|
||||||
|
dot_prod = np.sum(b1 * a2, axis=-1, keepdims=True)
|
||||||
|
b2_orth = a2 - dot_prod * b1
|
||||||
|
b2 = b2_orth / (np.linalg.norm(b2_orth, axis=-1, keepdims=True) + 1e-6)
|
||||||
|
|
||||||
|
# b3
|
||||||
|
b3 = np.cross(b1, b2, axis=-1)
|
||||||
|
|
||||||
|
rotation_matrix = np.stack([b1, b2, b3], axis=-1) # shape: (N, 3, 3)
|
||||||
|
|
||||||
|
axis_angle_list = []
|
||||||
|
for i in range(rotation_matrix.shape[0]):
|
||||||
|
quat = mat2quat(rotation_matrix[i])
|
||||||
|
axis_angle = quat2axisangle(quat)
|
||||||
|
axis_angle_list.append(axis_angle)
|
||||||
|
|
||||||
|
axis_angle_array = np.stack(axis_angle_list, axis=0) # shape: (N, 3)
|
||||||
|
|
||||||
|
if flag == 1:
|
||||||
|
axis_angle_array = axis_angle_array[0]
|
||||||
|
|
||||||
|
return axis_angle_array
|
||||||
|
|
||||||
|
|
||||||
|
def mat_to_rotate6d(abs_action):
|
||||||
|
if len(abs_action.shape) == 2:
|
||||||
|
return np.concatenate([abs_action[:3, 0], abs_action[:3, 1]], axis=-1)
|
||||||
|
elif len(abs_action.shape) == 3:
|
||||||
|
return np.concatenate([abs_action[:, :3, 0], abs_action[:, :3, 1]], axis=-1)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
|
||||||
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||||
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||||
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||||
|
'survival rate' as the argument.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if drop_prob == 0.0 or not training:
|
||||||
|
return x
|
||||||
|
keep_prob = 1 - drop_prob
|
||||||
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||||
|
if keep_prob > 0.0 and scale_by_keep:
|
||||||
|
random_tensor.div_(keep_prob)
|
||||||
|
return x * random_tensor
|
||||||
20
src/lerobot/robots/earthrover_mini_plus/__init__.py
Normal file
20
src/lerobot/robots/earthrover_mini_plus/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
|
||||||
|
from .robot_earthrover_mini_plus import EarthRoverMiniPlus
|
||||||
|
|
||||||
|
__all__ = ["EarthRoverMiniPlus", "EarthRoverMiniPlusConfig"]
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Configuration for EarthRover Mini Plus robot."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from ..config import RobotConfig
|
||||||
|
|
||||||
|
|
||||||
|
@RobotConfig.register_subclass("earthrover_mini_plus")
|
||||||
|
@dataclass
|
||||||
|
class EarthRoverMiniPlusConfig(RobotConfig):
|
||||||
|
"""Configuration for EarthRover Mini Plus robot using Frodobots SDK.
|
||||||
|
|
||||||
|
This robot uses cloud-based control via the Frodobots SDK HTTP API.
|
||||||
|
Camera frames are accessed directly through SDK HTTP endpoints.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
sdk_url: URL of the Frodobots SDK server (default: http://localhost:8000)
|
||||||
|
"""
|
||||||
|
|
||||||
|
sdk_url: str = "http://localhost:8000"
|
||||||
1
src/lerobot/robots/earthrover_mini_plus/earthrover_mini_plus.mdx
Symbolic link
1
src/lerobot/robots/earthrover_mini_plus/earthrover_mini_plus.mdx
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../../../../docs/source/earthrover_mini_plus.mdx
|
||||||
@@ -0,0 +1,473 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""EarthRover Mini Plus robot using Frodobots SDK."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||||
|
|
||||||
|
from ..robot import Robot
|
||||||
|
from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Action feature keys
|
||||||
|
ACTION_LINEAR_VEL = "linear.vel"
|
||||||
|
ACTION_ANGULAR_VEL = "angular.vel"
|
||||||
|
|
||||||
|
# Observation feature keys
|
||||||
|
OBS_FRONT = "front"
|
||||||
|
OBS_REAR = "rear"
|
||||||
|
OBS_LINEAR_VEL = "linear.vel"
|
||||||
|
OBS_BATTERY_LEVEL = "battery.level"
|
||||||
|
OBS_ORIENTATION_DEG = "orientation.deg"
|
||||||
|
OBS_GPS_LATITUDE = "gps.latitude"
|
||||||
|
OBS_GPS_LONGITUDE = "gps.longitude"
|
||||||
|
OBS_GPS_SIGNAL = "gps.signal"
|
||||||
|
OBS_SIGNAL_LEVEL = "signal.level"
|
||||||
|
OBS_VIBRATION = "vibration"
|
||||||
|
OBS_LAMP_STATE = "lamp.state"
|
||||||
|
|
||||||
|
|
||||||
|
class EarthRoverMiniPlus(Robot):
|
||||||
|
"""
|
||||||
|
EarthRover Mini Plus robot controlled via Frodobots SDK HTTP API.
|
||||||
|
|
||||||
|
This robot uses cloud-based control through the Frodobots SDK instead of direct
|
||||||
|
hardware connection. Cameras stream via WebRTC through Agora cloud, and control
|
||||||
|
commands are sent via HTTP POST requests.
|
||||||
|
|
||||||
|
The robot supports:
|
||||||
|
- Dual cameras (front and rear) accessed via SDK HTTP endpoints
|
||||||
|
- Linear and angular velocity control
|
||||||
|
- Battery and orientation telemetry
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
config: Robot configuration
|
||||||
|
sdk_base_url: URL of the Frodobots SDK server (default: http://localhost:8000)
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = EarthRoverMiniPlusConfig
|
||||||
|
name = "earthrover_mini_plus"
|
||||||
|
|
||||||
|
def __init__(self, config: EarthRoverMiniPlusConfig):
|
||||||
|
"""Initialize EarthRover Mini Plus robot.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Robot configuration including SDK URL
|
||||||
|
"""
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.sdk_base_url = "http://localhost:8000"
|
||||||
|
|
||||||
|
# Empty cameras dict for compatibility with recording script
|
||||||
|
# Cameras are accessed directly via SDK, not through Camera objects
|
||||||
|
self.cameras = {}
|
||||||
|
self._is_connected = False
|
||||||
|
|
||||||
|
# Cache for camera frames (fallback when requests fail)
|
||||||
|
self._last_front_frame = None
|
||||||
|
self._last_rear_frame = None
|
||||||
|
|
||||||
|
# Cache for robot telemetry data (fallback when requests fail)
|
||||||
|
self._last_robot_data = None
|
||||||
|
|
||||||
|
logger.info(f"Initialized {self.name} with SDK at {self.sdk_base_url}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
"""Check if robot is connected to SDK."""
|
||||||
|
return self._is_connected
|
||||||
|
|
||||||
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
|
"""Connect to robot via Frodobots SDK.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
calibrate: Not used for SDK-based robot (kept for API compatibility)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DeviceAlreadyConnectedError: If robot is already connected
|
||||||
|
DeviceNotConnectedError: If cannot connect to SDK server
|
||||||
|
"""
|
||||||
|
if self._is_connected:
|
||||||
|
raise DeviceAlreadyConnectedError(f"{self.name} is already connected")
|
||||||
|
|
||||||
|
# Verify SDK is running and accessible
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{self.sdk_base_url}/data", timeout=10.0)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise DeviceNotConnectedError(
|
||||||
|
f"Cannot connect to SDK at {self.sdk_base_url}. "
|
||||||
|
"Make sure it's running: hypercorn main:app --reload"
|
||||||
|
)
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise DeviceNotConnectedError(f"Cannot connect to SDK at {self.sdk_base_url}: {e}") from e
|
||||||
|
|
||||||
|
self._is_connected = True
|
||||||
|
logger.info(f"{self.name} connected to SDK")
|
||||||
|
|
||||||
|
if calibrate:
|
||||||
|
self.calibrate()
|
||||||
|
|
||||||
|
def calibrate(self) -> None:
|
||||||
|
"""Calibration not needed for SDK-based robot."""
|
||||||
|
logger.info("Calibration not required for SDK-based robot")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
"""SDK robot doesn't require calibration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Always True for SDK-based robots
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def configure(self) -> None:
|
||||||
|
"""Configure robot (no-op for SDK-based robot)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def observation_features(self) -> dict[str, type | tuple]:
|
||||||
|
"""Define the observation space for dataset recording.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Observation features with types/shapes:
|
||||||
|
- front: (480, 640, 3) - Front camera RGB image
|
||||||
|
- rear: (480, 640, 3) - Rear camera RGB image
|
||||||
|
- linear.vel: float - Current speed (0-1, SDK reports only positive speeds)
|
||||||
|
- battery.level: float - Battery level (0-1, normalized from 0-100)
|
||||||
|
- orientation.deg: float - Robot orientation (0-1, normalized from raw value)
|
||||||
|
- gps.latitude: float - GPS latitude coordinate
|
||||||
|
- gps.longitude: float - GPS longitude coordinate
|
||||||
|
- gps.signal: float - GPS signal strength (0-1, normalized from percentage)
|
||||||
|
- signal.level: float - Network signal level (0-1, normalized from 0-5)
|
||||||
|
- vibration: float - Vibration sensor reading
|
||||||
|
- lamp.state: float - Lamp state (0=off, 1=on)
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
# Cameras (height, width, channels)
|
||||||
|
OBS_FRONT: (480, 640, 3),
|
||||||
|
OBS_REAR: (480, 640, 3),
|
||||||
|
# Motion state
|
||||||
|
OBS_LINEAR_VEL: float,
|
||||||
|
# Robot state
|
||||||
|
OBS_BATTERY_LEVEL: float,
|
||||||
|
OBS_ORIENTATION_DEG: float,
|
||||||
|
# GPS
|
||||||
|
OBS_GPS_LATITUDE: float,
|
||||||
|
OBS_GPS_LONGITUDE: float,
|
||||||
|
OBS_GPS_SIGNAL: float,
|
||||||
|
# Sensors
|
||||||
|
OBS_SIGNAL_LEVEL: float,
|
||||||
|
OBS_VIBRATION: float,
|
||||||
|
OBS_LAMP_STATE: float,
|
||||||
|
}
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def action_features(self) -> dict[str, type]:
|
||||||
|
"""Define the action space.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Action features with types:
|
||||||
|
- linear.vel: float - Target linear velocity
|
||||||
|
- angular.vel: float - Target angular velocity
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
ACTION_LINEAR_VEL: float,
|
||||||
|
ACTION_ANGULAR_VEL: float,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_observation(self) -> dict[str, Any]:
|
||||||
|
"""Get current robot observation from SDK.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Observation containing:
|
||||||
|
- front: Front camera image (480, 640, 3) in RGB format
|
||||||
|
- rear: Rear camera image (480, 640, 3) in RGB format
|
||||||
|
- linear.vel: Current speed (0-1, SDK reports only positive speeds)
|
||||||
|
- battery.level: Battery level (0-1, normalized from 0-100)
|
||||||
|
- orientation.deg: Robot orientation (0-1, normalized from raw value)
|
||||||
|
- gps.latitude: GPS latitude coordinate
|
||||||
|
- gps.longitude: GPS longitude coordinate
|
||||||
|
- gps.signal: GPS signal strength (0-1, normalized from percentage)
|
||||||
|
- signal.level: Network signal level (0-1, normalized from 0-5)
|
||||||
|
- vibration: Vibration sensor reading
|
||||||
|
- lamp.state: Lamp state (0=off, 1=on)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DeviceNotConnectedError: If robot is not connected
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Camera frames are retrieved from SDK endpoints /v2/front and /v2/rear.
|
||||||
|
Frames are decoded from base64 and converted from BGR to RGB format.
|
||||||
|
Robot telemetry is retrieved from /data endpoint.
|
||||||
|
All SDK values are normalized to appropriate ranges for dataset recording.
|
||||||
|
"""
|
||||||
|
if not self._is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||||
|
|
||||||
|
observation = {}
|
||||||
|
|
||||||
|
# Get camera images from SDK
|
||||||
|
frames = self._get_camera_frames()
|
||||||
|
observation[OBS_FRONT] = frames["front"]
|
||||||
|
observation[OBS_REAR] = frames["rear"]
|
||||||
|
|
||||||
|
# Get robot state from SDK
|
||||||
|
robot_data = self._get_robot_data()
|
||||||
|
|
||||||
|
# Motion state
|
||||||
|
observation[OBS_LINEAR_VEL] = robot_data["speed"] / 100.0 # Normalize 0-100 to 0-1
|
||||||
|
|
||||||
|
# Robot state
|
||||||
|
observation[OBS_BATTERY_LEVEL] = robot_data["battery"] / 100.0 # Normalize 0-100 to 0-1
|
||||||
|
observation[OBS_ORIENTATION_DEG] = robot_data["orientation"] / 360.0 # Normalize to 0-1
|
||||||
|
|
||||||
|
# GPS data
|
||||||
|
observation[OBS_GPS_LATITUDE] = robot_data["latitude"]
|
||||||
|
observation[OBS_GPS_LONGITUDE] = robot_data["longitude"]
|
||||||
|
observation[OBS_GPS_SIGNAL] = robot_data["gps_signal"] / 100.0 # Normalize percentage to 0-1
|
||||||
|
|
||||||
|
# Sensors
|
||||||
|
observation[OBS_SIGNAL_LEVEL] = robot_data["signal_level"] / 5.0 # Normalize 0-5 to 0-1
|
||||||
|
observation[OBS_VIBRATION] = robot_data["vibration"]
|
||||||
|
observation[OBS_LAMP_STATE] = float(robot_data["lamp"]) # 0 or 1
|
||||||
|
|
||||||
|
return observation
|
||||||
|
|
||||||
|
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Send action to robot via SDK.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: Action dict with keys:
|
||||||
|
- linear.vel: Target linear velocity (-1 to 1)
|
||||||
|
- angular.vel: Target angular velocity (-1 to 1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The action that was sent (matches action_features keys)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DeviceNotConnectedError: If robot is not connected
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Actions are sent to SDK via POST /control endpoint.
|
||||||
|
SDK expects commands in range [-1, 1].
|
||||||
|
"""
|
||||||
|
if not self._is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||||
|
|
||||||
|
# Extract action values and convert to float
|
||||||
|
linear = float(action.get(ACTION_LINEAR_VEL, 0.0))
|
||||||
|
angular = float(action.get(ACTION_ANGULAR_VEL, 0.0))
|
||||||
|
|
||||||
|
# Send command to SDK
|
||||||
|
try:
|
||||||
|
self._send_command_to_sdk(linear, angular)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending action: {e}")
|
||||||
|
|
||||||
|
# Return action in format matching action_features
|
||||||
|
return {
|
||||||
|
ACTION_LINEAR_VEL: linear,
|
||||||
|
ACTION_ANGULAR_VEL: angular,
|
||||||
|
}
|
||||||
|
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
"""Disconnect from robot.
|
||||||
|
|
||||||
|
Stops the robot and closes connection to SDK.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DeviceNotConnectedError: If robot is not connected
|
||||||
|
"""
|
||||||
|
if not self._is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||||
|
|
||||||
|
# Stop the robot before disconnecting
|
||||||
|
try:
|
||||||
|
self._send_command_to_sdk(0.0, 0.0)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to stop robot during disconnect: {e}")
|
||||||
|
|
||||||
|
self._is_connected = False
|
||||||
|
logger.info(f"{self.name} disconnected")
|
||||||
|
|
||||||
|
# Private helper methods for SDK communication
|
||||||
|
|
||||||
|
def _get_camera_frames(self) -> dict[str, np.ndarray]:
|
||||||
|
"""Get camera frames from SDK using v2 endpoints with caching fallback.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary with 'front' and 'rear' keys containing:
|
||||||
|
- Current frame (if request succeeds)
|
||||||
|
- Cached frame (if request fails but cache exists)
|
||||||
|
- Zero array (if request fails and no cache exists yet)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Uses /v2/front and /v2/rear endpoints which are 15x faster than /screenshot.
|
||||||
|
Images are base64 encoded, resized to 640x480, and converted from BGR to RGB.
|
||||||
|
If request fails, returns the last successfully retrieved frame (cached).
|
||||||
|
"""
|
||||||
|
frames = {}
|
||||||
|
|
||||||
|
# Get front camera
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{self.sdk_base_url}/v2/front", timeout=2.0)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
if "front_frame" in data and data["front_frame"]:
|
||||||
|
front_img = self._decode_base64_image(data["front_frame"])
|
||||||
|
if front_img is not None:
|
||||||
|
# Resize and convert BGR to RGB
|
||||||
|
front_img = cv2.resize(front_img, (640, 480))
|
||||||
|
front_rgb = cv2.cvtColor(front_img, cv2.COLOR_BGR2RGB)
|
||||||
|
frames["front"] = front_rgb
|
||||||
|
# Cache the successful frame
|
||||||
|
self._last_front_frame = front_rgb
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error fetching front camera: {e}")
|
||||||
|
|
||||||
|
# Fallback: use cache or zero array
|
||||||
|
if "front" not in frames:
|
||||||
|
if self._last_front_frame is not None:
|
||||||
|
frames["front"] = self._last_front_frame
|
||||||
|
else:
|
||||||
|
frames["front"] = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Get rear camera
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{self.sdk_base_url}/v2/rear", timeout=2.0)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
if "rear_frame" in data and data["rear_frame"]:
|
||||||
|
rear_img = self._decode_base64_image(data["rear_frame"])
|
||||||
|
if rear_img is not None:
|
||||||
|
# Resize and convert BGR to RGB
|
||||||
|
rear_img = cv2.resize(rear_img, (640, 480))
|
||||||
|
rear_rgb = cv2.cvtColor(rear_img, cv2.COLOR_BGR2RGB)
|
||||||
|
frames["rear"] = rear_rgb
|
||||||
|
# Cache the successful frame
|
||||||
|
self._last_rear_frame = rear_rgb
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error fetching rear camera: {e}")
|
||||||
|
|
||||||
|
# Fallback: use cache or zero array
|
||||||
|
if "rear" not in frames:
|
||||||
|
if self._last_rear_frame is not None:
|
||||||
|
frames["rear"] = self._last_rear_frame
|
||||||
|
else:
|
||||||
|
frames["rear"] = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
return frames
|
||||||
|
|
||||||
|
def _decode_base64_image(self, base64_string: str) -> np.ndarray | None:
|
||||||
|
"""Decode base64 string to image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base64_string: Base64 encoded image string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Decoded image in BGR format (OpenCV default), or None if decoding fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
img_bytes = base64.b64decode(base64_string)
|
||||||
|
nparr = np.frombuffer(img_bytes, np.uint8)
|
||||||
|
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||||
|
return img # Return in BGR format (OpenCV default)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error decoding image: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_robot_data(self) -> dict:
|
||||||
|
"""Get robot telemetry data from SDK.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Robot telemetry data including battery, speed, orientation, GPS, etc:
|
||||||
|
- Current data (if request succeeds)
|
||||||
|
- Cached data (if request fails but cache exists)
|
||||||
|
- Default values (if request fails and no cache exists yet)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Uses /data endpoint which provides comprehensive robot state.
|
||||||
|
If request fails, returns the last successfully retrieved data (cached).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{self.sdk_base_url}/data", timeout=2.0)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
# Cache the successful data
|
||||||
|
self._last_robot_data = data
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error fetching robot data: {e}")
|
||||||
|
|
||||||
|
# Fallback: use cache or default values
|
||||||
|
if self._last_robot_data is not None:
|
||||||
|
return self._last_robot_data
|
||||||
|
else:
|
||||||
|
# Return dict with default values (used only on first failure before any cache exists)
|
||||||
|
return {
|
||||||
|
"speed": 0,
|
||||||
|
"battery": 0,
|
||||||
|
"orientation": 0,
|
||||||
|
"latitude": 0.0,
|
||||||
|
"longitude": 0.0,
|
||||||
|
"gps_signal": 0,
|
||||||
|
"signal_level": 0,
|
||||||
|
"vibration": 0.0,
|
||||||
|
"lamp": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _send_command_to_sdk(self, linear: float, angular: float, lamp: int = 0) -> bool:
|
||||||
|
"""Send control command to SDK.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
linear: Linear velocity command (-1 to 1)
|
||||||
|
angular: Angular velocity command (-1 to 1)
|
||||||
|
lamp: Lamp control (0=off, 1=on)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if command sent successfully, False otherwise
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Uses POST /control endpoint. Commands are sent as JSON payload.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
payload = {
|
||||||
|
"command": {
|
||||||
|
"linear": linear,
|
||||||
|
"angular": angular,
|
||||||
|
"lamp": lamp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.sdk_base_url}/control",
|
||||||
|
json=payload,
|
||||||
|
timeout=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.status_code == 200
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending command: {e}")
|
||||||
|
return False
|
||||||
21
src/lerobot/robots/omx_follower/__init__.py
Normal file
21
src/lerobot/robots/omx_follower/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# OMX is a fully open-source robot from ROBOTIS.
|
||||||
|
# More information at: https://ai.robotis.com/omx/introduction_omx.html
|
||||||
|
|
||||||
|
from .config_omx_follower import OmxFollowerConfig
|
||||||
|
from .omx_follower import OmxFollower
|
||||||
39
src/lerobot/robots/omx_follower/config_omx_follower.py
Normal file
39
src/lerobot/robots/omx_follower/config_omx_follower.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from lerobot.cameras import CameraConfig
|
||||||
|
|
||||||
|
from ..config import RobotConfig
|
||||||
|
|
||||||
|
|
||||||
|
@RobotConfig.register_subclass("omx_follower")
|
||||||
|
@dataclass
|
||||||
|
class OmxFollowerConfig(RobotConfig):
|
||||||
|
# Port to connect to the arm
|
||||||
|
port: str
|
||||||
|
|
||||||
|
disable_torque_on_disconnect: bool = True
|
||||||
|
|
||||||
|
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||||
|
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||||
|
# names to the max_relative_target value for that motor.
|
||||||
|
max_relative_target: float | dict[str, float] | None = None
|
||||||
|
|
||||||
|
# cameras
|
||||||
|
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Set to `True` for backward compatibility with previous policies/dataset
|
||||||
|
use_degrees: bool = False
|
||||||
225
src/lerobot/robots/omx_follower/omx_follower.py
Normal file
225
src/lerobot/robots/omx_follower/omx_follower.py
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from lerobot.cameras.utils import make_cameras_from_configs
|
||||||
|
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||||
|
from lerobot.motors.dynamixel import (
|
||||||
|
DriveMode,
|
||||||
|
DynamixelMotorsBus,
|
||||||
|
OperatingMode,
|
||||||
|
)
|
||||||
|
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||||
|
|
||||||
|
from ..robot import Robot
|
||||||
|
from ..utils import ensure_safe_goal_position
|
||||||
|
from .config_omx_follower import OmxFollowerConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OmxFollower(Robot):
|
||||||
|
"""
|
||||||
|
- [OMX](https://github.com/ROBOTIS-GIT/open_manipulator),
|
||||||
|
expansion, developed by Woojin Wie and Junha Cha from [ROBOTIS](https://ai.robotis.com/)
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = OmxFollowerConfig
|
||||||
|
name = "omx_follower"
|
||||||
|
|
||||||
|
def __init__(self, config: OmxFollowerConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100
|
||||||
|
self.bus = DynamixelMotorsBus(
|
||||||
|
port=self.config.port,
|
||||||
|
motors={
|
||||||
|
"shoulder_pan": Motor(11, "xl430-w250", norm_mode_body),
|
||||||
|
"shoulder_lift": Motor(12, "xl430-w250", norm_mode_body),
|
||||||
|
"elbow_flex": Motor(13, "xl430-w250", norm_mode_body),
|
||||||
|
"wrist_flex": Motor(14, "xl330-m288", norm_mode_body),
|
||||||
|
"wrist_roll": Motor(15, "xl330-m288", norm_mode_body),
|
||||||
|
"gripper": Motor(16, "xl330-m288", MotorNormMode.RANGE_0_100),
|
||||||
|
},
|
||||||
|
calibration=self.calibration,
|
||||||
|
)
|
||||||
|
self.cameras = make_cameras_from_configs(config.cameras)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _motors_ft(self) -> dict[str, type]:
|
||||||
|
return {f"{motor}.pos": float for motor in self.bus.motors}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _cameras_ft(self) -> dict[str, tuple]:
|
||||||
|
return {
|
||||||
|
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||||
|
}
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def observation_features(self) -> dict[str, type | tuple]:
|
||||||
|
return {**self._motors_ft, **self._cameras_ft}
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def action_features(self) -> dict[str, type]:
|
||||||
|
return self._motors_ft
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||||
|
|
||||||
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
For OMX robots that come pre-calibrated:
|
||||||
|
- If default calibration from package doesn't match motors, read from motors and save
|
||||||
|
- This allows using pre-calibrated robots without manual calibration
|
||||||
|
- If no calibration file exists, use factory default values (homing_offset=0, range_min=0, range_max=4095)
|
||||||
|
"""
|
||||||
|
if self.is_connected:
|
||||||
|
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||||
|
|
||||||
|
self.bus.connect()
|
||||||
|
if not self.is_calibrated and calibrate:
|
||||||
|
logger.info(
|
||||||
|
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
|
||||||
|
)
|
||||||
|
self.calibrate()
|
||||||
|
|
||||||
|
for cam in self.cameras.values():
|
||||||
|
cam.connect()
|
||||||
|
|
||||||
|
self.configure()
|
||||||
|
logger.info(f"{self} connected.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
return self.bus.is_calibrated
|
||||||
|
|
||||||
|
def calibrate(self) -> None:
|
||||||
|
self.bus.disable_torque()
|
||||||
|
logger.info(f"\nUsing factory default calibration values for {self}")
|
||||||
|
logger.info(f"\nWriting default configuration of {self} to the motors")
|
||||||
|
for motor in self.bus.motors:
|
||||||
|
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
|
||||||
|
|
||||||
|
for motor in self.bus.motors:
|
||||||
|
self.bus.write("Drive_Mode", motor, DriveMode.NON_INVERTED.value)
|
||||||
|
|
||||||
|
self.calibration = {}
|
||||||
|
for motor, m in self.bus.motors.items():
|
||||||
|
self.calibration[motor] = MotorCalibration(
|
||||||
|
id=m.id,
|
||||||
|
drive_mode=0,
|
||||||
|
homing_offset=0,
|
||||||
|
range_min=0,
|
||||||
|
range_max=4095,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bus.write_calibration(self.calibration)
|
||||||
|
self._save_calibration()
|
||||||
|
logger.info(f"Calibration saved to {self.calibration_fpath}")
|
||||||
|
|
||||||
|
def configure(self) -> None:
|
||||||
|
with self.bus.torque_disabled():
|
||||||
|
self.bus.configure_motors()
|
||||||
|
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
|
||||||
|
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling
|
||||||
|
# the arm, you could end up with a servo with a position 0 or 4095 at a crucial point
|
||||||
|
for motor in self.bus.motors:
|
||||||
|
if motor != "gripper":
|
||||||
|
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
|
||||||
|
|
||||||
|
# Use 'position control current based' for gripper to be limited by the limit of the current. For
|
||||||
|
# the follower gripper, it means it can grasp an object without forcing too much even tho, its
|
||||||
|
# goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
|
||||||
|
# For the leader gripper, it means we can use it as a physical trigger, since we can force with
|
||||||
|
# our finger to make it move, and it will move back to its original target position when we
|
||||||
|
# release the force.
|
||||||
|
self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
|
||||||
|
|
||||||
|
# Set better PID values to close the gap between recorded states and actions
|
||||||
|
# TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor
|
||||||
|
self.bus.write("Position_P_Gain", "elbow_flex", 1500)
|
||||||
|
self.bus.write("Position_I_Gain", "elbow_flex", 0)
|
||||||
|
self.bus.write("Position_D_Gain", "elbow_flex", 600)
|
||||||
|
|
||||||
|
def setup_motors(self) -> None:
|
||||||
|
for motor in reversed(self.bus.motors):
|
||||||
|
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||||
|
self.bus.setup_motor(motor)
|
||||||
|
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||||
|
|
||||||
|
def get_observation(self) -> dict[str, Any]:
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||||
|
|
||||||
|
# Read arm position
|
||||||
|
start = time.perf_counter()
|
||||||
|
obs_dict = self.bus.sync_read("Present_Position")
|
||||||
|
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
|
||||||
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
|
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||||
|
|
||||||
|
# Capture images from cameras
|
||||||
|
for cam_key, cam in self.cameras.items():
|
||||||
|
start = time.perf_counter()
|
||||||
|
obs_dict[cam_key] = cam.async_read()
|
||||||
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
|
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||||
|
|
||||||
|
return obs_dict
|
||||||
|
|
||||||
|
def send_action(self, action: dict[str, float]) -> dict[str, float]:
|
||||||
|
"""Command arm to move to a target joint configuration.
|
||||||
|
|
||||||
|
The relative action magnitude may be clipped depending on the configuration parameter
|
||||||
|
`max_relative_target`. In this case, the action sent differs from original action.
|
||||||
|
Thus, this function always returns the action actually sent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action (dict[str, float]): The goal positions for the motors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, float]: The action sent to the motors, potentially clipped.
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||||
|
|
||||||
|
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||||
|
|
||||||
|
# Cap goal position when too far away from present position.
|
||||||
|
# /!\ Slower fps expected due to reading from the follower.
|
||||||
|
if self.config.max_relative_target is not None:
|
||||||
|
present_pos = self.bus.sync_read("Present_Position")
|
||||||
|
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
|
||||||
|
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
||||||
|
|
||||||
|
# Send goal position to the arm
|
||||||
|
self.bus.sync_write("Goal_Position", goal_pos)
|
||||||
|
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||||
|
|
||||||
|
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||||
|
for cam in self.cameras.values():
|
||||||
|
cam.disconnect()
|
||||||
|
|
||||||
|
logger.info(f"{self} disconnected.")
|
||||||
20
src/lerobot/robots/openarms/__init__.py
Normal file
20
src/lerobot/robots/openarms/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .config_openarms_follower import OpenArmsFollowerConfig
|
||||||
|
from .openarms_follower import OpenArmsFollower
|
||||||
|
|
||||||
|
__all__ = ["OpenArmsFollower", "OpenArmsFollowerConfig"]
|
||||||
118
src/lerobot/robots/openarms/config_openarms_follower.py
Normal file
118
src/lerobot/robots/openarms/config_openarms_follower.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from lerobot.cameras import CameraConfig
|
||||||
|
from lerobot.motors.damiao.tables import MotorType
|
||||||
|
|
||||||
|
from ..config import RobotConfig
|
||||||
|
|
||||||
|
|
||||||
|
@RobotConfig.register_subclass("openarms_follower")
|
||||||
|
@dataclass
|
||||||
|
class OpenArmsFollowerConfig(RobotConfig):
|
||||||
|
"""Configuration for the OpenArms follower robot with Damiao motors."""
|
||||||
|
|
||||||
|
# CAN interfaces - one per arm
|
||||||
|
# Right arm CAN interface (e.g., "can0")
|
||||||
|
# Left arm CAN interface (e.g., "can1")
|
||||||
|
# Linux: "can0", "can1", etc.
|
||||||
|
# macOS: "/dev/cu.usbmodem*" (serial device)
|
||||||
|
port_right: str = "can0" # CAN interface for right arm
|
||||||
|
port_left: str = "can1" # CAN interface for left arm
|
||||||
|
|
||||||
|
# CAN interface type: "socketcan" (Linux), "slcan" (macOS/serial), or "auto" (auto-detect)
|
||||||
|
can_interface: str = "socketcan"
|
||||||
|
|
||||||
|
# CAN FD settings (OpenArms uses CAN FD by default)
|
||||||
|
use_can_fd: bool = True
|
||||||
|
can_bitrate: int = 1000000 # Nominal bitrate (1 Mbps)
|
||||||
|
can_data_bitrate: int = 5000000 # Data bitrate for CAN FD (5 Mbps)
|
||||||
|
|
||||||
|
# Whether to disable torque when disconnecting
|
||||||
|
disable_torque_on_disconnect: bool = True
|
||||||
|
|
||||||
|
# Safety limit for relative target positions
|
||||||
|
# Set to a positive scalar for all motors, or a dict mapping motor names to limits
|
||||||
|
max_relative_target: Optional[float | Dict[str, float]] = None
|
||||||
|
|
||||||
|
# Camera configurations
|
||||||
|
cameras: Dict[str, CameraConfig] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Motor configuration for OpenArms (7 DOF per arm)
|
||||||
|
# Maps motor names to (send_can_id, recv_can_id, motor_type)
|
||||||
|
# Based on: https://docs.openarm.dev/software/setup/configure-test
|
||||||
|
# OpenArms uses 4 types of motors:
|
||||||
|
# - DM8009 (DM-J8009P-2EC) for shoulders (high torque)
|
||||||
|
# - DM4340P and DM4340 for shoulder rotation and elbow
|
||||||
|
# - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper
|
||||||
|
motor_config: Dict[str, tuple[int, int, str]] = field(default_factory=lambda: {
|
||||||
|
"joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009)
|
||||||
|
"joint_2": (0x02, 0x12, "dm8009"), # J2 - Shoulder lift (DM8009)
|
||||||
|
"joint_3": (0x03, 0x13, "dm4340"), # J3 - Shoulder rotation (DM4340)
|
||||||
|
"joint_4": (0x04, 0x14, "dm4340"), # J4 - Elbow flex (DM4340)
|
||||||
|
"joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310)
|
||||||
|
"joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310)
|
||||||
|
"joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310)
|
||||||
|
"gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310)
|
||||||
|
})
|
||||||
|
|
||||||
|
# MIT control parameters for position control (used in send_action)
|
||||||
|
# List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
|
||||||
|
position_kp: list[float] = field(default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 25.0])
|
||||||
|
position_kd: list[float] = field(default_factory=lambda: [3.0, 3.0, 3.0, 3.0, 0.2, 0.2, 0.2, 0.2])
|
||||||
|
|
||||||
|
# Damping gains for stability when applying torque compensation (gravity/friction)
|
||||||
|
# Used when kp=0 and only torque is applied
|
||||||
|
damping_kd: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5, 0.5, 0.1, 0.1, 0.1, 0.1])
|
||||||
|
|
||||||
|
# Friction model parameters: τ_fric(ω) = Fo + Fv·ω + Fc·tanh(k·ω)
|
||||||
|
# From OpenArms config/follower.yaml
|
||||||
|
friction_fc: list[float] = field(default_factory=lambda: [0.306, 0.306, 0.40, 0.166, 0.050, 0.093, 0.172, 0.0512]) # Coulomb friction [Nm]
|
||||||
|
friction_k: list[float] = field(default_factory=lambda: [28.417, 28.417, 29.065, 130.038, 151.771, 242.287, 7.888, 4.000]) # tanh steepness
|
||||||
|
friction_fv: list[float] = field(default_factory=lambda: [0.063, 0.0630, 0.604, 0.813, 0.029, 0.072, 0.084, 0.084]) # Viscous friction [Nm·s/rad]
|
||||||
|
friction_fo: list[float] = field(default_factory=lambda: [0.088, 0.088, 0.008, -0.058, 0.005, 0.009, -0.059, -0.050]) # Offset torque [Nm]
|
||||||
|
|
||||||
|
# Calibration parameters
|
||||||
|
calibration_mode: str = "manual" # "manual" or "auto"
|
||||||
|
zero_position_on_connect: bool = False # Set zero position on connect
|
||||||
|
|
||||||
|
# Joint limits for position clipping (degrees)
|
||||||
|
# Format: [min, max] for each joint
|
||||||
|
# These limits clip commands in send_action to prevent mechanical damage
|
||||||
|
joint_limits_right: Dict[str, tuple[float, float]] = field(default_factory=lambda: {
|
||||||
|
"joint_1": (-75.0, 75.0),
|
||||||
|
"joint_2": (-9.0, 90.0),
|
||||||
|
"joint_3": (-85.0, 85.0),
|
||||||
|
"joint_4": (0.0, 135.0),
|
||||||
|
"joint_5": (-85.0, 85.0),
|
||||||
|
"joint_6": (-40.0, 40.0),
|
||||||
|
"joint_7": (-80.0, 80.0),
|
||||||
|
"gripper": (-65.0, 0.0),
|
||||||
|
})
|
||||||
|
|
||||||
|
joint_limits_left: Dict[str, tuple[float, float]] = field(default_factory=lambda: {
|
||||||
|
"joint_1": (-75.0, 75.0),
|
||||||
|
"joint_2": (-90.0, 9.0),
|
||||||
|
"joint_3": (-85.0, 85.0),
|
||||||
|
"joint_4": (0.0, 135.0),
|
||||||
|
"joint_5": (-85.0, 85.0),
|
||||||
|
"joint_6": (-40.0, 40.0),
|
||||||
|
"joint_7": (-80.0, 80.0),
|
||||||
|
"gripper": (-65.0, 0.0),
|
||||||
|
})
|
||||||
698
src/lerobot/robots/openarms/openarms_follower.py
Normal file
698
src/lerobot/robots/openarms/openarms_follower.py
Normal file
@@ -0,0 +1,698 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pinocchio as pin
|
||||||
|
|
||||||
|
from lerobot.cameras.utils import make_cameras_from_configs
|
||||||
|
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||||
|
from lerobot.motors.damiao import DamiaoMotorsBus
|
||||||
|
from lerobot.motors.damiao.tables import MotorType
|
||||||
|
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||||
|
|
||||||
|
from ..robot import Robot
|
||||||
|
from ..utils import ensure_safe_goal_position
|
||||||
|
from .config_openarms_follower import OpenArmsFollowerConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenArmsFollower(Robot):
|
||||||
|
"""
|
||||||
|
OpenArms Follower Robot which uses CAN bus communication to control 7 DOF arm with a gripper.
|
||||||
|
The arm uses Damiao motors in MIT control mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = OpenArmsFollowerConfig
|
||||||
|
name = "openarms_follower"
|
||||||
|
|
||||||
|
def __init__(self, config: OpenArmsFollowerConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
norm_mode_body = MotorNormMode.DEGREES # Always use degrees for Damiao motors
|
||||||
|
|
||||||
|
# Right arm motors (on port_right)
|
||||||
|
# Each arm uses the same CAN IDs since they're on separate buses
|
||||||
|
motors_right = {}
|
||||||
|
for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items():
|
||||||
|
motor = Motor(send_id, motor_type_str, norm_mode_body)
|
||||||
|
motor.recv_id = recv_id
|
||||||
|
motor.motor_type = getattr(MotorType, motor_type_str.upper().replace("-", "_"))
|
||||||
|
motors_right[motor_name] = motor
|
||||||
|
|
||||||
|
# Left arm motors (on port_left, same IDs as right since separate bus)
|
||||||
|
motors_left = {}
|
||||||
|
for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items():
|
||||||
|
motor = Motor(send_id, motor_type_str, norm_mode_body)
|
||||||
|
motor.recv_id = recv_id
|
||||||
|
motor.motor_type = getattr(MotorType, motor_type_str.upper().replace("-", "_"))
|
||||||
|
motors_left[motor_name] = motor
|
||||||
|
|
||||||
|
# Initialize separate Damiao motors buses (one per arm) with CAN FD support
|
||||||
|
self.bus_right = DamiaoMotorsBus(
|
||||||
|
port=self.config.port_right,
|
||||||
|
motors=motors_right,
|
||||||
|
calibration={k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")},
|
||||||
|
can_interface=self.config.can_interface,
|
||||||
|
use_can_fd=self.config.use_can_fd,
|
||||||
|
bitrate=self.config.can_bitrate,
|
||||||
|
data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bus_left = DamiaoMotorsBus(
|
||||||
|
port=self.config.port_left,
|
||||||
|
motors=motors_left,
|
||||||
|
calibration={k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")},
|
||||||
|
can_interface=self.config.can_interface,
|
||||||
|
use_can_fd=self.config.use_can_fd,
|
||||||
|
bitrate=self.config.can_bitrate,
|
||||||
|
data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize cameras
|
||||||
|
self.cameras = make_cameras_from_configs(config.cameras)
|
||||||
|
# Cache for last valid camera frames (to avoid blocking on slow USB reads)
|
||||||
|
self.camera_frame_cache = {key: None for key in self.cameras.keys()}
|
||||||
|
|
||||||
|
# Initialize Pinocchio robot model for dynamics (optional)
|
||||||
|
self.pin_robot = None
|
||||||
|
try:
|
||||||
|
# Load URDF - try external path first (with meshes), then repository
|
||||||
|
import os
|
||||||
|
from os.path import expanduser, dirname
|
||||||
|
|
||||||
|
# Try external URDF with meshes first
|
||||||
|
external_urdf_path = expanduser("~/Documents/openarm_description/openarm_bimanual_pybullet.urdf")
|
||||||
|
if os.path.exists(external_urdf_path):
|
||||||
|
urdf_path = external_urdf_path
|
||||||
|
urdf_dir = dirname(urdf_path)
|
||||||
|
|
||||||
|
self.pin_robot = pin.RobotWrapper.BuildFromURDF(urdf_path, urdf_dir)
|
||||||
|
self.pin_robot.data = self.pin_robot.model.createData()
|
||||||
|
logger.info(f"Loaded OpenArms URDF for dynamics computation from {urdf_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not load URDF for dynamics: {e}. Gravity compensation will not be available.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _motors_ft(self) -> Dict[str, type]:
|
||||||
|
"""Motor features for observation and action spaces."""
|
||||||
|
features = {}
|
||||||
|
# Right arm motors - only positions stored in dataset
|
||||||
|
for motor in self.bus_right.motors:
|
||||||
|
features[f"right_{motor}.pos"] = float
|
||||||
|
# Left arm motors - only positions stored in dataset
|
||||||
|
for motor in self.bus_left.motors:
|
||||||
|
features[f"left_{motor}.pos"] = float
|
||||||
|
return features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _cameras_ft(self) -> Dict[str, tuple]:
|
||||||
|
"""Camera features for observation space."""
|
||||||
|
return {
|
||||||
|
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3)
|
||||||
|
for cam in self.cameras
|
||||||
|
}
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def observation_features(self) -> Dict[str, type | tuple]:
|
||||||
|
"""Combined observation features from motors and cameras."""
|
||||||
|
return {**self._motors_ft, **self._cameras_ft}
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def action_features(self) -> Dict[str, type]:
|
||||||
|
"""Action features (motor positions only)."""
|
||||||
|
return self._motors_ft
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
"""Check if robot is connected."""
|
||||||
|
return (self.bus_right.is_connected and
|
||||||
|
self.bus_left.is_connected and
|
||||||
|
all(cam.is_connected for cam in self.cameras.values()))
|
||||||
|
|
||||||
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Connect to the robot and optionally calibrate.
|
||||||
|
|
||||||
|
We assume that at connection time, the arms are in a safe rest position,
|
||||||
|
and torque can be safely disabled to run calibration if needed.
|
||||||
|
"""
|
||||||
|
if self.is_connected:
|
||||||
|
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||||
|
|
||||||
|
# Connect to both CAN buses
|
||||||
|
logger.info(f"Connecting right arm on {self.config.port_right}...")
|
||||||
|
self.bus_right.connect()
|
||||||
|
logger.info(f"Connecting left arm on {self.config.port_left}...")
|
||||||
|
self.bus_left.connect()
|
||||||
|
|
||||||
|
# Run calibration if needed
|
||||||
|
if calibrate:
|
||||||
|
logger.info(
|
||||||
|
"No calibration found or overwriting calibration. Running calibration..."
|
||||||
|
)
|
||||||
|
self.calibrate()
|
||||||
|
|
||||||
|
# Connect cameras
|
||||||
|
for cam in self.cameras.values():
|
||||||
|
cam.connect()
|
||||||
|
|
||||||
|
# Configure motors
|
||||||
|
self.configure()
|
||||||
|
|
||||||
|
# Optionally set zero position
|
||||||
|
if self.config.zero_position_on_connect:
|
||||||
|
logger.info("Setting current position as zero...")
|
||||||
|
self.bus_right.set_zero_position()
|
||||||
|
self.bus_left.set_zero_position()
|
||||||
|
|
||||||
|
logger.info(f"{self} connected.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
"""Check if robot is calibrated."""
|
||||||
|
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
|
||||||
|
|
||||||
|
def calibrate(self) -> None:
|
||||||
|
"""
|
||||||
|
Run calibration procedure for OpenArms robot.
|
||||||
|
|
||||||
|
The calibration procedure:
|
||||||
|
1. Disable torque
|
||||||
|
2. Ask user to position arms in hanging position with grippers closed
|
||||||
|
3. Set this as zero position
|
||||||
|
4. Record range of motion for each joint
|
||||||
|
5. Save calibration
|
||||||
|
"""
|
||||||
|
if self.calibration:
|
||||||
|
# Ask user whether to use existing calibration
|
||||||
|
user_input = input(
|
||||||
|
f"Press ENTER to use existing calibration for {self.id}, "
|
||||||
|
f"or type 'c' and press ENTER to run new calibration: "
|
||||||
|
)
|
||||||
|
if user_input.strip().lower() != "c":
|
||||||
|
logger.info(f"Using existing calibration for {self.id}")
|
||||||
|
# Split calibration for each bus
|
||||||
|
cal_right = {k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")}
|
||||||
|
cal_left = {k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")}
|
||||||
|
self.bus_right.write_calibration(cal_right)
|
||||||
|
self.bus_left.write_calibration(cal_left)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"\nRunning calibration for {self}")
|
||||||
|
|
||||||
|
# Calibrate each arm separately
|
||||||
|
self._calibrate_arm("right", self.bus_right)
|
||||||
|
self._calibrate_arm("left", self.bus_left)
|
||||||
|
|
||||||
|
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
|
||||||
|
|
||||||
|
def _calibrate_arm(self, arm_name: str, bus: DamiaoMotorsBus) -> None:
|
||||||
|
"""Calibrate a single arm."""
|
||||||
|
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
|
||||||
|
|
||||||
|
# Disable torque for manual positioning
|
||||||
|
bus.disable_torque()
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Step 1: Set zero position
|
||||||
|
input(
|
||||||
|
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
|
||||||
|
"Position the arm in the following configuration:\n"
|
||||||
|
" - Arm hanging straight down\n"
|
||||||
|
" - Gripper closed\n"
|
||||||
|
"Press ENTER when ready..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set current position as zero for all motors
|
||||||
|
bus.set_zero_position()
|
||||||
|
logger.info(f"{arm_name.capitalize()} arm zero position set.")
|
||||||
|
|
||||||
|
# Automatically set range to -90° to +90° for all joints
|
||||||
|
print(
|
||||||
|
f"\nAutomatically setting range: -90° to +90° for all joints"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create calibration data with fixed ranges
|
||||||
|
if self.calibration is None:
|
||||||
|
self.calibration = {}
|
||||||
|
|
||||||
|
for motor_name, motor in bus.motors.items():
|
||||||
|
# Prefix motor name with arm name for storage
|
||||||
|
prefixed_name = f"{arm_name}_{motor_name}"
|
||||||
|
|
||||||
|
# Use -90 to +90 for all joints and gripper (integers required)
|
||||||
|
self.calibration[prefixed_name] = MotorCalibration(
|
||||||
|
id=motor.id,
|
||||||
|
drive_mode=0, # Normal direction
|
||||||
|
homing_offset=0, # Already set via set_zero_position
|
||||||
|
range_min=-90, # -90 degrees (integer)
|
||||||
|
range_max=90, # +90 degrees (integer)
|
||||||
|
)
|
||||||
|
logger.info(f" {prefixed_name}: range set to [-90°, +90°]")
|
||||||
|
|
||||||
|
# Write calibration to this arm's motors
|
||||||
|
cal_for_bus = {k.replace(f"{arm_name}_", ""): v for k, v in self.calibration.items() if k.startswith(f"{arm_name}_")}
|
||||||
|
bus.write_calibration(cal_for_bus)
|
||||||
|
|
||||||
|
# Re-enable torque
|
||||||
|
bus.enable_torque()
|
||||||
|
|
||||||
|
# Save calibration after each arm
|
||||||
|
self._save_calibration()
|
||||||
|
|
||||||
|
def configure(self) -> None:
|
||||||
|
"""Configure motors with appropriate settings."""
|
||||||
|
# Configure right arm
|
||||||
|
with self.bus_right.torque_disabled():
|
||||||
|
self.bus_right.configure_motors()
|
||||||
|
|
||||||
|
# Configure left arm
|
||||||
|
with self.bus_left.torque_disabled():
|
||||||
|
self.bus_left.configure_motors()
|
||||||
|
|
||||||
|
def setup_motors(self) -> None:
|
||||||
|
raise NotImplementedError("Motor ID configuration is typically done via manufacturer tools for CAN motors.")
|
||||||
|
|
||||||
|
def get_observation(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get current observation from robot including position, velocity, and torque.
|
||||||
|
|
||||||
|
OPTIMIZED: Reads all motor states (pos/vel/torque) in one CAN refresh cycle
|
||||||
|
instead of 3 separate reads.
|
||||||
|
|
||||||
|
Note: Velocity and torque are read but not stored in dataset (only used for
|
||||||
|
internal calculations). Only positions and camera images are stored.
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||||
|
|
||||||
|
obs_dict = {}
|
||||||
|
|
||||||
|
# Detailed profiling for bottleneck analysis
|
||||||
|
timings = {}
|
||||||
|
|
||||||
|
# OPTIMIZED: Use sync_read_all_states to get pos/vel/torque in one go
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
right_states = self.bus_right.sync_read_all_states()
|
||||||
|
timings["right_motors"] = (time.perf_counter() - t0) * 1000
|
||||||
|
|
||||||
|
for motor in self.bus_right.motors:
|
||||||
|
state = right_states.get(motor, {})
|
||||||
|
obs_dict[f"right_{motor}.pos"] = state.get("position", 0.0)
|
||||||
|
obs_dict[f"right_{motor}.vel"] = state.get("velocity", 0.0)
|
||||||
|
obs_dict[f"right_{motor}.torque"] = state.get("torque", 0.0)
|
||||||
|
|
||||||
|
# OPTIMIZED: Use sync_read_all_states to get pos/vel/torque in one go
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
left_states = self.bus_left.sync_read_all_states()
|
||||||
|
timings["left_motors"] = (time.perf_counter() - t0) * 1000
|
||||||
|
|
||||||
|
for motor in self.bus_left.motors:
|
||||||
|
state = left_states.get(motor, {})
|
||||||
|
obs_dict[f"left_{motor}.pos"] = state.get("position", 0.0)
|
||||||
|
obs_dict[f"left_{motor}.vel"] = state.get("velocity", 0.0)
|
||||||
|
obs_dict[f"left_{motor}.torque"] = state.get("torque", 0.0)
|
||||||
|
|
||||||
|
# Capture images from cameras (with individual timing)
|
||||||
|
# Use async_read with very short timeout to avoid blocking on slow USB cameras
|
||||||
|
for cam_key, cam in self.cameras.items():
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
try:
|
||||||
|
# Use 5ms timeout - if frame isn't ready, reuse last frame
|
||||||
|
frame = cam.async_read(timeout_ms=5)
|
||||||
|
self.camera_frame_cache[cam_key] = frame # Update cache
|
||||||
|
obs_dict[cam_key] = frame
|
||||||
|
except TimeoutError:
|
||||||
|
# If no new frame available, reuse last valid frame from cache
|
||||||
|
# This prevents blocking the entire control loop on slow USB reads
|
||||||
|
if self.camera_frame_cache[cam_key] is not None:
|
||||||
|
obs_dict[cam_key] = self.camera_frame_cache[cam_key]
|
||||||
|
logger.debug(f"Camera {cam_key} timeout, reusing cached frame")
|
||||||
|
|
||||||
|
# Store timing with padded name to align output (e.g. "left_wrist ")
|
||||||
|
timings[f"{cam_key:14s}"] = (time.perf_counter() - t0) * 1000
|
||||||
|
|
||||||
|
# Log detailed timings (for debugging slow observations)
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
total_time = sum(timings.values())
|
||||||
|
breakdown = " | ".join([f"{k}: {v:.1f}ms" for k, v in timings.items()])
|
||||||
|
logger.debug(f"{self} get_observation: {total_time:.1f}ms total | {breakdown}")
|
||||||
|
|
||||||
|
# Store timings in obs_dict for external profiling
|
||||||
|
obs_dict["_timing_breakdown"] = timings
|
||||||
|
|
||||||
|
return obs_dict
|
||||||
|
|
||||||
|
def send_action(
|
||||||
|
self,
|
||||||
|
action: Dict[str, Any],
|
||||||
|
custom_kp: Optional[Dict[str, float]] = None,
|
||||||
|
custom_kd: Optional[Dict[str, float]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Send action command to robot.
|
||||||
|
|
||||||
|
The action magnitude may be clipped based on safety limits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: Dictionary with motor positions (e.g., "right_joint_1.pos", "left_joint_2.pos")
|
||||||
|
custom_kp: Optional custom kp gains per motor (e.g., {"right_joint_1": 120.0, "left_joint_2": 150.0})
|
||||||
|
custom_kd: Optional custom kd gains per motor (e.g., {"right_joint_1": 1.5, "left_joint_2": 2.0})
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The action actually sent (potentially clipped)
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||||
|
|
||||||
|
# Extract motor positions from action and split by arm
|
||||||
|
goal_pos_right = {}
|
||||||
|
goal_pos_left = {}
|
||||||
|
|
||||||
|
for key, val in action.items():
|
||||||
|
if key.endswith(".pos"):
|
||||||
|
motor_name = key.removesuffix(".pos")
|
||||||
|
if motor_name.startswith("right_"):
|
||||||
|
# Remove "right_" prefix for bus access
|
||||||
|
goal_pos_right[motor_name.removeprefix("right_")] = val
|
||||||
|
elif motor_name.startswith("left_"):
|
||||||
|
# Remove "left_" prefix for bus access
|
||||||
|
goal_pos_left[motor_name.removeprefix("left_")] = val
|
||||||
|
|
||||||
|
# Apply joint limit clipping to right arm
|
||||||
|
for motor_name, position in goal_pos_right.items():
|
||||||
|
if motor_name in self.config.joint_limits_right:
|
||||||
|
min_limit, max_limit = self.config.joint_limits_right[motor_name]
|
||||||
|
clipped_position = max(min_limit, min(max_limit, position))
|
||||||
|
if clipped_position != position:
|
||||||
|
logger.debug(f"Clipped right_{motor_name} from {position:.2f}° to {clipped_position:.2f}°")
|
||||||
|
goal_pos_right[motor_name] = clipped_position
|
||||||
|
|
||||||
|
# Apply joint limit clipping to left arm
|
||||||
|
for motor_name, position in goal_pos_left.items():
|
||||||
|
if motor_name in self.config.joint_limits_left:
|
||||||
|
min_limit, max_limit = self.config.joint_limits_left[motor_name]
|
||||||
|
clipped_position = max(min_limit, min(max_limit, position))
|
||||||
|
if clipped_position != position:
|
||||||
|
logger.debug(f"Clipped left_{motor_name} from {position:.2f}° to {clipped_position:.2f}°")
|
||||||
|
goal_pos_left[motor_name] = clipped_position
|
||||||
|
|
||||||
|
# Apply safety limits if configured
|
||||||
|
if self.config.max_relative_target is not None:
|
||||||
|
# Get current positions
|
||||||
|
present_pos_right = self.bus_right.sync_read("Present_Position")
|
||||||
|
present_pos_left = self.bus_left.sync_read("Present_Position")
|
||||||
|
|
||||||
|
# Apply safety limits to right arm
|
||||||
|
if goal_pos_right:
|
||||||
|
goal_present_pos_right = {
|
||||||
|
key: (g_pos, present_pos_right.get(key, 0.0))
|
||||||
|
for key, g_pos in goal_pos_right.items()
|
||||||
|
}
|
||||||
|
goal_pos_right = ensure_safe_goal_position(
|
||||||
|
goal_present_pos_right,
|
||||||
|
self.config.max_relative_target
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply safety limits to left arm
|
||||||
|
if goal_pos_left:
|
||||||
|
goal_present_pos_left = {
|
||||||
|
key: (g_pos, present_pos_left.get(key, 0.0))
|
||||||
|
for key, g_pos in goal_pos_left.items()
|
||||||
|
}
|
||||||
|
goal_pos_left = ensure_safe_goal_position(
|
||||||
|
goal_present_pos_left,
|
||||||
|
self.config.max_relative_target
|
||||||
|
)
|
||||||
|
|
||||||
|
# Motor name to index mapping for gains
|
||||||
|
motor_index = {
|
||||||
|
"joint_1": 0,
|
||||||
|
"joint_2": 1,
|
||||||
|
"joint_3": 2,
|
||||||
|
"joint_4": 3,
|
||||||
|
"joint_5": 4,
|
||||||
|
"joint_6": 5,
|
||||||
|
"joint_7": 6,
|
||||||
|
"gripper": 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use batch MIT control for right arm (sends all commands, then collects responses)
|
||||||
|
if goal_pos_right:
|
||||||
|
commands_right = {}
|
||||||
|
for motor_name, position_degrees in goal_pos_right.items():
|
||||||
|
idx = motor_index.get(motor_name, 0)
|
||||||
|
|
||||||
|
# Use custom gains if provided, otherwise use config defaults
|
||||||
|
full_motor_name = f"right_{motor_name}"
|
||||||
|
if custom_kp is not None and full_motor_name in custom_kp:
|
||||||
|
kp = custom_kp[full_motor_name]
|
||||||
|
else:
|
||||||
|
kp = self.config.position_kp[idx] if isinstance(self.config.position_kp, list) else self.config.position_kp
|
||||||
|
|
||||||
|
if custom_kd is not None and full_motor_name in custom_kd:
|
||||||
|
kd = custom_kd[full_motor_name]
|
||||||
|
else:
|
||||||
|
kd = self.config.position_kd[idx] if isinstance(self.config.position_kd, list) else self.config.position_kd
|
||||||
|
|
||||||
|
commands_right[motor_name] = (kp, kd, position_degrees, 0.0, 0.0)
|
||||||
|
self.bus_right._mit_control_batch(commands_right)
|
||||||
|
|
||||||
|
# Use batch MIT control for left arm (sends all commands, then collects responses)
|
||||||
|
if goal_pos_left:
|
||||||
|
commands_left = {}
|
||||||
|
for motor_name, position_degrees in goal_pos_left.items():
|
||||||
|
idx = motor_index.get(motor_name, 0)
|
||||||
|
|
||||||
|
# Use custom gains if provided, otherwise use config defaults
|
||||||
|
full_motor_name = f"left_{motor_name}"
|
||||||
|
if custom_kp is not None and full_motor_name in custom_kp:
|
||||||
|
kp = custom_kp[full_motor_name]
|
||||||
|
else:
|
||||||
|
kp = self.config.position_kp[idx] if isinstance(self.config.position_kp, list) else self.config.position_kp
|
||||||
|
|
||||||
|
if custom_kd is not None and full_motor_name in custom_kd:
|
||||||
|
kd = custom_kd[full_motor_name]
|
||||||
|
else:
|
||||||
|
kd = self.config.position_kd[idx] if isinstance(self.config.position_kd, list) else self.config.position_kd
|
||||||
|
|
||||||
|
commands_left[motor_name] = (kp, kd, position_degrees, 0.0, 0.0)
|
||||||
|
self.bus_left._mit_control_batch(commands_left)
|
||||||
|
|
||||||
|
# Return the actions that were actually sent
|
||||||
|
result = {}
|
||||||
|
for motor, val in goal_pos_right.items():
|
||||||
|
result[f"right_{motor}.pos"] = val
|
||||||
|
for motor, val in goal_pos_left.items():
|
||||||
|
result[f"left_{motor}.pos"] = val
|
||||||
|
return result
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
"""Disconnect from robot."""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||||
|
|
||||||
|
# Disconnect from CAN buses
|
||||||
|
self.bus_right.disconnect(self.config.disable_torque_on_disconnect)
|
||||||
|
self.bus_left.disconnect(self.config.disable_torque_on_disconnect)
|
||||||
|
|
||||||
|
# Disconnect cameras
|
||||||
|
for cam in self.cameras.values():
|
||||||
|
cam.disconnect()
|
||||||
|
|
||||||
|
logger.info(f"{self} disconnected.")
|
||||||
|
|
||||||
|
def _deg_to_rad(self, deg: Dict[str, float | int]) -> Dict[str, float]:
|
||||||
|
"""Convert degrees to radians for all motors."""
|
||||||
|
return {m: np.deg2rad(float(v)) for m, v in deg.items()}
|
||||||
|
|
||||||
|
def _gravity_from_q(self, q_rad: Dict[str, float]) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Compute g(q) [N·m] for all joints in the robot.
|
||||||
|
The order of joints in the URDF matches the concatenated motor lists (right then left).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q_rad: Dictionary mapping motor names (with arm prefix) to positions in radians
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping motor names to gravity torques in N·m
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If URDF model is not loaded
|
||||||
|
"""
|
||||||
|
if self.pin_robot is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot compute gravity: URDF model not loaded. "
|
||||||
|
"Ensure urdf/openarms.urdf exists and is valid."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build position vector in the order of motors (left arm, then right arm)
|
||||||
|
# This order must match the URDF joint order
|
||||||
|
# URDF has: left_joint1-7, left_finger_joint1-2, right_joint1-7, right_finger_joint1-2
|
||||||
|
q = np.zeros(self.pin_robot.model.nq)
|
||||||
|
idx = 0
|
||||||
|
|
||||||
|
# Left arm motors (first in URDF) - joints 1-7
|
||||||
|
for motor_name in self.bus_left.motors:
|
||||||
|
if motor_name == "gripper":
|
||||||
|
continue # Skip gripper, will be handled separately
|
||||||
|
full_name = f"left_{motor_name}"
|
||||||
|
q[idx] = q_rad.get(full_name, 0.0)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
# Skip left finger joints (leave as zeros)
|
||||||
|
idx += 2
|
||||||
|
|
||||||
|
# Right arm motors (second in URDF) - joints 1-7
|
||||||
|
for motor_name in self.bus_right.motors:
|
||||||
|
if motor_name == "gripper":
|
||||||
|
continue # Skip gripper, will be handled separately
|
||||||
|
full_name = f"right_{motor_name}"
|
||||||
|
q[idx] = q_rad.get(full_name, 0.0)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
# Skip right finger joints (leave as zeros)
|
||||||
|
idx += 2
|
||||||
|
|
||||||
|
# Compute generalized gravity vector
|
||||||
|
g = pin.computeGeneralizedGravity(self.pin_robot.model, self.pin_robot.data, q)
|
||||||
|
|
||||||
|
# Map back to motor names (only arm joints, not fingers)
|
||||||
|
result = {}
|
||||||
|
idx = 0
|
||||||
|
|
||||||
|
# Left arm torques (joints 1-7)
|
||||||
|
for motor_name in self.bus_left.motors:
|
||||||
|
if motor_name == "gripper":
|
||||||
|
result["left_gripper"] = 0.0 # No gravity compensation for gripper
|
||||||
|
continue
|
||||||
|
result[f"left_{motor_name}"] = float(g[idx])
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
# Skip left finger joint torques in output
|
||||||
|
idx += 2
|
||||||
|
|
||||||
|
# Right arm torques (joints 1-7)
|
||||||
|
for motor_name in self.bus_right.motors:
|
||||||
|
if motor_name == "gripper":
|
||||||
|
result["right_gripper"] = 0.0 # No gravity compensation for gripper
|
||||||
|
continue
|
||||||
|
result[f"right_{motor_name}"] = float(g[idx])
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
# Skip right finger joint torques in output
|
||||||
|
idx += 2
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _friction_from_velocity(
|
||||||
|
self,
|
||||||
|
velocity_rad_per_sec: Dict[str, float],
|
||||||
|
friction_scale: float = 1.0,
|
||||||
|
amp_tmp: float = 1.0,
|
||||||
|
coef_tmp: float = 0.1
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Compute friction torques for all joints in the robot using tanh friction model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
velocity_rad_per_sec: Dictionary mapping motor names (with arm prefix) to velocities in rad/s
|
||||||
|
friction_scale: Scale factor for friction compensation (default 1.0, use 0.3 for stability)
|
||||||
|
amp_tmp: Amplitude factor for tanh term (default 1.0)
|
||||||
|
coef_tmp: Coefficient for tanh steepness (default 0.1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping motor names to friction torques in N·m
|
||||||
|
"""
|
||||||
|
# Motor name to index mapping
|
||||||
|
motor_name_to_index = {
|
||||||
|
"joint_1": 0,
|
||||||
|
"joint_2": 1,
|
||||||
|
"joint_3": 2,
|
||||||
|
"joint_4": 3,
|
||||||
|
"joint_5": 4,
|
||||||
|
"joint_6": 5,
|
||||||
|
"joint_7": 6,
|
||||||
|
"gripper": 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# Process all motors (left and right)
|
||||||
|
for motor_full_name, velocity in velocity_rad_per_sec.items():
|
||||||
|
# Extract motor name without arm prefix
|
||||||
|
if motor_full_name.startswith("right_"):
|
||||||
|
motor_name = motor_full_name.removeprefix("right_")
|
||||||
|
elif motor_full_name.startswith("left_"):
|
||||||
|
motor_name = motor_full_name.removeprefix("left_")
|
||||||
|
else:
|
||||||
|
result[motor_full_name] = 0.0
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get motor index for friction parameters
|
||||||
|
motor_index = motor_name_to_index.get(motor_name, 0)
|
||||||
|
|
||||||
|
# Get friction parameters from config
|
||||||
|
Fc = self.config.friction_fc[motor_index]
|
||||||
|
k = self.config.friction_k[motor_index]
|
||||||
|
Fv = self.config.friction_fv[motor_index]
|
||||||
|
Fo = self.config.friction_fo[motor_index]
|
||||||
|
|
||||||
|
# Friction model: τ_fric = amp * Fc * tanh(coef * k * ω) + Fv * ω + Fo
|
||||||
|
friction_torque = (
|
||||||
|
amp_tmp * Fc * np.tanh(coef_tmp * k * velocity) +
|
||||||
|
Fv * velocity +
|
||||||
|
Fo
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply scale factor
|
||||||
|
friction_torque *= friction_scale
|
||||||
|
|
||||||
|
result[motor_full_name] = float(friction_torque)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_damping_kd(self, motor_name: str) -> float:
|
||||||
|
"""
|
||||||
|
Get damping gain (Kd) for a specific motor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
motor_name: Motor name without arm prefix (e.g., "joint_1", "gripper")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Damping gain value
|
||||||
|
"""
|
||||||
|
motor_name_to_index = {
|
||||||
|
"joint_1": 0,
|
||||||
|
"joint_2": 1,
|
||||||
|
"joint_3": 2,
|
||||||
|
"joint_4": 3,
|
||||||
|
"joint_5": 4,
|
||||||
|
"joint_6": 5,
|
||||||
|
"joint_7": 6,
|
||||||
|
"gripper": 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
motor_index = motor_name_to_index.get(motor_name, 0)
|
||||||
|
return self.config.damping_kd[motor_index]
|
||||||
|
|
||||||
18
src/lerobot/robots/unitree_g1/__init__.py
Normal file
18
src/lerobot/robots/unitree_g1/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .config_unitree_g1 import UnitreeG1Config
|
||||||
|
from .unitree_g1 import UnitreeG1
|
||||||
58
src/lerobot/robots/unitree_g1/config_unitree_g1.py
Normal file
58
src/lerobot/robots/unitree_g1/config_unitree_g1.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from ..config import RobotConfig
|
||||||
|
|
||||||
|
_GAINS: dict[str, dict[str, list[float]]] = {
|
||||||
|
"left_leg": {
|
||||||
|
"kp": [150, 150, 150, 300, 40, 40],
|
||||||
|
"kd": [2, 2, 2, 4, 2, 2],
|
||||||
|
}, # pitch, roll, yaw, knee, ankle_pitch, ankle_roll
|
||||||
|
"right_leg": {"kp": [150, 150, 150, 300, 40, 40], "kd": [2, 2, 2, 4, 2, 2]},
|
||||||
|
"waist": {"kp": [250, 250, 250], "kd": [5, 5, 5]}, # yaw, roll, pitch
|
||||||
|
"left_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]}, # shoulder_pitch/roll/yaw, elbow
|
||||||
|
"left_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]}, # roll, pitch, yaw
|
||||||
|
"right_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]},
|
||||||
|
"right_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]},
|
||||||
|
"other": {"kp": [80, 80, 80, 80, 80, 80], "kd": [3, 3, 3, 3, 3, 3]},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_gains() -> tuple[list[float], list[float]]:
|
||||||
|
"""Build kp and kd lists from body-part groupings."""
|
||||||
|
kp = [v for g in _GAINS.values() for v in g["kp"]]
|
||||||
|
kd = [v for g in _GAINS.values() for v in g["kd"]]
|
||||||
|
return kp, kd
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_KP, _DEFAULT_KD = _build_gains()
|
||||||
|
|
||||||
|
|
||||||
|
@RobotConfig.register_subclass("unitree_g1")
|
||||||
|
@dataclass
|
||||||
|
class UnitreeG1Config(RobotConfig):
|
||||||
|
kp: list[float] = field(default_factory=lambda: _DEFAULT_KP.copy())
|
||||||
|
kd: list[float] = field(default_factory=lambda: _DEFAULT_KD.copy())
|
||||||
|
|
||||||
|
control_dt: float = 1.0 / 250.0 # 250Hz
|
||||||
|
|
||||||
|
# launch mujoco simulation
|
||||||
|
is_simulation: bool = True
|
||||||
|
|
||||||
|
# socket config for ZMQ bridge
|
||||||
|
robot_ip: str = "192.168.123.164"
|
||||||
89
src/lerobot/robots/unitree_g1/g1_utils.py
Normal file
89
src/lerobot/robots/unitree_g1/g1_utils.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
# ruff: noqa: N801, N815
|
||||||
|
|
||||||
|
NUM_MOTORS = 35
|
||||||
|
|
||||||
|
|
||||||
|
class G1_29_JointArmIndex(IntEnum):
|
||||||
|
# Left arm
|
||||||
|
kLeftShoulderPitch = 15
|
||||||
|
kLeftShoulderRoll = 16
|
||||||
|
kLeftShoulderYaw = 17
|
||||||
|
kLeftElbow = 18
|
||||||
|
kLeftWristRoll = 19
|
||||||
|
kLeftWristPitch = 20
|
||||||
|
kLeftWristyaw = 21
|
||||||
|
|
||||||
|
# Right arm
|
||||||
|
kRightShoulderPitch = 22
|
||||||
|
kRightShoulderRoll = 23
|
||||||
|
kRightShoulderYaw = 24
|
||||||
|
kRightElbow = 25
|
||||||
|
kRightWristRoll = 26
|
||||||
|
kRightWristPitch = 27
|
||||||
|
kRightWristYaw = 28
|
||||||
|
|
||||||
|
|
||||||
|
class G1_29_JointIndex(IntEnum):
|
||||||
|
# Left leg
|
||||||
|
kLeftHipPitch = 0
|
||||||
|
kLeftHipRoll = 1
|
||||||
|
kLeftHipYaw = 2
|
||||||
|
kLeftKnee = 3
|
||||||
|
kLeftAnklePitch = 4
|
||||||
|
kLeftAnkleRoll = 5
|
||||||
|
|
||||||
|
# Right leg
|
||||||
|
kRightHipPitch = 6
|
||||||
|
kRightHipRoll = 7
|
||||||
|
kRightHipYaw = 8
|
||||||
|
kRightKnee = 9
|
||||||
|
kRightAnklePitch = 10
|
||||||
|
kRightAnkleRoll = 11
|
||||||
|
|
||||||
|
kWaistYaw = 12
|
||||||
|
kWaistRoll = 13
|
||||||
|
kWaistPitch = 14
|
||||||
|
|
||||||
|
# Left arm
|
||||||
|
kLeftShoulderPitch = 15
|
||||||
|
kLeftShoulderRoll = 16
|
||||||
|
kLeftShoulderYaw = 17
|
||||||
|
kLeftElbow = 18
|
||||||
|
kLeftWristRoll = 19
|
||||||
|
kLeftWristPitch = 20
|
||||||
|
kLeftWristyaw = 21
|
||||||
|
|
||||||
|
# Right arm
|
||||||
|
kRightShoulderPitch = 22
|
||||||
|
kRightShoulderRoll = 23
|
||||||
|
kRightShoulderYaw = 24
|
||||||
|
kRightElbow = 25
|
||||||
|
kRightWristRoll = 26
|
||||||
|
kRightWristPitch = 27
|
||||||
|
kRightWristYaw = 28
|
||||||
|
|
||||||
|
# not used
|
||||||
|
kNotUsedJoint0 = 29
|
||||||
|
kNotUsedJoint1 = 30
|
||||||
|
kNotUsedJoint2 = 31
|
||||||
|
kNotUsedJoint3 = 32
|
||||||
|
kNotUsedJoint4 = 33
|
||||||
|
kNotUsedJoint5 = 34
|
||||||
212
src/lerobot/robots/unitree_g1/run_g1_server.py
Normal file
212
src/lerobot/robots/unitree_g1/run_g1_server.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
DDS-to-ZMQ bridge server for Unitree G1 robot.
|
||||||
|
|
||||||
|
This server runs on the robot and forwards:
|
||||||
|
- Robot state (LowState) from DDS to ZMQ (for remote clients)
|
||||||
|
- Robot commands (LowCmd) from ZMQ to DDS (from remote clients)
|
||||||
|
|
||||||
|
Uses JSON for secure serialization instead of pickle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import contextlib
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import MotionSwitcherClient
|
||||||
|
from unitree_sdk2py.core.channel import ChannelFactoryInitialize, ChannelPublisher, ChannelSubscriber
|
||||||
|
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||||
|
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState
|
||||||
|
from unitree_sdk2py.utils.crc import CRC
|
||||||
|
|
||||||
|
# DDS topic names follow Unitree SDK naming conventions
|
||||||
|
# ruff: noqa: N816
|
||||||
|
kTopicLowCommand_Debug = "rt/lowcmd" # action to robot
|
||||||
|
kTopicLowState = "rt/lowstate" # observation from robot
|
||||||
|
|
||||||
|
LOWCMD_PORT = 6000
|
||||||
|
LOWSTATE_PORT = 6001
|
||||||
|
NUM_MOTORS = 35
|
||||||
|
|
||||||
|
|
||||||
|
def lowstate_to_dict(msg: hg_LowState) -> dict[str, Any]:
|
||||||
|
"""Convert LowState SDK message to a JSON-serializable dictionary."""
|
||||||
|
motor_states = []
|
||||||
|
for i in range(NUM_MOTORS):
|
||||||
|
temp = msg.motor_state[i].temperature
|
||||||
|
avg_temp = float(sum(temp) / len(temp)) if isinstance(temp, list) else float(temp)
|
||||||
|
motor_states.append(
|
||||||
|
{
|
||||||
|
"q": float(msg.motor_state[i].q),
|
||||||
|
"dq": float(msg.motor_state[i].dq),
|
||||||
|
"tau_est": float(msg.motor_state[i].tau_est),
|
||||||
|
"temperature": avg_temp,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"motor_state": motor_states,
|
||||||
|
"imu_state": {
|
||||||
|
"quaternion": [float(x) for x in msg.imu_state.quaternion],
|
||||||
|
"gyroscope": [float(x) for x in msg.imu_state.gyroscope],
|
||||||
|
"accelerometer": [float(x) for x in msg.imu_state.accelerometer],
|
||||||
|
"rpy": [float(x) for x in msg.imu_state.rpy],
|
||||||
|
"temperature": float(msg.imu_state.temperature),
|
||||||
|
},
|
||||||
|
# Encode bytes as base64 for JSON compatibility
|
||||||
|
"wireless_remote": base64.b64encode(bytes(msg.wireless_remote)).decode("ascii"),
|
||||||
|
"mode_machine": int(msg.mode_machine),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def dict_to_lowcmd(data: dict[str, Any]) -> hg_LowCmd:
|
||||||
|
"""Convert dictionary back to LowCmd SDK message."""
|
||||||
|
cmd = unitree_hg_msg_dds__LowCmd_()
|
||||||
|
cmd.mode_pr = data.get("mode_pr", 0)
|
||||||
|
cmd.mode_machine = data.get("mode_machine", 0)
|
||||||
|
|
||||||
|
for i, motor_data in enumerate(data.get("motor_cmd", [])):
|
||||||
|
cmd.motor_cmd[i].mode = motor_data.get("mode", 0)
|
||||||
|
cmd.motor_cmd[i].q = motor_data.get("q", 0.0)
|
||||||
|
cmd.motor_cmd[i].dq = motor_data.get("dq", 0.0)
|
||||||
|
cmd.motor_cmd[i].kp = motor_data.get("kp", 0.0)
|
||||||
|
cmd.motor_cmd[i].kd = motor_data.get("kd", 0.0)
|
||||||
|
cmd.motor_cmd[i].tau = motor_data.get("tau", 0.0)
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
def state_forward_loop(
|
||||||
|
lowstate_sub: ChannelSubscriber,
|
||||||
|
lowstate_sock: zmq.Socket,
|
||||||
|
state_period: float,
|
||||||
|
shutdown_event: threading.Event,
|
||||||
|
) -> None:
|
||||||
|
"""Read observation from DDS and forward to ZMQ clients."""
|
||||||
|
last_state_time = 0.0
|
||||||
|
|
||||||
|
while not shutdown_event.is_set():
|
||||||
|
# read from DDS
|
||||||
|
msg = lowstate_sub.Read()
|
||||||
|
if msg is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
# optional downsampling (if robot dds rate > state_period)
|
||||||
|
if now - last_state_time >= state_period:
|
||||||
|
# Convert to dict and serialize with JSON
|
||||||
|
state_dict = lowstate_to_dict(msg)
|
||||||
|
payload = json.dumps({"topic": kTopicLowState, "data": state_dict}).encode("utf-8")
|
||||||
|
# if no subscribers / tx buffer full, just drop
|
||||||
|
with contextlib.suppress(zmq.Again):
|
||||||
|
lowstate_sock.send(payload, zmq.NOBLOCK)
|
||||||
|
last_state_time = now
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_forward_loop(
|
||||||
|
lowcmd_sock: zmq.Socket,
|
||||||
|
lowcmd_pub_debug: ChannelPublisher,
|
||||||
|
crc: CRC,
|
||||||
|
) -> None:
|
||||||
|
"""Receive commands from ZMQ and forward to DDS."""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
payload = lowcmd_sock.recv()
|
||||||
|
except zmq.ContextTerminated:
|
||||||
|
break
|
||||||
|
msg_dict = json.loads(payload.decode("utf-8"))
|
||||||
|
|
||||||
|
topic = msg_dict.get("topic", "")
|
||||||
|
cmd_data = msg_dict.get("data", {})
|
||||||
|
|
||||||
|
# Reconstruct LowCmd object from dict
|
||||||
|
cmd = dict_to_lowcmd(cmd_data)
|
||||||
|
|
||||||
|
# recompute crc
|
||||||
|
cmd.crc = crc.Crc(cmd)
|
||||||
|
|
||||||
|
if topic == kTopicLowCommand_Debug:
|
||||||
|
lowcmd_pub_debug.Write(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Main entry point for the robot server bridge."""
|
||||||
|
# initialize DDS
|
||||||
|
ChannelFactoryInitialize(0)
|
||||||
|
|
||||||
|
# stop all active publishers on the robot
|
||||||
|
msc = MotionSwitcherClient()
|
||||||
|
msc.SetTimeout(5.0)
|
||||||
|
msc.Init()
|
||||||
|
|
||||||
|
status, result = msc.CheckMode()
|
||||||
|
while result is not None and "name" in result and result["name"]:
|
||||||
|
msc.ReleaseMode()
|
||||||
|
status, result = msc.CheckMode()
|
||||||
|
time.sleep(1.0)
|
||||||
|
|
||||||
|
crc = CRC()
|
||||||
|
|
||||||
|
# initialize DDS publisher
|
||||||
|
lowcmd_pub_debug = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||||
|
lowcmd_pub_debug.Init()
|
||||||
|
|
||||||
|
# initialize DDS subscriber
|
||||||
|
lowstate_sub = ChannelSubscriber(kTopicLowState, hg_LowState)
|
||||||
|
lowstate_sub.Init()
|
||||||
|
|
||||||
|
# initialize ZMQ
|
||||||
|
ctx = zmq.Context.instance()
|
||||||
|
|
||||||
|
# receive commands from remote client
|
||||||
|
lowcmd_sock = ctx.socket(zmq.PULL)
|
||||||
|
lowcmd_sock.bind(f"tcp://0.0.0.0:{LOWCMD_PORT}")
|
||||||
|
|
||||||
|
# publish state to remote clients
|
||||||
|
lowstate_sock = ctx.socket(zmq.PUB)
|
||||||
|
lowstate_sock.bind(f"tcp://0.0.0.0:{LOWSTATE_PORT}")
|
||||||
|
|
||||||
|
state_period = 0.002 # ~500 hz
|
||||||
|
shutdown_event = threading.Event()
|
||||||
|
|
||||||
|
# start observation forwarding in background thread
|
||||||
|
t_state = threading.Thread(
|
||||||
|
target=state_forward_loop,
|
||||||
|
args=(lowstate_sub, lowstate_sock, state_period, shutdown_event),
|
||||||
|
)
|
||||||
|
t_state.start()
|
||||||
|
|
||||||
|
print("bridge running (lowstate -> zmq, lowcmd -> dds)")
|
||||||
|
|
||||||
|
# run command forwarding in main thread
|
||||||
|
try:
|
||||||
|
cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, crc)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("shutting down bridge...")
|
||||||
|
finally:
|
||||||
|
shutdown_event.set()
|
||||||
|
ctx.term() # terminates blocking zmq.recv() calls
|
||||||
|
t_state.join(timeout=2.0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
284
src/lerobot/robots/unitree_g1/unitree_g1.py
Normal file
284
src/lerobot/robots/unitree_g1/unitree_g1.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import struct
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||||
|
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
|
||||||
|
LowCmd_ as hg_LowCmd,
|
||||||
|
LowState_ as hg_LowState,
|
||||||
|
)
|
||||||
|
from unitree_sdk2py.utils.crc import CRC
|
||||||
|
|
||||||
|
from lerobot.envs.factory import make_env
|
||||||
|
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||||
|
|
||||||
|
from ..robot import Robot
|
||||||
|
from .config_unitree_g1 import UnitreeG1Config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# DDS topic names follow Unitree SDK naming conventions
|
||||||
|
# ruff: noqa: N816
|
||||||
|
kTopicLowCommand_Debug = "rt/lowcmd"
|
||||||
|
kTopicLowState = "rt/lowstate"
|
||||||
|
|
||||||
|
G1_29_Num_Motors = 35
|
||||||
|
G1_23_Num_Motors = 35
|
||||||
|
H1_2_Num_Motors = 35
|
||||||
|
H1_Num_Motors = 20
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MotorState:
|
||||||
|
q: float | None = None # position
|
||||||
|
dq: float | None = None # velocity
|
||||||
|
tau_est: float | None = None # estimated torque
|
||||||
|
temperature: float | None = None # motor temperature
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IMUState:
|
||||||
|
quaternion: np.ndarray | None = None # [w, x, y, z]
|
||||||
|
gyroscope: np.ndarray | None = None # [x, y, z] angular velocity (rad/s)
|
||||||
|
accelerometer: np.ndarray | None = None # [x, y, z] linear acceleration (m/s²)
|
||||||
|
rpy: np.ndarray | None = None # [roll, pitch, yaw] (rad)
|
||||||
|
temperature: float | None = None # IMU temperature
|
||||||
|
|
||||||
|
|
||||||
|
# g1 observation class
|
||||||
|
@dataclass
|
||||||
|
class G1_29_LowState: # noqa: N801
|
||||||
|
motor_state: list[MotorState] = field(
|
||||||
|
default_factory=lambda: [MotorState() for _ in range(G1_29_Num_Motors)]
|
||||||
|
)
|
||||||
|
imu_state: IMUState = field(default_factory=IMUState)
|
||||||
|
wireless_remote: Any = None # Raw wireless remote data
|
||||||
|
mode_machine: int = 0 # Robot mode
|
||||||
|
|
||||||
|
|
||||||
|
class DataBuffer:
|
||||||
|
def __init__(self):
|
||||||
|
self.data = None
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
def get_data(self):
|
||||||
|
with self.lock:
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
def set_data(self, data):
|
||||||
|
with self.lock:
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
|
||||||
|
class UnitreeG1(Robot):
|
||||||
|
config_class = UnitreeG1Config
|
||||||
|
name = "unitree_g1"
|
||||||
|
|
||||||
|
# unitree remote controller
|
||||||
|
class RemoteController:
|
||||||
|
def __init__(self):
|
||||||
|
self.lx = 0
|
||||||
|
self.ly = 0
|
||||||
|
self.rx = 0
|
||||||
|
self.ry = 0
|
||||||
|
self.button = [0] * 16
|
||||||
|
|
||||||
|
def set(self, data):
|
||||||
|
# wireless_remote
|
||||||
|
keys = struct.unpack("H", data[2:4])[0]
|
||||||
|
for i in range(16):
|
||||||
|
self.button[i] = (keys & (1 << i)) >> i
|
||||||
|
self.lx = struct.unpack("f", data[4:8])[0]
|
||||||
|
self.rx = struct.unpack("f", data[8:12])[0]
|
||||||
|
self.ry = struct.unpack("f", data[12:16])[0]
|
||||||
|
self.ly = struct.unpack("f", data[20:24])[0]
|
||||||
|
|
||||||
|
def __init__(self, config: UnitreeG1Config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
logger.info("Initialize UnitreeG1...")
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.control_dt = config.control_dt
|
||||||
|
|
||||||
|
if config.is_simulation:
|
||||||
|
from unitree_sdk2py.core.channel import (
|
||||||
|
ChannelFactoryInitialize,
|
||||||
|
ChannelPublisher,
|
||||||
|
ChannelSubscriber,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from lerobot.robots.unitree_g1.unitree_sdk2_socket import (
|
||||||
|
ChannelFactoryInitialize,
|
||||||
|
ChannelPublisher,
|
||||||
|
ChannelSubscriber,
|
||||||
|
)
|
||||||
|
|
||||||
|
# connect robot
|
||||||
|
self.ChannelFactoryInitialize = ChannelFactoryInitialize
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
# initialize direct motor control interface
|
||||||
|
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||||
|
self.lowcmd_publisher.Init()
|
||||||
|
self.lowstate_subscriber = ChannelSubscriber(kTopicLowState, hg_LowState)
|
||||||
|
self.lowstate_subscriber.Init()
|
||||||
|
self.lowstate_buffer = DataBuffer()
|
||||||
|
|
||||||
|
# initialize subscribe thread to read robot state
|
||||||
|
self._shutdown_event = threading.Event()
|
||||||
|
self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state)
|
||||||
|
self.subscribe_thread.start()
|
||||||
|
|
||||||
|
while not self.is_connected:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# initialize hg's lowcmd msg
|
||||||
|
self.crc = CRC()
|
||||||
|
self.msg = unitree_hg_msg_dds__LowCmd_()
|
||||||
|
self.msg.mode_pr = 0
|
||||||
|
|
||||||
|
# Wait for first state message to arrive
|
||||||
|
lowstate = None
|
||||||
|
while lowstate is None:
|
||||||
|
lowstate = self.lowstate_buffer.get_data()
|
||||||
|
if lowstate is None:
|
||||||
|
time.sleep(0.01)
|
||||||
|
logger.warning("[UnitreeG1] Waiting for robot state...")
|
||||||
|
logger.warning("[UnitreeG1] Connected to robot.")
|
||||||
|
self.msg.mode_machine = lowstate.mode_machine
|
||||||
|
|
||||||
|
# initialize all motors with unified kp/kd from config
|
||||||
|
self.kp = np.array(config.kp, dtype=np.float32)
|
||||||
|
self.kd = np.array(config.kd, dtype=np.float32)
|
||||||
|
|
||||||
|
for id in G1_29_JointIndex:
|
||||||
|
self.msg.motor_cmd[id].mode = 1
|
||||||
|
self.msg.motor_cmd[id].kp = self.kp[id.value]
|
||||||
|
self.msg.motor_cmd[id].kd = self.kd[id.value]
|
||||||
|
self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q
|
||||||
|
|
||||||
|
# Initialize remote controller
|
||||||
|
self.remote_controller = self.RemoteController()
|
||||||
|
|
||||||
|
def _subscribe_motor_state(self): # polls robot state @ 250Hz
|
||||||
|
while not self._shutdown_event.is_set():
|
||||||
|
start_time = time.time()
|
||||||
|
msg = self.lowstate_subscriber.Read()
|
||||||
|
if msg is not None:
|
||||||
|
lowstate = G1_29_LowState()
|
||||||
|
|
||||||
|
# Capture motor states
|
||||||
|
for id in range(G1_29_Num_Motors):
|
||||||
|
lowstate.motor_state[id].q = msg.motor_state[id].q
|
||||||
|
lowstate.motor_state[id].dq = msg.motor_state[id].dq
|
||||||
|
lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est
|
||||||
|
lowstate.motor_state[id].temperature = msg.motor_state[id].temperature
|
||||||
|
|
||||||
|
# Capture IMU state
|
||||||
|
lowstate.imu_state.quaternion = list(msg.imu_state.quaternion)
|
||||||
|
lowstate.imu_state.gyroscope = list(msg.imu_state.gyroscope)
|
||||||
|
lowstate.imu_state.accelerometer = list(msg.imu_state.accelerometer)
|
||||||
|
lowstate.imu_state.rpy = list(msg.imu_state.rpy)
|
||||||
|
lowstate.imu_state.temperature = msg.imu_state.temperature
|
||||||
|
|
||||||
|
# Capture wireless remote data
|
||||||
|
lowstate.wireless_remote = msg.wireless_remote
|
||||||
|
|
||||||
|
# Capture mode_machine
|
||||||
|
lowstate.mode_machine = msg.mode_machine
|
||||||
|
|
||||||
|
self.lowstate_buffer.set_data(lowstate)
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
all_t_elapsed = current_time - start_time
|
||||||
|
sleep_time = max(0, (self.control_dt - all_t_elapsed)) # maintain constant control dt
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def action_features(self) -> dict[str, type]:
|
||||||
|
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
|
||||||
|
|
||||||
|
def calibrate(self) -> None: # robot is already calibrated
|
||||||
|
pass
|
||||||
|
|
||||||
|
def configure(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def connect(self, calibrate: bool = True) -> None: # connect to DDS
|
||||||
|
if self.config.is_simulation:
|
||||||
|
self.ChannelFactoryInitialize(0, "lo")
|
||||||
|
self.mujoco_env = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True)
|
||||||
|
else:
|
||||||
|
self.ChannelFactoryInitialize(0)
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
self._shutdown_event.set()
|
||||||
|
self.subscribe_thread.join(timeout=2.0)
|
||||||
|
if self.config.is_simulation:
|
||||||
|
self.mujoco_env["hub_env"][0].envs[0].kill_sim()
|
||||||
|
|
||||||
|
def get_observation(self) -> dict[str, Any]:
|
||||||
|
return self.lowstate_buffer.get_data()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self.lowstate_buffer.get_data() is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _motors_ft(self) -> dict[str, type]:
|
||||||
|
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _cameras_ft(self) -> dict[str, tuple]:
|
||||||
|
return {
|
||||||
|
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||||
|
}
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def observation_features(self) -> dict[str, type | tuple]:
|
||||||
|
return {**self._motors_ft, **self._cameras_ft}
|
||||||
|
|
||||||
|
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
self.msg.crc = self.crc.Crc(action)
|
||||||
|
self.lowcmd_publisher.Write(action)
|
||||||
|
return action
|
||||||
|
|
||||||
|
def get_gravity_orientation(self, quaternion): # get gravity orientation from quaternion
|
||||||
|
"""Get gravity orientation from quaternion."""
|
||||||
|
qw = quaternion[0]
|
||||||
|
qx = quaternion[1]
|
||||||
|
qy = quaternion[2]
|
||||||
|
qz = quaternion[3]
|
||||||
|
|
||||||
|
gravity_orientation = np.zeros(3)
|
||||||
|
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
|
||||||
|
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
|
||||||
|
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
|
||||||
|
return gravity_orientation
|
||||||
168
src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py
Normal file
168
src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||||
|
|
||||||
|
_ctx: zmq.Context | None = None
|
||||||
|
_lowcmd_sock: zmq.Socket | None = None
|
||||||
|
_lowstate_sock: zmq.Socket | None = None
|
||||||
|
|
||||||
|
LOWCMD_PORT = 6000
|
||||||
|
LOWSTATE_PORT = 6001
|
||||||
|
|
||||||
|
# DDS topic names follow Unitree SDK naming conventions
|
||||||
|
# ruff: noqa: N816
|
||||||
|
kTopicLowCommand_Debug = "rt/lowcmd"
|
||||||
|
|
||||||
|
|
||||||
|
class LowStateMsg:
|
||||||
|
"""
|
||||||
|
Wrapper class that mimics the Unitree SDK LowState_ message structure.
|
||||||
|
|
||||||
|
Reconstructs the message from deserialized JSON data to maintain
|
||||||
|
compatibility with existing code that expects SDK message objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class MotorState:
|
||||||
|
"""Motor state data for a single joint."""
|
||||||
|
|
||||||
|
def __init__(self, data: dict[str, Any]) -> None:
|
||||||
|
self.q: float = data.get("q", 0.0)
|
||||||
|
self.dq: float = data.get("dq", 0.0)
|
||||||
|
self.tau_est: float = data.get("tau_est", 0.0)
|
||||||
|
self.temperature: float = data.get("temperature", 0.0)
|
||||||
|
|
||||||
|
class IMUState:
|
||||||
|
"""IMU sensor data."""
|
||||||
|
|
||||||
|
def __init__(self, data: dict[str, Any]) -> None:
|
||||||
|
self.quaternion: list[float] = data.get("quaternion", [1.0, 0.0, 0.0, 0.0])
|
||||||
|
self.gyroscope: list[float] = data.get("gyroscope", [0.0, 0.0, 0.0])
|
||||||
|
self.accelerometer: list[float] = data.get("accelerometer", [0.0, 0.0, 0.0])
|
||||||
|
self.rpy: list[float] = data.get("rpy", [0.0, 0.0, 0.0])
|
||||||
|
self.temperature: float = data.get("temperature", 0.0)
|
||||||
|
|
||||||
|
def __init__(self, data: dict[str, Any]) -> None:
|
||||||
|
"""Initialize from deserialized JSON data."""
|
||||||
|
self.motor_state = [self.MotorState(m) for m in data.get("motor_state", [])]
|
||||||
|
self.imu_state = self.IMUState(data.get("imu_state", {}))
|
||||||
|
# Decode base64-encoded wireless_remote bytes
|
||||||
|
wireless_b64 = data.get("wireless_remote", "")
|
||||||
|
self.wireless_remote: bytes = base64.b64decode(wireless_b64) if wireless_b64 else b""
|
||||||
|
self.mode_machine: int = data.get("mode_machine", 0)
|
||||||
|
|
||||||
|
|
||||||
|
def lowcmd_to_dict(topic: str, msg: Any) -> dict[str, Any]:
|
||||||
|
"""Convert LowCmd message to a JSON-serializable dictionary."""
|
||||||
|
motor_cmds = []
|
||||||
|
# Iterate over all motor commands in the message
|
||||||
|
for i in range(len(msg.motor_cmd)):
|
||||||
|
motor_cmds.append(
|
||||||
|
{
|
||||||
|
"mode": int(msg.motor_cmd[i].mode),
|
||||||
|
"q": float(msg.motor_cmd[i].q),
|
||||||
|
"dq": float(msg.motor_cmd[i].dq),
|
||||||
|
"kp": float(msg.motor_cmd[i].kp),
|
||||||
|
"kd": float(msg.motor_cmd[i].kd),
|
||||||
|
"tau": float(msg.motor_cmd[i].tau),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"topic": topic,
|
||||||
|
"data": {
|
||||||
|
"mode_pr": int(msg.mode_pr),
|
||||||
|
"mode_machine": int(msg.mode_machine),
|
||||||
|
"motor_cmd": motor_cmds,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def ChannelFactoryInitialize(*args: Any, **kwargs: Any) -> None: # noqa: N802
|
||||||
|
"""
|
||||||
|
Initialize ZMQ sockets for robot communication.
|
||||||
|
|
||||||
|
This function mimics the Unitree SDK's ChannelFactoryInitialize but uses
|
||||||
|
ZMQ sockets to connect to the robot server bridge instead of DDS.
|
||||||
|
"""
|
||||||
|
global _ctx, _lowcmd_sock, _lowstate_sock
|
||||||
|
|
||||||
|
# read socket config
|
||||||
|
config = UnitreeG1Config()
|
||||||
|
robot_ip = config.robot_ip
|
||||||
|
|
||||||
|
ctx = zmq.Context.instance()
|
||||||
|
_ctx = ctx
|
||||||
|
|
||||||
|
# lowcmd: send robot commands
|
||||||
|
lowcmd_sock = ctx.socket(zmq.PUSH)
|
||||||
|
lowcmd_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
|
||||||
|
lowcmd_sock.connect(f"tcp://{robot_ip}:{LOWCMD_PORT}")
|
||||||
|
_lowcmd_sock = lowcmd_sock
|
||||||
|
|
||||||
|
# lowstate: receive robot observations
|
||||||
|
lowstate_sock = ctx.socket(zmq.SUB)
|
||||||
|
lowstate_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
|
||||||
|
lowstate_sock.connect(f"tcp://{robot_ip}:{LOWSTATE_PORT}")
|
||||||
|
lowstate_sock.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||||
|
_lowstate_sock = lowstate_sock
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelPublisher:
|
||||||
|
"""ZMQ-based publisher that sends commands to the robot server."""
|
||||||
|
|
||||||
|
def __init__(self, topic: str, msg_type: type) -> None:
|
||||||
|
self.topic = topic
|
||||||
|
self.msg_type = msg_type
|
||||||
|
|
||||||
|
def Init(self) -> None: # noqa: N802
|
||||||
|
"""Initialize the publisher (no-op for ZMQ)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def Write(self, msg: Any) -> None: # noqa: N802
|
||||||
|
"""Serialize and send a command message to the robot."""
|
||||||
|
if _lowcmd_sock is None:
|
||||||
|
raise RuntimeError("ChannelFactoryInitialize must be called first")
|
||||||
|
|
||||||
|
payload = json.dumps(lowcmd_to_dict(self.topic, msg)).encode("utf-8")
|
||||||
|
_lowcmd_sock.send(payload)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelSubscriber:
|
||||||
|
"""ZMQ-based subscriber that receives state from the robot server."""
|
||||||
|
|
||||||
|
def __init__(self, topic: str, msg_type: type) -> None:
|
||||||
|
self.topic = topic
|
||||||
|
self.msg_type = msg_type
|
||||||
|
|
||||||
|
def Init(self) -> None: # noqa: N802
|
||||||
|
"""Initialize the subscriber (no-op for ZMQ)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def Read(self) -> LowStateMsg: # noqa: N802
|
||||||
|
"""Receive and deserialize a state message from the robot."""
|
||||||
|
if _lowstate_sock is None:
|
||||||
|
raise RuntimeError("ChannelFactoryInitialize must be called first")
|
||||||
|
|
||||||
|
payload = _lowstate_sock.recv()
|
||||||
|
msg_dict = json.loads(payload.decode("utf-8"))
|
||||||
|
return LowStateMsg(msg_dict.get("data", {}))
|
||||||
@@ -28,6 +28,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
|||||||
from .koch_follower import KochFollower
|
from .koch_follower import KochFollower
|
||||||
|
|
||||||
return KochFollower(config)
|
return KochFollower(config)
|
||||||
|
elif config.type == "omx_follower":
|
||||||
|
from .omx_follower import OmxFollower
|
||||||
|
|
||||||
|
return OmxFollower(config)
|
||||||
elif config.type == "so100_follower":
|
elif config.type == "so100_follower":
|
||||||
from .so100_follower import SO100Follower
|
from .so100_follower import SO100Follower
|
||||||
|
|
||||||
|
|||||||
41
src/lerobot/scratch.txt
Normal file
41
src/lerobot/scratch.txt
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# Eun visualizer locally
|
||||||
|
|
||||||
|
# login to hf an set your access token
|
||||||
|
hf auth login
|
||||||
|
# if not installed, install with: pip install huggingface_hub
|
||||||
|
git clone https://github.com/huggingface/lerobot-dataset-visualizer.git
|
||||||
|
cd lerobot-dataset-visualizer
|
||||||
|
python -m lerobot_dataset_viz --repo-id lerobot-data-collection/repo-id-nez --episode-index 0
|
||||||
|
git checkout feat/private_repo_viz
|
||||||
|
npm install
|
||||||
|
npm run dev
|
||||||
|
# open http://localhost:3000 in your browser
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================
|
||||||
|
|
||||||
|
|
||||||
|
# default merge command; copy your list of datasets ids in repo_ids
|
||||||
|
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot-data-collection/repo-id-nez \
|
||||||
|
--operation.type merge --push_to_hub true \
|
||||||
|
--operation.repo_ids "[]"
|
||||||
|
|
||||||
|
|
||||||
|
# merge test datasets into one
|
||||||
|
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot-data-collection/test-2025-11-03-merged \
|
||||||
|
--operation.type merge --push_to_hub true \
|
||||||
|
--operation.repo_ids "['lerobot-data-collection/test-2025-11-03-13-18', 'lerobot-data-collection/test-2025-11-03-13-19', 'lerobot-data-collection/test-2025-11-03-13-20', 'lerobot-data-collection/test-2025-11-03-13-21', 'lerobot-data-collection/test-2025-11-03-13-23', 'lerobot-data-collection/test-2025-11-03-13-24', 'lerobot-data-collection/test-2025-11-03-13-25', 'lerobot-data-collection/test-2025-11-03-13-26', 'lerobot-data-collection/test-2025-11-03-13-27', 'lerobot-data-collection/test-2025-11-03-13-29', 'lerobot-data-collection/test-2025-11-03-13-30', 'lerobot-data-collection/test-2025-11-03-13-31', 'lerobot-data-collection/test-2025-11-03-13-34', 'lerobot-data-collection/test-2025-11-03-13-41', 'lerobot-data-collection/test-2025-11-03-13-42', 'lerobot-data-collection/test-2025-11-03-13-43', 'lerobot-data-collection/test-2025-11-03-13-44', 'lerobot-data-collection/test-2025-11-03-13-45', 'lerobot-data-collection/test-2025-11-03-13-46', 'lerobot-data-collection/test-2025-11-03-13-47', 'lerobot-data-collection/test-2025-11-03-13-48', 'lerobot-data-collection/test-2025-11-03-13-49']"
|
||||||
|
|
||||||
|
# RUN loop_dataset.py to get your repo_ids
|
||||||
|
|
||||||
|
# ========================================================= Two folds datasets
|
||||||
|
|
||||||
|
#merge
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot-data-collection/two-folds-dataset-full-11-04 \
|
||||||
|
--operation.type merge --push_to_hub true \
|
||||||
|
--operation.repo_ids "['lerobot-data-collection/two-folds-dataset-2025-11-04-15-06', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-08', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-10', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-11', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-12', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-14', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-16', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-18', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-20', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-22', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-24', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-25', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-27', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-28', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-29', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-33', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-34', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-35', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-36', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-52', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-53', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-54', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-55', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-56', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-57', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-59', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-00', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-01', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-02', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-03', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-04', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-05', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-06', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-07', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-08', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-09', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-26', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-28', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-29', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-30']"
|
||||||
@@ -40,6 +40,7 @@ from lerobot.robots import ( # noqa: F401
|
|||||||
koch_follower,
|
koch_follower,
|
||||||
lekiwi,
|
lekiwi,
|
||||||
make_robot_from_config,
|
make_robot_from_config,
|
||||||
|
omx_follower,
|
||||||
so100_follower,
|
so100_follower,
|
||||||
so101_follower,
|
so101_follower,
|
||||||
)
|
)
|
||||||
@@ -49,10 +50,11 @@ from lerobot.teleoperators import ( # noqa: F401
|
|||||||
homunculus,
|
homunculus,
|
||||||
koch_leader,
|
koch_leader,
|
||||||
make_teleoperator_from_config,
|
make_teleoperator_from_config,
|
||||||
|
omx_leader,
|
||||||
so100_leader,
|
so100_leader,
|
||||||
so101_leader,
|
so101_leader,
|
||||||
)
|
)
|
||||||
from lerobot.utils.import_utils import register_third_party_devices
|
from lerobot.utils.import_utils import register_third_party_plugins
|
||||||
from lerobot.utils.utils import init_logging
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
@@ -84,7 +86,7 @@ def calibrate(cfg: CalibrateConfig):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
register_third_party_devices()
|
register_third_party_plugins()
|
||||||
calibrate()
|
calibrate()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,6 @@ import argparse
|
|||||||
import gc
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Iterator
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -78,19 +77,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|||||||
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
||||||
|
|
||||||
|
|
||||||
class EpisodeSampler(torch.utils.data.Sampler):
|
|
||||||
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
|
||||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
|
||||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
|
||||||
self.frame_ids = range(from_idx, to_idx)
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator:
|
|
||||||
return iter(self.frame_ids)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.frame_ids)
|
|
||||||
|
|
||||||
|
|
||||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||||
assert chw_float32_torch.dtype == torch.float32
|
assert chw_float32_torch.dtype == torch.float32
|
||||||
assert chw_float32_torch.ndim == 3
|
assert chw_float32_torch.ndim == 3
|
||||||
@@ -119,12 +105,10 @@ def visualize_dataset(
|
|||||||
repo_id = dataset.repo_id
|
repo_id = dataset.repo_id
|
||||||
|
|
||||||
logging.info("Loading dataloader")
|
logging.info("Loading dataloader")
|
||||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
sampler=episode_sampler,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Starting Rerun")
|
logging.info("Starting Rerun")
|
||||||
|
|||||||
@@ -18,7 +18,8 @@
|
|||||||
Edit LeRobot datasets using various transformation tools.
|
Edit LeRobot datasets using various transformation tools.
|
||||||
|
|
||||||
This script allows you to delete episodes, split datasets, merge datasets,
|
This script allows you to delete episodes, split datasets, merge datasets,
|
||||||
and remove features. When new_repo_id is specified, creates a new dataset.
|
remove features, and convert image datasets to video format.
|
||||||
|
When new_repo_id is specified, creates a new dataset.
|
||||||
|
|
||||||
Usage Examples:
|
Usage Examples:
|
||||||
|
|
||||||
@@ -65,6 +66,25 @@ Remove camera feature:
|
|||||||
--operation.type remove_feature \
|
--operation.type remove_feature \
|
||||||
--operation.feature_names "['observation.images.top']"
|
--operation.feature_names "['observation.images.top']"
|
||||||
|
|
||||||
|
Convert image dataset to video format (saves locally):
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot/pusht_image \
|
||||||
|
--operation.type convert_to_video \
|
||||||
|
--operation.output_dir /path/to/output/pusht_video
|
||||||
|
|
||||||
|
Convert image dataset and save with new repo_id:
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot/pusht_image \
|
||||||
|
--new_repo_id lerobot/pusht_video \
|
||||||
|
--operation.type convert_to_video
|
||||||
|
|
||||||
|
Convert and push to hub:
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot/pusht_image \
|
||||||
|
--new_repo_id lerobot/pusht_video \
|
||||||
|
--operation.type convert_to_video \
|
||||||
|
--push_to_hub true
|
||||||
|
|
||||||
Using JSON config file:
|
Using JSON config file:
|
||||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
--config_path path/to/edit_config.json
|
--config_path path/to/edit_config.json
|
||||||
@@ -72,9 +92,13 @@ Using JSON config file:
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.datasets.dataset_tools import (
|
from lerobot.datasets.dataset_tools import (
|
||||||
delete_episodes,
|
delete_episodes,
|
||||||
@@ -82,8 +106,10 @@ from lerobot.datasets.dataset_tools import (
|
|||||||
remove_feature,
|
remove_feature,
|
||||||
split_dataset,
|
split_dataset,
|
||||||
)
|
)
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
from lerobot.datasets.utils import write_stats, write_tasks
|
||||||
|
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
|
||||||
|
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
|
||||||
from lerobot.utils.utils import init_logging
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
@@ -111,10 +137,23 @@ class RemoveFeatureConfig:
|
|||||||
feature_names: list[str] | None = None
|
feature_names: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConvertToVideoConfig:
|
||||||
|
type: str = "convert_to_video"
|
||||||
|
output_dir: str | None = None
|
||||||
|
vcodec: str = "libsvtav1"
|
||||||
|
pix_fmt: str = "yuv420p"
|
||||||
|
g: int = 2
|
||||||
|
crf: int = 30
|
||||||
|
fast_decode: int = 0
|
||||||
|
episode_indices: list[int] | None = None
|
||||||
|
num_workers: int = 4
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EditDatasetConfig:
|
class EditDatasetConfig:
|
||||||
repo_id: str
|
repo_id: str
|
||||||
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig
|
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertToVideoConfig
|
||||||
root: str | None = None
|
root: str | None = None
|
||||||
new_repo_id: str | None = None
|
new_repo_id: str | None = None
|
||||||
push_to_hub: bool = False
|
push_to_hub: bool = False
|
||||||
@@ -258,6 +297,415 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
|||||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
|
def save_episode_images_for_video(
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
imgs_dir: Path,
|
||||||
|
img_key: str,
|
||||||
|
episode_index: int,
|
||||||
|
num_workers: int = 4,
|
||||||
|
) -> None:
|
||||||
|
"""Save images from a specific episode and camera to disk for video encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The LeRobot dataset to extract images from
|
||||||
|
imgs_dir: Directory to save images to
|
||||||
|
img_key: The image key (camera) to extract
|
||||||
|
episode_index: Index of the episode to save
|
||||||
|
num_workers: Number of threads for parallel image saving
|
||||||
|
"""
|
||||||
|
# Create directory
|
||||||
|
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Get dataset without torch format for PIL image access
|
||||||
|
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||||
|
|
||||||
|
# Select only this camera's images
|
||||||
|
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||||
|
|
||||||
|
# Get episode start and end indices
|
||||||
|
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||||
|
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||||
|
|
||||||
|
# Get all items for this episode
|
||||||
|
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||||
|
|
||||||
|
# Define function to save a single image
|
||||||
|
def save_single_image(i_item_tuple):
|
||||||
|
i, item = i_item_tuple
|
||||||
|
img = item[img_key]
|
||||||
|
# Use frame-XXXXXX.png format to match encode_video_frames expectations
|
||||||
|
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||||
|
return i
|
||||||
|
|
||||||
|
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
|
||||||
|
items = list(enumerate(episode_dataset))
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||||
|
futures = [executor.submit(save_single_image, item) for item in items]
|
||||||
|
for future in as_completed(futures):
|
||||||
|
future.result() # This will raise any exceptions that occurred
|
||||||
|
|
||||||
|
|
||||||
|
def encode_episode_videos(
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
new_meta: LeRobotDatasetMetadata,
|
||||||
|
episode_index: int,
|
||||||
|
vcodec: str,
|
||||||
|
pix_fmt: str,
|
||||||
|
g: int,
|
||||||
|
crf: int,
|
||||||
|
fast_decode: int,
|
||||||
|
temp_dir: Path,
|
||||||
|
num_image_workers: int = 4,
|
||||||
|
) -> dict[str, dict]:
|
||||||
|
"""Encode videos for a single episode and return video metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Source dataset with images
|
||||||
|
new_meta: Metadata object for the new video dataset
|
||||||
|
episode_index: Episode index to process
|
||||||
|
vcodec: Video codec
|
||||||
|
pix_fmt: Pixel format
|
||||||
|
g: Group of pictures size
|
||||||
|
crf: Constant rate factor
|
||||||
|
fast_decode: Fast decode tuning
|
||||||
|
temp_dir: Temporary directory for images
|
||||||
|
num_image_workers: Number of workers for saving images
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping video keys to their metadata (chunk_index, file_index, timestamps)
|
||||||
|
"""
|
||||||
|
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||||
|
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||||
|
|
||||||
|
video_metadata = {}
|
||||||
|
fps = int(dataset.fps) # Convert to int for PyAV compatibility
|
||||||
|
episode_length = dataset.meta.episodes["length"][episode_index]
|
||||||
|
episode_duration = episode_length / dataset.fps # Use original fps for duration calculation
|
||||||
|
|
||||||
|
for img_key in img_keys:
|
||||||
|
# Save images temporarily
|
||||||
|
imgs_dir = temp_dir / f"episode_{episode_index:06d}" / img_key
|
||||||
|
save_episode_images_for_video(dataset, imgs_dir, img_key, episode_index, num_image_workers)
|
||||||
|
|
||||||
|
# Determine chunk and file indices
|
||||||
|
# For simplicity, we'll put each episode in its own file
|
||||||
|
chunk_idx = episode_index // new_meta.chunks_size
|
||||||
|
file_idx = episode_index % new_meta.chunks_size
|
||||||
|
|
||||||
|
# Create video path in the new dataset structure
|
||||||
|
video_path = new_meta.root / new_meta.video_path.format(
|
||||||
|
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
|
||||||
|
)
|
||||||
|
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Encode video
|
||||||
|
encode_video_frames(
|
||||||
|
imgs_dir=imgs_dir,
|
||||||
|
video_path=video_path,
|
||||||
|
fps=fps,
|
||||||
|
vcodec=vcodec,
|
||||||
|
pix_fmt=pix_fmt,
|
||||||
|
g=g,
|
||||||
|
crf=crf,
|
||||||
|
fast_decode=fast_decode,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up temporary images
|
||||||
|
shutil.rmtree(imgs_dir)
|
||||||
|
|
||||||
|
# Store video metadata
|
||||||
|
video_metadata[img_key] = {
|
||||||
|
f"videos/{img_key}/chunk_index": chunk_idx,
|
||||||
|
f"videos/{img_key}/file_index": file_idx,
|
||||||
|
f"videos/{img_key}/from_timestamp": 0.0,
|
||||||
|
f"videos/{img_key}/to_timestamp": episode_duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
return video_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def convert_dataset_to_videos(
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
output_dir: Path,
|
||||||
|
repo_id: str | None = None,
|
||||||
|
vcodec: str = "libsvtav1",
|
||||||
|
pix_fmt: str = "yuv420p",
|
||||||
|
g: int = 2,
|
||||||
|
crf: int = 30,
|
||||||
|
fast_decode: int = 0,
|
||||||
|
episode_indices: list[int] | None = None,
|
||||||
|
num_workers: int = 4,
|
||||||
|
) -> LeRobotDataset:
|
||||||
|
"""Convert image-based dataset to video-based dataset.
|
||||||
|
|
||||||
|
Creates a new LeRobotDataset with videos instead of images, following the proper
|
||||||
|
LeRobot dataset structure with videos stored in chunked MP4 files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The source LeRobot dataset with images
|
||||||
|
output_dir: Directory to save the new video dataset
|
||||||
|
repo_id: Repository ID for the new dataset (default: original_id + "_video")
|
||||||
|
vcodec: Video codec (default: libsvtav1)
|
||||||
|
pix_fmt: Pixel format (default: yuv420p)
|
||||||
|
g: Group of pictures size (default: 2)
|
||||||
|
crf: Constant rate factor (default: 30)
|
||||||
|
fast_decode: Fast decode tuning (default: 0)
|
||||||
|
episode_indices: List of episode indices to convert (None = all episodes)
|
||||||
|
num_workers: Number of threads for parallel processing (default: 4)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New LeRobotDataset with videos
|
||||||
|
"""
|
||||||
|
# Check that it's an image dataset
|
||||||
|
if len(dataset.meta.video_keys) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all image keys
|
||||||
|
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||||
|
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||||
|
|
||||||
|
if len(img_keys) == 0:
|
||||||
|
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
|
||||||
|
|
||||||
|
# Determine which episodes to process
|
||||||
|
if episode_indices is None:
|
||||||
|
episode_indices = list(range(dataset.meta.total_episodes))
|
||||||
|
|
||||||
|
if repo_id is None:
|
||||||
|
repo_id = f"{dataset.repo_id}_video"
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||||
|
)
|
||||||
|
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
|
||||||
|
|
||||||
|
# Create new features dict, converting image features to video features
|
||||||
|
new_features = {}
|
||||||
|
for key, value in dataset.meta.features.items():
|
||||||
|
if key not in img_keys:
|
||||||
|
new_features[key] = value
|
||||||
|
else:
|
||||||
|
# Convert image key to video format
|
||||||
|
new_features[key] = value.copy()
|
||||||
|
new_features[key]["dtype"] = "video" # Change dtype from "image" to "video"
|
||||||
|
# Video info will be updated after episodes are encoded
|
||||||
|
|
||||||
|
# Create new metadata for video dataset
|
||||||
|
new_meta = LeRobotDatasetMetadata.create(
|
||||||
|
repo_id=repo_id,
|
||||||
|
fps=dataset.meta.fps,
|
||||||
|
features=new_features,
|
||||||
|
robot_type=dataset.meta.robot_type,
|
||||||
|
root=output_dir,
|
||||||
|
use_videos=True,
|
||||||
|
chunks_size=dataset.meta.chunks_size,
|
||||||
|
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
|
||||||
|
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create temporary directory for image extraction
|
||||||
|
temp_dir = output_dir / "temp_images"
|
||||||
|
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Process each episode
|
||||||
|
all_episode_metadata = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
for ep_idx in tqdm(episode_indices, desc="Converting episodes to videos"):
|
||||||
|
# Get episode metadata from source
|
||||||
|
src_episode = dataset.meta.episodes[ep_idx]
|
||||||
|
|
||||||
|
# Encode videos for this episode
|
||||||
|
video_metadata = encode_episode_videos(
|
||||||
|
dataset=dataset,
|
||||||
|
new_meta=new_meta,
|
||||||
|
episode_index=ep_idx,
|
||||||
|
vcodec=vcodec,
|
||||||
|
pix_fmt=pix_fmt,
|
||||||
|
g=g,
|
||||||
|
crf=crf,
|
||||||
|
fast_decode=fast_decode,
|
||||||
|
temp_dir=temp_dir,
|
||||||
|
num_image_workers=num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build episode metadata
|
||||||
|
episode_meta = {
|
||||||
|
"episode_index": ep_idx,
|
||||||
|
"length": src_episode["length"],
|
||||||
|
"dataset_from_index": ep_idx * src_episode["length"],
|
||||||
|
"dataset_to_index": (ep_idx + 1) * src_episode["length"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add video metadata
|
||||||
|
for img_key in img_keys:
|
||||||
|
episode_meta.update(video_metadata[img_key])
|
||||||
|
|
||||||
|
# Add data chunk/file info (using same structure as source)
|
||||||
|
if "data/chunk_index" in src_episode:
|
||||||
|
episode_meta["data/chunk_index"] = src_episode["data/chunk_index"]
|
||||||
|
episode_meta["data/file_index"] = src_episode["data/file_index"]
|
||||||
|
|
||||||
|
all_episode_metadata.append(episode_meta)
|
||||||
|
|
||||||
|
# Copy and transform data files (removing image columns)
|
||||||
|
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
|
||||||
|
|
||||||
|
# Save episode metadata
|
||||||
|
episodes_df = pd.DataFrame(all_episode_metadata)
|
||||||
|
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
|
||||||
|
episodes_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
episodes_df.to_parquet(episodes_path, index=False)
|
||||||
|
|
||||||
|
# Update metadata info
|
||||||
|
new_meta.info["total_episodes"] = len(episode_indices)
|
||||||
|
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata)
|
||||||
|
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
||||||
|
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
||||||
|
|
||||||
|
# Update video info for all image keys (now videos)
|
||||||
|
# We need to manually set video info since update_video_info() checks video_keys first
|
||||||
|
for img_key in img_keys:
|
||||||
|
if not new_meta.features[img_key].get("info", None):
|
||||||
|
video_path = new_meta.root / new_meta.video_path.format(
|
||||||
|
video_key=img_key, chunk_index=0, file_index=0
|
||||||
|
)
|
||||||
|
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
||||||
|
|
||||||
|
from lerobot.datasets.utils import write_info
|
||||||
|
|
||||||
|
write_info(new_meta.info, new_meta.root)
|
||||||
|
|
||||||
|
# Copy stats and tasks
|
||||||
|
if dataset.meta.stats is not None:
|
||||||
|
# Remove image stats
|
||||||
|
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
|
||||||
|
write_stats(new_stats, new_meta.root)
|
||||||
|
|
||||||
|
if dataset.meta.tasks is not None:
|
||||||
|
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up temporary directory
|
||||||
|
if temp_dir.exists():
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
logging.info(f"✓ Completed converting {dataset.repo_id} to video format")
|
||||||
|
logging.info(f"New dataset saved to: {output_dir}")
|
||||||
|
|
||||||
|
# Return new dataset
|
||||||
|
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_data_without_images(
|
||||||
|
src_dataset: LeRobotDataset,
|
||||||
|
dst_meta: LeRobotDatasetMetadata,
|
||||||
|
episode_indices: list[int],
|
||||||
|
img_keys: list[str],
|
||||||
|
) -> None:
|
||||||
|
"""Copy data files without image columns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src_dataset: Source dataset
|
||||||
|
dst_meta: Destination metadata
|
||||||
|
episode_indices: Episodes to include
|
||||||
|
img_keys: Image keys to remove
|
||||||
|
"""
|
||||||
|
from lerobot.datasets.utils import DATA_DIR
|
||||||
|
|
||||||
|
data_dir = src_dataset.root / DATA_DIR
|
||||||
|
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||||
|
|
||||||
|
if not parquet_files:
|
||||||
|
raise ValueError(f"No parquet files found in {data_dir}")
|
||||||
|
|
||||||
|
episode_set = set(episode_indices)
|
||||||
|
|
||||||
|
for src_path in tqdm(parquet_files, desc="Processing data files"):
|
||||||
|
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||||
|
|
||||||
|
# Filter to only include selected episodes
|
||||||
|
df = df[df["episode_index"].isin(episode_set)].copy()
|
||||||
|
|
||||||
|
if len(df) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Remove image columns
|
||||||
|
columns_to_drop = [col for col in img_keys if col in df.columns]
|
||||||
|
if columns_to_drop:
|
||||||
|
df = df.drop(columns=columns_to_drop)
|
||||||
|
|
||||||
|
# Get chunk and file indices from path
|
||||||
|
relative_path = src_path.relative_to(src_dataset.root)
|
||||||
|
chunk_dir = relative_path.parts[1]
|
||||||
|
file_name = relative_path.parts[2]
|
||||||
|
chunk_idx = int(chunk_dir.split("-")[1])
|
||||||
|
file_idx = int(file_name.split("-")[1].split(".")[0])
|
||||||
|
|
||||||
|
# Write to destination without pandas index
|
||||||
|
dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet"
|
||||||
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
df.to_parquet(dst_path, index=False)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_convert_to_video(cfg: EditDatasetConfig) -> None:
|
||||||
|
# Note: Parser may create any config type with the right fields, so we access fields directly
|
||||||
|
# instead of checking isinstance()
|
||||||
|
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||||
|
|
||||||
|
# Determine output directory and repo_id
|
||||||
|
# Priority: 1) new_repo_id, 2) operation.output_dir, 3) auto-generated name
|
||||||
|
output_dir_config = getattr(cfg.operation, "output_dir", None)
|
||||||
|
|
||||||
|
if cfg.new_repo_id:
|
||||||
|
# Use new_repo_id for both local storage and hub push
|
||||||
|
output_repo_id = cfg.new_repo_id
|
||||||
|
output_dir = Path(cfg.root) / cfg.new_repo_id if cfg.root else HF_LEROBOT_HOME / cfg.new_repo_id
|
||||||
|
logging.info(f"Saving to new dataset: {cfg.new_repo_id}")
|
||||||
|
elif output_dir_config:
|
||||||
|
# Use custom output directory for local-only storage
|
||||||
|
output_dir = Path(output_dir_config)
|
||||||
|
# Extract repo name from output_dir for the dataset
|
||||||
|
output_repo_id = output_dir.name
|
||||||
|
logging.info(f"Saving to local directory: {output_dir}")
|
||||||
|
else:
|
||||||
|
# Auto-generate name: append "_video" to original repo_id
|
||||||
|
output_repo_id = f"{cfg.repo_id}_video"
|
||||||
|
output_dir = Path(cfg.root) / output_repo_id if cfg.root else HF_LEROBOT_HOME / output_repo_id
|
||||||
|
logging.info(f"Saving to auto-generated location: {output_dir}")
|
||||||
|
|
||||||
|
logging.info(f"Converting dataset {cfg.repo_id} to video format")
|
||||||
|
|
||||||
|
new_dataset = convert_dataset_to_videos(
|
||||||
|
dataset=dataset,
|
||||||
|
output_dir=output_dir,
|
||||||
|
repo_id=output_repo_id,
|
||||||
|
vcodec=getattr(cfg.operation, "vcodec", "libsvtav1"),
|
||||||
|
pix_fmt=getattr(cfg.operation, "pix_fmt", "yuv420p"),
|
||||||
|
g=getattr(cfg.operation, "g", 2),
|
||||||
|
crf=getattr(cfg.operation, "crf", 30),
|
||||||
|
fast_decode=getattr(cfg.operation, "fast_decode", 0),
|
||||||
|
episode_indices=getattr(cfg.operation, "episode_indices", None),
|
||||||
|
num_workers=getattr(cfg.operation, "num_workers", 4),
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Video dataset created successfully!")
|
||||||
|
logging.info(f"Location: {output_dir}")
|
||||||
|
logging.info(f"Episodes: {new_dataset.meta.total_episodes}")
|
||||||
|
logging.info(f"Frames: {new_dataset.meta.total_frames}")
|
||||||
|
|
||||||
|
if cfg.push_to_hub:
|
||||||
|
logging.info(f"Pushing to hub as {output_repo_id}...")
|
||||||
|
new_dataset.push_to_hub()
|
||||||
|
logging.info("✓ Successfully pushed to hub!")
|
||||||
|
else:
|
||||||
|
logging.info("Dataset saved locally (not pushed to hub)")
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def edit_dataset(cfg: EditDatasetConfig) -> None:
|
def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||||
operation_type = cfg.operation.type
|
operation_type = cfg.operation.type
|
||||||
@@ -270,10 +718,12 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
|||||||
handle_merge(cfg)
|
handle_merge(cfg)
|
||||||
elif operation_type == "remove_feature":
|
elif operation_type == "remove_feature":
|
||||||
handle_remove_feature(cfg)
|
handle_remove_feature(cfg)
|
||||||
|
elif operation_type == "convert_to_video":
|
||||||
|
handle_convert_to_video(cfg)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown operation type: {operation_type}\n"
|
f"Unknown operation type: {operation_type}\n"
|
||||||
f"Available operations: delete_episodes, split, merge, remove_feature"
|
f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user