Compare commits

..

13 Commits

Author SHA1 Message Date
CarolinePascal
92f08933ec fix(features): allowing for sequence of shape (1,) when a names list is provided 2026-03-11 11:03:04 +01:00
CarolinePascal
09c93c4aa1 fix(revision exception): adding uncaught exception when requested dataset revision does not exist remotely 2026-03-11 11:03:04 +01:00
Silvio Traversaro
19c6adef85 chore(dependencies): Increase opencv-python-headless upper bound (#3120)
Signed-off-by: Silvio Traversaro <silvio@traversaro.it>
2026-03-09 23:27:18 +01:00
Johnson Sun
96b7f3dae0 Parse HF_USER with NO_COLOR to avoid incorrectly capturing bash ANSI codes (#3119) 2026-03-09 18:47:58 +01:00
Martino Russi
885ef91892 fix(unitree_g1): correct SDK detection and update installation docs (#3115)
* update docs

* update toml / docs

* update docs

* fix joystick

* Update pyproject.toml

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>

* update toml and docs

* update docs

* clarify robot

* update docs

* update docs

* update pinocchio deps

* final touches

* Update docs/source/unitree_g1.mdx

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>

* move envhub dependencies to docs

* point to unitree_sdk docs

* upper bound on onnx

* chore(docs): small details unitree docs

* chore(deps): add version pin and unitree_sdk hint

---------

Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2026-03-09 18:47:12 +01:00
Steven Palma
b0efa73520 chore(dependencies): Bump lerobot to 0.5.1 (#3118) 2026-03-09 12:43:32 +01:00
Steven Palma
00b662de02 chore(dependencies): Bump lerobot to 0.5.0 (#3117) 2026-03-09 11:34:52 +01:00
Steven Palma
5c51a74484 chore(deps): update requirements file (#3114) 2026-03-09 11:18:05 +01:00
Steven Palma
db8547e35d test(cameras): skip flaky async_read test (#3106) 2026-03-08 14:02:33 +01:00
Steven Palma
c17d949531 chore(readme): update citation with ICLR26 paper (#3107)
* peer reviewed citation 🎉

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

* add iclr year

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

* fix quentin's spelling name

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

* docs(readme): update citation

---------

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
Co-authored-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
2026-03-08 14:01:43 +01:00
Steven Palma
1e131f93f8 chore(docs): add uv installation instructions (#3105)
* chore(docs): add uv installation instructions

* fix(docs): format tabs

* chore(docs): small details

* chore(docs): last details uv installation instructions

* chore(docs): last detail

---

Co-authored-by: sahilmaniyar888 <156301258+sahilmaniyar888@users.noreply.github.com>
2026-03-08 13:00:06 +01:00
Ignat Georgiev
2fb5c7add0 feat(train): add cudnn_deterministic option for reproducible training (#3102)
Add a `cudnn_deterministic` flag to `TrainPipelineConfig` (default: False)
that sets `torch.backends.cudnn.deterministic = True` and disables benchmark
mode, eliminating CUDA floating-point non-determinism at the cost of ~10-20%
training speed. When False (default) the existing benchmark=True behaviour
is preserved.
2026-03-08 12:29:33 +01:00
Martino Russi
4f2ef024d8 feat(robots): Unitree G1 WBC implementation (#2876)
* move locomotion from examples to robot, move controller to teleoperator class

* modify teleoperate to send back actions to robot

* whole body controller

* add holosoma to locomotros

* various updates

* update joint zeroing etc

* ensure safefail with locomotion

* add unitree locomotion

* launch camera from g1 server

* publish at varying framerates

* fix async read in camera

* attempting to fix camera lag

* test camera speedup

* training

* inference works

* remove logging from pi0

* remove logging

* push local changes

* testing

* final changes

* revert control_utils

* revert utils

* revert

* revert g1

* revert again:

* revert utils

* push recents

* remove examples

* remove junk

* remove mjlog

* revergt edit_dataset

* Update lerobot_edit_dataset.py

Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>

* undo teleop changes

* revert logging

* remove loggings

* remove loogs

* revert dataset tools

* Update dataset_tools.py

Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>

* move gravity to utils

* revert changes

* remove matplotlib viewer (rerun works fine)

* factory revert

* send policy action directly

* recent changes

* implement flexible action space

* send empty command if arms are missing

* rename locomotion to controller

* add init

* implement feedback

* add feedback for teleoperator

* fix ruff

* fix ruff

* use read_latest

* fix zmq camera

* revert exo_serial

* simplify PR

* revert exo_changes

* revert camera_zmq

* Update camera_zmq.py

Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>

* remove frame duplication from zmq server

* revert channerfactoryinitialize

* keep channelfactoryinitialize

* remove zeroing out logic

* fix typo

* refactor teleop class

* simplify teleop further

* import armindex at the top

* fix visualizer again

* revert ik helper

* push stuff

* simplify image_server

* update image_server

* asd

* add threading logic

* simplify ik helper stuff

* simplify holosoma

* fix names

* fix docs

* revert leg override

* clean connect

* fix controller

* fix ruff

* clean teleoperator

* set_from_wireless

* avoid double initializations

* refactor robot class

* fix pre-commit

* update docs

* update docs format

* add teleop instructions

* unitree_g1 specific exception in record/teleoperate

* add thumbnail to docs

* add thumbnail to doc

* refactor(unitree): multiple improvements (#3103)

* refactor(unitree): multiple improvements

* test(unitree): added tests + improved installation instructions

* refactor(robots): minor changes unitree robot kinematic

* chore(robots): rename g1 kinematics file

---------

Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2026-03-08 11:33:24 +01:00
46 changed files with 2050 additions and 3391 deletions

View File

@@ -100,11 +100,11 @@ lerobot-train \
--dataset.repo_id=lerobot/aloha_mobile_cabinet
```
| Category | Models |
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
| Category | Models |
| -------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
@@ -135,7 +135,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
## Citation
If you use LeRobot in your research, please cite:
If you use LeRobot in your project, please cite the GitHub repository to acknowledge the ongoing development and contributors:
```bibtex
@misc{cadene2024lerobot,
@@ -146,6 +146,23 @@ If you use LeRobot in your research, please cite:
}
```
If you are referencing our research or the academic paper, please also cite our ICLR publication:
<details>
<summary><b>ICLR 2026 Paper</b></summary>
```bibtex
@inproceedings{cadenelerobot,
title={LeRobot: An Open-Source Library for End-to-End Robot Learning},
author={Cadene, Remi and Alibert, Simon and Capuano, Francesco and Aractingi, Michel and Zouitine, Adil and Kooijmans, Pepijn and Choghari, Jade and Russi, Martino and Pascal, Caroline and Palma, Steven and Shukor, Mustafa and Moss, Jess and Soare, Alexander and Aubakirova, Dana and Lhoest, Quentin and Gallou\'edec, Quentin and Wolf, Thomas},
booktitle={The Fourteenth International Conference on Learning Representations},
year={2026},
url={https://arxiv.org/abs/2602.22818}
}
```
</details>
## Contribute
We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](./CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!

View File

@@ -47,8 +47,6 @@
title: NVIDIA GR00T N1.5
- local: xvla
title: X-VLA
- local: multi_task_dit
title: Multitask DiT Policy
- local: walloss
title: WALL-OSS
title: "Policies"

View File

@@ -165,7 +165,7 @@ hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
Then store your Hugging Face repository name in a variable:
```bash
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
HF_USER=$(NO_COLOR=1 hf auth whoami | awk -F': *' 'NR==1 {print $2}')
echo $HF_USER
```

View File

@@ -1,8 +1,8 @@
# Installation
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
This guide uses `conda` (via miniforge) to manage environments (recommended). If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and `ffmpeg` installed with the `libsvtav1` encoder, then skip ahead to [Environment Setup](#step-2-environment-setup).
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
## Step 1 (`conda` only): Install [`miniforge`](https://conda-forge.org/download/)
```bash
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
@@ -11,22 +11,47 @@ bash Miniforge3-$(uname)-$(uname -m).sh
## Step 2: Environment Setup
Create a virtual environment with Python 3.12, using conda:
Create a virtual environment with Python 3.12:
<!-- prettier-ignore-start -->
<hfoptions id="create_venv">
<hfoption id="conda">
```bash
conda create -y -n lerobot python=3.12
```
Then activate your conda environment, you have to do this each time you open a shell to use lerobot:
</hfoption>
<hfoption id="uv">
```bash
conda activate lerobot
uv python install 3.12
uv venv --python 3.12
```
</hfoption>
</hfoptions>
<!-- prettier-ignore-end -->
Then activate your virtual environment, you have to do this each time you open a shell to use lerobot:
<!-- prettier-ignore-start -->
<hfoptions id="activate_venv">
<hfoption id="conda">```bash
conda activate lerobot
```</hfoption>
<hfoption id="uv">
```bash
# Linux/macOSsource
source .venv/bin/activate
# Windows PowerShell
source .venv\Scripts\Activate.ps1
```
</hfoption>
</hfoptions>
<!-- prettier-ignore-end -->
When using `conda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg -c conda-forge
ffmpeg -version # ffmpeg 8.X is not yet supported !
```
> [!TIP]
@@ -47,6 +72,9 @@ conda install ffmpeg -c conda-forge
> conda install evdev -c conda-forge
> ```
> [!IMPORTANT]
> If you are using `uv` you will have to install `ffmpeg` system-wide (outside of the virtual environment). You rely on `uv` and `torchcodec` ability to dynamically link to the system `ffmpeg`.
## Step 3: Install LeRobot 🤗
### From Source
@@ -60,23 +88,45 @@ cd lerobot
Then, install the library in editable mode. This is useful if you plan to contribute to the code.
<!-- prettier-ignore-start -->
<hfoptions id="install_lerobot_src">
<hfoption id="conda">
```bash
pip install -e .
```
</hfoption>
<hfoption id="uv">
```bash
uv pip install -e .
```
</hfoption>
</hfoptions>
<!-- prettier-ignore-end -->
### Installation from PyPI
**Core Library:**
Install the base package with:
<!-- prettier-ignore-start -->
<hfoptions id="install_lerobot_pypi">
<hfoption id="conda">
```bash
pip install lerobot
```
</hfoption>
<hfoption id="uv">
```bash
uv pip install lerobot
```
</hfoption>
</hfoptions>
<!-- prettier-ignore-end -->
_This installs only the default dependencies._
**Extra Features:**
To install additional functionality, use one of the following:
To install additional functionality, use one of the following (If you are using `uv`, replace `pip install` with `uv pip install` in the commands below.):
```bash
pip install 'lerobot[all]' # All available features
@@ -93,7 +143,7 @@ https://pypi.org/project/lerobot/
### Troubleshooting
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
To install these for linux run:
To install these for Linux run:
```bash
sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev
@@ -103,7 +153,7 @@ For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/
## Optional dependencies
LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`.
LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`. If you are using `uv`, replace `pip install` with `uv pip install` in the commands below.
### Simulations

View File

@@ -1,340 +0,0 @@
# Multitask DiT Policy
Multitask Diffusion Transformer (DiT) Policy is an evolution of the original Diffusion Policy architecture, which leverages a large DiT with text and vision conditioning for multitask robot learning. This implementation supports both diffusion and flow matching objectives for action generation, enabling robots to perform diverse manipulation tasks conditioned on language instructions.
## Model Overview
The model uses:
- **CLIP Vision Encoder**: Processes RGB images from multiple camera views
- **CLIP Text Encoder**: Encodes language task instructions (frozen weights with learnable projection)
- **Diffusion Transformer**: Predicts action sequences conditioned on observations and language
- **Two Objectives**: Supports both diffusion (DDPM/DDIM) and flow matching for action generation
This model is exciting because you can achieve extremely high dexterity, competitive with multi-billion parameter
VLAs, with only ~450M parameters and significantly less training.
## Installation Requirements
Multitask DiT Policy has additional dependencies. Install it with:
```bash
pip install lerobot[multi_task_dit]
```
This will install all necessary dependencies including the HuggingFace Transformers library for CLIP models.
## Usage
To use Multitask DiT in your LeRobot configuration, specify the policy type as:
```python
policy.type=multi_task_dit
```
## Training
### Basic Training Command
Here's a complete training command for training Multitask DiT on your dataset:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/multitask_dit_training \
--batch_size=32 \
--steps=5000 \
--save_freq=500 \
--log_freq=100 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true
```
### Recommended Hyperparameters and Dataset Details (30Hz Control Frequency)
For reliable performance, start with these suggested default hyperparameters:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/mutitask_dit_training \
--batch_size=320 \
--steps=30000 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.horizon=32 \
--policy.n_action_steps=24 \
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \
--policy.num_train_timesteps=100 \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true
```
**Key Parameters:**
- **Batch Size**: 192-320 - If you have access to a GPU that can support this, you will get the best training dynamics
- **Horizon**: 32 - number of action steps to predict, ~1.0 sec at 30Hz
- **n_action_steps**: 24 - ~0.8 seconds at 30Hz
- **Objective**: `diffusion` - start with diffusion and experiment with flow matching if generation quality is poor
- **Training Steps**: >30k steps recommended for a single task
### Training Configuration Parameters
#### Objective Selection
Choose between diffusion and flow matching:
```bash
# Diffusion objective (default)
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \ # or "DDIM"
--policy.num_train_timesteps=100 \
--policy.num_inference_steps=10 \ # For faster inference
--policy.beta_schedule=squaredcos_cap_v2 \ # Noise schedule type
--policy.prediction_type=epsilon \ # "epsilon" (predict noise) or "sample" (predict clean)
--policy.clip_sample=true \ # Clip samples during denoising
--policy.clip_sample_range=1.0 # Clipping range [-x, x]
# Flow matching objective
--policy.objective=flow_matching \
--policy.timestep_sampling_strategy=beta \ # or "uniform" | the beta sampling strategy performance appears much better in practice
--policy.num_integration_steps=100 \
--policy.integration_method=euler \ # or "rk4"
--policy.sigma_min=0.0 # Minimum noise in flow interpolation path
```
#### Transformer Architecture
Adjust model capacity based on dataset size:
```bash
# Small datasets (< 100 examples)
--policy.num_layers=4 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
# Medium datasets (100-5k examples) - default
--policy.num_layers=6 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
# Large datasets (> 5k examples)
--policy.num_layers=8 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
```
**Positional Encoding Options:**
The model supports two positional encoding methods for action sequences:
```bash
# Rotary Position Embedding (RoPE) - default, recommended
--policy.use_rope=true \
--policy.rope_base=10000.0 # Base frequency for RoPE
# Absolute positional encoding
--policy.use_positional_encoding=true # Disables RoPE when true
```
**Other Transformer Parameters:**
```bash
--policy.dropout=0.1 # Dropout rate for DiT blocks (0.0-1.0)
--policy.timestep_embed_dim=256 # Timestep embedding dimension
```
#### Vision Encoder Configuration
```bash
# Use different CLIP model for more expressivity at the cost of inference time
# experiment with larger or smaller models depending on the complexity of your tasks and size of dataset
--policy.vision_encoder_name=openai/clip-vit-large-patch14
# Use separate vision encoder per camera
# This may be useful when cameras have significantly different characteristics, but
# be wary of increased VRAM footprint.
--policy.use_separate_rgb_encoder_per_camera=true
# Image preprocessing
--policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups
--policy.image_crop_shape=[224,224] \
--policy.image_crop_is_random=true # Random during training, center at inference
```
#### Text Encoder Configuration
```bash
# Use different CLIP text encoder model
# same as vision: experiment with larger or smaller models depending on the
# complexity of your tasks and size of dataset
--policy.text_encoder_name=openai/clip-vit-large-patch14
```
#### Learning Rate Configuration
The vision encoder uses a separate learning rate multiplier, where 1/10th is suggested to be the ideal staritng point:
```bash
--policy.optimizer_lr=2e-5 \
--policy.vision_encoder_lr_multiplier=0.1 # Vision encoder LR = 0.1 * optimizer_lr
```
### Training Tuning Guidelines
#### 1. Flow Matching with Beta Sampling
The original diffusion implementation here is based on the work described in [TRI's LBM paper](https://arxiv.org/abs/2507.05331)
Additionally, we have implemented a flow-matching objective, which is described at a high-level in [Boston Dynamics blog post](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/).
Consider testing the flow-matching objective and evaluating performance differences for your task:
```bash
--policy.objective=flow_matching \
--policy.timestep_sampling_strategy=beta \
--policy.timestep_sampling_alpha=1.5 \
--policy.timestep_sampling_beta=1.0 \
--policy.timestep_sampling_s=0.999
```
This hasn't been shown to be a silver bullet across every user case, but it occasionally results in smoother and more consistent actions.
#### 2. Number of Transformer Layers
Match model capacity to your dataset size:
- **Small datasets** (< 100 examples): Reduce to 4 layers
- **Large datasets** (> 5k examples): Increase to 8 layers
#### 3. `horizon` Tuning
The model can be sensitive to the horizon you choose. Start with around a 1 second horizon based on your control frequency:
- **30 Hz frequency**: `horizon=30`
- **10 Hz frequency**: `horizon=10`
Then experiment with increasing from there. The horizon determines how far into the future the model predicts actions.
#### 4. `n_action_steps` Sensitivity
The model can also be very sensitive to `n_action_steps`. Start with it being around 0.8 seconds based on your control frequency and tune from there:
- **Lower values**: More reactive but potentially less stable for long-horizon tasks
- **Higher values**: Better for long-horizon execution but open-loop failures are limited in their recovery
### Inference Tuning
For faster inference, use DDIM with fewer sampling steps:
```bash
--policy.noise_scheduler_type=DDIM \
--policy.num_inference_steps=10
```
### Resuming Training
To resume training from a checkpoint:
```bash
lerobot-train \
--config_path=./outputs/mutitask_dit_training/checkpoints/last/pretrained_model/train_config.json \
--resume=true
```
The checkpoint directory should contain `model.safetensors` and `config.json` files (saved automatically during training). When resuming, the configuration is loaded from the checkpoint, so you don't need to specify other parameters.
## Common Failure Modes and Debugging
Training these models can be finicky. Here are common failure modes and debugging approaches:
### Idling / No Motion
The model may "collapse" during inference, resulting in static or no motion. This can occur when:
1. **Insufficient training data**: If you only have 20-50 examples, try to roughly double your dataset size. Once you have above 300 examples, if you're still seeing this, the task may be too complex.
2. **Multiple similar tasks**: When your dataset contains multiple similar tasks (e.g., picking up 2 different objects), the model may rely too heavily on language conditioning which might not be rich enough.
**Debugging tips:**
- Increase dataset size (double until you get to over 300 examples)
- Train for longer, up to 100k steps, even when the loss flatlines
- Check if the model is receiving proper language instructions or increase diversity of instruction
### Executing the Wrong Task
Sometimes the robot will completely ignore your instruction and perform some other task. This generally only happens if you have trained on multiple tasks.
**Potential causes:**
- Language instruction ambiguity
- Insufficient task-specific training data
- Model confusion between similar tasks in the multitask dataset
**Debugging tips:**
- Verify language instruction specificity, especially if descriptions are similar between multiple tasks
- Check task distribution in your training dataset and add weighting to the failing/ignored task
- Consider task-specific fine-tuning
### Training Instability
If training loss is unstable or diverging:
- Try adjusting learning rate between `1e-5` and `3e-4`
- Increase batch size if possible
- Check that your dataset normalization is correct
- Verify image preprocessing is working correctly
## Performance Considerations
### GPU Requirements
- **Inference**: At least an RTX 5070 Ti (or equivalent GPU) is recommended for reasonable speed performance
- **Training**: A GPU with enough VRAM to load batch sizes of >64 is ideal, which will vary depending on the number of image observations, etc
### Batch Size Recommendations
- **Minimum**: 64 (less than this may result in unstable training)
- **Recommended**: 256-320 (best performance, requires larger GPU)
## Example: Training on Custom Dataset
Here's a complete example training on a custom dataset:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/mutitask_dit_training \
--batch_size=320 \
--steps=30000 \
--save_freq=1000 \
--log_freq=100 \
--eval_freq=1000 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.horizon=32 \
--policy.n_action_steps=24 \
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \
--policy.num_layers=6 \
--policy.hidden_dim=512 \
--policy.vision_encoder_name=openai/clip-vit-base-patch16 \
--policy.image_resize_shape=[320,240] \
--policy.image_crop_shape=[224,224] \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true \
--wandb.project=multitask_dit
```
## References
For more details on the technical implementation and architecture, see:
- [A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation](https://arxiv.org/abs/2507.05331)
- [Large Behavior Models and Atlas Find New Footing](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/)
- [Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy](https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy)

View File

@@ -1,37 +0,0 @@
# Multitask DiT Policy
## Citation
If you use this work, please cite the following works:
```bibtex
@misc{jones2025multitaskditpolicy,
author = {Bryson Jones},
title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy},
year = {2025},
url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy},
note = {Blog post}
}
```
```bibtex
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
author = {TRI LBM Team},
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
year = {2025},
eprint = {arXiv:2507.05331},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2507.05331}
}
```
```bibtex
@misc{bostondynamics2025largebehaviormodelsatlas,
author = {Boston Dynamics and TRI Research Team},
title = {Large Behavior Models and Atlas Find New Footing},
year = {2025},
url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/},
note = {Blog post}
}
```

View File

@@ -1,23 +1,72 @@
# Unitree G1
This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/unitree_thumbnail.jpg"
alt="Unitree G1 locomanipulation demo"
style={{ width: "100%" }}
/>
## About
We support both 29 and 23 DOF G1 EDU version. We introduce:
- **`unitree g1` robot class, handling low level read/write from/to the humanoid**
- **ZMQ socket bridge** for remote communication and camera streaming, allowing for remote policy deployment over wlan, eth or directly on the robot
- **Locomotion policies** from NVIDIA gr00t and Amazon FAR Holosoma
- **Simulation mode** for testing policies without the physical robot in mujoco
The Unitree G1 humanoid is now supported in LeRobot! You can teleoperate, train locomanipulation policies, test in sim, and more. Both 29 and 23 DoF variants are supported.
---
## Connection guide
## Part 1: Getting Started
### Step 1: Configure Ethernet Interface
### Install the Unitree SDK
Set a static IP on the same subnet as the robot:
Follow the [unitree_sdk2_python installation guide](https://github.com/unitreerobotics/unitree_sdk2_python#installation). Tested with `unitree_sdk2py==1.0.1` and `cyclonedds==0.10.2`:
```bash
conda create -y -n lerobot python=3.12
conda activate lerobot
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python
pip install -e .
cd ..
```
### Install LeRobot
```bash
conda install ffmpeg -c conda-forge
conda install -c conda-forge "pinocchio>=3.0.0,<4.0.0"
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e '.[unitree_g1]'
```
<Tip>
For now, pinocchio must be installed from conda-forge (not pip) to include the
CasADi bindings needed for arm IK.
</Tip>
### Test the Installation (Simulation)
The simulation environment has its own dependencies. Check the Simulation environment dependencies: [Unitree G1 Mujoco EnvHub](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main).
```bash
pip install mujoco loguru msgpack msgpack-numpy
```
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--teleop.type=unitree_g1 \
--teleop.id=wbc_unitree \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30, "warmup_s": 5}}' \
--display_data=true \
--robot.controller=GrootLocomotionController
```
This will launch a [MuJoCo sim instance](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main) for the G1. You can connect a gamepad to your machine before launching in order to control the robot's locomotion in sim. We support both [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) and [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl) via `--robot.controller`.
- Press `9` to release the robot
- Press `7` / `8` to increase / decrease waist height
### Connect to the Physical Robot
The G1's Ethernet IP is fixed at `192.168.123.164`. Your machine must have a static IP on the same subnet: `192.168.123.x` where `x ≠ 164`.
```bash
# Replace 'enp131s0' with your ethernet interface name (check with `ip a`)
@@ -26,47 +75,23 @@ sudo ip addr add 192.168.123.200/24 dev enp131s0
sudo ip link set enp131s0 up
```
**Note**: The G1's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` with x ≠ 164.
### Step 2: SSH into the Robot
### SSH into the Robot
```bash
ssh unitree@192.168.123.164
# Password: 123
```
You should now be connected to the G1's Orin.
### Share Internet via Ethernet
---
## Part 2: Enable WiFi on the Robot
Wlan0 is disabled by default on the G1. To enable it:
### Step 1: Enable WiFi Hardware
```bash
sudo rfkill unblock wifi
sudo rfkill unblock all
# Bring up wlan0
sudo ip link set wlan0 up
# Enable NetworkManager control of wlan0
sudo nmcli radio wifi on
sudo nmcli device set wlan0 managed yes
sudo systemctl restart NetworkManager
```
### Step 2: Enable Internet Forwarding
The G1 needs internet access to clone repos and install packages. Share your laptop's connection over Ethernet:
**On your laptop:**
```bash
# Enable IP forwarding
sudo sysctl -w net.ipv4.ip_forward=1
# Set up NAT (replace wlp132s0f0 with your WiFi interface)
# Replace wlp132s0f0 with your WiFi interface name
sudo iptables -t nat -A POSTROUTING -o wlp132s0f0 -s 192.168.123.0/24 -j MASQUERADE
sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTABLISHED -j ACCEPT
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
@@ -75,223 +100,193 @@ sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
**On the G1:**
```bash
# Add laptop as default gateway
sudo ip route del default 2>/dev/null || true
sudo ip route add default via 192.168.123.200 dev eth0
echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
# Test connection
# Verify
ping -c 3 8.8.8.8
```
### Step 3: Connect to WiFi Network
### Install the Unitree SDK on the G1
Follow the [unitree_sdk2_python installation guide](https://github.com/unitreerobotics/unitree_sdk2_python#installation):
```bash
conda create -y -n lerobot python=3.12
conda activate lerobot
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python
python -m pip install -e .
cd ..
```
### Install LeRobot on the G1
```bash
git clone https://github.com/huggingface/lerobot.git
cd lerobot
conda install -c conda-forge "pinocchio>=3.0.0,<4.0.0"
python -m pip install -e '.[unitree_g1]'
```
<Tip>
For now, pinocchio must be installed from conda-forge (not pip) to include the
CasADi bindings needed for arm IK.
</Tip>
### (Optional) Enable WiFi on the Robot
For wireless SSH access, you can enable WiFi on the G1 (it's blocked by default):
```bash
sudo rfkill unblock all
sudo ip link set wlan0 up
sudo nmcli radio wifi on
sudo nmcli device set wlan0 managed yes
sudo systemctl restart NetworkManager
```
**Connect to a WiFi network:**
```bash
# List available networks
nmcli device wifi list
# Connect to your WiFi (example)
sudo nmcli connection add type wifi ifname wlan0 con-name "YourNetwork" ssid "YourNetwork"
sudo nmcli connection modify "YourNetwork" wifi-sec.key-mgmt wpa-psk
sudo nmcli connection modify "YourNetwork" wifi-sec.psk "YourPassword"
sudo nmcli connection modify "YourNetwork" connection.autoconnect yes
sudo nmcli connection up "YourNetwork"
# Check WiFi IP address
ip a show wlan0
```
### Step 4: SSH Over WiFi
Once connected to WiFi, note the robot's IP address and disconnect the Ethernet cable. You can now SSH over WiFi:
You can then SSH over WiFi instead of Ethernet:
```bash
ssh unitree@<YOUR_ROBOT_IP>
ssh unitree@<ROBOT_WIFI_IP>
# Password: 123
```
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address.
---
## Part 2: Teleoperation & Locomotion
### Run the Robot Server
On the robot (from `~/lerobot`):
```bash
cd ~/lerobot
python src/lerobot/robots/unitree_g1/run_g1_server.py --camera
```
### Run the Locomotion Policy
You can run the teleoperation client from your laptop over Ethernet, over WiFi (experimental), or directly on the robot itself. Mind potential latency introduced by your network.
**From your laptop:**
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--robot.robot_ip=<ROBOT_IP> \
--teleop.type=unitree_g1 \
--teleop.id=wbc_unitree \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "<ROBOT_IP>", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--display_data=true \
--robot.controller=HolosomaLocomotionController
```
We support both [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl) and [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) via `--robot.controller`.
---
## Part 3: Robot Server Setup
## Part 3: Loco-Manipulation with the Homunculus Exoskeleton
### Step 1: Install LeRobot on the Orin
We provide a loco-manipulation solution via the Homunculus Exoskeleton — an open-source 7 DoF exoskeleton for whole-body control. Check it out [here](https://github.com/nepyope/hmc_exo).
SSH into the robot and install LeRobot:
```bash
ssh unitree@<YOUR_ROBOT_IP>
conda create -y -n lerobot python=3.12
conda activate lerobot
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e '.[unitree_g1]'
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python && pip install -e .
```
**Note**: The Unitree SDK requires CycloneDDS v0.10.2 to be installed. See the [Unitree SDK documentation](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
### Step 2: Run the Robot Server
On the robot:
```bash
python src/lerobot/robots/unitree_g1/run_g1_server.py
```
**Important**: Keep this terminal running. The server must be active for remote control.
---
## Part 4: Controlling the robot
With the robot server running, you can now control the robot remotely. Let's launch a locomotion policy
### Step 1: Install LeRobot on your machine
```bash
conda create -y -n lerobot python=3.12
conda activate lerobot
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e '.[unitree_g1]'
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python && pip install -e .
```
### Step 2: Update Robot IP in Config
Edit the config file to match your robot's WiFi IP:
```python
# In src/lerobot/robots/unitree_g1/config_unitree_g1.py
robot_ip: str = "<YOUR_ROBOT_IP>" # Replace with your robot's WiFi IP.
```
### Step 3: Run the Locomotion Policy
```bash
# Run GR00T locomotion controller
python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1"
# Run Holosoma locomotion controller
python examples/unitree_g1/holosoma_locomotion.py
```
Press `Ctrl+C` to stop the policy.
---
## Running in Simulation Mode (MuJoCo)
You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI.
### Calibrate Exoskeleton Teleoperator
### Calibrate
```bash
lerobot-calibrate \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo
```
### Teleoperate in Simulation
During calibration move each joint through its entire range. After fitting, move the joint in a neutral position and press `n` to advance.
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--fps=100
```
### Record Dataset in Simulation
### Record a Dataset
```bash
lerobot-record \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--dataset.repo_id=your-username/dataset-name \
--dataset.single_task="Test" \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
--dataset.encoder_threads=2
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--dataset.repo_id=your-username/dataset-name \
--dataset.single_task="Test" \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2
```
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
> **Note:** Omit `--teleop.left_arm_config.port` and `--teleop.right_arm_config.port` if you're only using the joystick.
Example dataset: [nepyope/unitree_box_move_blue_full](https://huggingface.co/datasets/nepyope/unitree_box_move_blue_full)
---
## Running on Real Robot
## Part 4: Training & Inference
Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot.
### Start the Camera Server
On the robot, start the ZMQ image server:
### Train
```bash
python src/lerobot/cameras/zmq/image_server.py
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/dataset-name \
--policy.type=pi05 \
--output_dir=./outputs/pi05_training \
--job_name=pi05_training \
--policy.repo_id=your-username/your-repo-id \
--policy.pretrained_path=lerobot/pi05_base \
--policy.compile_model=true \
--policy.gradient_checkpointing=true \
--wandb.enable=true \
--policy.dtype=bfloat16 \
--policy.freeze_vision_encoder=false \
--policy.train_expert_only=false \
--steps=3000 \
--policy.device=cuda \
--batch_size=32
```
Keep this running in a separate terminal for camera streaming during recording.
### Inference with RTC
### Teleoperate Real Robot
Once trained, we recommend deploying policies using inference-time RTC:
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--fps=100
python examples/rtc/eval_with_real_robot.py \
--policy.path=your-username/your-repo-id \
--policy.device=cuda \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--robot.controller=HolosomaLocomotionController \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "<ROBOT_IP>", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--task="task_description" \
--duration=1000 \
--fps=30 \
--rtc.enabled=true
```
### Record Dataset on Real Robot
```bash
lerobot-record \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--dataset.repo_id=your-username/dataset-name \
--dataset.single_task="Test" \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
--dataset.encoder_threads=2
```
**Note**: Update `server_address` to match your robot's camera server IP.
Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real)
---
## Additional Resources
@@ -300,8 +295,8 @@ Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/da
- [GR00T-WholeBodyControl](https://github.com/NVlabs/GR00T-WholeBodyControl)
- [Holosoma](https://github.com/amazon-far/holosoma)
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
- [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
- [Unitree IL LeRobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
---
_Last updated: December 2025_
_Last updated: March 2026_

View File

@@ -78,6 +78,7 @@ from torch import Tensor
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
@@ -97,6 +98,7 @@ from lerobot.robots import ( # noqa: F401
bi_so_follower,
koch_follower,
so_follower,
unitree_g1,
)
from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES

View File

@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.4.5"
version = "0.5.1"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
dynamic = ["readme"]
license = { text = "Apache-2.0" }
@@ -76,7 +76,7 @@ dependencies = [
"torchvision>=0.21.0,<0.26.0",
"einops>=0.8.0,<0.9.0",
"opencv-python-headless>=4.9.0,<4.13.0",
"opencv-python-headless>=4.9.0,<4.14.0",
"av>=15.0.0,<16.0.0",
"jsonlines>=4.0.0,<5.0.0",
"pynput>=1.7.8,<1.9.0",
@@ -119,12 +119,13 @@ gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
unitree_g1 = [
# "unitree-sdk2==1.0.1",
"pyzmq>=26.2.1,<28.0.0",
"onnxruntime>=1.16.0,<2.0.0",
"pin>=3.0.0,<4.0.0",
"onnx>=1.16.0,<2.0.0",
"meshcat>=0.3.0,<0.4.0",
"lerobot[matplotlib-dep]",
"casadi>=3.6.0,<4.0.0",
"lerobot[pygame-dep]",
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
kinematics = ["lerobot[placo-dep]"]
@@ -144,7 +145,6 @@ wallx = [
]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
multi_task_dit = ["lerobot[transformers-dep]"]
groot = [
"lerobot[transformers-dep]",
"lerobot[peft]",
@@ -207,6 +207,7 @@ all = [
"lerobot[metaworld]",
"lerobot[sarm]",
"lerobot[peft]",
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
]
[project.scripts]

View File

@@ -1,76 +1,73 @@
#
# This file is autogenerated by pip-compile with Python 3.10
# This file is autogenerated by pip-compile with Python 3.12
# by the following command:
#
# pip-compile --output-file=requirements-macos.txt requirements.in
#
-e .[all]
# via -[all]
absl-py==2.3.1
absl-py==2.4.0
# via
# dm-control
# dm-env
# dm-tree
# labmaze
# mujoco
# tensorboard
accelerate==1.11.0
accelerate==1.13.0
# via
# lerobot
# peft
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.13.1
aiohttp==3.13.3
# via fsspec
aiosignal==1.4.0
# via aiohttp
annotated-doc==0.0.4
# via
# fastapi
# typer
annotated-types==0.7.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via
# hydra-core
# omegaconf
anyio==4.11.0
anyio==4.12.1
# via
# httpx
# starlette
# watchfiles
asttokens==3.0.0
asttokens==3.0.1
# via stack-data
async-timeout==5.0.1
# via aiohttp
attrs==25.4.0
# via
# aiohttp
# dm-tree
# jsonlines
# jsonschema
# referencing
# rerun-sdk
av==15.1.0
# via lerobot
bddl==1.0.1
# via libero
certifi==2025.10.5
# via
# lerobot
# qwen-vl-utils
certifi==2026.2.25
# via
# httpcore
# httpx
# requests
# sentry-sdk
cffi==2.0.0
# via pymunk
cfgv==3.4.0
cfgv==3.5.0
# via pre-commit
charset-normalizer==3.4.4
charset-normalizer==3.4.5
# via requests
click==8.3.0
click==8.3.1
# via
# typer
# uvicorn
# wandb
cloudpickle==3.1.1
# via
# gymnasium
# libero
cmake==4.1.0
cloudpickle==3.1.2
# via gymnasium
cmake==4.1.3
# via lerobot
cmeel==0.57.3
cmeel==0.59.0
# via
# cmeel-assimp
# cmeel-boost
@@ -108,15 +105,17 @@ cmeel-zlib==1.3.1
# via cmeel-assimp
coal-library==3.0.1
# via pin
contourpy==1.3.2
# via matplotlib
coverage[toml]==7.11.0
contourpy==1.3.3
# via
# lerobot
# matplotlib
coverage[toml]==7.13.4
# via pytest-cov
cycler==0.12.1
# via matplotlib
datasets==4.1.1
datasets==4.6.1
# via lerobot
debugpy==1.8.17
debugpy==1.8.20
# via lerobot
decorator==5.2.1
# via ipython
@@ -130,7 +129,7 @@ dill==0.4.0
# multiprocess
distlib==0.4.0
# via virtualenv
dm-control==1.0.34
dm-control==1.0.37
# via gym-aloha
dm-env==1.6
# via dm-control
@@ -138,69 +137,55 @@ dm-tree==0.1.9
# via
# dm-control
# dm-env
# lerobot
docopt==0.6.2
# via num2words
draccus==0.10.0
# via lerobot
dynamixel-sdk==3.8.4
# via lerobot
easydict==1.13
# via libero
egl-probe @ git+https://github.com/huggingface/egl_probe.git
# via
# libero
# robomimic
eigenpy==3.10.3
# via coal-library
einops==0.8.1
# via
# lerobot
# libero
einops==0.8.2
# via lerobot
eiquadprog==1.2.9
# via placo
etils[epath,epy]==1.13.0
etils[epath,epy]==1.14.0
# via mujoco
exceptiongroup==1.3.0
# via
# anyio
# ipython
# pytest
executing==2.2.1
# via stack-data
faker==34.0.2
# via lerobot
farama-notifications==0.0.4
# via gymnasium
fastapi==0.119.1
# via teleop
fastjsonschema==2.21.2
# via nbformat
fastapi==0.135.1
# via
# lerobot
# teleop
feetech-servo-sdk==1.0.0
# via lerobot
filelock==3.20.0
filelock==3.25.0
# via
# datasets
# diffusers
# huggingface-hub
# python-discovery
# torch
# transformers
# virtualenv
fonttools==4.60.1
fonttools==4.61.1
# via matplotlib
frozenlist==1.8.0
# via
# aiohttp
# aiosignal
fsspec[http]==2025.9.0
fsspec[http]==2026.2.0
# via
# datasets
# etils
# huggingface-hub
# torch
future==1.0.0
# via libero
gitdb==4.0.12
# via gitpython
gitpython==3.1.45
gitpython==3.1.46
# via wandb
glfw==2.10.0
# via
@@ -212,7 +197,6 @@ grpcio==1.73.1
# lerobot
# reachy2-sdk
# reachy2-sdk-api
# tensorboard
grpcio-tools==1.73.1
# via
# lerobot
@@ -223,71 +207,67 @@ gym-hil==0.1.13
# via lerobot
gym-pusht==0.1.6
# via lerobot
gymnasium==1.2.1
gymnasium==1.2.3
# via
# gym-aloha
# gym-hil
# gym-pusht
# lerobot
# libero
# metaworld
h11==0.16.0
# via uvicorn
h5py==3.15.1
# via robomimic
# via
# httpcore
# uvicorn
hebi-py==2.11.0
# via lerobot
hf-transfer==0.1.9
# via huggingface-hub
hf-xet==1.1.10
hf-xet==1.3.2
# via huggingface-hub
hidapi==0.14.0.post4
# via
# gym-hil
# lerobot
httpcore==1.0.9
# via httpx
httptools==0.7.1
# via uvicorn
huggingface-hub[cli,hf-transfer]==0.35.3
httpx==0.28.1
# via
# datasets
# huggingface-hub
huggingface-hub==1.6.0
# via
# accelerate
# datasets
# diffusers
# lerobot
# peft
# timm
# tokenizers
# transformers
hydra-core==1.3.2
# via libero
identify==2.6.15
identify==2.6.17
# via pre-commit
idna==3.11
# via
# anyio
# httpx
# requests
# yarl
imageio[ffmpeg]==2.37.0
imageio[ffmpeg]==2.37.2
# via
# gym-aloha
# gym-hil
# lerobot
# metaworld
# robomimic
# scikit-image
imageio-ffmpeg==0.6.0
# via
# imageio
# robomimic
importlib-metadata==8.7.0
# via imageio
importlib-metadata==8.7.1
# via diffusers
importlib-resources==6.5.2
# via etils
iniconfig==2.3.0
# via pytest
inquirerpy==0.3.4
# via huggingface-hub
ipython==8.37.0
ipython==9.11.0
# via meshcat
ipython-pygments-lexers==1.1.1
# via ipython
ischedule==1.2.7
# via placo
jedi==0.19.2
@@ -296,44 +276,24 @@ jinja2==3.1.6
# via torch
jsonlines==4.0.0
# via lerobot
jsonschema==4.25.1
# via nbformat
jsonschema-specifications==2025.9.1
# via jsonschema
jupyter-core==5.9.1
# via nbformat
jupytext==1.18.1
# via bddl
kiwisolver==1.4.9
# via matplotlib
labmaze==1.0.6
# via dm-control
lazy-loader==0.4
lazy-loader==0.5
# via scikit-image
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
# via lerobot
llvmlite==0.45.1
# via numba
librt==0.8.1
# via mypy
lxml==6.0.2
# via dm-control
markdown==3.9
# via tensorboard
markdown-it-py==4.0.0
# via
# jupytext
# mdit-py-plugins
# via rich
markupsafe==3.0.3
# via
# jinja2
# werkzeug
matplotlib==3.10.7
# via
# lerobot
# libero
# via jinja2
matplotlib==3.10.8
# via lerobot
matplotlib-inline==0.2.1
# via ipython
mdit-py-plugins==0.5.0
# via jupytext
mdurl==0.1.2
# via markdown-it-py
mergedeep==1.3.4
@@ -346,41 +306,35 @@ mock-serial==0.0.1
# via lerobot
mpmath==1.3.0
# via sympy
mujoco==3.3.7
mujoco==3.5.0
# via
# dm-control
# gym-aloha
# gym-hil
# libero
# metaworld
# robosuite
multidict==6.7.0
multidict==6.7.1
# via
# aiohttp
# yarl
multiprocess==0.70.16
multiprocess==0.70.18
# via datasets
mypy==1.19.1
# via lerobot
mypy-extensions==1.1.0
# via typing-inspect
nbformat==5.10.4
# via jupytext
networkx==3.4.2
# via
# bddl
# mypy
# typing-inspect
networkx==3.6.1
# via
# scikit-image
# torch
ninja==1.13.0
# via lerobot
nodeenv==1.9.1
nodeenv==1.10.0
# via pre-commit
num2words==0.5.14
# via lerobot
numba==0.62.1
# via robosuite
numpy==2.2.6
# via
# accelerate
# bddl
# cmeel-boost
# contourpy
# datasets
@@ -389,16 +343,14 @@ numpy==2.2.6
# dm-env
# dm-tree
# gymnasium
# h5py
# hebi-py
# imageio
# labmaze
# libero
# lerobot
# matplotlib
# meshcat
# metaworld
# mujoco
# numba
# opencv-python
# opencv-python-headless
# pandas
@@ -406,26 +358,18 @@ numpy==2.2.6
# pyquaternion
# reachy2-sdk
# rerun-sdk
# robomimic
# robosuite
# scikit-image
# scipy
# shapely
# teleop
# tensorboard
# tensorboardx
# tifffile
# torchvision
# transformers
# transforms3d
omegaconf==2.3.0
# via hydra-core
opencv-python==4.12.0.88
opencv-python==4.13.0.92
# via
# gym-pusht
# libero
# reachy2-sdk
# robosuite
opencv-python-headless==4.12.0.88
# via lerobot
orderly-set==5.5.0
@@ -435,97 +379,87 @@ packaging==25.0
# accelerate
# datasets
# huggingface-hub
# hydra-core
# jupytext
# lazy-loader
# lerobot
# matplotlib
# peft
# pytest
# qwen-vl-utils
# reachy2-sdk
# scikit-image
# tensorboard
# tensorboardx
# transformers
# wandb
pandas==2.3.3
# via
# datasets
# lerobot
parso==0.8.5
parso==0.8.6
# via jedi
peft==0.17.1
pathspec==1.0.4
# via mypy
peft==0.18.1
# via lerobot
pexpect==4.9.0
# via ipython
pfzy==0.3.4
# via inquirerpy
pillow==12.0.0
pillow==12.1.1
# via
# diffusers
# imageio
# lerobot
# matplotlib
# meshcat
# qwen-vl-utils
# rerun-sdk
# robosuite
# scikit-image
# tensorboard
# torchvision
pin==3.4.0
# via placo
placo==0.9.14
placo==0.9.16
# via lerobot
platformdirs==4.5.0
platformdirs==4.9.4
# via
# jupyter-core
# python-discovery
# virtualenv
# wandb
pluggy==1.6.0
# via
# pytest
# pytest-cov
pre-commit==4.3.0
pre-commit==4.5.1
# via lerobot
prompt-toolkit==3.0.52
# via
# inquirerpy
# ipython
# via ipython
propcache==0.4.1
# via
# aiohttp
# yarl
protobuf==6.31.0
protobuf==6.31.1
# via
# dm-control
# grpcio-tools
# lerobot
# reachy2-sdk
# reachy2-sdk-api
# tensorboard
# tensorboardx
# wandb
psutil==7.1.1
psutil==7.2.2
# via
# accelerate
# imageio
# peft
# robomimic
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
# via stack-data
pyarrow==21.0.0
pyarrow==23.0.1
# via
# datasets
# rerun-sdk
pycparser==2.23
pycparser==3.0
# via cffi
pydantic==2.12.3
pydantic==2.12.5
# via
# fastapi
# wandb
pydantic-core==2.41.4
pydantic-core==2.41.5
# via pydantic
pygame==2.6.1
# via
@@ -535,33 +469,35 @@ pygame==2.6.1
pygments==2.19.2
# via
# ipython
# ipython-pygments-lexers
# pytest
# rich
pymunk==6.11.1
# via
# gym-pusht
# lerobot
pyngrok==7.4.1
pyngrok==7.5.1
# via meshcat
pynput==1.8.1
# via
# gym-hil
# lerobot
pyobjc-core==12.0
pyobjc-core==12.1
# via
# pyobjc-framework-applicationservices
# pyobjc-framework-cocoa
# pyobjc-framework-coretext
# pyobjc-framework-quartz
pyobjc-framework-applicationservices==12.0
pyobjc-framework-applicationservices==12.1
# via pynput
pyobjc-framework-cocoa==12.0
pyobjc-framework-cocoa==12.1
# via
# pyobjc-framework-applicationservices
# pyobjc-framework-coretext
# pyobjc-framework-quartz
pyobjc-framework-coretext==12.0
pyobjc-framework-coretext==12.1
# via pyobjc-framework-applicationservices
pyobjc-framework-quartz==12.0
pyobjc-framework-quartz==12.1
# via
# pynput
# pyobjc-framework-applicationservices
@@ -570,13 +506,13 @@ pyopengl==3.1.10
# via
# dm-control
# mujoco
pyparsing==3.2.5
pyparsing==3.3.2
# via
# dm-control
# matplotlib
pyquaternion==0.9.9
# via reachy2-sdk
pyrealsense2-macosx==2.54.2
pyrealsense2-macosx==2.56.5
# via lerobot
pyserial==3.5
# via
@@ -585,7 +521,6 @@ pyserial==3.5
# lerobot
pytest==8.4.2
# via
# bddl
# lerobot
# pytest-cov
# pytest-timeout
@@ -596,11 +531,14 @@ pytest-timeout==2.4.0
# via lerobot
python-dateutil==2.9.0.post0
# via
# faker
# matplotlib
# pandas
python-dotenv==1.1.1
python-discovery==1.1.1
# via virtualenv
python-dotenv==1.2.2
# via uvicorn
pytz==2025.2
pytz==2026.1.post1
# via pandas
pyyaml==6.0.3
# via
@@ -609,13 +547,10 @@ pyyaml==6.0.3
# draccus
# hebi-py
# huggingface-hub
# jupytext
# omegaconf
# peft
# pre-commit
# pyngrok
# pyyaml-include
# timm
# transformers
# uvicorn
# wandb
@@ -625,15 +560,13 @@ pyzmq==27.1.0
# via
# lerobot
# meshcat
reachy2-sdk==1.0.14
qwen-vl-utils==0.0.14
# via lerobot
reachy2-sdk==1.0.15
# via lerobot
reachy2-sdk-api==1.0.21
# via reachy2-sdk
referencing==0.37.0
# via
# jsonschema
# jsonschema-specifications
regex==2025.10.23
regex==2026.2.28
# via
# diffusers
# transformers
@@ -642,184 +575,150 @@ requests==2.32.5
# datasets
# diffusers
# dm-control
# huggingface-hub
# qwen-vl-utils
# teleop
# transformers
# wandb
rerun-sdk==0.26.1
rerun-sdk==0.26.2
# via lerobot
rhoban-cmeel-jsoncpp==1.9.4.9
# via placo
robomimic==0.2.0
# via libero
robosuite==1.4.0
# via libero
rpds-py==0.28.0
# via
# jsonschema
# referencing
safetensors==0.6.2
rich==14.3.3
# via typer
safetensors==0.7.0
# via
# accelerate
# diffusers
# lerobot
# peft
# timm
# transformers
scikit-image==0.25.2
# via
# gym-pusht
# lerobot
scipy==1.15.3
scipy==1.17.1
# via
# dm-control
# lerobot
# metaworld
# robosuite
# scikit-image
sentry-sdk==2.42.1
# torchdiffeq
sentry-sdk==2.54.0
# via wandb
shapely==2.1.2
# via gym-pusht
shellingham==1.5.4
# via typer
six==1.17.0
# via
# pynput
# python-dateutil
smmap==5.0.2
smmap==5.0.3
# via gitdb
sniffio==1.3.1
# via anyio
stack-data==0.6.3
# via ipython
starlette==0.48.0
starlette==0.52.1
# via fastapi
sympy==1.14.0
# via torch
teleop==0.1.2
teleop==0.1.4
# via lerobot
tensorboard==2.20.0
# via robomimic
tensorboard-data-server==0.7.2
# via tensorboard
tensorboardx==2.6.4
# via robomimic
termcolor==3.1.0
# via
# lerobot
# robomimic
thop==0.1.1.post2209072238
# via libero
tifffile==2025.5.10
termcolor==3.3.0
# via lerobot
tifffile==2026.3.3
# via scikit-image
timm==1.0.20
# via lerobot
tokenizers==0.22.1
tokenizers==0.22.2
# via transformers
toml==0.10.2
# via draccus
tomli==2.3.0
# via
# cmeel
# coverage
# jupytext
# pytest
torch==2.7.1
torch==2.10.0
# via
# accelerate
# lerobot
# peft
# robomimic
# thop
# timm
# torchdiffeq
# torchvision
torchcodec==0.5
torchcodec==0.10.0
# via lerobot
torchvision==0.22.1
# via
# lerobot
# robomimic
# timm
tornado==6.5.2
torchdiffeq==0.2.5
# via lerobot
torchvision==0.25.0
# via lerobot
tornado==6.5.4
# via meshcat
tqdm==4.67.1
tqdm==4.67.3
# via
# datasets
# dm-control
# huggingface-hub
# peft
# robomimic
# transformers
traitlets==5.14.3
# via
# ipython
# jupyter-core
# matplotlib-inline
# nbformat
transformers==4.57.1
transformers==5.3.0
# via
# lerobot
# libero
# peft
transforms3d==0.4.2
# via teleop
typer==0.24.1
# via
# huggingface-hub
# transformers
typing-extensions==4.15.0
# via
# aiosignal
# anyio
# etils
# exceptiongroup
# faker
# fastapi
# gymnasium
# huggingface-hub
# ipython
# multidict
# mypy
# pydantic
# pydantic-core
# referencing
# rerun-sdk
# starlette
# torch
# typing-inspect
# typing-inspection
# uvicorn
# virtualenv
# wandb
typing-inspect==0.9.0
# via draccus
typing-inspection==0.4.2
# via pydantic
tzdata==2025.2
# via
# fastapi
# pydantic
tzdata==2025.3
# via pandas
u-msgpack-python==2.8.0
# via meshcat
urllib3==2.5.0
urllib3==2.6.3
# via
# requests
# sentry-sdk
uvicorn[standard]==0.38.0
uvicorn[standard]==0.41.0
# via teleop
uvloop==0.22.1
# via uvicorn
virtualenv==20.35.3
virtualenv==21.1.0
# via pre-commit
wandb==0.21.4
# via
# lerobot
# libero
wandb==0.24.2
# via lerobot
watchfiles==1.1.1
# via uvicorn
wcwidth==0.2.14
wcwidth==0.6.0
# via prompt-toolkit
websocket-client==1.9.0
# via teleop
websockets==15.0.1
websockets==16.0
# via uvicorn
werkzeug==3.1.3
# via tensorboard
wrapt==2.0.0
wrapt==2.1.2
# via dm-tree
xxhash==3.6.0
# via datasets
yarl==1.22.0
yarl==1.23.0
# via aiohttp
zipp==3.23.0
# via

View File

@@ -1,12 +1,12 @@
#
# This file is autogenerated by pip-compile with Python 3.10
# This file is autogenerated by pip-compile with Python 3.12
# by the following command:
#
# pip-compile --output-file=requirements-ubuntu.txt requirements.in
#
-e .[all]
# via -[all]
absl-py==2.3.1
absl-py==2.4.0
# via
# dm-control
# dm-env
@@ -14,30 +14,33 @@ absl-py==2.3.1
# labmaze
# mujoco
# tensorboard
accelerate==1.11.0
accelerate==1.13.0
# via
# lerobot
# peft
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.13.1
aiohttp==3.13.3
# via fsspec
aiosignal==1.4.0
# via aiohttp
annotated-doc==0.0.4
# via
# fastapi
# typer
annotated-types==0.7.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via
# hydra-core
# omegaconf
anyio==4.11.0
anyio==4.12.1
# via
# httpx
# starlette
# watchfiles
asttokens==3.0.0
asttokens==3.0.1
# via stack-data
async-timeout==5.0.1
# via aiohttp
attrs==25.4.0
# via
# aiohttp
@@ -47,30 +50,35 @@ attrs==25.4.0
# referencing
# rerun-sdk
av==15.1.0
# via lerobot
bddl==1.0.1
# via libero
certifi==2025.10.5
# via
# lerobot
# qwen-vl-utils
bddl==1.0.1
# via hf-libero
certifi==2026.2.25
# via
# httpcore
# httpx
# requests
# sentry-sdk
cffi==2.0.0
# via pymunk
cfgv==3.4.0
cfgv==3.5.0
# via pre-commit
charset-normalizer==3.4.4
charset-normalizer==3.4.5
# via requests
click==8.3.0
click==8.3.1
# via
# typer
# uvicorn
# wandb
cloudpickle==3.1.1
cloudpickle==3.1.2
# via
# gymnasium
# libero
cmake==4.1.0
# hf-libero
cmake==4.1.3
# via lerobot
cmeel==0.57.3
cmeel==0.59.0
# via
# cmeel-assimp
# cmeel-boost
@@ -108,20 +116,24 @@ cmeel-zlib==1.3.1
# via cmeel-assimp
coal-library==3.0.1
# via pin
contourpy==1.3.2
# via matplotlib
coverage[toml]==7.11.0
contourpy==1.3.3
# via
# lerobot
# matplotlib
coverage[toml]==7.13.4
# via pytest-cov
cuda-bindings==12.9.4
# via torch
cuda-pathfinder==1.4.1
# via cuda-bindings
cycler==0.12.1
# via matplotlib
datasets==4.1.1
datasets==4.6.1
# via lerobot
debugpy==1.8.17
debugpy==1.8.20
# via lerobot
decorator==5.2.1
# via ipython
decord==0.6.0
# via lerobot
deepdiff==8.6.1
# via lerobot
diffusers==0.35.2
@@ -132,7 +144,7 @@ dill==0.4.0
# multiprocess
distlib==0.4.0
# via virtualenv
dm-control==1.0.34
dm-control==1.0.37
# via gym-aloha
dm-env==1.6
# via dm-control
@@ -140,7 +152,6 @@ dm-tree==0.1.9
# via
# dm-control
# dm-env
# lerobot
docopt==0.6.2
# via num2words
draccus==0.10.0
@@ -148,66 +159,60 @@ draccus==0.10.0
dynamixel-sdk==3.8.4
# via lerobot
easydict==1.13
# via libero
egl-probe @ git+https://github.com/huggingface/egl_probe.git
# via
# libero
# robomimic
# via hf-libero
egl-probe==1.0.2
# via robomimic
eigenpy==3.10.3
# via coal-library
einops==0.8.1
einops==0.8.2
# via
# flash-attn
# hf-libero
# lerobot
# libero
eiquadprog==1.2.9
# via placo
etils[epath,epy]==1.13.0
etils[epath,epy]==1.14.0
# via mujoco
evdev==1.9.2
evdev==1.9.3
# via pynput
exceptiongroup==1.3.0
# via
# anyio
# ipython
# pytest
executing==2.2.1
# via stack-data
faker==34.0.2
# via lerobot
farama-notifications==0.0.4
# via gymnasium
fastapi==0.119.1
# via teleop
fastapi==0.135.1
# via
# lerobot
# teleop
fastjsonschema==2.21.2
# via nbformat
feetech-servo-sdk==1.0.0
# via lerobot
filelock==3.20.0
filelock==3.25.0
# via
# datasets
# diffusers
# huggingface-hub
# python-discovery
# torch
# transformers
# virtualenv
flash-attn==2.8.3
# via lerobot
fonttools==4.60.1
fonttools==4.61.1
# via matplotlib
frozenlist==1.8.0
# via
# aiohttp
# aiosignal
fsspec[http]==2025.9.0
fsspec[http]==2026.2.0
# via
# datasets
# etils
# huggingface-hub
# torch
future==1.0.0
# via libero
# via hf-libero
gitdb==4.0.12
# via gitpython
gitpython==3.1.45
gitpython==3.1.46
# via wandb
glfw==2.10.0
# via
@@ -230,50 +235,60 @@ gym-hil==0.1.13
# via lerobot
gym-pusht==0.1.6
# via lerobot
gymnasium==1.2.1
gymnasium==1.2.3
# via
# gym-aloha
# gym-hil
# gym-pusht
# hf-libero
# lerobot
# libero
# metaworld
h11==0.16.0
# via uvicorn
h5py==3.15.1
# via
# httpcore
# uvicorn
h5py==3.16.0
# via robomimic
hebi-py==2.11.0
# via lerobot
hf-transfer==0.1.9
# via huggingface-hub
hf-xet==1.1.10
hf-egl-probe==1.0.2
# via hf-libero
hf-libero==0.1.3
# via lerobot
hf-xet==1.3.2
# via huggingface-hub
hidapi==0.14.0.post4
# via
# gym-hil
# lerobot
httpcore==1.0.9
# via httpx
httptools==0.7.1
# via uvicorn
huggingface-hub[cli,hf-transfer]==0.35.3
httpx==0.28.1
# via
# datasets
# huggingface-hub
huggingface-hub==1.6.0
# via
# accelerate
# datasets
# diffusers
# lerobot
# peft
# timm
# tokenizers
# transformers
hydra-core==1.3.2
# via libero
identify==2.6.15
# via hf-libero
identify==2.6.17
# via pre-commit
idna==3.11
# via
# anyio
# httpx
# requests
# yarl
imageio[ffmpeg]==2.37.0
imageio[ffmpeg]==2.37.2
# via
# gym-aloha
# gym-hil
@@ -285,16 +300,14 @@ imageio-ffmpeg==0.6.0
# via
# imageio
# robomimic
importlib-metadata==8.7.0
importlib-metadata==8.7.1
# via diffusers
importlib-resources==6.5.2
# via etils
iniconfig==2.3.0
# via pytest
inquirerpy==0.3.4
# via huggingface-hub
ipython==8.37.0
ipython==9.11.0
# via meshcat
ipython-pygments-lexers==1.1.1
# via ipython
ischedule==1.2.7
# via placo
jedi==0.19.2
@@ -303,40 +316,41 @@ jinja2==3.1.6
# via torch
jsonlines==4.0.0
# via lerobot
jsonschema==4.25.1
jsonschema==4.26.0
# via nbformat
jsonschema-specifications==2025.9.1
# via jsonschema
jupyter-core==5.9.1
# via nbformat
jupytext==1.18.1
jupytext==1.19.1
# via bddl
kiwisolver==1.4.9
# via matplotlib
labmaze==1.0.6
# via dm-control
lazy-loader==0.4
lazy-loader==0.5
# via scikit-image
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
# via lerobot
llvmlite==0.45.1
librt==0.8.1
# via mypy
llvmlite==0.46.0
# via numba
lxml==6.0.2
# via dm-control
markdown==3.9
markdown==3.10.2
# via tensorboard
markdown-it-py==4.0.0
# via
# jupytext
# mdit-py-plugins
# rich
markupsafe==3.0.3
# via
# jinja2
# werkzeug
matplotlib==3.10.7
matplotlib==3.10.8
# via
# hf-libero
# lerobot
# libero
matplotlib-inline==0.2.1
# via ipython
mdit-py-plugins==0.5.0
@@ -353,36 +367,38 @@ mock-serial==0.0.1
# via lerobot
mpmath==1.3.0
# via sympy
mujoco==3.3.7
mujoco==3.5.0
# via
# dm-control
# gym-aloha
# gym-hil
# libero
# hf-libero
# metaworld
# robosuite
multidict==6.7.0
multidict==6.7.1
# via
# aiohttp
# yarl
multiprocess==0.70.16
multiprocess==0.70.18
# via datasets
mypy==1.19.1
# via lerobot
mypy-extensions==1.1.0
# via typing-inspect
# via
# mypy
# typing-inspect
nbformat==5.10.4
# via jupytext
networkx==3.4.2
networkx==3.6.1
# via
# bddl
# scikit-image
# torch
ninja==1.13.0
# via lerobot
nodeenv==1.9.1
nodeenv==1.10.0
# via pre-commit
num2words==0.5.14
# via lerobot
numba==0.62.1
numba==0.64.0
# via robosuite
numpy==2.2.6
# via
@@ -391,7 +407,6 @@ numpy==2.2.6
# cmeel-boost
# contourpy
# datasets
# decord
# diffusers
# dm-control
# dm-env
@@ -399,9 +414,10 @@ numpy==2.2.6
# gymnasium
# h5py
# hebi-py
# hf-libero
# imageio
# labmaze
# libero
# lerobot
# matplotlib
# meshcat
# metaworld
@@ -426,49 +442,51 @@ numpy==2.2.6
# torchvision
# transformers
# transforms3d
nvidia-cublas-cu12==12.6.4.1
nvidia-cublas-cu12==12.8.4.1
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-cupti-cu12==12.8.90
# via torch
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-nvrtc-cu12==12.8.93
# via torch
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.8.90
# via torch
nvidia-cudnn-cu12==9.5.1.17
nvidia-cudnn-cu12==9.10.2.21
# via torch
nvidia-cufft-cu12==11.3.0.4
nvidia-cufft-cu12==11.3.3.83
# via torch
nvidia-cufile-cu12==1.11.1.6
nvidia-cufile-cu12==1.13.1.3
# via torch
nvidia-curand-cu12==10.3.7.77
nvidia-curand-cu12==10.3.9.90
# via torch
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusolver-cu12==11.7.3.90
# via torch
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparse-cu12==12.5.8.93
# via
# nvidia-cusolver-cu12
# torch
nvidia-cusparselt-cu12==0.6.3
nvidia-cusparselt-cu12==0.7.1
# via torch
nvidia-nccl-cu12==2.26.2
nvidia-nccl-cu12==2.27.5
# via torch
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvjitlink-cu12==12.8.93
# via
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
# torch
nvidia-nvtx-cu12==12.6.77
nvidia-nvshmem-cu12==3.4.5
# via torch
nvidia-nvtx-cu12==12.8.90
# via torch
omegaconf==2.3.0
# via hydra-core
opencv-python==4.12.0.88
opencv-python==4.13.0.92
# via
# gym-pusht
# libero
# hf-libero
# reachy2-sdk
# robosuite
opencv-python-headless==4.12.0.88
@@ -487,6 +505,7 @@ packaging==25.0
# matplotlib
# peft
# pytest
# qwen-vl-utils
# reachy2-sdk
# scikit-image
# tensorboard
@@ -497,21 +516,21 @@ pandas==2.3.3
# via
# datasets
# lerobot
parso==0.8.5
parso==0.8.6
# via jedi
peft==0.17.1
pathspec==1.0.4
# via mypy
peft==0.18.1
# via lerobot
pexpect==4.9.0
# via ipython
pfzy==0.3.4
# via inquirerpy
pillow==12.0.0
pillow==12.1.1
# via
# diffusers
# imageio
# lerobot
# matplotlib
# meshcat
# qwen-vl-utils
# rerun-sdk
# robosuite
# scikit-image
@@ -519,28 +538,27 @@ pillow==12.0.0
# torchvision
pin==3.4.0
# via placo
placo==0.9.14
placo==0.9.16
# via lerobot
platformdirs==4.5.0
platformdirs==4.9.4
# via
# jupyter-core
# python-discovery
# virtualenv
# wandb
pluggy==1.6.0
# via
# pytest
# pytest-cov
pre-commit==4.3.0
pre-commit==4.5.1
# via lerobot
prompt-toolkit==3.0.52
# via
# inquirerpy
# ipython
# via ipython
propcache==0.4.1
# via
# aiohttp
# yarl
protobuf==6.31.0
protobuf==6.31.1
# via
# dm-control
# grpcio-tools
@@ -550,7 +568,7 @@ protobuf==6.31.0
# tensorboard
# tensorboardx
# wandb
psutil==7.1.1
psutil==7.2.2
# via
# accelerate
# imageio
@@ -560,17 +578,17 @@ ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
# via stack-data
pyarrow==21.0.0
pyarrow==23.0.1
# via
# datasets
# rerun-sdk
pycparser==2.23
pycparser==3.0
# via cffi
pydantic==2.12.3
pydantic==2.12.5
# via
# fastapi
# wandb
pydantic-core==2.41.4
pydantic-core==2.41.5
# via pydantic
pygame==2.6.1
# via
@@ -580,12 +598,14 @@ pygame==2.6.1
pygments==2.19.2
# via
# ipython
# ipython-pygments-lexers
# pytest
# rich
pymunk==6.11.1
# via
# gym-pusht
# lerobot
pyngrok==7.4.1
pyngrok==7.5.1
# via meshcat
pynput==1.8.1
# via
@@ -595,7 +615,7 @@ pyopengl==3.1.10
# via
# dm-control
# mujoco
pyparsing==3.2.5
pyparsing==3.3.2
# via
# dm-control
# matplotlib
@@ -621,13 +641,16 @@ pytest-timeout==2.4.0
# via lerobot
python-dateutil==2.9.0.post0
# via
# faker
# matplotlib
# pandas
python-dotenv==1.1.1
python-discovery==1.1.1
# via virtualenv
python-dotenv==1.2.2
# via uvicorn
python-xlib==0.33
# via pynput
pytz==2025.2
pytz==2026.1.post1
# via pandas
pyyaml==6.0.3
# via
@@ -642,7 +665,6 @@ pyyaml==6.0.3
# pre-commit
# pyngrok
# pyyaml-include
# timm
# transformers
# uvicorn
# wandb
@@ -652,7 +674,9 @@ pyzmq==27.1.0
# via
# lerobot
# meshcat
reachy2-sdk==1.0.14
qwen-vl-utils==0.0.14
# via lerobot
reachy2-sdk==1.0.15
# via lerobot
reachy2-sdk-api==1.0.21
# via reachy2-sdk
@@ -660,7 +684,7 @@ referencing==0.37.0
# via
# jsonschema
# jsonschema-specifications
regex==2025.10.23
regex==2026.2.28
# via
# diffusers
# transformers
@@ -669,60 +693,62 @@ requests==2.32.5
# datasets
# diffusers
# dm-control
# huggingface-hub
# qwen-vl-utils
# teleop
# transformers
# wandb
rerun-sdk==0.26.1
rerun-sdk==0.26.2
# via lerobot
rhoban-cmeel-jsoncpp==1.9.4.9
# via placo
rich==14.3.3
# via typer
robomimic==0.2.0
# via libero
# via hf-libero
robosuite==1.4.0
# via libero
rpds-py==0.28.0
# via hf-libero
rpds-py==0.30.0
# via
# jsonschema
# referencing
safetensors==0.6.2
safetensors==0.7.0
# via
# accelerate
# diffusers
# lerobot
# peft
# timm
# transformers
scikit-image==0.25.2
# via
# gym-pusht
# lerobot
scipy==1.15.3
scipy==1.17.1
# via
# dm-control
# lerobot
# metaworld
# robosuite
# scikit-image
sentry-sdk==2.42.1
# torchdiffeq
sentry-sdk==2.54.0
# via wandb
shapely==2.1.2
# via gym-pusht
shellingham==1.5.4
# via typer
six==1.17.0
# via
# pynput
# python-dateutil
# python-xlib
smmap==5.0.2
smmap==5.0.3
# via gitdb
sniffio==1.3.1
# via anyio
stack-data==0.6.3
# via ipython
starlette==0.48.0
starlette==0.52.1
# via fastapi
sympy==1.14.0
# via torch
teleop==0.1.2
teleop==0.1.4
# via lerobot
tensorboard==2.20.0
# via robomimic
@@ -730,46 +756,38 @@ tensorboard-data-server==0.7.2
# via tensorboard
tensorboardx==2.6.4
# via robomimic
termcolor==3.1.0
termcolor==3.3.0
# via
# lerobot
# robomimic
thop==0.1.1.post2209072238
# via libero
tifffile==2025.5.10
# via hf-libero
tifffile==2026.3.3
# via scikit-image
timm==1.0.20
# via lerobot
tokenizers==0.22.1
tokenizers==0.22.2
# via transformers
toml==0.10.2
# via draccus
tomli==2.3.0
# via
# cmeel
# coverage
# jupytext
# pytest
torch==2.7.1
torch==2.10.0
# via
# accelerate
# flash-attn
# lerobot
# peft
# robomimic
# thop
# timm
# torchdiffeq
# torchvision
torchcodec==0.5
torchcodec==0.10.0
# via lerobot
torchvision==0.22.1
torchdiffeq==0.2.5
# via lerobot
torchvision==0.25.0
# via
# lerobot
# robomimic
# timm
tornado==6.5.2
tornado==6.5.4
# via meshcat
tqdm==4.67.1
tqdm==4.67.3
# via
# datasets
# dm-control
@@ -783,26 +801,29 @@ traitlets==5.14.3
# jupyter-core
# matplotlib-inline
# nbformat
transformers==4.57.1
transformers==5.3.0
# via
# hf-libero
# lerobot
# libero
# peft
transforms3d==0.4.2
# via teleop
triton==3.3.1
triton==3.6.0
# via torch
typer==0.24.1
# via
# huggingface-hub
# transformers
typing-extensions==4.15.0
# via
# aiosignal
# anyio
# etils
# exceptiongroup
# faker
# fastapi
# gymnasium
# huggingface-hub
# ipython
# multidict
# mypy
# pydantic
# pydantic-core
# referencing
@@ -811,46 +832,46 @@ typing-extensions==4.15.0
# torch
# typing-inspect
# typing-inspection
# uvicorn
# virtualenv
# wandb
typing-inspect==0.9.0
# via draccus
typing-inspection==0.4.2
# via pydantic
tzdata==2025.2
# via
# fastapi
# pydantic
tzdata==2025.3
# via pandas
u-msgpack-python==2.8.0
# via meshcat
urllib3==2.5.0
urllib3==2.6.3
# via
# requests
# sentry-sdk
uvicorn[standard]==0.38.0
uvicorn[standard]==0.41.0
# via teleop
uvloop==0.22.1
# via uvicorn
virtualenv==20.35.3
virtualenv==21.1.0
# via pre-commit
wandb==0.21.4
wandb==0.24.2
# via
# hf-libero
# lerobot
# libero
watchfiles==1.1.1
# via uvicorn
wcwidth==0.2.14
wcwidth==0.6.0
# via prompt-toolkit
websocket-client==1.9.0
# via teleop
websockets==15.0.1
websockets==16.0
# via uvicorn
werkzeug==3.1.3
werkzeug==3.1.6
# via tensorboard
wrapt==2.0.0
wrapt==2.1.2
# via dm-tree
xxhash==3.6.0
# via datasets
yarl==1.22.0
yarl==1.23.0
# via aiohttp
zipp==3.23.0
# via

View File

@@ -1,9 +1,9 @@
# requirements.in
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.0.1 25A362 arm64).
# Darwin MacBook-Pro.local 25.0.0 Darwin Kernel Version 25.0.0: Wed Sep 17 21:42:08 PDT 2025; root:xnu-12377.1.9~141/RELEASE_ARM64_T8132 arm64
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.3.1 25D2128 arm64).
# Darwin MacBook-Pro.local 25.3.0 Darwin Kernel Version 25.3.0: Wed Jan 28 20:54:55 PST 2026; root:xnu-12377.91.3~2/RELEASE_ARM64_T8132 arm64
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.3 LTS x86_64).
# Linux mlerobot-linux 6.14.0-33-generic #33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.4 LTS x86_64).
# Linux lerobot-linux 6.17.0-14-generic #14~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Jan 15 15:52:10 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
-e .[all]

View File

@@ -181,7 +181,7 @@ class ZMQCamera(Camera):
try:
message = self.socket.recv_string()
except Exception as e:
# Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import
# zmq is lazy-imported in connect(), so check by name to avoid a top-level import
if type(e).__name__ == "Again":
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
raise

View File

@@ -23,6 +23,7 @@ import base64
import contextlib
import json
import logging
import threading
import time
from collections import deque
@@ -42,10 +43,57 @@ def encode_image(image: np.ndarray, quality: int = 80) -> str:
return base64.b64encode(buffer).decode("utf-8")
class CameraCaptureThread:
"""Background thread that continuously captures and encodes frames from a camera."""
def __init__(self, camera: OpenCVCamera, name: str):
self.camera = camera
self.name = name
self.latest_encoded: str | None = None # Pre-encoded JPEG as base64
self.latest_timestamp: float = 0.0
self.frame_lock = threading.Lock()
self.running = False
self.thread: threading.Thread | None = None
def start(self):
"""Start the capture thread."""
self.running = True
self.thread = threading.Thread(target=self._capture_loop, daemon=True)
self.thread.start()
def stop(self):
"""Stop the capture thread."""
self.running = False
if self.thread:
self.thread.join(timeout=1.0)
def _capture_loop(self):
"""Continuously capture and encode frames at the camera's native rate."""
while self.running:
try:
frame = self.camera.read() # Blocks at camera's native rate
timestamp = time.time()
# Encode immediately in capture thread (this is the slow part)
encoded = encode_image(frame)
with self.frame_lock:
self.latest_encoded = encoded
self.latest_timestamp = timestamp
except Exception as e:
logger.warning(f"Camera {self.name} capture error: {e}")
time.sleep(0.01)
def get_latest(self) -> tuple[str | None, float]:
"""Get the latest encoded frame and its timestamp."""
with self.frame_lock:
return self.latest_encoded, self.latest_timestamp
class ImageServer:
def __init__(self, config: dict, port: int = 5555):
# fps controls the publish loop rate (how often frames are sent over ZMQ), not the camera capture rate
self.fps = config.get("fps", 30)
self.cameras: dict[str, OpenCVCamera] = {}
self.capture_threads: dict[str, CameraCaptureThread] = {}
for name, cfg in config.get("cameras", {}).items():
shape = cfg.get("shape", [480, 640])
@@ -61,6 +109,10 @@ class ImageServer:
self.cameras[name] = camera
logger.info(f"Camera {name}: {shape[1]}x{shape[0]}")
# Create capture thread for this camera
capture_thread = CameraCaptureThread(camera, name)
self.capture_threads[name] = capture_thread
# ZMQ PUB socket
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
@@ -73,6 +125,18 @@ class ImageServer:
def run(self):
frame_count = 0
frame_times = deque(maxlen=60)
last_published_ts: dict[str, float] = {}
# Start all capture threads
for capture_thread in self.capture_threads.values():
capture_thread.start()
# Wait for first frames to be captured and encoded
logger.info("Waiting for cameras to start capturing...")
for name, capture_thread in self.capture_threads.items():
while capture_thread.get_latest()[0] is None:
time.sleep(0.01)
logger.info(f"Camera {name} ready (capture + encode in background)")
try:
while True:
@@ -80,10 +144,12 @@ class ImageServer:
# Build message
message = {"timestamps": {}, "images": {}}
for name, cam in self.cameras.items():
frame = cam.read() # Returns RGB
message["timestamps"][name] = time.time()
message["images"][name] = encode_image(frame)
for name, capture_thread in self.capture_threads.items():
encoded, timestamp = capture_thread.get_latest()
if encoded is not None and timestamp > last_published_ts.get(name, 0.0):
message["timestamps"][name] = timestamp
message["images"][name] = encoded
last_published_ts[name] = timestamp
# Send as JSON string (suppress if buffer full)
with contextlib.suppress(zmq.Again):
@@ -102,6 +168,8 @@ class ImageServer:
except KeyboardInterrupt:
pass
finally:
for capture_thread in self.capture_threads.values():
capture_thread.stop()
for cam in self.cameras.values():
cam.disconnect()
self.socket.close()

View File

@@ -50,6 +50,9 @@ class TrainPipelineConfig(HubMixin):
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
seed: int | None = 1000
# Set to True to use deterministic cuDNN algorithms for reproducibility.
# This disables cudnn.benchmark and may reduce training speed by ~10-20%.
cudnn_deterministic: bool = False
# Number of workers for the dataloader.
num_workers: int = 4
batch_size: int = 8

View File

@@ -578,7 +578,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
elif ft["shape"] == (1,):
elif ft["shape"] == (1,) and ft["names"] is None:
hf_features[key] = datasets.Value(dtype=ft["dtype"])
elif len(ft["shape"]) == 1:
hf_features[key] = datasets.Sequence(

View File

@@ -57,6 +57,7 @@ import pyarrow as pa
import tqdm
from datasets import Dataset, Features, Image
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
from requests import HTTPError
from lerobot.datasets.compute_stats import aggregate_stats
@@ -511,7 +512,7 @@ def convert_dataset(
hub_api = HfApi()
try:
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
except HTTPError as e:
except(HTTPError,RevisionNotFoundError) as e:
print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
pass
hub_api.delete_files(

View File

@@ -15,7 +15,6 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
@@ -29,7 +28,6 @@ from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
__all__ = [
"ACTConfig",
"DiffusionConfig",
"MultiTaskDiTConfig",
"PI0Config",
"PI05Config",
"PI0FastConfig",

View File

@@ -31,7 +31,6 @@ from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -67,7 +66,8 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -86,10 +86,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.act.modeling_act import ACTPolicy
return ACTPolicy
elif name == "multi_task_dit":
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
return MultiTaskDiTPolicy
elif name == "vqbet":
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
@@ -150,8 +146,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"smolvla", "reward_classifier", "wall_x".
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
"reward_classifier", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -166,8 +162,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return DiffusionConfig(**kwargs)
elif policy_type == "act":
return ACTConfig(**kwargs)
elif policy_type == "multi_task_dit":
return MultiTaskDiTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
@@ -314,16 +308,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, MultiTaskDiTConfig):
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
make_multi_task_dit_pre_post_processors,
)
processors = make_multi_task_dit_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, VQBeTConfig):
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors

View File

@@ -1,37 +0,0 @@
# Multitask DiT Policy
## Citation
If you use this work, please cite the following works:
```bibtex
@misc{jones2025multitaskditpolicy,
author = {Bryson Jones},
title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy},
year = {2025},
url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy},
note = {Blog post}
}
```
```bibtex
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
author = {TRI LBM Team},
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
year = {2025},
eprint = {arXiv:2507.05331},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2507.05331}
}
```
```bibtex
@misc{bostondynamics2025largebehaviormodelsatlas,
author = {Boston Dynamics and TRI Research Team},
title = {Large Behavior Models and Atlas Find New Footing},
year = {2025},
url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/},
note = {Blog post}
}
```

View File

@@ -1,21 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_multi_task_dit import MultiTaskDiTConfig
from .modeling_multi_task_dit import MultiTaskDiTPolicy
from .processor_multi_task_dit import make_multi_task_dit_pre_post_processors
__all__ = ["MultiTaskDiTConfig", "MultiTaskDiTPolicy", "make_multi_task_dit_pre_post_processors"]

View File

@@ -1,256 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamConfig
from lerobot.optim.schedulers import DiffuserSchedulerConfig
@PreTrainedConfig.register_subclass("multi_task_dit")
@dataclass
class MultiTaskDiTConfig(PreTrainedConfig):
"""Configuration for the Multi-Task Diffusion Transformer (DiT) policy.
A transformer-based policy that supports both diffusion and flow matching objectives
for multi-task robot learning with text and vision conditioning.
"""
n_obs_steps: int = 2 # Number of observation steps for temporal context
horizon: int = 32 # Number of action steps to predict
n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz)
# Objective Selection
objective: str = "diffusion" # "diffusion" or "flow_matching"
# --- Diffusion-specific (used when objective="diffusion") ---
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
num_train_timesteps: int = 100 # Number of diffusion timesteps
beta_schedule: str = "squaredcos_cap_v2" # Noise schedule type
beta_start: float = 0.0001 # Starting noise level
beta_end: float = 0.02 # Ending noise level
prediction_type: str = "epsilon" # "epsilon" (predict noise) or "sample" (predict clean)
clip_sample: bool = True # Clip samples during denoising
clip_sample_range: float = 1.0 # Clipping range [-x, x]
num_inference_steps: int | None = None # Denoising steps at inference (defaults to num_train_timesteps)
# --- Flow Matching-specific (used when objective="flow_matching") ---
sigma_min: float = 0.0 # Minimum noise in flow interpolation path
num_integration_steps: int = 100 # ODE integration steps at inference
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
timestep_sampling_strategy: str = "beta" # "uniform" or "beta"
timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold
timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha
timestep_sampling_beta: float = 1.0 # (beta only) Beta distribution beta
# Transformer Architecture
hidden_dim: int = 512 # Transformer hidden dimension
num_layers: int = 6 # Number of transformer layers
num_heads: int = 8 # Number of attention heads
dropout: float = 0.1 # Dropout rate
use_positional_encoding: bool = False # Use absolute positional encoding
timestep_embed_dim: int = 256 # Timestep embedding dimension
use_rope: bool = True # Use Rotary Position Embedding
rope_base: float = 10000.0 # RoPE base frequency
# Vision Encoder (CLIP)
vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
use_separate_rgb_encoder_per_camera: bool = False # Separate encoder per camera view
vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder
image_resize_shape: tuple[int, int] | None = None # Resize images before crop
image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default)
image_crop_is_random: bool = True # Random crop during training, center at inference
# Text Encoder (CLIP)
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
tokenizer_max_length: int = 77 # Max length for tokenized text (CLIP default is 77)
tokenizer_padding: str = "max_length" # Padding strategy: "max_length" or "longest"
tokenizer_padding_side: str = "right" # Padding side: "left" or "right"
tokenizer_truncation: bool = True # Whether to truncate sequences longer than max_length
# Normalization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# Training/Optimizer
optimizer_lr: float = 2e-5
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.0
scheduler_name: str = "cosine"
scheduler_warmup_steps: int = 0
do_mask_loss_for_padding: bool = False
# Auto-calculated
drop_n_last_frames: int | None = None
def __post_init__(self):
super().__post_init__()
if self.drop_n_last_frames is None:
self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1
self._validate()
def _validate(self):
"""Validate configuration parameters."""
# Objective validation
if self.objective not in ["diffusion", "flow_matching"]:
raise ValueError(f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'")
# Transformer validation
if self.hidden_dim <= 0:
raise ValueError("hidden_dim must be positive")
if self.num_layers <= 0:
raise ValueError("num_layers must be positive")
if self.num_heads <= 0:
raise ValueError("num_heads must be positive")
if self.hidden_dim % self.num_heads != 0:
raise ValueError("hidden_dim must be divisible by num_heads")
if not (0.0 <= self.dropout <= 1.0):
raise ValueError("dropout must be between 0.0 and 1.0")
# Vision encoder validation
if "clip" not in self.vision_encoder_name.lower():
raise ValueError(
f"vision_encoder_name must be a CLIP model (contain 'clip'), got '{self.vision_encoder_name}'"
)
if (
self.image_resize_shape
and self.image_crop_shape
and (
self.image_crop_shape[0] > self.image_resize_shape[0]
or self.image_crop_shape[1] > self.image_resize_shape[1]
)
):
logging.warning(
"image_crop_shape %s must be <= image_resize_shape %s; disabling cropping.",
self.image_crop_shape,
self.image_resize_shape,
)
self.image_crop_shape = None
# Text encoder validation
if "clip" not in self.text_encoder_name.lower():
raise ValueError(
f"text_encoder_name must be a CLIP model (contain 'clip'), got '{self.text_encoder_name}'"
)
# Objective-specific validation
if self.objective == "diffusion":
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
raise ValueError(
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
)
if self.prediction_type not in ["epsilon", "sample"]:
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
if self.num_train_timesteps <= 0:
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
raise ValueError(f"Invalid beta values: {self.beta_start}, {self.beta_end}")
elif self.objective == "flow_matching":
if not (0.0 <= self.sigma_min <= 1.0):
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
if self.num_integration_steps <= 0:
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
if self.integration_method not in ["euler", "rk4"]:
raise ValueError(
f"integration_method must be 'euler' or 'rk4', got {self.integration_method}"
)
if self.timestep_sampling_strategy not in ["uniform", "beta"]:
raise ValueError("timestep_sampling_strategy must be 'uniform' or 'beta'")
if self.timestep_sampling_strategy == "beta":
if not (0.0 < self.timestep_sampling_s <= 1.0):
raise ValueError(f"timestep_sampling_s must be in (0, 1], got {self.timestep_sampling_s}")
if self.timestep_sampling_alpha <= 0:
raise ValueError("timestep_sampling_alpha must be positive")
if self.timestep_sampling_beta <= 0:
raise ValueError("timestep_sampling_beta must be positive")
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
return DiffuserSchedulerConfig(
name=self.scheduler_name,
num_warmup_steps=self.scheduler_warmup_steps,
)
def validate_features(self) -> None:
"""Validate that required input features are present and properly configured."""
# If the configured crop doesn't fit, disable cropping instead of erroring.
# Note: if image_resize_shape is set, cropping is applied *after* resizing.
if self.image_crop_shape is not None:
for key, image_ft in self.image_features.items():
# image_ft.shape is (C, H, W)
effective_h, effective_w = (
self.image_resize_shape
if self.image_resize_shape is not None
else (image_ft.shape[1], image_ft.shape[2])
)
if self.image_crop_shape[0] > effective_h or self.image_crop_shape[1] > effective_w:
logging.warning(
"image_crop_shape %s doesn't fit within effective image shape (%s, %s) for '%s'; disabling cropping.",
self.image_crop_shape,
effective_h,
effective_w,
key,
)
self.image_crop_shape = None
break
if len(self.image_features) > 0:
first_key, first_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_ft.shape:
raise ValueError(
f"Image '{key}' shape {image_ft.shape} != '{first_key}' shape {first_ft.shape}"
)
@property
def is_diffusion(self) -> bool:
return self.objective == "diffusion"
@property
def is_flow_matching(self) -> bool:
return self.objective == "flow_matching"
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -1,803 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-Task Diffusion Transformer (DiT) Policy
Transformer-based diffusion policy for multi-task robot learning with text and vision conditioning.
Supports both diffusion and flow matching objectives for action generation.
References:
- https://arxiv.org/abs/2507.05331
- https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/
- https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy
"""
import math
from collections import deque
from typing import TYPE_CHECKING
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import CLIPTextModel, CLIPVisionModel
else:
CLIPTextModel = None
CLIPVisionModel = None
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
# -- Policy --
class MultiTaskDiTPolicy(PreTrainedPolicy):
config_class = MultiTaskDiTConfig
name = "multi_task_dit"
def __init__(self, config: MultiTaskDiTConfig, **kwargs):
super().__init__(config)
config.validate_features()
self.config = config
self._queues = None
self.observation_encoder = ObservationEncoder(config)
conditioning_dim = self.observation_encoder.conditioning_dim
self.noise_predictor = DiffusionTransformer(config, conditioning_dim=conditioning_dim)
action_dim = config.action_feature.shape[0]
horizon = config.horizon
if config.is_diffusion:
self.objective = DiffusionObjective(
config,
action_dim=action_dim,
horizon=horizon,
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
)
elif config.is_flow_matching:
self.objective = FlowMatchingObjective(
config,
action_dim=action_dim,
horizon=horizon,
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
)
else:
raise ValueError(f"Unsupported objective: {config.objective}")
self.reset()
def get_optim_params(self) -> list:
"""Returns parameter groups with different learning rates for vision vs non-vision parameters"""
non_vision_params = []
vision_encoder_params = []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
if "observation_encoder.vision_encoder" in name:
vision_encoder_params.append(param)
else:
non_vision_params.append(param)
return [
{"params": non_vision_params},
{
"params": vision_encoder_params,
"lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier,
},
]
def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
conditioning_vec = self.observation_encoder.encode(batch)
actions = self.objective.conditional_sample(self.noise_predictor, batch_size, conditioning_vec)
start = n_obs_steps - 1
end = start + self.config.n_action_steps
actions = actions[:, start:end]
return actions
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
ACTION: deque(maxlen=self.config.n_action_steps),
}
if self.config.image_features:
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations"""
self.eval()
for k in batch:
if k in self._queues:
batch[k] = torch.stack(list(self._queues[k]), dim=1)
actions = self._generate_actions(batch)
return actions
def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Prepare batch by stacking image features if needed."""
if self.config.image_features:
batch = dict(batch) # shallow copy to avoid modifying original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
return batch
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations"""
if ACTION in batch:
batch = dict(batch) # shallow copy to avoid modifying original
batch.pop(ACTION)
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch)
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1))
action = self._queues[ACTION].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""Run the batch through the model and compute the loss for training"""
batch = self._prepare_batch(batch)
conditioning_vec = self.observation_encoder.encode(batch)
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
return loss, None
# -- Observation Encoders --
class CLIPVisionEncoder(nn.Module):
"""CLIP vision encoder using the CLS token for global image representation."""
def __init__(self, model_name: str):
super().__init__()
self.model_name = model_name
self.model = CLIPVisionModel.from_pretrained(self.model_name)
self.num_non_spatial_tokens = 1
self.embed_dim = self.model.config.hidden_size
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to CLS token."""
outputs = self.model(pixel_values=x, output_hidden_states=False)
cls_token = outputs.last_hidden_state[:, 0]
b, embed_dim = cls_token.shape
return cls_token.reshape(b, embed_dim, 1, 1)
def get_output_shape(self) -> tuple:
return (self.embed_dim, 1, 1)
class CLIPTextEncoder(nn.Module):
"""CLIP text encoder with frozen weights and a learnable projection layer.
Accepts pre-tokenized inputs (input_ids and attention_mask) from the processor pipeline. See the processor
pipeline to see how the tokenization is handled.
"""
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
super().__init__()
self.model_name = model_name
self.projection_dim = projection_dim
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_embed_dim = self.text_encoder.config.hidden_size
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
"""Encode pre-tokenized text to feature vectors."""
# Ensure inputs are on the same device as the model
device = next(self.parameters()).device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
with torch.no_grad():
outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
clip_features = outputs.pooler_output
return self.projection(clip_features)
class ObservationEncoder(nn.Module):
"""Handles all observation processing for the conditioning vector."""
def __init__(self, config):
super().__init__()
self.config = config
self._setup_preprocessing(config)
if config.image_features:
self.num_cameras = len(config.image_features)
self.camera_names = list(config.image_features.keys())
if config.use_separate_rgb_encoder_per_camera:
self.vision_encoders = nn.ModuleList(
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
)
self.vision_encoder = None
else:
self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name)
self.vision_encoders = None
else:
self.vision_encoder = None
self.vision_encoders = None
self.camera_names = []
self.num_cameras = 0
if hasattr(config, "robot_state_feature") and config.robot_state_feature:
self.robot_state_dim = config.robot_state_feature.shape[0]
else:
self.robot_state_dim = 0
self.text_dim = config.hidden_dim
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
self._setup_vector_output()
def _apply_preprocessing(self, images: Tensor) -> Tensor:
if self.do_resize:
images = self.resize(images)
if self.do_crop:
images = self.maybe_random_crop(images) if self.training else self.center_crop(images)
return images
def _setup_preprocessing(self, config):
if config.image_resize_shape is not None:
self.do_resize = True
self.resize = torchvision.transforms.Resize(
size=config.image_resize_shape,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True,
)
else:
self.do_resize = False
if config.image_crop_shape is not None:
self.do_crop = True
self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape)
if config.image_crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
def _setup_vector_output(self):
total_dim = 0
if self.vision_encoder is not None or self.vision_encoders is not None:
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders))
feature_map_shape = encoder_to_check.get_output_shape()
c, h, w = feature_map_shape
spatial_feature_dim = c * h * w
total_dim += spatial_feature_dim * self.num_cameras
total_dim += self.robot_state_dim
total_dim += self.text_dim
self.conditioning_dim = total_dim * self.config.n_obs_steps
def encode(self, batch: dict) -> Tensor:
"""Encode observations to vector format."""
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
conditioning_feats = []
conditioning_feats.append(batch[OBS_STATE])
if self.vision_encoder is not None or self.vision_encoders is not None:
images = batch[OBS_IMAGES]
if len(images.shape) == 5:
images = images.unsqueeze(1)
if self.config.use_separate_rgb_encoder_per_camera:
camera_features = []
for cam_idx in range(self.num_cameras):
cam_images = images[:, :, cam_idx]
cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w")
cam_images_flat = self._apply_preprocessing(cam_images_flat)
cam_features = self.vision_encoders[cam_idx](cam_images_flat)
cam_visual_features = cam_features.flatten(start_dim=1)
cam_features_reshaped = einops.rearrange(
cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps
)
camera_features.append(cam_features_reshaped)
img_features = torch.cat(camera_features, dim=-1)
conditioning_feats.append(img_features)
else:
images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w")
images_flat = self._apply_preprocessing(images_flat)
visual_features = self.vision_encoder(images_flat).flatten(start_dim=1)
img_features = einops.rearrange(
visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras
)
conditioning_feats.append(img_features)
if self.text_encoder is not None and OBS_LANGUAGE_TOKENS in batch:
input_ids = batch[OBS_LANGUAGE_TOKENS] # [batch_size, seq_length]
attention_mask = batch[OBS_LANGUAGE_ATTENTION_MASK] # [batch_size, seq_length]
text_features = self.text_encoder(input_ids, attention_mask)
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
conditioning_feats.append(text_features)
combined_features = torch.cat(conditioning_feats, dim=-1)
return combined_features.flatten(start_dim=1)
# -- Transformer Components --
def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
"""Modulate input with shift and scale for AdaLN-Zero."""
return x * (1 + scale) + shift
class SinusoidalPosEmb(nn.Module):
"""Sinusoidal positional embeddings for timesteps."""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class RotaryPositionalEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE) for transformers."""
def __init__(self, head_dim: int, max_seq_len: int = 512, base: float = 10000.0):
super().__init__()
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._precompute_cache(max_seq_len)
def _precompute_cache(self, seq_len: int):
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False)
def _rotate_half(self, x: Tensor) -> Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]:
seq_len = q.shape[2]
if seq_len > self.max_seq_len:
raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}.")
cos = self._cos_cached[:, :, :seq_len, :].to(q.dtype)
sin = self._sin_cached[:, :, :seq_len, :].to(q.dtype)
q_rotated = (q * cos) + (self._rotate_half(q) * sin)
k_rotated = (k * cos) + (self._rotate_half(k) * sin)
return q_rotated, k_rotated
class RoPEAttention(nn.Module):
"""Multi-head self-attention with Rotary Position Embedding (RoPE)."""
def __init__(
self,
hidden_size: int,
num_heads: int,
dropout: float = 0.0,
max_seq_len: int = 512,
rope_base: float = 10000.0,
):
super().__init__()
assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim**-0.5
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.rope = RotaryPositionalEmbedding(head_dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base)
def forward(self, x: Tensor) -> Tensor:
B, T, _ = x.shape # noqa: N806
qkv = self.qkv_proj(x)
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q, k = self.rope(q, k)
attn_out = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout.p if isinstance(self.dropout, nn.Dropout) and self.training else 0.0,
)
attn_out = attn_out.transpose(1, 2).reshape(B, T, self.hidden_size)
return self.out_proj(attn_out)
class TransformerBlock(nn.Module):
"""DiT-style transformer block with AdaLN-Zero."""
def __init__(
self,
hidden_size: int = 128,
num_heads: int = 4,
num_features: int = 128,
dropout: float = 0.0,
use_rope: bool = False,
max_seq_len: int = 512,
rope_base: float = 10000.0,
):
super().__init__()
self.use_rope = use_rope
if use_rope:
self.attn = RoPEAttention(
hidden_size=hidden_size,
num_heads=num_heads,
dropout=dropout,
max_seq_len=max_seq_len,
rope_base=rope_base,
)
else:
self.multihead_attn = nn.MultiheadAttention(
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
)
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(approximate="tanh"),
nn.Linear(hidden_size * 4, hidden_size),
)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(num_features, 6 * hidden_size, bias=True))
def forward(self, x: Tensor, features: Tensor) -> Tensor:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
features
).chunk(6, dim=1)
attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1))
if self.use_rope:
attn_out = self.attn(attn_input)
else:
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
x = x + gate_msa.unsqueeze(1) * attn_out
mlp_input = modulate(self.norm2(x), shift_mlp.unsqueeze(1), scale_mlp.unsqueeze(1))
mlp_out = self.mlp(mlp_input)
x = x + gate_mlp.unsqueeze(1) * mlp_out
return x
class DiffusionTransformer(nn.Module):
"""Transformer-based diffusion noise prediction model."""
def __init__(self, config, conditioning_dim: int):
super().__init__()
self.config = config
self.conditioning_dim = conditioning_dim
self.action_dim = config.action_feature.shape[0]
self.horizon = config.horizon
self.hidden_size = config.hidden_dim
self.num_layers = config.num_layers
self.num_heads = config.num_heads
self.dropout = config.dropout
self.use_rope = config.use_rope
self.timestep_embed_dim = config.timestep_embed_dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(self.timestep_embed_dim),
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
nn.GELU(),
nn.Linear(2 * self.timestep_embed_dim, self.timestep_embed_dim),
nn.GELU(),
)
self.cond_dim = self.timestep_embed_dim + conditioning_dim
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
if config.use_positional_encoding:
self.pos_embedding = nn.Parameter(
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
)
else:
self.pos_embedding = None
self.transformer_blocks = nn.ModuleList(
[
TransformerBlock(
hidden_size=self.hidden_size,
num_heads=self.num_heads,
num_features=self.cond_dim,
dropout=self.dropout,
use_rope=self.use_rope,
max_seq_len=self.horizon,
rope_base=config.rope_base,
)
for _ in range(self.num_layers)
]
)
self.output_proj = nn.Linear(self.hidden_size, self.action_dim)
self._initialize_weights()
def _initialize_weights(self):
for block in self.transformer_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
def forward(self, x: Tensor, timestep: Tensor, conditioning_vec: Tensor) -> Tensor:
_, seq_len, _ = x.shape
timestep_features = self.time_mlp(timestep)
cond_features = torch.cat([timestep_features, conditioning_vec], dim=-1)
hidden_seq = self.input_proj(x)
if self.pos_embedding is not None:
hidden_seq = hidden_seq + self.pos_embedding[:, :seq_len, :]
for block in self.transformer_blocks:
hidden_seq = block(hidden_seq, cond_features)
return self.output_proj(hidden_seq)
# -- Objectives --
class DiffusionObjective(nn.Module):
"""Standard diffusion (DDPM/DDIM) objective implementation."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__()
self.config = config
self.action_dim = action_dim
self.horizon = horizon
self.do_mask_loss_for_padding = do_mask_loss_for_padding
scheduler_kwargs = {
"num_train_timesteps": config.num_train_timesteps,
"beta_start": config.beta_start,
"beta_end": config.beta_end,
"beta_schedule": config.beta_schedule,
"clip_sample": config.clip_sample,
"clip_sample_range": config.clip_sample_range,
"prediction_type": config.prediction_type,
}
if config.noise_scheduler_type == "DDPM":
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
elif config.noise_scheduler_type == "DDIM":
self.noise_scheduler = DDIMScheduler(**scheduler_kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {config.noise_scheduler_type}")
self.num_inference_steps = (
config.num_inference_steps
if config.num_inference_steps is not None
else self.noise_scheduler.config.num_train_timesteps
)
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
clean_actions = batch[ACTION]
noise = torch.randn_like(clean_actions)
timesteps = torch.randint(
low=0,
high=self.noise_scheduler.config.num_train_timesteps,
size=(clean_actions.shape[0],),
device=clean_actions.device,
).long()
noisy_actions = self.noise_scheduler.add_noise(clean_actions, noise, timesteps)
prediction_type = self.noise_scheduler.config.prediction_type
if prediction_type == "epsilon":
target = noise
elif prediction_type == "sample":
target = clean_actions
else:
raise ValueError(f"Unsupported prediction type: {prediction_type}")
predicted = model(noisy_actions, timesteps, conditioning_vec=conditioning_vec)
loss = F.mse_loss(predicted, target, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_actions = ~batch["action_is_pad"]
loss = loss * valid_actions.unsqueeze(-1)
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
sample = torch.randn(
size=(batch_size, self.horizon, self.action_dim),
dtype=dtype,
device=device,
)
self.noise_scheduler.set_timesteps(self.num_inference_steps)
for t in self.noise_scheduler.timesteps:
model_output = model(
sample,
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
conditioning_vec=conditioning_vec,
)
sample = self.noise_scheduler.step(model_output, t, sample).prev_sample
return sample
class FlowMatchingObjective(nn.Module):
"""Flow matching objective: trains a model to predict velocity fields."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__()
self.config = config
self.action_dim = action_dim
self.horizon = horizon
self.do_mask_loss_for_padding = do_mask_loss_for_padding
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
if self.config.timestep_sampling_strategy == "uniform":
return torch.rand(batch_size, device=device)
elif self.config.timestep_sampling_strategy == "beta":
beta_dist = torch.distributions.Beta(
self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta
)
u = beta_dist.sample((batch_size,)).to(device)
return self.config.timestep_sampling_s * (1.0 - u)
else:
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
data = batch[ACTION]
batch_size = data.shape[0]
device = data.device
noise = torch.randn_like(data)
t = self._sample_timesteps(batch_size, device)
t_expanded = t.view(-1, 1, 1)
x_t = t_expanded * data + (1 - (1 - self.config.sigma_min) * t_expanded) * noise
target_velocity = data - (1 - self.config.sigma_min) * noise
predicted_velocity = model(x_t, t, conditioning_vec=conditioning_vec)
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_mask = ~batch["action_is_pad"]
loss = loss * valid_mask.unsqueeze(-1)
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
x = torch.randn((batch_size, self.horizon, self.action_dim), dtype=dtype, device=device)
num_steps = self.config.num_integration_steps
time_grid = torch.linspace(0, 1, num_steps + 1, device=device)
if self.config.integration_method == "euler":
x = self._euler_integrate(model, x, time_grid, conditioning_vec)
elif self.config.integration_method == "rk4":
x = self._rk4_integrate(model, x, time_grid, conditioning_vec)
else:
raise ValueError(f"Unknown integration method: {self.config.integration_method}")
return x
def _euler_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
x = x_init
for i in range(len(time_grid) - 1):
t_scalar = time_grid[i].item()
dt = (time_grid[i + 1] - time_grid[i]).item()
t_batch = torch.full((x.shape[0],), t_scalar, dtype=x.dtype, device=x.device)
with torch.no_grad():
velocity = model(x, t_batch, conditioning_vec=conditioning_vec)
x = x + dt * velocity
return x
def _rk4_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
x = x_init
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
with torch.no_grad():
return model(x_val, t_batch, conditioning_vec=conditioning_vec)
for i in range(len(time_grid) - 1):
t = time_grid[i].item()
dt = (time_grid[i + 1] - time_grid[i]).item()
k1 = dynamics(x, t)
k2 = dynamics(x + dt * k1 / 2, t + dt / 2)
k3 = dynamics(x + dt * k2 / 2, t + dt / 2)
k4 = dynamics(x + dt * k3, t + dt)
x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
return x

View File

@@ -1,105 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_multi_task_dit_pre_post_processors(
config: MultiTaskDiTConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for a Multi-Task DiT policy.
The pre-processing pipeline prepares the input data for the model by:
1. Renaming features.
2. Adding a batch dimension.
3. Tokenizing the language task description (if present).
4. Moving the data to the specified device.
5. Normalizing the input and output features based on dataset statistics.
The post-processing pipeline handles the model's output by:
1. Unnormalizing the output features to their original scale.
2. Moving the data to the CPU.
Args:
config: The configuration object for the Multi-Task DiT policy,
containing feature definitions, normalization mappings, and device information.
dataset_stats: A dictionary of statistics used for normalization.
Defaults to None.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
TokenizerProcessorStep(
tokenizer_name=config.text_encoder_name,
padding=config.tokenizer_padding,
padding_side=config.tokenizer_padding_side,
max_length=config.tokenizer_max_length,
truncation=config.tokenizer_truncation,
),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
device=config.device,
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)

View File

@@ -16,3 +16,5 @@
from .config_unitree_g1 import UnitreeG1Config
from .unitree_g1 import UnitreeG1
__all__ = ["UnitreeG1", "UnitreeG1Config"]

View File

@@ -27,11 +27,10 @@ _GAINS: dict[str, dict[str, list[float]]] = {
}, # pitch, roll, yaw, knee, ankle_pitch, ankle_roll
"right_leg": {"kp": [150, 150, 150, 300, 40, 40], "kd": [2, 2, 2, 4, 2, 2]},
"waist": {"kp": [250, 250, 250], "kd": [5, 5, 5]}, # yaw, roll, pitch
"left_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]}, # shoulder_pitch/roll/yaw, elbow
"left_arm": {"kp": [50, 50, 80, 80], "kd": [3, 3, 3, 3]}, # shoulder_pitch/roll/yaw, elbow
"left_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]}, # roll, pitch, yaw
"right_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]},
"right_arm": {"kp": [50, 50, 80, 80], "kd": [3, 3, 3, 3]},
"right_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]},
"other": {"kp": [80, 80, 80, 80, 80, 80], "kd": [3, 3, 3, 3, 3, 3]},
}
@@ -68,3 +67,7 @@ class UnitreeG1Config(RobotConfig):
# Compensates for gravity on the unitree's arms using the arm ik solver
gravity_compensation: bool = False
# Lower-body controller class name, e.g. "GrootLocomotionController" or
# "HolosomaLocomotionController". None disables it.
controller: str | None = None

View File

@@ -16,13 +16,11 @@
import logging
import os
import sys
from collections import deque
import numpy as np
logger = logging.getLogger(__name__)
parent2_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(parent2_dir)
class WeightedMovingFilter:
@@ -31,18 +29,14 @@ class WeightedMovingFilter:
self._weights = np.array(weights)
self._data_size = data_size
self._filtered_data = np.zeros(self._data_size)
self._data_queue = []
self._data_queue = deque(maxlen=self._window_size)
def _apply_filter(self):
if len(self._data_queue) < self._window_size:
return self._data_queue[-1]
data_array = np.array(self._data_queue)
temp_filtered_data = np.zeros(self._data_size)
for i in range(self._data_size):
temp_filtered_data[i] = np.convolve(data_array[:, i], self._weights, mode="valid")[-1]
return temp_filtered_data
return data_array.T @ self._weights
def add_data(self, new_data):
assert len(new_data) == self._data_size
@@ -52,9 +46,6 @@ class WeightedMovingFilter:
): # skip duplicate data
return
if len(self._data_queue) >= self._window_size:
self._data_queue.pop(0)
self._data_queue.append(new_data)
self._filtered_data = self._apply_filter()
@@ -71,8 +62,6 @@ class G1_29_ArmIK: # noqa: N801
from pinocchio import casadi as cpin
self._pin = pin
np.set_printoptions(precision=5, suppress=True, linewidth=200)
self.unit_test = unit_test
self.repo_path = snapshot_download("lerobot/unitree-g1-mujoco")
@@ -249,50 +238,35 @@ class G1_29_ArmIK: # noqa: N801
self.opti.set_value(self.param_tf_r, right_wrist)
self.opti.set_value(self.var_q_last, self.init_data) # for smooth
converged = True
try:
self.opti.solve()
sol_q = self.opti.value(self.var_q)
self.smooth_filter.add_data(sol_q)
sol_q = self.smooth_filter.filtered_data
if current_lr_arm_motor_dq is not None:
v = current_lr_arm_motor_dq * 0.0
else:
v = (sol_q - self.init_data) * 0.0
self.init_data = sol_q
sol_tauff = self._pin.rnea(
self.reduced_robot.model,
self.reduced_robot.data,
sol_q,
v,
np.zeros(self.reduced_robot.model.nv),
)
return sol_q, sol_tauff
except Exception as e:
logger.error(f"ERROR in convergence, plotting debug info.{e}")
converged = False
logger.error(f"IK convergence error: {e}")
sol_q = self.opti.debug.value(self.var_q)
self.smooth_filter.add_data(sol_q)
sol_q = self.smooth_filter.filtered_data
if current_lr_arm_motor_dq is not None:
v = current_lr_arm_motor_dq * 0.0
else:
v = (sol_q - self.init_data) * 0.0
self.init_data = sol_q
self.smooth_filter.add_data(sol_q)
sol_q = self.smooth_filter.filtered_data
self.init_data = sol_q
if not converged:
logger.error(
f"sol_q:{sol_q} \nmotorstate: \n{current_lr_arm_motor_q} \nleft_pose: \n{left_wrist} \nright_pose: \n{right_wrist}"
)
return current_lr_arm_motor_q, np.zeros(self.reduced_robot.model.nv)
sol_tauff = self._pin.rnea(
self.reduced_robot.model,
self.reduced_robot.data,
sol_q,
np.zeros(self.reduced_robot.model.nv),
np.zeros(self.reduced_robot.model.nv),
)
return sol_q, sol_tauff
def solve_tau(self, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None):
try:
q_g1 = np.array(current_lr_arm_motor_q, dtype=float)

View File

@@ -14,12 +14,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from enum import IntEnum
import numpy as np
# ruff: noqa: N801, N815
NUM_MOTORS = 29
REMOTE_AXES = ("remote.lx", "remote.ly", "remote.rx", "remote.ry")
REMOTE_BUTTONS = tuple(f"remote.button.{i}" for i in range(16))
REMOTE_KEYS = REMOTE_AXES + REMOTE_BUTTONS
def default_remote_input() -> dict[str, float]:
"""Return a zeroed-out remote input dict (axes + buttons)."""
return dict.fromkeys(REMOTE_KEYS, 0.0)
def get_gravity_orientation(quaternion: list[float] | np.ndarray) -> np.ndarray:
"""Get gravity orientation from quaternion [w, x, y, z]."""
qw, qx, qy, qz = quaternion
gravity_orientation = np.zeros(3, dtype=np.float32)
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
return gravity_orientation
class G1_29_JointArmIndex(IntEnum):
# Left arm
@@ -29,7 +51,7 @@ class G1_29_JointArmIndex(IntEnum):
kLeftElbow = 18
kLeftWristRoll = 19
kLeftWristPitch = 20
kLeftWristyaw = 21
kLeftWristYaw = 21
# Right arm
kRightShoulderPitch = 22
@@ -41,6 +63,21 @@ class G1_29_JointArmIndex(IntEnum):
kRightWristYaw = 28
def make_locomotion_controller(name: str | None):
"""Instantiate a locomotion controller by class name. Returns None if name is None."""
if name is None:
return None
controllers = {
"GrootLocomotionController": "lerobot.robots.unitree_g1.gr00t_locomotion",
"HolosomaLocomotionController": "lerobot.robots.unitree_g1.holosoma_locomotion",
}
module_path = controllers.get(name)
if module_path is None:
raise ValueError(f"Unknown controller: {name!r}. Available: {list(controllers)}")
module = importlib.import_module(module_path)
return getattr(module, name)()
class G1_29_JointIndex(IntEnum):
# Left leg
kLeftHipPitch = 0
@@ -69,7 +106,7 @@ class G1_29_JointIndex(IntEnum):
kLeftElbow = 18
kLeftWristRoll = 19
kLeftWristPitch = 20
kLeftWristyaw = 21
kLeftWristYaw = 21
# Right arm
kRightShoulderPitch = 22

View File

@@ -14,20 +14,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import time
from collections import deque
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
from lerobot.robots.unitree_g1.g1_utils import (
REMOTE_AXES,
REMOTE_BUTTONS,
G1_29_JointIndex,
get_gravity_orientation,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -36,18 +36,13 @@ GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # Hip pitch
GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # Knee
GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # Ankle pitch
MISSING_JOINTS = []
G1_MODEL = "g1_23" # Or "g1_29"
if G1_MODEL == "g1_23":
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # Waist yaw/pitch, wrist pitch/yaw
# Control parameters
ACTION_SCALE = 0.25
CONTROL_DT = 0.02 # 50Hz
ANG_VEL_SCALE: float = 0.25
DOF_POS_SCALE: float = 1.0
DOF_VEL_SCALE: float = 0.05
CMD_SCALE: list = [2.0, 2.0, 0.25]
CMD_SCALE: list[float] = [2.0, 2.0, 0.25]
DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
@@ -85,11 +80,11 @@ def load_groot_policies(
class GrootLocomotionController:
"""GR00T lower-body locomotion controller for the Unitree G1."""
def __init__(self, policy_balance, policy_walk, robot, config):
self.policy_balance = policy_balance
self.policy_walk = policy_walk
self.robot = robot
self.config = config
control_dt = CONTROL_DT # Expose for unitree_g1.py
def __init__(self):
# Load policies
self.policy_balance, self.policy_walk = load_groot_policies()
self.cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot
@@ -109,45 +104,60 @@ class GrootLocomotionController:
logger.info("GrootLocomotionController initialized")
def run_step(self):
# Get current observation
obs = self.robot.get_observation()
def reset(self) -> None:
"""Reset internal state for a new episode."""
self.cmd[:] = 0.0
self.groot_qj_all[:] = 0.0
self.groot_dqj_all[:] = 0.0
self.groot_action[:] = 0.0
self.groot_obs_single[:] = 0.0
self.groot_obs_stacked[:] = 0.0
self.groot_height_cmd = 0.74
self.groot_orientation_cmd[:] = 0.0
self.groot_obs_history.clear()
for _ in range(6):
self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
if not obs:
return
def run_step(self, action: dict, lowstate) -> dict:
"""Run one step of the locomotion controller.
# Get command from remote controller
if obs["remote.buttons"][0]: # R1 - raise waist
Args:
action: Action dict containing remote.lx/ly/rx/ry and buttons
lowstate: Robot lowstate containing motor positions/velocities and IMU
Returns:
Action dict for lower body joints (0-14)
"""
if lowstate is None:
return {}
buttons = [int(action.get(k, 0)) for k in REMOTE_BUTTONS]
if buttons[0]: # R1 - raise waist
self.groot_height_cmd += 0.001
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
if obs["remote.buttons"][4]: # R2 - lower waist
if buttons[4]: # R2 - lower waist
self.groot_height_cmd -= 0.001
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
self.cmd[0] = obs["remote.ly"] # Forward/backward
self.cmd[1] = obs["remote.lx"] * -1 # Left/right
self.cmd[2] = obs["remote.rx"] * -1 # Rotation rate
lx, ly, rx, _ry = (action.get(k, 0.0) for k in REMOTE_AXES)
self.cmd[0] = ly # Forward/backward
self.cmd[1] = -lx # Left/right (negated)
self.cmd[2] = -rx # Rotation rate (negated)
# Get joint positions and velocities from flat dict
# Get joint positions and velocities from lowstate
for motor in G1_29_JointIndex:
name = motor.name
idx = motor.value
self.groot_qj_all[idx] = obs[f"{name}.q"]
self.groot_dqj_all[idx] = obs[f"{name}.dq"]
# Adapt observation for g1_23dof
for idx in MISSING_JOINTS:
self.groot_qj_all[idx] = 0.0
self.groot_dqj_all[idx] = 0.0
self.groot_qj_all[idx] = lowstate.motor_state[idx].q
self.groot_dqj_all[idx] = lowstate.motor_state[idx].dq
# Scale joint positions and velocities
qj_obs = self.groot_qj_all.copy()
dqj_obs = self.groot_dqj_all.copy()
# Express IMU data in gravity frame of reference
quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
gravity_orientation = self.robot.get_gravity_orientation(quat)
quat = lowstate.imu_state.quaternion
ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32)
gravity_orientation = get_gravity_orientation(quat)
# Scale joint positions and velocities before policy inference
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
@@ -186,73 +196,10 @@ class GrootLocomotionController:
# Transform action back to target joint positions
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * ACTION_SCALE
# Build action dict (only first 15 joints for GR00T)
# Build action dict
action_dict = {}
for i in range(15):
motor_name = G1_29_JointIndex(i).name
action_dict[f"{motor_name}.q"] = float(target_dof_pos_15[i])
# Zero out missing joints for g1_23dof
for joint_idx in MISSING_JOINTS:
motor_name = G1_29_JointIndex(joint_idx).name
action_dict[f"{motor_name}.q"] = 0.0
# Send action to robot
self.robot.send_action(action_dict)
def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None:
"""Main function to run the GR00T locomotion controller.
Args:
repo_id: Hugging Face Hub repository ID for GR00T policies.
"""
# Load policies
policy_balance, policy_walk = load_groot_policies(repo_id=repo_id)
# Initialize robot
config = UnitreeG1Config()
robot = UnitreeG1(config)
robot.connect()
# Initialize gr00T locomotion controller
groot_controller = GrootLocomotionController(
policy_balance=policy_balance,
policy_walk=policy_walk,
robot=robot,
config=config,
)
try:
robot.reset(CONTROL_DT, GROOT_DEFAULT_ANGLES)
logger.info("Use joystick: LY=fwd/back, LX=left/right, RX=rotate, R1=raise waist, R2=lower waist")
logger.info("Press Ctrl+C to stop")
# Run step
while not robot._shutdown_event.is_set():
start_time = time.time()
groot_controller.run_step()
elapsed = time.time() - start_time
sleep_time = max(0, CONTROL_DT - elapsed)
time.sleep(sleep_time)
except KeyboardInterrupt:
logger.info("Stopping locomotion...")
finally:
if robot.is_connected:
robot.disconnect()
logger.info("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1")
parser.add_argument(
"--repo-id",
type=str,
default=DEFAULT_GROOT_REPO_ID,
help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})",
)
args = parser.parse_args()
run(repo_id=args.repo_id)
return action_dict

View File

@@ -14,21 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
import time
import numpy as np
import onnx
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
from lerobot.robots.unitree_g1.g1_utils import (
REMOTE_AXES,
G1_29_JointArmIndex,
G1_29_JointIndex,
get_gravity_orientation,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DEFAULT_ANGLES = np.zeros(29, dtype=np.float32)
@@ -40,18 +40,13 @@ DEFAULT_ANGLES[16] = 0.2 # Left shoulder roll
DEFAULT_ANGLES[23] = -0.2 # Right shoulder roll
DEFAULT_ANGLES[[18, 25]] = 0.6 # Elbow
MISSING_JOINTS = []
G1_MODEL = "g1_23" # Or "g1_29"
if G1_MODEL == "g1_23":
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # Waist yaw/pitch, wrist pitch/yaw
# Control parameters
ACTION_SCALE = 0.25
CONTROL_DT = 0.02 # 50Hz
CONTROL_DT = 0.005 # 200Hz
ANG_VEL_SCALE = 0.25
DOF_POS_SCALE = 1.0
DOF_VEL_SCALE = 0.05
GAIT_PERIOD = 1.0
GAIT_PERIOD = 0.5
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
@@ -87,7 +82,7 @@ def load_policy(
logger.info(f"Policy loaded: {policy.get_inputs()[0].shape}{policy.get_outputs()[0].shape}")
# Extract KP/KD from ONNX metadata
model = onnx.load(policy_path)
model = onnx.load(policy_path, load_external_data=False)
metadata = {prop.key: prop.value for prop in model.metadata_props}
if "kp" not in metadata or "kd" not in metadata:
@@ -101,15 +96,13 @@ def load_policy(
class HolosomaLocomotionController:
"""Holosoma whole-body locomotion controller for Unitree G1."""
"""Holosoma lower-body locomotion controller for Unitree G1."""
def __init__(self, policy, robot, kp: np.ndarray, kd: np.ndarray):
self.policy = policy
self.robot = robot
control_dt = CONTROL_DT # Expose for unitree_g1.py
# Override robot's PD gains with policy gains
self.robot.kp = kp
self.robot.kd = kd
def __init__(self):
# Load policy and gains
self.policy, self.kp, self.kd = load_policy()
self.cmd = np.zeros(3, dtype=np.float32)
@@ -124,35 +117,55 @@ class HolosomaLocomotionController:
self.phase_dt = 2 * np.pi / ((1.0 / CONTROL_DT) * GAIT_PERIOD)
self.is_standing = True
def run_step(self):
# Get current observation
obs = self.robot.get_observation()
logger.info("HolosomaLocomotionController initialized")
if not obs:
return
def reset(self) -> None:
"""Reset internal state for a new episode."""
self.cmd[:] = 0.0
self.qj[:] = 0.0
self.dqj[:] = 0.0
self.obs[:] = 0.0
self.last_action[:] = 0.0
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
self.is_standing = True
# Get command from remote controller
ly = obs["remote.ly"] if abs(obs["remote.ly"]) > 0.1 else 0.0
lx = obs["remote.lx"] if abs(obs["remote.lx"]) > 0.1 else 0.0
rx = obs["remote.rx"] if abs(obs["remote.rx"]) > 0.1 else 0.0
def run_step(self, action: dict, lowstate) -> dict:
"""Run one step of the locomotion controller.
Args:
action: Action dict containing remote.lx/ly/rx/ry
lowstate: Robot lowstate containing motor positions/velocities and IMU
Returns:
Action dict for lower body joints (0-14)
"""
if lowstate is None:
return {}
lx, ly, rx, _ry = (action.get(k, 0.0) for k in REMOTE_AXES)
ly = ly if abs(ly) > 0.1 else 0.0
lx = lx if abs(lx) > 0.1 else 0.0
rx = rx if abs(rx) > 0.1 else 0.0
ly = np.clip(ly, -0.3, 0.3)
lx = np.clip(lx, -0.3, 0.3)
self.cmd[:] = [ly, -lx, -rx]
# Get joint positions and velocities
# Get joint positions and velocities from lowstate
for motor in G1_29_JointIndex:
name = motor.name
idx = motor.value
self.qj[idx] = obs[f"{name}.q"]
self.dqj[idx] = obs[f"{name}.dq"]
self.qj[idx] = lowstate.motor_state[idx].q
self.dqj[idx] = lowstate.motor_state[idx].dq
# Adapt observation for g1_23dof
for idx in MISSING_JOINTS:
self.qj[idx] = 0.0
self.dqj[idx] = 0.0
# Hide arm positions from policy (show DEFAULT_ANGLES instead)
# This prevents policy from reacting to teleop arm movements
for arm_joint in G1_29_JointArmIndex:
self.qj[arm_joint.value] = DEFAULT_ANGLES[arm_joint.value]
self.dqj[arm_joint.value] = 0.0
# Express IMU data in gravity frame of reference
quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
gravity = self.robot.get_gravity_orientation(quat)
quat = lowstate.imu_state.quaternion
ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32)
gravity = get_gravity_orientation(quat)
# Scale joint positions and velocities before policy inference
qj_obs = (self.qj - DEFAULT_ANGLES) * DOF_POS_SCALE
@@ -186,79 +199,16 @@ class HolosomaLocomotionController:
# Run policy inference
ort_in = {self.policy.get_inputs()[0].name: self.obs.reshape(1, -1).astype(np.float32)}
raw_action = self.policy.run(None, ort_in)[0].squeeze()
action = np.clip(raw_action, -100.0, 100.0)
self.last_action = action.copy()
policy_action = np.clip(raw_action, -100.0, 100.0)
self.last_action = policy_action.copy()
# Transform action back to target joint positions
target = DEFAULT_ANGLES + action * ACTION_SCALE
target = DEFAULT_ANGLES + policy_action * ACTION_SCALE
# Build action dict
# Build action dict (first 15 joints only)
action_dict = {}
for motor in G1_29_JointIndex:
action_dict[f"{motor.name}.q"] = float(target[motor.value])
for i in range(15):
motor_name = G1_29_JointIndex(i).name
action_dict[f"{motor_name}.q"] = float(target[i])
# Zero out missing joints for g1_23dof
for joint_idx in MISSING_JOINTS:
motor_name = G1_29_JointIndex(joint_idx).name
action_dict[f"{motor_name}.q"] = 0.0
# Send action to robot
self.robot.send_action(action_dict)
def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") -> None:
"""Main function to run the Holosoma locomotion controller.
Args:
repo_id: Hugging Face Hub repository ID for Holosoma policies.
policy_type: Policy type to use ('fastsac' or 'ppo').
"""
# Load policy and gains
policy, kp, kd = load_policy(repo_id=repo_id, policy_type=policy_type)
# Initialize robot
config = UnitreeG1Config()
robot = UnitreeG1(config)
robot.connect()
holosoma_controller = HolosomaLocomotionController(policy, robot, kp, kd)
try:
robot.reset(CONTROL_DT, DEFAULT_ANGLES)
logger.info("Use joystick: LY=fwd/back, LX=left/right, RX=rotate")
logger.info("Press Ctrl+C to stop")
# Run step
while not robot._shutdown_event.is_set():
start_time = time.time()
holosoma_controller.run_step()
elapsed = time.time() - start_time
sleep_time = max(0, CONTROL_DT - elapsed)
time.sleep(sleep_time)
except KeyboardInterrupt:
logger.info("Stopping locomotion...")
finally:
if robot.is_connected:
robot.disconnect()
logger.info("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Holosoma Locomotion Controller for Unitree G1")
parser.add_argument(
"--repo-id",
type=str,
default=DEFAULT_HOLOSOMA_REPO_ID,
help=f"Hugging Face Hub repo ID for Holosoma policies (default: {DEFAULT_HOLOSOMA_REPO_ID})",
)
parser.add_argument(
"--policy",
type=str,
choices=["fastsac", "ppo"],
default="fastsac",
help="Policy type to use: 'fastsac' (default) or 'ppo'",
)
args = parser.parse_args()
run(repo_id=args.repo_id, policy_type=args.policy)
return action_dict

View File

@@ -24,6 +24,7 @@ This server runs on the robot and forwards:
Uses JSON for secure serialization instead of pickle.
"""
import argparse
import base64
import contextlib
import json
@@ -38,6 +39,8 @@ from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState
from unitree_sdk2py.utils.crc import CRC
from lerobot.cameras.zmq.image_server import ImageServer
# DDS topic names follow Unitree SDK naming conventions
# ruff: noqa: N816
kTopicLowCommand_Debug = "rt/lowcmd" # action to robot
@@ -150,6 +153,32 @@ def cmd_forward_loop(
def main() -> None:
"""Main entry point for the robot server bridge."""
parser = argparse.ArgumentParser(description="DDS-to-ZMQ bridge server for Unitree G1")
parser.add_argument("--camera", action="store_true", help="Also launch camera server")
parser.add_argument("--camera-device", type=int, default=4, help="Camera device ID (default: 4)")
parser.add_argument("--camera-fps", type=int, default=30, help="Camera FPS (default: 30)")
parser.add_argument("--camera-width", type=int, default=640, help="Camera width (default: 640)")
parser.add_argument("--camera-height", type=int, default=480, help="Camera height (default: 480)")
parser.add_argument("--camera-port", type=int, default=5555, help="Camera ZMQ port (default: 5555)")
args = parser.parse_args()
# Optionally start camera server in background thread
camera_thread = None
if args.camera:
camera_config = {
"fps": args.camera_fps,
"cameras": {
"head_camera": {
"device_id": args.camera_device,
"shape": [args.camera_height, args.camera_width],
}
},
}
camera_server = ImageServer(camera_config, port=args.camera_port)
camera_thread = threading.Thread(target=camera_server.run, daemon=True)
camera_thread.start()
print(f"Camera server started on port {args.camera_port} (device {args.camera_device})")
# initialize DDS
ChannelFactoryInitialize(0)
@@ -206,6 +235,8 @@ def main() -> None:
shutdown_event.set()
ctx.term() # terminates blocking zmq.recv() calls
t_state.join(timeout=2.0)
if camera_thread is not None:
camera_thread.join(timeout=2.0)
if __name__ == "__main__":

View File

@@ -14,27 +14,67 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import struct
import threading
import time
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any
from typing import TYPE_CHECKING, Protocol, runtime_checkable
import numpy as np
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.envs.factory import make_env
from lerobot.processor import RobotAction, RobotObservation
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex, G1_29_JointIndex
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
from lerobot.robots.unitree_g1.g1_kinematics import G1_29_ArmIK
from lerobot.robots.unitree_g1.g1_utils import (
REMOTE_AXES,
REMOTE_KEYS,
G1_29_JointArmIndex,
G1_29_JointIndex,
default_remote_input,
make_locomotion_controller,
)
from lerobot.utils.import_utils import _unitree_sdk_available
from ..robot import Robot
from .config_unitree_g1 import UnitreeG1Config
if TYPE_CHECKING or _unitree_sdk_available:
from unitree_sdk2py.core.channel import (
ChannelFactoryInitialize as _SDKChannelFactoryInitialize,
ChannelPublisher as _SDKChannelPublisher,
ChannelSubscriber as _SDKChannelSubscriber,
)
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
LowCmd_ as hg_LowCmd,
LowState_ as hg_LowState,
)
from unitree_sdk2py.utils.crc import CRC
else:
_SDKChannelFactoryInitialize = None
_SDKChannelPublisher = None
_SDKChannelSubscriber = None
unitree_hg_msg_dds__LowCmd_ = None
hg_LowCmd = None
hg_LowState = None
CRC = None
logger = logging.getLogger(__name__)
@runtime_checkable
class LocomotionController(Protocol):
control_dt: float
def run_step(self, action: dict, lowstate) -> dict: ...
def reset(self) -> None: ...
# DDS topic names follow Unitree SDK naming conventions
# ruff: noqa: N816
kTopicLowCommand_Debug = "rt/lowcmd"
@@ -63,7 +103,7 @@ class IMUState:
class G1_29_LowState: # noqa: N801
motor_state: list[MotorState] = field(default_factory=lambda: [MotorState() for _ in G1_29_JointIndex])
imu_state: IMUState = field(default_factory=IMUState)
wireless_remote: Any = None # Raw wireless remote data
wireless_remote: bytes | None = None # Raw wireless remote data
mode_machine: int = 0 # Robot mode
@@ -71,25 +111,6 @@ class UnitreeG1(Robot):
config_class = UnitreeG1Config
name = "unitree_g1"
# unitree remote controller
class RemoteController:
def __init__(self):
self.lx = 0
self.ly = 0
self.rx = 0
self.ry = 0
self.button = [0] * 16
def set(self, data):
# wireless_remote
keys = struct.unpack("H", data[2:4])[0]
for i in range(16):
self.button[i] = (keys & (1 << i)) >> i
self.lx = struct.unpack("f", data[4:8])[0]
self.rx = struct.unpack("f", data[8:12])[0]
self.ry = struct.unpack("f", data[12:16])[0]
self.ly = struct.unpack("f", data[20:24])[0]
def __init__(self, config: UnitreeG1Config):
super().__init__(config)
@@ -103,11 +124,9 @@ class UnitreeG1(Robot):
# Import channel classes based on mode
if config.is_simulation:
from unitree_sdk2py.core.channel import (
ChannelFactoryInitialize,
ChannelPublisher,
ChannelSubscriber,
)
self._ChannelFactoryInitialize = _SDKChannelFactoryInitialize
self._ChannelPublisher = _SDKChannelPublisher
self._ChannelSubscriber = _SDKChannelSubscriber
else:
from lerobot.robots.unitree_g1.unitree_sdk2_socket import (
ChannelFactoryInitialize,
@@ -115,22 +134,30 @@ class UnitreeG1(Robot):
ChannelSubscriber,
)
# Store for use in connect()
self._ChannelFactoryInitialize = ChannelFactoryInitialize
self._ChannelPublisher = ChannelPublisher
self._ChannelSubscriber = ChannelSubscriber
self._ChannelFactoryInitialize = ChannelFactoryInitialize
self._ChannelPublisher = ChannelPublisher
self._ChannelSubscriber = ChannelSubscriber
# Initialize state variables
self.sim_env = None
self._env_wrapper = None
self._lowstate = None
self._lowstate_lock = threading.Lock()
self._shutdown_event = threading.Event()
self.subscribe_thread = None
self.remote_controller = self.RemoteController()
self.arm_ik = G1_29_ArmIK()
self.arm_ik = G1_29_ArmIK() if config.gravity_compensation else None
def _subscribe_motor_state(self): # polls robot state @ 250Hz
# Lower-body controller loaded dynamically
self.controller: LocomotionController | None = make_locomotion_controller(config.controller)
# Controller thread state
self._controller_thread = None
self._controller_action_lock = threading.Lock()
self.controller_input = default_remote_input()
self.controller_output = {}
def _subscribe_lowstate(self): # polls robot state @ 250Hz
while not self._shutdown_event.is_set():
start_time = time.time()
@@ -143,11 +170,11 @@ class UnitreeG1(Robot):
lowstate = G1_29_LowState()
# Capture motor states using jointindex
for id in G1_29_JointIndex:
lowstate.motor_state[id].q = msg.motor_state[id].q
lowstate.motor_state[id].dq = msg.motor_state[id].dq
lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est
lowstate.motor_state[id].temperature = msg.motor_state[id].temperature
for joint in G1_29_JointIndex:
lowstate.motor_state[joint].q = msg.motor_state[joint].q
lowstate.motor_state[joint].dq = msg.motor_state[joint].dq
lowstate.motor_state[joint].tau_est = msg.motor_state[joint].tau_est
lowstate.motor_state[joint].temperature = msg.motor_state[joint].temperature
# Capture IMU state
lowstate.imu_state.quaternion = list(msg.imu_state.quaternion)
@@ -162,31 +189,106 @@ class UnitreeG1(Robot):
# Capture mode_machine
lowstate.mode_machine = msg.mode_machine
self._lowstate = lowstate
with self._lowstate_lock:
self._lowstate = lowstate
current_time = time.time()
all_t_elapsed = current_time - start_time
sleep_time = max(0, (self.control_dt - all_t_elapsed)) # maintain constant control dt
time.sleep(sleep_time)
def publish_lowcmd(
self,
action: RobotAction,
kp: np.ndarray | list[float] | None = None,
kd: np.ndarray | list[float] | None = None,
tau: np.ndarray | list[float] | None = None,
) -> None: # writes robot command whenever requested
for motor in G1_29_JointIndex:
key = f"{motor.name}.q"
if key in action:
self.msg.motor_cmd[motor.value].q = action[key]
self.msg.motor_cmd[motor.value].qd = 0
self.msg.motor_cmd[motor.value].kp = (
kp[motor.value] if kp is not None else self.kp[motor.value]
)
self.msg.motor_cmd[motor.value].kd = (
kd[motor.value] if kd is not None else self.kd[motor.value]
)
self.msg.motor_cmd[motor.value].tau = tau[motor.value] if tau is not None else 0.0
self.msg.crc = self.crc.Crc(self.msg)
self.lowcmd_publisher.Write(self.msg)
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
if self.controller is None:
return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
def calibrate(self) -> None: # robot is already calibrated
arm_features = {f"{G1_29_JointArmIndex(motor).name}.q": float for motor in G1_29_JointArmIndex}
remote_features = dict.fromkeys(REMOTE_AXES, float)
return {**arm_features, **remote_features}
def _controller_loop(self):
"""Background thread that runs controller at policy's control_dt."""
control_dt = self.controller.control_dt
logger.info(f"Controller loop starting with control_dt={control_dt} ({1.0 / control_dt:.1f}Hz)")
loop_count = 0
last_log_time = time.time()
while not self._shutdown_event.is_set():
start_time = time.time()
with self._lowstate_lock:
lowstate = self._lowstate
if lowstate is not None and self.controller is not None:
loop_count += 1
if time.time() - last_log_time >= 5.0: # Log every 5 seconds
actual_hz = loop_count / (time.time() - last_log_time)
logger.info(
f"Controller actual rate: {actual_hz:.1f}Hz (target: {1.0 / control_dt:.1f}Hz)"
)
loop_count = 0
last_log_time = time.time()
# Read controller input snapshot
with self._controller_action_lock:
controller_input = dict(self.controller_input)
# Run controller step
controller_action = self.controller.run_step(controller_input, lowstate)
# Write controller output snapshot
with self._controller_action_lock:
self.controller_output = dict(controller_action)
ctrl_kp = self.controller.kp if hasattr(self.controller, "kp") else None
ctrl_kd = self.controller.kd if hasattr(self.controller, "kd") else None
self.publish_lowcmd(controller_action, kp=ctrl_kp, kd=ctrl_kd)
elapsed = time.time() - start_time
sleep_time = max(0, control_dt - elapsed)
time.sleep(sleep_time)
def calibrate(self) -> None:
# TODO: implement g1_29 calibration
pass
def configure(self) -> None:
pass
def connect(self, calibrate: bool = True) -> None: # connect to DDS
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
LowCmd_ as hg_LowCmd,
LowState_ as hg_LowState,
)
from unitree_sdk2py.utils.crc import CRC
# Initialize DDS channel and simulation environment
if self.config.is_simulation:
self._ChannelFactoryInitialize(0, "lo")
@@ -194,7 +296,7 @@ class UnitreeG1(Robot):
# Extract the actual gym env from the dict structure
self.sim_env = self._env_wrapper["hub_env"][0].envs[0]
else:
self._ChannelFactoryInitialize(0)
self._ChannelFactoryInitialize(0, config=self.config)
# Initialize direct motor control interface
self.lowcmd_publisher = self._ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
@@ -203,7 +305,7 @@ class UnitreeG1(Robot):
self.lowstate_subscriber.Init()
# Start subscribe thread to read robot state
self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state)
self.subscribe_thread = threading.Thread(target=self._subscribe_lowstate)
self.subscribe_thread.start()
# Connect cameras
@@ -220,25 +322,53 @@ class UnitreeG1(Robot):
# Wait for first state message to arrive
lowstate = None
deadline = time.time() + 10.0
while lowstate is None:
lowstate = self._lowstate
with self._lowstate_lock:
lowstate = self._lowstate
if lowstate is None:
if time.time() > deadline:
raise TimeoutError("Timed out waiting for robot state (10s)")
logger.warning("[UnitreeG1] Waiting for robot state...")
time.sleep(0.01)
logger.warning("[UnitreeG1] Waiting for robot state...")
logger.warning("[UnitreeG1] Connected to robot.")
logger.info("[UnitreeG1] Connected to robot.")
self.msg.mode_machine = lowstate.mode_machine
# Initialize all motors with unified kp/kd from config
self.kp = np.array(self.config.kp, dtype=np.float32)
self.kd = np.array(self.config.kd, dtype=np.float32)
for id in G1_29_JointIndex:
self.msg.motor_cmd[id].mode = 1
self.msg.motor_cmd[id].kp = self.kp[id.value]
self.msg.motor_cmd[id].kd = self.kd[id.value]
self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q
for joint in G1_29_JointIndex:
self.msg.motor_cmd[joint].mode = 1
self.msg.motor_cmd[joint].kp = self.kp[joint.value]
self.msg.motor_cmd[joint].kd = self.kd[joint.value]
self.msg.motor_cmd[joint].q = lowstate.motor_state[joint.value].q
# Start controller thread if enabled
if self.controller is not None:
self._controller_thread = threading.Thread(target=self._controller_loop, daemon=True)
self._controller_thread.start()
fps = int(1.0 / self.controller.control_dt)
logger.info(f"Controller thread started ({fps}Hz)")
def _send_zero_torque(self) -> None:
"""Send a zero-gain command to make joints passive before shutting down."""
try:
with self._lowstate_lock:
lowstate = self._lowstate
if lowstate is None:
return
action = {f"{motor.name}.q": lowstate.motor_state[motor.value].q for motor in G1_29_JointIndex}
zero_gains = np.zeros(29, dtype=np.float32)
self.publish_lowcmd(action, kp=zero_gains, kd=zero_gains, tau=zero_gains)
logger.info("Sent zero-torque command for safe shutdown")
except Exception as e:
logger.warning(f"Failed to send zero-torque on disconnect: {e}")
def disconnect(self):
# Put robot in passive mode before stopping threads
if not self.config.is_simulation:
self._send_zero_torque()
# Signal thread to stop and unblock any waits
self._shutdown_event.set()
@@ -248,6 +378,12 @@ class UnitreeG1(Robot):
if self.subscribe_thread.is_alive():
logger.warning("Subscribe thread did not stop cleanly")
# Wait for controller thread to finish
if self._controller_thread is not None:
self._controller_thread.join(timeout=2.0)
if self._controller_thread.is_alive():
logger.warning("Controller thread did not stop cleanly")
# Close simulation environment
if self.config.is_simulation and self.sim_env is not None:
try:
@@ -274,7 +410,8 @@ class UnitreeG1(Robot):
cam.disconnect()
def get_observation(self) -> RobotObservation:
lowstate = self._lowstate
with self._lowstate_lock:
lowstate = self._lowstate
if lowstate is None:
return {}
@@ -313,14 +450,9 @@ class UnitreeG1(Robot):
obs["imu.rpy.pitch"] = lowstate.imu_state.rpy[1]
obs["imu.rpy.yaw"] = lowstate.imu_state.rpy[2]
# Controller - parse wireless_remote and add to obs
if lowstate.wireless_remote and len(lowstate.wireless_remote) >= 24:
self.remote_controller.set(lowstate.wireless_remote)
obs["remote.buttons"] = self.remote_controller.button.copy()
obs["remote.lx"] = self.remote_controller.lx
obs["remote.ly"] = self.remote_controller.ly
obs["remote.rx"] = self.remote_controller.rx
obs["remote.ry"] = self.remote_controller.ry
# Wireless remote (raw bytes for teleoperator)
if lowstate.wireless_remote:
obs["wireless_remote"] = lowstate.wireless_remote
# Cameras - read images from ZMQ cameras
for cam_name, cam in self._cameras.items():
@@ -328,73 +460,63 @@ class UnitreeG1(Robot):
return obs
def send_action(self, action: RobotAction) -> RobotAction:
action_to_publish = action
if self.controller is not None:
# Controller thread owns legs/waist. Here we only update joystick inputs
# and publish arm targets from the teleoperator.
self._update_controller_action(action)
arm_prefixes = tuple(j.name for j in G1_29_JointArmIndex)
action_to_publish = {
key: value
for key, value in action.items()
if key.endswith(".q") and key.startswith(arm_prefixes)
}
tau = None
if self.config.gravity_compensation and self.arm_ik is not None:
tau = np.zeros(29, dtype=np.float32)
action_np = np.array(
[
action_to_publish.get(f"{joint.name}.q", self.msg.motor_cmd[joint.value].q)
for joint in G1_29_JointArmIndex
],
dtype=np.float32,
)
arm_tau = self.arm_ik.solve_tau(action_np)
arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value
for joint in G1_29_JointArmIndex:
local_idx = joint.value - arm_start_idx
tau[joint.value] = arm_tau[local_idx]
self.publish_lowcmd(action_to_publish, tau=tau)
return action
def _update_controller_action(self, action: RobotAction) -> None:
"""Update controller input state from incoming teleop action."""
with self._controller_action_lock:
for key in REMOTE_KEYS:
if key in action:
self.controller_input[key] = action[key]
@property
def is_calibrated(self) -> bool:
return True
@property
def is_connected(self) -> bool:
return self._lowstate is not None
with self._lowstate_lock:
return self._lowstate is not None
@property
def _motors_ft(self) -> dict[str, type]:
"""Joint positions for all 29 joints."""
return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
@property
def cameras(self) -> dict:
return self._cameras
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
def send_action(self, action: RobotAction) -> RobotAction:
for motor in G1_29_JointIndex:
key = f"{motor.name}.q"
if key in action:
self.msg.motor_cmd[motor.value].q = action[key]
self.msg.motor_cmd[motor.value].qd = 0
self.msg.motor_cmd[motor.value].kp = self.kp[motor.value]
self.msg.motor_cmd[motor.value].kd = self.kd[motor.value]
self.msg.motor_cmd[motor.value].tau = 0
if self.config.gravity_compensation:
# Build action_np from motor commands (arm joints are indices 15-28, local indices 0-13)
action_np = np.zeros(14)
arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value # 15
for joint in G1_29_JointArmIndex:
local_idx = joint.value - arm_start_idx
action_np[local_idx] = self.msg.motor_cmd[joint.value].q
tau = self.arm_ik.solve_tau(action_np)
# Apply tau back to motor commands
for joint in G1_29_JointArmIndex:
local_idx = joint.value - arm_start_idx
self.msg.motor_cmd[joint.value].tau = tau[local_idx]
self.msg.crc = self.crc.Crc(self.msg)
self.lowcmd_publisher.Write(self.msg)
return action
def get_gravity_orientation(self, quaternion): # get gravity orientation from quaternion
"""Get gravity orientation from quaternion."""
qw = quaternion[0]
qx = quaternion[1]
qy = quaternion[2]
qz = quaternion[3]
gravity_orientation = np.zeros(3)
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
return gravity_orientation
def reset(
self,
control_dt: float | None = None,
@@ -407,15 +529,9 @@ class UnitreeG1(Robot):
if self.config.is_simulation and self.sim_env is not None:
self.sim_env.reset()
for motor in G1_29_JointIndex:
self.msg.motor_cmd[motor.value].q = default_positions[motor.value]
self.msg.motor_cmd[motor.value].qd = 0
self.msg.motor_cmd[motor.value].kp = self.kp[motor.value]
self.msg.motor_cmd[motor.value].kd = self.kd[motor.value]
self.msg.motor_cmd[motor.value].tau = 0
self.msg.crc = self.crc.Crc(self.msg)
self.lowcmd_publisher.Write(self.msg)
self.publish_lowcmd(
{f"{motor.name}.q": float(default_positions[motor.value]) for motor in G1_29_JointIndex}
)
else:
total_time = 3.0
num_steps = int(total_time / control_dt)
@@ -446,4 +562,8 @@ class UnitreeG1(Robot):
sleep_time = max(0, control_dt - elapsed)
time.sleep(sleep_time)
# Reset controller internal state (gait phase, obs history, etc.)
if self.controller is not None and hasattr(self.controller, "reset"):
self.controller.reset()
logger.info("Reached default position")

View File

@@ -22,6 +22,8 @@ import zmq
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
# Module-level ZMQ state mirrors the Unitree SDK's global ChannelFactory Singleton.
# Only one robot connection per process is supported.
_ctx: zmq.Context | None = None
_lowcmd_sock: zmq.Socket | None = None
_lowstate_sock: zmq.Socket | None = None
@@ -97,17 +99,22 @@ def lowcmd_to_dict(topic: str, msg: Any) -> dict[str, Any]:
}
def ChannelFactoryInitialize(*args: Any, **kwargs: Any) -> None: # noqa: N802
def ChannelFactoryInitialize(domain_id: int = 0, config: Any = None) -> None: # noqa: N802
"""
Initialize ZMQ sockets for robot communication.
This function mimics the Unitree SDK's ChannelFactoryInitialize but uses
ZMQ sockets to connect to the robot server bridge instead of DDS.
Args:
domain_id: Ignored (for API compatibility with Unitree SDK)
config: UnitreeG1Config instance with robot_ip
"""
global _ctx, _lowcmd_sock, _lowstate_sock
# read socket config
config = UnitreeG1Config()
if config is None:
config = UnitreeG1Config()
robot_ip = config.robot_ip
ctx = zmq.Context.instance()

View File

@@ -369,6 +369,8 @@ def record_loop(
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
elif policy is None and isinstance(teleop, Teleoperator):
if robot.name == "unitree_g1":
teleop.send_feedback(obs)
act = teleop.get_action()
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
@@ -556,10 +558,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
):
log_say("Reset the environment", cfg.play_sounds)
# reset g1 robot
if robot.name == "unitree_g1":
robot.reset()
record_loop(
robot=robot,
events=events,

View File

@@ -60,6 +60,7 @@ import rerun as rr
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.processor import (
RobotAction,
@@ -153,7 +154,6 @@ def teleop_loop(
display_len = max(len(key) for key in robot.action_features)
start = time.perf_counter()
while True:
loop_start = time.perf_counter()
@@ -163,6 +163,9 @@ def teleop_loop(
# given that it is the identity processor as default
obs = robot.get_observation()
if robot.name == "unitree_g1":
teleop.send_feedback(obs)
# Get teleop action
raw_action = teleop.get_action()

View File

@@ -209,7 +209,11 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# Use accelerator's device
device = accelerator.device
torch.backends.cudnn.benchmark = True
if cfg.cudnn_deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
# Dataset loading synchronization: main process downloads first to avoid race conditions

View File

@@ -19,3 +19,13 @@ from .exo_calib import ExoskeletonCalibration, ExoskeletonJointCalibration
from .exo_ik import ExoskeletonIKHelper
from .exo_serial import ExoskeletonArm
from .unitree_g1 import UnitreeG1Teleoperator
__all__ = [
"ExoskeletonArmPortConfig",
"ExoskeletonCalibration",
"ExoskeletonIKHelper",
"ExoskeletonJointCalibration",
"ExoskeletonArm",
"UnitreeG1Teleoperator",
"UnitreeG1TeleoperatorConfig",
]

View File

@@ -35,6 +35,9 @@ import serial
logger = logging.getLogger(__name__)
ADC_MAX = 2**12 - 1
ADC_HALF = ADC_MAX / 2
# exoskeleton joint names -> ADC channel pairs. TODO: add wrist pitch and wrist yaw
JOINTS = {
"shoulder_pitch": (0, 1),
@@ -59,7 +62,7 @@ class ExoskeletonCalibration:
version: int = 2
side: str = ""
adc_max: int = 2**12 - 1
adc_max: int = ADC_MAX
joints: list[ExoskeletonJointCalibration] = field(default_factory=list)
def to_dict(self) -> dict:
@@ -92,7 +95,7 @@ class ExoskeletonCalibration:
return cls(
version=data.get("version", 2),
side=data.get("side", ""),
adc_max=data.get("adc_max", 2**12 - 1),
adc_max=data.get("adc_max", ADC_MAX),
joints=joints,
)
@@ -112,11 +115,8 @@ class CalibParams:
def normalize_angle(angle: float) -> float:
while angle > np.pi:
angle -= 2 * np.pi
while angle < -np.pi:
angle += 2 * np.pi
return angle
"""Normalize angle to [-pi, pi]."""
return float(np.arctan2(np.sin(angle), np.cos(angle)))
def joint_z_and_angle(raw16: list[int], j: ExoskeletonJointCalibration) -> tuple[np.ndarray, float]:
@@ -125,7 +125,7 @@ def joint_z_and_angle(raw16: list[int], j: ExoskeletonJointCalibration) -> tuple
"""
pair = JOINTS[j.name]
s, c = raw16[pair[0]], raw16[pair[1]] # get sin and cos
p = np.array([float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2]) # center the raw values
p = np.array([float(c) - ADC_HALF, float(s) - ADC_HALF]) # center the raw values
z = np.asarray(j.T) @ (
p - np.asarray(j.center_fit)
) # center the ellipse and invert the transformation matrix to get unit circle coords
@@ -167,7 +167,7 @@ def run_exo_calibration(
def read_joint_point(raw16: list[int], pair: tuple[int, int]):
s, c = raw16[pair[0]], raw16[pair[1]]
return float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2, float(s), float(c)
return float(c) - ADC_HALF, float(s) - ADC_HALF, float(s), float(c)
def select_fit_subset(xs, ys):
"""Select and filter points for ellipse fitting. Trims outliers by radius and downsamples."""
@@ -317,7 +317,7 @@ def run_exo_calibration(
calib = ExoskeletonCalibration(
version=2,
side=side,
adc_max=2**12 - 1,
adc_max=ADC_MAX,
joints=[
ExoskeletonJointCalibration(
name=j["name"],
@@ -367,8 +367,8 @@ def run_exo_calibration(
state["win_s"].append(s_raw)
state["win_c"].append(c_raw)
if len(state["win_s"]) >= max(3, params.median_window):
state["ys"].append(running_median(state["win_s"]) - (2**12 - 1) / 2)
state["xs"].append(running_median(state["win_c"]) - (2**12 - 1) / 2)
state["ys"].append(running_median(state["win_s"]) - ADC_HALF)
state["xs"].append(running_median(state["win_c"]) - ADC_HALF)
else:
jdata = joints_out[-1]
z = np.array(jdata["T"]) @ (np.array([x_raw, y_raw]) - np.array(jdata["center_fit"]))

View File

@@ -25,8 +25,8 @@ from dataclasses import dataclass
import numpy as np
from lerobot.robots.unitree_g1.g1_kinematics import G1_29_ArmIK
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
from .exo_calib import JOINTS

View File

@@ -32,25 +32,29 @@ def parse_raw16(line: bytes) -> list[int] | None:
if len(parts) < 16:
return None
return [int(x) for x in parts[:16]]
except Exception:
except (ValueError, IndexError):
return None
def read_raw_from_serial(ser) -> list[int] | None:
"""Read latest sample from serial; if buffer is backed up, keep only the newest."""
last = None
while ser.in_waiting > 0:
b = ser.readline()
if not b:
break
raw16 = parse_raw16(b)
if raw16 is not None:
last = raw16
if last is None:
b = ser.readline()
if b:
last = parse_raw16(b)
return last
try:
last = None
while ser.in_waiting > 0:
b = ser.readline()
if not b:
break
raw16 = parse_raw16(b)
if raw16 is not None:
last = raw16
if last is None:
b = ser.readline()
if b:
last = parse_raw16(b)
return last
except serial.SerialException as e:
logger.warning(f"Serial read error: {e}")
return None
@dataclass
@@ -115,5 +119,6 @@ class ExoskeletonArm:
return {} if raw is None else exo_raw_to_angles(raw, self.calibration)
def calibrate(self) -> None:
ser = self._ser
self.calibration = run_exo_calibration(ser, self.side, self.calibration_fpath)
if not self.is_connected:
raise RuntimeError("Cannot calibrate: exoskeleton not connected")
self.calibration = run_exo_calibration(self._ser, self.side, self.calibration_fpath)

View File

@@ -17,9 +17,22 @@
import logging
import time
from functools import cached_property
from typing import TYPE_CHECKING, Any
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
from lerobot.robots.unitree_g1.g1_utils import REMOTE_AXES, G1_29_JointArmIndex
from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS
from lerobot.utils.import_utils import _unitree_sdk_available
if TYPE_CHECKING or _unitree_sdk_available:
from unitree_sdk2py.utils.joystick import Joystick
else:
class Joystick:
def __init__(self):
raise ImportError(
"unitree_sdk2py is required for RemoteController. Install with: pip install unitree_sdk2py"
)
from ..teleoperator import Teleoperator
from .config_unitree_g1 import UnitreeG1TeleoperatorConfig
@@ -29,6 +42,120 @@ from .exo_serial import ExoskeletonArm
logger = logging.getLogger(__name__)
class RemoteController:
"""Unitree remote controller data parser for joystick and button state."""
# ADC parameters for exoskeleton joystick (12-bit ADC)
ADC_MAX = 4095
ADC_HALF = ADC_MAX / 2
JOYSTICK_X_IDX = 11 # X axis in raw ADC array
JOYSTICK_BTN_IDX = 12 # Button in raw ADC array
JOYSTICK_Y_IDX = 13 # Y axis in raw ADC array
# Map SDK named buttons to positional indices matching the wireless_remote
# byte layout (little-endian uint16 from bytes 2-3).
_BUTTON_MAP: list[str] = [
"RB",
"LB",
"start",
"back",
"RT",
"LT",
"",
"",
"A",
"B",
"X",
"Y",
"up",
"right",
"down",
"left",
]
def __init__(self):
self.lx = 0.0
self.ly = 0.0
self.rx = 0.0
self.ry = 0.0
self.button = [0] * 16
self.remote_action = dict.fromkeys(REMOTE_AXES, 0.0)
# SDK joystick parser for wireless remote bytes
self._joystick = Joystick()
# Disable axis smoothing and deadzone to preserve raw values
for axis in (self._joystick.lx, self._joystick.ly, self._joystick.rx, self._joystick.ry):
axis.smooth = 1.0
axis.deadzone = 0.0
# Joystick center calibration (read at connect time)
self.left_center_x = self.ADC_HALF
self.left_center_y = self.ADC_HALF
self.right_center_x = self.ADC_HALF
self.right_center_y = self.ADC_HALF
# Whether to use exo joystick (detected at connect time)
self.use_left_exo_joystick = False
self.use_right_exo_joystick = False
def _sync_remote_action(self) -> None:
self.remote_action.update(zip(REMOTE_AXES, (self.lx, self.ly, self.rx, self.ry), strict=True))
def calibrate_center(self, raw16: list[int] | None, side: str) -> None:
if raw16 is None or len(raw16) < 16:
logger.info(f"{side.capitalize()} exo joystick: no data available")
return
btn_val = raw16[self.JOYSTICK_BTN_IDX]
logger.info(f"{side.capitalize()} exo joystick button ADC: {btn_val} (threshold: {self.ADC_HALF})")
if btn_val <= self.ADC_HALF:
logger.info(f"{side.capitalize()} exo joystick not detected (button below threshold)")
return
x = raw16[self.JOYSTICK_X_IDX]
y = raw16[self.JOYSTICK_Y_IDX]
if side == "left":
self.use_left_exo_joystick = True
self.left_center_x, self.left_center_y = x, y
else:
self.use_right_exo_joystick = True
self.right_center_x, self.right_center_y = x, y
logger.info(f"{side.capitalize()} exo joystick enabled, center: x={x}, y={y}")
def set_from_exo(self, raw16: list[int] | None, side: str) -> None:
if raw16 is None or len(raw16) < 16:
return
if side == "left":
if not self.use_left_exo_joystick:
return
self.lx = (raw16[self.JOYSTICK_X_IDX] - self.left_center_x) / self.ADC_HALF
self.ly = (raw16[self.JOYSTICK_Y_IDX] - self.left_center_y) / self.ADC_HALF
self.button[4] = 1 if raw16[self.JOYSTICK_BTN_IDX] < self.ADC_HALF else 0
return
if not self.use_right_exo_joystick:
return
self.rx = (raw16[self.JOYSTICK_X_IDX] - self.right_center_x) / self.ADC_HALF
self.ry = (raw16[self.JOYSTICK_Y_IDX] - self.right_center_y) / self.ADC_HALF
self.button[0] = 1 if raw16[self.JOYSTICK_BTN_IDX] < self.ADC_HALF else 0
def set_from_wireless(self, wireless_remote: bytes) -> None:
"""Parse Unitree wireless remote raw bytes into joystick + button state."""
if len(wireless_remote) < 24:
return
self._joystick.extract(wireless_remote)
self.lx = self._joystick.lx.data
self.ly = self._joystick.ly.data
self.rx = self._joystick.rx.data
self.ry = self._joystick.ry.data
for i, name in enumerate(self._BUTTON_MAP):
if name:
self.button[i] = getattr(self._joystick, name).data
class UnitreeG1Teleoperator(Teleoperator):
"""
Bimanual exoskeleton arms teleoperator for Unitree G1 arms.
@@ -43,6 +170,13 @@ class UnitreeG1Teleoperator(Teleoperator):
def __init__(self, config: UnitreeG1TeleoperatorConfig):
super().__init__(config)
self.config = config
left_exo_enabled = bool(config.left_arm_config.port.strip())
right_exo_enabled = bool(config.right_arm_config.port.strip())
if left_exo_enabled != right_exo_enabled:
raise ValueError(
"Invalid exo config: set both left/right exo ports, or leave both empty for remote-only mode."
)
self._arm_control_enabled = left_exo_enabled and right_exo_enabled
# Setup calibration directory
self.calibration_dir = (
@@ -70,24 +204,37 @@ class UnitreeG1Teleoperator(Teleoperator):
)
self.ik_helper: ExoskeletonIKHelper | None = None
self.remote_controller = RemoteController()
@cached_property
def action_features(self) -> dict[str, type]:
return {f"{name}.q": float for name in self._g1_joint_names}
remote_features = dict.fromkeys(self.remote_controller.remote_action, float)
if not self._arm_control_enabled:
return remote_features
joint_features = {f"{name}.q": float for name in self._g1_arm_joint_names}
return {**joint_features, **remote_features}
@cached_property
def feedback_features(self) -> dict[str, type]:
return {}
return {"wireless_remote": bytes}
@property
def is_connected(self) -> bool:
if not self._arm_control_enabled:
return True
return self.left_arm.is_connected and self.right_arm.is_connected
@property
def is_calibrated(self) -> bool:
if not self._arm_control_enabled:
return True
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def connect(self, calibrate: bool = True) -> None:
if not self._arm_control_enabled:
logger.warning("Exo ports not fully configured; teleop will send joystick only (no arm actions)")
return
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@@ -95,6 +242,13 @@ class UnitreeG1Teleoperator(Teleoperator):
self.ik_helper = ExoskeletonIKHelper(frozen_joints=frozen_joints)
logger.info("IK helper initialized")
time.sleep(0.1) # Give serial time to populate buffer
left_raw = self.left_arm.read_raw()
right_raw = self.right_arm.read_raw()
self.remote_controller.calibrate_center(left_raw, "left")
self.remote_controller.calibrate_center(right_raw, "right")
def calibrate(self) -> None:
if not self.left_arm.is_calibrated:
logger.info("Starting calibration for left arm...")
@@ -115,12 +269,33 @@ class UnitreeG1Teleoperator(Teleoperator):
pass
def get_action(self) -> dict[str, float]:
left_angles = self.left_arm.get_angles()
right_angles = self.right_arm.get_angles()
return self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles)
joint_action = {}
left_raw = None
right_raw = None
if self._arm_control_enabled:
left_raw = self.left_arm.read_raw()
right_raw = self.right_arm.read_raw()
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Exoskeleton arms do not support feedback")
left_angles = self.left_arm.get_angles()
right_angles = self.right_arm.get_angles()
joint_action = self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles)
# Wireless remote has priority when non-zero; otherwise, use exo joystick.
rc = self.remote_controller
wireless_active = (
abs(rc.lx) > 1e-3 or abs(rc.ly) > 1e-3 or abs(rc.rx) > 1e-3 or abs(rc.ry) > 1e-3
) or any(rc.button)
if self._arm_control_enabled and not wireless_active:
rc.set_from_exo(left_raw, "left")
rc.set_from_exo(right_raw, "right")
rc._sync_remote_action()
return {**joint_action, **rc.remote_action}
def send_feedback(self, feedback: dict[str, Any]) -> None:
wireless_remote = feedback.get("wireless_remote")
if wireless_remote is not None:
self.remote_controller.set_from_wireless(wireless_remote)
def disconnect(self) -> None:
self.left_arm.disconnect()
@@ -153,5 +328,5 @@ class UnitreeG1Teleoperator(Teleoperator):
print("\n\nVisualization stopped.")
@cached_property
def _g1_joint_names(self) -> list[str]:
return [joint.name for joint in G1_29_JointIndex]
def _g1_arm_joint_names(self) -> list[str]:
return [joint.name for joint in G1_29_JointArmIndex]

View File

@@ -74,6 +74,8 @@ _peft_available = is_package_available("peft")
_scipy_available = is_package_available("scipy")
_reachy2_sdk_available = is_package_available("reachy2_sdk")
_can_available = is_package_available("python-can", "can")
_unitree_sdk_available = is_package_available("unitree-sdk2py", "unitree_sdk2py")
_pygame_available = is_package_available("pygame")
def make_device_from_device_class(config: ChoiceRegistry) -> Any:

View File

@@ -170,6 +170,7 @@ def test_async_read(index_or_path):
assert isinstance(img, np.ndarray)
@pytest.mark.skip("Skipping test: async_read 0 timeout behavior may be flaky/non-deterministic.")
def test_async_read_timeout():
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0)

View File

@@ -1,624 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: E402
"""Test script for Multi-Task DiT policy.
To run tests locally:
python -m pytest tests/policies/multi_task_dit/test_multi_task_dit.py -v
"""
import os
import pytest
import torch
from torch import Tensor
pytest.importorskip("transformers")
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local transformers installation and is not meant for CI",
)
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
make_multi_task_dit_pre_post_processors,
)
from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
from lerobot.utils.random_utils import seeded_context, set_seed
@pytest.fixture(autouse=True)
def set_random_seed():
seed = 17
set_seed(seed)
def create_train_batch(
batch_size: int = 2,
n_obs_steps: int = 2,
horizon: int = 16,
state_dim: int = 10,
action_dim: int = 10,
height: int = 224,
width: int = 224,
) -> dict[str, Tensor]:
"""Create a training batch with visual input and text."""
return {
"observation.state": torch.randn(batch_size, n_obs_steps, state_dim),
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, n_obs_steps, 3, height, width),
ACTION: torch.randn(batch_size, horizon, action_dim),
"task": ["pick up the cube"] * batch_size,
}
def create_observation_batch(
batch_size: int = 2, state_dim: int = 10, height: int = 224, width: int = 224
) -> dict:
"""Create observation batch for inference for a single timestep."""
return {
"observation.state": torch.randn(batch_size, state_dim),
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, height, width),
"task": ["pick up the red cube"] * batch_size,
}
def create_config(
state_dim: int = 10,
action_dim: int = 10,
n_obs_steps: int = 2,
horizon: int = 16,
n_action_steps: int = 8,
with_visual: bool = True,
height: int = 224,
width: int = 224,
) -> MultiTaskDiTConfig:
"""Create a MultiTaskDiT config for testing.
Args:
state_dim: Dimension of state observations
action_dim: Dimension of actions
n_obs_steps: Number of observation steps
horizon: Action prediction horizon
n_action_steps: Number of action steps to execute
with_visual: Whether to include visual input (default: True)
height: Image height (only used if with_visual=True)
width: Image width (only used if with_visual=True)
"""
input_features = {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}
if with_visual:
input_features[f"{OBS_IMAGES}.laptop"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(3, height, width)
)
config = MultiTaskDiTConfig(
input_features=input_features,
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
# Use smaller model for faster tests
hidden_dim=128,
num_layers=2,
num_heads=4,
)
config.validate_features()
return config
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)])
def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_dim: int):
"""Test forward pass (training mode)."""
n_obs_steps = 2
horizon = 16
n_action_steps = 8
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
)
policy = MultiTaskDiTPolicy(config=config)
policy.train()
# Use preprocessor to handle tokenization
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
batch = create_train_batch(
batch_size=batch_size,
n_obs_steps=n_obs_steps,
horizon=horizon,
state_dim=state_dim,
action_dim=action_dim,
)
# Process batch through preprocessor to tokenize task text
processed_batch = preprocessor(batch)
# Test forward pass
loss, _ = policy.forward(processed_batch)
assert loss is not None
assert loss.item() is not None
assert loss.shape == ()
# Test backward pass
loss.backward()
def test_multi_task_dit_pre_post_processors():
"""Test pre and post processors for Multi-Task DiT policy."""
state_dim = 10
action_dim = 8
n_obs_steps = 2
horizon = 16
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=8,
)
config.device = "cpu"
# Set normalization mode to match the stats we're providing
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD, # Use MEAN_STD since we provide mean/std stats
"ACTION": NormalizationMode.MIN_MAX,
}
# Create dataset stats for normalization
dataset_stats = {
"observation.state": {
"mean": torch.zeros(state_dim),
"std": torch.ones(state_dim),
},
"action": {
"min": torch.full((action_dim,), -1.0),
"max": torch.ones(action_dim),
},
}
# Create processors
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(
config=config, dataset_stats=dataset_stats
)
# Test preprocessor with sample data
batch = {
"observation.state": torch.randn(state_dim),
f"{OBS_IMAGES}.laptop": torch.rand(3, 224, 224),
ACTION: torch.randn(action_dim),
"task": "pick up the cube",
}
processed_batch = preprocessor(batch)
# Check that data is batched
assert processed_batch["observation.state"].shape == (1, state_dim)
assert processed_batch[f"{OBS_IMAGES}.laptop"].shape == (1, 3, 224, 224)
assert processed_batch[ACTION].shape == (1, action_dim)
# Check that task text was tokenized
assert OBS_LANGUAGE_TOKENS in processed_batch
assert OBS_LANGUAGE_ATTENTION_MASK in processed_batch
assert processed_batch[OBS_LANGUAGE_TOKENS].shape[0] == 1 # batch dimension
assert processed_batch[OBS_LANGUAGE_ATTENTION_MASK].shape[0] == 1 # batch dimension
# Check that data is on correct device
assert processed_batch["observation.state"].device.type == "cpu"
assert processed_batch[f"{OBS_IMAGES}.laptop"].device.type == "cpu"
assert processed_batch[ACTION].device.type == "cpu"
# Test postprocessor with sample action (PolicyAction is just a torch.Tensor)
action = torch.randn(1, action_dim)
processed_action = postprocessor(action)
# Check that action is unnormalized and on CPU
assert processed_action.shape == (1, action_dim)
assert processed_action.device.type == "cpu"
def test_multi_task_dit_pre_post_processors_normalization():
"""Test that normalization and unnormalization work correctly with simple sanity check numbers."""
state_dim = 3
action_dim = 2
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
n_obs_steps=2,
horizon=16,
n_action_steps=8,
)
config.device = "cpu"
# Set normalization mode to match the stats we're providing
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD, # Use MEAN_STD since we provide mean/std stats
"ACTION": NormalizationMode.MIN_MAX,
}
# Use simple stats that will actually transform the values
dataset_stats = {
"observation.state": {
"mean": torch.full((state_dim,), 5.0),
"std": torch.full((state_dim,), 2.0),
},
"action": {
"min": torch.zeros(action_dim),
"max": torch.full((action_dim,), 2.0),
},
}
# Create processors
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(
config=config, dataset_stats=dataset_stats
)
# Use simple input values
input_state = torch.tensor([7.0, 5.0, 3.0]) # Will normalize to [1.0, 0.0, -1.0]
input_action = torch.tensor([1.0, 2.0]) # Will normalize to [0.0, 1.0]
batch = {
"observation.state": input_state,
f"{OBS_IMAGES}.laptop": torch.rand(3, 224, 224),
ACTION: input_action,
"task": "test task",
}
# Process through preprocessor
processed_batch = preprocessor(batch)
# State normalization: (x - mean) / std
expected_normalized_state = torch.tensor([1.0, 0.0, -1.0])
assert torch.allclose(processed_batch["observation.state"][0], expected_normalized_state, atol=1e-5)
# Action normalization: (x - min) / (max - min) * 2 - 1
expected_normalized_action = torch.tensor([0.0, 1.0])
assert torch.allclose(processed_batch[ACTION][0], expected_normalized_action, atol=1e-5)
# Test unnormalization: should recover original values
normalized_action_tensor = processed_batch[ACTION][0:1] # Keep batch dimension
unnormalized_action = postprocessor(normalized_action_tensor)
# Should recover original action values
assert torch.allclose(unnormalized_action[0], input_action, atol=1e-4)
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)])
def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
"""Test select_action (inference mode)."""
n_obs_steps = 2
horizon = 16
n_action_steps = 8
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
)
policy = MultiTaskDiTPolicy(config=config)
policy.eval()
policy.reset() # Reset queues before inference
# Create processors - use IDENTITY normalization when no stats provided
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
# Process observation through preprocessor
processed_obs = preprocessor(observation_batch)
selected_action = policy.select_action(processed_obs)
# Process action through postprocessor (PolicyAction is just a torch.Tensor)
processed_action = postprocessor(selected_action)
assert processed_action.shape == (batch_size, action_dim)
def test_multi_task_dit_policy_diffusion_objective():
"""Test policy with diffusion objective."""
batch_size = 2
state_dim = 10
action_dim = 10
n_obs_steps = 2
horizon = 16
n_action_steps = 8
input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config = MultiTaskDiTConfig(
input_features=input_features,
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
# Use diffusion objective
objective="diffusion",
noise_scheduler_type="DDPM",
num_train_timesteps=100,
num_inference_steps=10,
# Smaller model for tests
hidden_dim=128,
num_layers=2,
num_heads=4,
)
config.validate_features()
policy = MultiTaskDiTPolicy(config=config)
policy.train()
# Use preprocessor to handle tokenization
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
batch = create_train_batch(
batch_size=batch_size,
n_obs_steps=n_obs_steps,
horizon=horizon,
state_dim=state_dim,
action_dim=action_dim,
)
# Process batch through preprocessor to tokenize task text
processed_batch = preprocessor(batch)
# Test forward pass
loss, _ = policy.forward(processed_batch)
assert loss is not None
assert loss.item() is not None
# Test inference
policy.eval()
# Use IDENTITY normalization when no stats provided
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
# Process observation through preprocessor
processed_obs = preprocessor(observation_batch)
selected_action = policy.select_action(processed_obs)
# Process action through postprocessor (PolicyAction is just a torch.Tensor)
processed_action = postprocessor(selected_action)
assert processed_action.shape == (batch_size, action_dim)
def test_multi_task_dit_policy_flow_matching_objective():
"""Test policy with flow matching objective."""
batch_size = 2
state_dim = 10
action_dim = 10
n_obs_steps = 2
horizon = 16
n_action_steps = 8
input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config = MultiTaskDiTConfig(
input_features=input_features,
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
# Use flow matching objective
objective="flow_matching",
sigma_min=0.0,
num_integration_steps=10, # Fewer steps for faster tests
integration_method="euler",
# Smaller model for tests
hidden_dim=128,
num_layers=2,
num_heads=4,
)
config.validate_features()
policy = MultiTaskDiTPolicy(config=config)
policy.train()
# Use preprocessor to handle tokenization
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
batch = create_train_batch(
batch_size=batch_size,
n_obs_steps=n_obs_steps,
horizon=horizon,
state_dim=state_dim,
action_dim=action_dim,
)
# Process batch through preprocessor to tokenize task text
processed_batch = preprocessor(batch)
# Test forward pass
loss, _ = policy.forward(processed_batch)
assert loss is not None
assert loss.item() is not None
# Test inference
policy.eval()
# Use IDENTITY normalization when no stats provided
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
# Process observation through preprocessor
processed_obs = preprocessor(observation_batch)
selected_action = policy.select_action(processed_obs)
# Process action through postprocessor (PolicyAction is just a torch.Tensor)
processed_action = postprocessor(selected_action)
assert processed_action.shape == (batch_size, action_dim)
def test_multi_task_dit_policy_save_and_load(tmp_path):
"""Test that the policy can be saved and loaded correctly."""
root = tmp_path / "test_multi_task_dit_save_and_load"
state_dim = 10
action_dim = 10
batch_size = 2
n_obs_steps = 2
horizon = 16
n_action_steps = 8
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
)
policy = MultiTaskDiTPolicy(config=config)
policy.eval()
# Get device before saving
device = next(policy.parameters()).device
policy.save_pretrained(root)
loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config)
# Explicitly move loaded_policy to the same device
loaded_policy.to(device)
loaded_policy.eval()
batch = create_train_batch(
batch_size=batch_size,
n_obs_steps=n_obs_steps,
horizon=horizon,
state_dim=state_dim,
action_dim=action_dim,
)
# Use preprocessor to handle tokenization
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
with torch.no_grad():
with seeded_context(12):
# Process batch through preprocessor
processed_batch = preprocessor(batch)
# Move batch to the same device as the policy
for key in processed_batch:
if isinstance(processed_batch[key], torch.Tensor):
processed_batch[key] = processed_batch[key].to(device)
# Collect policy values before saving
loss, _ = policy.forward(processed_batch)
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
# Process observation through preprocessor
processed_obs = preprocessor(observation_batch)
actions = policy.select_action(processed_obs)
with seeded_context(12):
# Process batch through preprocessor
processed_batch = preprocessor(batch)
# Collect policy values after loading
loaded_loss, _ = loaded_policy.forward(processed_batch)
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
processed_obs = preprocessor(loaded_observation_batch)
loaded_actions = loaded_policy.select_action(processed_obs)
# Compare state dicts
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
for k in policy.state_dict():
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
# Compare values before and after saving and loading
assert torch.allclose(loss, loaded_loss)
assert torch.allclose(actions, loaded_actions)
def test_multi_task_dit_policy_get_optim_params():
"""Test that the policy returns correct optimizer parameter groups."""
config = create_config(
state_dim=10,
action_dim=10,
n_obs_steps=2,
horizon=16,
n_action_steps=8,
)
policy = MultiTaskDiTPolicy(config=config)
param_groups = policy.get_optim_params()
# Should have 2 parameter groups: non-vision and vision encoder
assert len(param_groups) == 2
# First group is non-vision params (no lr specified, will use default)
assert "params" in param_groups[0]
assert len(param_groups[0]["params"]) > 0
# Second group is vision encoder params with different lr
assert "params" in param_groups[1]
assert "lr" in param_groups[1]
expected_lr = config.optimizer_lr * config.vision_encoder_lr_multiplier
assert param_groups[1]["lr"] == expected_lr

