mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
Compare commits
14 Commits
security-f
...
feat/multi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c5925399a9 | ||
|
|
f478ae5bfa | ||
|
|
b4d40d0228 | ||
|
|
db5c26f07d | ||
|
|
8904768db4 | ||
|
|
b0efa73520 | ||
|
|
00b662de02 | ||
|
|
5c51a74484 | ||
|
|
db8547e35d | ||
|
|
c17d949531 | ||
|
|
1e131f93f8 | ||
|
|
2fb5c7add0 | ||
|
|
4f2ef024d8 | ||
|
|
6139b133ca |
19
README.md
19
README.md
@@ -135,7 +135,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
|||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
If you use LeRobot in your research, please cite:
|
If you use LeRobot in your project, please cite the GitHub repository to acknowledge the ongoing development and contributors:
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@misc{cadene2024lerobot,
|
@misc{cadene2024lerobot,
|
||||||
@@ -146,6 +146,23 @@ If you use LeRobot in your research, please cite:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you are referencing our research or the academic paper, please also cite our ICLR publication:
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>ICLR 2026 Paper</b></summary>
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{cadenelerobot,
|
||||||
|
title={LeRobot: An Open-Source Library for End-to-End Robot Learning},
|
||||||
|
author={Cadene, Remi and Alibert, Simon and Capuano, Francesco and Aractingi, Michel and Zouitine, Adil and Kooijmans, Pepijn and Choghari, Jade and Russi, Martino and Pascal, Caroline and Palma, Steven and Shukor, Mustafa and Moss, Jess and Soare, Alexander and Aubakirova, Dana and Lhoest, Quentin and Gallou\'edec, Quentin and Wolf, Thomas},
|
||||||
|
booktitle={The Fourteenth International Conference on Learning Representations},
|
||||||
|
year={2026},
|
||||||
|
url={https://arxiv.org/abs/2602.22818}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## Contribute
|
## Contribute
|
||||||
|
|
||||||
We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](./CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!
|
We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](./CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!
|
||||||
|
|||||||
@@ -31,6 +31,8 @@
|
|||||||
title: Using Subtasks in the Dataset
|
title: Using Subtasks in the Dataset
|
||||||
- local: streaming_video_encoding
|
- local: streaming_video_encoding
|
||||||
title: Streaming Video Encoding
|
title: Streaming Video Encoding
|
||||||
|
- local: multi_dataset_training
|
||||||
|
title: Multi-Dataset Training
|
||||||
title: "Datasets"
|
title: "Datasets"
|
||||||
- sections:
|
- sections:
|
||||||
- local: act
|
- local: act
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
# Installation
|
# Installation
|
||||||
|
|
||||||
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
|
This guide uses `conda` (via miniforge) to manage environments (recommended). If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and `ffmpeg` installed with the `libsvtav1` encoder, then skip ahead to [Environment Setup](#step-2-environment-setup).
|
||||||
|
|
||||||
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
|
## Step 1 (`conda` only): Install [`miniforge`](https://conda-forge.org/download/)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
|
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
|
||||||
@@ -11,22 +11,47 @@ bash Miniforge3-$(uname)-$(uname -m).sh
|
|||||||
|
|
||||||
## Step 2: Environment Setup
|
## Step 2: Environment Setup
|
||||||
|
|
||||||
Create a virtual environment with Python 3.12, using conda:
|
Create a virtual environment with Python 3.12:
|
||||||
|
|
||||||
|
<!-- prettier-ignore-start -->
|
||||||
|
<hfoptions id="create_venv">
|
||||||
|
<hfoption id="conda">
|
||||||
```bash
|
```bash
|
||||||
conda create -y -n lerobot python=3.12
|
conda create -y -n lerobot python=3.12
|
||||||
```
|
```
|
||||||
|
</hfoption>
|
||||||
Then activate your conda environment, you have to do this each time you open a shell to use lerobot:
|
<hfoption id="uv">
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
conda activate lerobot
|
uv python install 3.12
|
||||||
|
uv venv --python 3.12
|
||||||
```
|
```
|
||||||
|
</hfoption>
|
||||||
|
</hfoptions>
|
||||||
|
<!-- prettier-ignore-end -->
|
||||||
|
|
||||||
|
Then activate your virtual environment, you have to do this each time you open a shell to use lerobot:
|
||||||
|
|
||||||
|
<!-- prettier-ignore-start -->
|
||||||
|
<hfoptions id="activate_venv">
|
||||||
|
<hfoption id="conda">```bash
|
||||||
|
conda activate lerobot
|
||||||
|
```</hfoption>
|
||||||
|
<hfoption id="uv">
|
||||||
|
```bash
|
||||||
|
# Linux/macOSsource
|
||||||
|
source .venv/bin/activate
|
||||||
|
# Windows PowerShell
|
||||||
|
source .venv\Scripts\Activate.ps1
|
||||||
|
```
|
||||||
|
</hfoption>
|
||||||
|
</hfoptions>
|
||||||
|
<!-- prettier-ignore-end -->
|
||||||
|
|
||||||
When using `conda`, install `ffmpeg` in your environment:
|
When using `conda`, install `ffmpeg` in your environment:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
conda install ffmpeg -c conda-forge
|
conda install ffmpeg -c conda-forge
|
||||||
|
ffmpeg -version # ffmpeg 8.X is not yet supported !
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
@@ -47,6 +72,9 @@ conda install ffmpeg -c conda-forge
|
|||||||
> conda install evdev -c conda-forge
|
> conda install evdev -c conda-forge
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> If you are using `uv` you will have to install `ffmpeg` system-wide (outside of the virtual environment). You rely on `uv` and `torchcodec` ability to dynamically link to the system `ffmpeg`.
|
||||||
|
|
||||||
## Step 3: Install LeRobot 🤗
|
## Step 3: Install LeRobot 🤗
|
||||||
|
|
||||||
### From Source
|
### From Source
|
||||||
@@ -60,23 +88,45 @@ cd lerobot
|
|||||||
|
|
||||||
Then, install the library in editable mode. This is useful if you plan to contribute to the code.
|
Then, install the library in editable mode. This is useful if you plan to contribute to the code.
|
||||||
|
|
||||||
|
<!-- prettier-ignore-start -->
|
||||||
|
<hfoptions id="install_lerobot_src">
|
||||||
|
<hfoption id="conda">
|
||||||
```bash
|
```bash
|
||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="uv">
|
||||||
|
```bash
|
||||||
|
uv pip install -e .
|
||||||
|
```
|
||||||
|
</hfoption>
|
||||||
|
</hfoptions>
|
||||||
|
<!-- prettier-ignore-end -->
|
||||||
|
|
||||||
### Installation from PyPI
|
### Installation from PyPI
|
||||||
|
|
||||||
**Core Library:**
|
**Core Library:**
|
||||||
Install the base package with:
|
Install the base package with:
|
||||||
|
|
||||||
|
<!-- prettier-ignore-start -->
|
||||||
|
<hfoptions id="install_lerobot_pypi">
|
||||||
|
<hfoption id="conda">
|
||||||
```bash
|
```bash
|
||||||
pip install lerobot
|
pip install lerobot
|
||||||
```
|
```
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="uv">
|
||||||
|
```bash
|
||||||
|
uv pip install lerobot
|
||||||
|
```
|
||||||
|
</hfoption>
|
||||||
|
</hfoptions>
|
||||||
|
<!-- prettier-ignore-end -->
|
||||||
|
|
||||||
_This installs only the default dependencies._
|
_This installs only the default dependencies._
|
||||||
|
|
||||||
**Extra Features:**
|
**Extra Features:**
|
||||||
To install additional functionality, use one of the following:
|
To install additional functionality, use one of the following (If you are using `uv`, replace `pip install` with `uv pip install` in the commands below.):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install 'lerobot[all]' # All available features
|
pip install 'lerobot[all]' # All available features
|
||||||
@@ -93,7 +143,7 @@ https://pypi.org/project/lerobot/
|
|||||||
### Troubleshooting
|
### Troubleshooting
|
||||||
|
|
||||||
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||||
To install these for linux run:
|
To install these for Linux run:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev
|
sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev
|
||||||
@@ -103,7 +153,7 @@ For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/
|
|||||||
|
|
||||||
## Optional dependencies
|
## Optional dependencies
|
||||||
|
|
||||||
LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`.
|
LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`. If you are using `uv`, replace `pip install` with `uv pip install` in the commands below.
|
||||||
|
|
||||||
### Simulations
|
### Simulations
|
||||||
|
|
||||||
|
|||||||
232
docs/source/multi_dataset_training.mdx
Normal file
232
docs/source/multi_dataset_training.mdx
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
# Multi-Dataset Training
|
||||||
|
|
||||||
|
This guide covers how to train a single policy on multiple heterogeneous datasets using `MultiLeRobotDataset`.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Real-world robot learning datasets come from different environments, robots, and camera setups. A RoboCasa dataset might have three cameras named `robot0_agentview_left`, `robot0_agentview_right`, and `robot0_eye_in_hand`, while a LIBERO dataset uses `observation.images.front` and `observation.images.wrist`, and a RoboMME dataset uses bare `image` and `wrist_image` keys. State and action dimensions also differ.
|
||||||
|
|
||||||
|
`MultiLeRobotDataset` lets you train on all of them jointly by:
|
||||||
|
|
||||||
|
- **Mapping** each dataset's feature keys into a shared namespace
|
||||||
|
- **Padding** features that a dataset doesn't have with zeros
|
||||||
|
- **Weighting** how often each dataset is sampled
|
||||||
|
- **Transforming** samples per-dataset (e.g. padding actions to a common dimension)
|
||||||
|
- **Aggregating** statistics across all sub-datasets for normalization
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Multi-dataset training is configured via `MultiDatasetConfig` in a YAML config file. Instead of a single `dataset.repo_id`, you provide a `datasets` list where each entry is a `SubDatasetConfig`.
|
||||||
|
|
||||||
|
### SubDatasetConfig fields
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `repo_id` | `str` | required | HuggingFace repo ID or local dataset name |
|
||||||
|
| `root` | `str \| None` | `None` | Local root directory for the dataset |
|
||||||
|
| `episodes` | `list[int] \| None` | `None` | Subset of episode indices to use |
|
||||||
|
| `revision` | `str \| None` | `None` | Dataset version / revision |
|
||||||
|
| `video_backend` | `str` | auto | Video decoding backend (`pyav`, `torchcodec`, etc.) |
|
||||||
|
| `weight` | `float` | `1.0` | Relative sampling weight for this dataset |
|
||||||
|
| `feature_map` | `dict[str, str]` | `{}` | Maps dataset keys to unified policy keys |
|
||||||
|
| `transforms` | `list` | `None` | Per-dataset transform steps (applied per sample) |
|
||||||
|
|
||||||
|
### Example: Three-dataset config
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
dataset:
|
||||||
|
type: multi
|
||||||
|
use_imagenet_stats: true
|
||||||
|
datasets:
|
||||||
|
# RoboCasa: 3 cameras, state(16), action(12)
|
||||||
|
- repo_id: pepijn223/robocasa_PrepareCoffee
|
||||||
|
root: /data/robocasa_PrepareCoffee
|
||||||
|
weight: 1.0
|
||||||
|
feature_map:
|
||||||
|
observation.images.robot0_agentview_left: observation.images.front_left
|
||||||
|
observation.images.robot0_agentview_right: observation.images.front_right
|
||||||
|
observation.images.robot0_eye_in_hand: observation.images.wrist
|
||||||
|
|
||||||
|
# LIBERO-plus: 2 cameras, state(8), action(7)
|
||||||
|
- repo_id: pepijn223/libero_plus_lerobot
|
||||||
|
root: /data/libero_plus_lerobot
|
||||||
|
weight: 0.5
|
||||||
|
feature_map:
|
||||||
|
observation.images.front: observation.images.front_left
|
||||||
|
observation.images.wrist: observation.images.wrist
|
||||||
|
transforms:
|
||||||
|
- type: pad_action
|
||||||
|
kwargs: {target_dim: 12}
|
||||||
|
- type: pad_state
|
||||||
|
kwargs: {target_dim: 16}
|
||||||
|
|
||||||
|
# RoboMME: 2 cameras (non-standard keys), state(8), action(8)
|
||||||
|
- repo_id: pepijn223/robomme_data_lerobot
|
||||||
|
root: /data/robomme_data_lerobot
|
||||||
|
weight: 0.3
|
||||||
|
feature_map:
|
||||||
|
image: observation.images.front_left
|
||||||
|
wrist_image: observation.images.wrist
|
||||||
|
state: observation.state
|
||||||
|
actions: action
|
||||||
|
transforms:
|
||||||
|
- type: pad_action
|
||||||
|
kwargs: {target_dim: 12}
|
||||||
|
- type: pad_state
|
||||||
|
kwargs: {target_dim: 16}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Feature Mapping
|
||||||
|
|
||||||
|
The `feature_map` dictionary renames dataset-local keys into a shared namespace. Keys not listed pass through unchanged. In the example above, all three datasets end up with the same camera key names (`observation.images.front_left`, `observation.images.wrist`) even though they use different conventions internally.
|
||||||
|
|
||||||
|
After mapping, the **union** of all features across datasets defines the unified schema. If a feature exists in some datasets but not others, it is automatically zero-padded for datasets that lack it, and a boolean `{key}_is_pad` flag is added to the sample so the policy can optionally mask padded features.
|
||||||
|
|
||||||
|
## Automatic Padding
|
||||||
|
|
||||||
|
When a sub-dataset doesn't have a feature that exists in the unified schema:
|
||||||
|
|
||||||
|
- **Images/videos**: padded with a black frame (zeros) matching the expected resolution
|
||||||
|
- **Float tensors** (state, action): padded with zeros
|
||||||
|
- **Integer/bool tensors**: padded with zeros / False
|
||||||
|
|
||||||
|
A companion `{key}_is_pad = True` tensor is added so the model can distinguish real data from padding.
|
||||||
|
|
||||||
|
## Per-Dataset Transforms
|
||||||
|
|
||||||
|
Each sub-dataset can have its own `transforms` pipeline that runs after feature renaming but before cross-dataset padding. This is useful for making shapes compatible before PyTorch's collate function stacks the batch.
|
||||||
|
|
||||||
|
### Built-in transforms
|
||||||
|
|
||||||
|
| Name | Description | Parameters |
|
||||||
|
|------|-------------|------------|
|
||||||
|
| `pad_action` | Zero-pad `action` to a target dimension | `target_dim: int` |
|
||||||
|
| `pad_state` | Zero-pad `observation.state` to a target dimension | `target_dim: int` |
|
||||||
|
| `resize_images` | Resize all `observation.images.*` tensors | `height: int`, `width: int` |
|
||||||
|
|
||||||
|
### Custom transforms
|
||||||
|
|
||||||
|
You can register your own transforms in `lerobot/datasets/transforms.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.datasets.transforms import DatasetTransformStep, register_dataset_transform
|
||||||
|
|
||||||
|
@register_dataset_transform("my_transform")
|
||||||
|
class MyTransform(DatasetTransformStep):
|
||||||
|
def __init__(self, some_param: int):
|
||||||
|
self.some_param = some_param
|
||||||
|
|
||||||
|
def __call__(self, sample: dict) -> dict:
|
||||||
|
# Modify sample in-place or return a new dict
|
||||||
|
sample["action"] = sample["action"] * self.some_param
|
||||||
|
return sample
|
||||||
|
```
|
||||||
|
|
||||||
|
Then reference it in the config:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
transforms:
|
||||||
|
- type: my_transform
|
||||||
|
kwargs: {some_param: 2}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Weighted Sampling
|
||||||
|
|
||||||
|
The `weight` field on each sub-dataset controls how often it is sampled during training. Weights are relative and automatically normalized to probabilities. For example, with weights `[1.0, 0.5, 0.3]`, the first dataset is sampled roughly 56% of the time, the second 28%, and the third 16%.
|
||||||
|
|
||||||
|
This uses `WeightedEpisodeAwareSampler`, which respects episode boundaries (so `drop_n_last_frames` and similar policy settings work correctly) while sampling across datasets proportionally.
|
||||||
|
|
||||||
|
## Stats Aggregation
|
||||||
|
|
||||||
|
Normalization statistics (mean, std, min, max, quantiles) are automatically aggregated across all sub-datasets using the mapped feature keys. The aggregation uses a weighted parallel variance algorithm so that datasets with more frames contribute proportionally to the global statistics.
|
||||||
|
|
||||||
|
The aggregated stats are used by the standard LeRobot preprocessor for normalization during training.
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
Launch training the same way as single-dataset training. The factory and training script automatically detect `MultiDatasetConfig` and set up the weighted sampler:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m lerobot.scripts.lerobot_train \
|
||||||
|
--config_path path/to/multi_dataset_config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
The data flow during training with `MultiLeRobotDataset`:
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────┐
|
||||||
|
│ MultiLeRobotDataset.__getitem__(global_idx) │
|
||||||
|
│ │
|
||||||
|
│ 1. Map global_idx → (dataset_idx, local_idx) │
|
||||||
|
│ 2. Fetch sample from sub-dataset │
|
||||||
|
│ 3. Rename keys via feature_map │
|
||||||
|
│ 4. Apply per-dataset transforms (pad_action, etc.) │
|
||||||
|
│ 5. Zero-pad missing features + add _is_pad flags │
|
||||||
|
│ 6. Add dataset_index tag │
|
||||||
|
└─────────────────────┬───────────────────────────────────┘
|
||||||
|
│
|
||||||
|
┌────────────▼────────────┐
|
||||||
|
│ PyTorch DataLoader │
|
||||||
|
│ (collates into batch) │
|
||||||
|
└────────────┬────────────┘
|
||||||
|
│
|
||||||
|
┌────────────▼────────────┐
|
||||||
|
│ LeRobot Preprocessor │
|
||||||
|
│ (normalize, tokenize) │
|
||||||
|
└────────────┬────────────┘
|
||||||
|
│
|
||||||
|
┌────────────▼────────────┐
|
||||||
|
│ Policy forward + loss │
|
||||||
|
└─────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
### `NewMultiLeRobotDataset`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
|
||||||
|
|
||||||
|
dataset = NewMultiLeRobotDataset(
|
||||||
|
configs=[...], # list[SubDatasetConfig]
|
||||||
|
image_transforms=None, # optional image augmentation
|
||||||
|
delta_timestamps=None, # optional temporal neighbors
|
||||||
|
tolerance_s=1e-4, # timestamp tolerance
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset.num_frames # total frames across all sub-datasets
|
||||||
|
dataset.num_episodes # total episodes
|
||||||
|
dataset.meta # MultiDatasetMeta (stats, features, episodes)
|
||||||
|
dataset.dataset_weights # list of per-dataset weights
|
||||||
|
dataset.features # unified feature dict (union of all mapped features)
|
||||||
|
dataset.camera_keys # unified camera key list
|
||||||
|
```
|
||||||
|
|
||||||
|
### `WeightedEpisodeAwareSampler`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.datasets.sampler import WeightedEpisodeAwareSampler
|
||||||
|
|
||||||
|
sampler = WeightedEpisodeAwareSampler(
|
||||||
|
dataset_from_indices=dataset.meta.episodes["dataset_from_index"],
|
||||||
|
dataset_to_indices=dataset.meta.episodes["dataset_to_index"],
|
||||||
|
dataset_membership=dataset.meta.episodes["dataset_source"],
|
||||||
|
dataset_weights=dataset.dataset_weights,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### `DatasetTransformPipeline`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.datasets.transforms import DatasetTransformPipeline, DatasetTransformStepConfig
|
||||||
|
|
||||||
|
pipeline = DatasetTransformPipeline([
|
||||||
|
DatasetTransformStepConfig(type="pad_action", kwargs={"target_dim": 12}),
|
||||||
|
DatasetTransformStepConfig(type="pad_state", kwargs={"target_dim": 16}),
|
||||||
|
])
|
||||||
|
|
||||||
|
sample = pipeline(sample) # modifies the sample dict
|
||||||
|
```
|
||||||
@@ -1,23 +1,49 @@
|
|||||||
# Unitree G1
|
# Unitree G1
|
||||||
|
|
||||||
This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion.
|
<img
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/unitree_thumbnail.jpg"
|
||||||
|
alt="Unitree G1 locomanipulation demo"
|
||||||
|
style={{ width: "100%" }}
|
||||||
|
/>
|
||||||
|
|
||||||
## About
|
The Unitree G1 humanoid is now supported in LeRobot! You can teleoperate, train locomanipulation policies, test in sim, and more. Both 29 and 23 DoF variants are supported.
|
||||||
|
|
||||||
We support both 29 and 23 DOF G1 EDU version. We introduce:
|
|
||||||
|
|
||||||
- **`unitree g1` robot class, handling low level read/write from/to the humanoid**
|
|
||||||
- **ZMQ socket bridge** for remote communication and camera streaming, allowing for remote policy deployment over wlan, eth or directly on the robot
|
|
||||||
- **Locomotion policies** from NVIDIA gr00t and Amazon FAR Holosoma
|
|
||||||
- **Simulation mode** for testing policies without the physical robot in mujoco
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Connection guide
|
## Part 1: Getting Started
|
||||||
|
|
||||||
### Step 1: Configure Ethernet Interface
|
### Install LeRobot on Your Machine
|
||||||
|
|
||||||
Set a static IP on the same subnet as the robot:
|
```bash
|
||||||
|
conda create -y -n lerobot python=3.12
|
||||||
|
conda activate lerobot
|
||||||
|
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||||
|
cd unitree_sdk2_python && pip install -e .
|
||||||
|
git clone https://github.com/huggingface/lerobot.git
|
||||||
|
cd lerobot
|
||||||
|
pip install -e '.[unitree_g1]'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test the Installation (Simulation)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-teleoperate \
|
||||||
|
--robot.type=unitree_g1 \
|
||||||
|
--robot.is_simulation=true \
|
||||||
|
--teleop.type=unitree_g1 \
|
||||||
|
--teleop.id=wbc_unitree \
|
||||||
|
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||||
|
--display_data=true
|
||||||
|
```
|
||||||
|
|
||||||
|
This will launch a [MuJoCo sim instance](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main) for the G1.
|
||||||
|
|
||||||
|
- Press `9` to release the robot
|
||||||
|
- Press `7` / `8` to increase / decrease waist height
|
||||||
|
|
||||||
|
### Connect to the Robot
|
||||||
|
|
||||||
|
The G1's Ethernet IP is fixed at `192.168.123.164`. Your machine must have a static IP on the same subnet: `192.168.123.x` where `x ≠ 164`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Replace 'enp131s0' with your ethernet interface name (check with `ip a`)
|
# Replace 'enp131s0' with your ethernet interface name (check with `ip a`)
|
||||||
@@ -26,272 +52,200 @@ sudo ip addr add 192.168.123.200/24 dev enp131s0
|
|||||||
sudo ip link set enp131s0 up
|
sudo ip link set enp131s0 up
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note**: The G1's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` with x ≠ 164.
|
### SSH into the Robot
|
||||||
|
|
||||||
### Step 2: SSH into the Robot
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ssh unitree@192.168.123.164
|
ssh unitree@192.168.123.164
|
||||||
# Password: 123
|
# Password: 123
|
||||||
```
|
```
|
||||||
|
|
||||||
You should now be connected to the G1's Orin.
|
### Install LeRobot on the G1
|
||||||
|
|
||||||
|
From the robot:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda create -y -n lerobot python=3.12
|
||||||
|
conda activate lerobot
|
||||||
|
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||||
|
cd unitree_sdk2_python && pip install -e .
|
||||||
|
git clone https://github.com/huggingface/lerobot.git
|
||||||
|
cd lerobot
|
||||||
|
pip install -e '.[unitree_g1]'
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note:** The Unitree SDK requires CycloneDDS v0.10.2. See the [Unitree SDK docs](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Part 2: Enable WiFi on the Robot
|
## Part 2: Enable WiFi on the Robot
|
||||||
|
|
||||||
Wlan0 is disabled by default on the G1. To enable it:
|
Wi-Fi connectivity is blocked by default on the G1. To activate:
|
||||||
|
|
||||||
### Step 1: Enable WiFi Hardware
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sudo rfkill unblock wifi
|
|
||||||
sudo rfkill unblock all
|
sudo rfkill unblock all
|
||||||
|
|
||||||
# Bring up wlan0
|
|
||||||
sudo ip link set wlan0 up
|
sudo ip link set wlan0 up
|
||||||
|
|
||||||
# Enable NetworkManager control of wlan0
|
|
||||||
sudo nmcli radio wifi on
|
sudo nmcli radio wifi on
|
||||||
sudo nmcli device set wlan0 managed yes
|
sudo nmcli device set wlan0 managed yes
|
||||||
sudo systemctl restart NetworkManager
|
sudo systemctl restart NetworkManager
|
||||||
```
|
```
|
||||||
|
|
||||||
### Step 2: Enable Internet Forwarding
|
**On your laptop** (share internet via Ethernet):
|
||||||
|
|
||||||
**On your laptop:**
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Enable IP forwarding
|
|
||||||
sudo sysctl -w net.ipv4.ip_forward=1
|
sudo sysctl -w net.ipv4.ip_forward=1
|
||||||
|
|
||||||
# Set up NAT (replace wlp132s0f0 with your WiFi interface)
|
# Replace wlp132s0f0 with your WiFi interface name
|
||||||
sudo iptables -t nat -A POSTROUTING -o wlp132s0f0 -s 192.168.123.0/24 -j MASQUERADE
|
sudo iptables -t nat -A POSTROUTING -o wlp132s0f0 -s 192.168.123.0/24 -j MASQUERADE
|
||||||
sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTABLISHED -j ACCEPT
|
sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTABLISHED -j ACCEPT
|
||||||
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
|
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
|
||||||
```
|
```
|
||||||
|
|
||||||
**On the G1:**
|
**On the G1** (set default route through your laptop):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Add laptop as default gateway
|
|
||||||
sudo ip route del default 2>/dev/null || true
|
sudo ip route del default 2>/dev/null || true
|
||||||
sudo ip route add default via 192.168.123.200 dev eth0
|
sudo ip route add default via 192.168.123.200 dev eth0
|
||||||
echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
|
echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
|
||||||
|
|
||||||
# Test connection
|
# Verify
|
||||||
ping -c 3 8.8.8.8
|
ping -c 3 8.8.8.8
|
||||||
```
|
```
|
||||||
|
|
||||||
### Step 3: Connect to WiFi Network
|
**Connect to a WiFi network:**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# List available networks
|
|
||||||
nmcli device wifi list
|
nmcli device wifi list
|
||||||
|
|
||||||
# Connect to your WiFi (example)
|
|
||||||
sudo nmcli connection add type wifi ifname wlan0 con-name "YourNetwork" ssid "YourNetwork"
|
sudo nmcli connection add type wifi ifname wlan0 con-name "YourNetwork" ssid "YourNetwork"
|
||||||
sudo nmcli connection modify "YourNetwork" wifi-sec.key-mgmt wpa-psk
|
sudo nmcli connection modify "YourNetwork" wifi-sec.key-mgmt wpa-psk
|
||||||
sudo nmcli connection modify "YourNetwork" wifi-sec.psk "YourPassword"
|
sudo nmcli connection modify "YourNetwork" wifi-sec.psk "YourPassword"
|
||||||
sudo nmcli connection modify "YourNetwork" connection.autoconnect yes
|
sudo nmcli connection modify "YourNetwork" connection.autoconnect yes
|
||||||
sudo nmcli connection up "YourNetwork"
|
sudo nmcli connection up "YourNetwork"
|
||||||
|
|
||||||
# Check WiFi IP address
|
|
||||||
ip a show wlan0
|
ip a show wlan0
|
||||||
```
|
```
|
||||||
|
|
||||||
### Step 4: SSH Over WiFi
|
You can now SSH over WiFi:
|
||||||
|
|
||||||
Once connected to WiFi, note the robot's IP address and disconnect the Ethernet cable. You can now SSH over WiFi:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ssh unitree@<YOUR_ROBOT_IP>
|
ssh unitree@<ROBOT_WIFI_IP>
|
||||||
# Password: 123
|
# Password: 123
|
||||||
```
|
```
|
||||||
|
|
||||||
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address.
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Part 3: Robot Server Setup
|
## Part 3: Teleoperation & Locomotion
|
||||||
|
|
||||||
### Step 1: Install LeRobot on the Orin
|
### Run the Robot Server
|
||||||
|
|
||||||
SSH into the robot and install LeRobot:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ssh unitree@<YOUR_ROBOT_IP>
|
|
||||||
|
|
||||||
conda create -y -n lerobot python=3.12
|
|
||||||
conda activate lerobot
|
|
||||||
git clone https://github.com/huggingface/lerobot.git
|
|
||||||
cd lerobot
|
|
||||||
pip install -e '.[unitree_g1]'
|
|
||||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
|
||||||
cd unitree_sdk2_python && pip install -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
**Note**: The Unitree SDK requires CycloneDDS v0.10.2 to be installed. See the [Unitree SDK documentation](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
|
|
||||||
|
|
||||||
### Step 2: Run the Robot Server
|
|
||||||
|
|
||||||
On the robot:
|
On the robot:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/robots/unitree_g1/run_g1_server.py
|
python src/lerobot/robots/unitree_g1/run_g1_server.py --camera
|
||||||
```
|
```
|
||||||
|
|
||||||
**Important**: Keep this terminal running. The server must be active for remote control.
|
### Run the Locomotion Policy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-teleoperate \
|
||||||
|
--robot.type=unitree_g1 \
|
||||||
|
--robot.is_simulation=false \
|
||||||
|
--robot.robot_ip=<ROBOT_IP> \
|
||||||
|
--teleop.type=unitree_g1 \
|
||||||
|
--teleop.id=wbc_unitree \
|
||||||
|
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "<ROBOT_IP>", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||||
|
--display_data=true \
|
||||||
|
--robot.controller=HolosomaLocomotionController
|
||||||
|
```
|
||||||
|
|
||||||
|
We support both [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) and [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl).
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Part 4: Controlling the robot
|
## Part 4: Loco-Manipulation with the Homunculus Exoskeleton
|
||||||
|
|
||||||
With the robot server running, you can now control the robot remotely. Let's launch a locomotion policy
|
We provide a loco-manipulation solution via the Homunculus Exoskeleton — an open-source 7 DoF exoskeleton for whole-body control. Assembly instructions [here](https://github.com/nepyope/hmc_exo).
|
||||||
|
|
||||||
### Step 1: Install LeRobot on your machine
|
### Calibrate
|
||||||
|
|
||||||
```bash
|
|
||||||
conda create -y -n lerobot python=3.12
|
|
||||||
conda activate lerobot
|
|
||||||
git clone https://github.com/huggingface/lerobot.git
|
|
||||||
cd lerobot
|
|
||||||
pip install -e '.[unitree_g1]'
|
|
||||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
|
||||||
cd unitree_sdk2_python && pip install -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 2: Update Robot IP in Config
|
|
||||||
|
|
||||||
Edit the config file to match your robot's WiFi IP:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# In src/lerobot/robots/unitree_g1/config_unitree_g1.py
|
|
||||||
robot_ip: str = "<YOUR_ROBOT_IP>" # Replace with your robot's WiFi IP.
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 3: Run the Locomotion Policy
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run GR00T locomotion controller
|
|
||||||
python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1"
|
|
||||||
|
|
||||||
# Run Holosoma locomotion controller
|
|
||||||
python examples/unitree_g1/holosoma_locomotion.py
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
Press `Ctrl+C` to stop the policy.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Running in Simulation Mode (MuJoCo)
|
|
||||||
|
|
||||||
You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI.
|
|
||||||
|
|
||||||
### Calibrate Exoskeleton Teleoperator
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-calibrate \
|
lerobot-calibrate \
|
||||||
--teleop.type=unitree_g1 \
|
--teleop.type=unitree_g1 \
|
||||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||||
--teleop.id=exo
|
--teleop.id=exo
|
||||||
```
|
```
|
||||||
|
|
||||||
### Teleoperate in Simulation
|
During calibration move each joint through its entire range. After fitting, move the joint in a neutral position and press `n` to advance.
|
||||||
|
|
||||||
```bash
|
### Record a Dataset
|
||||||
lerobot-teleoperate \
|
|
||||||
--robot.type=unitree_g1 \
|
|
||||||
--robot.is_simulation=true \
|
|
||||||
--teleop.type=unitree_g1 \
|
|
||||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
|
||||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
|
||||||
--teleop.id=exo \
|
|
||||||
--fps=100
|
|
||||||
```
|
|
||||||
|
|
||||||
### Record Dataset in Simulation
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-record \
|
lerobot-record \
|
||||||
--robot.type=unitree_g1 \
|
--robot.type=unitree_g1 \
|
||||||
--robot.is_simulation=true \
|
--robot.is_simulation=true \
|
||||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||||
--teleop.type=unitree_g1 \
|
--teleop.type=unitree_g1 \
|
||||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||||
--teleop.id=exo \
|
--teleop.id=exo \
|
||||||
--dataset.repo_id=your-username/dataset-name \
|
--dataset.repo_id=your-username/dataset-name \
|
||||||
--dataset.single_task="Test" \
|
--dataset.single_task="Test" \
|
||||||
--dataset.num_episodes=2 \
|
--dataset.num_episodes=2 \
|
||||||
--dataset.episode_time_s=5 \
|
--dataset.episode_time_s=5 \
|
||||||
--dataset.reset_time_s=5 \
|
--dataset.reset_time_s=5 \
|
||||||
--dataset.push_to_hub=true \
|
--dataset.push_to_hub=true \
|
||||||
--dataset.streaming_encoding=true \
|
--dataset.streaming_encoding=true \
|
||||||
# --dataset.vcodec=auto \
|
--dataset.encoder_threads=2
|
||||||
--dataset.encoder_threads=2
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
|
> **Note:** Omit `--teleop.left_arm_config.port` and `--teleop.right_arm_config.port` if you're only using the joystick.
|
||||||
|
|
||||||
|
Example dataset: [nepyope/unitree_box_move_blue_full](https://huggingface.co/datasets/nepyope/unitree_box_move_blue_full)
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Running on Real Robot
|
## Part 5: Training & Inference
|
||||||
|
|
||||||
Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot.
|
### Train
|
||||||
|
|
||||||
### Start the Camera Server
|
|
||||||
|
|
||||||
On the robot, start the ZMQ image server:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/cameras/zmq/image_server.py
|
python src/lerobot/scripts/lerobot_train.py \
|
||||||
|
--dataset.repo_id=your-username/dataset-name \
|
||||||
|
--policy.type=pi05 \
|
||||||
|
--output_dir=./outputs/pi05_training \
|
||||||
|
--job_name=pi05_training \
|
||||||
|
--policy.repo_id=your-username/your-repo-id \
|
||||||
|
--policy.pretrained_path=lerobot/pi05_base \
|
||||||
|
--policy.compile_model=true \
|
||||||
|
--policy.gradient_checkpointing=true \
|
||||||
|
--wandb.enable=true \
|
||||||
|
--policy.dtype=bfloat16 \
|
||||||
|
--policy.freeze_vision_encoder=false \
|
||||||
|
--policy.train_expert_only=false \
|
||||||
|
--steps=3000 \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--batch_size=32
|
||||||
```
|
```
|
||||||
|
|
||||||
Keep this running in a separate terminal for camera streaming during recording.
|
### Inference with RTC
|
||||||
|
|
||||||
### Teleoperate Real Robot
|
Once trained, we recommend deploying policies using inference-time RTC:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-teleoperate \
|
python examples/rtc/eval_with_real_robot.py \
|
||||||
--robot.type=unitree_g1 \
|
--policy.path=your-username/your-repo-id \
|
||||||
--robot.is_simulation=false \
|
--policy.device=cuda \
|
||||||
--teleop.type=unitree_g1 \
|
--robot.type=unitree_g1 \
|
||||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
--robot.is_simulation=false \
|
||||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
--robot.controller=HolosomaLocomotionController \
|
||||||
--teleop.id=exo \
|
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "<ROBOT_IP>", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||||
--fps=100
|
--task="task_description" \
|
||||||
|
--duration=1000 \
|
||||||
|
--fps=30 \
|
||||||
|
--rtc.enabled=true
|
||||||
```
|
```
|
||||||
|
|
||||||
### Record Dataset on Real Robot
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-record \
|
|
||||||
--robot.type=unitree_g1 \
|
|
||||||
--robot.is_simulation=false \
|
|
||||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
|
||||||
--teleop.type=unitree_g1 \
|
|
||||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
|
||||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
|
||||||
--teleop.id=exo \
|
|
||||||
--dataset.repo_id=your-username/dataset-name \
|
|
||||||
--dataset.single_task="Test" \
|
|
||||||
--dataset.num_episodes=2 \
|
|
||||||
--dataset.episode_time_s=5 \
|
|
||||||
--dataset.reset_time_s=5 \
|
|
||||||
--dataset.push_to_hub=true \
|
|
||||||
--dataset.streaming_encoding=true \
|
|
||||||
# --dataset.vcodec=auto \
|
|
||||||
--dataset.encoder_threads=2
|
|
||||||
```
|
|
||||||
|
|
||||||
**Note**: Update `server_address` to match your robot's camera server IP.
|
|
||||||
|
|
||||||
Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real)
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Additional Resources
|
## Additional Resources
|
||||||
@@ -300,8 +254,8 @@ Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/da
|
|||||||
- [GR00T-WholeBodyControl](https://github.com/NVlabs/GR00T-WholeBodyControl)
|
- [GR00T-WholeBodyControl](https://github.com/NVlabs/GR00T-WholeBodyControl)
|
||||||
- [Holosoma](https://github.com/amazon-far/holosoma)
|
- [Holosoma](https://github.com/amazon-far/holosoma)
|
||||||
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||||
- [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
|
- [Unitree IL LeRobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
_Last updated: December 2025_
|
_Last updated: March 2026_
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ from torch import Tensor
|
|||||||
|
|
||||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||||
|
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import RTCAttentionSchedule
|
from lerobot.configs.types import RTCAttentionSchedule
|
||||||
@@ -97,6 +98,7 @@ from lerobot.robots import ( # noqa: F401
|
|||||||
bi_so_follower,
|
bi_so_follower,
|
||||||
koch_follower,
|
koch_follower,
|
||||||
so_follower,
|
so_follower,
|
||||||
|
unitree_g1,
|
||||||
)
|
)
|
||||||
from lerobot.robots.utils import make_robot_from_config
|
from lerobot.robots.utils import make_robot_from_config
|
||||||
from lerobot.utils.constants import OBS_IMAGES
|
from lerobot.utils.constants import OBS_IMAGES
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "lerobot"
|
name = "lerobot"
|
||||||
version = "0.4.5"
|
version = "0.5.1"
|
||||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||||
dynamic = ["readme"]
|
dynamic = ["readme"]
|
||||||
license = { text = "Apache-2.0" }
|
license = { text = "Apache-2.0" }
|
||||||
@@ -119,11 +119,13 @@ gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
|||||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||||
unitree_g1 = [
|
unitree_g1 = [
|
||||||
|
"unitree-sdk2==1.0.1",
|
||||||
"pyzmq>=26.2.1,<28.0.0",
|
"pyzmq>=26.2.1,<28.0.0",
|
||||||
"onnxruntime>=1.16.0,<2.0.0",
|
"onnxruntime>=1.16.0,<2.0.0",
|
||||||
"pin>=3.0.0,<4.0.0",
|
"pin>=3.0.0,<4.0.0",
|
||||||
"meshcat>=0.3.0,<0.4.0",
|
"meshcat>=0.3.0,<0.4.0",
|
||||||
"lerobot[matplotlib-dep]",
|
"lerobot[matplotlib-dep]",
|
||||||
|
"lerobot[pygame-dep]",
|
||||||
"casadi>=3.6.0,<4.0.0",
|
"casadi>=3.6.0,<4.0.0",
|
||||||
]
|
]
|
||||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||||
@@ -173,6 +175,14 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
|||||||
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||||
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||||
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||||
|
libero_plus = [
|
||||||
|
"lerobot[transformers-dep]",
|
||||||
|
"libero @ git+https://github.com/sylvestf/LIBERO-plus.git@main ; sys_platform == 'linux'",
|
||||||
|
"lerobot[scipy-dep]",
|
||||||
|
]
|
||||||
|
robomme = [
|
||||||
|
"robomme @ git+https://github.com/RoboMME/robomme_benchmark.git@main ; sys_platform == 'linux'",
|
||||||
|
]
|
||||||
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
|
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||||
|
|
||||||
# All
|
# All
|
||||||
@@ -206,6 +216,7 @@ all = [
|
|||||||
"lerobot[metaworld]",
|
"lerobot[metaworld]",
|
||||||
"lerobot[sarm]",
|
"lerobot[sarm]",
|
||||||
"lerobot[peft]",
|
"lerobot[peft]",
|
||||||
|
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -1,76 +1,73 @@
|
|||||||
#
|
#
|
||||||
# This file is autogenerated by pip-compile with Python 3.10
|
# This file is autogenerated by pip-compile with Python 3.12
|
||||||
# by the following command:
|
# by the following command:
|
||||||
#
|
#
|
||||||
# pip-compile --output-file=requirements-macos.txt requirements.in
|
# pip-compile --output-file=requirements-macos.txt requirements.in
|
||||||
#
|
#
|
||||||
-e .[all]
|
-e .[all]
|
||||||
# via -[all]
|
# via -[all]
|
||||||
absl-py==2.3.1
|
absl-py==2.4.0
|
||||||
# via
|
# via
|
||||||
# dm-control
|
# dm-control
|
||||||
# dm-env
|
# dm-env
|
||||||
# dm-tree
|
# dm-tree
|
||||||
# labmaze
|
# labmaze
|
||||||
# mujoco
|
# mujoco
|
||||||
# tensorboard
|
accelerate==1.13.0
|
||||||
accelerate==1.11.0
|
|
||||||
# via
|
# via
|
||||||
# lerobot
|
# lerobot
|
||||||
# peft
|
# peft
|
||||||
aiohappyeyeballs==2.6.1
|
aiohappyeyeballs==2.6.1
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
aiohttp==3.13.1
|
aiohttp==3.13.3
|
||||||
# via fsspec
|
# via fsspec
|
||||||
aiosignal==1.4.0
|
aiosignal==1.4.0
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
|
annotated-doc==0.0.4
|
||||||
|
# via
|
||||||
|
# fastapi
|
||||||
|
# typer
|
||||||
annotated-types==0.7.0
|
annotated-types==0.7.0
|
||||||
# via pydantic
|
# via pydantic
|
||||||
antlr4-python3-runtime==4.9.3
|
anyio==4.12.1
|
||||||
# via
|
|
||||||
# hydra-core
|
|
||||||
# omegaconf
|
|
||||||
anyio==4.11.0
|
|
||||||
# via
|
# via
|
||||||
|
# httpx
|
||||||
# starlette
|
# starlette
|
||||||
# watchfiles
|
# watchfiles
|
||||||
asttokens==3.0.0
|
asttokens==3.0.1
|
||||||
# via stack-data
|
# via stack-data
|
||||||
async-timeout==5.0.1
|
|
||||||
# via aiohttp
|
|
||||||
attrs==25.4.0
|
attrs==25.4.0
|
||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
# dm-tree
|
# dm-tree
|
||||||
# jsonlines
|
# jsonlines
|
||||||
# jsonschema
|
|
||||||
# referencing
|
|
||||||
# rerun-sdk
|
# rerun-sdk
|
||||||
av==15.1.0
|
av==15.1.0
|
||||||
# via lerobot
|
|
||||||
bddl==1.0.1
|
|
||||||
# via libero
|
|
||||||
certifi==2025.10.5
|
|
||||||
# via
|
# via
|
||||||
|
# lerobot
|
||||||
|
# qwen-vl-utils
|
||||||
|
certifi==2026.2.25
|
||||||
|
# via
|
||||||
|
# httpcore
|
||||||
|
# httpx
|
||||||
# requests
|
# requests
|
||||||
# sentry-sdk
|
# sentry-sdk
|
||||||
cffi==2.0.0
|
cffi==2.0.0
|
||||||
# via pymunk
|
# via pymunk
|
||||||
cfgv==3.4.0
|
cfgv==3.5.0
|
||||||
# via pre-commit
|
# via pre-commit
|
||||||
charset-normalizer==3.4.4
|
charset-normalizer==3.4.5
|
||||||
# via requests
|
# via requests
|
||||||
click==8.3.0
|
click==8.3.1
|
||||||
# via
|
# via
|
||||||
|
# typer
|
||||||
# uvicorn
|
# uvicorn
|
||||||
# wandb
|
# wandb
|
||||||
cloudpickle==3.1.1
|
cloudpickle==3.1.2
|
||||||
# via
|
# via gymnasium
|
||||||
# gymnasium
|
cmake==4.1.3
|
||||||
# libero
|
|
||||||
cmake==4.1.0
|
|
||||||
# via lerobot
|
# via lerobot
|
||||||
cmeel==0.57.3
|
cmeel==0.59.0
|
||||||
# via
|
# via
|
||||||
# cmeel-assimp
|
# cmeel-assimp
|
||||||
# cmeel-boost
|
# cmeel-boost
|
||||||
@@ -108,15 +105,17 @@ cmeel-zlib==1.3.1
|
|||||||
# via cmeel-assimp
|
# via cmeel-assimp
|
||||||
coal-library==3.0.1
|
coal-library==3.0.1
|
||||||
# via pin
|
# via pin
|
||||||
contourpy==1.3.2
|
contourpy==1.3.3
|
||||||
# via matplotlib
|
# via
|
||||||
coverage[toml]==7.11.0
|
# lerobot
|
||||||
|
# matplotlib
|
||||||
|
coverage[toml]==7.13.4
|
||||||
# via pytest-cov
|
# via pytest-cov
|
||||||
cycler==0.12.1
|
cycler==0.12.1
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
datasets==4.1.1
|
datasets==4.6.1
|
||||||
# via lerobot
|
# via lerobot
|
||||||
debugpy==1.8.17
|
debugpy==1.8.20
|
||||||
# via lerobot
|
# via lerobot
|
||||||
decorator==5.2.1
|
decorator==5.2.1
|
||||||
# via ipython
|
# via ipython
|
||||||
@@ -130,7 +129,7 @@ dill==0.4.0
|
|||||||
# multiprocess
|
# multiprocess
|
||||||
distlib==0.4.0
|
distlib==0.4.0
|
||||||
# via virtualenv
|
# via virtualenv
|
||||||
dm-control==1.0.34
|
dm-control==1.0.37
|
||||||
# via gym-aloha
|
# via gym-aloha
|
||||||
dm-env==1.6
|
dm-env==1.6
|
||||||
# via dm-control
|
# via dm-control
|
||||||
@@ -138,69 +137,55 @@ dm-tree==0.1.9
|
|||||||
# via
|
# via
|
||||||
# dm-control
|
# dm-control
|
||||||
# dm-env
|
# dm-env
|
||||||
# lerobot
|
|
||||||
docopt==0.6.2
|
docopt==0.6.2
|
||||||
# via num2words
|
# via num2words
|
||||||
draccus==0.10.0
|
draccus==0.10.0
|
||||||
# via lerobot
|
# via lerobot
|
||||||
dynamixel-sdk==3.8.4
|
dynamixel-sdk==3.8.4
|
||||||
# via lerobot
|
# via lerobot
|
||||||
easydict==1.13
|
|
||||||
# via libero
|
|
||||||
egl-probe @ git+https://github.com/huggingface/egl_probe.git
|
|
||||||
# via
|
|
||||||
# libero
|
|
||||||
# robomimic
|
|
||||||
eigenpy==3.10.3
|
eigenpy==3.10.3
|
||||||
# via coal-library
|
# via coal-library
|
||||||
einops==0.8.1
|
einops==0.8.2
|
||||||
# via
|
# via lerobot
|
||||||
# lerobot
|
|
||||||
# libero
|
|
||||||
eiquadprog==1.2.9
|
eiquadprog==1.2.9
|
||||||
# via placo
|
# via placo
|
||||||
etils[epath,epy]==1.13.0
|
etils[epath,epy]==1.14.0
|
||||||
# via mujoco
|
# via mujoco
|
||||||
exceptiongroup==1.3.0
|
|
||||||
# via
|
|
||||||
# anyio
|
|
||||||
# ipython
|
|
||||||
# pytest
|
|
||||||
executing==2.2.1
|
executing==2.2.1
|
||||||
# via stack-data
|
# via stack-data
|
||||||
|
faker==34.0.2
|
||||||
|
# via lerobot
|
||||||
farama-notifications==0.0.4
|
farama-notifications==0.0.4
|
||||||
# via gymnasium
|
# via gymnasium
|
||||||
fastapi==0.119.1
|
fastapi==0.135.1
|
||||||
# via teleop
|
# via
|
||||||
fastjsonschema==2.21.2
|
# lerobot
|
||||||
# via nbformat
|
# teleop
|
||||||
feetech-servo-sdk==1.0.0
|
feetech-servo-sdk==1.0.0
|
||||||
# via lerobot
|
# via lerobot
|
||||||
filelock==3.20.0
|
filelock==3.25.0
|
||||||
# via
|
# via
|
||||||
# datasets
|
# datasets
|
||||||
# diffusers
|
# diffusers
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
|
# python-discovery
|
||||||
# torch
|
# torch
|
||||||
# transformers
|
|
||||||
# virtualenv
|
# virtualenv
|
||||||
fonttools==4.60.1
|
fonttools==4.61.1
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
frozenlist==1.8.0
|
frozenlist==1.8.0
|
||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
# aiosignal
|
# aiosignal
|
||||||
fsspec[http]==2025.9.0
|
fsspec[http]==2026.2.0
|
||||||
# via
|
# via
|
||||||
# datasets
|
# datasets
|
||||||
# etils
|
# etils
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
# torch
|
# torch
|
||||||
future==1.0.0
|
|
||||||
# via libero
|
|
||||||
gitdb==4.0.12
|
gitdb==4.0.12
|
||||||
# via gitpython
|
# via gitpython
|
||||||
gitpython==3.1.45
|
gitpython==3.1.46
|
||||||
# via wandb
|
# via wandb
|
||||||
glfw==2.10.0
|
glfw==2.10.0
|
||||||
# via
|
# via
|
||||||
@@ -212,7 +197,6 @@ grpcio==1.73.1
|
|||||||
# lerobot
|
# lerobot
|
||||||
# reachy2-sdk
|
# reachy2-sdk
|
||||||
# reachy2-sdk-api
|
# reachy2-sdk-api
|
||||||
# tensorboard
|
|
||||||
grpcio-tools==1.73.1
|
grpcio-tools==1.73.1
|
||||||
# via
|
# via
|
||||||
# lerobot
|
# lerobot
|
||||||
@@ -223,71 +207,67 @@ gym-hil==0.1.13
|
|||||||
# via lerobot
|
# via lerobot
|
||||||
gym-pusht==0.1.6
|
gym-pusht==0.1.6
|
||||||
# via lerobot
|
# via lerobot
|
||||||
gymnasium==1.2.1
|
gymnasium==1.2.3
|
||||||
# via
|
# via
|
||||||
# gym-aloha
|
# gym-aloha
|
||||||
# gym-hil
|
# gym-hil
|
||||||
# gym-pusht
|
# gym-pusht
|
||||||
# lerobot
|
# lerobot
|
||||||
# libero
|
|
||||||
# metaworld
|
# metaworld
|
||||||
h11==0.16.0
|
h11==0.16.0
|
||||||
# via uvicorn
|
# via
|
||||||
h5py==3.15.1
|
# httpcore
|
||||||
# via robomimic
|
# uvicorn
|
||||||
hebi-py==2.11.0
|
hebi-py==2.11.0
|
||||||
# via lerobot
|
# via lerobot
|
||||||
hf-transfer==0.1.9
|
hf-xet==1.3.2
|
||||||
# via huggingface-hub
|
|
||||||
hf-xet==1.1.10
|
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
hidapi==0.14.0.post4
|
hidapi==0.14.0.post4
|
||||||
# via
|
# via
|
||||||
# gym-hil
|
# gym-hil
|
||||||
# lerobot
|
# lerobot
|
||||||
|
httpcore==1.0.9
|
||||||
|
# via httpx
|
||||||
httptools==0.7.1
|
httptools==0.7.1
|
||||||
# via uvicorn
|
# via uvicorn
|
||||||
huggingface-hub[cli,hf-transfer]==0.35.3
|
httpx==0.28.1
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# huggingface-hub
|
||||||
|
huggingface-hub==1.6.0
|
||||||
# via
|
# via
|
||||||
# accelerate
|
# accelerate
|
||||||
# datasets
|
# datasets
|
||||||
# diffusers
|
# diffusers
|
||||||
# lerobot
|
# lerobot
|
||||||
# peft
|
# peft
|
||||||
# timm
|
|
||||||
# tokenizers
|
# tokenizers
|
||||||
# transformers
|
# transformers
|
||||||
hydra-core==1.3.2
|
identify==2.6.17
|
||||||
# via libero
|
|
||||||
identify==2.6.15
|
|
||||||
# via pre-commit
|
# via pre-commit
|
||||||
idna==3.11
|
idna==3.11
|
||||||
# via
|
# via
|
||||||
# anyio
|
# anyio
|
||||||
|
# httpx
|
||||||
# requests
|
# requests
|
||||||
# yarl
|
# yarl
|
||||||
imageio[ffmpeg]==2.37.0
|
imageio[ffmpeg]==2.37.2
|
||||||
# via
|
# via
|
||||||
# gym-aloha
|
# gym-aloha
|
||||||
# gym-hil
|
# gym-hil
|
||||||
# lerobot
|
# lerobot
|
||||||
# metaworld
|
# metaworld
|
||||||
# robomimic
|
|
||||||
# scikit-image
|
# scikit-image
|
||||||
imageio-ffmpeg==0.6.0
|
imageio-ffmpeg==0.6.0
|
||||||
# via
|
# via imageio
|
||||||
# imageio
|
importlib-metadata==8.7.1
|
||||||
# robomimic
|
|
||||||
importlib-metadata==8.7.0
|
|
||||||
# via diffusers
|
# via diffusers
|
||||||
importlib-resources==6.5.2
|
|
||||||
# via etils
|
|
||||||
iniconfig==2.3.0
|
iniconfig==2.3.0
|
||||||
# via pytest
|
# via pytest
|
||||||
inquirerpy==0.3.4
|
ipython==9.11.0
|
||||||
# via huggingface-hub
|
|
||||||
ipython==8.37.0
|
|
||||||
# via meshcat
|
# via meshcat
|
||||||
|
ipython-pygments-lexers==1.1.1
|
||||||
|
# via ipython
|
||||||
ischedule==1.2.7
|
ischedule==1.2.7
|
||||||
# via placo
|
# via placo
|
||||||
jedi==0.19.2
|
jedi==0.19.2
|
||||||
@@ -296,44 +276,24 @@ jinja2==3.1.6
|
|||||||
# via torch
|
# via torch
|
||||||
jsonlines==4.0.0
|
jsonlines==4.0.0
|
||||||
# via lerobot
|
# via lerobot
|
||||||
jsonschema==4.25.1
|
|
||||||
# via nbformat
|
|
||||||
jsonschema-specifications==2025.9.1
|
|
||||||
# via jsonschema
|
|
||||||
jupyter-core==5.9.1
|
|
||||||
# via nbformat
|
|
||||||
jupytext==1.18.1
|
|
||||||
# via bddl
|
|
||||||
kiwisolver==1.4.9
|
kiwisolver==1.4.9
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
labmaze==1.0.6
|
labmaze==1.0.6
|
||||||
# via dm-control
|
# via dm-control
|
||||||
lazy-loader==0.4
|
lazy-loader==0.5
|
||||||
# via scikit-image
|
# via scikit-image
|
||||||
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
|
librt==0.8.1
|
||||||
# via lerobot
|
# via mypy
|
||||||
llvmlite==0.45.1
|
|
||||||
# via numba
|
|
||||||
lxml==6.0.2
|
lxml==6.0.2
|
||||||
# via dm-control
|
# via dm-control
|
||||||
markdown==3.9
|
|
||||||
# via tensorboard
|
|
||||||
markdown-it-py==4.0.0
|
markdown-it-py==4.0.0
|
||||||
# via
|
# via rich
|
||||||
# jupytext
|
|
||||||
# mdit-py-plugins
|
|
||||||
markupsafe==3.0.3
|
markupsafe==3.0.3
|
||||||
# via
|
# via jinja2
|
||||||
# jinja2
|
matplotlib==3.10.8
|
||||||
# werkzeug
|
# via lerobot
|
||||||
matplotlib==3.10.7
|
|
||||||
# via
|
|
||||||
# lerobot
|
|
||||||
# libero
|
|
||||||
matplotlib-inline==0.2.1
|
matplotlib-inline==0.2.1
|
||||||
# via ipython
|
# via ipython
|
||||||
mdit-py-plugins==0.5.0
|
|
||||||
# via jupytext
|
|
||||||
mdurl==0.1.2
|
mdurl==0.1.2
|
||||||
# via markdown-it-py
|
# via markdown-it-py
|
||||||
mergedeep==1.3.4
|
mergedeep==1.3.4
|
||||||
@@ -346,41 +306,35 @@ mock-serial==0.0.1
|
|||||||
# via lerobot
|
# via lerobot
|
||||||
mpmath==1.3.0
|
mpmath==1.3.0
|
||||||
# via sympy
|
# via sympy
|
||||||
mujoco==3.3.7
|
mujoco==3.5.0
|
||||||
# via
|
# via
|
||||||
# dm-control
|
# dm-control
|
||||||
# gym-aloha
|
# gym-aloha
|
||||||
# gym-hil
|
# gym-hil
|
||||||
# libero
|
|
||||||
# metaworld
|
# metaworld
|
||||||
# robosuite
|
multidict==6.7.1
|
||||||
multidict==6.7.0
|
|
||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
# yarl
|
# yarl
|
||||||
multiprocess==0.70.16
|
multiprocess==0.70.18
|
||||||
# via datasets
|
# via datasets
|
||||||
|
mypy==1.19.1
|
||||||
|
# via lerobot
|
||||||
mypy-extensions==1.1.0
|
mypy-extensions==1.1.0
|
||||||
# via typing-inspect
|
|
||||||
nbformat==5.10.4
|
|
||||||
# via jupytext
|
|
||||||
networkx==3.4.2
|
|
||||||
# via
|
# via
|
||||||
# bddl
|
# mypy
|
||||||
|
# typing-inspect
|
||||||
|
networkx==3.6.1
|
||||||
|
# via
|
||||||
# scikit-image
|
# scikit-image
|
||||||
# torch
|
# torch
|
||||||
ninja==1.13.0
|
nodeenv==1.10.0
|
||||||
# via lerobot
|
|
||||||
nodeenv==1.9.1
|
|
||||||
# via pre-commit
|
# via pre-commit
|
||||||
num2words==0.5.14
|
num2words==0.5.14
|
||||||
# via lerobot
|
# via lerobot
|
||||||
numba==0.62.1
|
|
||||||
# via robosuite
|
|
||||||
numpy==2.2.6
|
numpy==2.2.6
|
||||||
# via
|
# via
|
||||||
# accelerate
|
# accelerate
|
||||||
# bddl
|
|
||||||
# cmeel-boost
|
# cmeel-boost
|
||||||
# contourpy
|
# contourpy
|
||||||
# datasets
|
# datasets
|
||||||
@@ -389,16 +343,14 @@ numpy==2.2.6
|
|||||||
# dm-env
|
# dm-env
|
||||||
# dm-tree
|
# dm-tree
|
||||||
# gymnasium
|
# gymnasium
|
||||||
# h5py
|
|
||||||
# hebi-py
|
# hebi-py
|
||||||
# imageio
|
# imageio
|
||||||
# labmaze
|
# labmaze
|
||||||
# libero
|
# lerobot
|
||||||
# matplotlib
|
# matplotlib
|
||||||
# meshcat
|
# meshcat
|
||||||
# metaworld
|
# metaworld
|
||||||
# mujoco
|
# mujoco
|
||||||
# numba
|
|
||||||
# opencv-python
|
# opencv-python
|
||||||
# opencv-python-headless
|
# opencv-python-headless
|
||||||
# pandas
|
# pandas
|
||||||
@@ -406,26 +358,18 @@ numpy==2.2.6
|
|||||||
# pyquaternion
|
# pyquaternion
|
||||||
# reachy2-sdk
|
# reachy2-sdk
|
||||||
# rerun-sdk
|
# rerun-sdk
|
||||||
# robomimic
|
|
||||||
# robosuite
|
|
||||||
# scikit-image
|
# scikit-image
|
||||||
# scipy
|
# scipy
|
||||||
# shapely
|
# shapely
|
||||||
# teleop
|
# teleop
|
||||||
# tensorboard
|
|
||||||
# tensorboardx
|
|
||||||
# tifffile
|
# tifffile
|
||||||
# torchvision
|
# torchvision
|
||||||
# transformers
|
# transformers
|
||||||
# transforms3d
|
# transforms3d
|
||||||
omegaconf==2.3.0
|
opencv-python==4.13.0.92
|
||||||
# via hydra-core
|
|
||||||
opencv-python==4.12.0.88
|
|
||||||
# via
|
# via
|
||||||
# gym-pusht
|
# gym-pusht
|
||||||
# libero
|
|
||||||
# reachy2-sdk
|
# reachy2-sdk
|
||||||
# robosuite
|
|
||||||
opencv-python-headless==4.12.0.88
|
opencv-python-headless==4.12.0.88
|
||||||
# via lerobot
|
# via lerobot
|
||||||
orderly-set==5.5.0
|
orderly-set==5.5.0
|
||||||
@@ -435,97 +379,87 @@ packaging==25.0
|
|||||||
# accelerate
|
# accelerate
|
||||||
# datasets
|
# datasets
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
# hydra-core
|
|
||||||
# jupytext
|
|
||||||
# lazy-loader
|
# lazy-loader
|
||||||
# lerobot
|
# lerobot
|
||||||
# matplotlib
|
# matplotlib
|
||||||
# peft
|
# peft
|
||||||
# pytest
|
# pytest
|
||||||
|
# qwen-vl-utils
|
||||||
# reachy2-sdk
|
# reachy2-sdk
|
||||||
# scikit-image
|
# scikit-image
|
||||||
# tensorboard
|
|
||||||
# tensorboardx
|
|
||||||
# transformers
|
# transformers
|
||||||
# wandb
|
# wandb
|
||||||
pandas==2.3.3
|
pandas==2.3.3
|
||||||
# via
|
# via
|
||||||
# datasets
|
# datasets
|
||||||
# lerobot
|
# lerobot
|
||||||
parso==0.8.5
|
parso==0.8.6
|
||||||
# via jedi
|
# via jedi
|
||||||
peft==0.17.1
|
pathspec==1.0.4
|
||||||
|
# via mypy
|
||||||
|
peft==0.18.1
|
||||||
# via lerobot
|
# via lerobot
|
||||||
pexpect==4.9.0
|
pexpect==4.9.0
|
||||||
# via ipython
|
# via ipython
|
||||||
pfzy==0.3.4
|
pillow==12.1.1
|
||||||
# via inquirerpy
|
|
||||||
pillow==12.0.0
|
|
||||||
# via
|
# via
|
||||||
# diffusers
|
# diffusers
|
||||||
# imageio
|
# imageio
|
||||||
# lerobot
|
|
||||||
# matplotlib
|
# matplotlib
|
||||||
# meshcat
|
# meshcat
|
||||||
|
# qwen-vl-utils
|
||||||
# rerun-sdk
|
# rerun-sdk
|
||||||
# robosuite
|
|
||||||
# scikit-image
|
# scikit-image
|
||||||
# tensorboard
|
|
||||||
# torchvision
|
# torchvision
|
||||||
pin==3.4.0
|
pin==3.4.0
|
||||||
# via placo
|
# via placo
|
||||||
placo==0.9.14
|
placo==0.9.16
|
||||||
# via lerobot
|
# via lerobot
|
||||||
platformdirs==4.5.0
|
platformdirs==4.9.4
|
||||||
# via
|
# via
|
||||||
# jupyter-core
|
# python-discovery
|
||||||
# virtualenv
|
# virtualenv
|
||||||
# wandb
|
# wandb
|
||||||
pluggy==1.6.0
|
pluggy==1.6.0
|
||||||
# via
|
# via
|
||||||
# pytest
|
# pytest
|
||||||
# pytest-cov
|
# pytest-cov
|
||||||
pre-commit==4.3.0
|
pre-commit==4.5.1
|
||||||
# via lerobot
|
# via lerobot
|
||||||
prompt-toolkit==3.0.52
|
prompt-toolkit==3.0.52
|
||||||
# via
|
# via ipython
|
||||||
# inquirerpy
|
|
||||||
# ipython
|
|
||||||
propcache==0.4.1
|
propcache==0.4.1
|
||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
# yarl
|
# yarl
|
||||||
protobuf==6.31.0
|
protobuf==6.31.1
|
||||||
# via
|
# via
|
||||||
# dm-control
|
# dm-control
|
||||||
# grpcio-tools
|
# grpcio-tools
|
||||||
# lerobot
|
# lerobot
|
||||||
# reachy2-sdk
|
# reachy2-sdk
|
||||||
# reachy2-sdk-api
|
# reachy2-sdk-api
|
||||||
# tensorboard
|
|
||||||
# tensorboardx
|
|
||||||
# wandb
|
# wandb
|
||||||
psutil==7.1.1
|
psutil==7.2.2
|
||||||
# via
|
# via
|
||||||
# accelerate
|
# accelerate
|
||||||
# imageio
|
# imageio
|
||||||
# peft
|
# peft
|
||||||
# robomimic
|
|
||||||
ptyprocess==0.7.0
|
ptyprocess==0.7.0
|
||||||
# via pexpect
|
# via pexpect
|
||||||
pure-eval==0.2.3
|
pure-eval==0.2.3
|
||||||
# via stack-data
|
# via stack-data
|
||||||
pyarrow==21.0.0
|
pyarrow==23.0.1
|
||||||
# via
|
# via
|
||||||
# datasets
|
# datasets
|
||||||
# rerun-sdk
|
# rerun-sdk
|
||||||
pycparser==2.23
|
pycparser==3.0
|
||||||
# via cffi
|
# via cffi
|
||||||
pydantic==2.12.3
|
pydantic==2.12.5
|
||||||
# via
|
# via
|
||||||
# fastapi
|
# fastapi
|
||||||
# wandb
|
# wandb
|
||||||
pydantic-core==2.41.4
|
pydantic-core==2.41.5
|
||||||
# via pydantic
|
# via pydantic
|
||||||
pygame==2.6.1
|
pygame==2.6.1
|
||||||
# via
|
# via
|
||||||
@@ -535,33 +469,35 @@ pygame==2.6.1
|
|||||||
pygments==2.19.2
|
pygments==2.19.2
|
||||||
# via
|
# via
|
||||||
# ipython
|
# ipython
|
||||||
|
# ipython-pygments-lexers
|
||||||
# pytest
|
# pytest
|
||||||
|
# rich
|
||||||
pymunk==6.11.1
|
pymunk==6.11.1
|
||||||
# via
|
# via
|
||||||
# gym-pusht
|
# gym-pusht
|
||||||
# lerobot
|
# lerobot
|
||||||
pyngrok==7.4.1
|
pyngrok==7.5.1
|
||||||
# via meshcat
|
# via meshcat
|
||||||
pynput==1.8.1
|
pynput==1.8.1
|
||||||
# via
|
# via
|
||||||
# gym-hil
|
# gym-hil
|
||||||
# lerobot
|
# lerobot
|
||||||
pyobjc-core==12.0
|
pyobjc-core==12.1
|
||||||
# via
|
# via
|
||||||
# pyobjc-framework-applicationservices
|
# pyobjc-framework-applicationservices
|
||||||
# pyobjc-framework-cocoa
|
# pyobjc-framework-cocoa
|
||||||
# pyobjc-framework-coretext
|
# pyobjc-framework-coretext
|
||||||
# pyobjc-framework-quartz
|
# pyobjc-framework-quartz
|
||||||
pyobjc-framework-applicationservices==12.0
|
pyobjc-framework-applicationservices==12.1
|
||||||
# via pynput
|
# via pynput
|
||||||
pyobjc-framework-cocoa==12.0
|
pyobjc-framework-cocoa==12.1
|
||||||
# via
|
# via
|
||||||
# pyobjc-framework-applicationservices
|
# pyobjc-framework-applicationservices
|
||||||
# pyobjc-framework-coretext
|
# pyobjc-framework-coretext
|
||||||
# pyobjc-framework-quartz
|
# pyobjc-framework-quartz
|
||||||
pyobjc-framework-coretext==12.0
|
pyobjc-framework-coretext==12.1
|
||||||
# via pyobjc-framework-applicationservices
|
# via pyobjc-framework-applicationservices
|
||||||
pyobjc-framework-quartz==12.0
|
pyobjc-framework-quartz==12.1
|
||||||
# via
|
# via
|
||||||
# pynput
|
# pynput
|
||||||
# pyobjc-framework-applicationservices
|
# pyobjc-framework-applicationservices
|
||||||
@@ -570,13 +506,13 @@ pyopengl==3.1.10
|
|||||||
# via
|
# via
|
||||||
# dm-control
|
# dm-control
|
||||||
# mujoco
|
# mujoco
|
||||||
pyparsing==3.2.5
|
pyparsing==3.3.2
|
||||||
# via
|
# via
|
||||||
# dm-control
|
# dm-control
|
||||||
# matplotlib
|
# matplotlib
|
||||||
pyquaternion==0.9.9
|
pyquaternion==0.9.9
|
||||||
# via reachy2-sdk
|
# via reachy2-sdk
|
||||||
pyrealsense2-macosx==2.54.2
|
pyrealsense2-macosx==2.56.5
|
||||||
# via lerobot
|
# via lerobot
|
||||||
pyserial==3.5
|
pyserial==3.5
|
||||||
# via
|
# via
|
||||||
@@ -585,7 +521,6 @@ pyserial==3.5
|
|||||||
# lerobot
|
# lerobot
|
||||||
pytest==8.4.2
|
pytest==8.4.2
|
||||||
# via
|
# via
|
||||||
# bddl
|
|
||||||
# lerobot
|
# lerobot
|
||||||
# pytest-cov
|
# pytest-cov
|
||||||
# pytest-timeout
|
# pytest-timeout
|
||||||
@@ -596,11 +531,14 @@ pytest-timeout==2.4.0
|
|||||||
# via lerobot
|
# via lerobot
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
# via
|
# via
|
||||||
|
# faker
|
||||||
# matplotlib
|
# matplotlib
|
||||||
# pandas
|
# pandas
|
||||||
python-dotenv==1.1.1
|
python-discovery==1.1.1
|
||||||
|
# via virtualenv
|
||||||
|
python-dotenv==1.2.2
|
||||||
# via uvicorn
|
# via uvicorn
|
||||||
pytz==2025.2
|
pytz==2026.1.post1
|
||||||
# via pandas
|
# via pandas
|
||||||
pyyaml==6.0.3
|
pyyaml==6.0.3
|
||||||
# via
|
# via
|
||||||
@@ -609,13 +547,10 @@ pyyaml==6.0.3
|
|||||||
# draccus
|
# draccus
|
||||||
# hebi-py
|
# hebi-py
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
# jupytext
|
|
||||||
# omegaconf
|
|
||||||
# peft
|
# peft
|
||||||
# pre-commit
|
# pre-commit
|
||||||
# pyngrok
|
# pyngrok
|
||||||
# pyyaml-include
|
# pyyaml-include
|
||||||
# timm
|
|
||||||
# transformers
|
# transformers
|
||||||
# uvicorn
|
# uvicorn
|
||||||
# wandb
|
# wandb
|
||||||
@@ -625,15 +560,13 @@ pyzmq==27.1.0
|
|||||||
# via
|
# via
|
||||||
# lerobot
|
# lerobot
|
||||||
# meshcat
|
# meshcat
|
||||||
reachy2-sdk==1.0.14
|
qwen-vl-utils==0.0.14
|
||||||
|
# via lerobot
|
||||||
|
reachy2-sdk==1.0.15
|
||||||
# via lerobot
|
# via lerobot
|
||||||
reachy2-sdk-api==1.0.21
|
reachy2-sdk-api==1.0.21
|
||||||
# via reachy2-sdk
|
# via reachy2-sdk
|
||||||
referencing==0.37.0
|
regex==2026.2.28
|
||||||
# via
|
|
||||||
# jsonschema
|
|
||||||
# jsonschema-specifications
|
|
||||||
regex==2025.10.23
|
|
||||||
# via
|
# via
|
||||||
# diffusers
|
# diffusers
|
||||||
# transformers
|
# transformers
|
||||||
@@ -642,184 +575,150 @@ requests==2.32.5
|
|||||||
# datasets
|
# datasets
|
||||||
# diffusers
|
# diffusers
|
||||||
# dm-control
|
# dm-control
|
||||||
# huggingface-hub
|
# qwen-vl-utils
|
||||||
# teleop
|
# teleop
|
||||||
# transformers
|
|
||||||
# wandb
|
# wandb
|
||||||
rerun-sdk==0.26.1
|
rerun-sdk==0.26.2
|
||||||
# via lerobot
|
# via lerobot
|
||||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||||
# via placo
|
# via placo
|
||||||
robomimic==0.2.0
|
rich==14.3.3
|
||||||
# via libero
|
# via typer
|
||||||
robosuite==1.4.0
|
safetensors==0.7.0
|
||||||
# via libero
|
|
||||||
rpds-py==0.28.0
|
|
||||||
# via
|
|
||||||
# jsonschema
|
|
||||||
# referencing
|
|
||||||
safetensors==0.6.2
|
|
||||||
# via
|
# via
|
||||||
# accelerate
|
# accelerate
|
||||||
# diffusers
|
# diffusers
|
||||||
# lerobot
|
# lerobot
|
||||||
# peft
|
# peft
|
||||||
# timm
|
|
||||||
# transformers
|
# transformers
|
||||||
scikit-image==0.25.2
|
scikit-image==0.25.2
|
||||||
# via
|
# via
|
||||||
# gym-pusht
|
# gym-pusht
|
||||||
# lerobot
|
# lerobot
|
||||||
scipy==1.15.3
|
scipy==1.17.1
|
||||||
# via
|
# via
|
||||||
# dm-control
|
# dm-control
|
||||||
|
# lerobot
|
||||||
# metaworld
|
# metaworld
|
||||||
# robosuite
|
|
||||||
# scikit-image
|
# scikit-image
|
||||||
sentry-sdk==2.42.1
|
# torchdiffeq
|
||||||
|
sentry-sdk==2.54.0
|
||||||
# via wandb
|
# via wandb
|
||||||
shapely==2.1.2
|
shapely==2.1.2
|
||||||
# via gym-pusht
|
# via gym-pusht
|
||||||
|
shellingham==1.5.4
|
||||||
|
# via typer
|
||||||
six==1.17.0
|
six==1.17.0
|
||||||
# via
|
# via
|
||||||
# pynput
|
# pynput
|
||||||
# python-dateutil
|
# python-dateutil
|
||||||
smmap==5.0.2
|
smmap==5.0.3
|
||||||
# via gitdb
|
# via gitdb
|
||||||
sniffio==1.3.1
|
|
||||||
# via anyio
|
|
||||||
stack-data==0.6.3
|
stack-data==0.6.3
|
||||||
# via ipython
|
# via ipython
|
||||||
starlette==0.48.0
|
starlette==0.52.1
|
||||||
# via fastapi
|
# via fastapi
|
||||||
sympy==1.14.0
|
sympy==1.14.0
|
||||||
# via torch
|
# via torch
|
||||||
teleop==0.1.2
|
teleop==0.1.4
|
||||||
# via lerobot
|
# via lerobot
|
||||||
tensorboard==2.20.0
|
termcolor==3.3.0
|
||||||
# via robomimic
|
# via lerobot
|
||||||
tensorboard-data-server==0.7.2
|
tifffile==2026.3.3
|
||||||
# via tensorboard
|
|
||||||
tensorboardx==2.6.4
|
|
||||||
# via robomimic
|
|
||||||
termcolor==3.1.0
|
|
||||||
# via
|
|
||||||
# lerobot
|
|
||||||
# robomimic
|
|
||||||
thop==0.1.1.post2209072238
|
|
||||||
# via libero
|
|
||||||
tifffile==2025.5.10
|
|
||||||
# via scikit-image
|
# via scikit-image
|
||||||
timm==1.0.20
|
tokenizers==0.22.2
|
||||||
# via lerobot
|
|
||||||
tokenizers==0.22.1
|
|
||||||
# via transformers
|
# via transformers
|
||||||
toml==0.10.2
|
toml==0.10.2
|
||||||
# via draccus
|
# via draccus
|
||||||
tomli==2.3.0
|
torch==2.10.0
|
||||||
# via
|
|
||||||
# cmeel
|
|
||||||
# coverage
|
|
||||||
# jupytext
|
|
||||||
# pytest
|
|
||||||
torch==2.7.1
|
|
||||||
# via
|
# via
|
||||||
# accelerate
|
# accelerate
|
||||||
# lerobot
|
# lerobot
|
||||||
# peft
|
# peft
|
||||||
# robomimic
|
# torchdiffeq
|
||||||
# thop
|
|
||||||
# timm
|
|
||||||
# torchvision
|
# torchvision
|
||||||
torchcodec==0.5
|
torchcodec==0.10.0
|
||||||
# via lerobot
|
# via lerobot
|
||||||
torchvision==0.22.1
|
torchdiffeq==0.2.5
|
||||||
# via
|
# via lerobot
|
||||||
# lerobot
|
torchvision==0.25.0
|
||||||
# robomimic
|
# via lerobot
|
||||||
# timm
|
tornado==6.5.4
|
||||||
tornado==6.5.2
|
|
||||||
# via meshcat
|
# via meshcat
|
||||||
tqdm==4.67.1
|
tqdm==4.67.3
|
||||||
# via
|
# via
|
||||||
# datasets
|
# datasets
|
||||||
# dm-control
|
# dm-control
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
# peft
|
# peft
|
||||||
# robomimic
|
|
||||||
# transformers
|
# transformers
|
||||||
traitlets==5.14.3
|
traitlets==5.14.3
|
||||||
# via
|
# via
|
||||||
# ipython
|
# ipython
|
||||||
# jupyter-core
|
|
||||||
# matplotlib-inline
|
# matplotlib-inline
|
||||||
# nbformat
|
transformers==5.3.0
|
||||||
transformers==4.57.1
|
|
||||||
# via
|
# via
|
||||||
# lerobot
|
# lerobot
|
||||||
# libero
|
|
||||||
# peft
|
# peft
|
||||||
transforms3d==0.4.2
|
transforms3d==0.4.2
|
||||||
# via teleop
|
# via teleop
|
||||||
|
typer==0.24.1
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# transformers
|
||||||
typing-extensions==4.15.0
|
typing-extensions==4.15.0
|
||||||
# via
|
# via
|
||||||
# aiosignal
|
# aiosignal
|
||||||
# anyio
|
# anyio
|
||||||
# etils
|
# etils
|
||||||
# exceptiongroup
|
# faker
|
||||||
# fastapi
|
# fastapi
|
||||||
# gymnasium
|
# gymnasium
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
# ipython
|
# mypy
|
||||||
# multidict
|
|
||||||
# pydantic
|
# pydantic
|
||||||
# pydantic-core
|
# pydantic-core
|
||||||
# referencing
|
|
||||||
# rerun-sdk
|
# rerun-sdk
|
||||||
# starlette
|
# starlette
|
||||||
# torch
|
# torch
|
||||||
# typing-inspect
|
# typing-inspect
|
||||||
# typing-inspection
|
# typing-inspection
|
||||||
# uvicorn
|
|
||||||
# virtualenv
|
|
||||||
# wandb
|
# wandb
|
||||||
typing-inspect==0.9.0
|
typing-inspect==0.9.0
|
||||||
# via draccus
|
# via draccus
|
||||||
typing-inspection==0.4.2
|
typing-inspection==0.4.2
|
||||||
# via pydantic
|
# via
|
||||||
tzdata==2025.2
|
# fastapi
|
||||||
|
# pydantic
|
||||||
|
tzdata==2025.3
|
||||||
# via pandas
|
# via pandas
|
||||||
u-msgpack-python==2.8.0
|
u-msgpack-python==2.8.0
|
||||||
# via meshcat
|
# via meshcat
|
||||||
urllib3==2.5.0
|
urllib3==2.6.3
|
||||||
# via
|
# via
|
||||||
# requests
|
# requests
|
||||||
# sentry-sdk
|
# sentry-sdk
|
||||||
uvicorn[standard]==0.38.0
|
uvicorn[standard]==0.41.0
|
||||||
# via teleop
|
# via teleop
|
||||||
uvloop==0.22.1
|
uvloop==0.22.1
|
||||||
# via uvicorn
|
# via uvicorn
|
||||||
virtualenv==20.35.3
|
virtualenv==21.1.0
|
||||||
# via pre-commit
|
# via pre-commit
|
||||||
wandb==0.21.4
|
wandb==0.24.2
|
||||||
# via
|
# via lerobot
|
||||||
# lerobot
|
|
||||||
# libero
|
|
||||||
watchfiles==1.1.1
|
watchfiles==1.1.1
|
||||||
# via uvicorn
|
# via uvicorn
|
||||||
wcwidth==0.2.14
|
wcwidth==0.6.0
|
||||||
# via prompt-toolkit
|
# via prompt-toolkit
|
||||||
websocket-client==1.9.0
|
websocket-client==1.9.0
|
||||||
# via teleop
|
# via teleop
|
||||||
websockets==15.0.1
|
websockets==16.0
|
||||||
# via uvicorn
|
# via uvicorn
|
||||||
werkzeug==3.1.3
|
wrapt==2.1.2
|
||||||
# via tensorboard
|
|
||||||
wrapt==2.0.0
|
|
||||||
# via dm-tree
|
# via dm-tree
|
||||||
xxhash==3.6.0
|
xxhash==3.6.0
|
||||||
# via datasets
|
# via datasets
|
||||||
yarl==1.22.0
|
yarl==1.23.0
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
zipp==3.23.0
|
zipp==3.23.0
|
||||||
# via
|
# via
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
#
|
#
|
||||||
# This file is autogenerated by pip-compile with Python 3.10
|
# This file is autogenerated by pip-compile with Python 3.12
|
||||||
# by the following command:
|
# by the following command:
|
||||||
#
|
#
|
||||||
# pip-compile --output-file=requirements-ubuntu.txt requirements.in
|
# pip-compile --output-file=requirements-ubuntu.txt requirements.in
|
||||||
#
|
#
|
||||||
-e .[all]
|
-e .[all]
|
||||||
# via -[all]
|
# via -[all]
|
||||||
absl-py==2.3.1
|
absl-py==2.4.0
|
||||||
# via
|
# via
|
||||||
# dm-control
|
# dm-control
|
||||||
# dm-env
|
# dm-env
|
||||||
@@ -14,30 +14,33 @@ absl-py==2.3.1
|
|||||||
# labmaze
|
# labmaze
|
||||||
# mujoco
|
# mujoco
|
||||||
# tensorboard
|
# tensorboard
|
||||||
accelerate==1.11.0
|
accelerate==1.13.0
|
||||||
# via
|
# via
|
||||||
# lerobot
|
# lerobot
|
||||||
# peft
|
# peft
|
||||||
aiohappyeyeballs==2.6.1
|
aiohappyeyeballs==2.6.1
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
aiohttp==3.13.1
|
aiohttp==3.13.3
|
||||||
# via fsspec
|
# via fsspec
|
||||||
aiosignal==1.4.0
|
aiosignal==1.4.0
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
|
annotated-doc==0.0.4
|
||||||
|
# via
|
||||||
|
# fastapi
|
||||||
|
# typer
|
||||||
annotated-types==0.7.0
|
annotated-types==0.7.0
|
||||||
# via pydantic
|
# via pydantic
|
||||||
antlr4-python3-runtime==4.9.3
|
antlr4-python3-runtime==4.9.3
|
||||||
# via
|
# via
|
||||||
# hydra-core
|
# hydra-core
|
||||||
# omegaconf
|
# omegaconf
|
||||||
anyio==4.11.0
|
anyio==4.12.1
|
||||||
# via
|
# via
|
||||||
|
# httpx
|
||||||
# starlette
|
# starlette
|
||||||
# watchfiles
|
# watchfiles
|
||||||
asttokens==3.0.0
|
asttokens==3.0.1
|
||||||
# via stack-data
|
# via stack-data
|
||||||
async-timeout==5.0.1
|
|
||||||
# via aiohttp
|
|
||||||
attrs==25.4.0
|
attrs==25.4.0
|
||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
@@ -47,30 +50,35 @@ attrs==25.4.0
|
|||||||
# referencing
|
# referencing
|
||||||
# rerun-sdk
|
# rerun-sdk
|
||||||
av==15.1.0
|
av==15.1.0
|
||||||
# via lerobot
|
|
||||||
bddl==1.0.1
|
|
||||||
# via libero
|
|
||||||
certifi==2025.10.5
|
|
||||||
# via
|
# via
|
||||||
|
# lerobot
|
||||||
|
# qwen-vl-utils
|
||||||
|
bddl==1.0.1
|
||||||
|
# via hf-libero
|
||||||
|
certifi==2026.2.25
|
||||||
|
# via
|
||||||
|
# httpcore
|
||||||
|
# httpx
|
||||||
# requests
|
# requests
|
||||||
# sentry-sdk
|
# sentry-sdk
|
||||||
cffi==2.0.0
|
cffi==2.0.0
|
||||||
# via pymunk
|
# via pymunk
|
||||||
cfgv==3.4.0
|
cfgv==3.5.0
|
||||||
# via pre-commit
|
# via pre-commit
|
||||||
charset-normalizer==3.4.4
|
charset-normalizer==3.4.5
|
||||||
# via requests
|
# via requests
|
||||||
click==8.3.0
|
click==8.3.1
|
||||||
# via
|
# via
|
||||||
|
# typer
|
||||||
# uvicorn
|
# uvicorn
|
||||||
# wandb
|
# wandb
|
||||||
cloudpickle==3.1.1
|
cloudpickle==3.1.2
|
||||||
# via
|
# via
|
||||||
# gymnasium
|
# gymnasium
|
||||||
# libero
|
# hf-libero
|
||||||
cmake==4.1.0
|
cmake==4.1.3
|
||||||
# via lerobot
|
# via lerobot
|
||||||
cmeel==0.57.3
|
cmeel==0.59.0
|
||||||
# via
|
# via
|
||||||
# cmeel-assimp
|
# cmeel-assimp
|
||||||
# cmeel-boost
|
# cmeel-boost
|
||||||
@@ -108,20 +116,24 @@ cmeel-zlib==1.3.1
|
|||||||
# via cmeel-assimp
|
# via cmeel-assimp
|
||||||
coal-library==3.0.1
|
coal-library==3.0.1
|
||||||
# via pin
|
# via pin
|
||||||
contourpy==1.3.2
|
contourpy==1.3.3
|
||||||
# via matplotlib
|
# via
|
||||||
coverage[toml]==7.11.0
|
# lerobot
|
||||||
|
# matplotlib
|
||||||
|
coverage[toml]==7.13.4
|
||||||
# via pytest-cov
|
# via pytest-cov
|
||||||
|
cuda-bindings==12.9.4
|
||||||
|
# via torch
|
||||||
|
cuda-pathfinder==1.4.1
|
||||||
|
# via cuda-bindings
|
||||||
cycler==0.12.1
|
cycler==0.12.1
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
datasets==4.1.1
|
datasets==4.6.1
|
||||||
# via lerobot
|
# via lerobot
|
||||||
debugpy==1.8.17
|
debugpy==1.8.20
|
||||||
# via lerobot
|
# via lerobot
|
||||||
decorator==5.2.1
|
decorator==5.2.1
|
||||||
# via ipython
|
# via ipython
|
||||||
decord==0.6.0
|
|
||||||
# via lerobot
|
|
||||||
deepdiff==8.6.1
|
deepdiff==8.6.1
|
||||||
# via lerobot
|
# via lerobot
|
||||||
diffusers==0.35.2
|
diffusers==0.35.2
|
||||||
@@ -132,7 +144,7 @@ dill==0.4.0
|
|||||||
# multiprocess
|
# multiprocess
|
||||||
distlib==0.4.0
|
distlib==0.4.0
|
||||||
# via virtualenv
|
# via virtualenv
|
||||||
dm-control==1.0.34
|
dm-control==1.0.37
|
||||||
# via gym-aloha
|
# via gym-aloha
|
||||||
dm-env==1.6
|
dm-env==1.6
|
||||||
# via dm-control
|
# via dm-control
|
||||||
@@ -140,7 +152,6 @@ dm-tree==0.1.9
|
|||||||
# via
|
# via
|
||||||
# dm-control
|
# dm-control
|
||||||
# dm-env
|
# dm-env
|
||||||
# lerobot
|
|
||||||
docopt==0.6.2
|
docopt==0.6.2
|
||||||
# via num2words
|
# via num2words
|
||||||
draccus==0.10.0
|
draccus==0.10.0
|
||||||
@@ -148,66 +159,60 @@ draccus==0.10.0
|
|||||||
dynamixel-sdk==3.8.4
|
dynamixel-sdk==3.8.4
|
||||||
# via lerobot
|
# via lerobot
|
||||||
easydict==1.13
|
easydict==1.13
|
||||||
# via libero
|
# via hf-libero
|
||||||
egl-probe @ git+https://github.com/huggingface/egl_probe.git
|
egl-probe==1.0.2
|
||||||
# via
|
# via robomimic
|
||||||
# libero
|
|
||||||
# robomimic
|
|
||||||
eigenpy==3.10.3
|
eigenpy==3.10.3
|
||||||
# via coal-library
|
# via coal-library
|
||||||
einops==0.8.1
|
einops==0.8.2
|
||||||
# via
|
# via
|
||||||
# flash-attn
|
# hf-libero
|
||||||
# lerobot
|
# lerobot
|
||||||
# libero
|
|
||||||
eiquadprog==1.2.9
|
eiquadprog==1.2.9
|
||||||
# via placo
|
# via placo
|
||||||
etils[epath,epy]==1.13.0
|
etils[epath,epy]==1.14.0
|
||||||
# via mujoco
|
# via mujoco
|
||||||
evdev==1.9.2
|
evdev==1.9.3
|
||||||
# via pynput
|
# via pynput
|
||||||
exceptiongroup==1.3.0
|
|
||||||
# via
|
|
||||||
# anyio
|
|
||||||
# ipython
|
|
||||||
# pytest
|
|
||||||
executing==2.2.1
|
executing==2.2.1
|
||||||
# via stack-data
|
# via stack-data
|
||||||
|
faker==34.0.2
|
||||||
|
# via lerobot
|
||||||
farama-notifications==0.0.4
|
farama-notifications==0.0.4
|
||||||
# via gymnasium
|
# via gymnasium
|
||||||
fastapi==0.119.1
|
fastapi==0.135.1
|
||||||
# via teleop
|
# via
|
||||||
|
# lerobot
|
||||||
|
# teleop
|
||||||
fastjsonschema==2.21.2
|
fastjsonschema==2.21.2
|
||||||
# via nbformat
|
# via nbformat
|
||||||
feetech-servo-sdk==1.0.0
|
feetech-servo-sdk==1.0.0
|
||||||
# via lerobot
|
# via lerobot
|
||||||
filelock==3.20.0
|
filelock==3.25.0
|
||||||
# via
|
# via
|
||||||
# datasets
|
# datasets
|
||||||
# diffusers
|
# diffusers
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
|
# python-discovery
|
||||||
# torch
|
# torch
|
||||||
# transformers
|
|
||||||
# virtualenv
|
# virtualenv
|
||||||
flash-attn==2.8.3
|
fonttools==4.61.1
|
||||||
# via lerobot
|
|
||||||
fonttools==4.60.1
|
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
frozenlist==1.8.0
|
frozenlist==1.8.0
|
||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
# aiosignal
|
# aiosignal
|
||||||
fsspec[http]==2025.9.0
|
fsspec[http]==2026.2.0
|
||||||
# via
|
# via
|
||||||
# datasets
|
# datasets
|
||||||
# etils
|
# etils
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
# torch
|
# torch
|
||||||
future==1.0.0
|
future==1.0.0
|
||||||
# via libero
|
# via hf-libero
|
||||||
gitdb==4.0.12
|
gitdb==4.0.12
|
||||||
# via gitpython
|
# via gitpython
|
||||||
gitpython==3.1.45
|
gitpython==3.1.46
|
||||||
# via wandb
|
# via wandb
|
||||||
glfw==2.10.0
|
glfw==2.10.0
|
||||||
# via
|
# via
|
||||||
@@ -230,50 +235,60 @@ gym-hil==0.1.13
|
|||||||
# via lerobot
|
# via lerobot
|
||||||
gym-pusht==0.1.6
|
gym-pusht==0.1.6
|
||||||
# via lerobot
|
# via lerobot
|
||||||
gymnasium==1.2.1
|
gymnasium==1.2.3
|
||||||
# via
|
# via
|
||||||
# gym-aloha
|
# gym-aloha
|
||||||
# gym-hil
|
# gym-hil
|
||||||
# gym-pusht
|
# gym-pusht
|
||||||
|
# hf-libero
|
||||||
# lerobot
|
# lerobot
|
||||||
# libero
|
|
||||||
# metaworld
|
# metaworld
|
||||||
h11==0.16.0
|
h11==0.16.0
|
||||||
# via uvicorn
|
# via
|
||||||
h5py==3.15.1
|
# httpcore
|
||||||
|
# uvicorn
|
||||||
|
h5py==3.16.0
|
||||||
# via robomimic
|
# via robomimic
|
||||||
hebi-py==2.11.0
|
hebi-py==2.11.0
|
||||||
# via lerobot
|
# via lerobot
|
||||||
hf-transfer==0.1.9
|
hf-egl-probe==1.0.2
|
||||||
# via huggingface-hub
|
# via hf-libero
|
||||||
hf-xet==1.1.10
|
hf-libero==0.1.3
|
||||||
|
# via lerobot
|
||||||
|
hf-xet==1.3.2
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
hidapi==0.14.0.post4
|
hidapi==0.14.0.post4
|
||||||
# via
|
# via
|
||||||
# gym-hil
|
# gym-hil
|
||||||
# lerobot
|
# lerobot
|
||||||
|
httpcore==1.0.9
|
||||||
|
# via httpx
|
||||||
httptools==0.7.1
|
httptools==0.7.1
|
||||||
# via uvicorn
|
# via uvicorn
|
||||||
huggingface-hub[cli,hf-transfer]==0.35.3
|
httpx==0.28.1
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# huggingface-hub
|
||||||
|
huggingface-hub==1.6.0
|
||||||
# via
|
# via
|
||||||
# accelerate
|
# accelerate
|
||||||
# datasets
|
# datasets
|
||||||
# diffusers
|
# diffusers
|
||||||
# lerobot
|
# lerobot
|
||||||
# peft
|
# peft
|
||||||
# timm
|
|
||||||
# tokenizers
|
# tokenizers
|
||||||
# transformers
|
# transformers
|
||||||
hydra-core==1.3.2
|
hydra-core==1.3.2
|
||||||
# via libero
|
# via hf-libero
|
||||||
identify==2.6.15
|
identify==2.6.17
|
||||||
# via pre-commit
|
# via pre-commit
|
||||||
idna==3.11
|
idna==3.11
|
||||||
# via
|
# via
|
||||||
# anyio
|
# anyio
|
||||||
|
# httpx
|
||||||
# requests
|
# requests
|
||||||
# yarl
|
# yarl
|
||||||
imageio[ffmpeg]==2.37.0
|
imageio[ffmpeg]==2.37.2
|
||||||
# via
|
# via
|
||||||
# gym-aloha
|
# gym-aloha
|
||||||
# gym-hil
|
# gym-hil
|
||||||
@@ -285,16 +300,14 @@ imageio-ffmpeg==0.6.0
|
|||||||
# via
|
# via
|
||||||
# imageio
|
# imageio
|
||||||
# robomimic
|
# robomimic
|
||||||
importlib-metadata==8.7.0
|
importlib-metadata==8.7.1
|
||||||
# via diffusers
|
# via diffusers
|
||||||
importlib-resources==6.5.2
|
|
||||||
# via etils
|
|
||||||
iniconfig==2.3.0
|
iniconfig==2.3.0
|
||||||
# via pytest
|
# via pytest
|
||||||
inquirerpy==0.3.4
|
ipython==9.11.0
|
||||||
# via huggingface-hub
|
|
||||||
ipython==8.37.0
|
|
||||||
# via meshcat
|
# via meshcat
|
||||||
|
ipython-pygments-lexers==1.1.1
|
||||||
|
# via ipython
|
||||||
ischedule==1.2.7
|
ischedule==1.2.7
|
||||||
# via placo
|
# via placo
|
||||||
jedi==0.19.2
|
jedi==0.19.2
|
||||||
@@ -303,40 +316,41 @@ jinja2==3.1.6
|
|||||||
# via torch
|
# via torch
|
||||||
jsonlines==4.0.0
|
jsonlines==4.0.0
|
||||||
# via lerobot
|
# via lerobot
|
||||||
jsonschema==4.25.1
|
jsonschema==4.26.0
|
||||||
# via nbformat
|
# via nbformat
|
||||||
jsonschema-specifications==2025.9.1
|
jsonschema-specifications==2025.9.1
|
||||||
# via jsonschema
|
# via jsonschema
|
||||||
jupyter-core==5.9.1
|
jupyter-core==5.9.1
|
||||||
# via nbformat
|
# via nbformat
|
||||||
jupytext==1.18.1
|
jupytext==1.19.1
|
||||||
# via bddl
|
# via bddl
|
||||||
kiwisolver==1.4.9
|
kiwisolver==1.4.9
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
labmaze==1.0.6
|
labmaze==1.0.6
|
||||||
# via dm-control
|
# via dm-control
|
||||||
lazy-loader==0.4
|
lazy-loader==0.5
|
||||||
# via scikit-image
|
# via scikit-image
|
||||||
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
|
librt==0.8.1
|
||||||
# via lerobot
|
# via mypy
|
||||||
llvmlite==0.45.1
|
llvmlite==0.46.0
|
||||||
# via numba
|
# via numba
|
||||||
lxml==6.0.2
|
lxml==6.0.2
|
||||||
# via dm-control
|
# via dm-control
|
||||||
markdown==3.9
|
markdown==3.10.2
|
||||||
# via tensorboard
|
# via tensorboard
|
||||||
markdown-it-py==4.0.0
|
markdown-it-py==4.0.0
|
||||||
# via
|
# via
|
||||||
# jupytext
|
# jupytext
|
||||||
# mdit-py-plugins
|
# mdit-py-plugins
|
||||||
|
# rich
|
||||||
markupsafe==3.0.3
|
markupsafe==3.0.3
|
||||||
# via
|
# via
|
||||||
# jinja2
|
# jinja2
|
||||||
# werkzeug
|
# werkzeug
|
||||||
matplotlib==3.10.7
|
matplotlib==3.10.8
|
||||||
# via
|
# via
|
||||||
|
# hf-libero
|
||||||
# lerobot
|
# lerobot
|
||||||
# libero
|
|
||||||
matplotlib-inline==0.2.1
|
matplotlib-inline==0.2.1
|
||||||
# via ipython
|
# via ipython
|
||||||
mdit-py-plugins==0.5.0
|
mdit-py-plugins==0.5.0
|
||||||
@@ -353,36 +367,38 @@ mock-serial==0.0.1
|
|||||||
# via lerobot
|
# via lerobot
|
||||||
mpmath==1.3.0
|
mpmath==1.3.0
|
||||||
# via sympy
|
# via sympy
|
||||||
mujoco==3.3.7
|
mujoco==3.5.0
|
||||||
# via
|
# via
|
||||||
# dm-control
|
# dm-control
|
||||||
# gym-aloha
|
# gym-aloha
|
||||||
# gym-hil
|
# gym-hil
|
||||||
# libero
|
# hf-libero
|
||||||
# metaworld
|
# metaworld
|
||||||
# robosuite
|
# robosuite
|
||||||
multidict==6.7.0
|
multidict==6.7.1
|
||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
# yarl
|
# yarl
|
||||||
multiprocess==0.70.16
|
multiprocess==0.70.18
|
||||||
# via datasets
|
# via datasets
|
||||||
|
mypy==1.19.1
|
||||||
|
# via lerobot
|
||||||
mypy-extensions==1.1.0
|
mypy-extensions==1.1.0
|
||||||
# via typing-inspect
|
# via
|
||||||
|
# mypy
|
||||||
|
# typing-inspect
|
||||||
nbformat==5.10.4
|
nbformat==5.10.4
|
||||||
# via jupytext
|
# via jupytext
|
||||||
networkx==3.4.2
|
networkx==3.6.1
|
||||||
# via
|
# via
|
||||||
# bddl
|
# bddl
|
||||||
# scikit-image
|
# scikit-image
|
||||||
# torch
|
# torch
|
||||||
ninja==1.13.0
|
nodeenv==1.10.0
|
||||||
# via lerobot
|
|
||||||
nodeenv==1.9.1
|
|
||||||
# via pre-commit
|
# via pre-commit
|
||||||
num2words==0.5.14
|
num2words==0.5.14
|
||||||
# via lerobot
|
# via lerobot
|
||||||
numba==0.62.1
|
numba==0.64.0
|
||||||
# via robosuite
|
# via robosuite
|
||||||
numpy==2.2.6
|
numpy==2.2.6
|
||||||
# via
|
# via
|
||||||
@@ -391,7 +407,6 @@ numpy==2.2.6
|
|||||||
# cmeel-boost
|
# cmeel-boost
|
||||||
# contourpy
|
# contourpy
|
||||||
# datasets
|
# datasets
|
||||||
# decord
|
|
||||||
# diffusers
|
# diffusers
|
||||||
# dm-control
|
# dm-control
|
||||||
# dm-env
|
# dm-env
|
||||||
@@ -399,9 +414,10 @@ numpy==2.2.6
|
|||||||
# gymnasium
|
# gymnasium
|
||||||
# h5py
|
# h5py
|
||||||
# hebi-py
|
# hebi-py
|
||||||
|
# hf-libero
|
||||||
# imageio
|
# imageio
|
||||||
# labmaze
|
# labmaze
|
||||||
# libero
|
# lerobot
|
||||||
# matplotlib
|
# matplotlib
|
||||||
# meshcat
|
# meshcat
|
||||||
# metaworld
|
# metaworld
|
||||||
@@ -426,49 +442,51 @@ numpy==2.2.6
|
|||||||
# torchvision
|
# torchvision
|
||||||
# transformers
|
# transformers
|
||||||
# transforms3d
|
# transforms3d
|
||||||
nvidia-cublas-cu12==12.6.4.1
|
nvidia-cublas-cu12==12.8.4.1
|
||||||
# via
|
# via
|
||||||
# nvidia-cudnn-cu12
|
# nvidia-cudnn-cu12
|
||||||
# nvidia-cusolver-cu12
|
# nvidia-cusolver-cu12
|
||||||
# torch
|
# torch
|
||||||
nvidia-cuda-cupti-cu12==12.6.80
|
nvidia-cuda-cupti-cu12==12.8.90
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-cuda-nvrtc-cu12==12.6.77
|
nvidia-cuda-nvrtc-cu12==12.8.93
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-cuda-runtime-cu12==12.6.77
|
nvidia-cuda-runtime-cu12==12.8.90
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-cudnn-cu12==9.5.1.17
|
nvidia-cudnn-cu12==9.10.2.21
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-cufft-cu12==11.3.0.4
|
nvidia-cufft-cu12==11.3.3.83
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-cufile-cu12==1.11.1.6
|
nvidia-cufile-cu12==1.13.1.3
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-curand-cu12==10.3.7.77
|
nvidia-curand-cu12==10.3.9.90
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-cusolver-cu12==11.7.1.2
|
nvidia-cusolver-cu12==11.7.3.90
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-cusparse-cu12==12.5.4.2
|
nvidia-cusparse-cu12==12.5.8.93
|
||||||
# via
|
# via
|
||||||
# nvidia-cusolver-cu12
|
# nvidia-cusolver-cu12
|
||||||
# torch
|
# torch
|
||||||
nvidia-cusparselt-cu12==0.6.3
|
nvidia-cusparselt-cu12==0.7.1
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-nccl-cu12==2.26.2
|
nvidia-nccl-cu12==2.27.5
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-nvjitlink-cu12==12.6.85
|
nvidia-nvjitlink-cu12==12.8.93
|
||||||
# via
|
# via
|
||||||
# nvidia-cufft-cu12
|
# nvidia-cufft-cu12
|
||||||
# nvidia-cusolver-cu12
|
# nvidia-cusolver-cu12
|
||||||
# nvidia-cusparse-cu12
|
# nvidia-cusparse-cu12
|
||||||
# torch
|
# torch
|
||||||
nvidia-nvtx-cu12==12.6.77
|
nvidia-nvshmem-cu12==3.4.5
|
||||||
|
# via torch
|
||||||
|
nvidia-nvtx-cu12==12.8.90
|
||||||
# via torch
|
# via torch
|
||||||
omegaconf==2.3.0
|
omegaconf==2.3.0
|
||||||
# via hydra-core
|
# via hydra-core
|
||||||
opencv-python==4.12.0.88
|
opencv-python==4.13.0.92
|
||||||
# via
|
# via
|
||||||
# gym-pusht
|
# gym-pusht
|
||||||
# libero
|
# hf-libero
|
||||||
# reachy2-sdk
|
# reachy2-sdk
|
||||||
# robosuite
|
# robosuite
|
||||||
opencv-python-headless==4.12.0.88
|
opencv-python-headless==4.12.0.88
|
||||||
@@ -487,6 +505,7 @@ packaging==25.0
|
|||||||
# matplotlib
|
# matplotlib
|
||||||
# peft
|
# peft
|
||||||
# pytest
|
# pytest
|
||||||
|
# qwen-vl-utils
|
||||||
# reachy2-sdk
|
# reachy2-sdk
|
||||||
# scikit-image
|
# scikit-image
|
||||||
# tensorboard
|
# tensorboard
|
||||||
@@ -497,21 +516,21 @@ pandas==2.3.3
|
|||||||
# via
|
# via
|
||||||
# datasets
|
# datasets
|
||||||
# lerobot
|
# lerobot
|
||||||
parso==0.8.5
|
parso==0.8.6
|
||||||
# via jedi
|
# via jedi
|
||||||
peft==0.17.1
|
pathspec==1.0.4
|
||||||
|
# via mypy
|
||||||
|
peft==0.18.1
|
||||||
# via lerobot
|
# via lerobot
|
||||||
pexpect==4.9.0
|
pexpect==4.9.0
|
||||||
# via ipython
|
# via ipython
|
||||||
pfzy==0.3.4
|
pillow==12.1.1
|
||||||
# via inquirerpy
|
|
||||||
pillow==12.0.0
|
|
||||||
# via
|
# via
|
||||||
# diffusers
|
# diffusers
|
||||||
# imageio
|
# imageio
|
||||||
# lerobot
|
|
||||||
# matplotlib
|
# matplotlib
|
||||||
# meshcat
|
# meshcat
|
||||||
|
# qwen-vl-utils
|
||||||
# rerun-sdk
|
# rerun-sdk
|
||||||
# robosuite
|
# robosuite
|
||||||
# scikit-image
|
# scikit-image
|
||||||
@@ -519,28 +538,27 @@ pillow==12.0.0
|
|||||||
# torchvision
|
# torchvision
|
||||||
pin==3.4.0
|
pin==3.4.0
|
||||||
# via placo
|
# via placo
|
||||||
placo==0.9.14
|
placo==0.9.16
|
||||||
# via lerobot
|
# via lerobot
|
||||||
platformdirs==4.5.0
|
platformdirs==4.9.4
|
||||||
# via
|
# via
|
||||||
# jupyter-core
|
# jupyter-core
|
||||||
|
# python-discovery
|
||||||
# virtualenv
|
# virtualenv
|
||||||
# wandb
|
# wandb
|
||||||
pluggy==1.6.0
|
pluggy==1.6.0
|
||||||
# via
|
# via
|
||||||
# pytest
|
# pytest
|
||||||
# pytest-cov
|
# pytest-cov
|
||||||
pre-commit==4.3.0
|
pre-commit==4.5.1
|
||||||
# via lerobot
|
# via lerobot
|
||||||
prompt-toolkit==3.0.52
|
prompt-toolkit==3.0.52
|
||||||
# via
|
# via ipython
|
||||||
# inquirerpy
|
|
||||||
# ipython
|
|
||||||
propcache==0.4.1
|
propcache==0.4.1
|
||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
# yarl
|
# yarl
|
||||||
protobuf==6.31.0
|
protobuf==6.31.1
|
||||||
# via
|
# via
|
||||||
# dm-control
|
# dm-control
|
||||||
# grpcio-tools
|
# grpcio-tools
|
||||||
@@ -550,7 +568,7 @@ protobuf==6.31.0
|
|||||||
# tensorboard
|
# tensorboard
|
||||||
# tensorboardx
|
# tensorboardx
|
||||||
# wandb
|
# wandb
|
||||||
psutil==7.1.1
|
psutil==7.2.2
|
||||||
# via
|
# via
|
||||||
# accelerate
|
# accelerate
|
||||||
# imageio
|
# imageio
|
||||||
@@ -560,17 +578,17 @@ ptyprocess==0.7.0
|
|||||||
# via pexpect
|
# via pexpect
|
||||||
pure-eval==0.2.3
|
pure-eval==0.2.3
|
||||||
# via stack-data
|
# via stack-data
|
||||||
pyarrow==21.0.0
|
pyarrow==23.0.1
|
||||||
# via
|
# via
|
||||||
# datasets
|
# datasets
|
||||||
# rerun-sdk
|
# rerun-sdk
|
||||||
pycparser==2.23
|
pycparser==3.0
|
||||||
# via cffi
|
# via cffi
|
||||||
pydantic==2.12.3
|
pydantic==2.12.5
|
||||||
# via
|
# via
|
||||||
# fastapi
|
# fastapi
|
||||||
# wandb
|
# wandb
|
||||||
pydantic-core==2.41.4
|
pydantic-core==2.41.5
|
||||||
# via pydantic
|
# via pydantic
|
||||||
pygame==2.6.1
|
pygame==2.6.1
|
||||||
# via
|
# via
|
||||||
@@ -580,12 +598,14 @@ pygame==2.6.1
|
|||||||
pygments==2.19.2
|
pygments==2.19.2
|
||||||
# via
|
# via
|
||||||
# ipython
|
# ipython
|
||||||
|
# ipython-pygments-lexers
|
||||||
# pytest
|
# pytest
|
||||||
|
# rich
|
||||||
pymunk==6.11.1
|
pymunk==6.11.1
|
||||||
# via
|
# via
|
||||||
# gym-pusht
|
# gym-pusht
|
||||||
# lerobot
|
# lerobot
|
||||||
pyngrok==7.4.1
|
pyngrok==7.5.1
|
||||||
# via meshcat
|
# via meshcat
|
||||||
pynput==1.8.1
|
pynput==1.8.1
|
||||||
# via
|
# via
|
||||||
@@ -595,7 +615,7 @@ pyopengl==3.1.10
|
|||||||
# via
|
# via
|
||||||
# dm-control
|
# dm-control
|
||||||
# mujoco
|
# mujoco
|
||||||
pyparsing==3.2.5
|
pyparsing==3.3.2
|
||||||
# via
|
# via
|
||||||
# dm-control
|
# dm-control
|
||||||
# matplotlib
|
# matplotlib
|
||||||
@@ -621,13 +641,16 @@ pytest-timeout==2.4.0
|
|||||||
# via lerobot
|
# via lerobot
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
# via
|
# via
|
||||||
|
# faker
|
||||||
# matplotlib
|
# matplotlib
|
||||||
# pandas
|
# pandas
|
||||||
python-dotenv==1.1.1
|
python-discovery==1.1.1
|
||||||
|
# via virtualenv
|
||||||
|
python-dotenv==1.2.2
|
||||||
# via uvicorn
|
# via uvicorn
|
||||||
python-xlib==0.33
|
python-xlib==0.33
|
||||||
# via pynput
|
# via pynput
|
||||||
pytz==2025.2
|
pytz==2026.1.post1
|
||||||
# via pandas
|
# via pandas
|
||||||
pyyaml==6.0.3
|
pyyaml==6.0.3
|
||||||
# via
|
# via
|
||||||
@@ -642,7 +665,6 @@ pyyaml==6.0.3
|
|||||||
# pre-commit
|
# pre-commit
|
||||||
# pyngrok
|
# pyngrok
|
||||||
# pyyaml-include
|
# pyyaml-include
|
||||||
# timm
|
|
||||||
# transformers
|
# transformers
|
||||||
# uvicorn
|
# uvicorn
|
||||||
# wandb
|
# wandb
|
||||||
@@ -652,7 +674,9 @@ pyzmq==27.1.0
|
|||||||
# via
|
# via
|
||||||
# lerobot
|
# lerobot
|
||||||
# meshcat
|
# meshcat
|
||||||
reachy2-sdk==1.0.14
|
qwen-vl-utils==0.0.14
|
||||||
|
# via lerobot
|
||||||
|
reachy2-sdk==1.0.15
|
||||||
# via lerobot
|
# via lerobot
|
||||||
reachy2-sdk-api==1.0.21
|
reachy2-sdk-api==1.0.21
|
||||||
# via reachy2-sdk
|
# via reachy2-sdk
|
||||||
@@ -660,7 +684,7 @@ referencing==0.37.0
|
|||||||
# via
|
# via
|
||||||
# jsonschema
|
# jsonschema
|
||||||
# jsonschema-specifications
|
# jsonschema-specifications
|
||||||
regex==2025.10.23
|
regex==2026.2.28
|
||||||
# via
|
# via
|
||||||
# diffusers
|
# diffusers
|
||||||
# transformers
|
# transformers
|
||||||
@@ -669,60 +693,62 @@ requests==2.32.5
|
|||||||
# datasets
|
# datasets
|
||||||
# diffusers
|
# diffusers
|
||||||
# dm-control
|
# dm-control
|
||||||
# huggingface-hub
|
# qwen-vl-utils
|
||||||
# teleop
|
# teleop
|
||||||
# transformers
|
|
||||||
# wandb
|
# wandb
|
||||||
rerun-sdk==0.26.1
|
rerun-sdk==0.26.2
|
||||||
# via lerobot
|
# via lerobot
|
||||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||||
# via placo
|
# via placo
|
||||||
|
rich==14.3.3
|
||||||
|
# via typer
|
||||||
robomimic==0.2.0
|
robomimic==0.2.0
|
||||||
# via libero
|
# via hf-libero
|
||||||
robosuite==1.4.0
|
robosuite==1.4.0
|
||||||
# via libero
|
# via hf-libero
|
||||||
rpds-py==0.28.0
|
rpds-py==0.30.0
|
||||||
# via
|
# via
|
||||||
# jsonschema
|
# jsonschema
|
||||||
# referencing
|
# referencing
|
||||||
safetensors==0.6.2
|
safetensors==0.7.0
|
||||||
# via
|
# via
|
||||||
# accelerate
|
# accelerate
|
||||||
# diffusers
|
# diffusers
|
||||||
# lerobot
|
# lerobot
|
||||||
# peft
|
# peft
|
||||||
# timm
|
|
||||||
# transformers
|
# transformers
|
||||||
scikit-image==0.25.2
|
scikit-image==0.25.2
|
||||||
# via
|
# via
|
||||||
# gym-pusht
|
# gym-pusht
|
||||||
# lerobot
|
# lerobot
|
||||||
scipy==1.15.3
|
scipy==1.17.1
|
||||||
# via
|
# via
|
||||||
# dm-control
|
# dm-control
|
||||||
|
# lerobot
|
||||||
# metaworld
|
# metaworld
|
||||||
# robosuite
|
# robosuite
|
||||||
# scikit-image
|
# scikit-image
|
||||||
sentry-sdk==2.42.1
|
# torchdiffeq
|
||||||
|
sentry-sdk==2.54.0
|
||||||
# via wandb
|
# via wandb
|
||||||
shapely==2.1.2
|
shapely==2.1.2
|
||||||
# via gym-pusht
|
# via gym-pusht
|
||||||
|
shellingham==1.5.4
|
||||||
|
# via typer
|
||||||
six==1.17.0
|
six==1.17.0
|
||||||
# via
|
# via
|
||||||
# pynput
|
# pynput
|
||||||
# python-dateutil
|
# python-dateutil
|
||||||
# python-xlib
|
# python-xlib
|
||||||
smmap==5.0.2
|
smmap==5.0.3
|
||||||
# via gitdb
|
# via gitdb
|
||||||
sniffio==1.3.1
|
|
||||||
# via anyio
|
|
||||||
stack-data==0.6.3
|
stack-data==0.6.3
|
||||||
# via ipython
|
# via ipython
|
||||||
starlette==0.48.0
|
starlette==0.52.1
|
||||||
# via fastapi
|
# via fastapi
|
||||||
sympy==1.14.0
|
sympy==1.14.0
|
||||||
# via torch
|
# via torch
|
||||||
teleop==0.1.2
|
teleop==0.1.4
|
||||||
# via lerobot
|
# via lerobot
|
||||||
tensorboard==2.20.0
|
tensorboard==2.20.0
|
||||||
# via robomimic
|
# via robomimic
|
||||||
@@ -730,46 +756,38 @@ tensorboard-data-server==0.7.2
|
|||||||
# via tensorboard
|
# via tensorboard
|
||||||
tensorboardx==2.6.4
|
tensorboardx==2.6.4
|
||||||
# via robomimic
|
# via robomimic
|
||||||
termcolor==3.1.0
|
termcolor==3.3.0
|
||||||
# via
|
# via
|
||||||
# lerobot
|
# lerobot
|
||||||
# robomimic
|
# robomimic
|
||||||
thop==0.1.1.post2209072238
|
thop==0.1.1.post2209072238
|
||||||
# via libero
|
# via hf-libero
|
||||||
tifffile==2025.5.10
|
tifffile==2026.3.3
|
||||||
# via scikit-image
|
# via scikit-image
|
||||||
timm==1.0.20
|
tokenizers==0.22.2
|
||||||
# via lerobot
|
|
||||||
tokenizers==0.22.1
|
|
||||||
# via transformers
|
# via transformers
|
||||||
toml==0.10.2
|
toml==0.10.2
|
||||||
# via draccus
|
# via draccus
|
||||||
tomli==2.3.0
|
torch==2.10.0
|
||||||
# via
|
|
||||||
# cmeel
|
|
||||||
# coverage
|
|
||||||
# jupytext
|
|
||||||
# pytest
|
|
||||||
torch==2.7.1
|
|
||||||
# via
|
# via
|
||||||
# accelerate
|
# accelerate
|
||||||
# flash-attn
|
|
||||||
# lerobot
|
# lerobot
|
||||||
# peft
|
# peft
|
||||||
# robomimic
|
# robomimic
|
||||||
# thop
|
# thop
|
||||||
# timm
|
# torchdiffeq
|
||||||
# torchvision
|
# torchvision
|
||||||
torchcodec==0.5
|
torchcodec==0.10.0
|
||||||
# via lerobot
|
# via lerobot
|
||||||
torchvision==0.22.1
|
torchdiffeq==0.2.5
|
||||||
|
# via lerobot
|
||||||
|
torchvision==0.25.0
|
||||||
# via
|
# via
|
||||||
# lerobot
|
# lerobot
|
||||||
# robomimic
|
# robomimic
|
||||||
# timm
|
tornado==6.5.4
|
||||||
tornado==6.5.2
|
|
||||||
# via meshcat
|
# via meshcat
|
||||||
tqdm==4.67.1
|
tqdm==4.67.3
|
||||||
# via
|
# via
|
||||||
# datasets
|
# datasets
|
||||||
# dm-control
|
# dm-control
|
||||||
@@ -783,26 +801,29 @@ traitlets==5.14.3
|
|||||||
# jupyter-core
|
# jupyter-core
|
||||||
# matplotlib-inline
|
# matplotlib-inline
|
||||||
# nbformat
|
# nbformat
|
||||||
transformers==4.57.1
|
transformers==5.3.0
|
||||||
# via
|
# via
|
||||||
|
# hf-libero
|
||||||
# lerobot
|
# lerobot
|
||||||
# libero
|
|
||||||
# peft
|
# peft
|
||||||
transforms3d==0.4.2
|
transforms3d==0.4.2
|
||||||
# via teleop
|
# via teleop
|
||||||
triton==3.3.1
|
triton==3.6.0
|
||||||
# via torch
|
# via torch
|
||||||
|
typer==0.24.1
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# transformers
|
||||||
typing-extensions==4.15.0
|
typing-extensions==4.15.0
|
||||||
# via
|
# via
|
||||||
# aiosignal
|
# aiosignal
|
||||||
# anyio
|
# anyio
|
||||||
# etils
|
# etils
|
||||||
# exceptiongroup
|
# faker
|
||||||
# fastapi
|
# fastapi
|
||||||
# gymnasium
|
# gymnasium
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
# ipython
|
# mypy
|
||||||
# multidict
|
|
||||||
# pydantic
|
# pydantic
|
||||||
# pydantic-core
|
# pydantic-core
|
||||||
# referencing
|
# referencing
|
||||||
@@ -811,46 +832,46 @@ typing-extensions==4.15.0
|
|||||||
# torch
|
# torch
|
||||||
# typing-inspect
|
# typing-inspect
|
||||||
# typing-inspection
|
# typing-inspection
|
||||||
# uvicorn
|
|
||||||
# virtualenv
|
|
||||||
# wandb
|
# wandb
|
||||||
typing-inspect==0.9.0
|
typing-inspect==0.9.0
|
||||||
# via draccus
|
# via draccus
|
||||||
typing-inspection==0.4.2
|
typing-inspection==0.4.2
|
||||||
# via pydantic
|
# via
|
||||||
tzdata==2025.2
|
# fastapi
|
||||||
|
# pydantic
|
||||||
|
tzdata==2025.3
|
||||||
# via pandas
|
# via pandas
|
||||||
u-msgpack-python==2.8.0
|
u-msgpack-python==2.8.0
|
||||||
# via meshcat
|
# via meshcat
|
||||||
urllib3==2.5.0
|
urllib3==2.6.3
|
||||||
# via
|
# via
|
||||||
# requests
|
# requests
|
||||||
# sentry-sdk
|
# sentry-sdk
|
||||||
uvicorn[standard]==0.38.0
|
uvicorn[standard]==0.41.0
|
||||||
# via teleop
|
# via teleop
|
||||||
uvloop==0.22.1
|
uvloop==0.22.1
|
||||||
# via uvicorn
|
# via uvicorn
|
||||||
virtualenv==20.35.3
|
virtualenv==21.1.0
|
||||||
# via pre-commit
|
# via pre-commit
|
||||||
wandb==0.21.4
|
wandb==0.24.2
|
||||||
# via
|
# via
|
||||||
|
# hf-libero
|
||||||
# lerobot
|
# lerobot
|
||||||
# libero
|
|
||||||
watchfiles==1.1.1
|
watchfiles==1.1.1
|
||||||
# via uvicorn
|
# via uvicorn
|
||||||
wcwidth==0.2.14
|
wcwidth==0.6.0
|
||||||
# via prompt-toolkit
|
# via prompt-toolkit
|
||||||
websocket-client==1.9.0
|
websocket-client==1.9.0
|
||||||
# via teleop
|
# via teleop
|
||||||
websockets==15.0.1
|
websockets==16.0
|
||||||
# via uvicorn
|
# via uvicorn
|
||||||
werkzeug==3.1.3
|
werkzeug==3.1.6
|
||||||
# via tensorboard
|
# via tensorboard
|
||||||
wrapt==2.0.0
|
wrapt==2.1.2
|
||||||
# via dm-tree
|
# via dm-tree
|
||||||
xxhash==3.6.0
|
xxhash==3.6.0
|
||||||
# via datasets
|
# via datasets
|
||||||
yarl==1.22.0
|
yarl==1.23.0
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
zipp==3.23.0
|
zipp==3.23.0
|
||||||
# via
|
# via
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
# requirements.in
|
# requirements.in
|
||||||
|
|
||||||
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.0.1 25A362 arm64).
|
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.3.1 25D2128 arm64).
|
||||||
# Darwin MacBook-Pro.local 25.0.0 Darwin Kernel Version 25.0.0: Wed Sep 17 21:42:08 PDT 2025; root:xnu-12377.1.9~141/RELEASE_ARM64_T8132 arm64
|
# Darwin MacBook-Pro.local 25.3.0 Darwin Kernel Version 25.3.0: Wed Jan 28 20:54:55 PST 2026; root:xnu-12377.91.3~2/RELEASE_ARM64_T8132 arm64
|
||||||
|
|
||||||
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.3 LTS x86_64).
|
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.4 LTS x86_64).
|
||||||
# Linux mlerobot-linux 6.14.0-33-generic #33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
# Linux lerobot-linux 6.17.0-14-generic #14~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Jan 15 15:52:10 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
||||||
|
|
||||||
-e .[all]
|
-e .[all]
|
||||||
|
|||||||
@@ -49,9 +49,14 @@ import torch
|
|||||||
|
|
||||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||||
from lerobot.robots import (
|
from lerobot.robots import ( # noqa: F401
|
||||||
RobotConfig, # noqa: F401
|
Robot,
|
||||||
|
RobotConfig,
|
||||||
|
bi_so_follower,
|
||||||
|
koch_follower,
|
||||||
make_robot_from_config,
|
make_robot_from_config,
|
||||||
|
omx_follower,
|
||||||
|
so_follower,
|
||||||
)
|
)
|
||||||
from lerobot.transport import (
|
from lerobot.transport import (
|
||||||
services_pb2, # type: ignore
|
services_pb2, # type: ignore
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ class ZMQCamera(Camera):
|
|||||||
try:
|
try:
|
||||||
message = self.socket.recv_string()
|
message = self.socket.recv_string()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import
|
# zmq is lazy-imported in connect(), so check by name to avoid a top-level import
|
||||||
if type(e).__name__ == "Again":
|
if type(e).__name__ == "Again":
|
||||||
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
|
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import base64
|
|||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
@@ -42,10 +43,57 @@ def encode_image(image: np.ndarray, quality: int = 80) -> str:
|
|||||||
return base64.b64encode(buffer).decode("utf-8")
|
return base64.b64encode(buffer).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
class CameraCaptureThread:
|
||||||
|
"""Background thread that continuously captures and encodes frames from a camera."""
|
||||||
|
|
||||||
|
def __init__(self, camera: OpenCVCamera, name: str):
|
||||||
|
self.camera = camera
|
||||||
|
self.name = name
|
||||||
|
self.latest_encoded: str | None = None # Pre-encoded JPEG as base64
|
||||||
|
self.latest_timestamp: float = 0.0
|
||||||
|
self.frame_lock = threading.Lock()
|
||||||
|
self.running = False
|
||||||
|
self.thread: threading.Thread | None = None
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start the capture thread."""
|
||||||
|
self.running = True
|
||||||
|
self.thread = threading.Thread(target=self._capture_loop, daemon=True)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop the capture thread."""
|
||||||
|
self.running = False
|
||||||
|
if self.thread:
|
||||||
|
self.thread.join(timeout=1.0)
|
||||||
|
|
||||||
|
def _capture_loop(self):
|
||||||
|
"""Continuously capture and encode frames at the camera's native rate."""
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
frame = self.camera.read() # Blocks at camera's native rate
|
||||||
|
timestamp = time.time()
|
||||||
|
# Encode immediately in capture thread (this is the slow part)
|
||||||
|
encoded = encode_image(frame)
|
||||||
|
with self.frame_lock:
|
||||||
|
self.latest_encoded = encoded
|
||||||
|
self.latest_timestamp = timestamp
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Camera {self.name} capture error: {e}")
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
def get_latest(self) -> tuple[str | None, float]:
|
||||||
|
"""Get the latest encoded frame and its timestamp."""
|
||||||
|
with self.frame_lock:
|
||||||
|
return self.latest_encoded, self.latest_timestamp
|
||||||
|
|
||||||
|
|
||||||
class ImageServer:
|
class ImageServer:
|
||||||
def __init__(self, config: dict, port: int = 5555):
|
def __init__(self, config: dict, port: int = 5555):
|
||||||
|
# fps controls the publish loop rate (how often frames are sent over ZMQ), not the camera capture rate
|
||||||
self.fps = config.get("fps", 30)
|
self.fps = config.get("fps", 30)
|
||||||
self.cameras: dict[str, OpenCVCamera] = {}
|
self.cameras: dict[str, OpenCVCamera] = {}
|
||||||
|
self.capture_threads: dict[str, CameraCaptureThread] = {}
|
||||||
|
|
||||||
for name, cfg in config.get("cameras", {}).items():
|
for name, cfg in config.get("cameras", {}).items():
|
||||||
shape = cfg.get("shape", [480, 640])
|
shape = cfg.get("shape", [480, 640])
|
||||||
@@ -61,6 +109,10 @@ class ImageServer:
|
|||||||
self.cameras[name] = camera
|
self.cameras[name] = camera
|
||||||
logger.info(f"Camera {name}: {shape[1]}x{shape[0]}")
|
logger.info(f"Camera {name}: {shape[1]}x{shape[0]}")
|
||||||
|
|
||||||
|
# Create capture thread for this camera
|
||||||
|
capture_thread = CameraCaptureThread(camera, name)
|
||||||
|
self.capture_threads[name] = capture_thread
|
||||||
|
|
||||||
# ZMQ PUB socket
|
# ZMQ PUB socket
|
||||||
self.context = zmq.Context()
|
self.context = zmq.Context()
|
||||||
self.socket = self.context.socket(zmq.PUB)
|
self.socket = self.context.socket(zmq.PUB)
|
||||||
@@ -73,6 +125,18 @@ class ImageServer:
|
|||||||
def run(self):
|
def run(self):
|
||||||
frame_count = 0
|
frame_count = 0
|
||||||
frame_times = deque(maxlen=60)
|
frame_times = deque(maxlen=60)
|
||||||
|
last_published_ts: dict[str, float] = {}
|
||||||
|
|
||||||
|
# Start all capture threads
|
||||||
|
for capture_thread in self.capture_threads.values():
|
||||||
|
capture_thread.start()
|
||||||
|
|
||||||
|
# Wait for first frames to be captured and encoded
|
||||||
|
logger.info("Waiting for cameras to start capturing...")
|
||||||
|
for name, capture_thread in self.capture_threads.items():
|
||||||
|
while capture_thread.get_latest()[0] is None:
|
||||||
|
time.sleep(0.01)
|
||||||
|
logger.info(f"Camera {name} ready (capture + encode in background)")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
@@ -80,10 +144,12 @@ class ImageServer:
|
|||||||
|
|
||||||
# Build message
|
# Build message
|
||||||
message = {"timestamps": {}, "images": {}}
|
message = {"timestamps": {}, "images": {}}
|
||||||
for name, cam in self.cameras.items():
|
for name, capture_thread in self.capture_threads.items():
|
||||||
frame = cam.read() # Returns RGB
|
encoded, timestamp = capture_thread.get_latest()
|
||||||
message["timestamps"][name] = time.time()
|
if encoded is not None and timestamp > last_published_ts.get(name, 0.0):
|
||||||
message["images"][name] = encode_image(frame)
|
message["timestamps"][name] = timestamp
|
||||||
|
message["images"][name] = encoded
|
||||||
|
last_published_ts[name] = timestamp
|
||||||
|
|
||||||
# Send as JSON string (suppress if buffer full)
|
# Send as JSON string (suppress if buffer full)
|
||||||
with contextlib.suppress(zmq.Again):
|
with contextlib.suppress(zmq.Again):
|
||||||
@@ -102,6 +168,8 @@ class ImageServer:
|
|||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
|
for capture_thread in self.capture_threads.values():
|
||||||
|
capture_thread.stop()
|
||||||
for cam in self.cameras.values():
|
for cam in self.cameras.values():
|
||||||
cam.disconnect()
|
cam.disconnect()
|
||||||
self.socket.close()
|
self.socket.close()
|
||||||
|
|||||||
@@ -16,18 +16,13 @@
|
|||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from lerobot.datasets.transforms import ImageTransformsConfig
|
from lerobot.datasets.transforms import DatasetTransformStepConfig, ImageTransformsConfig
|
||||||
from lerobot.datasets.video_utils import get_safe_default_codec
|
from lerobot.datasets.video_utils import get_safe_default_codec
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DatasetConfig:
|
class DatasetConfig:
|
||||||
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
|
|
||||||
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
|
|
||||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
|
||||||
# datasets are provided.
|
|
||||||
repo_id: str
|
repo_id: str
|
||||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
|
||||||
root: str | None = None
|
root: str | None = None
|
||||||
episodes: list[int] | None = None
|
episodes: list[int] | None = None
|
||||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||||
@@ -37,6 +32,32 @@ class DatasetConfig:
|
|||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SubDatasetConfig:
|
||||||
|
"""Configuration for a single dataset within a MultiDatasetConfig."""
|
||||||
|
|
||||||
|
repo_id: str
|
||||||
|
root: str | None = None
|
||||||
|
episodes: list[int] | None = None
|
||||||
|
revision: str | None = None
|
||||||
|
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||||
|
weight: float = 1.0
|
||||||
|
# Maps dataset-local feature keys to unified policy keys.
|
||||||
|
# Keys not listed pass through unchanged.
|
||||||
|
feature_map: dict[str, str] = field(default_factory=dict)
|
||||||
|
# Per-dataset transforms applied after feature renaming, before cross-dataset padding.
|
||||||
|
transforms: list[DatasetTransformStepConfig] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiDatasetConfig:
|
||||||
|
"""Configuration for training on multiple datasets jointly."""
|
||||||
|
|
||||||
|
datasets: list[SubDatasetConfig] = field(default_factory=list)
|
||||||
|
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||||
|
use_imagenet_stats: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WandBConfig:
|
class WandBConfig:
|
||||||
enable: bool = False
|
enable: bool = False
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from huggingface_hub.errors import HfHubHTTPError
|
|||||||
|
|
||||||
from lerobot import envs
|
from lerobot import envs
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
from lerobot.configs.default import DatasetConfig, EvalConfig, MultiDatasetConfig, PeftConfig, WandBConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.optim import OptimizerConfig
|
from lerobot.optim import OptimizerConfig
|
||||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||||
@@ -35,7 +35,7 @@ TRAIN_CONFIG_NAME = "train_config.json"
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainPipelineConfig(HubMixin):
|
class TrainPipelineConfig(HubMixin):
|
||||||
dataset: DatasetConfig
|
dataset: DatasetConfig | MultiDatasetConfig
|
||||||
env: envs.EnvConfig | None = None
|
env: envs.EnvConfig | None = None
|
||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||||
@@ -50,6 +50,9 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||||
# AND for the evaluation environments.
|
# AND for the evaluation environments.
|
||||||
seed: int | None = 1000
|
seed: int | None = 1000
|
||||||
|
# Set to True to use deterministic cuDNN algorithms for reproducibility.
|
||||||
|
# This disables cudnn.benchmark and may reduce training speed by ~10-20%.
|
||||||
|
cudnn_deterministic: bool = False
|
||||||
# Number of workers for the dataloader.
|
# Number of workers for the dataloader.
|
||||||
num_workers: int = 4
|
num_workers: int = 4
|
||||||
batch_size: int = 8
|
batch_size: int = 8
|
||||||
@@ -126,8 +129,9 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
||||||
self.output_dir = Path("outputs/train") / train_dir
|
self.output_dir = Path("outputs/train") / train_dir
|
||||||
|
|
||||||
if isinstance(self.dataset.repo_id, list):
|
if isinstance(self.dataset, MultiDatasetConfig):
|
||||||
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
if len(self.dataset.datasets) < 1:
|
||||||
|
raise ValueError("MultiDatasetConfig.datasets must contain at least one sub-dataset.")
|
||||||
|
|
||||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||||
@@ -140,8 +144,7 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_rabc and not self.rabc_progress_path:
|
if self.use_rabc and not self.rabc_progress_path and isinstance(self.dataset, DatasetConfig):
|
||||||
# Auto-detect from dataset path
|
|
||||||
repo_id = self.dataset.repo_id
|
repo_id = self.dataset.repo_id
|
||||||
if self.dataset.root:
|
if self.dataset.root:
|
||||||
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
||||||
|
|||||||
@@ -18,13 +18,14 @@ from pprint import pformat
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.default import DatasetConfig, MultiDatasetConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.datasets.lerobot_dataset import (
|
from lerobot.datasets.lerobot_dataset import (
|
||||||
LeRobotDataset,
|
LeRobotDataset,
|
||||||
LeRobotDatasetMetadata,
|
LeRobotDatasetMetadata,
|
||||||
MultiLeRobotDataset,
|
|
||||||
)
|
)
|
||||||
|
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
|
||||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||||
from lerobot.datasets.transforms import ImageTransforms
|
from lerobot.datasets.transforms import ImageTransforms
|
||||||
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
|
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
|
||||||
@@ -68,66 +69,81 @@ def resolve_delta_timestamps(
|
|||||||
return delta_timestamps
|
return delta_timestamps
|
||||||
|
|
||||||
|
|
||||||
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
|
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | NewMultiLeRobotDataset:
|
||||||
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
|
"""Create a single or multi-dataset depending on the config type.
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: The MultiLeRobotDataset is currently deactivated.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LeRobotDataset | MultiLeRobotDataset
|
LeRobotDataset | NewMultiLeRobotDataset
|
||||||
"""
|
"""
|
||||||
|
if isinstance(cfg.dataset, MultiDatasetConfig):
|
||||||
|
return _make_multi_dataset(cfg)
|
||||||
|
|
||||||
|
return _make_single_dataset(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_single_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset:
|
||||||
|
ds_cfg: DatasetConfig = cfg.dataset # type: ignore[assignment]
|
||||||
image_transforms = (
|
image_transforms = (
|
||||||
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
ImageTransforms(ds_cfg.image_transforms) if ds_cfg.image_transforms.enable else None
|
||||||
)
|
)
|
||||||
|
ds_meta = LeRobotDatasetMetadata(ds_cfg.repo_id, root=ds_cfg.root, revision=ds_cfg.revision)
|
||||||
|
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||||
|
|
||||||
if isinstance(cfg.dataset.repo_id, str):
|
if not ds_cfg.streaming:
|
||||||
ds_meta = LeRobotDatasetMetadata(
|
dataset = LeRobotDataset(
|
||||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
ds_cfg.repo_id,
|
||||||
)
|
root=ds_cfg.root,
|
||||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
episodes=ds_cfg.episodes,
|
||||||
if not cfg.dataset.streaming:
|
delta_timestamps=delta_timestamps,
|
||||||
dataset = LeRobotDataset(
|
|
||||||
cfg.dataset.repo_id,
|
|
||||||
root=cfg.dataset.root,
|
|
||||||
episodes=cfg.dataset.episodes,
|
|
||||||
delta_timestamps=delta_timestamps,
|
|
||||||
image_transforms=image_transforms,
|
|
||||||
revision=cfg.dataset.revision,
|
|
||||||
video_backend=cfg.dataset.video_backend,
|
|
||||||
tolerance_s=cfg.tolerance_s,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
dataset = StreamingLeRobotDataset(
|
|
||||||
cfg.dataset.repo_id,
|
|
||||||
root=cfg.dataset.root,
|
|
||||||
episodes=cfg.dataset.episodes,
|
|
||||||
delta_timestamps=delta_timestamps,
|
|
||||||
image_transforms=image_transforms,
|
|
||||||
revision=cfg.dataset.revision,
|
|
||||||
max_num_shards=cfg.num_workers,
|
|
||||||
tolerance_s=cfg.tolerance_s,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
|
||||||
dataset = MultiLeRobotDataset(
|
|
||||||
cfg.dataset.repo_id,
|
|
||||||
# TODO(aliberts): add proper support for multi dataset
|
|
||||||
# delta_timestamps=delta_timestamps,
|
|
||||||
image_transforms=image_transforms,
|
image_transforms=image_transforms,
|
||||||
video_backend=cfg.dataset.video_backend,
|
revision=ds_cfg.revision,
|
||||||
|
video_backend=ds_cfg.video_backend,
|
||||||
|
tolerance_s=cfg.tolerance_s,
|
||||||
)
|
)
|
||||||
logging.info(
|
else:
|
||||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
dataset = StreamingLeRobotDataset(
|
||||||
f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
ds_cfg.repo_id,
|
||||||
|
root=ds_cfg.root,
|
||||||
|
episodes=ds_cfg.episodes,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
image_transforms=image_transforms,
|
||||||
|
revision=ds_cfg.revision,
|
||||||
|
max_num_shards=cfg.num_workers,
|
||||||
|
tolerance_s=cfg.tolerance_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.dataset.use_imagenet_stats:
|
if ds_cfg.use_imagenet_stats:
|
||||||
for key in dataset.meta.camera_keys:
|
for key in dataset.meta.camera_keys:
|
||||||
for stats_type, stats in IMAGENET_STATS.items():
|
for stats_type, stats_val in IMAGENET_STATS.items():
|
||||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
dataset.meta.stats[key][stats_type] = torch.tensor(stats_val, dtype=torch.float32)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def _make_multi_dataset(cfg: TrainPipelineConfig) -> NewMultiLeRobotDataset:
|
||||||
|
multi_cfg: MultiDatasetConfig = cfg.dataset # type: ignore[assignment]
|
||||||
|
image_transforms = (
|
||||||
|
ImageTransforms(multi_cfg.image_transforms) if multi_cfg.image_transforms.enable else None
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = NewMultiLeRobotDataset(
|
||||||
|
configs=multi_cfg.datasets,
|
||||||
|
image_transforms=image_transforms,
|
||||||
|
tolerance_s=cfg.tolerance_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
"MultiLeRobotDataset created with %d sub-datasets:\n%s",
|
||||||
|
len(multi_cfg.datasets),
|
||||||
|
pformat(
|
||||||
|
{i: c.repo_id for i, c in enumerate(multi_cfg.datasets)},
|
||||||
|
indent=2,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if multi_cfg.use_imagenet_stats:
|
||||||
|
for key in dataset.meta.camera_keys:
|
||||||
|
for stats_type, stats_val in IMAGENET_STATS.items():
|
||||||
|
dataset.meta.stats[key][stats_type] = torch.tensor(stats_val, dtype=torch.float32)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|||||||
364
src/lerobot/datasets/multi_dataset.py
Normal file
364
src/lerobot/datasets/multi_dataset.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
"""MultiLeRobotDataset: joint training over heterogeneous LeRobot datasets.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- Per-dataset feature mapping (rename keys to a unified namespace)
|
||||||
|
- Automatic zero-padding for features missing in some datasets
|
||||||
|
- Per-dataset transform pipelines
|
||||||
|
- Weighted sampling via dataset weights
|
||||||
|
- Aggregated stats across all sub-datasets
|
||||||
|
- A ``meta`` shim compatible with EpisodeAwareSampler and make_policy
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
from lerobot.configs.default import SubDatasetConfig
|
||||||
|
from lerobot.datasets.compute_stats import aggregate_stats
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.datasets.transforms import DatasetTransformPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class MultiDatasetMeta:
|
||||||
|
"""Lightweight metadata shim that exposes the same interface as ``LeRobotDatasetMetadata``.
|
||||||
|
|
||||||
|
Built by aggregating the metadata of multiple sub-datasets after their
|
||||||
|
feature keys have been mapped to a unified namespace.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
datasets: list[LeRobotDataset],
|
||||||
|
feature_maps: list[dict[str, str]],
|
||||||
|
):
|
||||||
|
self._datasets = datasets
|
||||||
|
self._feature_maps = feature_maps
|
||||||
|
|
||||||
|
self._unified_features = self._build_unified_features()
|
||||||
|
self._episodes = self._build_episodes()
|
||||||
|
self._stats = self._build_stats()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Feature union
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _build_unified_features(self) -> dict[str, dict]:
|
||||||
|
"""Build feature dict as the *union* of all mapped feature keys."""
|
||||||
|
unified: dict[str, dict] = {}
|
||||||
|
for ds, fmap in zip(self._datasets, self._feature_maps):
|
||||||
|
for original_key, feat_info in ds.meta.features.items():
|
||||||
|
mapped_key = fmap.get(original_key, original_key)
|
||||||
|
if mapped_key not in unified:
|
||||||
|
unified[mapped_key] = dict(feat_info)
|
||||||
|
else:
|
||||||
|
existing_shape = tuple(unified[mapped_key]["shape"])
|
||||||
|
new_shape = tuple(feat_info["shape"])
|
||||||
|
if existing_shape != new_shape and unified[mapped_key]["dtype"] == feat_info["dtype"]:
|
||||||
|
logging.warning(
|
||||||
|
"Feature '%s' has shape %s in one dataset but %s in another. "
|
||||||
|
"The larger shape will be used (padding applied automatically).",
|
||||||
|
mapped_key,
|
||||||
|
existing_shape,
|
||||||
|
new_shape,
|
||||||
|
)
|
||||||
|
if np.prod(new_shape) > np.prod(existing_shape):
|
||||||
|
unified[mapped_key] = dict(feat_info)
|
||||||
|
return unified
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Episode metadata (global flat indexing)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _build_episodes(self) -> dict[str, list]:
|
||||||
|
"""Concatenate episode boundaries across sub-datasets with frame offsets.
|
||||||
|
|
||||||
|
Produces the same column structure as ``load_episodes()`` so that
|
||||||
|
``EpisodeAwareSampler`` and ``WeightedEpisodeAwareSampler`` can consume it.
|
||||||
|
"""
|
||||||
|
from_indices: list[int] = []
|
||||||
|
to_indices: list[int] = []
|
||||||
|
dataset_source: list[int] = []
|
||||||
|
|
||||||
|
frame_offset = 0
|
||||||
|
for ds_idx, ds in enumerate(self._datasets):
|
||||||
|
eps = ds.meta.episodes
|
||||||
|
for ep in eps:
|
||||||
|
from_indices.append(ep["dataset_from_index"] + frame_offset)
|
||||||
|
to_indices.append(ep["dataset_to_index"] + frame_offset)
|
||||||
|
dataset_source.append(ds_idx)
|
||||||
|
frame_offset += ds.num_frames
|
||||||
|
|
||||||
|
return {
|
||||||
|
"dataset_from_index": from_indices,
|
||||||
|
"dataset_to_index": to_indices,
|
||||||
|
"dataset_source": dataset_source,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Stats aggregation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _build_stats(self) -> dict[str, dict[str, np.ndarray]]:
|
||||||
|
"""Aggregate stats across sub-datasets using mapped feature keys."""
|
||||||
|
mapped_stats_list: list[dict[str, dict]] = []
|
||||||
|
for ds, fmap in zip(self._datasets, self._feature_maps):
|
||||||
|
reverse_map = {v: k for k, v in fmap.items()}
|
||||||
|
mapped: dict[str, dict] = {}
|
||||||
|
for unified_key in self._unified_features:
|
||||||
|
original_key = reverse_map.get(unified_key, unified_key)
|
||||||
|
if original_key in ds.meta.stats:
|
||||||
|
mapped[unified_key] = ds.meta.stats[original_key]
|
||||||
|
mapped_stats_list.append(mapped)
|
||||||
|
|
||||||
|
return aggregate_stats(mapped_stats_list)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Properties matching LeRobotDatasetMetadata API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@property
|
||||||
|
def features(self) -> dict[str, dict]:
|
||||||
|
return self._unified_features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_keys(self) -> list[str]:
|
||||||
|
return [k for k, f in self._unified_features.items() if f["dtype"] == "image"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def video_keys(self) -> list[str]:
|
||||||
|
return [k for k, f in self._unified_features.items() if f["dtype"] == "video"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_keys(self) -> list[str]:
|
||||||
|
return [k for k, f in self._unified_features.items() if f["dtype"] in ("video", "image")]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def names(self) -> dict[str, list | dict]:
|
||||||
|
return {k: f["names"] for k, f in self._unified_features.items()}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shapes(self) -> dict[str, tuple]:
|
||||||
|
return {k: tuple(f["shape"]) for k, f in self._unified_features.items()}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fps(self) -> int:
|
||||||
|
fps_values = {ds.meta.fps for ds in self._datasets}
|
||||||
|
if len(fps_values) > 1:
|
||||||
|
logging.warning("Sub-datasets have different FPS values: %s. Using the first.", fps_values)
|
||||||
|
return self._datasets[0].meta.fps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stats(self) -> dict[str, dict[str, np.ndarray]]:
|
||||||
|
return self._stats
|
||||||
|
|
||||||
|
@stats.setter
|
||||||
|
def stats(self, value: dict):
|
||||||
|
self._stats = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def episodes(self) -> dict[str, list]:
|
||||||
|
return self._episodes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_episodes(self) -> int:
|
||||||
|
return sum(ds.meta.total_episodes for ds in self._datasets)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_frames(self) -> int:
|
||||||
|
return sum(ds.meta.total_frames for ds in self._datasets)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tasks(self) -> int:
|
||||||
|
return sum(ds.meta.total_tasks for ds in self._datasets)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def info(self) -> dict:
|
||||||
|
return {
|
||||||
|
"fps": self.fps,
|
||||||
|
"features": self._unified_features,
|
||||||
|
"total_episodes": self.total_episodes,
|
||||||
|
"total_frames": self.total_frames,
|
||||||
|
"total_tasks": self.total_tasks,
|
||||||
|
"codebase_version": "v3.0",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class NewMultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
"""Dataset that wraps multiple ``LeRobotDataset`` instances with feature mapping and padding.
|
||||||
|
|
||||||
|
Each sub-dataset can have different feature names and shapes. A per-dataset
|
||||||
|
``feature_map`` renames keys into a shared namespace. Features that a given
|
||||||
|
sub-dataset does not provide are zero-padded so every ``__getitem__`` returns
|
||||||
|
the full unified feature set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
configs: list[SubDatasetConfig],
|
||||||
|
image_transforms: Callable | None = None,
|
||||||
|
delta_timestamps: dict[str, list[float]] | None = None,
|
||||||
|
tolerance_s: float = 1e-4,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._configs = configs
|
||||||
|
self.image_transforms = image_transforms
|
||||||
|
|
||||||
|
self._datasets: list[LeRobotDataset] = []
|
||||||
|
self._feature_maps: list[dict[str, str]] = []
|
||||||
|
self._transform_pipelines: list[DatasetTransformPipeline | None] = []
|
||||||
|
self._weights: list[float] = []
|
||||||
|
|
||||||
|
for cfg in configs:
|
||||||
|
ds = LeRobotDataset(
|
||||||
|
repo_id=cfg.repo_id,
|
||||||
|
root=cfg.root,
|
||||||
|
episodes=cfg.episodes,
|
||||||
|
image_transforms=image_transforms,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
tolerance_s=tolerance_s,
|
||||||
|
revision=cfg.revision,
|
||||||
|
video_backend=cfg.video_backend,
|
||||||
|
)
|
||||||
|
self._datasets.append(ds)
|
||||||
|
self._feature_maps.append(cfg.feature_map or {})
|
||||||
|
self._transform_pipelines.append(
|
||||||
|
DatasetTransformPipeline(cfg.transforms) if cfg.transforms else None
|
||||||
|
)
|
||||||
|
self._weights.append(cfg.weight)
|
||||||
|
|
||||||
|
self._meta = MultiDatasetMeta(self._datasets, self._feature_maps)
|
||||||
|
|
||||||
|
# Pre-compute cumulative frame counts for fast index mapping.
|
||||||
|
self._cumulative_frames: list[int] = []
|
||||||
|
total = 0
|
||||||
|
for ds in self._datasets:
|
||||||
|
total += ds.num_frames
|
||||||
|
self._cumulative_frames.append(total)
|
||||||
|
|
||||||
|
# Build reverse maps (unified_key -> original_key) per dataset for padding.
|
||||||
|
self._reverse_maps: list[dict[str, str]] = []
|
||||||
|
for fmap in self._feature_maps:
|
||||||
|
self._reverse_maps.append({v: k for k, v in fmap.items()})
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
"MultiLeRobotDataset: %d sub-datasets, %d total frames, %d total episodes, "
|
||||||
|
"%d unified features",
|
||||||
|
len(self._datasets),
|
||||||
|
self.num_frames,
|
||||||
|
self.num_episodes,
|
||||||
|
len(self._meta.features),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public interface
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@property
|
||||||
|
def meta(self) -> MultiDatasetMeta:
|
||||||
|
return self._meta
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset_weights(self) -> list[float]:
|
||||||
|
return self._weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_frames(self) -> int:
|
||||||
|
return self._cumulative_frames[-1] if self._cumulative_frames else 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_episodes(self) -> int:
|
||||||
|
return sum(ds.num_episodes for ds in self._datasets)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def episodes(self) -> list[int] | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fps(self) -> int:
|
||||||
|
return self._meta.fps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def features(self) -> dict[str, dict]:
|
||||||
|
return self._meta.features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_keys(self) -> list[str]:
|
||||||
|
return self._meta.camera_keys
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Indexing
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _locate(self, idx: int) -> tuple[int, int]:
|
||||||
|
"""Map a global frame index to (dataset_index, local_index)."""
|
||||||
|
for ds_idx, cum in enumerate(self._cumulative_frames):
|
||||||
|
if idx < cum:
|
||||||
|
local = idx - (self._cumulative_frames[ds_idx - 1] if ds_idx > 0 else 0)
|
||||||
|
return ds_idx, local
|
||||||
|
raise IndexError(f"Index {idx} out of range (total {self.num_frames})")
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.num_frames
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||||
|
ds_idx, local_idx = self._locate(idx)
|
||||||
|
item = self._datasets[ds_idx][local_idx]
|
||||||
|
|
||||||
|
# 1. Rename keys according to feature_map.
|
||||||
|
fmap = self._feature_maps[ds_idx]
|
||||||
|
if fmap:
|
||||||
|
renamed: dict[str, torch.Tensor] = {}
|
||||||
|
for key, value in item.items():
|
||||||
|
renamed[fmap.get(key, key)] = value
|
||||||
|
item = renamed
|
||||||
|
|
||||||
|
# 2. Apply per-dataset transform pipeline.
|
||||||
|
pipeline = self._transform_pipelines[ds_idx]
|
||||||
|
if pipeline is not None:
|
||||||
|
item = pipeline(item)
|
||||||
|
|
||||||
|
# 3. Pad missing features with zeros.
|
||||||
|
reverse_map = self._reverse_maps[ds_idx]
|
||||||
|
ds_features = self._datasets[ds_idx].meta.features
|
||||||
|
for unified_key, feat_info in self._meta.features.items():
|
||||||
|
if unified_key in item:
|
||||||
|
continue
|
||||||
|
original_key = reverse_map.get(unified_key, unified_key)
|
||||||
|
if original_key in ds_features:
|
||||||
|
continue
|
||||||
|
shape = tuple(feat_info["shape"])
|
||||||
|
dtype = feat_info["dtype"]
|
||||||
|
if dtype in ("video", "image"):
|
||||||
|
# Camera tensors are (C, H, W) after transforms.
|
||||||
|
c, h, w = (shape[2], shape[0], shape[1]) if len(shape) == 3 else (3, shape[0], shape[1])
|
||||||
|
item[unified_key] = torch.zeros(c, h, w, dtype=torch.float32)
|
||||||
|
elif dtype in ("float32", "float64"):
|
||||||
|
item[unified_key] = torch.zeros(shape, dtype=torch.float32)
|
||||||
|
elif dtype in ("int32", "int64"):
|
||||||
|
item[unified_key] = torch.zeros(shape, dtype=torch.int64)
|
||||||
|
elif dtype == "bool":
|
||||||
|
item[unified_key] = torch.zeros(shape, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
item[unified_key] = torch.zeros(shape, dtype=torch.float32)
|
||||||
|
item[f"{unified_key}_is_pad"] = torch.tensor(True)
|
||||||
|
|
||||||
|
# 4. Tag which dataset this sample came from.
|
||||||
|
item["dataset_index"] = torch.tensor(ds_idx)
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
repo_ids = [c.repo_id for c in self._configs]
|
||||||
|
return (
|
||||||
|
f"NewMultiLeRobotDataset(\n"
|
||||||
|
f" repo_ids={repo_ids},\n"
|
||||||
|
f" num_frames={self.num_frames},\n"
|
||||||
|
f" num_episodes={self.num_episodes},\n"
|
||||||
|
f" unified_features={list(self._meta.features.keys())},\n"
|
||||||
|
f" weights={self._weights},\n"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
@@ -59,3 +59,80 @@ class EpisodeAwareSampler:
|
|||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.indices)
|
return len(self.indices)
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedEpisodeAwareSampler:
|
||||||
|
"""Sampler that draws frames from multiple datasets according to per-dataset weights.
|
||||||
|
|
||||||
|
Each iteration first selects a sub-dataset proportionally to its weight, then
|
||||||
|
uniformly samples a frame from that sub-dataset's valid index set. Episode
|
||||||
|
boundary information is respected so that dropped frames are excluded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_from_indices: Start index for each episode (global, flat).
|
||||||
|
dataset_to_indices: End index (exclusive) for each episode (global, flat).
|
||||||
|
dataset_membership: Which sub-dataset each episode belongs to (integer id).
|
||||||
|
dataset_weights: Relative sampling weight per sub-dataset.
|
||||||
|
episode_indices_to_use: If given, only episodes in this set are used.
|
||||||
|
drop_n_first_frames: Frames to skip at the start of each episode.
|
||||||
|
drop_n_last_frames: Frames to skip at the end of each episode.
|
||||||
|
shuffle: Whether to shuffle within each epoch.
|
||||||
|
num_samples: How many samples per epoch. Defaults to total valid frames.
|
||||||
|
generator: Optional torch.Generator for reproducibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_from_indices: list[int],
|
||||||
|
dataset_to_indices: list[int],
|
||||||
|
dataset_membership: list[int],
|
||||||
|
dataset_weights: list[float],
|
||||||
|
episode_indices_to_use: list | None = None,
|
||||||
|
drop_n_first_frames: int = 0,
|
||||||
|
drop_n_last_frames: int = 0,
|
||||||
|
shuffle: bool = False,
|
||||||
|
num_samples: int | None = None,
|
||||||
|
generator: torch.Generator | None = None,
|
||||||
|
):
|
||||||
|
n_datasets = max(dataset_membership) + 1 if dataset_membership else 0
|
||||||
|
self._per_dataset_indices: list[list[int]] = [[] for _ in range(n_datasets)]
|
||||||
|
|
||||||
|
episodes_to_use = set(episode_indices_to_use) if episode_indices_to_use is not None else None
|
||||||
|
|
||||||
|
for ep_idx, (start, end, ds_id) in enumerate(
|
||||||
|
zip(dataset_from_indices, dataset_to_indices, dataset_membership, strict=True)
|
||||||
|
):
|
||||||
|
if episodes_to_use is not None and ep_idx not in episodes_to_use:
|
||||||
|
continue
|
||||||
|
frame_range = range(start + drop_n_first_frames, end - drop_n_last_frames)
|
||||||
|
self._per_dataset_indices[ds_id].extend(frame_range)
|
||||||
|
|
||||||
|
# Normalise weights (only over datasets that actually have frames).
|
||||||
|
raw_weights = list(dataset_weights[:n_datasets])
|
||||||
|
self._weights = torch.zeros(n_datasets)
|
||||||
|
for i, w in enumerate(raw_weights):
|
||||||
|
if len(self._per_dataset_indices[i]) > 0:
|
||||||
|
self._weights[i] = w
|
||||||
|
total_w = self._weights.sum()
|
||||||
|
if total_w > 0:
|
||||||
|
self._weights /= total_w
|
||||||
|
|
||||||
|
self._total_frames = sum(len(idx) for idx in self._per_dataset_indices)
|
||||||
|
self._num_samples = num_samples if num_samples is not None else self._total_frames
|
||||||
|
self.shuffle = shuffle
|
||||||
|
self._generator = generator
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[int]:
|
||||||
|
if not self.shuffle:
|
||||||
|
for ds_indices in self._per_dataset_indices:
|
||||||
|
yield from ds_indices
|
||||||
|
return
|
||||||
|
|
||||||
|
for _ in range(self._num_samples):
|
||||||
|
ds_id = int(torch.multinomial(self._weights, 1, generator=self._generator).item())
|
||||||
|
indices = self._per_dataset_indices[ds_id]
|
||||||
|
local_idx = int(torch.randint(len(indices), (1,), generator=self._generator).item())
|
||||||
|
yield indices[local_idx]
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self._num_samples
|
||||||
|
|||||||
@@ -14,11 +14,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import collections
|
import collections
|
||||||
|
import logging
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F_nn
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
from torchvision.transforms.v2 import (
|
from torchvision.transforms.v2 import (
|
||||||
Transform,
|
Transform,
|
||||||
@@ -258,3 +260,114 @@ class ImageTransforms(Transform):
|
|||||||
|
|
||||||
def forward(self, *inputs: Any) -> Any:
|
def forward(self, *inputs: Any) -> Any:
|
||||||
return self.tf(*inputs)
|
return self.tf(*inputs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Per-dataset transform pipeline (used by MultiLeRobotDataset)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetTransformStepConfig:
|
||||||
|
"""Config for a single per-dataset transform step."""
|
||||||
|
|
||||||
|
type: str
|
||||||
|
kwargs: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
_DATASET_TRANSFORM_REGISTRY: dict[str, type["DatasetTransformStep"]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_dataset_transform(name: str):
|
||||||
|
"""Decorator to register a DatasetTransformStep by name."""
|
||||||
|
|
||||||
|
def decorator(cls: type["DatasetTransformStep"]) -> type["DatasetTransformStep"]:
|
||||||
|
_DATASET_TRANSFORM_REGISTRY[name] = cls
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTransformStep:
|
||||||
|
"""Base class for a single per-dataset transform applied to a sample dict."""
|
||||||
|
|
||||||
|
def __call__(self, sample: dict) -> dict:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@register_dataset_transform("pad_action")
|
||||||
|
class PadAction(DatasetTransformStep):
|
||||||
|
"""Zero-pad the ``action`` tensor to *target_dim* along the last axis."""
|
||||||
|
|
||||||
|
def __init__(self, target_dim: int):
|
||||||
|
self.target_dim = target_dim
|
||||||
|
|
||||||
|
def __call__(self, sample: dict) -> dict:
|
||||||
|
action = sample.get("action")
|
||||||
|
if action is None:
|
||||||
|
return sample
|
||||||
|
current = action.shape[-1]
|
||||||
|
if current < self.target_dim:
|
||||||
|
sample["action"] = F_nn.pad(action, (0, self.target_dim - current))
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
@register_dataset_transform("pad_state")
|
||||||
|
class PadState(DatasetTransformStep):
|
||||||
|
"""Zero-pad ``observation.state`` to *target_dim* along the last axis."""
|
||||||
|
|
||||||
|
def __init__(self, target_dim: int):
|
||||||
|
self.target_dim = target_dim
|
||||||
|
|
||||||
|
def __call__(self, sample: dict) -> dict:
|
||||||
|
state = sample.get("observation.state")
|
||||||
|
if state is None:
|
||||||
|
return sample
|
||||||
|
current = state.shape[-1]
|
||||||
|
if current < self.target_dim:
|
||||||
|
sample["observation.state"] = F_nn.pad(state, (0, self.target_dim - current))
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
@register_dataset_transform("resize_images")
|
||||||
|
class ResizeImages(DatasetTransformStep):
|
||||||
|
"""Resize all image/video camera tensors to (height, width)."""
|
||||||
|
|
||||||
|
def __init__(self, height: int, width: int):
|
||||||
|
self.size = (height, width)
|
||||||
|
|
||||||
|
def __call__(self, sample: dict) -> dict:
|
||||||
|
for key in list(sample.keys()):
|
||||||
|
if not key.startswith("observation.images."):
|
||||||
|
continue
|
||||||
|
img = sample[key]
|
||||||
|
if not isinstance(img, torch.Tensor) or img.ndim < 3:
|
||||||
|
continue
|
||||||
|
sample[key] = F.resize(img, self.size, antialias=True)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTransformPipeline:
|
||||||
|
"""Sequential pipeline of DatasetTransformStep instances."""
|
||||||
|
|
||||||
|
def __init__(self, configs: list[DatasetTransformStepConfig] | None = None):
|
||||||
|
self.steps: list[DatasetTransformStep] = []
|
||||||
|
if configs:
|
||||||
|
for cfg in configs:
|
||||||
|
self.steps.append(self._build(cfg))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build(cfg: DatasetTransformStepConfig) -> DatasetTransformStep:
|
||||||
|
cls = _DATASET_TRANSFORM_REGISTRY.get(cfg.type)
|
||||||
|
if cls is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown dataset transform '{cfg.type}'. "
|
||||||
|
f"Available: {list(_DATASET_TRANSFORM_REGISTRY)}"
|
||||||
|
)
|
||||||
|
return cls(**cfg.kwargs)
|
||||||
|
|
||||||
|
def __call__(self, sample: dict) -> dict:
|
||||||
|
for step in self.steps:
|
||||||
|
sample = step(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"DatasetTransformPipeline(steps={self.steps})"
|
||||||
|
|||||||
@@ -346,6 +346,105 @@ class LiberoEnv(EnvConfig):
|
|||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
@EnvConfig.register_subclass("libero_plus")
|
||||||
|
@dataclass
|
||||||
|
class LiberoPlusEnv(LiberoEnv):
|
||||||
|
"""Alias config for LIBERO-plus benchmarks.
|
||||||
|
|
||||||
|
LIBERO-plus keeps the same Python package/module names as LIBERO, so this
|
||||||
|
config reuses the existing LIBERO env implementation while making intent explicit
|
||||||
|
in experiment configs (`env.type=libero_plus`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
task: str = "libero_spatial"
|
||||||
|
|
||||||
|
|
||||||
|
@EnvConfig.register_subclass("robocasa")
|
||||||
|
@dataclass
|
||||||
|
class RoboCasaEnv(EnvConfig):
|
||||||
|
"""RoboCasa kitchen composite-task environments.
|
||||||
|
|
||||||
|
Wraps ``robocasa.wrappers.gym_wrapper.RoboCasaGymEnv`` with a flat 12-D Box
|
||||||
|
action space and a structured pixel + state observation dict.
|
||||||
|
|
||||||
|
Selected benchmark tasks (3 short + 2 long):
|
||||||
|
Short: PickPlaceCounterToCabinet, PrepareToast, CoffeeSetupMug
|
||||||
|
Long: PrepareCoffee, RestockPantry
|
||||||
|
"""
|
||||||
|
|
||||||
|
task: str = "PickPlaceCounterToCabinet"
|
||||||
|
tasks: list[str] | None = None # multi-task: list of task names (without robocasa/ prefix)
|
||||||
|
fps: int = 20
|
||||||
|
episode_length: int = 500
|
||||||
|
image_size: int = 128
|
||||||
|
split: str = "target" # "pretrain" or "target"
|
||||||
|
features: dict[str, PolicyFeature] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(12,)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
features_map: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
ACTION: ACTION,
|
||||||
|
"agentview_left": f"{OBS_IMAGES}.agentview_left",
|
||||||
|
"agentview_right": f"{OBS_IMAGES}.agentview_right",
|
||||||
|
"eye_in_hand": f"{OBS_IMAGES}.eye_in_hand",
|
||||||
|
"robot_state": OBS_STATE,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
|
||||||
|
self.features[cam] = PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL, shape=(self.image_size, self.image_size, 3)
|
||||||
|
)
|
||||||
|
self.features["robot_state"] = PolicyFeature(type=FeatureType.STATE, shape=(16,))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gym_kwargs(self) -> dict:
|
||||||
|
return {"split": self.split}
|
||||||
|
|
||||||
|
|
||||||
|
@EnvConfig.register_subclass("robomme")
|
||||||
|
@dataclass
|
||||||
|
class RoboMMEEnv(EnvConfig):
|
||||||
|
"""RoboMME memory-augmented manipulation benchmark (ManiSkill/SAPIEN).
|
||||||
|
|
||||||
|
16 tasks across 4 suites: Counting, Permanence, Reference, Imitation.
|
||||||
|
Uses BenchmarkEnvBuilder from the robomme package.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task: str = "PickXtimes"
|
||||||
|
fps: int = 10
|
||||||
|
episode_length: int = 300
|
||||||
|
action_space: str = "joint_angle"
|
||||||
|
dataset_split: str = "test"
|
||||||
|
task_ids: list[int] | None = None
|
||||||
|
features: dict[str, PolicyFeature] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(8,)),
|
||||||
|
"front_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
||||||
|
"wrist_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
||||||
|
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
features_map: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
ACTION: ACTION,
|
||||||
|
"front_rgb": f"{OBS_IMAGES}.front",
|
||||||
|
"wrist_rgb": f"{OBS_IMAGES}.wrist",
|
||||||
|
OBS_STATE: OBS_STATE,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gym_kwargs(self) -> dict:
|
||||||
|
return {
|
||||||
|
"action_space": self.action_space,
|
||||||
|
"dataset": self.dataset_split,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass("metaworld")
|
@EnvConfig.register_subclass("metaworld")
|
||||||
@dataclass
|
@dataclass
|
||||||
class MetaworldEnv(EnvConfig):
|
class MetaworldEnv(EnvConfig):
|
||||||
|
|||||||
@@ -20,11 +20,21 @@ import gymnasium as gym
|
|||||||
from gymnasium.envs.registration import registry as gym_registry
|
from gymnasium.envs.registration import registry as gym_registry
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
|
from lerobot.envs.configs import (
|
||||||
|
AlohaEnv,
|
||||||
|
EnvConfig,
|
||||||
|
HubEnvConfig,
|
||||||
|
IsaaclabArenaEnv,
|
||||||
|
LiberoEnv,
|
||||||
|
LiberoPlusEnv,
|
||||||
|
PushtEnv,
|
||||||
|
RoboCasaEnv,
|
||||||
|
RoboMMEEnv,
|
||||||
|
)
|
||||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
from lerobot.processor import ProcessorStep
|
from lerobot.processor import ProcessorStep
|
||||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
|
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep, RoboCasaProcessorStep
|
||||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
|
|
||||||
|
|
||||||
@@ -35,6 +45,12 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
|||||||
return PushtEnv(**kwargs)
|
return PushtEnv(**kwargs)
|
||||||
elif env_type == "libero":
|
elif env_type == "libero":
|
||||||
return LiberoEnv(**kwargs)
|
return LiberoEnv(**kwargs)
|
||||||
|
elif env_type == "libero_plus":
|
||||||
|
return LiberoPlusEnv(**kwargs)
|
||||||
|
elif env_type == "robocasa":
|
||||||
|
return RoboCasaEnv(**kwargs)
|
||||||
|
elif env_type == "robomme":
|
||||||
|
return RoboMMEEnv(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||||
|
|
||||||
@@ -70,9 +86,13 @@ def make_env_pre_post_processors(
|
|||||||
return make_xvla_libero_pre_post_processors()
|
return make_xvla_libero_pre_post_processors()
|
||||||
|
|
||||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
if isinstance(env_cfg, (LiberoEnv, LiberoPlusEnv)) or "libero" in env_cfg.type:
|
||||||
preprocessor_steps.append(LiberoProcessorStep())
|
preprocessor_steps.append(LiberoProcessorStep())
|
||||||
|
|
||||||
|
# For RoboCasa environments, add the RoboCasaProcessorStep to preprocessor
|
||||||
|
if isinstance(env_cfg, RoboCasaEnv) or "robocasa" in env_cfg.type:
|
||||||
|
preprocessor_steps.append(RoboCasaProcessorStep())
|
||||||
|
|
||||||
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
|
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
|
||||||
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
|
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
|
||||||
# Parse comma-separated keys (handle None for state-based policies)
|
# Parse comma-separated keys (handle None for state-based policies)
|
||||||
@@ -181,6 +201,33 @@ def make_env(
|
|||||||
control_mode=cfg.control_mode,
|
control_mode=cfg.control_mode,
|
||||||
episode_length=cfg.episode_length,
|
episode_length=cfg.episode_length,
|
||||||
)
|
)
|
||||||
|
elif "robocasa" in cfg.type:
|
||||||
|
from lerobot.envs.robocasa import create_robocasa_envs
|
||||||
|
|
||||||
|
tasks = cfg.tasks if cfg.tasks else [cfg.task]
|
||||||
|
return create_robocasa_envs(
|
||||||
|
tasks=tasks,
|
||||||
|
n_envs=n_envs,
|
||||||
|
image_size=cfg.image_size,
|
||||||
|
split=cfg.split,
|
||||||
|
episode_length=cfg.episode_length,
|
||||||
|
gym_kwargs=cfg.gym_kwargs,
|
||||||
|
env_cls=env_cls,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif "robomme" in cfg.type:
|
||||||
|
from lerobot.envs.robomme import create_robomme_envs
|
||||||
|
|
||||||
|
return create_robomme_envs(
|
||||||
|
task=cfg.task,
|
||||||
|
n_envs=n_envs,
|
||||||
|
action_space_type=cfg.action_space,
|
||||||
|
dataset=cfg.dataset_split,
|
||||||
|
episode_length=cfg.episode_length,
|
||||||
|
task_ids=cfg.task_ids,
|
||||||
|
env_cls=env_cls,
|
||||||
|
)
|
||||||
|
|
||||||
elif "metaworld" in cfg.type:
|
elif "metaworld" in cfg.type:
|
||||||
from lerobot.envs.metaworld import create_metaworld_envs
|
from lerobot.envs.metaworld import create_metaworld_envs
|
||||||
|
|
||||||
|
|||||||
@@ -26,8 +26,14 @@ import gymnasium as gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
from libero.libero import benchmark, get_libero_path
|
|
||||||
from libero.libero.envs import OffScreenRenderEnv
|
try:
|
||||||
|
from libero.libero import benchmark, get_libero_path
|
||||||
|
from libero.libero.envs import OffScreenRenderEnv
|
||||||
|
except ImportError:
|
||||||
|
# LIBERO-plus may be installed from source with an extra nested package level.
|
||||||
|
from libero.libero.libero import benchmark, get_libero_path
|
||||||
|
from libero.libero.libero.envs import OffScreenRenderEnv
|
||||||
|
|
||||||
from lerobot.processor import RobotObservation
|
from lerobot.processor import RobotObservation
|
||||||
|
|
||||||
|
|||||||
273
src/lerobot/envs/robocasa.py
Normal file
273
src/lerobot/envs/robocasa.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
|
||||||
|
# Action layout (flat 12D, normalized to [-1, 1]):
|
||||||
|
# [0:3] end_effector_position (delta x, y, z)
|
||||||
|
# [3:6] end_effector_rotation (delta roll, pitch, yaw)
|
||||||
|
# [6:7] gripper_close (open=-1, close=+1)
|
||||||
|
# [7:11] base_motion (x, y, theta, torso_height)
|
||||||
|
# [11:12] control_mode (arm=-1, base=+1)
|
||||||
|
ACTION_DIM = 12
|
||||||
|
ACTION_LOW = -1.0
|
||||||
|
ACTION_HIGH = 1.0
|
||||||
|
|
||||||
|
# Proprioceptive state layout (flat 16D):
|
||||||
|
# [0:2] gripper_qpos
|
||||||
|
# [2:5] base_position
|
||||||
|
# [5:9] base_rotation (quaternion)
|
||||||
|
# [9:12] end_effector_position_relative
|
||||||
|
# [12:16] end_effector_rotation_relative (quaternion)
|
||||||
|
STATE_DIM = 16
|
||||||
|
|
||||||
|
# Obs dict keys from RoboCasaGymEnv.get_observation()
|
||||||
|
_CAM_KEYS = (
|
||||||
|
"video.robot0_agentview_left",
|
||||||
|
"video.robot0_agentview_right",
|
||||||
|
"video.robot0_eye_in_hand",
|
||||||
|
)
|
||||||
|
_STATE_KEYS_ORDERED = (
|
||||||
|
"state.gripper_qpos", # (2,)
|
||||||
|
"state.base_position", # (3,)
|
||||||
|
"state.base_rotation", # (4,)
|
||||||
|
"state.end_effector_position_relative", # (3,)
|
||||||
|
"state.end_effector_rotation_relative", # (4,)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mapping from video.* key → short image name used in features_map
|
||||||
|
CAM_KEY_TO_NAME = {
|
||||||
|
"video.robot0_agentview_left": "agentview_left",
|
||||||
|
"video.robot0_agentview_right": "agentview_right",
|
||||||
|
"video.robot0_eye_in_hand": "eye_in_hand",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _flat_to_action_dict(flat: np.ndarray) -> dict[str, np.ndarray]:
|
||||||
|
"""Convert a 12D flat action array to the Dict format expected by RoboCasaGymEnv."""
|
||||||
|
return {
|
||||||
|
"action.end_effector_position": flat[0:3],
|
||||||
|
"action.end_effector_rotation": flat[3:6],
|
||||||
|
"action.gripper_close": flat[6:7],
|
||||||
|
"action.base_motion": flat[7:11],
|
||||||
|
"action.control_mode": flat[11:12],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RoboCasaEnv(gym.Env):
|
||||||
|
"""Thin wrapper around RoboCasaGymEnv that provides a flat Box action space
|
||||||
|
and a structured observation dict compatible with LeRobot policies.
|
||||||
|
|
||||||
|
Observations returned by step/reset:
|
||||||
|
{
|
||||||
|
"pixels": {
|
||||||
|
"agentview_left": (H, W, 3) uint8,
|
||||||
|
"agentview_right": (H, W, 3) uint8,
|
||||||
|
"eye_in_hand": (H, W, 3) uint8,
|
||||||
|
},
|
||||||
|
"robot_state": (16,) float32,
|
||||||
|
}
|
||||||
|
|
||||||
|
Actions: flat float32 ndarray of shape (12,), normalized to [-1, 1].
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
split: str = "target",
|
||||||
|
image_size: int = 128,
|
||||||
|
render_mode: str = "rgb_array",
|
||||||
|
episode_length: int = 500,
|
||||||
|
**gym_kwargs: Any,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Lazy import — robocasa is optional
|
||||||
|
import robocasa.environments # noqa: F401 — registers all gym envs
|
||||||
|
|
||||||
|
self.task = task
|
||||||
|
self.render_mode = render_mode
|
||||||
|
self.image_size = image_size
|
||||||
|
self._max_episode_steps = episode_length
|
||||||
|
self._step_count = 0
|
||||||
|
|
||||||
|
self._env = gym.make(
|
||||||
|
f"robocasa/{task}",
|
||||||
|
split=split,
|
||||||
|
camera_widths=image_size,
|
||||||
|
camera_heights=image_size,
|
||||||
|
**gym_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flat 12D Box action space
|
||||||
|
self.action_space = spaces.Box(
|
||||||
|
low=ACTION_LOW,
|
||||||
|
high=ACTION_HIGH,
|
||||||
|
shape=(ACTION_DIM,),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
images = {
|
||||||
|
name: spaces.Box(low=0, high=255, shape=(image_size, image_size, 3), dtype=np.uint8)
|
||||||
|
for name in CAM_KEY_TO_NAME.values()
|
||||||
|
}
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
"pixels": spaces.Dict(images),
|
||||||
|
"robot_state": spaces.Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(STATE_DIM,), dtype=np.float32
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_obs(self, raw_obs: dict) -> dict:
|
||||||
|
pixels = {
|
||||||
|
CAM_KEY_TO_NAME[k]: raw_obs[k]
|
||||||
|
for k in _CAM_KEYS
|
||||||
|
if k in raw_obs
|
||||||
|
}
|
||||||
|
state_parts = [
|
||||||
|
np.asarray(raw_obs[k], dtype=np.float32)
|
||||||
|
for k in _STATE_KEYS_ORDERED
|
||||||
|
if k in raw_obs
|
||||||
|
]
|
||||||
|
robot_state = np.concatenate(state_parts) if state_parts else np.zeros(STATE_DIM, dtype=np.float32)
|
||||||
|
return {"pixels": pixels, "robot_state": robot_state}
|
||||||
|
|
||||||
|
def reset(self, seed: int | None = None, **kwargs) -> tuple[dict, dict]:
|
||||||
|
super().reset(seed=seed)
|
||||||
|
self._step_count = 0
|
||||||
|
raw_obs, info = self._env.reset(seed=seed)
|
||||||
|
info.setdefault("is_success", False)
|
||||||
|
info["task"] = self.task
|
||||||
|
return self._format_obs(raw_obs), info
|
||||||
|
|
||||||
|
def step(self, action: np.ndarray) -> tuple[dict, float, bool, bool, dict]:
|
||||||
|
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}"
|
||||||
|
)
|
||||||
|
action_dict = _flat_to_action_dict(action)
|
||||||
|
raw_obs, reward, terminated, truncated, info = self._env.step(action_dict)
|
||||||
|
self._step_count += 1
|
||||||
|
|
||||||
|
is_success = bool(info.get("success", False))
|
||||||
|
terminated = terminated or is_success
|
||||||
|
if self._step_count >= self._max_episode_steps:
|
||||||
|
truncated = True
|
||||||
|
|
||||||
|
info.update({"task": self.task, "is_success": is_success})
|
||||||
|
obs = self._format_obs(raw_obs)
|
||||||
|
|
||||||
|
if terminated or truncated:
|
||||||
|
info["final_info"] = {"task": self.task, "is_success": is_success}
|
||||||
|
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def render(self) -> np.ndarray | None:
|
||||||
|
if self.render_mode == "rgb_array":
|
||||||
|
return self._env.render()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self._env.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_env_fns(
|
||||||
|
*,
|
||||||
|
task: str,
|
||||||
|
n_envs: int,
|
||||||
|
image_size: int,
|
||||||
|
split: str,
|
||||||
|
episode_length: int,
|
||||||
|
gym_kwargs: dict[str, Any],
|
||||||
|
) -> list[Callable[[], RoboCasaEnv]]:
|
||||||
|
"""Build n_envs factory callables for a single task."""
|
||||||
|
def _make(episode_index: int) -> RoboCasaEnv: # noqa: ARG001
|
||||||
|
return RoboCasaEnv(
|
||||||
|
task=task,
|
||||||
|
split=split,
|
||||||
|
image_size=image_size,
|
||||||
|
episode_length=episode_length,
|
||||||
|
**gym_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [partial(_make, i) for i in range(n_envs)]
|
||||||
|
|
||||||
|
|
||||||
|
def create_robocasa_envs(
|
||||||
|
tasks: str | Sequence[str],
|
||||||
|
n_envs: int,
|
||||||
|
image_size: int = 128,
|
||||||
|
split: str = "target",
|
||||||
|
episode_length: int = 500,
|
||||||
|
gym_kwargs: dict[str, Any] | None = None,
|
||||||
|
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||||
|
) -> dict[str, dict[int, Any]]:
|
||||||
|
"""Create vectorized RoboCasa environments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks: A single task name or list of task names (without "robocasa/" prefix).
|
||||||
|
E.g. "PickPlaceCounterToCabinet" or ["BoilPot", "PrepareCoffee"].
|
||||||
|
n_envs: Number of parallel envs per task.
|
||||||
|
image_size: Square image resolution for all cameras.
|
||||||
|
split: RoboCasa dataset split — "pretrain" or "target".
|
||||||
|
episode_length: Max steps per episode before truncation.
|
||||||
|
gym_kwargs: Extra kwargs forwarded to each RoboCasaEnv.
|
||||||
|
env_cls: Callable to wrap list of factory fns (SyncVectorEnv or AsyncVectorEnv).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[task_name][task_id=0] -> vec_env
|
||||||
|
"""
|
||||||
|
if env_cls is None or not callable(env_cls):
|
||||||
|
raise ValueError("env_cls must be a callable wrapping a list of env factory callables.")
|
||||||
|
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||||
|
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||||
|
|
||||||
|
if isinstance(tasks, str):
|
||||||
|
task_list = [t.strip() for t in tasks.split(",") if t.strip()]
|
||||||
|
else:
|
||||||
|
task_list = [str(t).strip() for t in tasks if str(t).strip()]
|
||||||
|
if not task_list:
|
||||||
|
raise ValueError("`tasks` must contain at least one task name.")
|
||||||
|
|
||||||
|
gym_kwargs = dict(gym_kwargs or {})
|
||||||
|
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||||
|
|
||||||
|
print(f"Creating RoboCasa envs | tasks={task_list} | n_envs(per task)={n_envs} | split={split}")
|
||||||
|
for task in task_list:
|
||||||
|
fns = _make_env_fns(
|
||||||
|
task=task,
|
||||||
|
n_envs=n_envs,
|
||||||
|
image_size=image_size,
|
||||||
|
split=split,
|
||||||
|
episode_length=episode_length,
|
||||||
|
gym_kwargs=gym_kwargs,
|
||||||
|
)
|
||||||
|
out["robocasa"][len(out["robocasa"])] = env_cls(fns)
|
||||||
|
print(f" Built vec env | task={task} | n_envs={n_envs}")
|
||||||
|
|
||||||
|
return {suite: dict(task_map) for suite, task_map in out.items()}
|
||||||
154
src/lerobot/envs/robomme.py
Normal file
154
src/lerobot/envs/robomme.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
"""RoboMME environment wrapper for LeRobot evaluation.
|
||||||
|
|
||||||
|
Wraps the RoboMME ``BenchmarkEnvBuilder`` into a Gymnasium-compatible
|
||||||
|
``VectorEnv`` suitable for ``lerobot_eval``.
|
||||||
|
|
||||||
|
RoboMME tasks:
|
||||||
|
Counting: BinFill, PickXtimes, SwingXtimes, StopCube
|
||||||
|
Permanence: VideoUnmask, VideoUnmaskSwap, ButtonUnmask, ButtonUnmaskSwap
|
||||||
|
Reference: PickHighlight, VideoRepick, VideoPlaceButton, VideoPlaceOrder
|
||||||
|
Imitation: MoveCube, InsertPeg, PatternLock, RouteStick
|
||||||
|
|
||||||
|
Install: pip install robomme (or from source: https://github.com/RoboMME/robomme_benchmark)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
ROBOMME_TASKS = [
|
||||||
|
"BinFill", "PickXtimes", "SwingXtimes", "StopCube",
|
||||||
|
"VideoUnmask", "VideoUnmaskSwap", "ButtonUnmask", "ButtonUnmaskSwap",
|
||||||
|
"PickHighlight", "VideoRepick", "VideoPlaceButton", "VideoPlaceOrder",
|
||||||
|
"MoveCube", "InsertPeg", "PatternLock", "RouteStick",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RoboMMEGymEnv(gym.Env):
|
||||||
|
"""Thin Gymnasium wrapper around a single RoboMME episode env."""
|
||||||
|
|
||||||
|
metadata = {"render_modes": ["rgb_array"]}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task: str = "PickXtimes",
|
||||||
|
action_space_type: str = "joint_angle",
|
||||||
|
dataset: str = "test",
|
||||||
|
episode_idx: int = 0,
|
||||||
|
max_steps: int = 300,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
from robomme.env_record_wrapper import BenchmarkEnvBuilder
|
||||||
|
|
||||||
|
self._task = task
|
||||||
|
self._action_space_type = action_space_type
|
||||||
|
self._dataset = dataset
|
||||||
|
self._episode_idx = episode_idx
|
||||||
|
self._max_steps = max_steps
|
||||||
|
|
||||||
|
self._builder = BenchmarkEnvBuilder(
|
||||||
|
env_id=task,
|
||||||
|
dataset=dataset,
|
||||||
|
action_space=action_space_type,
|
||||||
|
gui_render=False,
|
||||||
|
max_steps=max_steps,
|
||||||
|
)
|
||||||
|
self._env = None
|
||||||
|
|
||||||
|
action_dim = 8 if action_space_type == "joint_angle" else 7
|
||||||
|
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(action_dim,), dtype=np.float32)
|
||||||
|
self.observation_space = spaces.Dict({
|
||||||
|
"front_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
|
||||||
|
"wrist_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
|
||||||
|
"state": spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32),
|
||||||
|
})
|
||||||
|
|
||||||
|
def reset(self, *, seed=None, options=None):
|
||||||
|
super().reset(seed=seed)
|
||||||
|
self._env = self._builder.make_env_for_episode(
|
||||||
|
episode_idx=self._episode_idx, max_steps=self._max_steps,
|
||||||
|
)
|
||||||
|
obs, info = self._env.reset()
|
||||||
|
return self._convert_obs(obs), self._convert_info(info)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, terminated, truncated, info = self._env.step(action)
|
||||||
|
|
||||||
|
terminated_bool = bool(terminated.item()) if hasattr(terminated, "item") else bool(terminated)
|
||||||
|
truncated_bool = bool(truncated.item()) if hasattr(truncated, "item") else bool(truncated)
|
||||||
|
|
||||||
|
status = info.get("status", "ongoing")
|
||||||
|
is_success = status == "success"
|
||||||
|
conv_info = self._convert_info(info)
|
||||||
|
conv_info["is_success"] = is_success
|
||||||
|
|
||||||
|
return self._convert_obs(obs), float(reward), terminated_bool, truncated_bool, conv_info
|
||||||
|
|
||||||
|
def _convert_obs(self, obs: dict) -> dict:
|
||||||
|
front_rgb = obs["front_rgb_list"][-1] if isinstance(obs["front_rgb_list"], list) else obs["front_rgb_list"]
|
||||||
|
wrist_rgb = obs["wrist_rgb_list"][-1] if isinstance(obs["wrist_rgb_list"], list) else obs["wrist_rgb_list"]
|
||||||
|
joint_state = obs["joint_state_list"][-1] if isinstance(obs["joint_state_list"], list) else obs["joint_state_list"]
|
||||||
|
gripper_state = obs["gripper_state_list"][-1] if isinstance(obs["gripper_state_list"], list) else obs["gripper_state_list"]
|
||||||
|
|
||||||
|
front_rgb = np.asarray(front_rgb, dtype=np.uint8)
|
||||||
|
wrist_rgb = np.asarray(wrist_rgb, dtype=np.uint8)
|
||||||
|
joint = np.asarray(joint_state, dtype=np.float32).flatten()[:7]
|
||||||
|
gripper = np.asarray(gripper_state, dtype=np.float32).flatten()[:1]
|
||||||
|
state = np.concatenate([joint, gripper])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"front_rgb": front_rgb,
|
||||||
|
"wrist_rgb": wrist_rgb,
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _convert_info(self, info: dict) -> dict:
|
||||||
|
return {
|
||||||
|
"status": info.get("status", "ongoing"),
|
||||||
|
"task_goal": info.get("task_goal", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_robomme_envs(
|
||||||
|
task: str,
|
||||||
|
n_envs: int = 1,
|
||||||
|
action_space_type: str = "joint_angle",
|
||||||
|
dataset: str = "test",
|
||||||
|
episode_length: int = 300,
|
||||||
|
task_ids: list[int] | None = None,
|
||||||
|
env_cls=None,
|
||||||
|
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||||
|
"""Create vectorized RoboMME environments for evaluation.
|
||||||
|
|
||||||
|
Returns {suite_name: {task_id: VectorEnv}} matching lerobot's expected format.
|
||||||
|
"""
|
||||||
|
if env_cls is None:
|
||||||
|
env_cls = gym.vector.SyncVectorEnv
|
||||||
|
|
||||||
|
if task_ids is None:
|
||||||
|
task_ids = [0]
|
||||||
|
|
||||||
|
suite_name = "robomme"
|
||||||
|
envs_by_task = {}
|
||||||
|
|
||||||
|
for task_id in task_ids:
|
||||||
|
def _make_one(ep_idx=task_id):
|
||||||
|
return RoboMMEGymEnv(
|
||||||
|
task=task,
|
||||||
|
action_space_type=action_space_type,
|
||||||
|
dataset=dataset,
|
||||||
|
episode_idx=ep_idx,
|
||||||
|
max_steps=episode_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
vec = env_cls(
|
||||||
|
[_make_one for _ in range(n_envs)],
|
||||||
|
autoreset_mode=gym.vector.AutoresetMode.SAME_STEP,
|
||||||
|
)
|
||||||
|
envs_by_task[task_id] = vec
|
||||||
|
|
||||||
|
return {suite_name: envs_by_task}
|
||||||
@@ -153,6 +153,44 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="robocasa_processor")
|
||||||
|
class RoboCasaProcessorStep(ObservationProcessorStep):
|
||||||
|
"""
|
||||||
|
Processes RoboCasa observations into LeRobot format.
|
||||||
|
|
||||||
|
The RoboCasaEnv wrapper returns:
|
||||||
|
- ``pixels.<cam_name>``: (B, C, H, W) float32 images (already converted by vectorenv)
|
||||||
|
- ``observation.robot_state``: (B, 16) float32 proprioception
|
||||||
|
|
||||||
|
This step remaps them to:
|
||||||
|
- ``observation.images.<cam_name>`` (unchanged tensor)
|
||||||
|
- ``observation.state`` (robot_state renamed)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _process_observation(self, observation: dict) -> dict:
|
||||||
|
processed = {}
|
||||||
|
obs_prefix = OBS_PREFIX # "observation."
|
||||||
|
|
||||||
|
for key, value in observation.items():
|
||||||
|
if key.startswith(f"{OBS_IMAGES}."):
|
||||||
|
# Already in the right place; pass through
|
||||||
|
processed[key] = value
|
||||||
|
elif key == OBS_STATE or key == f"{obs_prefix}robot_state":
|
||||||
|
# Rename robot_state → observation.state
|
||||||
|
processed[OBS_STATE] = value.float() if hasattr(value, "float") else value
|
||||||
|
|
||||||
|
return processed
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
def observation(self, observation: dict) -> dict:
|
||||||
|
return self._process_observation(observation)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register(name="isaaclab_arena_processor")
|
@ProcessorStepRegistry.register(name="isaaclab_arena_processor")
|
||||||
class IsaaclabArenaProcessorStep(ObservationProcessorStep):
|
class IsaaclabArenaProcessorStep(ObservationProcessorStep):
|
||||||
|
|||||||
@@ -16,3 +16,5 @@
|
|||||||
|
|
||||||
from .config_unitree_g1 import UnitreeG1Config
|
from .config_unitree_g1 import UnitreeG1Config
|
||||||
from .unitree_g1 import UnitreeG1
|
from .unitree_g1 import UnitreeG1
|
||||||
|
|
||||||
|
__all__ = ["UnitreeG1", "UnitreeG1Config"]
|
||||||
|
|||||||
@@ -27,11 +27,10 @@ _GAINS: dict[str, dict[str, list[float]]] = {
|
|||||||
}, # pitch, roll, yaw, knee, ankle_pitch, ankle_roll
|
}, # pitch, roll, yaw, knee, ankle_pitch, ankle_roll
|
||||||
"right_leg": {"kp": [150, 150, 150, 300, 40, 40], "kd": [2, 2, 2, 4, 2, 2]},
|
"right_leg": {"kp": [150, 150, 150, 300, 40, 40], "kd": [2, 2, 2, 4, 2, 2]},
|
||||||
"waist": {"kp": [250, 250, 250], "kd": [5, 5, 5]}, # yaw, roll, pitch
|
"waist": {"kp": [250, 250, 250], "kd": [5, 5, 5]}, # yaw, roll, pitch
|
||||||
"left_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]}, # shoulder_pitch/roll/yaw, elbow
|
"left_arm": {"kp": [50, 50, 80, 80], "kd": [3, 3, 3, 3]}, # shoulder_pitch/roll/yaw, elbow
|
||||||
"left_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]}, # roll, pitch, yaw
|
"left_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]}, # roll, pitch, yaw
|
||||||
"right_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]},
|
"right_arm": {"kp": [50, 50, 80, 80], "kd": [3, 3, 3, 3]},
|
||||||
"right_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]},
|
"right_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]},
|
||||||
"other": {"kp": [80, 80, 80, 80, 80, 80], "kd": [3, 3, 3, 3, 3, 3]},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -68,3 +67,7 @@ class UnitreeG1Config(RobotConfig):
|
|||||||
|
|
||||||
# Compensates for gravity on the unitree's arms using the arm ik solver
|
# Compensates for gravity on the unitree's arms using the arm ik solver
|
||||||
gravity_compensation: bool = False
|
gravity_compensation: bool = False
|
||||||
|
|
||||||
|
# Lower-body controller class name, e.g. "GrootLocomotionController" or
|
||||||
|
# "HolosomaLocomotionController". None disables it.
|
||||||
|
controller: str | None = None
|
||||||
|
|||||||
@@ -16,13 +16,11 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
from collections import deque
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
parent2_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
sys.path.append(parent2_dir)
|
|
||||||
|
|
||||||
|
|
||||||
class WeightedMovingFilter:
|
class WeightedMovingFilter:
|
||||||
@@ -31,18 +29,14 @@ class WeightedMovingFilter:
|
|||||||
self._weights = np.array(weights)
|
self._weights = np.array(weights)
|
||||||
self._data_size = data_size
|
self._data_size = data_size
|
||||||
self._filtered_data = np.zeros(self._data_size)
|
self._filtered_data = np.zeros(self._data_size)
|
||||||
self._data_queue = []
|
self._data_queue = deque(maxlen=self._window_size)
|
||||||
|
|
||||||
def _apply_filter(self):
|
def _apply_filter(self):
|
||||||
if len(self._data_queue) < self._window_size:
|
if len(self._data_queue) < self._window_size:
|
||||||
return self._data_queue[-1]
|
return self._data_queue[-1]
|
||||||
|
|
||||||
data_array = np.array(self._data_queue)
|
data_array = np.array(self._data_queue)
|
||||||
temp_filtered_data = np.zeros(self._data_size)
|
return data_array.T @ self._weights
|
||||||
for i in range(self._data_size):
|
|
||||||
temp_filtered_data[i] = np.convolve(data_array[:, i], self._weights, mode="valid")[-1]
|
|
||||||
|
|
||||||
return temp_filtered_data
|
|
||||||
|
|
||||||
def add_data(self, new_data):
|
def add_data(self, new_data):
|
||||||
assert len(new_data) == self._data_size
|
assert len(new_data) == self._data_size
|
||||||
@@ -52,9 +46,6 @@ class WeightedMovingFilter:
|
|||||||
): # skip duplicate data
|
): # skip duplicate data
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(self._data_queue) >= self._window_size:
|
|
||||||
self._data_queue.pop(0)
|
|
||||||
|
|
||||||
self._data_queue.append(new_data)
|
self._data_queue.append(new_data)
|
||||||
self._filtered_data = self._apply_filter()
|
self._filtered_data = self._apply_filter()
|
||||||
|
|
||||||
@@ -71,8 +62,6 @@ class G1_29_ArmIK: # noqa: N801
|
|||||||
from pinocchio import casadi as cpin
|
from pinocchio import casadi as cpin
|
||||||
|
|
||||||
self._pin = pin
|
self._pin = pin
|
||||||
np.set_printoptions(precision=5, suppress=True, linewidth=200)
|
|
||||||
|
|
||||||
self.unit_test = unit_test
|
self.unit_test = unit_test
|
||||||
|
|
||||||
self.repo_path = snapshot_download("lerobot/unitree-g1-mujoco")
|
self.repo_path = snapshot_download("lerobot/unitree-g1-mujoco")
|
||||||
@@ -249,50 +238,35 @@ class G1_29_ArmIK: # noqa: N801
|
|||||||
self.opti.set_value(self.param_tf_r, right_wrist)
|
self.opti.set_value(self.param_tf_r, right_wrist)
|
||||||
self.opti.set_value(self.var_q_last, self.init_data) # for smooth
|
self.opti.set_value(self.var_q_last, self.init_data) # for smooth
|
||||||
|
|
||||||
|
converged = True
|
||||||
try:
|
try:
|
||||||
self.opti.solve()
|
self.opti.solve()
|
||||||
|
|
||||||
sol_q = self.opti.value(self.var_q)
|
sol_q = self.opti.value(self.var_q)
|
||||||
self.smooth_filter.add_data(sol_q)
|
|
||||||
sol_q = self.smooth_filter.filtered_data
|
|
||||||
|
|
||||||
if current_lr_arm_motor_dq is not None:
|
|
||||||
v = current_lr_arm_motor_dq * 0.0
|
|
||||||
else:
|
|
||||||
v = (sol_q - self.init_data) * 0.0
|
|
||||||
|
|
||||||
self.init_data = sol_q
|
|
||||||
|
|
||||||
sol_tauff = self._pin.rnea(
|
|
||||||
self.reduced_robot.model,
|
|
||||||
self.reduced_robot.data,
|
|
||||||
sol_q,
|
|
||||||
v,
|
|
||||||
np.zeros(self.reduced_robot.model.nv),
|
|
||||||
)
|
|
||||||
|
|
||||||
return sol_q, sol_tauff
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"ERROR in convergence, plotting debug info.{e}")
|
converged = False
|
||||||
|
logger.error(f"IK convergence error: {e}")
|
||||||
sol_q = self.opti.debug.value(self.var_q)
|
sol_q = self.opti.debug.value(self.var_q)
|
||||||
self.smooth_filter.add_data(sol_q)
|
|
||||||
sol_q = self.smooth_filter.filtered_data
|
|
||||||
|
|
||||||
if current_lr_arm_motor_dq is not None:
|
self.smooth_filter.add_data(sol_q)
|
||||||
v = current_lr_arm_motor_dq * 0.0
|
sol_q = self.smooth_filter.filtered_data
|
||||||
else:
|
self.init_data = sol_q
|
||||||
v = (sol_q - self.init_data) * 0.0
|
|
||||||
|
|
||||||
self.init_data = sol_q
|
|
||||||
|
|
||||||
|
if not converged:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"sol_q:{sol_q} \nmotorstate: \n{current_lr_arm_motor_q} \nleft_pose: \n{left_wrist} \nright_pose: \n{right_wrist}"
|
f"sol_q:{sol_q} \nmotorstate: \n{current_lr_arm_motor_q} \nleft_pose: \n{left_wrist} \nright_pose: \n{right_wrist}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return current_lr_arm_motor_q, np.zeros(self.reduced_robot.model.nv)
|
return current_lr_arm_motor_q, np.zeros(self.reduced_robot.model.nv)
|
||||||
|
|
||||||
|
sol_tauff = self._pin.rnea(
|
||||||
|
self.reduced_robot.model,
|
||||||
|
self.reduced_robot.data,
|
||||||
|
sol_q,
|
||||||
|
np.zeros(self.reduced_robot.model.nv),
|
||||||
|
np.zeros(self.reduced_robot.model.nv),
|
||||||
|
)
|
||||||
|
|
||||||
|
return sol_q, sol_tauff
|
||||||
|
|
||||||
def solve_tau(self, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None):
|
def solve_tau(self, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None):
|
||||||
try:
|
try:
|
||||||
q_g1 = np.array(current_lr_arm_motor_q, dtype=float)
|
q_g1 = np.array(current_lr_arm_motor_q, dtype=float)
|
||||||
@@ -14,12 +14,34 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import importlib
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# ruff: noqa: N801, N815
|
# ruff: noqa: N801, N815
|
||||||
|
|
||||||
NUM_MOTORS = 29
|
NUM_MOTORS = 29
|
||||||
|
|
||||||
|
REMOTE_AXES = ("remote.lx", "remote.ly", "remote.rx", "remote.ry")
|
||||||
|
REMOTE_BUTTONS = tuple(f"remote.button.{i}" for i in range(16))
|
||||||
|
REMOTE_KEYS = REMOTE_AXES + REMOTE_BUTTONS
|
||||||
|
|
||||||
|
|
||||||
|
def default_remote_input() -> dict[str, float]:
|
||||||
|
"""Return a zeroed-out remote input dict (axes + buttons)."""
|
||||||
|
return dict.fromkeys(REMOTE_KEYS, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gravity_orientation(quaternion: list[float] | np.ndarray) -> np.ndarray:
|
||||||
|
"""Get gravity orientation from quaternion [w, x, y, z]."""
|
||||||
|
qw, qx, qy, qz = quaternion
|
||||||
|
gravity_orientation = np.zeros(3, dtype=np.float32)
|
||||||
|
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
|
||||||
|
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
|
||||||
|
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
|
||||||
|
return gravity_orientation
|
||||||
|
|
||||||
|
|
||||||
class G1_29_JointArmIndex(IntEnum):
|
class G1_29_JointArmIndex(IntEnum):
|
||||||
# Left arm
|
# Left arm
|
||||||
@@ -29,7 +51,7 @@ class G1_29_JointArmIndex(IntEnum):
|
|||||||
kLeftElbow = 18
|
kLeftElbow = 18
|
||||||
kLeftWristRoll = 19
|
kLeftWristRoll = 19
|
||||||
kLeftWristPitch = 20
|
kLeftWristPitch = 20
|
||||||
kLeftWristyaw = 21
|
kLeftWristYaw = 21
|
||||||
|
|
||||||
# Right arm
|
# Right arm
|
||||||
kRightShoulderPitch = 22
|
kRightShoulderPitch = 22
|
||||||
@@ -41,6 +63,21 @@ class G1_29_JointArmIndex(IntEnum):
|
|||||||
kRightWristYaw = 28
|
kRightWristYaw = 28
|
||||||
|
|
||||||
|
|
||||||
|
def make_locomotion_controller(name: str | None):
|
||||||
|
"""Instantiate a locomotion controller by class name. Returns None if name is None."""
|
||||||
|
if name is None:
|
||||||
|
return None
|
||||||
|
controllers = {
|
||||||
|
"GrootLocomotionController": "lerobot.robots.unitree_g1.gr00t_locomotion",
|
||||||
|
"HolosomaLocomotionController": "lerobot.robots.unitree_g1.holosoma_locomotion",
|
||||||
|
}
|
||||||
|
module_path = controllers.get(name)
|
||||||
|
if module_path is None:
|
||||||
|
raise ValueError(f"Unknown controller: {name!r}. Available: {list(controllers)}")
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
return getattr(module, name)()
|
||||||
|
|
||||||
|
|
||||||
class G1_29_JointIndex(IntEnum):
|
class G1_29_JointIndex(IntEnum):
|
||||||
# Left leg
|
# Left leg
|
||||||
kLeftHipPitch = 0
|
kLeftHipPitch = 0
|
||||||
@@ -69,7 +106,7 @@ class G1_29_JointIndex(IntEnum):
|
|||||||
kLeftElbow = 18
|
kLeftElbow = 18
|
||||||
kLeftWristRoll = 19
|
kLeftWristRoll = 19
|
||||||
kLeftWristPitch = 20
|
kLeftWristPitch = 20
|
||||||
kLeftWristyaw = 21
|
kLeftWristYaw = 21
|
||||||
|
|
||||||
# Right arm
|
# Right arm
|
||||||
kRightShoulderPitch = 22
|
kRightShoulderPitch = 22
|
||||||
|
|||||||
@@ -14,20 +14,20 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
from lerobot.robots.unitree_g1.g1_utils import (
|
||||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
REMOTE_AXES,
|
||||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
REMOTE_BUTTONS,
|
||||||
|
G1_29_JointIndex,
|
||||||
|
get_gravity_orientation,
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -36,18 +36,13 @@ GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # Hip pitch
|
|||||||
GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # Knee
|
GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # Knee
|
||||||
GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # Ankle pitch
|
GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # Ankle pitch
|
||||||
|
|
||||||
MISSING_JOINTS = []
|
|
||||||
G1_MODEL = "g1_23" # Or "g1_29"
|
|
||||||
if G1_MODEL == "g1_23":
|
|
||||||
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # Waist yaw/pitch, wrist pitch/yaw
|
|
||||||
|
|
||||||
# Control parameters
|
# Control parameters
|
||||||
ACTION_SCALE = 0.25
|
ACTION_SCALE = 0.25
|
||||||
CONTROL_DT = 0.02 # 50Hz
|
CONTROL_DT = 0.02 # 50Hz
|
||||||
ANG_VEL_SCALE: float = 0.25
|
ANG_VEL_SCALE: float = 0.25
|
||||||
DOF_POS_SCALE: float = 1.0
|
DOF_POS_SCALE: float = 1.0
|
||||||
DOF_VEL_SCALE: float = 0.05
|
DOF_VEL_SCALE: float = 0.05
|
||||||
CMD_SCALE: list = [2.0, 2.0, 0.25]
|
CMD_SCALE: list[float] = [2.0, 2.0, 0.25]
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
|
DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
|
||||||
@@ -85,11 +80,11 @@ def load_groot_policies(
|
|||||||
class GrootLocomotionController:
|
class GrootLocomotionController:
|
||||||
"""GR00T lower-body locomotion controller for the Unitree G1."""
|
"""GR00T lower-body locomotion controller for the Unitree G1."""
|
||||||
|
|
||||||
def __init__(self, policy_balance, policy_walk, robot, config):
|
control_dt = CONTROL_DT # Expose for unitree_g1.py
|
||||||
self.policy_balance = policy_balance
|
|
||||||
self.policy_walk = policy_walk
|
def __init__(self):
|
||||||
self.robot = robot
|
# Load policies
|
||||||
self.config = config
|
self.policy_balance, self.policy_walk = load_groot_policies()
|
||||||
|
|
||||||
self.cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot
|
self.cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot
|
||||||
|
|
||||||
@@ -109,45 +104,60 @@ class GrootLocomotionController:
|
|||||||
|
|
||||||
logger.info("GrootLocomotionController initialized")
|
logger.info("GrootLocomotionController initialized")
|
||||||
|
|
||||||
def run_step(self):
|
def reset(self) -> None:
|
||||||
# Get current observation
|
"""Reset internal state for a new episode."""
|
||||||
obs = self.robot.get_observation()
|
self.cmd[:] = 0.0
|
||||||
|
self.groot_qj_all[:] = 0.0
|
||||||
|
self.groot_dqj_all[:] = 0.0
|
||||||
|
self.groot_action[:] = 0.0
|
||||||
|
self.groot_obs_single[:] = 0.0
|
||||||
|
self.groot_obs_stacked[:] = 0.0
|
||||||
|
self.groot_height_cmd = 0.74
|
||||||
|
self.groot_orientation_cmd[:] = 0.0
|
||||||
|
self.groot_obs_history.clear()
|
||||||
|
for _ in range(6):
|
||||||
|
self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
|
||||||
|
|
||||||
if not obs:
|
def run_step(self, action: dict, lowstate) -> dict:
|
||||||
return
|
"""Run one step of the locomotion controller.
|
||||||
|
|
||||||
# Get command from remote controller
|
Args:
|
||||||
if obs["remote.buttons"][0]: # R1 - raise waist
|
action: Action dict containing remote.lx/ly/rx/ry and buttons
|
||||||
|
lowstate: Robot lowstate containing motor positions/velocities and IMU
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Action dict for lower body joints (0-14)
|
||||||
|
"""
|
||||||
|
if lowstate is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
buttons = [int(action.get(k, 0)) for k in REMOTE_BUTTONS]
|
||||||
|
if buttons[0]: # R1 - raise waist
|
||||||
self.groot_height_cmd += 0.001
|
self.groot_height_cmd += 0.001
|
||||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||||
if obs["remote.buttons"][4]: # R2 - lower waist
|
if buttons[4]: # R2 - lower waist
|
||||||
self.groot_height_cmd -= 0.001
|
self.groot_height_cmd -= 0.001
|
||||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||||
|
|
||||||
self.cmd[0] = obs["remote.ly"] # Forward/backward
|
lx, ly, rx, _ry = (action.get(k, 0.0) for k in REMOTE_AXES)
|
||||||
self.cmd[1] = obs["remote.lx"] * -1 # Left/right
|
self.cmd[0] = ly # Forward/backward
|
||||||
self.cmd[2] = obs["remote.rx"] * -1 # Rotation rate
|
self.cmd[1] = -lx # Left/right (negated)
|
||||||
|
self.cmd[2] = -rx # Rotation rate (negated)
|
||||||
|
|
||||||
# Get joint positions and velocities from flat dict
|
# Get joint positions and velocities from lowstate
|
||||||
for motor in G1_29_JointIndex:
|
for motor in G1_29_JointIndex:
|
||||||
name = motor.name
|
|
||||||
idx = motor.value
|
idx = motor.value
|
||||||
self.groot_qj_all[idx] = obs[f"{name}.q"]
|
self.groot_qj_all[idx] = lowstate.motor_state[idx].q
|
||||||
self.groot_dqj_all[idx] = obs[f"{name}.dq"]
|
self.groot_dqj_all[idx] = lowstate.motor_state[idx].dq
|
||||||
|
|
||||||
# Adapt observation for g1_23dof
|
|
||||||
for idx in MISSING_JOINTS:
|
|
||||||
self.groot_qj_all[idx] = 0.0
|
|
||||||
self.groot_dqj_all[idx] = 0.0
|
|
||||||
|
|
||||||
# Scale joint positions and velocities
|
# Scale joint positions and velocities
|
||||||
qj_obs = self.groot_qj_all.copy()
|
qj_obs = self.groot_qj_all.copy()
|
||||||
dqj_obs = self.groot_dqj_all.copy()
|
dqj_obs = self.groot_dqj_all.copy()
|
||||||
|
|
||||||
# Express IMU data in gravity frame of reference
|
# Express IMU data in gravity frame of reference
|
||||||
quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
|
quat = lowstate.imu_state.quaternion
|
||||||
ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
|
ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32)
|
||||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
gravity_orientation = get_gravity_orientation(quat)
|
||||||
|
|
||||||
# Scale joint positions and velocities before policy inference
|
# Scale joint positions and velocities before policy inference
|
||||||
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
|
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
|
||||||
@@ -186,73 +196,10 @@ class GrootLocomotionController:
|
|||||||
# Transform action back to target joint positions
|
# Transform action back to target joint positions
|
||||||
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * ACTION_SCALE
|
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * ACTION_SCALE
|
||||||
|
|
||||||
# Build action dict (only first 15 joints for GR00T)
|
# Build action dict
|
||||||
action_dict = {}
|
action_dict = {}
|
||||||
for i in range(15):
|
for i in range(15):
|
||||||
motor_name = G1_29_JointIndex(i).name
|
motor_name = G1_29_JointIndex(i).name
|
||||||
action_dict[f"{motor_name}.q"] = float(target_dof_pos_15[i])
|
action_dict[f"{motor_name}.q"] = float(target_dof_pos_15[i])
|
||||||
|
|
||||||
# Zero out missing joints for g1_23dof
|
return action_dict
|
||||||
for joint_idx in MISSING_JOINTS:
|
|
||||||
motor_name = G1_29_JointIndex(joint_idx).name
|
|
||||||
action_dict[f"{motor_name}.q"] = 0.0
|
|
||||||
|
|
||||||
# Send action to robot
|
|
||||||
self.robot.send_action(action_dict)
|
|
||||||
|
|
||||||
|
|
||||||
def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None:
|
|
||||||
"""Main function to run the GR00T locomotion controller.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
repo_id: Hugging Face Hub repository ID for GR00T policies.
|
|
||||||
"""
|
|
||||||
# Load policies
|
|
||||||
policy_balance, policy_walk = load_groot_policies(repo_id=repo_id)
|
|
||||||
|
|
||||||
# Initialize robot
|
|
||||||
config = UnitreeG1Config()
|
|
||||||
robot = UnitreeG1(config)
|
|
||||||
|
|
||||||
robot.connect()
|
|
||||||
|
|
||||||
# Initialize gr00T locomotion controller
|
|
||||||
groot_controller = GrootLocomotionController(
|
|
||||||
policy_balance=policy_balance,
|
|
||||||
policy_walk=policy_walk,
|
|
||||||
robot=robot,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
robot.reset(CONTROL_DT, GROOT_DEFAULT_ANGLES)
|
|
||||||
|
|
||||||
logger.info("Use joystick: LY=fwd/back, LX=left/right, RX=rotate, R1=raise waist, R2=lower waist")
|
|
||||||
logger.info("Press Ctrl+C to stop")
|
|
||||||
|
|
||||||
# Run step
|
|
||||||
while not robot._shutdown_event.is_set():
|
|
||||||
start_time = time.time()
|
|
||||||
groot_controller.run_step()
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
sleep_time = max(0, CONTROL_DT - elapsed)
|
|
||||||
time.sleep(sleep_time)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("Stopping locomotion...")
|
|
||||||
finally:
|
|
||||||
if robot.is_connected:
|
|
||||||
robot.disconnect()
|
|
||||||
logger.info("Done!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1")
|
|
||||||
parser.add_argument(
|
|
||||||
"--repo-id",
|
|
||||||
type=str,
|
|
||||||
default=DEFAULT_GROOT_REPO_ID,
|
|
||||||
help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
run(repo_id=args.repo_id)
|
|
||||||
@@ -14,21 +14,21 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnx
|
import onnx
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
from lerobot.robots.unitree_g1.g1_utils import (
|
||||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
REMOTE_AXES,
|
||||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
G1_29_JointArmIndex,
|
||||||
|
G1_29_JointIndex,
|
||||||
|
get_gravity_orientation,
|
||||||
|
)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_ANGLES = np.zeros(29, dtype=np.float32)
|
DEFAULT_ANGLES = np.zeros(29, dtype=np.float32)
|
||||||
@@ -40,18 +40,13 @@ DEFAULT_ANGLES[16] = 0.2 # Left shoulder roll
|
|||||||
DEFAULT_ANGLES[23] = -0.2 # Right shoulder roll
|
DEFAULT_ANGLES[23] = -0.2 # Right shoulder roll
|
||||||
DEFAULT_ANGLES[[18, 25]] = 0.6 # Elbow
|
DEFAULT_ANGLES[[18, 25]] = 0.6 # Elbow
|
||||||
|
|
||||||
MISSING_JOINTS = []
|
|
||||||
G1_MODEL = "g1_23" # Or "g1_29"
|
|
||||||
if G1_MODEL == "g1_23":
|
|
||||||
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # Waist yaw/pitch, wrist pitch/yaw
|
|
||||||
|
|
||||||
# Control parameters
|
# Control parameters
|
||||||
ACTION_SCALE = 0.25
|
ACTION_SCALE = 0.25
|
||||||
CONTROL_DT = 0.02 # 50Hz
|
CONTROL_DT = 0.005 # 200Hz
|
||||||
ANG_VEL_SCALE = 0.25
|
ANG_VEL_SCALE = 0.25
|
||||||
DOF_POS_SCALE = 1.0
|
DOF_POS_SCALE = 1.0
|
||||||
DOF_VEL_SCALE = 0.05
|
DOF_VEL_SCALE = 0.05
|
||||||
GAIT_PERIOD = 1.0
|
GAIT_PERIOD = 0.5
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
|
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
|
||||||
@@ -87,7 +82,7 @@ def load_policy(
|
|||||||
logger.info(f"Policy loaded: {policy.get_inputs()[0].shape} → {policy.get_outputs()[0].shape}")
|
logger.info(f"Policy loaded: {policy.get_inputs()[0].shape} → {policy.get_outputs()[0].shape}")
|
||||||
|
|
||||||
# Extract KP/KD from ONNX metadata
|
# Extract KP/KD from ONNX metadata
|
||||||
model = onnx.load(policy_path)
|
model = onnx.load(policy_path, load_external_data=False)
|
||||||
metadata = {prop.key: prop.value for prop in model.metadata_props}
|
metadata = {prop.key: prop.value for prop in model.metadata_props}
|
||||||
|
|
||||||
if "kp" not in metadata or "kd" not in metadata:
|
if "kp" not in metadata or "kd" not in metadata:
|
||||||
@@ -101,15 +96,13 @@ def load_policy(
|
|||||||
|
|
||||||
|
|
||||||
class HolosomaLocomotionController:
|
class HolosomaLocomotionController:
|
||||||
"""Holosoma whole-body locomotion controller for Unitree G1."""
|
"""Holosoma lower-body locomotion controller for Unitree G1."""
|
||||||
|
|
||||||
def __init__(self, policy, robot, kp: np.ndarray, kd: np.ndarray):
|
control_dt = CONTROL_DT # Expose for unitree_g1.py
|
||||||
self.policy = policy
|
|
||||||
self.robot = robot
|
|
||||||
|
|
||||||
# Override robot's PD gains with policy gains
|
def __init__(self):
|
||||||
self.robot.kp = kp
|
# Load policy and gains
|
||||||
self.robot.kd = kd
|
self.policy, self.kp, self.kd = load_policy()
|
||||||
|
|
||||||
self.cmd = np.zeros(3, dtype=np.float32)
|
self.cmd = np.zeros(3, dtype=np.float32)
|
||||||
|
|
||||||
@@ -124,35 +117,55 @@ class HolosomaLocomotionController:
|
|||||||
self.phase_dt = 2 * np.pi / ((1.0 / CONTROL_DT) * GAIT_PERIOD)
|
self.phase_dt = 2 * np.pi / ((1.0 / CONTROL_DT) * GAIT_PERIOD)
|
||||||
self.is_standing = True
|
self.is_standing = True
|
||||||
|
|
||||||
def run_step(self):
|
logger.info("HolosomaLocomotionController initialized")
|
||||||
# Get current observation
|
|
||||||
obs = self.robot.get_observation()
|
|
||||||
|
|
||||||
if not obs:
|
def reset(self) -> None:
|
||||||
return
|
"""Reset internal state for a new episode."""
|
||||||
|
self.cmd[:] = 0.0
|
||||||
|
self.qj[:] = 0.0
|
||||||
|
self.dqj[:] = 0.0
|
||||||
|
self.obs[:] = 0.0
|
||||||
|
self.last_action[:] = 0.0
|
||||||
|
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||||
|
self.is_standing = True
|
||||||
|
|
||||||
# Get command from remote controller
|
def run_step(self, action: dict, lowstate) -> dict:
|
||||||
ly = obs["remote.ly"] if abs(obs["remote.ly"]) > 0.1 else 0.0
|
"""Run one step of the locomotion controller.
|
||||||
lx = obs["remote.lx"] if abs(obs["remote.lx"]) > 0.1 else 0.0
|
|
||||||
rx = obs["remote.rx"] if abs(obs["remote.rx"]) > 0.1 else 0.0
|
Args:
|
||||||
|
action: Action dict containing remote.lx/ly/rx/ry
|
||||||
|
lowstate: Robot lowstate containing motor positions/velocities and IMU
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Action dict for lower body joints (0-14)
|
||||||
|
"""
|
||||||
|
if lowstate is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
lx, ly, rx, _ry = (action.get(k, 0.0) for k in REMOTE_AXES)
|
||||||
|
ly = ly if abs(ly) > 0.1 else 0.0
|
||||||
|
lx = lx if abs(lx) > 0.1 else 0.0
|
||||||
|
rx = rx if abs(rx) > 0.1 else 0.0
|
||||||
|
ly = np.clip(ly, -0.3, 0.3)
|
||||||
|
lx = np.clip(lx, -0.3, 0.3)
|
||||||
self.cmd[:] = [ly, -lx, -rx]
|
self.cmd[:] = [ly, -lx, -rx]
|
||||||
|
|
||||||
# Get joint positions and velocities
|
# Get joint positions and velocities from lowstate
|
||||||
for motor in G1_29_JointIndex:
|
for motor in G1_29_JointIndex:
|
||||||
name = motor.name
|
|
||||||
idx = motor.value
|
idx = motor.value
|
||||||
self.qj[idx] = obs[f"{name}.q"]
|
self.qj[idx] = lowstate.motor_state[idx].q
|
||||||
self.dqj[idx] = obs[f"{name}.dq"]
|
self.dqj[idx] = lowstate.motor_state[idx].dq
|
||||||
|
|
||||||
# Adapt observation for g1_23dof
|
# Hide arm positions from policy (show DEFAULT_ANGLES instead)
|
||||||
for idx in MISSING_JOINTS:
|
# This prevents policy from reacting to teleop arm movements
|
||||||
self.qj[idx] = 0.0
|
for arm_joint in G1_29_JointArmIndex:
|
||||||
self.dqj[idx] = 0.0
|
self.qj[arm_joint.value] = DEFAULT_ANGLES[arm_joint.value]
|
||||||
|
self.dqj[arm_joint.value] = 0.0
|
||||||
|
|
||||||
# Express IMU data in gravity frame of reference
|
# Express IMU data in gravity frame of reference
|
||||||
quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
|
quat = lowstate.imu_state.quaternion
|
||||||
ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
|
ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32)
|
||||||
gravity = self.robot.get_gravity_orientation(quat)
|
gravity = get_gravity_orientation(quat)
|
||||||
|
|
||||||
# Scale joint positions and velocities before policy inference
|
# Scale joint positions and velocities before policy inference
|
||||||
qj_obs = (self.qj - DEFAULT_ANGLES) * DOF_POS_SCALE
|
qj_obs = (self.qj - DEFAULT_ANGLES) * DOF_POS_SCALE
|
||||||
@@ -186,79 +199,16 @@ class HolosomaLocomotionController:
|
|||||||
# Run policy inference
|
# Run policy inference
|
||||||
ort_in = {self.policy.get_inputs()[0].name: self.obs.reshape(1, -1).astype(np.float32)}
|
ort_in = {self.policy.get_inputs()[0].name: self.obs.reshape(1, -1).astype(np.float32)}
|
||||||
raw_action = self.policy.run(None, ort_in)[0].squeeze()
|
raw_action = self.policy.run(None, ort_in)[0].squeeze()
|
||||||
action = np.clip(raw_action, -100.0, 100.0)
|
policy_action = np.clip(raw_action, -100.0, 100.0)
|
||||||
self.last_action = action.copy()
|
self.last_action = policy_action.copy()
|
||||||
|
|
||||||
# Transform action back to target joint positions
|
# Transform action back to target joint positions
|
||||||
target = DEFAULT_ANGLES + action * ACTION_SCALE
|
target = DEFAULT_ANGLES + policy_action * ACTION_SCALE
|
||||||
|
|
||||||
# Build action dict
|
# Build action dict (first 15 joints only)
|
||||||
action_dict = {}
|
action_dict = {}
|
||||||
for motor in G1_29_JointIndex:
|
for i in range(15):
|
||||||
action_dict[f"{motor.name}.q"] = float(target[motor.value])
|
motor_name = G1_29_JointIndex(i).name
|
||||||
|
action_dict[f"{motor_name}.q"] = float(target[i])
|
||||||
|
|
||||||
# Zero out missing joints for g1_23dof
|
return action_dict
|
||||||
for joint_idx in MISSING_JOINTS:
|
|
||||||
motor_name = G1_29_JointIndex(joint_idx).name
|
|
||||||
action_dict[f"{motor_name}.q"] = 0.0
|
|
||||||
|
|
||||||
# Send action to robot
|
|
||||||
self.robot.send_action(action_dict)
|
|
||||||
|
|
||||||
|
|
||||||
def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") -> None:
|
|
||||||
"""Main function to run the Holosoma locomotion controller.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
repo_id: Hugging Face Hub repository ID for Holosoma policies.
|
|
||||||
policy_type: Policy type to use ('fastsac' or 'ppo').
|
|
||||||
"""
|
|
||||||
# Load policy and gains
|
|
||||||
policy, kp, kd = load_policy(repo_id=repo_id, policy_type=policy_type)
|
|
||||||
|
|
||||||
# Initialize robot
|
|
||||||
config = UnitreeG1Config()
|
|
||||||
robot = UnitreeG1(config)
|
|
||||||
robot.connect()
|
|
||||||
|
|
||||||
holosoma_controller = HolosomaLocomotionController(policy, robot, kp, kd)
|
|
||||||
|
|
||||||
try:
|
|
||||||
robot.reset(CONTROL_DT, DEFAULT_ANGLES)
|
|
||||||
|
|
||||||
logger.info("Use joystick: LY=fwd/back, LX=left/right, RX=rotate")
|
|
||||||
logger.info("Press Ctrl+C to stop")
|
|
||||||
|
|
||||||
# Run step
|
|
||||||
while not robot._shutdown_event.is_set():
|
|
||||||
start_time = time.time()
|
|
||||||
holosoma_controller.run_step()
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
sleep_time = max(0, CONTROL_DT - elapsed)
|
|
||||||
time.sleep(sleep_time)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("Stopping locomotion...")
|
|
||||||
finally:
|
|
||||||
if robot.is_connected:
|
|
||||||
robot.disconnect()
|
|
||||||
logger.info("Done!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Holosoma Locomotion Controller for Unitree G1")
|
|
||||||
parser.add_argument(
|
|
||||||
"--repo-id",
|
|
||||||
type=str,
|
|
||||||
default=DEFAULT_HOLOSOMA_REPO_ID,
|
|
||||||
help=f"Hugging Face Hub repo ID for Holosoma policies (default: {DEFAULT_HOLOSOMA_REPO_ID})",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--policy",
|
|
||||||
type=str,
|
|
||||||
choices=["fastsac", "ppo"],
|
|
||||||
default="fastsac",
|
|
||||||
help="Policy type to use: 'fastsac' (default) or 'ppo'",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
run(repo_id=args.repo_id, policy_type=args.policy)
|
|
||||||
@@ -24,6 +24,7 @@ This server runs on the robot and forwards:
|
|||||||
Uses JSON for secure serialization instead of pickle.
|
Uses JSON for secure serialization instead of pickle.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
import base64
|
import base64
|
||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
@@ -38,6 +39,8 @@ from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
|||||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState
|
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState
|
||||||
from unitree_sdk2py.utils.crc import CRC
|
from unitree_sdk2py.utils.crc import CRC
|
||||||
|
|
||||||
|
from lerobot.cameras.zmq.image_server import ImageServer
|
||||||
|
|
||||||
# DDS topic names follow Unitree SDK naming conventions
|
# DDS topic names follow Unitree SDK naming conventions
|
||||||
# ruff: noqa: N816
|
# ruff: noqa: N816
|
||||||
kTopicLowCommand_Debug = "rt/lowcmd" # action to robot
|
kTopicLowCommand_Debug = "rt/lowcmd" # action to robot
|
||||||
@@ -150,6 +153,32 @@ def cmd_forward_loop(
|
|||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
"""Main entry point for the robot server bridge."""
|
"""Main entry point for the robot server bridge."""
|
||||||
|
parser = argparse.ArgumentParser(description="DDS-to-ZMQ bridge server for Unitree G1")
|
||||||
|
parser.add_argument("--camera", action="store_true", help="Also launch camera server")
|
||||||
|
parser.add_argument("--camera-device", type=int, default=4, help="Camera device ID (default: 4)")
|
||||||
|
parser.add_argument("--camera-fps", type=int, default=30, help="Camera FPS (default: 30)")
|
||||||
|
parser.add_argument("--camera-width", type=int, default=640, help="Camera width (default: 640)")
|
||||||
|
parser.add_argument("--camera-height", type=int, default=480, help="Camera height (default: 480)")
|
||||||
|
parser.add_argument("--camera-port", type=int, default=5555, help="Camera ZMQ port (default: 5555)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Optionally start camera server in background thread
|
||||||
|
camera_thread = None
|
||||||
|
if args.camera:
|
||||||
|
camera_config = {
|
||||||
|
"fps": args.camera_fps,
|
||||||
|
"cameras": {
|
||||||
|
"head_camera": {
|
||||||
|
"device_id": args.camera_device,
|
||||||
|
"shape": [args.camera_height, args.camera_width],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
camera_server = ImageServer(camera_config, port=args.camera_port)
|
||||||
|
camera_thread = threading.Thread(target=camera_server.run, daemon=True)
|
||||||
|
camera_thread.start()
|
||||||
|
print(f"Camera server started on port {args.camera_port} (device {args.camera_device})")
|
||||||
|
|
||||||
# initialize DDS
|
# initialize DDS
|
||||||
ChannelFactoryInitialize(0)
|
ChannelFactoryInitialize(0)
|
||||||
|
|
||||||
@@ -206,6 +235,8 @@ def main() -> None:
|
|||||||
shutdown_event.set()
|
shutdown_event.set()
|
||||||
ctx.term() # terminates blocking zmq.recv() calls
|
ctx.term() # terminates blocking zmq.recv() calls
|
||||||
t_state.join(timeout=2.0)
|
t_state.join(timeout=2.0)
|
||||||
|
if camera_thread is not None:
|
||||||
|
camera_thread.join(timeout=2.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -14,27 +14,67 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import struct
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lerobot.cameras.utils import make_cameras_from_configs
|
from lerobot.cameras.utils import make_cameras_from_configs
|
||||||
from lerobot.envs.factory import make_env
|
from lerobot.envs.factory import make_env
|
||||||
from lerobot.processor import RobotAction, RobotObservation
|
from lerobot.processor import RobotAction, RobotObservation
|
||||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex, G1_29_JointIndex
|
from lerobot.robots.unitree_g1.g1_kinematics import G1_29_ArmIK
|
||||||
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
|
from lerobot.robots.unitree_g1.g1_utils import (
|
||||||
|
REMOTE_AXES,
|
||||||
|
REMOTE_KEYS,
|
||||||
|
G1_29_JointArmIndex,
|
||||||
|
G1_29_JointIndex,
|
||||||
|
default_remote_input,
|
||||||
|
make_locomotion_controller,
|
||||||
|
)
|
||||||
|
from lerobot.utils.import_utils import _unitree_sdk_available
|
||||||
|
|
||||||
from ..robot import Robot
|
from ..robot import Robot
|
||||||
from .config_unitree_g1 import UnitreeG1Config
|
from .config_unitree_g1 import UnitreeG1Config
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _unitree_sdk_available:
|
||||||
|
from unitree_sdk2py.core.channel import (
|
||||||
|
ChannelFactoryInitialize as _SDKChannelFactoryInitialize,
|
||||||
|
ChannelPublisher as _SDKChannelPublisher,
|
||||||
|
ChannelSubscriber as _SDKChannelSubscriber,
|
||||||
|
)
|
||||||
|
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||||
|
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
|
||||||
|
LowCmd_ as hg_LowCmd,
|
||||||
|
LowState_ as hg_LowState,
|
||||||
|
)
|
||||||
|
from unitree_sdk2py.utils.crc import CRC
|
||||||
|
else:
|
||||||
|
_SDKChannelFactoryInitialize = None
|
||||||
|
_SDKChannelPublisher = None
|
||||||
|
_SDKChannelSubscriber = None
|
||||||
|
unitree_hg_msg_dds__LowCmd_ = None
|
||||||
|
hg_LowCmd = None
|
||||||
|
hg_LowState = None
|
||||||
|
CRC = None
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class LocomotionController(Protocol):
|
||||||
|
control_dt: float
|
||||||
|
|
||||||
|
def run_step(self, action: dict, lowstate) -> dict: ...
|
||||||
|
|
||||||
|
def reset(self) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
# DDS topic names follow Unitree SDK naming conventions
|
# DDS topic names follow Unitree SDK naming conventions
|
||||||
# ruff: noqa: N816
|
# ruff: noqa: N816
|
||||||
kTopicLowCommand_Debug = "rt/lowcmd"
|
kTopicLowCommand_Debug = "rt/lowcmd"
|
||||||
@@ -63,7 +103,7 @@ class IMUState:
|
|||||||
class G1_29_LowState: # noqa: N801
|
class G1_29_LowState: # noqa: N801
|
||||||
motor_state: list[MotorState] = field(default_factory=lambda: [MotorState() for _ in G1_29_JointIndex])
|
motor_state: list[MotorState] = field(default_factory=lambda: [MotorState() for _ in G1_29_JointIndex])
|
||||||
imu_state: IMUState = field(default_factory=IMUState)
|
imu_state: IMUState = field(default_factory=IMUState)
|
||||||
wireless_remote: Any = None # Raw wireless remote data
|
wireless_remote: bytes | None = None # Raw wireless remote data
|
||||||
mode_machine: int = 0 # Robot mode
|
mode_machine: int = 0 # Robot mode
|
||||||
|
|
||||||
|
|
||||||
@@ -71,25 +111,6 @@ class UnitreeG1(Robot):
|
|||||||
config_class = UnitreeG1Config
|
config_class = UnitreeG1Config
|
||||||
name = "unitree_g1"
|
name = "unitree_g1"
|
||||||
|
|
||||||
# unitree remote controller
|
|
||||||
class RemoteController:
|
|
||||||
def __init__(self):
|
|
||||||
self.lx = 0
|
|
||||||
self.ly = 0
|
|
||||||
self.rx = 0
|
|
||||||
self.ry = 0
|
|
||||||
self.button = [0] * 16
|
|
||||||
|
|
||||||
def set(self, data):
|
|
||||||
# wireless_remote
|
|
||||||
keys = struct.unpack("H", data[2:4])[0]
|
|
||||||
for i in range(16):
|
|
||||||
self.button[i] = (keys & (1 << i)) >> i
|
|
||||||
self.lx = struct.unpack("f", data[4:8])[0]
|
|
||||||
self.rx = struct.unpack("f", data[8:12])[0]
|
|
||||||
self.ry = struct.unpack("f", data[12:16])[0]
|
|
||||||
self.ly = struct.unpack("f", data[20:24])[0]
|
|
||||||
|
|
||||||
def __init__(self, config: UnitreeG1Config):
|
def __init__(self, config: UnitreeG1Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@@ -103,11 +124,9 @@ class UnitreeG1(Robot):
|
|||||||
|
|
||||||
# Import channel classes based on mode
|
# Import channel classes based on mode
|
||||||
if config.is_simulation:
|
if config.is_simulation:
|
||||||
from unitree_sdk2py.core.channel import (
|
self._ChannelFactoryInitialize = _SDKChannelFactoryInitialize
|
||||||
ChannelFactoryInitialize,
|
self._ChannelPublisher = _SDKChannelPublisher
|
||||||
ChannelPublisher,
|
self._ChannelSubscriber = _SDKChannelSubscriber
|
||||||
ChannelSubscriber,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
from lerobot.robots.unitree_g1.unitree_sdk2_socket import (
|
from lerobot.robots.unitree_g1.unitree_sdk2_socket import (
|
||||||
ChannelFactoryInitialize,
|
ChannelFactoryInitialize,
|
||||||
@@ -115,22 +134,30 @@ class UnitreeG1(Robot):
|
|||||||
ChannelSubscriber,
|
ChannelSubscriber,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store for use in connect()
|
self._ChannelFactoryInitialize = ChannelFactoryInitialize
|
||||||
self._ChannelFactoryInitialize = ChannelFactoryInitialize
|
self._ChannelPublisher = ChannelPublisher
|
||||||
self._ChannelPublisher = ChannelPublisher
|
self._ChannelSubscriber = ChannelSubscriber
|
||||||
self._ChannelSubscriber = ChannelSubscriber
|
|
||||||
|
|
||||||
# Initialize state variables
|
# Initialize state variables
|
||||||
self.sim_env = None
|
self.sim_env = None
|
||||||
self._env_wrapper = None
|
self._env_wrapper = None
|
||||||
self._lowstate = None
|
self._lowstate = None
|
||||||
|
self._lowstate_lock = threading.Lock()
|
||||||
self._shutdown_event = threading.Event()
|
self._shutdown_event = threading.Event()
|
||||||
self.subscribe_thread = None
|
self.subscribe_thread = None
|
||||||
self.remote_controller = self.RemoteController()
|
|
||||||
|
|
||||||
self.arm_ik = G1_29_ArmIK()
|
self.arm_ik = G1_29_ArmIK() if config.gravity_compensation else None
|
||||||
|
|
||||||
def _subscribe_motor_state(self): # polls robot state @ 250Hz
|
# Lower-body controller loaded dynamically
|
||||||
|
self.controller: LocomotionController | None = make_locomotion_controller(config.controller)
|
||||||
|
|
||||||
|
# Controller thread state
|
||||||
|
self._controller_thread = None
|
||||||
|
self._controller_action_lock = threading.Lock()
|
||||||
|
self.controller_input = default_remote_input()
|
||||||
|
self.controller_output = {}
|
||||||
|
|
||||||
|
def _subscribe_lowstate(self): # polls robot state @ 250Hz
|
||||||
while not self._shutdown_event.is_set():
|
while not self._shutdown_event.is_set():
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
@@ -143,11 +170,11 @@ class UnitreeG1(Robot):
|
|||||||
lowstate = G1_29_LowState()
|
lowstate = G1_29_LowState()
|
||||||
|
|
||||||
# Capture motor states using jointindex
|
# Capture motor states using jointindex
|
||||||
for id in G1_29_JointIndex:
|
for joint in G1_29_JointIndex:
|
||||||
lowstate.motor_state[id].q = msg.motor_state[id].q
|
lowstate.motor_state[joint].q = msg.motor_state[joint].q
|
||||||
lowstate.motor_state[id].dq = msg.motor_state[id].dq
|
lowstate.motor_state[joint].dq = msg.motor_state[joint].dq
|
||||||
lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est
|
lowstate.motor_state[joint].tau_est = msg.motor_state[joint].tau_est
|
||||||
lowstate.motor_state[id].temperature = msg.motor_state[id].temperature
|
lowstate.motor_state[joint].temperature = msg.motor_state[joint].temperature
|
||||||
|
|
||||||
# Capture IMU state
|
# Capture IMU state
|
||||||
lowstate.imu_state.quaternion = list(msg.imu_state.quaternion)
|
lowstate.imu_state.quaternion = list(msg.imu_state.quaternion)
|
||||||
@@ -162,31 +189,106 @@ class UnitreeG1(Robot):
|
|||||||
# Capture mode_machine
|
# Capture mode_machine
|
||||||
lowstate.mode_machine = msg.mode_machine
|
lowstate.mode_machine = msg.mode_machine
|
||||||
|
|
||||||
self._lowstate = lowstate
|
with self._lowstate_lock:
|
||||||
|
self._lowstate = lowstate
|
||||||
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
all_t_elapsed = current_time - start_time
|
all_t_elapsed = current_time - start_time
|
||||||
sleep_time = max(0, (self.control_dt - all_t_elapsed)) # maintain constant control dt
|
sleep_time = max(0, (self.control_dt - all_t_elapsed)) # maintain constant control dt
|
||||||
time.sleep(sleep_time)
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
|
def publish_lowcmd(
|
||||||
|
self,
|
||||||
|
action: RobotAction,
|
||||||
|
kp: np.ndarray | list[float] | None = None,
|
||||||
|
kd: np.ndarray | list[float] | None = None,
|
||||||
|
tau: np.ndarray | list[float] | None = None,
|
||||||
|
) -> None: # writes robot command whenever requested
|
||||||
|
for motor in G1_29_JointIndex:
|
||||||
|
key = f"{motor.name}.q"
|
||||||
|
if key in action:
|
||||||
|
self.msg.motor_cmd[motor.value].q = action[key]
|
||||||
|
self.msg.motor_cmd[motor.value].qd = 0
|
||||||
|
self.msg.motor_cmd[motor.value].kp = (
|
||||||
|
kp[motor.value] if kp is not None else self.kp[motor.value]
|
||||||
|
)
|
||||||
|
self.msg.motor_cmd[motor.value].kd = (
|
||||||
|
kd[motor.value] if kd is not None else self.kd[motor.value]
|
||||||
|
)
|
||||||
|
self.msg.motor_cmd[motor.value].tau = tau[motor.value] if tau is not None else 0.0
|
||||||
|
|
||||||
|
self.msg.crc = self.crc.Crc(self.msg)
|
||||||
|
self.lowcmd_publisher.Write(self.msg)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _cameras_ft(self) -> dict[str, tuple]:
|
||||||
|
return {
|
||||||
|
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||||
|
}
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def observation_features(self) -> dict[str, type | tuple]:
|
||||||
|
return {**self._motors_ft, **self._cameras_ft}
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
|
if self.controller is None:
|
||||||
|
return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
|
||||||
|
|
||||||
def calibrate(self) -> None: # robot is already calibrated
|
arm_features = {f"{G1_29_JointArmIndex(motor).name}.q": float for motor in G1_29_JointArmIndex}
|
||||||
|
remote_features = dict.fromkeys(REMOTE_AXES, float)
|
||||||
|
return {**arm_features, **remote_features}
|
||||||
|
|
||||||
|
def _controller_loop(self):
|
||||||
|
"""Background thread that runs controller at policy's control_dt."""
|
||||||
|
control_dt = self.controller.control_dt
|
||||||
|
logger.info(f"Controller loop starting with control_dt={control_dt} ({1.0 / control_dt:.1f}Hz)")
|
||||||
|
|
||||||
|
loop_count = 0
|
||||||
|
last_log_time = time.time()
|
||||||
|
|
||||||
|
while not self._shutdown_event.is_set():
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
with self._lowstate_lock:
|
||||||
|
lowstate = self._lowstate
|
||||||
|
|
||||||
|
if lowstate is not None and self.controller is not None:
|
||||||
|
loop_count += 1
|
||||||
|
if time.time() - last_log_time >= 5.0: # Log every 5 seconds
|
||||||
|
actual_hz = loop_count / (time.time() - last_log_time)
|
||||||
|
logger.info(
|
||||||
|
f"Controller actual rate: {actual_hz:.1f}Hz (target: {1.0 / control_dt:.1f}Hz)"
|
||||||
|
)
|
||||||
|
loop_count = 0
|
||||||
|
last_log_time = time.time()
|
||||||
|
# Read controller input snapshot
|
||||||
|
with self._controller_action_lock:
|
||||||
|
controller_input = dict(self.controller_input)
|
||||||
|
|
||||||
|
# Run controller step
|
||||||
|
controller_action = self.controller.run_step(controller_input, lowstate)
|
||||||
|
|
||||||
|
# Write controller output snapshot
|
||||||
|
with self._controller_action_lock:
|
||||||
|
self.controller_output = dict(controller_action)
|
||||||
|
|
||||||
|
ctrl_kp = self.controller.kp if hasattr(self.controller, "kp") else None
|
||||||
|
ctrl_kd = self.controller.kd if hasattr(self.controller, "kd") else None
|
||||||
|
self.publish_lowcmd(controller_action, kp=ctrl_kp, kd=ctrl_kd)
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
sleep_time = max(0, control_dt - elapsed)
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
|
def calibrate(self) -> None:
|
||||||
|
# TODO: implement g1_29 calibration
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def configure(self) -> None:
|
def configure(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def connect(self, calibrate: bool = True) -> None: # connect to DDS
|
def connect(self, calibrate: bool = True) -> None: # connect to DDS
|
||||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
|
||||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
|
|
||||||
LowCmd_ as hg_LowCmd,
|
|
||||||
LowState_ as hg_LowState,
|
|
||||||
)
|
|
||||||
from unitree_sdk2py.utils.crc import CRC
|
|
||||||
|
|
||||||
# Initialize DDS channel and simulation environment
|
# Initialize DDS channel and simulation environment
|
||||||
if self.config.is_simulation:
|
if self.config.is_simulation:
|
||||||
self._ChannelFactoryInitialize(0, "lo")
|
self._ChannelFactoryInitialize(0, "lo")
|
||||||
@@ -194,7 +296,7 @@ class UnitreeG1(Robot):
|
|||||||
# Extract the actual gym env from the dict structure
|
# Extract the actual gym env from the dict structure
|
||||||
self.sim_env = self._env_wrapper["hub_env"][0].envs[0]
|
self.sim_env = self._env_wrapper["hub_env"][0].envs[0]
|
||||||
else:
|
else:
|
||||||
self._ChannelFactoryInitialize(0)
|
self._ChannelFactoryInitialize(0, config=self.config)
|
||||||
|
|
||||||
# Initialize direct motor control interface
|
# Initialize direct motor control interface
|
||||||
self.lowcmd_publisher = self._ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
self.lowcmd_publisher = self._ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||||
@@ -203,7 +305,7 @@ class UnitreeG1(Robot):
|
|||||||
self.lowstate_subscriber.Init()
|
self.lowstate_subscriber.Init()
|
||||||
|
|
||||||
# Start subscribe thread to read robot state
|
# Start subscribe thread to read robot state
|
||||||
self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state)
|
self.subscribe_thread = threading.Thread(target=self._subscribe_lowstate)
|
||||||
self.subscribe_thread.start()
|
self.subscribe_thread.start()
|
||||||
|
|
||||||
# Connect cameras
|
# Connect cameras
|
||||||
@@ -220,25 +322,53 @@ class UnitreeG1(Robot):
|
|||||||
|
|
||||||
# Wait for first state message to arrive
|
# Wait for first state message to arrive
|
||||||
lowstate = None
|
lowstate = None
|
||||||
|
deadline = time.time() + 10.0
|
||||||
while lowstate is None:
|
while lowstate is None:
|
||||||
lowstate = self._lowstate
|
with self._lowstate_lock:
|
||||||
|
lowstate = self._lowstate
|
||||||
if lowstate is None:
|
if lowstate is None:
|
||||||
|
if time.time() > deadline:
|
||||||
|
raise TimeoutError("Timed out waiting for robot state (10s)")
|
||||||
|
logger.warning("[UnitreeG1] Waiting for robot state...")
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
logger.warning("[UnitreeG1] Waiting for robot state...")
|
logger.info("[UnitreeG1] Connected to robot.")
|
||||||
logger.warning("[UnitreeG1] Connected to robot.")
|
|
||||||
self.msg.mode_machine = lowstate.mode_machine
|
self.msg.mode_machine = lowstate.mode_machine
|
||||||
|
|
||||||
# Initialize all motors with unified kp/kd from config
|
|
||||||
self.kp = np.array(self.config.kp, dtype=np.float32)
|
self.kp = np.array(self.config.kp, dtype=np.float32)
|
||||||
self.kd = np.array(self.config.kd, dtype=np.float32)
|
self.kd = np.array(self.config.kd, dtype=np.float32)
|
||||||
|
|
||||||
for id in G1_29_JointIndex:
|
for joint in G1_29_JointIndex:
|
||||||
self.msg.motor_cmd[id].mode = 1
|
self.msg.motor_cmd[joint].mode = 1
|
||||||
self.msg.motor_cmd[id].kp = self.kp[id.value]
|
self.msg.motor_cmd[joint].kp = self.kp[joint.value]
|
||||||
self.msg.motor_cmd[id].kd = self.kd[id.value]
|
self.msg.motor_cmd[joint].kd = self.kd[joint.value]
|
||||||
self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q
|
self.msg.motor_cmd[joint].q = lowstate.motor_state[joint.value].q
|
||||||
|
|
||||||
|
# Start controller thread if enabled
|
||||||
|
if self.controller is not None:
|
||||||
|
self._controller_thread = threading.Thread(target=self._controller_loop, daemon=True)
|
||||||
|
self._controller_thread.start()
|
||||||
|
fps = int(1.0 / self.controller.control_dt)
|
||||||
|
logger.info(f"Controller thread started ({fps}Hz)")
|
||||||
|
|
||||||
|
def _send_zero_torque(self) -> None:
|
||||||
|
"""Send a zero-gain command to make joints passive before shutting down."""
|
||||||
|
try:
|
||||||
|
with self._lowstate_lock:
|
||||||
|
lowstate = self._lowstate
|
||||||
|
if lowstate is None:
|
||||||
|
return
|
||||||
|
action = {f"{motor.name}.q": lowstate.motor_state[motor.value].q for motor in G1_29_JointIndex}
|
||||||
|
zero_gains = np.zeros(29, dtype=np.float32)
|
||||||
|
self.publish_lowcmd(action, kp=zero_gains, kd=zero_gains, tau=zero_gains)
|
||||||
|
logger.info("Sent zero-torque command for safe shutdown")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to send zero-torque on disconnect: {e}")
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
|
# Put robot in passive mode before stopping threads
|
||||||
|
if not self.config.is_simulation:
|
||||||
|
self._send_zero_torque()
|
||||||
|
|
||||||
# Signal thread to stop and unblock any waits
|
# Signal thread to stop and unblock any waits
|
||||||
self._shutdown_event.set()
|
self._shutdown_event.set()
|
||||||
|
|
||||||
@@ -248,6 +378,12 @@ class UnitreeG1(Robot):
|
|||||||
if self.subscribe_thread.is_alive():
|
if self.subscribe_thread.is_alive():
|
||||||
logger.warning("Subscribe thread did not stop cleanly")
|
logger.warning("Subscribe thread did not stop cleanly")
|
||||||
|
|
||||||
|
# Wait for controller thread to finish
|
||||||
|
if self._controller_thread is not None:
|
||||||
|
self._controller_thread.join(timeout=2.0)
|
||||||
|
if self._controller_thread.is_alive():
|
||||||
|
logger.warning("Controller thread did not stop cleanly")
|
||||||
|
|
||||||
# Close simulation environment
|
# Close simulation environment
|
||||||
if self.config.is_simulation and self.sim_env is not None:
|
if self.config.is_simulation and self.sim_env is not None:
|
||||||
try:
|
try:
|
||||||
@@ -274,7 +410,8 @@ class UnitreeG1(Robot):
|
|||||||
cam.disconnect()
|
cam.disconnect()
|
||||||
|
|
||||||
def get_observation(self) -> RobotObservation:
|
def get_observation(self) -> RobotObservation:
|
||||||
lowstate = self._lowstate
|
with self._lowstate_lock:
|
||||||
|
lowstate = self._lowstate
|
||||||
if lowstate is None:
|
if lowstate is None:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@@ -313,14 +450,9 @@ class UnitreeG1(Robot):
|
|||||||
obs["imu.rpy.pitch"] = lowstate.imu_state.rpy[1]
|
obs["imu.rpy.pitch"] = lowstate.imu_state.rpy[1]
|
||||||
obs["imu.rpy.yaw"] = lowstate.imu_state.rpy[2]
|
obs["imu.rpy.yaw"] = lowstate.imu_state.rpy[2]
|
||||||
|
|
||||||
# Controller - parse wireless_remote and add to obs
|
# Wireless remote (raw bytes for teleoperator)
|
||||||
if lowstate.wireless_remote and len(lowstate.wireless_remote) >= 24:
|
if lowstate.wireless_remote:
|
||||||
self.remote_controller.set(lowstate.wireless_remote)
|
obs["wireless_remote"] = lowstate.wireless_remote
|
||||||
obs["remote.buttons"] = self.remote_controller.button.copy()
|
|
||||||
obs["remote.lx"] = self.remote_controller.lx
|
|
||||||
obs["remote.ly"] = self.remote_controller.ly
|
|
||||||
obs["remote.rx"] = self.remote_controller.rx
|
|
||||||
obs["remote.ry"] = self.remote_controller.ry
|
|
||||||
|
|
||||||
# Cameras - read images from ZMQ cameras
|
# Cameras - read images from ZMQ cameras
|
||||||
for cam_name, cam in self._cameras.items():
|
for cam_name, cam in self._cameras.items():
|
||||||
@@ -328,73 +460,63 @@ class UnitreeG1(Robot):
|
|||||||
|
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
|
def send_action(self, action: RobotAction) -> RobotAction:
|
||||||
|
action_to_publish = action
|
||||||
|
if self.controller is not None:
|
||||||
|
# Controller thread owns legs/waist. Here we only update joystick inputs
|
||||||
|
# and publish arm targets from the teleoperator.
|
||||||
|
self._update_controller_action(action)
|
||||||
|
arm_prefixes = tuple(j.name for j in G1_29_JointArmIndex)
|
||||||
|
action_to_publish = {
|
||||||
|
key: value
|
||||||
|
for key, value in action.items()
|
||||||
|
if key.endswith(".q") and key.startswith(arm_prefixes)
|
||||||
|
}
|
||||||
|
|
||||||
|
tau = None
|
||||||
|
if self.config.gravity_compensation and self.arm_ik is not None:
|
||||||
|
tau = np.zeros(29, dtype=np.float32)
|
||||||
|
action_np = np.array(
|
||||||
|
[
|
||||||
|
action_to_publish.get(f"{joint.name}.q", self.msg.motor_cmd[joint.value].q)
|
||||||
|
for joint in G1_29_JointArmIndex
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
arm_tau = self.arm_ik.solve_tau(action_np)
|
||||||
|
arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value
|
||||||
|
for joint in G1_29_JointArmIndex:
|
||||||
|
local_idx = joint.value - arm_start_idx
|
||||||
|
tau[joint.value] = arm_tau[local_idx]
|
||||||
|
|
||||||
|
self.publish_lowcmd(action_to_publish, tau=tau)
|
||||||
|
return action
|
||||||
|
|
||||||
|
def _update_controller_action(self, action: RobotAction) -> None:
|
||||||
|
"""Update controller input state from incoming teleop action."""
|
||||||
|
with self._controller_action_lock:
|
||||||
|
for key in REMOTE_KEYS:
|
||||||
|
if key in action:
|
||||||
|
self.controller_input[key] = action[key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_calibrated(self) -> bool:
|
def is_calibrated(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return self._lowstate is not None
|
with self._lowstate_lock:
|
||||||
|
return self._lowstate is not None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _motors_ft(self) -> dict[str, type]:
|
def _motors_ft(self) -> dict[str, type]:
|
||||||
|
"""Joint positions for all 29 joints."""
|
||||||
return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
|
return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cameras(self) -> dict:
|
def cameras(self) -> dict:
|
||||||
return self._cameras
|
return self._cameras
|
||||||
|
|
||||||
@property
|
|
||||||
def _cameras_ft(self) -> dict[str, tuple]:
|
|
||||||
return {
|
|
||||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
|
||||||
}
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def observation_features(self) -> dict[str, type | tuple]:
|
|
||||||
return {**self._motors_ft, **self._cameras_ft}
|
|
||||||
|
|
||||||
def send_action(self, action: RobotAction) -> RobotAction:
|
|
||||||
for motor in G1_29_JointIndex:
|
|
||||||
key = f"{motor.name}.q"
|
|
||||||
if key in action:
|
|
||||||
self.msg.motor_cmd[motor.value].q = action[key]
|
|
||||||
self.msg.motor_cmd[motor.value].qd = 0
|
|
||||||
self.msg.motor_cmd[motor.value].kp = self.kp[motor.value]
|
|
||||||
self.msg.motor_cmd[motor.value].kd = self.kd[motor.value]
|
|
||||||
self.msg.motor_cmd[motor.value].tau = 0
|
|
||||||
|
|
||||||
if self.config.gravity_compensation:
|
|
||||||
# Build action_np from motor commands (arm joints are indices 15-28, local indices 0-13)
|
|
||||||
action_np = np.zeros(14)
|
|
||||||
arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value # 15
|
|
||||||
for joint in G1_29_JointArmIndex:
|
|
||||||
local_idx = joint.value - arm_start_idx
|
|
||||||
action_np[local_idx] = self.msg.motor_cmd[joint.value].q
|
|
||||||
tau = self.arm_ik.solve_tau(action_np)
|
|
||||||
|
|
||||||
# Apply tau back to motor commands
|
|
||||||
for joint in G1_29_JointArmIndex:
|
|
||||||
local_idx = joint.value - arm_start_idx
|
|
||||||
self.msg.motor_cmd[joint.value].tau = tau[local_idx]
|
|
||||||
|
|
||||||
self.msg.crc = self.crc.Crc(self.msg)
|
|
||||||
self.lowcmd_publisher.Write(self.msg)
|
|
||||||
return action
|
|
||||||
|
|
||||||
def get_gravity_orientation(self, quaternion): # get gravity orientation from quaternion
|
|
||||||
"""Get gravity orientation from quaternion."""
|
|
||||||
qw = quaternion[0]
|
|
||||||
qx = quaternion[1]
|
|
||||||
qy = quaternion[2]
|
|
||||||
qz = quaternion[3]
|
|
||||||
|
|
||||||
gravity_orientation = np.zeros(3)
|
|
||||||
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
|
|
||||||
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
|
|
||||||
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
|
|
||||||
return gravity_orientation
|
|
||||||
|
|
||||||
def reset(
|
def reset(
|
||||||
self,
|
self,
|
||||||
control_dt: float | None = None,
|
control_dt: float | None = None,
|
||||||
@@ -407,15 +529,9 @@ class UnitreeG1(Robot):
|
|||||||
|
|
||||||
if self.config.is_simulation and self.sim_env is not None:
|
if self.config.is_simulation and self.sim_env is not None:
|
||||||
self.sim_env.reset()
|
self.sim_env.reset()
|
||||||
|
self.publish_lowcmd(
|
||||||
for motor in G1_29_JointIndex:
|
{f"{motor.name}.q": float(default_positions[motor.value]) for motor in G1_29_JointIndex}
|
||||||
self.msg.motor_cmd[motor.value].q = default_positions[motor.value]
|
)
|
||||||
self.msg.motor_cmd[motor.value].qd = 0
|
|
||||||
self.msg.motor_cmd[motor.value].kp = self.kp[motor.value]
|
|
||||||
self.msg.motor_cmd[motor.value].kd = self.kd[motor.value]
|
|
||||||
self.msg.motor_cmd[motor.value].tau = 0
|
|
||||||
self.msg.crc = self.crc.Crc(self.msg)
|
|
||||||
self.lowcmd_publisher.Write(self.msg)
|
|
||||||
else:
|
else:
|
||||||
total_time = 3.0
|
total_time = 3.0
|
||||||
num_steps = int(total_time / control_dt)
|
num_steps = int(total_time / control_dt)
|
||||||
@@ -446,4 +562,8 @@ class UnitreeG1(Robot):
|
|||||||
sleep_time = max(0, control_dt - elapsed)
|
sleep_time = max(0, control_dt - elapsed)
|
||||||
time.sleep(sleep_time)
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
|
# Reset controller internal state (gait phase, obs history, etc.)
|
||||||
|
if self.controller is not None and hasattr(self.controller, "reset"):
|
||||||
|
self.controller.reset()
|
||||||
|
|
||||||
logger.info("Reached default position")
|
logger.info("Reached default position")
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ import zmq
|
|||||||
|
|
||||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||||
|
|
||||||
|
# Module-level ZMQ state mirrors the Unitree SDK's global ChannelFactory Singleton.
|
||||||
|
# Only one robot connection per process is supported.
|
||||||
_ctx: zmq.Context | None = None
|
_ctx: zmq.Context | None = None
|
||||||
_lowcmd_sock: zmq.Socket | None = None
|
_lowcmd_sock: zmq.Socket | None = None
|
||||||
_lowstate_sock: zmq.Socket | None = None
|
_lowstate_sock: zmq.Socket | None = None
|
||||||
@@ -97,17 +99,22 @@ def lowcmd_to_dict(topic: str, msg: Any) -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def ChannelFactoryInitialize(*args: Any, **kwargs: Any) -> None: # noqa: N802
|
def ChannelFactoryInitialize(domain_id: int = 0, config: Any = None) -> None: # noqa: N802
|
||||||
"""
|
"""
|
||||||
Initialize ZMQ sockets for robot communication.
|
Initialize ZMQ sockets for robot communication.
|
||||||
|
|
||||||
This function mimics the Unitree SDK's ChannelFactoryInitialize but uses
|
This function mimics the Unitree SDK's ChannelFactoryInitialize but uses
|
||||||
ZMQ sockets to connect to the robot server bridge instead of DDS.
|
ZMQ sockets to connect to the robot server bridge instead of DDS.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain_id: Ignored (for API compatibility with Unitree SDK)
|
||||||
|
config: UnitreeG1Config instance with robot_ip
|
||||||
"""
|
"""
|
||||||
global _ctx, _lowcmd_sock, _lowstate_sock
|
global _ctx, _lowcmd_sock, _lowstate_sock
|
||||||
|
|
||||||
# read socket config
|
# read socket config
|
||||||
config = UnitreeG1Config()
|
if config is None:
|
||||||
|
config = UnitreeG1Config()
|
||||||
robot_ip = config.robot_ip
|
robot_ip = config.robot_ip
|
||||||
|
|
||||||
ctx = zmq.Context.instance()
|
ctx = zmq.Context.instance()
|
||||||
|
|||||||
@@ -369,6 +369,8 @@ def record_loop(
|
|||||||
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
|
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
|
||||||
|
|
||||||
elif policy is None and isinstance(teleop, Teleoperator):
|
elif policy is None and isinstance(teleop, Teleoperator):
|
||||||
|
if robot.name == "unitree_g1":
|
||||||
|
teleop.send_feedback(obs)
|
||||||
act = teleop.get_action()
|
act = teleop.get_action()
|
||||||
|
|
||||||
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
|
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
|
||||||
@@ -556,10 +558,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
|||||||
):
|
):
|
||||||
log_say("Reset the environment", cfg.play_sounds)
|
log_say("Reset the environment", cfg.play_sounds)
|
||||||
|
|
||||||
# reset g1 robot
|
|
||||||
if robot.name == "unitree_g1":
|
|
||||||
robot.reset()
|
|
||||||
|
|
||||||
record_loop(
|
record_loop(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
events=events,
|
events=events,
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ import rerun as rr
|
|||||||
|
|
||||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||||
|
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
RobotAction,
|
RobotAction,
|
||||||
@@ -153,7 +154,6 @@ def teleop_loop(
|
|||||||
|
|
||||||
display_len = max(len(key) for key in robot.action_features)
|
display_len = max(len(key) for key in robot.action_features)
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
loop_start = time.perf_counter()
|
loop_start = time.perf_counter()
|
||||||
|
|
||||||
@@ -163,6 +163,9 @@ def teleop_loop(
|
|||||||
# given that it is the identity processor as default
|
# given that it is the identity processor as default
|
||||||
obs = robot.get_observation()
|
obs = robot.get_observation()
|
||||||
|
|
||||||
|
if robot.name == "unitree_g1":
|
||||||
|
teleop.send_feedback(obs)
|
||||||
|
|
||||||
# Get teleop action
|
# Get teleop action
|
||||||
raw_action = teleop.get_action()
|
raw_action = teleop.get_action()
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,8 @@ from tqdm import tqdm
|
|||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.datasets.factory import make_dataset
|
from lerobot.datasets.factory import make_dataset
|
||||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
|
||||||
|
from lerobot.datasets.sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
|
||||||
from lerobot.datasets.utils import cycle
|
from lerobot.datasets.utils import cycle
|
||||||
from lerobot.envs.factory import make_env, make_env_pre_post_processors
|
from lerobot.envs.factory import make_env, make_env_pre_post_processors
|
||||||
from lerobot.envs.utils import close_envs
|
from lerobot.envs.utils import close_envs
|
||||||
@@ -209,7 +210,11 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
|
|
||||||
# Use accelerator's device
|
# Use accelerator's device
|
||||||
device = accelerator.device
|
device = accelerator.device
|
||||||
torch.backends.cudnn.benchmark = True
|
if cfg.cudnn_deterministic:
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
else:
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||||
@@ -339,13 +344,25 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
# create dataloader for offline training
|
# create dataloader for offline training
|
||||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
drop_n_last = getattr(cfg.policy, "drop_n_last_frames", 0)
|
||||||
|
|
||||||
|
if isinstance(dataset, NewMultiLeRobotDataset):
|
||||||
|
shuffle = False
|
||||||
|
sampler = WeightedEpisodeAwareSampler(
|
||||||
|
dataset.meta.episodes["dataset_from_index"],
|
||||||
|
dataset.meta.episodes["dataset_to_index"],
|
||||||
|
dataset_membership=dataset.meta.episodes["dataset_source"],
|
||||||
|
dataset_weights=dataset.dataset_weights,
|
||||||
|
drop_n_last_frames=drop_n_last,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
elif drop_n_last > 0:
|
||||||
shuffle = False
|
shuffle = False
|
||||||
sampler = EpisodeAwareSampler(
|
sampler = EpisodeAwareSampler(
|
||||||
dataset.meta.episodes["dataset_from_index"],
|
dataset.meta.episodes["dataset_from_index"],
|
||||||
dataset.meta.episodes["dataset_to_index"],
|
dataset.meta.episodes["dataset_to_index"],
|
||||||
episode_indices_to_use=dataset.episodes,
|
episode_indices_to_use=dataset.episodes,
|
||||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
drop_n_last_frames=drop_n_last,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -356,7 +373,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
dataset,
|
dataset,
|
||||||
num_workers=cfg.num_workers,
|
num_workers=cfg.num_workers,
|
||||||
batch_size=cfg.batch_size,
|
batch_size=cfg.batch_size,
|
||||||
shuffle=shuffle and not cfg.dataset.streaming,
|
shuffle=shuffle and not getattr(cfg.dataset, "streaming", False),
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=device.type == "cuda",
|
pin_memory=device.type == "cuda",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
|
|||||||
@@ -19,3 +19,13 @@ from .exo_calib import ExoskeletonCalibration, ExoskeletonJointCalibration
|
|||||||
from .exo_ik import ExoskeletonIKHelper
|
from .exo_ik import ExoskeletonIKHelper
|
||||||
from .exo_serial import ExoskeletonArm
|
from .exo_serial import ExoskeletonArm
|
||||||
from .unitree_g1 import UnitreeG1Teleoperator
|
from .unitree_g1 import UnitreeG1Teleoperator
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ExoskeletonArmPortConfig",
|
||||||
|
"ExoskeletonCalibration",
|
||||||
|
"ExoskeletonIKHelper",
|
||||||
|
"ExoskeletonJointCalibration",
|
||||||
|
"ExoskeletonArm",
|
||||||
|
"UnitreeG1Teleoperator",
|
||||||
|
"UnitreeG1TeleoperatorConfig",
|
||||||
|
]
|
||||||
|
|||||||
@@ -35,6 +35,9 @@ import serial
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
ADC_MAX = 2**12 - 1
|
||||||
|
ADC_HALF = ADC_MAX / 2
|
||||||
|
|
||||||
# exoskeleton joint names -> ADC channel pairs. TODO: add wrist pitch and wrist yaw
|
# exoskeleton joint names -> ADC channel pairs. TODO: add wrist pitch and wrist yaw
|
||||||
JOINTS = {
|
JOINTS = {
|
||||||
"shoulder_pitch": (0, 1),
|
"shoulder_pitch": (0, 1),
|
||||||
@@ -59,7 +62,7 @@ class ExoskeletonCalibration:
|
|||||||
|
|
||||||
version: int = 2
|
version: int = 2
|
||||||
side: str = ""
|
side: str = ""
|
||||||
adc_max: int = 2**12 - 1
|
adc_max: int = ADC_MAX
|
||||||
joints: list[ExoskeletonJointCalibration] = field(default_factory=list)
|
joints: list[ExoskeletonJointCalibration] = field(default_factory=list)
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
@@ -92,7 +95,7 @@ class ExoskeletonCalibration:
|
|||||||
return cls(
|
return cls(
|
||||||
version=data.get("version", 2),
|
version=data.get("version", 2),
|
||||||
side=data.get("side", ""),
|
side=data.get("side", ""),
|
||||||
adc_max=data.get("adc_max", 2**12 - 1),
|
adc_max=data.get("adc_max", ADC_MAX),
|
||||||
joints=joints,
|
joints=joints,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -112,11 +115,8 @@ class CalibParams:
|
|||||||
|
|
||||||
|
|
||||||
def normalize_angle(angle: float) -> float:
|
def normalize_angle(angle: float) -> float:
|
||||||
while angle > np.pi:
|
"""Normalize angle to [-pi, pi]."""
|
||||||
angle -= 2 * np.pi
|
return float(np.arctan2(np.sin(angle), np.cos(angle)))
|
||||||
while angle < -np.pi:
|
|
||||||
angle += 2 * np.pi
|
|
||||||
return angle
|
|
||||||
|
|
||||||
|
|
||||||
def joint_z_and_angle(raw16: list[int], j: ExoskeletonJointCalibration) -> tuple[np.ndarray, float]:
|
def joint_z_and_angle(raw16: list[int], j: ExoskeletonJointCalibration) -> tuple[np.ndarray, float]:
|
||||||
@@ -125,7 +125,7 @@ def joint_z_and_angle(raw16: list[int], j: ExoskeletonJointCalibration) -> tuple
|
|||||||
"""
|
"""
|
||||||
pair = JOINTS[j.name]
|
pair = JOINTS[j.name]
|
||||||
s, c = raw16[pair[0]], raw16[pair[1]] # get sin and cos
|
s, c = raw16[pair[0]], raw16[pair[1]] # get sin and cos
|
||||||
p = np.array([float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2]) # center the raw values
|
p = np.array([float(c) - ADC_HALF, float(s) - ADC_HALF]) # center the raw values
|
||||||
z = np.asarray(j.T) @ (
|
z = np.asarray(j.T) @ (
|
||||||
p - np.asarray(j.center_fit)
|
p - np.asarray(j.center_fit)
|
||||||
) # center the ellipse and invert the transformation matrix to get unit circle coords
|
) # center the ellipse and invert the transformation matrix to get unit circle coords
|
||||||
@@ -167,7 +167,7 @@ def run_exo_calibration(
|
|||||||
|
|
||||||
def read_joint_point(raw16: list[int], pair: tuple[int, int]):
|
def read_joint_point(raw16: list[int], pair: tuple[int, int]):
|
||||||
s, c = raw16[pair[0]], raw16[pair[1]]
|
s, c = raw16[pair[0]], raw16[pair[1]]
|
||||||
return float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2, float(s), float(c)
|
return float(c) - ADC_HALF, float(s) - ADC_HALF, float(s), float(c)
|
||||||
|
|
||||||
def select_fit_subset(xs, ys):
|
def select_fit_subset(xs, ys):
|
||||||
"""Select and filter points for ellipse fitting. Trims outliers by radius and downsamples."""
|
"""Select and filter points for ellipse fitting. Trims outliers by radius and downsamples."""
|
||||||
@@ -317,7 +317,7 @@ def run_exo_calibration(
|
|||||||
calib = ExoskeletonCalibration(
|
calib = ExoskeletonCalibration(
|
||||||
version=2,
|
version=2,
|
||||||
side=side,
|
side=side,
|
||||||
adc_max=2**12 - 1,
|
adc_max=ADC_MAX,
|
||||||
joints=[
|
joints=[
|
||||||
ExoskeletonJointCalibration(
|
ExoskeletonJointCalibration(
|
||||||
name=j["name"],
|
name=j["name"],
|
||||||
@@ -367,8 +367,8 @@ def run_exo_calibration(
|
|||||||
state["win_s"].append(s_raw)
|
state["win_s"].append(s_raw)
|
||||||
state["win_c"].append(c_raw)
|
state["win_c"].append(c_raw)
|
||||||
if len(state["win_s"]) >= max(3, params.median_window):
|
if len(state["win_s"]) >= max(3, params.median_window):
|
||||||
state["ys"].append(running_median(state["win_s"]) - (2**12 - 1) / 2)
|
state["ys"].append(running_median(state["win_s"]) - ADC_HALF)
|
||||||
state["xs"].append(running_median(state["win_c"]) - (2**12 - 1) / 2)
|
state["xs"].append(running_median(state["win_c"]) - ADC_HALF)
|
||||||
else:
|
else:
|
||||||
jdata = joints_out[-1]
|
jdata = joints_out[-1]
|
||||||
z = np.array(jdata["T"]) @ (np.array([x_raw, y_raw]) - np.array(jdata["center_fit"]))
|
z = np.array(jdata["T"]) @ (np.array([x_raw, y_raw]) - np.array(jdata["center_fit"]))
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.robots.unitree_g1.g1_kinematics import G1_29_ArmIK
|
||||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex
|
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex
|
||||||
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
|
|
||||||
|
|
||||||
from .exo_calib import JOINTS
|
from .exo_calib import JOINTS
|
||||||
|
|
||||||
|
|||||||
@@ -32,25 +32,29 @@ def parse_raw16(line: bytes) -> list[int] | None:
|
|||||||
if len(parts) < 16:
|
if len(parts) < 16:
|
||||||
return None
|
return None
|
||||||
return [int(x) for x in parts[:16]]
|
return [int(x) for x in parts[:16]]
|
||||||
except Exception:
|
except (ValueError, IndexError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def read_raw_from_serial(ser) -> list[int] | None:
|
def read_raw_from_serial(ser) -> list[int] | None:
|
||||||
"""Read latest sample from serial; if buffer is backed up, keep only the newest."""
|
"""Read latest sample from serial; if buffer is backed up, keep only the newest."""
|
||||||
last = None
|
try:
|
||||||
while ser.in_waiting > 0:
|
last = None
|
||||||
b = ser.readline()
|
while ser.in_waiting > 0:
|
||||||
if not b:
|
b = ser.readline()
|
||||||
break
|
if not b:
|
||||||
raw16 = parse_raw16(b)
|
break
|
||||||
if raw16 is not None:
|
raw16 = parse_raw16(b)
|
||||||
last = raw16
|
if raw16 is not None:
|
||||||
if last is None:
|
last = raw16
|
||||||
b = ser.readline()
|
if last is None:
|
||||||
if b:
|
b = ser.readline()
|
||||||
last = parse_raw16(b)
|
if b:
|
||||||
return last
|
last = parse_raw16(b)
|
||||||
|
return last
|
||||||
|
except serial.SerialException as e:
|
||||||
|
logger.warning(f"Serial read error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -115,5 +119,6 @@ class ExoskeletonArm:
|
|||||||
return {} if raw is None else exo_raw_to_angles(raw, self.calibration)
|
return {} if raw is None else exo_raw_to_angles(raw, self.calibration)
|
||||||
|
|
||||||
def calibrate(self) -> None:
|
def calibrate(self) -> None:
|
||||||
ser = self._ser
|
if not self.is_connected:
|
||||||
self.calibration = run_exo_calibration(ser, self.side, self.calibration_fpath)
|
raise RuntimeError("Cannot calibrate: exoskeleton not connected")
|
||||||
|
self.calibration = run_exo_calibration(self._ser, self.side, self.calibration_fpath)
|
||||||
|
|||||||
@@ -17,9 +17,22 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
from lerobot.robots.unitree_g1.g1_utils import REMOTE_AXES, G1_29_JointArmIndex
|
||||||
from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS
|
from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS
|
||||||
|
from lerobot.utils.import_utils import _unitree_sdk_available
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _unitree_sdk_available:
|
||||||
|
from unitree_sdk2py.utils.joystick import Joystick
|
||||||
|
else:
|
||||||
|
|
||||||
|
class Joystick:
|
||||||
|
def __init__(self):
|
||||||
|
raise ImportError(
|
||||||
|
"unitree_sdk2py is required for RemoteController. Install with: pip install unitree_sdk2py"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
from ..teleoperator import Teleoperator
|
from ..teleoperator import Teleoperator
|
||||||
from .config_unitree_g1 import UnitreeG1TeleoperatorConfig
|
from .config_unitree_g1 import UnitreeG1TeleoperatorConfig
|
||||||
@@ -29,6 +42,120 @@ from .exo_serial import ExoskeletonArm
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteController:
|
||||||
|
"""Unitree remote controller data parser for joystick and button state."""
|
||||||
|
|
||||||
|
# ADC parameters for exoskeleton joystick (12-bit ADC)
|
||||||
|
ADC_MAX = 4095
|
||||||
|
ADC_HALF = ADC_MAX / 2
|
||||||
|
JOYSTICK_X_IDX = 11 # X axis in raw ADC array
|
||||||
|
JOYSTICK_BTN_IDX = 12 # Button in raw ADC array
|
||||||
|
JOYSTICK_Y_IDX = 13 # Y axis in raw ADC array
|
||||||
|
|
||||||
|
# Map SDK named buttons to positional indices matching the wireless_remote
|
||||||
|
# byte layout (little-endian uint16 from bytes 2-3).
|
||||||
|
_BUTTON_MAP: list[str] = [
|
||||||
|
"RB",
|
||||||
|
"LB",
|
||||||
|
"start",
|
||||||
|
"back",
|
||||||
|
"RT",
|
||||||
|
"LT",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"A",
|
||||||
|
"B",
|
||||||
|
"X",
|
||||||
|
"Y",
|
||||||
|
"up",
|
||||||
|
"right",
|
||||||
|
"down",
|
||||||
|
"left",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.lx = 0.0
|
||||||
|
self.ly = 0.0
|
||||||
|
self.rx = 0.0
|
||||||
|
self.ry = 0.0
|
||||||
|
self.button = [0] * 16
|
||||||
|
self.remote_action = dict.fromkeys(REMOTE_AXES, 0.0)
|
||||||
|
|
||||||
|
# SDK joystick parser for wireless remote bytes
|
||||||
|
self._joystick = Joystick()
|
||||||
|
# Disable axis smoothing and deadzone to preserve raw values
|
||||||
|
for axis in (self._joystick.lx, self._joystick.ly, self._joystick.rx, self._joystick.ry):
|
||||||
|
axis.smooth = 1.0
|
||||||
|
axis.deadzone = 0.0
|
||||||
|
|
||||||
|
# Joystick center calibration (read at connect time)
|
||||||
|
self.left_center_x = self.ADC_HALF
|
||||||
|
self.left_center_y = self.ADC_HALF
|
||||||
|
self.right_center_x = self.ADC_HALF
|
||||||
|
self.right_center_y = self.ADC_HALF
|
||||||
|
|
||||||
|
# Whether to use exo joystick (detected at connect time)
|
||||||
|
self.use_left_exo_joystick = False
|
||||||
|
self.use_right_exo_joystick = False
|
||||||
|
|
||||||
|
def _sync_remote_action(self) -> None:
|
||||||
|
self.remote_action.update(zip(REMOTE_AXES, (self.lx, self.ly, self.rx, self.ry), strict=True))
|
||||||
|
|
||||||
|
def calibrate_center(self, raw16: list[int] | None, side: str) -> None:
|
||||||
|
if raw16 is None or len(raw16) < 16:
|
||||||
|
logger.info(f"{side.capitalize()} exo joystick: no data available")
|
||||||
|
return
|
||||||
|
|
||||||
|
btn_val = raw16[self.JOYSTICK_BTN_IDX]
|
||||||
|
logger.info(f"{side.capitalize()} exo joystick button ADC: {btn_val} (threshold: {self.ADC_HALF})")
|
||||||
|
if btn_val <= self.ADC_HALF:
|
||||||
|
logger.info(f"{side.capitalize()} exo joystick not detected (button below threshold)")
|
||||||
|
return
|
||||||
|
|
||||||
|
x = raw16[self.JOYSTICK_X_IDX]
|
||||||
|
y = raw16[self.JOYSTICK_Y_IDX]
|
||||||
|
if side == "left":
|
||||||
|
self.use_left_exo_joystick = True
|
||||||
|
self.left_center_x, self.left_center_y = x, y
|
||||||
|
else:
|
||||||
|
self.use_right_exo_joystick = True
|
||||||
|
self.right_center_x, self.right_center_y = x, y
|
||||||
|
logger.info(f"{side.capitalize()} exo joystick enabled, center: x={x}, y={y}")
|
||||||
|
|
||||||
|
def set_from_exo(self, raw16: list[int] | None, side: str) -> None:
|
||||||
|
if raw16 is None or len(raw16) < 16:
|
||||||
|
return
|
||||||
|
|
||||||
|
if side == "left":
|
||||||
|
if not self.use_left_exo_joystick:
|
||||||
|
return
|
||||||
|
self.lx = (raw16[self.JOYSTICK_X_IDX] - self.left_center_x) / self.ADC_HALF
|
||||||
|
self.ly = (raw16[self.JOYSTICK_Y_IDX] - self.left_center_y) / self.ADC_HALF
|
||||||
|
self.button[4] = 1 if raw16[self.JOYSTICK_BTN_IDX] < self.ADC_HALF else 0
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.use_right_exo_joystick:
|
||||||
|
return
|
||||||
|
self.rx = (raw16[self.JOYSTICK_X_IDX] - self.right_center_x) / self.ADC_HALF
|
||||||
|
self.ry = (raw16[self.JOYSTICK_Y_IDX] - self.right_center_y) / self.ADC_HALF
|
||||||
|
self.button[0] = 1 if raw16[self.JOYSTICK_BTN_IDX] < self.ADC_HALF else 0
|
||||||
|
|
||||||
|
def set_from_wireless(self, wireless_remote: bytes) -> None:
|
||||||
|
"""Parse Unitree wireless remote raw bytes into joystick + button state."""
|
||||||
|
if len(wireless_remote) < 24:
|
||||||
|
return
|
||||||
|
self._joystick.extract(wireless_remote)
|
||||||
|
|
||||||
|
self.lx = self._joystick.lx.data
|
||||||
|
self.ly = self._joystick.ly.data
|
||||||
|
self.rx = self._joystick.rx.data
|
||||||
|
self.ry = self._joystick.ry.data
|
||||||
|
|
||||||
|
for i, name in enumerate(self._BUTTON_MAP):
|
||||||
|
if name:
|
||||||
|
self.button[i] = getattr(self._joystick, name).data
|
||||||
|
|
||||||
|
|
||||||
class UnitreeG1Teleoperator(Teleoperator):
|
class UnitreeG1Teleoperator(Teleoperator):
|
||||||
"""
|
"""
|
||||||
Bimanual exoskeleton arms teleoperator for Unitree G1 arms.
|
Bimanual exoskeleton arms teleoperator for Unitree G1 arms.
|
||||||
@@ -43,6 +170,13 @@ class UnitreeG1Teleoperator(Teleoperator):
|
|||||||
def __init__(self, config: UnitreeG1TeleoperatorConfig):
|
def __init__(self, config: UnitreeG1TeleoperatorConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
left_exo_enabled = bool(config.left_arm_config.port.strip())
|
||||||
|
right_exo_enabled = bool(config.right_arm_config.port.strip())
|
||||||
|
if left_exo_enabled != right_exo_enabled:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid exo config: set both left/right exo ports, or leave both empty for remote-only mode."
|
||||||
|
)
|
||||||
|
self._arm_control_enabled = left_exo_enabled and right_exo_enabled
|
||||||
|
|
||||||
# Setup calibration directory
|
# Setup calibration directory
|
||||||
self.calibration_dir = (
|
self.calibration_dir = (
|
||||||
@@ -70,24 +204,37 @@ class UnitreeG1Teleoperator(Teleoperator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.ik_helper: ExoskeletonIKHelper | None = None
|
self.ik_helper: ExoskeletonIKHelper | None = None
|
||||||
|
self.remote_controller = RemoteController()
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
return {f"{name}.q": float for name in self._g1_joint_names}
|
remote_features = dict.fromkeys(self.remote_controller.remote_action, float)
|
||||||
|
if not self._arm_control_enabled:
|
||||||
|
return remote_features
|
||||||
|
joint_features = {f"{name}.q": float for name in self._g1_arm_joint_names}
|
||||||
|
return {**joint_features, **remote_features}
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def feedback_features(self) -> dict[str, type]:
|
def feedback_features(self) -> dict[str, type]:
|
||||||
return {}
|
return {"wireless_remote": bytes}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
|
if not self._arm_control_enabled:
|
||||||
|
return True
|
||||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_calibrated(self) -> bool:
|
def is_calibrated(self) -> bool:
|
||||||
|
if not self._arm_control_enabled:
|
||||||
|
return True
|
||||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||||
|
|
||||||
def connect(self, calibrate: bool = True) -> None:
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
|
if not self._arm_control_enabled:
|
||||||
|
logger.warning("Exo ports not fully configured; teleop will send joystick only (no arm actions)")
|
||||||
|
return
|
||||||
|
|
||||||
self.left_arm.connect(calibrate)
|
self.left_arm.connect(calibrate)
|
||||||
self.right_arm.connect(calibrate)
|
self.right_arm.connect(calibrate)
|
||||||
|
|
||||||
@@ -95,6 +242,13 @@ class UnitreeG1Teleoperator(Teleoperator):
|
|||||||
self.ik_helper = ExoskeletonIKHelper(frozen_joints=frozen_joints)
|
self.ik_helper = ExoskeletonIKHelper(frozen_joints=frozen_joints)
|
||||||
logger.info("IK helper initialized")
|
logger.info("IK helper initialized")
|
||||||
|
|
||||||
|
time.sleep(0.1) # Give serial time to populate buffer
|
||||||
|
|
||||||
|
left_raw = self.left_arm.read_raw()
|
||||||
|
right_raw = self.right_arm.read_raw()
|
||||||
|
self.remote_controller.calibrate_center(left_raw, "left")
|
||||||
|
self.remote_controller.calibrate_center(right_raw, "right")
|
||||||
|
|
||||||
def calibrate(self) -> None:
|
def calibrate(self) -> None:
|
||||||
if not self.left_arm.is_calibrated:
|
if not self.left_arm.is_calibrated:
|
||||||
logger.info("Starting calibration for left arm...")
|
logger.info("Starting calibration for left arm...")
|
||||||
@@ -115,12 +269,33 @@ class UnitreeG1Teleoperator(Teleoperator):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def get_action(self) -> dict[str, float]:
|
def get_action(self) -> dict[str, float]:
|
||||||
left_angles = self.left_arm.get_angles()
|
joint_action = {}
|
||||||
right_angles = self.right_arm.get_angles()
|
left_raw = None
|
||||||
return self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles)
|
right_raw = None
|
||||||
|
if self._arm_control_enabled:
|
||||||
|
left_raw = self.left_arm.read_raw()
|
||||||
|
right_raw = self.right_arm.read_raw()
|
||||||
|
|
||||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
left_angles = self.left_arm.get_angles()
|
||||||
raise NotImplementedError("Exoskeleton arms do not support feedback")
|
right_angles = self.right_arm.get_angles()
|
||||||
|
joint_action = self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles)
|
||||||
|
|
||||||
|
# Wireless remote has priority when non-zero; otherwise, use exo joystick.
|
||||||
|
rc = self.remote_controller
|
||||||
|
wireless_active = (
|
||||||
|
abs(rc.lx) > 1e-3 or abs(rc.ly) > 1e-3 or abs(rc.rx) > 1e-3 or abs(rc.ry) > 1e-3
|
||||||
|
) or any(rc.button)
|
||||||
|
if self._arm_control_enabled and not wireless_active:
|
||||||
|
rc.set_from_exo(left_raw, "left")
|
||||||
|
rc.set_from_exo(right_raw, "right")
|
||||||
|
|
||||||
|
rc._sync_remote_action()
|
||||||
|
return {**joint_action, **rc.remote_action}
|
||||||
|
|
||||||
|
def send_feedback(self, feedback: dict[str, Any]) -> None:
|
||||||
|
wireless_remote = feedback.get("wireless_remote")
|
||||||
|
if wireless_remote is not None:
|
||||||
|
self.remote_controller.set_from_wireless(wireless_remote)
|
||||||
|
|
||||||
def disconnect(self) -> None:
|
def disconnect(self) -> None:
|
||||||
self.left_arm.disconnect()
|
self.left_arm.disconnect()
|
||||||
@@ -153,5 +328,5 @@ class UnitreeG1Teleoperator(Teleoperator):
|
|||||||
print("\n\nVisualization stopped.")
|
print("\n\nVisualization stopped.")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def _g1_joint_names(self) -> list[str]:
|
def _g1_arm_joint_names(self) -> list[str]:
|
||||||
return [joint.name for joint in G1_29_JointIndex]
|
return [joint.name for joint in G1_29_JointArmIndex]
|
||||||
|
|||||||
@@ -74,6 +74,8 @@ _peft_available = is_package_available("peft")
|
|||||||
_scipy_available = is_package_available("scipy")
|
_scipy_available = is_package_available("scipy")
|
||||||
_reachy2_sdk_available = is_package_available("reachy2_sdk")
|
_reachy2_sdk_available = is_package_available("reachy2_sdk")
|
||||||
_can_available = is_package_available("python-can", "can")
|
_can_available = is_package_available("python-can", "can")
|
||||||
|
_unitree_sdk_available = is_package_available("unitree-sdk2", "unitree_sdk2py")
|
||||||
|
_pygame_available = is_package_available("pygame")
|
||||||
|
|
||||||
|
|
||||||
def make_device_from_device_class(config: ChoiceRegistry) -> Any:
|
def make_device_from_device_class(config: ChoiceRegistry) -> Any:
|
||||||
|
|||||||
@@ -231,3 +231,39 @@ def test_ready_to_send_observation_with_varying_threshold(robot_client, g_thresh
|
|||||||
robot_client.action_queue.put(act)
|
robot_client.action_queue.put(act)
|
||||||
|
|
||||||
assert robot_client._ready_to_send_observation() is expected
|
assert robot_client._ready_to_send_observation() is expected
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Regression test: robot type registry populated by robot_client imports
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_robot_client_registers_builtin_robot_types():
|
||||||
|
"""Importing robot_client must populate RobotConfig's ChoiceRegistry.
|
||||||
|
|
||||||
|
This is a regression test for a bug introduced in #2425, where removing
|
||||||
|
robot module imports from robot_client.py caused RobotConfig's registry to
|
||||||
|
be empty, breaking CLI argument parsing with:
|
||||||
|
error: argument --robot.type: invalid choice: 'so101_follower' (choose from )
|
||||||
|
|
||||||
|
Robot types are registered via @RobotConfig.register_subclass() decorators
|
||||||
|
at import time, so all supported modules must be explicitly imported.
|
||||||
|
"""
|
||||||
|
import lerobot.async_inference.robot_client # noqa: F401
|
||||||
|
from lerobot.robots.config import RobotConfig
|
||||||
|
|
||||||
|
known_choices = RobotConfig.get_known_choices()
|
||||||
|
|
||||||
|
expected_robot_types = [
|
||||||
|
"so100_follower",
|
||||||
|
"so101_follower",
|
||||||
|
"koch_follower",
|
||||||
|
"omx_follower",
|
||||||
|
"bi_so_follower",
|
||||||
|
]
|
||||||
|
for robot_type in expected_robot_types:
|
||||||
|
assert robot_type in known_choices, (
|
||||||
|
f"Robot type '{robot_type}' is not registered in RobotConfig's ChoiceRegistry. "
|
||||||
|
f"Ensure the corresponding module is imported in robot_client.py. "
|
||||||
|
f"Known choices: {sorted(known_choices)}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -170,6 +170,7 @@ def test_async_read(index_or_path):
|
|||||||
assert isinstance(img, np.ndarray)
|
assert isinstance(img, np.ndarray)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("Skipping test: async_read 0 timeout behavior may be flaky/non-deterministic.")
|
||||||
def test_async_read_timeout():
|
def test_async_read_timeout():
|
||||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0)
|
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0)
|
||||||
|
|
||||||
|
|||||||
267
tests/robots/test_unitree_g1.py
Normal file
267
tests/robots/test_unitree_g1.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Tests for Unitree G1 robot. Meant to be run in an environment where the Unitree SDK is installed."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.utils.import_utils import _unitree_sdk_available
|
||||||
|
|
||||||
|
if not _unitree_sdk_available:
|
||||||
|
pytest.skip("Unitree SDK not available", allow_module_level=True)
|
||||||
|
|
||||||
|
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||||
|
from lerobot.robots.unitree_g1.g1_utils import (
|
||||||
|
NUM_MOTORS,
|
||||||
|
REMOTE_AXES,
|
||||||
|
REMOTE_BUTTONS,
|
||||||
|
REMOTE_KEYS,
|
||||||
|
G1_29_JointArmIndex,
|
||||||
|
G1_29_JointIndex,
|
||||||
|
default_remote_input,
|
||||||
|
get_gravity_orientation,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Unit tests for g1_utils (no SDK needed)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestG1Utils:
|
||||||
|
def test_num_motors(self):
|
||||||
|
assert NUM_MOTORS == 29
|
||||||
|
|
||||||
|
def test_joint_index_count(self):
|
||||||
|
assert len(G1_29_JointIndex) == 29
|
||||||
|
|
||||||
|
def test_joint_arm_index_count(self):
|
||||||
|
assert len(G1_29_JointArmIndex) == 14
|
||||||
|
|
||||||
|
def test_arm_indices_are_subset_of_full(self):
|
||||||
|
full_values = {j.value for j in G1_29_JointIndex}
|
||||||
|
arm_values = {j.value for j in G1_29_JointArmIndex}
|
||||||
|
assert arm_values.issubset(full_values)
|
||||||
|
|
||||||
|
def test_arm_indices_start_at_15(self):
|
||||||
|
assert min(j.value for j in G1_29_JointArmIndex) == 15
|
||||||
|
assert max(j.value for j in G1_29_JointArmIndex) == 28
|
||||||
|
|
||||||
|
def test_enum_naming_consistency(self):
|
||||||
|
"""Verify all wrist joints use consistent PascalCase naming."""
|
||||||
|
wrist_joints = [j for j in G1_29_JointIndex if "Wrist" in j.name]
|
||||||
|
for j in wrist_joints:
|
||||||
|
# Should be "WristYaw", "WristPitch", "WristRoll" — no lowercase after "Wrist"
|
||||||
|
after_wrist = j.name.split("Wrist")[1]
|
||||||
|
assert after_wrist[0].isupper(), f"{j.name} has inconsistent casing after 'Wrist'"
|
||||||
|
|
||||||
|
def test_remote_keys_structure(self):
|
||||||
|
assert len(REMOTE_AXES) == 4
|
||||||
|
assert len(REMOTE_BUTTONS) == 16
|
||||||
|
assert len(REMOTE_KEYS) == 20
|
||||||
|
assert REMOTE_KEYS == REMOTE_AXES + REMOTE_BUTTONS
|
||||||
|
|
||||||
|
def test_default_remote_input(self):
|
||||||
|
d = default_remote_input()
|
||||||
|
assert len(d) == 20
|
||||||
|
assert all(v == 0.0 for v in d.values())
|
||||||
|
assert set(d.keys()) == set(REMOTE_KEYS)
|
||||||
|
|
||||||
|
def test_gravity_orientation_identity(self):
|
||||||
|
"""Quaternion [1, 0, 0, 0] (no rotation) should give gravity along -z."""
|
||||||
|
g = get_gravity_orientation([1.0, 0.0, 0.0, 0.0])
|
||||||
|
assert g.shape == (3,)
|
||||||
|
assert g.dtype == np.float32
|
||||||
|
np.testing.assert_allclose(g, [0.0, 0.0, -1.0], atol=1e-6)
|
||||||
|
|
||||||
|
def test_gravity_orientation_dtype(self):
|
||||||
|
g = get_gravity_orientation(np.array([1.0, 0.0, 0.0, 0.0]))
|
||||||
|
assert g.dtype == np.float32
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Unit tests for UnitreeG1Config (no SDK needed)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnitreeG1Config:
|
||||||
|
def test_default_config(self):
|
||||||
|
cfg = UnitreeG1Config()
|
||||||
|
assert len(cfg.kp) == 29
|
||||||
|
assert len(cfg.kd) == 29
|
||||||
|
assert len(cfg.default_positions) == 29
|
||||||
|
assert cfg.is_simulation is True
|
||||||
|
assert cfg.controller is None
|
||||||
|
assert cfg.gravity_compensation is False
|
||||||
|
|
||||||
|
def test_gains_are_positive(self):
|
||||||
|
cfg = UnitreeG1Config()
|
||||||
|
assert all(v > 0 for v in cfg.kp)
|
||||||
|
assert all(v > 0 for v in cfg.kd)
|
||||||
|
|
||||||
|
def test_config_copies_gains(self):
|
||||||
|
"""Each config instance should have its own copy of gains."""
|
||||||
|
cfg1 = UnitreeG1Config()
|
||||||
|
cfg2 = UnitreeG1Config()
|
||||||
|
cfg1.kp[0] = 999.0
|
||||||
|
assert cfg2.kp[0] != 999.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Robot mock and integration tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_lowstate_msg_mock():
|
||||||
|
"""Create a mock that mimics the SDK LowState_ message."""
|
||||||
|
msg = MagicMock()
|
||||||
|
for i in range(29):
|
||||||
|
motor = MagicMock()
|
||||||
|
motor.q = float(i) * 0.1
|
||||||
|
motor.dq = float(i) * 0.01
|
||||||
|
motor.tau_est = float(i) * 0.001
|
||||||
|
motor.temperature = 30.0 + i
|
||||||
|
msg.motor_state.__getitem__ = lambda self, idx, _motors={}: _motors.setdefault(
|
||||||
|
idx, MagicMock(q=idx * 0.1, dq=idx * 0.01, tau_est=idx * 0.001, temperature=30.0 + idx)
|
||||||
|
)
|
||||||
|
|
||||||
|
msg.imu_state.quaternion = [1.0, 0.0, 0.0, 0.0]
|
||||||
|
msg.imu_state.gyroscope = [0.1, 0.2, 0.3]
|
||||||
|
msg.imu_state.accelerometer = [0.0, 0.0, 9.81]
|
||||||
|
msg.imu_state.rpy = [0.0, 0.0, 0.0]
|
||||||
|
msg.imu_state.temperature = 25.0
|
||||||
|
msg.wireless_remote = b"\x00" * 40
|
||||||
|
msg.mode_machine = 0
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sdk_mocks():
|
||||||
|
"""Create mocks for the Unitree SDK modules used by UnitreeG1."""
|
||||||
|
lowcmd_default = MagicMock()
|
||||||
|
lowcmd_default.mode_pr = 0
|
||||||
|
lowcmd_default.motor_cmd = [MagicMock() for _ in range(35)]
|
||||||
|
|
||||||
|
crc_mock = MagicMock()
|
||||||
|
crc_mock.Crc.return_value = 0
|
||||||
|
|
||||||
|
lowstate_msg = _make_lowstate_msg_mock()
|
||||||
|
|
||||||
|
subscriber_mock = MagicMock()
|
||||||
|
subscriber_mock.Read.return_value = lowstate_msg
|
||||||
|
|
||||||
|
publisher_mock = MagicMock()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"lowcmd_default": lowcmd_default,
|
||||||
|
"crc_mock": crc_mock,
|
||||||
|
"subscriber_mock": subscriber_mock,
|
||||||
|
"publisher_mock": publisher_mock,
|
||||||
|
"lowstate_msg": lowstate_msg,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unitree_g1():
|
||||||
|
"""Create a UnitreeG1 robot with all SDK dependencies mocked."""
|
||||||
|
mocks = _make_sdk_mocks()
|
||||||
|
|
||||||
|
mock_channel_init = MagicMock()
|
||||||
|
mock_channel_pub = MagicMock(return_value=mocks["publisher_mock"])
|
||||||
|
mock_channel_sub = MagicMock(return_value=mocks["subscriber_mock"])
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"lerobot.robots.unitree_g1.unitree_g1.make_cameras_from_configs",
|
||||||
|
return_value={},
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"lerobot.robots.unitree_g1.unitree_g1.G1_29_ArmIK",
|
||||||
|
return_value=MagicMock(),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"lerobot.robots.unitree_g1.unitree_g1._SDKChannelFactoryInitialize",
|
||||||
|
mock_channel_init,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"lerobot.robots.unitree_g1.unitree_g1._SDKChannelPublisher",
|
||||||
|
mock_channel_pub,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"lerobot.robots.unitree_g1.unitree_g1._SDKChannelSubscriber",
|
||||||
|
mock_channel_sub,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"lerobot.robots.unitree_g1.unitree_g1.unitree_hg_msg_dds__LowCmd_",
|
||||||
|
MagicMock(return_value=mocks["lowcmd_default"]),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"lerobot.robots.unitree_g1.unitree_g1.hg_LowCmd",
|
||||||
|
MagicMock,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"lerobot.robots.unitree_g1.unitree_g1.hg_LowState",
|
||||||
|
MagicMock,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"lerobot.robots.unitree_g1.unitree_g1.CRC",
|
||||||
|
MagicMock(return_value=mocks["crc_mock"]),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||||
|
|
||||||
|
cfg = UnitreeG1Config(is_simulation=True, gravity_compensation=False)
|
||||||
|
robot = UnitreeG1(cfg)
|
||||||
|
yield robot, mocks
|
||||||
|
if robot.is_connected:
|
||||||
|
robot.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_state(unitree_g1):
|
||||||
|
robot, _ = unitree_g1
|
||||||
|
assert not robot.is_connected
|
||||||
|
assert robot.controller is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_observation_features(unitree_g1):
|
||||||
|
robot, _ = unitree_g1
|
||||||
|
features = robot.observation_features
|
||||||
|
# Should have .q for all 29 joints (no cameras configured)
|
||||||
|
assert len(features) == 29
|
||||||
|
for joint in G1_29_JointIndex:
|
||||||
|
assert f"{joint.name}.q" in features
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_features_no_controller(unitree_g1):
|
||||||
|
robot, _ = unitree_g1
|
||||||
|
features = robot.action_features
|
||||||
|
# Without controller: all 29 joints
|
||||||
|
assert len(features) == 29
|
||||||
|
for joint in G1_29_JointIndex:
|
||||||
|
assert f"{joint.name}.q" in features
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_observation_before_connect(unitree_g1):
|
||||||
|
robot, _ = unitree_g1
|
||||||
|
obs = robot.get_observation()
|
||||||
|
assert obs == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_disconnect_idempotent(unitree_g1):
|
||||||
|
robot, _ = unitree_g1
|
||||||
|
# Should not raise even when not connected
|
||||||
|
robot.disconnect()
|
||||||
309
tests/teleoperators/test_unitree_g1_teleoperator.py
Normal file
309
tests/teleoperators/test_unitree_g1_teleoperator.py
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Tests for Unitree G1 teleoperator. Meant to be run in an environment where the Unitree SDK is installed."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.utils.import_utils import _unitree_sdk_available
|
||||||
|
|
||||||
|
if not _unitree_sdk_available:
|
||||||
|
pytest.skip("Unitree SDK not available", allow_module_level=True)
|
||||||
|
|
||||||
|
from lerobot.robots.unitree_g1.g1_utils import REMOTE_AXES
|
||||||
|
from lerobot.teleoperators.unitree_g1.config_unitree_g1 import (
|
||||||
|
ExoskeletonArmPortConfig,
|
||||||
|
UnitreeG1TeleoperatorConfig,
|
||||||
|
)
|
||||||
|
from lerobot.teleoperators.unitree_g1.unitree_g1 import RemoteController, UnitreeG1Teleoperator
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests for RemoteController
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_joystick_mock():
|
||||||
|
"""Create a mock Joystick class matching the SDK interface."""
|
||||||
|
joystick = MagicMock()
|
||||||
|
# Axes are Axis objects with .data attribute
|
||||||
|
joystick.lx = MagicMock(data=0.0, smooth=0.03, deadzone=0.01)
|
||||||
|
joystick.ly = MagicMock(data=0.0, smooth=0.03, deadzone=0.01)
|
||||||
|
joystick.rx = MagicMock(data=0.0, smooth=0.03, deadzone=0.01)
|
||||||
|
joystick.ry = MagicMock(data=0.0, smooth=0.03, deadzone=0.01)
|
||||||
|
# Buttons are Button objects with .data attribute
|
||||||
|
for name in ["RB", "LB", "start", "back", "RT", "LT", "A", "B", "X", "Y", "up", "right", "down", "left"]:
|
||||||
|
setattr(joystick, name, MagicMock(data=0))
|
||||||
|
return joystick
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def remote_controller():
|
||||||
|
"""Create a RemoteController with a mocked Joystick."""
|
||||||
|
mock_joystick = _make_joystick_mock()
|
||||||
|
|
||||||
|
rc = RemoteController()
|
||||||
|
rc._joystick = mock_joystick
|
||||||
|
yield rc, mock_joystick
|
||||||
|
|
||||||
|
|
||||||
|
def test_remote_controller_init(remote_controller):
|
||||||
|
rc, _ = remote_controller
|
||||||
|
assert rc.lx == 0.0
|
||||||
|
assert rc.ly == 0.0
|
||||||
|
assert rc.rx == 0.0
|
||||||
|
assert rc.ry == 0.0
|
||||||
|
assert len(rc.button) == 16
|
||||||
|
assert all(b == 0 for b in rc.button)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_remote_action(remote_controller):
|
||||||
|
rc, _ = remote_controller
|
||||||
|
rc.lx = 0.5
|
||||||
|
rc.ly = -0.3
|
||||||
|
rc.rx = 0.1
|
||||||
|
rc.ry = 0.0
|
||||||
|
rc._sync_remote_action()
|
||||||
|
|
||||||
|
assert rc.remote_action["remote.lx"] == 0.5
|
||||||
|
assert rc.remote_action["remote.ly"] == -0.3
|
||||||
|
assert rc.remote_action["remote.rx"] == 0.1
|
||||||
|
assert rc.remote_action["remote.ry"] == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_from_wireless_calls_extract(remote_controller):
|
||||||
|
rc, mock_joystick = remote_controller
|
||||||
|
# Set up the mock to populate data after extract
|
||||||
|
mock_joystick.lx.data = 0.5
|
||||||
|
mock_joystick.ly.data = -0.3
|
||||||
|
mock_joystick.rx.data = 0.1
|
||||||
|
mock_joystick.ry.data = 0.0
|
||||||
|
|
||||||
|
wireless_data = b"\x00" * 40
|
||||||
|
rc.set_from_wireless(wireless_data)
|
||||||
|
|
||||||
|
mock_joystick.extract.assert_called_once_with(wireless_data)
|
||||||
|
assert rc.lx == 0.5
|
||||||
|
assert rc.ly == -0.3
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_from_wireless_short_data(remote_controller):
|
||||||
|
rc, mock_joystick = remote_controller
|
||||||
|
rc.set_from_wireless(b"\x00" * 10) # Too short
|
||||||
|
mock_joystick.extract.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_from_wireless_buttons(remote_controller):
|
||||||
|
rc, mock_joystick = remote_controller
|
||||||
|
# Simulate RB pressed
|
||||||
|
mock_joystick.RB.data = 1
|
||||||
|
mock_joystick.lx.data = 0.0
|
||||||
|
mock_joystick.ly.data = 0.0
|
||||||
|
mock_joystick.rx.data = 0.0
|
||||||
|
mock_joystick.ry.data = 0.0
|
||||||
|
|
||||||
|
rc.set_from_wireless(b"\x00" * 40)
|
||||||
|
assert rc.button[0] == 1 # RB maps to button[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_from_exo_left(remote_controller):
|
||||||
|
rc, _ = remote_controller
|
||||||
|
rc.use_left_exo_joystick = True
|
||||||
|
rc.left_center_x = 2048
|
||||||
|
rc.left_center_y = 2048
|
||||||
|
|
||||||
|
raw16 = [0] * 16
|
||||||
|
raw16[11] = 3048 # X axis: (3048 - 2048) / 2047.5 ≈ 0.488
|
||||||
|
raw16[13] = 1048 # Y axis: (1048 - 2048) / 2047.5 ≈ -0.488
|
||||||
|
raw16[12] = 0 # Button pressed (below ADC_HALF)
|
||||||
|
|
||||||
|
rc.set_from_exo(raw16, "left")
|
||||||
|
assert rc.lx == pytest.approx((3048 - 2048) / 2047.5, abs=1e-3)
|
||||||
|
assert rc.ly == pytest.approx((1048 - 2048) / 2047.5, abs=1e-3)
|
||||||
|
assert rc.button[4] == 1 # Left button maps to button[4]
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_from_exo_clears_button(remote_controller):
|
||||||
|
rc, _ = remote_controller
|
||||||
|
rc.use_left_exo_joystick = True
|
||||||
|
rc.button[4] = 1 # Pre-set
|
||||||
|
|
||||||
|
raw16 = [0] * 16
|
||||||
|
raw16[12] = 4000 # Button NOT pressed (above ADC_HALF)
|
||||||
|
|
||||||
|
rc.set_from_exo(raw16, "left")
|
||||||
|
assert rc.button[4] == 0 # Should be cleared
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_from_exo_ignored_when_not_enabled(remote_controller):
|
||||||
|
rc, _ = remote_controller
|
||||||
|
rc.use_left_exo_joystick = False
|
||||||
|
raw16 = [0] * 16
|
||||||
|
raw16[11] = 3000
|
||||||
|
|
||||||
|
rc.set_from_exo(raw16, "left")
|
||||||
|
assert rc.lx == 0.0 # Unchanged
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests for UnitreeG1TeleoperatorConfig (no SDK needed)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTeleoperatorConfig:
|
||||||
|
def test_default_config(self):
|
||||||
|
cfg = UnitreeG1TeleoperatorConfig()
|
||||||
|
assert cfg.left_arm_config.port == ""
|
||||||
|
assert cfg.right_arm_config.port == ""
|
||||||
|
assert cfg.frozen_joints == ""
|
||||||
|
|
||||||
|
def test_config_with_ports(self):
|
||||||
|
cfg = UnitreeG1TeleoperatorConfig(
|
||||||
|
left_arm_config=ExoskeletonArmPortConfig(port="/dev/ttyACM0"),
|
||||||
|
right_arm_config=ExoskeletonArmPortConfig(port="/dev/ttyACM1"),
|
||||||
|
)
|
||||||
|
assert cfg.left_arm_config.port == "/dev/ttyACM0"
|
||||||
|
assert cfg.right_arm_config.port == "/dev/ttyACM1"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests for UnitreeG1Teleoperator
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def teleop_remote_only():
|
||||||
|
"""Create a UnitreeG1Teleoperator in remote-only mode (no exo arms)."""
|
||||||
|
cfg = UnitreeG1TeleoperatorConfig() # No ports = remote-only mode
|
||||||
|
teleop = UnitreeG1Teleoperator(cfg)
|
||||||
|
yield teleop
|
||||||
|
|
||||||
|
|
||||||
|
def test_remote_only_connect(teleop_remote_only):
|
||||||
|
"""Remote-only mode should connect immediately without serial ports."""
|
||||||
|
teleop = teleop_remote_only
|
||||||
|
teleop.connect()
|
||||||
|
assert teleop.is_connected
|
||||||
|
assert not teleop._arm_control_enabled
|
||||||
|
|
||||||
|
|
||||||
|
def test_remote_only_action_features(teleop_remote_only):
|
||||||
|
teleop = teleop_remote_only
|
||||||
|
features = teleop.action_features
|
||||||
|
# Remote-only: just the 4 remote axes
|
||||||
|
assert set(features.keys()) == set(REMOTE_AXES)
|
||||||
|
|
||||||
|
|
||||||
|
def test_feedback_features(teleop_remote_only):
|
||||||
|
teleop = teleop_remote_only
|
||||||
|
features = teleop.feedback_features
|
||||||
|
assert "wireless_remote" in features
|
||||||
|
assert features["wireless_remote"] is bytes
|
||||||
|
|
||||||
|
|
||||||
|
def test_remote_only_get_action(teleop_remote_only):
|
||||||
|
teleop = teleop_remote_only
|
||||||
|
teleop.connect()
|
||||||
|
action = teleop.get_action()
|
||||||
|
assert set(action.keys()) == set(REMOTE_AXES)
|
||||||
|
assert all(isinstance(v, float) for v in action.values())
|
||||||
|
|
||||||
|
|
||||||
|
def test_send_feedback(teleop_remote_only):
|
||||||
|
teleop = teleop_remote_only
|
||||||
|
teleop.connect()
|
||||||
|
# Should not raise
|
||||||
|
teleop.send_feedback({"wireless_remote": b"\x00" * 40})
|
||||||
|
|
||||||
|
|
||||||
|
def test_send_feedback_missing_key(teleop_remote_only):
|
||||||
|
teleop = teleop_remote_only
|
||||||
|
teleop.connect()
|
||||||
|
# Should not raise even with missing key
|
||||||
|
teleop.send_feedback({"other_key": 42})
|
||||||
|
|
||||||
|
|
||||||
|
def test_asymmetric_exo_ports_raises():
|
||||||
|
"""Configuring only one exo port should raise ValueError."""
|
||||||
|
cfg = UnitreeG1TeleoperatorConfig(
|
||||||
|
left_arm_config=ExoskeletonArmPortConfig(port="/dev/ttyACM0"),
|
||||||
|
# right_arm_config left empty
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="set both left/right"):
|
||||||
|
UnitreeG1Teleoperator(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests for ExoskeletonArm (needs serial mock)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestExoskeletonArm:
|
||||||
|
def test_parse_raw16_valid(self):
|
||||||
|
from lerobot.teleoperators.unitree_g1.exo_serial import parse_raw16
|
||||||
|
|
||||||
|
line = b"100 200 300 400 500 600 700 800 900 1000 1100 1200 1300 1400 1500 1600\n"
|
||||||
|
result = parse_raw16(line)
|
||||||
|
assert result is not None
|
||||||
|
assert len(result) == 16
|
||||||
|
assert result[0] == 100
|
||||||
|
assert result[15] == 1600
|
||||||
|
|
||||||
|
def test_parse_raw16_too_short(self):
|
||||||
|
from lerobot.teleoperators.unitree_g1.exo_serial import parse_raw16
|
||||||
|
|
||||||
|
line = b"100 200 300\n"
|
||||||
|
assert parse_raw16(line) is None
|
||||||
|
|
||||||
|
def test_parse_raw16_garbage(self):
|
||||||
|
from lerobot.teleoperators.unitree_g1.exo_serial import parse_raw16
|
||||||
|
|
||||||
|
assert parse_raw16(b"not numbers at all\n") is None
|
||||||
|
assert parse_raw16(b"\xff\xfe\xfd\n") is None
|
||||||
|
assert parse_raw16(b"") is None
|
||||||
|
|
||||||
|
def test_calibrate_requires_connection(self):
|
||||||
|
from lerobot.teleoperators.unitree_g1.exo_serial import ExoskeletonArm
|
||||||
|
|
||||||
|
arm = ExoskeletonArm(
|
||||||
|
port="/dev/null",
|
||||||
|
calibration_fpath=MagicMock(is_file=MagicMock(return_value=False)),
|
||||||
|
side="left",
|
||||||
|
)
|
||||||
|
with pytest.raises(RuntimeError, match="not connected"):
|
||||||
|
arm.calibrate()
|
||||||
|
|
||||||
|
def test_is_connected_false_by_default(self):
|
||||||
|
from lerobot.teleoperators.unitree_g1.exo_serial import ExoskeletonArm
|
||||||
|
|
||||||
|
arm = ExoskeletonArm(
|
||||||
|
port="/dev/null",
|
||||||
|
calibration_fpath=MagicMock(is_file=MagicMock(return_value=False)),
|
||||||
|
side="left",
|
||||||
|
)
|
||||||
|
assert not arm.is_connected
|
||||||
|
assert not arm.is_calibrated
|
||||||
|
|
||||||
|
def test_read_raw_when_disconnected(self):
|
||||||
|
from lerobot.teleoperators.unitree_g1.exo_serial import ExoskeletonArm
|
||||||
|
|
||||||
|
arm = ExoskeletonArm(
|
||||||
|
port="/dev/null",
|
||||||
|
calibration_fpath=MagicMock(is_file=MagicMock(return_value=False)),
|
||||||
|
side="left",
|
||||||
|
)
|
||||||
|
assert arm.read_raw() is None
|
||||||
176
tests/test_robocasa_env.py
Normal file
176
tests/test_robocasa_env.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for RoboCasa LeRobot integration.
|
||||||
|
|
||||||
|
Requires: robocasa installed + kitchen assets downloaded.
|
||||||
|
Tests are skipped automatically if robocasa is not available.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Skip entire module if robocasa is not installed or assets are missing
|
||||||
|
robocasa = pytest.importorskip("robocasa", reason="robocasa not installed")
|
||||||
|
|
||||||
|
from lerobot.envs.robocasa import ACTION_DIM, STATE_DIM, CAM_KEY_TO_NAME, RoboCasaEnv, create_robocasa_envs
|
||||||
|
|
||||||
|
# The 5 benchmark tasks (3 short + 2 long)
|
||||||
|
BENCHMARK_TASKS = [
|
||||||
|
"PickPlaceCounterToCabinet", # short
|
||||||
|
"PrepareToast", # short
|
||||||
|
"CoffeeSetupMug", # short
|
||||||
|
"PrepareCoffee", # long
|
||||||
|
"RestockPantry", # long
|
||||||
|
]
|
||||||
|
SHORT_TASKS = BENCHMARK_TASKS[:3]
|
||||||
|
LONG_TASKS = BENCHMARK_TASKS[3:]
|
||||||
|
|
||||||
|
IMAGE_SIZE = 64 # small for fast tests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def single_env():
|
||||||
|
"""Shared env instance for lightweight tests."""
|
||||||
|
env = RoboCasaEnv(task="PickPlaceCounterToCabinet", image_size=IMAGE_SIZE)
|
||||||
|
yield env
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoboCasaEnvSpaces:
|
||||||
|
def test_action_space_is_flat_box(self, single_env):
|
||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
assert isinstance(single_env.action_space, gym.spaces.Box)
|
||||||
|
assert single_env.action_space.shape == (ACTION_DIM,)
|
||||||
|
assert single_env.action_space.dtype == np.float32
|
||||||
|
|
||||||
|
def test_action_bounds(self, single_env):
|
||||||
|
assert np.all(single_env.action_space.low == -1.0)
|
||||||
|
assert np.all(single_env.action_space.high == 1.0)
|
||||||
|
|
||||||
|
def test_observation_space_has_pixels_and_state(self, single_env):
|
||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
assert isinstance(single_env.observation_space, gym.spaces.Dict)
|
||||||
|
assert "pixels" in single_env.observation_space.spaces
|
||||||
|
assert "robot_state" in single_env.observation_space.spaces
|
||||||
|
|
||||||
|
def test_observation_space_cameras(self, single_env):
|
||||||
|
pixels_space = single_env.observation_space["pixels"]
|
||||||
|
expected_cams = set(CAM_KEY_TO_NAME.values())
|
||||||
|
assert set(pixels_space.spaces.keys()) == expected_cams
|
||||||
|
|
||||||
|
def test_state_dim(self, single_env):
|
||||||
|
state_space = single_env.observation_space["robot_state"]
|
||||||
|
assert state_space.shape == (STATE_DIM,)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoboCasaEnvReset:
|
||||||
|
def test_reset_returns_obs_and_info(self, single_env):
|
||||||
|
obs, info = single_env.reset()
|
||||||
|
assert isinstance(obs, dict)
|
||||||
|
assert isinstance(info, dict)
|
||||||
|
|
||||||
|
def test_reset_obs_has_pixels(self, single_env):
|
||||||
|
obs, _ = single_env.reset()
|
||||||
|
assert "pixels" in obs
|
||||||
|
for cam_name in CAM_KEY_TO_NAME.values():
|
||||||
|
assert cam_name in obs["pixels"], f"Missing camera: {cam_name}"
|
||||||
|
|
||||||
|
def test_reset_obs_image_shape(self, single_env):
|
||||||
|
obs, _ = single_env.reset()
|
||||||
|
for cam_name, img in obs["pixels"].items():
|
||||||
|
assert img.shape == (IMAGE_SIZE, IMAGE_SIZE, 3), f"Bad shape for {cam_name}: {img.shape}"
|
||||||
|
assert img.dtype == np.uint8
|
||||||
|
|
||||||
|
def test_reset_obs_state_shape(self, single_env):
|
||||||
|
obs, _ = single_env.reset()
|
||||||
|
assert obs["robot_state"].shape == (STATE_DIM,)
|
||||||
|
assert obs["robot_state"].dtype == np.float32
|
||||||
|
|
||||||
|
def test_reset_info_has_task(self, single_env):
|
||||||
|
_, info = single_env.reset()
|
||||||
|
assert "task" in info
|
||||||
|
assert info["task"] == "PickPlaceCounterToCabinet"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoboCasaEnvStep:
|
||||||
|
def test_step_10_random_actions(self, single_env):
|
||||||
|
single_env.reset()
|
||||||
|
for _ in range(10):
|
||||||
|
action = single_env.action_space.sample()
|
||||||
|
obs, reward, terminated, truncated, info = single_env.step(action)
|
||||||
|
assert obs["robot_state"].shape == (STATE_DIM,)
|
||||||
|
assert isinstance(reward, float)
|
||||||
|
assert isinstance(terminated, bool)
|
||||||
|
assert isinstance(truncated, bool)
|
||||||
|
|
||||||
|
def test_step_bad_action_raises(self, single_env):
|
||||||
|
single_env.reset()
|
||||||
|
with pytest.raises(ValueError, match="Expected 1-D action"):
|
||||||
|
single_env.step(np.zeros((2, ACTION_DIM)))
|
||||||
|
|
||||||
|
def test_step_info_has_is_success(self, single_env):
|
||||||
|
single_env.reset()
|
||||||
|
_, _, _, _, info = single_env.step(single_env.action_space.sample())
|
||||||
|
assert "is_success" in info
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoboCasaConfig:
|
||||||
|
def test_robocasa_env_config(self):
|
||||||
|
from lerobot.envs.configs import RoboCasaEnv as RoboCasaEnvConfig
|
||||||
|
from lerobot.configs.types import FeatureType
|
||||||
|
|
||||||
|
cfg = RoboCasaEnvConfig(task="PickPlaceCounterToCabinet", image_size=IMAGE_SIZE)
|
||||||
|
assert cfg.type == "robocasa"
|
||||||
|
# action feature
|
||||||
|
assert "action" in cfg.features
|
||||||
|
assert cfg.features["action"].shape == (ACTION_DIM,)
|
||||||
|
# camera features
|
||||||
|
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
|
||||||
|
assert cam in cfg.features
|
||||||
|
assert cfg.features[cam].type == FeatureType.VISUAL
|
||||||
|
assert cfg.features[cam].shape == (IMAGE_SIZE, IMAGE_SIZE, 3)
|
||||||
|
# state feature
|
||||||
|
assert "robot_state" in cfg.features
|
||||||
|
assert cfg.features["robot_state"].shape == (STATE_DIM,)
|
||||||
|
|
||||||
|
def test_make_env_config_robocasa(self):
|
||||||
|
from lerobot.envs.factory import make_env_config
|
||||||
|
cfg = make_env_config("robocasa", task="PickPlaceCounterToCabinet")
|
||||||
|
assert cfg.type == "robocasa"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoboCasaProcessorStep:
|
||||||
|
def test_processor_remaps_keys(self):
|
||||||
|
import torch
|
||||||
|
from lerobot.processor.env_processor import RoboCasaProcessorStep
|
||||||
|
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
|
step = RoboCasaProcessorStep()
|
||||||
|
B = 2
|
||||||
|
obs = {
|
||||||
|
f"{OBS_IMAGES}.agentview_left": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||||
|
f"{OBS_IMAGES}.agentview_right": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||||
|
f"{OBS_IMAGES}.eye_in_hand": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||||
|
f"observation.robot_state": torch.zeros(B, STATE_DIM),
|
||||||
|
}
|
||||||
|
out = step._process_observation(obs)
|
||||||
|
assert OBS_STATE in out
|
||||||
|
assert out[OBS_STATE].dtype == torch.float32
|
||||||
|
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
|
||||||
|
assert f"{OBS_IMAGES}.{cam}" in out
|
||||||
Reference in New Issue
Block a user