View File

@@ -0,0 +1,267 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Unitree G1 robot. Meant to be run in an environment where the Unitree SDK is installed."""
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from lerobot.utils.import_utils import _unitree_sdk_available
if not _unitree_sdk_available:
pytest.skip("Unitree SDK not available", allow_module_level=True)
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.g1_utils import (
NUM_MOTORS,
REMOTE_AXES,
REMOTE_BUTTONS,
REMOTE_KEYS,
G1_29_JointArmIndex,
G1_29_JointIndex,
default_remote_input,
get_gravity_orientation,
)
# ---------------------------------------------------------------------------
# Unit tests for g1_utils (no SDK needed)
# ---------------------------------------------------------------------------
class TestG1Utils:
def test_num_motors(self):
assert NUM_MOTORS == 29
def test_joint_index_count(self):
assert len(G1_29_JointIndex) == 29
def test_joint_arm_index_count(self):
assert len(G1_29_JointArmIndex) == 14
def test_arm_indices_are_subset_of_full(self):
full_values = {j.value for j in G1_29_JointIndex}
arm_values = {j.value for j in G1_29_JointArmIndex}
assert arm_values.issubset(full_values)
def test_arm_indices_start_at_15(self):
assert min(j.value for j in G1_29_JointArmIndex) == 15
assert max(j.value for j in G1_29_JointArmIndex) == 28
def test_enum_naming_consistency(self):
"""Verify all wrist joints use consistent PascalCase naming."""
wrist_joints = [j for j in G1_29_JointIndex if "Wrist" in j.name]
for j in wrist_joints:
# Should be "WristYaw", "WristPitch", "WristRoll" — no lowercase after "Wrist"
after_wrist = j.name.split("Wrist")[1]
assert after_wrist[0].isupper(), f"{j.name} has inconsistent casing after 'Wrist'"
def test_remote_keys_structure(self):
assert len(REMOTE_AXES) == 4
assert len(REMOTE_BUTTONS) == 16
assert len(REMOTE_KEYS) == 20
assert REMOTE_KEYS == REMOTE_AXES + REMOTE_BUTTONS
def test_default_remote_input(self):
d = default_remote_input()
assert len(d) == 20
assert all(v == 0.0 for v in d.values())
assert set(d.keys()) == set(REMOTE_KEYS)
def test_gravity_orientation_identity(self):
"""Quaternion [1, 0, 0, 0] (no rotation) should give gravity along -z."""
g = get_gravity_orientation([1.0, 0.0, 0.0, 0.0])
assert g.shape == (3,)
assert g.dtype == np.float32
np.testing.assert_allclose(g, [0.0, 0.0, -1.0], atol=1e-6)
def test_gravity_orientation_dtype(self):
g = get_gravity_orientation(np.array([1.0, 0.0, 0.0, 0.0]))
assert g.dtype == np.float32
# ---------------------------------------------------------------------------
# Unit tests for UnitreeG1Config (no SDK needed)
# ---------------------------------------------------------------------------
class TestUnitreeG1Config:
def test_default_config(self):
cfg = UnitreeG1Config()
assert len(cfg.kp) == 29
assert len(cfg.kd) == 29
assert len(cfg.default_positions) == 29
assert cfg.is_simulation is True
assert cfg.controller is None
assert cfg.gravity_compensation is False
def test_gains_are_positive(self):
cfg = UnitreeG1Config()
assert all(v > 0 for v in cfg.kp)
assert all(v > 0 for v in cfg.kd)
def test_config_copies_gains(self):
"""Each config instance should have its own copy of gains."""
cfg1 = UnitreeG1Config()
cfg2 = UnitreeG1Config()
cfg1.kp[0] = 999.0
assert cfg2.kp[0] != 999.0
# ---------------------------------------------------------------------------
# Robot mock and integration tests
# ---------------------------------------------------------------------------
def _make_lowstate_msg_mock():
"""Create a mock that mimics the SDK LowState_ message."""
msg = MagicMock()
for i in range(29):
motor = MagicMock()
motor.q = float(i) * 0.1
motor.dq = float(i) * 0.01
motor.tau_est = float(i) * 0.001
motor.temperature = 30.0 + i
msg.motor_state.__getitem__ = lambda self, idx, _motors={}: _motors.setdefault(
idx, MagicMock(q=idx * 0.1, dq=idx * 0.01, tau_est=idx * 0.001, temperature=30.0 + idx)
)
msg.imu_state.quaternion = [1.0, 0.0, 0.0, 0.0]
msg.imu_state.gyroscope = [0.1, 0.2, 0.3]
msg.imu_state.accelerometer = [0.0, 0.0, 9.81]
msg.imu_state.rpy = [0.0, 0.0, 0.0]
msg.imu_state.temperature = 25.0
msg.wireless_remote = b"\x00" * 40
msg.mode_machine = 0
return msg
def _make_sdk_mocks():
"""Create mocks for the Unitree SDK modules used by UnitreeG1."""
lowcmd_default = MagicMock()
lowcmd_default.mode_pr = 0
lowcmd_default.motor_cmd = [MagicMock() for _ in range(35)]
crc_mock = MagicMock()
crc_mock.Crc.return_value = 0
lowstate_msg = _make_lowstate_msg_mock()
subscriber_mock = MagicMock()
subscriber_mock.Read.return_value = lowstate_msg
publisher_mock = MagicMock()
return {
"lowcmd_default": lowcmd_default,
"crc_mock": crc_mock,
"subscriber_mock": subscriber_mock,
"publisher_mock": publisher_mock,
"lowstate_msg": lowstate_msg,
}
@pytest.fixture
def unitree_g1():
"""Create a UnitreeG1 robot with all SDK dependencies mocked."""
mocks = _make_sdk_mocks()
mock_channel_init = MagicMock()
mock_channel_pub = MagicMock(return_value=mocks["publisher_mock"])
mock_channel_sub = MagicMock(return_value=mocks["subscriber_mock"])
with (
patch(
"lerobot.robots.unitree_g1.unitree_g1.make_cameras_from_configs",
return_value={},
),
patch(
"lerobot.robots.unitree_g1.unitree_g1.G1_29_ArmIK",
return_value=MagicMock(),
),
patch(
"lerobot.robots.unitree_g1.unitree_g1._SDKChannelFactoryInitialize",
mock_channel_init,
),
patch(
"lerobot.robots.unitree_g1.unitree_g1._SDKChannelPublisher",
mock_channel_pub,
),
patch(
"lerobot.robots.unitree_g1.unitree_g1._SDKChannelSubscriber",
mock_channel_sub,
),
patch(
"lerobot.robots.unitree_g1.unitree_g1.unitree_hg_msg_dds__LowCmd_",
MagicMock(return_value=mocks["lowcmd_default"]),
),
patch(
"lerobot.robots.unitree_g1.unitree_g1.hg_LowCmd",
MagicMock,
),
patch(
"lerobot.robots.unitree_g1.unitree_g1.hg_LowState",
MagicMock,
),
patch(
"lerobot.robots.unitree_g1.unitree_g1.CRC",
MagicMock(return_value=mocks["crc_mock"]),
),
):
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
cfg = UnitreeG1Config(is_simulation=True, gravity_compensation=False)
robot = UnitreeG1(cfg)
yield robot, mocks
if robot.is_connected:
robot.disconnect()
def test_init_state(unitree_g1):
robot, _ = unitree_g1
assert not robot.is_connected
assert robot.controller is None
def test_observation_features(unitree_g1):
robot, _ = unitree_g1
features = robot.observation_features
# Should have .q for all 29 joints (no cameras configured)
assert len(features) == 29
for joint in G1_29_JointIndex:
assert f"{joint.name}.q" in features
def test_action_features_no_controller(unitree_g1):
robot, _ = unitree_g1
features = robot.action_features
# Without controller: all 29 joints
assert len(features) == 29
for joint in G1_29_JointIndex:
assert f"{joint.name}.q" in features
def test_get_observation_before_connect(unitree_g1):
robot, _ = unitree_g1
obs = robot.get_observation()
assert obs == {}
def test_disconnect_idempotent(unitree_g1):
robot, _ = unitree_g1
# Should not raise even when not connected
robot.disconnect()

View File

@@ -0,0 +1,309 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Unitree G1 teleoperator. Meant to be run in an environment where the Unitree SDK is installed."""
from unittest.mock import MagicMock
import pytest
from lerobot.utils.import_utils import _unitree_sdk_available
if not _unitree_sdk_available:
pytest.skip("Unitree SDK not available", allow_module_level=True)
from lerobot.robots.unitree_g1.g1_utils import REMOTE_AXES
from lerobot.teleoperators.unitree_g1.config_unitree_g1 import (
ExoskeletonArmPortConfig,
UnitreeG1TeleoperatorConfig,
)
from lerobot.teleoperators.unitree_g1.unitree_g1 import RemoteController, UnitreeG1Teleoperator
# ---------------------------------------------------------------------------
# Tests for RemoteController
# ---------------------------------------------------------------------------
def _make_joystick_mock():
"""Create a mock Joystick class matching the SDK interface."""
joystick = MagicMock()
# Axes are Axis objects with .data attribute
joystick.lx = MagicMock(data=0.0, smooth=0.03, deadzone=0.01)
joystick.ly = MagicMock(data=0.0, smooth=0.03, deadzone=0.01)
joystick.rx = MagicMock(data=0.0, smooth=0.03, deadzone=0.01)
joystick.ry = MagicMock(data=0.0, smooth=0.03, deadzone=0.01)
# Buttons are Button objects with .data attribute
for name in ["RB", "LB", "start", "back", "RT", "LT", "A", "B", "X", "Y", "up", "right", "down", "left"]:
setattr(joystick, name, MagicMock(data=0))
return joystick
@pytest.fixture
def remote_controller():
"""Create a RemoteController with a mocked Joystick."""
mock_joystick = _make_joystick_mock()
rc = RemoteController()
rc._joystick = mock_joystick
yield rc, mock_joystick
def test_remote_controller_init(remote_controller):
rc, _ = remote_controller
assert rc.lx == 0.0
assert rc.ly == 0.0
assert rc.rx == 0.0
assert rc.ry == 0.0
assert len(rc.button) == 16
assert all(b == 0 for b in rc.button)
def test_sync_remote_action(remote_controller):
rc, _ = remote_controller
rc.lx = 0.5
rc.ly = -0.3
rc.rx = 0.1
rc.ry = 0.0
rc._sync_remote_action()
assert rc.remote_action["remote.lx"] == 0.5
assert rc.remote_action["remote.ly"] == -0.3
assert rc.remote_action["remote.rx"] == 0.1
assert rc.remote_action["remote.ry"] == 0.0
def test_set_from_wireless_calls_extract(remote_controller):
rc, mock_joystick = remote_controller
# Set up the mock to populate data after extract
mock_joystick.lx.data = 0.5
mock_joystick.ly.data = -0.3
mock_joystick.rx.data = 0.1
mock_joystick.ry.data = 0.0
wireless_data = b"\x00" * 40
rc.set_from_wireless(wireless_data)
mock_joystick.extract.assert_called_once_with(wireless_data)
assert rc.lx == 0.5
assert rc.ly == -0.3
def test_set_from_wireless_short_data(remote_controller):
rc, mock_joystick = remote_controller
rc.set_from_wireless(b"\x00" * 10) # Too short
mock_joystick.extract.assert_not_called()
def test_set_from_wireless_buttons(remote_controller):
rc, mock_joystick = remote_controller
# Simulate RB pressed
mock_joystick.RB.data = 1
mock_joystick.lx.data = 0.0
mock_joystick.ly.data = 0.0
mock_joystick.rx.data = 0.0
mock_joystick.ry.data = 0.0
rc.set_from_wireless(b"\x00" * 40)
assert rc.button[0] == 1 # RB maps to button[0]
def test_set_from_exo_left(remote_controller):
rc, _ = remote_controller
rc.use_left_exo_joystick = True
rc.left_center_x = 2048
rc.left_center_y = 2048
raw16 = [0] * 16
raw16[11] = 3048 # X axis: (3048 - 2048) / 2047.5 ≈ 0.488
raw16[13] = 1048 # Y axis: (1048 - 2048) / 2047.5 ≈ -0.488
raw16[12] = 0 # Button pressed (below ADC_HALF)
rc.set_from_exo(raw16, "left")
assert rc.lx == pytest.approx((3048 - 2048) / 2047.5, abs=1e-3)
assert rc.ly == pytest.approx((1048 - 2048) / 2047.5, abs=1e-3)
assert rc.button[4] == 1 # Left button maps to button[4]
def test_set_from_exo_clears_button(remote_controller):
rc, _ = remote_controller
rc.use_left_exo_joystick = True
rc.button[4] = 1 # Pre-set
raw16 = [0] * 16
raw16[12] = 4000 # Button NOT pressed (above ADC_HALF)
rc.set_from_exo(raw16, "left")
assert rc.button[4] == 0 # Should be cleared
def test_set_from_exo_ignored_when_not_enabled(remote_controller):
rc, _ = remote_controller
rc.use_left_exo_joystick = False
raw16 = [0] * 16
raw16[11] = 3000
rc.set_from_exo(raw16, "left")
assert rc.lx == 0.0 # Unchanged
# ---------------------------------------------------------------------------
# Tests for UnitreeG1TeleoperatorConfig (no SDK needed)
# ---------------------------------------------------------------------------
class TestTeleoperatorConfig:
def test_default_config(self):
cfg = UnitreeG1TeleoperatorConfig()
assert cfg.left_arm_config.port == ""
assert cfg.right_arm_config.port == ""
assert cfg.frozen_joints == ""
def test_config_with_ports(self):
cfg = UnitreeG1TeleoperatorConfig(
left_arm_config=ExoskeletonArmPortConfig(port="/dev/ttyACM0"),
right_arm_config=ExoskeletonArmPortConfig(port="/dev/ttyACM1"),
)
assert cfg.left_arm_config.port == "/dev/ttyACM0"
assert cfg.right_arm_config.port == "/dev/ttyACM1"
# ---------------------------------------------------------------------------
# Tests for UnitreeG1Teleoperator
# ---------------------------------------------------------------------------
@pytest.fixture
def teleop_remote_only():
"""Create a UnitreeG1Teleoperator in remote-only mode (no exo arms)."""
cfg = UnitreeG1TeleoperatorConfig() # No ports = remote-only mode
teleop = UnitreeG1Teleoperator(cfg)
yield teleop
def test_remote_only_connect(teleop_remote_only):
"""Remote-only mode should connect immediately without serial ports."""
teleop = teleop_remote_only
teleop.connect()
assert teleop.is_connected
assert not teleop._arm_control_enabled
def test_remote_only_action_features(teleop_remote_only):
teleop = teleop_remote_only
features = teleop.action_features
# Remote-only: just the 4 remote axes
assert set(features.keys()) == set(REMOTE_AXES)
def test_feedback_features(teleop_remote_only):
teleop = teleop_remote_only
features = teleop.feedback_features
assert "wireless_remote" in features
assert features["wireless_remote"] is bytes
def test_remote_only_get_action(teleop_remote_only):
teleop = teleop_remote_only
teleop.connect()
action = teleop.get_action()
assert set(action.keys()) == set(REMOTE_AXES)
assert all(isinstance(v, float) for v in action.values())
def test_send_feedback(teleop_remote_only):
teleop = teleop_remote_only
teleop.connect()
# Should not raise
teleop.send_feedback({"wireless_remote": b"\x00" * 40})
def test_send_feedback_missing_key(teleop_remote_only):
teleop = teleop_remote_only
teleop.connect()
# Should not raise even with missing key
teleop.send_feedback({"other_key": 42})
def test_asymmetric_exo_ports_raises():
"""Configuring only one exo port should raise ValueError."""
cfg = UnitreeG1TeleoperatorConfig(
left_arm_config=ExoskeletonArmPortConfig(port="/dev/ttyACM0"),
# right_arm_config left empty
)
with pytest.raises(ValueError, match="set both left/right"):
UnitreeG1Teleoperator(cfg)
# ---------------------------------------------------------------------------
# Tests for ExoskeletonArm (needs serial mock)
# ---------------------------------------------------------------------------
class TestExoskeletonArm:
def test_parse_raw16_valid(self):
from lerobot.teleoperators.unitree_g1.exo_serial import parse_raw16
line = b"100 200 300 400 500 600 700 800 900 1000 1100 1200 1300 1400 1500 1600\n"
result = parse_raw16(line)
assert result is not None
assert len(result) == 16
assert result[0] == 100
assert result[15] == 1600
def test_parse_raw16_too_short(self):
from lerobot.teleoperators.unitree_g1.exo_serial import parse_raw16
line = b"100 200 300\n"
assert parse_raw16(line) is None
def test_parse_raw16_garbage(self):
from lerobot.teleoperators.unitree_g1.exo_serial import parse_raw16
assert parse_raw16(b"not numbers at all\n") is None
assert parse_raw16(b"\xff\xfe\xfd\n") is None
assert parse_raw16(b"") is None
def test_calibrate_requires_connection(self):
from lerobot.teleoperators.unitree_g1.exo_serial import ExoskeletonArm
arm = ExoskeletonArm(
port="/dev/null",
calibration_fpath=MagicMock(is_file=MagicMock(return_value=False)),
side="left",
)
with pytest.raises(RuntimeError, match="not connected"):
arm.calibrate()
def test_is_connected_false_by_default(self):
from lerobot.teleoperators.unitree_g1.exo_serial import ExoskeletonArm
arm = ExoskeletonArm(
port="/dev/null",
calibration_fpath=MagicMock(is_file=MagicMock(return_value=False)),
side="left",
)
assert not arm.is_connected
assert not arm.is_calibrated
def test_read_raw_when_disconnected(self):
from lerobot.teleoperators.unitree_g1.exo_serial import ExoskeletonArm
arm = ExoskeletonArm(
port="/dev/null",
calibration_fpath=MagicMock(is_file=MagicMock(return_value=False)),
side="left",
)
assert arm.read_raw() is None