mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
Compare commits
11 Commits
v0.4.2
...
feat/conve
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c74cbe599 | ||
|
|
fa3919a0ff | ||
|
|
e38346316b | ||
|
|
2a2b648891 | ||
|
|
cf36f4b873 | ||
|
|
e1ae51b02a | ||
|
|
797cd2725a | ||
|
|
af4766b602 | ||
|
|
37f43df88a | ||
|
|
5f7b5f2817 | ||
|
|
c55fbe1b3e |
7
.github/workflows/fast_tests.yml
vendored
7
.github/workflows/fast_tests.yml
vendored
@@ -60,12 +60,19 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
lfs: true
|
||||
|
||||
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
|
||||
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
|
||||
- name: Setup /mnt storage
|
||||
run: sudo chown -R $USER:$USER /mnt
|
||||
|
||||
# TODO(Steven): Evaluate the need of these dependencies
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
|
||||
7
.github/workflows/full_tests.yml
vendored
7
.github/workflows/full_tests.yml
vendored
@@ -58,12 +58,19 @@ jobs:
|
||||
github.event_name == 'workflow_dispatch'
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
|
||||
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
|
||||
- name: Setup /mnt storage
|
||||
run: sudo chown -R $USER:$USER /mnt
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
|
||||
7
.github/workflows/unbound_deps_tests.yml
vendored
7
.github/workflows/unbound_deps_tests.yml
vendored
@@ -45,12 +45,19 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
|
||||
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
|
||||
- name: Setup /mnt storage
|
||||
run: sudo chown -R $USER:$USER /mnt
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
|
||||
@@ -79,6 +79,8 @@
|
||||
title: Hope Jr
|
||||
- local: reachy2
|
||||
title: Reachy 2
|
||||
- local: unitree_g1
|
||||
title: Unitree G1
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
|
||||
203
docs/source/unitree_g1.mdx
Normal file
203
docs/source/unitree_g1.mdx
Normal file
@@ -0,0 +1,203 @@
|
||||
# Unitree G1 Robot Setup and Control
|
||||
|
||||
This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion.
|
||||
|
||||
## About the Unitree G1
|
||||
|
||||
We offer support for both 29 and 23 DOF G1. In this first PR we introduce:
|
||||
|
||||
- **`unitree g1` robot class, handling low level communication with the humanoid**
|
||||
- **ZMQ socket bridge** for remote communication over WiFi, allowing one to deploy policies remotely instead of over ethernet or directly on the Orin
|
||||
- **GR00T locomotion policy** for bipedal walking and balance
|
||||
|
||||
---
|
||||
|
||||
## Part 1: Connect to Robot over Ethernet
|
||||
|
||||
### Step 1: Configure Your Computer's Ethernet Interface
|
||||
|
||||
Set a static IP on the same subnet as the robot:
|
||||
|
||||
```bash
|
||||
# Replace 'enp131s0' with your ethernet interface name (check with `ip a`)
|
||||
sudo ip addr flush dev enp131s0
|
||||
sudo ip addr add 192.168.123.200/24 dev enp131s0
|
||||
sudo ip link set enp131s0 up
|
||||
```
|
||||
|
||||
**Note**: The robot's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` where x ≠ 164.
|
||||
|
||||
### Step 2: SSH into the Robot
|
||||
|
||||
```bash
|
||||
ssh unitree@192.168.123.164
|
||||
# Password: 123
|
||||
```
|
||||
|
||||
You should now be connected to the robot's onboard computer.
|
||||
|
||||
---
|
||||
|
||||
## Part 2: Enable WiFi on the Robot
|
||||
|
||||
Once connected via Ethernet, follow these steps to enable WiFi:
|
||||
|
||||
### Step 1: Enable WiFi Hardware
|
||||
|
||||
```bash
|
||||
# Unblock WiFi radio
|
||||
sudo rfkill unblock wifi
|
||||
sudo rfkill unblock all
|
||||
|
||||
# Bring up WiFi interface
|
||||
sudo ip link set wlan0 up
|
||||
|
||||
# Enable NetworkManager control
|
||||
sudo nmcli radio wifi on
|
||||
sudo nmcli device set wlan0 managed yes
|
||||
sudo systemctl restart NetworkManager
|
||||
```
|
||||
|
||||
### Step 2: Enable Internet Forwarding
|
||||
|
||||
**On your laptop:**
|
||||
|
||||
```bash
|
||||
# Enable IP forwarding
|
||||
sudo sysctl -w net.ipv4.ip_forward=1
|
||||
|
||||
# Set up NAT (replace wlp132s0f0 with your WiFi interface)
|
||||
sudo iptables -t nat -A POSTROUTING -o wlp132s0f0 -s 192.168.123.0/24 -j MASQUERADE
|
||||
sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTABLISHED -j ACCEPT
|
||||
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
|
||||
```
|
||||
|
||||
**On the robot:**
|
||||
|
||||
```bash
|
||||
# Add laptop as default gateway
|
||||
sudo ip route del default 2>/dev/null || true
|
||||
sudo ip route add default via 192.168.123.200 dev eth0
|
||||
echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
|
||||
|
||||
# Test connection
|
||||
ping -c 3 8.8.8.8
|
||||
```
|
||||
|
||||
### Step 3: Connect to WiFi Network
|
||||
|
||||
```bash
|
||||
# List available networks
|
||||
nmcli device wifi list
|
||||
|
||||
# Connect to your WiFi (example)
|
||||
sudo nmcli connection add type wifi ifname wlan0 con-name "YourNetwork" ssid "YourNetwork"
|
||||
sudo nmcli connection modify "YourNetwork" wifi-sec.key-mgmt wpa-psk
|
||||
sudo nmcli connection modify "YourNetwork" wifi-sec.psk "YourPassword"
|
||||
sudo nmcli connection modify "YourNetwork" connection.autoconnect yes
|
||||
sudo nmcli connection up "YourNetwork"
|
||||
|
||||
# Check WiFi IP address
|
||||
ip a show wlan0
|
||||
```
|
||||
|
||||
### Step 4: SSH Over WiFi
|
||||
|
||||
Once connected to WiFi, note the robot's IP address and disconnect the Ethernet cable. You can now SSH over WiFi:
|
||||
|
||||
```bash
|
||||
ssh unitree@<YOUR_ROBOT_IP>
|
||||
# Password: 123
|
||||
```
|
||||
|
||||
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address (e.g., `172.18.129.215`).
|
||||
|
||||
---
|
||||
|
||||
## Part 3: Robot Server Setup
|
||||
|
||||
### Step 1: Install LeRobot on the Orin
|
||||
|
||||
SSH into the robot and install LeRobot:
|
||||
|
||||
```bash
|
||||
ssh unitree@<YOUR_ROBOT_IP>
|
||||
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
```
|
||||
|
||||
**Note**: The Unitree SDK requires CycloneDDS v0.10.2 to be installed. See the [Unitree SDK documentation](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
|
||||
|
||||
### Step 2: Run the Robot Server
|
||||
|
||||
On the robot:
|
||||
|
||||
```bash
|
||||
python src/lerobot/robots/unitree_g1/run_g1_server.py
|
||||
```
|
||||
|
||||
**Important**: Keep this terminal running. The server must be active for remote control.
|
||||
|
||||
---
|
||||
|
||||
## Part 4: Running GR00T Locomotion
|
||||
|
||||
With the robot server running, you can now control the robot from your laptop.
|
||||
|
||||
### Step 1: Install LeRobot on your machine
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
```
|
||||
|
||||
### Step 2: Update Robot IP in Config
|
||||
|
||||
Edit the config file to match your robot's WiFi IP:
|
||||
|
||||
```python
|
||||
# In src/lerobot/robots/unitree_g1/config_unitree_g1.py
|
||||
robot_ip: str = "<YOUR_ROBOT_IP>" # Replace with your robot's WiFi IP.
|
||||
```
|
||||
|
||||
**Note**: When running directly on the G1 (not remotely), set `robot_ip: str = "127.0.0.1"` instead.
|
||||
|
||||
### Step 3: Run the Locomotion Policy
|
||||
|
||||
```bash
|
||||
# Run GR00T locomotion controller
|
||||
python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1"
|
||||
```
|
||||
|
||||
### Step 4: Control with Remote
|
||||
|
||||
- **Left stick**: Forward/backward and left/right movement
|
||||
- **Right stick**: Rotation
|
||||
- **R1 button**: Raise waist height
|
||||
- **R2 button**: Lower waist height
|
||||
|
||||
Press `Ctrl+C` to stop the policy.
|
||||
|
||||
---
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Unitree SDK Documentation](https://github.com/unitreerobotics/unitree_sdk2_python)
|
||||
- [GR00T Policy Repository](https://huggingface.co/nepyope/GR00T-WholeBodyControl_g1)
|
||||
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||
- [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
|
||||
|
||||
---
|
||||
|
||||
_Last updated: December 2025_
|
||||
245
examples/dataset/aggregate_egodex.py
Normal file
245
examples/dataset/aggregate_egodex.py
Normal file
@@ -0,0 +1,245 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Aggregate EgoDex shards into a single dataset.
|
||||
|
||||
After distributed processing creates multiple shards, this script combines
|
||||
them into a single unified dataset.
|
||||
|
||||
Reference: https://arxiv.org/abs/2505.11709, https://github.com/apple/ml-egodex
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
|
||||
class AggregateEgoDexDatasets(PipelineStep):
|
||||
"""Datatrove pipeline step for aggregating EgoDex shards."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_ids: list[str],
|
||||
aggregated_repo_id: str,
|
||||
local_dir: Path | str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
hf_repo_id: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.aggr_repo_id = aggregated_repo_id
|
||||
self.local_dir = Path(local_dir) if local_dir else None
|
||||
self.push_to_hub = push_to_hub
|
||||
self.hf_repo_id = hf_repo_id if hf_repo_id else aggregated_repo_id
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
import logging
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
|
||||
# Only worker 0 performs aggregation (aggregate_datasets handles parallelism internally)
|
||||
if rank == 0:
|
||||
logging.info(f"Starting aggregation of {len(self.repo_ids)} shards into {self.aggr_repo_id}")
|
||||
|
||||
# Build roots list if local_dir is specified
|
||||
roots = None
|
||||
if self.local_dir:
|
||||
roots = [self.local_dir / repo_id for repo_id in self.repo_ids]
|
||||
# Filter to only existing directories
|
||||
existing_roots = [r for r in roots if r.exists()]
|
||||
if len(existing_roots) != len(self.repo_ids):
|
||||
logging.warning(
|
||||
f"Only {len(existing_roots)} of {len(self.repo_ids)} shard directories found. "
|
||||
"Missing shards will be skipped."
|
||||
)
|
||||
# Update repo_ids to match existing roots
|
||||
existing_repo_ids = [
|
||||
repo_id for repo_id, r in zip(self.repo_ids, roots, strict=False) if r.exists()
|
||||
]
|
||||
roots = existing_roots
|
||||
self.repo_ids = existing_repo_ids
|
||||
|
||||
if len(self.repo_ids) == 0:
|
||||
logging.error("No shard directories found. Nothing to aggregate.")
|
||||
return
|
||||
|
||||
aggr_root = self.local_dir / self.aggr_repo_id if self.local_dir else None
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids=self.repo_ids,
|
||||
aggr_repo_id=self.aggr_repo_id,
|
||||
roots=roots,
|
||||
aggr_root=aggr_root,
|
||||
)
|
||||
logging.info("Aggregation complete!")
|
||||
|
||||
# Push to Hugging Face Hub if requested
|
||||
if self.push_to_hub:
|
||||
logging.info(f"Pushing to Hugging Face Hub as {self.hf_repo_id}...")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=self.aggr_repo_id,
|
||||
root=aggr_root,
|
||||
)
|
||||
# Update repo_id for pushing to different HF account if specified
|
||||
dataset.repo_id = self.hf_repo_id
|
||||
dataset.push_to_hub(
|
||||
tags=["egodex", "hand", "dexterous", "lerobot"],
|
||||
license="cc-by-nc-nd-4.0",
|
||||
)
|
||||
logging.info("Push to hub complete!")
|
||||
else:
|
||||
logging.info(f"Worker {rank} skipping - only worker 0 performs aggregation")
|
||||
|
||||
|
||||
def make_aggregate_executor(
|
||||
repo_id,
|
||||
num_shards,
|
||||
job_name,
|
||||
logs_dir,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
local_dir,
|
||||
push_to_hub,
|
||||
hf_repo_id,
|
||||
slurm=True,
|
||||
):
|
||||
"""Create executor for aggregating EgoDex shards."""
|
||||
# Generate repo IDs for all shards
|
||||
repo_ids = [f"{repo_id}_world_{num_shards}_rank_{rank}" for rank in range(num_shards)]
|
||||
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
AggregateEgoDexDatasets(repo_ids, repo_id, local_dir, push_to_hub, hf_repo_id),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": 1, # Only need 1 task for aggregation
|
||||
"workers": 1, # Only need 1 worker
|
||||
"time": "24:00:00", # 24 hours for aggregation
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
executor = SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
}
|
||||
)
|
||||
executor = LocalPipelineExecutor(**kwargs)
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Aggregate EgoDex dataset shards into a single unified dataset."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier (base name without shard suffix, e.g., pepijn/egodex-test)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-shards",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Number of shards to aggregate (must match --workers from slurm_port_egodex.py)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
default=Path("logs"),
|
||||
help="Path to logs directory for datatrove",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="aggr_egodex",
|
||||
help="Job name used in SLURM",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch over SLURM. Use --slurm 0 to launch locally (for debugging)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
help="SLURM partition (ideally CPU partition)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of CPUs for aggregation task",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="8G",
|
||||
help="Memory per CPU for aggregation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Local directory where shards are stored. If not specified, uses default HF cache.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Push aggregated dataset to Hugging Face Hub after aggregation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Hugging Face repo ID for upload (e.g., username/dataset-name). Defaults to --repo-id.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
|
||||
aggregate_executor = make_aggregate_executor(**kwargs)
|
||||
aggregate_executor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
129
examples/dataset/download_egodex.sh
Executable file
129
examples/dataset/download_egodex.sh
Executable file
@@ -0,0 +1,129 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Download EgoDex dataset
|
||||
# Reference: https://arxiv.org/abs/2505.11709, https://github.com/apple/ml-egodex
|
||||
#
|
||||
# Usage: ./download_egodex.sh [output_dir] [parts...]
|
||||
#
|
||||
# Examples:
|
||||
# ./download_egodex.sh ./data test # Download test set only (16 GB)
|
||||
# ./download_egodex.sh ./data part1 part2 # Download training parts 1 and 2
|
||||
# ./download_egodex.sh ./data all # Download everything (~1.7 TB)
|
||||
#
|
||||
# Available parts:
|
||||
# test - Test set (16 GB)
|
||||
# part1 - Training set part 1 (300 GB)
|
||||
# part2 - Training set part 2 (300 GB)
|
||||
# part3 - Training set part 3 (300 GB)
|
||||
# part4 - Training set part 4 (300 GB)
|
||||
# part5 - Training set part 5 (300 GB)
|
||||
# extra - Additional data (200 GB)
|
||||
# all - Download all parts (~1.7 TB total)
|
||||
|
||||
set -e
|
||||
|
||||
BASE_URL="https://ml-site.cdn-apple.com/datasets/egodex"
|
||||
|
||||
# Map part names to filenames
|
||||
declare -A PART_FILES=(
|
||||
["test"]="test.zip"
|
||||
["part1"]="part1.zip"
|
||||
["part2"]="part2.zip"
|
||||
["part3"]="part3.zip"
|
||||
["part4"]="part4.zip"
|
||||
["part5"]="part5.zip"
|
||||
["extra"]="extra.zip"
|
||||
)
|
||||
|
||||
ALL_PARTS=("test" "part1" "part2" "part3" "part4" "part5" "extra")
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 <output_dir> <parts...>"
|
||||
echo ""
|
||||
echo "Examples:"
|
||||
echo " $0 ./data test # Download test set only (16 GB)"
|
||||
echo " $0 ./data part1 part2 # Download training parts 1 and 2"
|
||||
echo " $0 ./data all # Download everything (~1.7 TB)"
|
||||
echo ""
|
||||
echo "Available parts: test, part1, part2, part3, part4, part5, extra, all"
|
||||
exit 1
|
||||
}
|
||||
|
||||
download_part() {
|
||||
local output_dir="$1"
|
||||
local part="$2"
|
||||
local filename="${PART_FILES[$part]}"
|
||||
local url="${BASE_URL}/${filename}"
|
||||
local output_file="${output_dir}/${filename}"
|
||||
|
||||
echo "----------------------------------------"
|
||||
echo "Downloading: ${part} (${filename})"
|
||||
echo "URL: ${url}"
|
||||
echo "Output: ${output_file}"
|
||||
echo "----------------------------------------"
|
||||
|
||||
# Download with curl, showing progress
|
||||
curl -L --progress-bar "${url}" -o "${output_file}"
|
||||
|
||||
# Unzip
|
||||
echo "Extracting ${filename}..."
|
||||
unzip -q "${output_file}" -d "${output_dir}"
|
||||
|
||||
# Optionally remove zip file to save space
|
||||
# Uncomment the next line if you want to delete zips after extraction
|
||||
# rm "${output_file}"
|
||||
|
||||
echo "Done: ${part}"
|
||||
echo ""
|
||||
}
|
||||
|
||||
# Check arguments
|
||||
if [ $# -lt 2 ]; then
|
||||
usage
|
||||
fi
|
||||
|
||||
OUTPUT_DIR="$1"
|
||||
shift
|
||||
|
||||
# Create output directory
|
||||
mkdir -p "${OUTPUT_DIR}"
|
||||
|
||||
# Determine which parts to download
|
||||
PARTS_TO_DOWNLOAD=()
|
||||
|
||||
for arg in "$@"; do
|
||||
if [ "$arg" == "all" ]; then
|
||||
PARTS_TO_DOWNLOAD=("${ALL_PARTS[@]}")
|
||||
break
|
||||
elif [ -n "${PART_FILES[$arg]}" ]; then
|
||||
PARTS_TO_DOWNLOAD+=("$arg")
|
||||
else
|
||||
echo "Error: Unknown part '${arg}'"
|
||||
echo "Available parts: test, part1, part2, part3, part4, part5, extra, all"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
if [ ${#PARTS_TO_DOWNLOAD[@]} -eq 0 ]; then
|
||||
echo "Error: No valid parts specified"
|
||||
usage
|
||||
fi
|
||||
|
||||
echo "========================================"
|
||||
echo "EgoDex Dataset Download"
|
||||
echo "========================================"
|
||||
echo "Output directory: ${OUTPUT_DIR}"
|
||||
echo "Parts to download: ${PARTS_TO_DOWNLOAD[*]}"
|
||||
echo "========================================"
|
||||
echo ""
|
||||
|
||||
# Download each part
|
||||
for part in "${PARTS_TO_DOWNLOAD[@]}"; do
|
||||
download_part "${OUTPUT_DIR}" "${part}"
|
||||
done
|
||||
|
||||
echo "========================================"
|
||||
echo "Download complete!"
|
||||
echo "Data saved to: ${OUTPUT_DIR}"
|
||||
echo "========================================"
|
||||
|
||||
443
examples/dataset/slurm_port_egodex.py
Normal file
443
examples/dataset/slurm_port_egodex.py
Normal file
@@ -0,0 +1,443 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Distributed EgoDex dataset porting using SLURM and datatrove.
|
||||
|
||||
EgoDex is a large-scale dataset for egocentric dexterous manipulation collected
|
||||
with ARKit on Apple Vision Pro. This script converts EgoDex data to LeRobot format.
|
||||
|
||||
Reference: https://arxiv.org/abs/2505.11709, https://github.com/apple/ml-egodex
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
import mediapy as mpy
|
||||
import numpy as np
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Image dimensions
|
||||
DEFAULT_IMAGE_HEIGHT = 1080
|
||||
DEFAULT_IMAGE_WIDTH = 1920
|
||||
|
||||
class PortEgoDexShards(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
raw_dir: Path | str,
|
||||
repo_id: str,
|
||||
local_dir: Path | str = None,
|
||||
percentage: float = 100.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.raw_dir = Path(raw_dir)
|
||||
self.repo_id = repo_id
|
||||
self.local_dir = Path(local_dir) if local_dir else Path("data/local_datasets")
|
||||
self.percentage = percentage
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
import mediapy as mpy
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
def _get_state_for_single_frame(transforms_group, frame_idx):
|
||||
"""
|
||||
Construct 48D hand state representation from EgoDex.
|
||||
|
||||
State vector composition (per hand = 24D, total = 48D):
|
||||
- Wrist 3D position (3)
|
||||
- Wrist orientation in 6D representation (6)
|
||||
- 5 fingertip 3D positions (15)
|
||||
"""
|
||||
state_vector = []
|
||||
fingertip_joints = {
|
||||
"left": [
|
||||
"leftThumbTip",
|
||||
"leftIndexFingerTip",
|
||||
"leftMiddleFingerTip",
|
||||
"leftRingFingerTip",
|
||||
"leftLittleFingerTip",
|
||||
],
|
||||
"right": [
|
||||
"rightThumbTip",
|
||||
"rightIndexFingerTip",
|
||||
"rightMiddleFingerTip",
|
||||
"rightRingFingerTip",
|
||||
"rightLittleFingerTip",
|
||||
],
|
||||
}
|
||||
|
||||
for hand_side in ["left", "right"]:
|
||||
hand_key = f"{hand_side}Hand"
|
||||
hand_transform = transforms_group[hand_key][frame_idx]
|
||||
|
||||
# 1. Wrist 3D position
|
||||
hand_position = hand_transform[:3, 3]
|
||||
state_vector.extend(hand_position)
|
||||
|
||||
# 2. Wrist orientation in compact 6D representation
|
||||
rotation_matrix = hand_transform[:3, :3]
|
||||
rotation_6d = np.concatenate([rotation_matrix[:, 0], rotation_matrix[:, 1]])
|
||||
state_vector.extend(rotation_6d)
|
||||
|
||||
# 3. 3D positions of 5 fingertips
|
||||
for fingertip in fingertip_joints[hand_side]:
|
||||
fingertip_transform = transforms_group[fingertip][frame_idx]
|
||||
fingertip_pos = fingertip_transform[:3, 3]
|
||||
state_vector.extend(fingertip_pos)
|
||||
|
||||
# Also return camera extrinsics for optional coordinate frame transformations
|
||||
return np.array(state_vector, dtype=np.float32), transforms_group["camera"][frame_idx]
|
||||
|
||||
def get_state_and_action_from_egodex_annotations(demo):
|
||||
"""
|
||||
Convert EgoDex demo annotations into states and actions.
|
||||
|
||||
The "action" is the state at time t+1 (next-pose prediction).
|
||||
"""
|
||||
transforms_group = demo["transforms"]
|
||||
total_frames = list(transforms_group.values())[0].shape[0]
|
||||
|
||||
states_list, extrinsics_list = [], []
|
||||
for frame_idx in range(total_frames):
|
||||
state_vector, extrinsics = _get_state_for_single_frame(transforms_group, frame_idx)
|
||||
states_list.append(state_vector)
|
||||
extrinsics_list.append(extrinsics.flatten()) # Flatten 4x4 to 16D
|
||||
|
||||
state = np.array(states_list, dtype=np.float32)
|
||||
extrinsics = np.array(extrinsics_list, dtype=np.float32)
|
||||
|
||||
# Shift by 1 timestep to convert state to action
|
||||
action = np.roll(state, -1, axis=0)
|
||||
|
||||
return state, action, extrinsics
|
||||
|
||||
def process_demo(hdf5_file_path, video_path):
|
||||
"""Process a single EgoDex demo and return frames for LeRobot."""
|
||||
video = mpy.read_video(str(video_path))
|
||||
video = np.asarray(video)
|
||||
num_frames = video.shape[0]
|
||||
frames = []
|
||||
|
||||
with h5py.File(hdf5_file_path, "r") as demo:
|
||||
state, action, extrinsics = get_state_and_action_from_egodex_annotations(demo)
|
||||
|
||||
# Get natural language task description
|
||||
if demo.attrs.get("llm_type") == "reversible":
|
||||
direction = demo.attrs.get("which_llm_description", "1")
|
||||
lang_instruction = demo.attrs.get(
|
||||
"llm_description" if direction == "1" else "llm_description2",
|
||||
"manipulation task",
|
||||
)
|
||||
else:
|
||||
lang_instruction = demo.attrs.get("llm_description", "manipulation task")
|
||||
|
||||
for step_idx in range(num_frames):
|
||||
# Resize image to default dimensions
|
||||
image_resized = cv2.resize(
|
||||
video[step_idx],
|
||||
(DEFAULT_IMAGE_WIDTH, DEFAULT_IMAGE_HEIGHT),
|
||||
interpolation=cv2.INTER_AREA,
|
||||
)
|
||||
frame = {
|
||||
"task": lang_instruction,
|
||||
"observation.image": image_resized,
|
||||
"observation.state": state[step_idx],
|
||||
"observation.extrinsics": extrinsics[step_idx],
|
||||
"action": action[step_idx],
|
||||
}
|
||||
frames.append(frame)
|
||||
|
||||
return frames
|
||||
|
||||
init_logging()
|
||||
|
||||
# Define EgoDex features
|
||||
EGODEX_FEATURES = {
|
||||
"observation.image": {
|
||||
"dtype": "video",
|
||||
"shape": (DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3),
|
||||
"names": ["height", "width", "rgb"],
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (48,),
|
||||
"names": [
|
||||
# Left hand wrist position (3)
|
||||
"left_wrist_x",
|
||||
"left_wrist_y",
|
||||
"left_wrist_z",
|
||||
# Left hand wrist rotation 6D (6)
|
||||
"left_rot_0",
|
||||
"left_rot_1",
|
||||
"left_rot_2",
|
||||
"left_rot_3",
|
||||
"left_rot_4",
|
||||
"left_rot_5",
|
||||
# Left fingertips (15)
|
||||
"left_thumb_x",
|
||||
"left_thumb_y",
|
||||
"left_thumb_z",
|
||||
"left_index_x",
|
||||
"left_index_y",
|
||||
"left_index_z",
|
||||
"left_middle_x",
|
||||
"left_middle_y",
|
||||
"left_middle_z",
|
||||
"left_ring_x",
|
||||
"left_ring_y",
|
||||
"left_ring_z",
|
||||
"left_little_x",
|
||||
"left_little_y",
|
||||
"left_little_z",
|
||||
# Right hand wrist position (3)
|
||||
"right_wrist_x",
|
||||
"right_wrist_y",
|
||||
"right_wrist_z",
|
||||
# Right hand wrist rotation 6D (6)
|
||||
"right_rot_0",
|
||||
"right_rot_1",
|
||||
"right_rot_2",
|
||||
"right_rot_3",
|
||||
"right_rot_4",
|
||||
"right_rot_5",
|
||||
# Right fingertips (15)
|
||||
"right_thumb_x",
|
||||
"right_thumb_y",
|
||||
"right_thumb_z",
|
||||
"right_index_x",
|
||||
"right_index_y",
|
||||
"right_index_z",
|
||||
"right_middle_x",
|
||||
"right_middle_y",
|
||||
"right_middle_z",
|
||||
"right_ring_x",
|
||||
"right_ring_y",
|
||||
"right_ring_z",
|
||||
"right_little_x",
|
||||
"right_little_y",
|
||||
"right_little_z",
|
||||
],
|
||||
},
|
||||
"observation.extrinsics": {
|
||||
"dtype": "float32",
|
||||
"shape": (16,),
|
||||
"names": [f"extrinsic_{i}" for i in range(16)],
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (48,),
|
||||
"names": [f"action_{i}" for i in range(48)],
|
||||
},
|
||||
}
|
||||
|
||||
# 1. Discover all HDF5 files
|
||||
files = sorted(list(self.raw_dir.rglob("*.hdf5")))
|
||||
if not files:
|
||||
print(f"No HDF5 files found in {self.raw_dir}")
|
||||
return
|
||||
|
||||
# 2. Apply percentage filter
|
||||
if self.percentage < 100:
|
||||
num_files = max(1, int(len(files) * self.percentage / 100))
|
||||
files = files[:num_files]
|
||||
print(f"Processing {self.percentage}% of dataset: {num_files} files")
|
||||
|
||||
# 3. Assign files to this worker
|
||||
my_files = files[rank::world_size]
|
||||
if not my_files:
|
||||
print(f"Rank {rank} has no files to process.")
|
||||
return
|
||||
|
||||
print(f"Rank {rank} processing {len(my_files)} files out of {len(files)} total.")
|
||||
|
||||
# 4. Create a LeRobot dataset for this shard
|
||||
shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}"
|
||||
shard_root = self.local_dir / shard_repo_id if self.local_dir else None
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=shard_repo_id,
|
||||
fps=30,
|
||||
robot_type="hand",
|
||||
features=EGODEX_FEATURES,
|
||||
root=shard_root,
|
||||
)
|
||||
|
||||
# 5. Process each file
|
||||
for input_h5 in my_files:
|
||||
try:
|
||||
# Derive corresponding video path
|
||||
video_path = input_h5.with_suffix(".mp4")
|
||||
if not video_path.exists():
|
||||
print(f"Warning: Video file not found for {input_h5}, skipping.")
|
||||
continue
|
||||
|
||||
# Process demo and add frames
|
||||
frames = process_demo(input_h5, video_path)
|
||||
for frame in frames:
|
||||
dataset.add_frame(frame)
|
||||
dataset.save_episode()
|
||||
|
||||
# Clean up to avoid OOM
|
||||
del frames
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {input_h5}: {e}")
|
||||
continue
|
||||
|
||||
# 6. Finalize the dataset
|
||||
dataset.finalize()
|
||||
|
||||
|
||||
def make_port_executor(
|
||||
raw_dir,
|
||||
repo_id,
|
||||
job_name,
|
||||
logs_dir,
|
||||
workers,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
local_dir,
|
||||
percentage,
|
||||
slurm=True,
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
PortEgoDexShards(raw_dir, repo_id, local_dir, percentage),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": workers,
|
||||
"workers": workers,
|
||||
"time": "10:00:00", # EgoDex is large, allow more time
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
executor = SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"tasks": workers,
|
||||
"workers": 1, # Run locally sequentially for debugging
|
||||
}
|
||||
)
|
||||
executor = LocalPipelineExecutor(**kwargs)
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert EgoDex dataset to LeRobot format using SLURM."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing input EgoDex data (HDF5 + MP4 files).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier (e.g., user/egodex-lerobot).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
default=Path("logs"),
|
||||
help="Path to logs directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="port_egodex",
|
||||
help="Job name used in SLURM.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch over SLURM. Use --slurm 0 to launch sequentially (useful for debugging).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of SLURM workers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
help="SLURM partition.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of CPUs per worker.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="4G",
|
||||
help="Memory per CPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--percentage",
|
||||
type=float,
|
||||
default=100.0,
|
||||
help="Percentage of dataset to process (e.g., 1.0 for 1%%). Useful for testing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Local directory to save the LeRobot dataset. Defaults to data/local_datasets.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
|
||||
port_executor = make_port_executor(**kwargs)
|
||||
port_executor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
347
examples/unitree_g1/gr00t_locomotion.py
Normal file
347
examples/unitree_g1/gr00t_locomotion.py
Normal file
@@ -0,0 +1,347 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Example: GR00T Locomotion with Pre-loaded Policies
|
||||
|
||||
This example demonstrates the NEW pattern for loading GR00T policies externally
|
||||
and passing them to the robot class.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GROOT_DEFAULT_ANGLES = np.zeros(29, dtype=np.float32)
|
||||
GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # hip pitch
|
||||
GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # knee
|
||||
GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # ankle pitch
|
||||
|
||||
MISSING_JOINTS = []
|
||||
G1_MODEL = "g1_23" # or "g1_29"
|
||||
if G1_MODEL == "g1_23":
|
||||
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # waist yaw/pitch, wrist pitch/yaw
|
||||
|
||||
LOCOMOTION_ACTION_SCALE = 0.25
|
||||
|
||||
LOCOMOTION_CONTROL_DT = 0.02
|
||||
|
||||
ANG_VEL_SCALE: float = 0.25
|
||||
DOF_POS_SCALE: float = 1.0
|
||||
DOF_VEL_SCALE: float = 0.05
|
||||
CMD_SCALE: list = [2.0, 2.0, 0.25]
|
||||
|
||||
|
||||
DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
|
||||
|
||||
|
||||
def load_groot_policies(
|
||||
repo_id: str = DEFAULT_GROOT_REPO_ID,
|
||||
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
|
||||
"""Load GR00T dual-policy system (Balance + Walk) from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
repo_id: Hugging Face Hub repository ID containing the ONNX policies.
|
||||
"""
|
||||
logger.info(f"Loading GR00T dual-policy system from Hugging Face Hub ({repo_id})...")
|
||||
|
||||
# Download ONNX policies from Hugging Face Hub
|
||||
balance_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="GR00T-WholeBodyControl-Balance.onnx",
|
||||
)
|
||||
walk_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="GR00T-WholeBodyControl-Walk.onnx",
|
||||
)
|
||||
|
||||
# Load ONNX policies
|
||||
policy_balance = ort.InferenceSession(balance_path)
|
||||
policy_walk = ort.InferenceSession(walk_path)
|
||||
|
||||
logger.info("GR00T policies loaded successfully")
|
||||
|
||||
return policy_balance, policy_walk
|
||||
|
||||
|
||||
class GrootLocomotionController:
|
||||
"""
|
||||
Handles GR00T-style locomotion control for the Unitree G1 robot.
|
||||
|
||||
This controller manages:
|
||||
- Dual-policy system (Balance + Walk)
|
||||
- 29-joint observation processing
|
||||
- 15D action output (legs + waist)
|
||||
- Policy inference and motor command generation
|
||||
"""
|
||||
|
||||
def __init__(self, policy_balance, policy_walk, robot, config):
|
||||
self.policy_balance = policy_balance
|
||||
self.policy_walk = policy_walk
|
||||
self.robot = robot
|
||||
self.config = config
|
||||
|
||||
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot
|
||||
|
||||
# GR00T-specific state
|
||||
self.groot_qj_all = np.zeros(29, dtype=np.float32)
|
||||
self.groot_dqj_all = np.zeros(29, dtype=np.float32)
|
||||
self.groot_action = np.zeros(15, dtype=np.float32)
|
||||
self.groot_obs_single = np.zeros(86, dtype=np.float32)
|
||||
self.groot_obs_history = deque(maxlen=6)
|
||||
self.groot_obs_stacked = np.zeros(516, dtype=np.float32)
|
||||
self.groot_height_cmd = 0.74 # Default base height
|
||||
self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
||||
# input to gr00t is 6 frames (6*86D=516)
|
||||
for _ in range(6):
|
||||
self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
|
||||
|
||||
# Thread management
|
||||
self.locomotion_running = False
|
||||
self.locomotion_thread = None
|
||||
|
||||
logger.info("GrootLocomotionController initialized")
|
||||
|
||||
def groot_locomotion_run(self):
|
||||
# get current observation
|
||||
robot_state = self.robot.get_observation()
|
||||
|
||||
if robot_state is None:
|
||||
return
|
||||
|
||||
# get command from remote controller
|
||||
if robot_state.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||
if self.robot.remote_controller.button[0]: # R1 - raise waist
|
||||
self.groot_height_cmd += 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
if self.robot.remote_controller.button[4]: # R2 - lower waist
|
||||
self.groot_height_cmd -= 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
else:
|
||||
self.robot.remote_controller.lx = 0.0
|
||||
self.robot.remote_controller.ly = 0.0
|
||||
self.robot.remote_controller.rx = 0.0
|
||||
self.robot.remote_controller.ry = 0.0
|
||||
|
||||
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
|
||||
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right
|
||||
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # rotation rate
|
||||
|
||||
for i in range(29):
|
||||
self.groot_qj_all[i] = robot_state.motor_state[i].q
|
||||
self.groot_dqj_all[i] = robot_state.motor_state[i].dq
|
||||
|
||||
# adapt observation for g1_23dof
|
||||
for idx in MISSING_JOINTS:
|
||||
self.groot_qj_all[idx] = 0.0
|
||||
self.groot_dqj_all[idx] = 0.0
|
||||
|
||||
# Scale joint positions and velocities
|
||||
qj_obs = self.groot_qj_all.copy()
|
||||
dqj_obs = self.groot_dqj_all.copy()
|
||||
|
||||
# express imu data in gravity frame of reference
|
||||
quat = robot_state.imu_state.quaternion
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
# scale joint positions and velocities before policy inference
|
||||
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
|
||||
dqj_obs = dqj_obs * DOF_VEL_SCALE
|
||||
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
||||
|
||||
# build single frame observation
|
||||
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE)
|
||||
self.groot_obs_single[3] = self.groot_height_cmd
|
||||
self.groot_obs_single[4:7] = self.groot_orientation_cmd
|
||||
self.groot_obs_single[7:10] = ang_vel_scaled
|
||||
self.groot_obs_single[10:13] = gravity_orientation
|
||||
self.groot_obs_single[13:42] = qj_obs
|
||||
self.groot_obs_single[42:71] = dqj_obs
|
||||
self.groot_obs_single[71:86] = self.groot_action # 15D previous actions
|
||||
|
||||
# Add to history and stack observations (6 frames × 86D = 516D)
|
||||
self.groot_obs_history.append(self.groot_obs_single.copy())
|
||||
|
||||
# Stack all 6 frames into 516D vector
|
||||
for i, obs_frame in enumerate(self.groot_obs_history):
|
||||
start_idx = i * 86
|
||||
end_idx = start_idx + 86
|
||||
self.groot_obs_stacked[start_idx:end_idx] = obs_frame
|
||||
|
||||
# Run policy inference (ONNX) with 516D stacked observation
|
||||
|
||||
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
|
||||
|
||||
selected_policy = (
|
||||
self.policy_balance if cmd_magnitude < 0.05 else self.policy_walk
|
||||
) # balance/standing policy for small commands, walking policy for movement commands
|
||||
|
||||
# run policy inference
|
||||
ort_inputs = {selected_policy.get_inputs()[0].name: np.expand_dims(self.groot_obs_stacked, axis=0)}
|
||||
ort_outs = selected_policy.run(None, ort_inputs)
|
||||
self.groot_action = ort_outs[0].squeeze()
|
||||
|
||||
# transform action back to target joint positions
|
||||
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * LOCOMOTION_ACTION_SCALE
|
||||
|
||||
# command motors
|
||||
for i in range(15):
|
||||
motor_idx = i
|
||||
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# adapt action for g1_23dof
|
||||
for joint_idx in MISSING_JOINTS:
|
||||
self.robot.msg.motor_cmd[joint_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[joint_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[joint_idx].kp = self.robot.kp[joint_idx]
|
||||
self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd[joint_idx]
|
||||
self.robot.msg.motor_cmd[joint_idx].tau = 0
|
||||
|
||||
# send action to robot
|
||||
self.robot.send_action(self.robot.msg)
|
||||
|
||||
def _locomotion_thread_loop(self):
|
||||
"""Background thread that runs the locomotion policy at specified rate."""
|
||||
logger.info("Locomotion thread started")
|
||||
while self.locomotion_running:
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.groot_locomotion_run()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in locomotion loop: {e}")
|
||||
|
||||
# Sleep to maintain control rate
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger.info("Locomotion thread stopped")
|
||||
|
||||
def start_locomotion_thread(self):
|
||||
if self.locomotion_running:
|
||||
logger.warning("Locomotion thread already running")
|
||||
return
|
||||
|
||||
logger.info("Starting locomotion control thread...")
|
||||
self.locomotion_running = True
|
||||
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
|
||||
self.locomotion_thread.start()
|
||||
|
||||
logger.info("Locomotion control thread started!")
|
||||
|
||||
def stop_locomotion_thread(self):
|
||||
if not self.locomotion_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping locomotion control thread...")
|
||||
self.locomotion_running = False
|
||||
if self.locomotion_thread:
|
||||
self.locomotion_thread.join(timeout=2.0)
|
||||
logger.info("Locomotion control thread stopped")
|
||||
|
||||
def reset_robot(self):
|
||||
"""Move robot legs to default standing position over 2 seconds (arms are not moved)."""
|
||||
total_time = 3.0
|
||||
num_step = int(total_time / self.robot.control_dt)
|
||||
|
||||
# Only control legs, not arms (first 12 joints)
|
||||
default_pos = GROOT_DEFAULT_ANGLES # First 12 values are leg angles
|
||||
dof_size = len(default_pos)
|
||||
|
||||
# Get current lowstate
|
||||
robot_state = self.robot.get_observation()
|
||||
|
||||
# Record the current leg positions
|
||||
init_dof_pos = np.zeros(dof_size, dtype=np.float32)
|
||||
for i in range(dof_size):
|
||||
init_dof_pos[i] = robot_state.motor_state[i].q
|
||||
|
||||
# Move legs to default pos
|
||||
for i in range(num_step):
|
||||
alpha = i / num_step
|
||||
for motor_idx in range(dof_size):
|
||||
target_pos = default_pos[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = (
|
||||
init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
|
||||
)
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(self.robot.control_dt)
|
||||
logger.info("Reached default position (legs only)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=DEFAULT_GROOT_REPO_ID,
|
||||
help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# load policies
|
||||
policy_balance, policy_walk = load_groot_policies(repo_id=args.repo_id)
|
||||
|
||||
# initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
# initialize gr00t locomotion controller
|
||||
groot_controller = GrootLocomotionController(
|
||||
policy_balance=policy_balance,
|
||||
policy_walk=policy_walk,
|
||||
robot=robot,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# reset legs and start locomotion thread
|
||||
try:
|
||||
groot_controller.reset_robot()
|
||||
groot_controller.start_locomotion_thread()
|
||||
|
||||
# log status
|
||||
logger.info("Robot initialized with GR00T locomotion policies")
|
||||
logger.info("Locomotion controller running in background thread")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# keep robot alive
|
||||
while True:
|
||||
time.sleep(1.0)
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping locomotion...")
|
||||
groot_controller.stop_locomotion_thread()
|
||||
print("Done!")
|
||||
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.4.2"
|
||||
version = "0.4.3"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
readme = "README.md"
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -107,6 +107,10 @@ dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
unitree_g1 = [
|
||||
"pyzmq>=26.2.1,<28.0.0",
|
||||
"onnxruntime>=1.16.0"
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
|
||||
@@ -538,6 +538,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
if config.compile_model:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
||||
# Also compile the main forward pass used during training
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
|
||||
18
src/lerobot/robots/unitree_g1/__init__.py
Normal file
18
src/lerobot/robots/unitree_g1/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
from .unitree_g1 import UnitreeG1
|
||||
55
src/lerobot/robots/unitree_g1/config_unitree_g1.py
Normal file
55
src/lerobot/robots/unitree_g1/config_unitree_g1.py
Normal file
@@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
_GAINS: dict[str, dict[str, list[float]]] = {
|
||||
"left_leg": {
|
||||
"kp": [150, 150, 150, 300, 40, 40],
|
||||
"kd": [2, 2, 2, 4, 2, 2],
|
||||
}, # pitch, roll, yaw, knee, ankle_pitch, ankle_roll
|
||||
"right_leg": {"kp": [150, 150, 150, 300, 40, 40], "kd": [2, 2, 2, 4, 2, 2]},
|
||||
"waist": {"kp": [250, 250, 250], "kd": [5, 5, 5]}, # yaw, roll, pitch
|
||||
"left_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]}, # shoulder_pitch/roll/yaw, elbow
|
||||
"left_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]}, # roll, pitch, yaw
|
||||
"right_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]},
|
||||
"right_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]},
|
||||
"other": {"kp": [80, 80, 80, 80, 80, 80], "kd": [3, 3, 3, 3, 3, 3]},
|
||||
}
|
||||
|
||||
|
||||
def _build_gains() -> tuple[list[float], list[float]]:
|
||||
"""Build kp and kd lists from body-part groupings."""
|
||||
kp = [v for g in _GAINS.values() for v in g["kp"]]
|
||||
kd = [v for g in _GAINS.values() for v in g["kd"]]
|
||||
return kp, kd
|
||||
|
||||
|
||||
_DEFAULT_KP, _DEFAULT_KD = _build_gains()
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("unitree_g1")
|
||||
@dataclass
|
||||
class UnitreeG1Config(RobotConfig):
|
||||
kp: list[float] = field(default_factory=lambda: _DEFAULT_KP.copy())
|
||||
kd: list[float] = field(default_factory=lambda: _DEFAULT_KD.copy())
|
||||
|
||||
control_dt: float = 1.0 / 250.0 # 250Hz
|
||||
|
||||
# socket config for ZMQ bridge
|
||||
robot_ip: str = "192.168.123.164"
|
||||
89
src/lerobot/robots/unitree_g1/g1_utils.py
Normal file
89
src/lerobot/robots/unitree_g1/g1_utils.py
Normal file
@@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
# ruff: noqa: N801, N815
|
||||
|
||||
NUM_MOTORS = 35
|
||||
|
||||
|
||||
class G1_29_JointArmIndex(IntEnum):
|
||||
# Left arm
|
||||
kLeftShoulderPitch = 15
|
||||
kLeftShoulderRoll = 16
|
||||
kLeftShoulderYaw = 17
|
||||
kLeftElbow = 18
|
||||
kLeftWristRoll = 19
|
||||
kLeftWristPitch = 20
|
||||
kLeftWristyaw = 21
|
||||
|
||||
# Right arm
|
||||
kRightShoulderPitch = 22
|
||||
kRightShoulderRoll = 23
|
||||
kRightShoulderYaw = 24
|
||||
kRightElbow = 25
|
||||
kRightWristRoll = 26
|
||||
kRightWristPitch = 27
|
||||
kRightWristYaw = 28
|
||||
|
||||
|
||||
class G1_29_JointIndex(IntEnum):
|
||||
# Left leg
|
||||
kLeftHipPitch = 0
|
||||
kLeftHipRoll = 1
|
||||
kLeftHipYaw = 2
|
||||
kLeftKnee = 3
|
||||
kLeftAnklePitch = 4
|
||||
kLeftAnkleRoll = 5
|
||||
|
||||
# Right leg
|
||||
kRightHipPitch = 6
|
||||
kRightHipRoll = 7
|
||||
kRightHipYaw = 8
|
||||
kRightKnee = 9
|
||||
kRightAnklePitch = 10
|
||||
kRightAnkleRoll = 11
|
||||
|
||||
kWaistYaw = 12
|
||||
kWaistRoll = 13
|
||||
kWaistPitch = 14
|
||||
|
||||
# Left arm
|
||||
kLeftShoulderPitch = 15
|
||||
kLeftShoulderRoll = 16
|
||||
kLeftShoulderYaw = 17
|
||||
kLeftElbow = 18
|
||||
kLeftWristRoll = 19
|
||||
kLeftWristPitch = 20
|
||||
kLeftWristyaw = 21
|
||||
|
||||
# Right arm
|
||||
kRightShoulderPitch = 22
|
||||
kRightShoulderRoll = 23
|
||||
kRightShoulderYaw = 24
|
||||
kRightElbow = 25
|
||||
kRightWristRoll = 26
|
||||
kRightWristPitch = 27
|
||||
kRightWristYaw = 28
|
||||
|
||||
# not used
|
||||
kNotUsedJoint0 = 29
|
||||
kNotUsedJoint1 = 30
|
||||
kNotUsedJoint2 = 31
|
||||
kNotUsedJoint3 = 32
|
||||
kNotUsedJoint4 = 33
|
||||
kNotUsedJoint5 = 34
|
||||
212
src/lerobot/robots/unitree_g1/run_g1_server.py
Normal file
212
src/lerobot/robots/unitree_g1/run_g1_server.py
Normal file
@@ -0,0 +1,212 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
DDS-to-ZMQ bridge server for Unitree G1 robot.
|
||||
|
||||
This server runs on the robot and forwards:
|
||||
- Robot state (LowState) from DDS to ZMQ (for remote clients)
|
||||
- Robot commands (LowCmd) from ZMQ to DDS (from remote clients)
|
||||
|
||||
Uses JSON for secure serialization instead of pickle.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import zmq
|
||||
from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import MotionSwitcherClient
|
||||
from unitree_sdk2py.core.channel import ChannelFactoryInitialize, ChannelPublisher, ChannelSubscriber
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
|
||||
# DDS topic names follow Unitree SDK naming conventions
|
||||
# ruff: noqa: N816
|
||||
kTopicLowCommand_Debug = "rt/lowcmd" # action to robot
|
||||
kTopicLowState = "rt/lowstate" # observation from robot
|
||||
|
||||
LOWCMD_PORT = 6000
|
||||
LOWSTATE_PORT = 6001
|
||||
NUM_MOTORS = 35
|
||||
|
||||
|
||||
def lowstate_to_dict(msg: hg_LowState) -> dict[str, Any]:
|
||||
"""Convert LowState SDK message to a JSON-serializable dictionary."""
|
||||
motor_states = []
|
||||
for i in range(NUM_MOTORS):
|
||||
temp = msg.motor_state[i].temperature
|
||||
avg_temp = float(sum(temp) / len(temp)) if isinstance(temp, list) else float(temp)
|
||||
motor_states.append(
|
||||
{
|
||||
"q": float(msg.motor_state[i].q),
|
||||
"dq": float(msg.motor_state[i].dq),
|
||||
"tau_est": float(msg.motor_state[i].tau_est),
|
||||
"temperature": avg_temp,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"motor_state": motor_states,
|
||||
"imu_state": {
|
||||
"quaternion": [float(x) for x in msg.imu_state.quaternion],
|
||||
"gyroscope": [float(x) for x in msg.imu_state.gyroscope],
|
||||
"accelerometer": [float(x) for x in msg.imu_state.accelerometer],
|
||||
"rpy": [float(x) for x in msg.imu_state.rpy],
|
||||
"temperature": float(msg.imu_state.temperature),
|
||||
},
|
||||
# Encode bytes as base64 for JSON compatibility
|
||||
"wireless_remote": base64.b64encode(bytes(msg.wireless_remote)).decode("ascii"),
|
||||
"mode_machine": int(msg.mode_machine),
|
||||
}
|
||||
|
||||
|
||||
def dict_to_lowcmd(data: dict[str, Any]) -> hg_LowCmd:
|
||||
"""Convert dictionary back to LowCmd SDK message."""
|
||||
cmd = unitree_hg_msg_dds__LowCmd_()
|
||||
cmd.mode_pr = data.get("mode_pr", 0)
|
||||
cmd.mode_machine = data.get("mode_machine", 0)
|
||||
|
||||
for i, motor_data in enumerate(data.get("motor_cmd", [])):
|
||||
cmd.motor_cmd[i].mode = motor_data.get("mode", 0)
|
||||
cmd.motor_cmd[i].q = motor_data.get("q", 0.0)
|
||||
cmd.motor_cmd[i].dq = motor_data.get("dq", 0.0)
|
||||
cmd.motor_cmd[i].kp = motor_data.get("kp", 0.0)
|
||||
cmd.motor_cmd[i].kd = motor_data.get("kd", 0.0)
|
||||
cmd.motor_cmd[i].tau = motor_data.get("tau", 0.0)
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def state_forward_loop(
|
||||
lowstate_sub: ChannelSubscriber,
|
||||
lowstate_sock: zmq.Socket,
|
||||
state_period: float,
|
||||
) -> None:
|
||||
"""Read observation from DDS and forward to ZMQ clients."""
|
||||
last_state_time = 0.0
|
||||
|
||||
while True:
|
||||
# read from DDS
|
||||
msg = lowstate_sub.Read()
|
||||
if msg is None:
|
||||
continue
|
||||
|
||||
now = time.time()
|
||||
# optional downsampling (if robot dds rate > state_period)
|
||||
if now - last_state_time >= state_period:
|
||||
# Convert to dict and serialize with JSON
|
||||
state_dict = lowstate_to_dict(msg)
|
||||
payload = json.dumps({"topic": kTopicLowState, "data": state_dict}).encode("utf-8")
|
||||
# if no subscribers / tx buffer full, just drop
|
||||
with contextlib.suppress(zmq.Again):
|
||||
lowstate_sock.send(payload, zmq.NOBLOCK)
|
||||
last_state_time = now
|
||||
|
||||
|
||||
def cmd_forward_loop(
|
||||
lowcmd_sock: zmq.Socket,
|
||||
lowcmd_pub_debug: ChannelPublisher,
|
||||
crc: CRC,
|
||||
) -> None:
|
||||
"""Receive commands from ZMQ and forward to DDS."""
|
||||
while True:
|
||||
payload = lowcmd_sock.recv()
|
||||
msg_dict = json.loads(payload.decode("utf-8"))
|
||||
|
||||
topic = msg_dict.get("topic", "")
|
||||
cmd_data = msg_dict.get("data", {})
|
||||
|
||||
# Reconstruct LowCmd object from dict
|
||||
cmd = dict_to_lowcmd(cmd_data)
|
||||
|
||||
# recompute crc
|
||||
cmd.crc = crc.Crc(cmd)
|
||||
|
||||
if topic == kTopicLowCommand_Debug:
|
||||
lowcmd_pub_debug.Write(cmd)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point for the robot server bridge."""
|
||||
# initialize DDS
|
||||
ChannelFactoryInitialize(0)
|
||||
|
||||
# stop all active publishers on the robot
|
||||
msc = MotionSwitcherClient()
|
||||
msc.SetTimeout(5.0)
|
||||
msc.Init()
|
||||
|
||||
status, result = msc.CheckMode()
|
||||
while result is not None and "name" in result and result["name"]:
|
||||
msc.ReleaseMode()
|
||||
status, result = msc.CheckMode()
|
||||
time.sleep(1.0)
|
||||
|
||||
crc = CRC()
|
||||
|
||||
# initialize DDS publisher
|
||||
lowcmd_pub_debug = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||
lowcmd_pub_debug.Init()
|
||||
|
||||
# initialize DDS subscriber
|
||||
lowstate_sub = ChannelSubscriber(kTopicLowState, hg_LowState)
|
||||
lowstate_sub.Init()
|
||||
|
||||
# initialize ZMQ
|
||||
ctx = zmq.Context.instance()
|
||||
|
||||
# receive commands from remote client
|
||||
lowcmd_sock = ctx.socket(zmq.PULL)
|
||||
lowcmd_sock.bind(f"tcp://0.0.0.0:{LOWCMD_PORT}")
|
||||
|
||||
# publish state to remote clients
|
||||
lowstate_sock = ctx.socket(zmq.PUB)
|
||||
lowstate_sock.bind(f"tcp://0.0.0.0:{LOWSTATE_PORT}")
|
||||
|
||||
state_period = 0.002 # ~500 hz
|
||||
|
||||
# start observation forwarding thread
|
||||
t_state = threading.Thread(
|
||||
target=state_forward_loop,
|
||||
args=(lowstate_sub, lowstate_sock, state_period),
|
||||
daemon=True,
|
||||
)
|
||||
t_state.start()
|
||||
|
||||
# start action forwarding thread
|
||||
t_cmd = threading.Thread(
|
||||
target=cmd_forward_loop,
|
||||
args=(lowcmd_sock, lowcmd_pub_debug, crc),
|
||||
daemon=True,
|
||||
)
|
||||
t_cmd.start()
|
||||
|
||||
print("bridge running (lowstate -> zmq, lowcmd -> dds)")
|
||||
# keep main thread alive so daemon threads don't exit
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1.0)
|
||||
except KeyboardInterrupt:
|
||||
print("shutting down bridge...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
267
src/lerobot/robots/unitree_g1/unitree_g1.py
Normal file
267
src/lerobot/robots/unitree_g1/unitree_g1.py
Normal file
@@ -0,0 +1,267 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import struct
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
|
||||
LowCmd_ as hg_LowCmd,
|
||||
LowState_ as hg_LowState,
|
||||
)
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.unitree_sdk2_socket import (
|
||||
ChannelFactoryInitialize,
|
||||
ChannelPublisher,
|
||||
ChannelSubscriber,
|
||||
)
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# DDS topic names follow Unitree SDK naming conventions
|
||||
# ruff: noqa: N816
|
||||
kTopicLowCommand_Debug = "rt/lowcmd"
|
||||
kTopicLowState = "rt/lowstate"
|
||||
|
||||
G1_29_Num_Motors = 35
|
||||
G1_23_Num_Motors = 35
|
||||
H1_2_Num_Motors = 35
|
||||
H1_Num_Motors = 20
|
||||
|
||||
|
||||
@dataclass
|
||||
class MotorState:
|
||||
q: float | None = None # position
|
||||
dq: float | None = None # velocity
|
||||
tau_est: float | None = None # estimated torque
|
||||
temperature: float | None = None # motor temperature
|
||||
|
||||
|
||||
@dataclass
|
||||
class IMUState:
|
||||
quaternion: np.ndarray | None = None # [w, x, y, z]
|
||||
gyroscope: np.ndarray | None = None # [x, y, z] angular velocity (rad/s)
|
||||
accelerometer: np.ndarray | None = None # [x, y, z] linear acceleration (m/s²)
|
||||
rpy: np.ndarray | None = None # [roll, pitch, yaw] (rad)
|
||||
temperature: float | None = None # IMU temperature
|
||||
|
||||
|
||||
# g1 observation class
|
||||
@dataclass
|
||||
class G1_29_LowState: # noqa: N801
|
||||
motor_state: list[MotorState] = field(
|
||||
default_factory=lambda: [MotorState() for _ in range(G1_29_Num_Motors)]
|
||||
)
|
||||
imu_state: IMUState = field(default_factory=IMUState)
|
||||
wireless_remote: Any = None # Raw wireless remote data
|
||||
mode_machine: int = 0 # Robot mode
|
||||
|
||||
|
||||
class DataBuffer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def get_data(self):
|
||||
with self.lock:
|
||||
return self.data
|
||||
|
||||
def set_data(self, data):
|
||||
with self.lock:
|
||||
self.data = data
|
||||
|
||||
|
||||
class UnitreeG1(Robot):
|
||||
config_class = UnitreeG1Config
|
||||
name = "unitree_g1"
|
||||
|
||||
# unitree remote controller
|
||||
class RemoteController:
|
||||
def __init__(self):
|
||||
self.lx = 0
|
||||
self.ly = 0
|
||||
self.rx = 0
|
||||
self.ry = 0
|
||||
self.button = [0] * 16
|
||||
|
||||
def set(self, data):
|
||||
# wireless_remote
|
||||
keys = struct.unpack("H", data[2:4])[0]
|
||||
for i in range(16):
|
||||
self.button[i] = (keys & (1 << i)) >> i
|
||||
self.lx = struct.unpack("f", data[4:8])[0]
|
||||
self.rx = struct.unpack("f", data[8:12])[0]
|
||||
self.ry = struct.unpack("f", data[12:16])[0]
|
||||
self.ly = struct.unpack("f", data[20:24])[0]
|
||||
|
||||
def __init__(self, config: UnitreeG1Config):
|
||||
super().__init__(config)
|
||||
|
||||
logger.info("Initialize UnitreeG1...")
|
||||
|
||||
self.config = config
|
||||
|
||||
self.control_dt = config.control_dt
|
||||
|
||||
# connect robot
|
||||
self.connect()
|
||||
|
||||
# initialize direct motor control interface
|
||||
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||
self.lowcmd_publisher.Init()
|
||||
self.lowstate_subscriber = ChannelSubscriber(kTopicLowState, hg_LowState)
|
||||
self.lowstate_subscriber.Init()
|
||||
self.lowstate_buffer = DataBuffer()
|
||||
|
||||
# initialize subscribe thread to read robot state
|
||||
self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state)
|
||||
self.subscribe_thread.daemon = True
|
||||
self.subscribe_thread.start()
|
||||
|
||||
while not self.is_connected:
|
||||
time.sleep(0.1)
|
||||
|
||||
# initialize hg's lowcmd msg
|
||||
self.crc = CRC()
|
||||
self.msg = unitree_hg_msg_dds__LowCmd_()
|
||||
self.msg.mode_pr = 0
|
||||
|
||||
# Wait for first state message to arrive
|
||||
lowstate = None
|
||||
while lowstate is None:
|
||||
lowstate = self.lowstate_buffer.get_data()
|
||||
if lowstate is None:
|
||||
time.sleep(0.01)
|
||||
logger.warning("[UnitreeG1] Waiting for robot state...")
|
||||
logger.warning("[UnitreeG1] Connected to robot.")
|
||||
self.msg.mode_machine = lowstate.mode_machine
|
||||
|
||||
# initialize all motors with unified kp/kd from config
|
||||
self.kp = np.array(config.kp, dtype=np.float32)
|
||||
self.kd = np.array(config.kd, dtype=np.float32)
|
||||
|
||||
for id in G1_29_JointIndex:
|
||||
self.msg.motor_cmd[id].mode = 1
|
||||
self.msg.motor_cmd[id].kp = self.kp[id.value]
|
||||
self.msg.motor_cmd[id].kd = self.kd[id.value]
|
||||
self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q
|
||||
|
||||
# Initialize remote controller
|
||||
self.remote_controller = self.RemoteController()
|
||||
|
||||
def _subscribe_motor_state(self): # polls robot state @ 250Hz
|
||||
while True:
|
||||
start_time = time.time()
|
||||
msg = self.lowstate_subscriber.Read()
|
||||
if msg is not None:
|
||||
lowstate = G1_29_LowState()
|
||||
|
||||
# Capture motor states
|
||||
for id in range(G1_29_Num_Motors):
|
||||
lowstate.motor_state[id].q = msg.motor_state[id].q
|
||||
lowstate.motor_state[id].dq = msg.motor_state[id].dq
|
||||
lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est
|
||||
lowstate.motor_state[id].temperature = msg.motor_state[id].temperature
|
||||
|
||||
# Capture IMU state
|
||||
lowstate.imu_state.quaternion = list(msg.imu_state.quaternion)
|
||||
lowstate.imu_state.gyroscope = list(msg.imu_state.gyroscope)
|
||||
lowstate.imu_state.accelerometer = list(msg.imu_state.accelerometer)
|
||||
lowstate.imu_state.rpy = list(msg.imu_state.rpy)
|
||||
lowstate.imu_state.temperature = msg.imu_state.temperature
|
||||
|
||||
# Capture wireless remote data
|
||||
lowstate.wireless_remote = msg.wireless_remote
|
||||
|
||||
# Capture mode_machine
|
||||
lowstate.mode_machine = msg.mode_machine
|
||||
|
||||
self.lowstate_buffer.set_data(lowstate)
|
||||
|
||||
current_time = time.time()
|
||||
all_t_elapsed = current_time - start_time
|
||||
sleep_time = max(0, (self.control_dt - all_t_elapsed)) # maintain constant control dt
|
||||
time.sleep(sleep_time)
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
|
||||
|
||||
def calibrate(self) -> None: # robot is already calibrated
|
||||
pass
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None: # connect to DDS
|
||||
ChannelFactoryInitialize(0)
|
||||
|
||||
def disconnect(self):
|
||||
pass
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
return self.lowstate_buffer.get_data()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.lowstate_buffer.get_data() is not None
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
self.msg.crc = self.crc.Crc(action)
|
||||
self.lowcmd_publisher.Write(action)
|
||||
return action
|
||||
|
||||
def get_gravity_orientation(self, quaternion): # get gravity orientation from quaternion
|
||||
"""Get gravity orientation from quaternion."""
|
||||
qw = quaternion[0]
|
||||
qx = quaternion[1]
|
||||
qy = quaternion[2]
|
||||
qz = quaternion[3]
|
||||
|
||||
gravity_orientation = np.zeros(3)
|
||||
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
|
||||
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
|
||||
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
|
||||
return gravity_orientation
|
||||
168
src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py
Normal file
168
src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py
Normal file
@@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import zmq
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
|
||||
_ctx: zmq.Context | None = None
|
||||
_lowcmd_sock: zmq.Socket | None = None
|
||||
_lowstate_sock: zmq.Socket | None = None
|
||||
|
||||
LOWCMD_PORT = 6000
|
||||
LOWSTATE_PORT = 6001
|
||||
|
||||
# DDS topic names follow Unitree SDK naming conventions
|
||||
# ruff: noqa: N816
|
||||
kTopicLowCommand_Debug = "rt/lowcmd"
|
||||
|
||||
|
||||
class LowStateMsg:
|
||||
"""
|
||||
Wrapper class that mimics the Unitree SDK LowState_ message structure.
|
||||
|
||||
Reconstructs the message from deserialized JSON data to maintain
|
||||
compatibility with existing code that expects SDK message objects.
|
||||
"""
|
||||
|
||||
class MotorState:
|
||||
"""Motor state data for a single joint."""
|
||||
|
||||
def __init__(self, data: dict[str, Any]) -> None:
|
||||
self.q: float = data.get("q", 0.0)
|
||||
self.dq: float = data.get("dq", 0.0)
|
||||
self.tau_est: float = data.get("tau_est", 0.0)
|
||||
self.temperature: float = data.get("temperature", 0.0)
|
||||
|
||||
class IMUState:
|
||||
"""IMU sensor data."""
|
||||
|
||||
def __init__(self, data: dict[str, Any]) -> None:
|
||||
self.quaternion: list[float] = data.get("quaternion", [1.0, 0.0, 0.0, 0.0])
|
||||
self.gyroscope: list[float] = data.get("gyroscope", [0.0, 0.0, 0.0])
|
||||
self.accelerometer: list[float] = data.get("accelerometer", [0.0, 0.0, 0.0])
|
||||
self.rpy: list[float] = data.get("rpy", [0.0, 0.0, 0.0])
|
||||
self.temperature: float = data.get("temperature", 0.0)
|
||||
|
||||
def __init__(self, data: dict[str, Any]) -> None:
|
||||
"""Initialize from deserialized JSON data."""
|
||||
self.motor_state = [self.MotorState(m) for m in data.get("motor_state", [])]
|
||||
self.imu_state = self.IMUState(data.get("imu_state", {}))
|
||||
# Decode base64-encoded wireless_remote bytes
|
||||
wireless_b64 = data.get("wireless_remote", "")
|
||||
self.wireless_remote: bytes = base64.b64decode(wireless_b64) if wireless_b64 else b""
|
||||
self.mode_machine: int = data.get("mode_machine", 0)
|
||||
|
||||
|
||||
def lowcmd_to_dict(topic: str, msg: Any) -> dict[str, Any]:
|
||||
"""Convert LowCmd message to a JSON-serializable dictionary."""
|
||||
motor_cmds = []
|
||||
# Iterate over all motor commands in the message
|
||||
for i in range(len(msg.motor_cmd)):
|
||||
motor_cmds.append(
|
||||
{
|
||||
"mode": int(msg.motor_cmd[i].mode),
|
||||
"q": float(msg.motor_cmd[i].q),
|
||||
"dq": float(msg.motor_cmd[i].dq),
|
||||
"kp": float(msg.motor_cmd[i].kp),
|
||||
"kd": float(msg.motor_cmd[i].kd),
|
||||
"tau": float(msg.motor_cmd[i].tau),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"topic": topic,
|
||||
"data": {
|
||||
"mode_pr": int(msg.mode_pr),
|
||||
"mode_machine": int(msg.mode_machine),
|
||||
"motor_cmd": motor_cmds,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def ChannelFactoryInitialize(*args: Any, **kwargs: Any) -> None: # noqa: N802
|
||||
"""
|
||||
Initialize ZMQ sockets for robot communication.
|
||||
|
||||
This function mimics the Unitree SDK's ChannelFactoryInitialize but uses
|
||||
ZMQ sockets to connect to the robot server bridge instead of DDS.
|
||||
"""
|
||||
global _ctx, _lowcmd_sock, _lowstate_sock
|
||||
|
||||
# read socket config
|
||||
config = UnitreeG1Config()
|
||||
robot_ip = config.robot_ip
|
||||
|
||||
ctx = zmq.Context.instance()
|
||||
_ctx = ctx
|
||||
|
||||
# lowcmd: send robot commands
|
||||
lowcmd_sock = ctx.socket(zmq.PUSH)
|
||||
lowcmd_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
|
||||
lowcmd_sock.connect(f"tcp://{robot_ip}:{LOWCMD_PORT}")
|
||||
_lowcmd_sock = lowcmd_sock
|
||||
|
||||
# lowstate: receive robot observations
|
||||
lowstate_sock = ctx.socket(zmq.SUB)
|
||||
lowstate_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
|
||||
lowstate_sock.connect(f"tcp://{robot_ip}:{LOWSTATE_PORT}")
|
||||
lowstate_sock.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
_lowstate_sock = lowstate_sock
|
||||
|
||||
|
||||
class ChannelPublisher:
|
||||
"""ZMQ-based publisher that sends commands to the robot server."""
|
||||
|
||||
def __init__(self, topic: str, msg_type: type) -> None:
|
||||
self.topic = topic
|
||||
self.msg_type = msg_type
|
||||
|
||||
def Init(self) -> None: # noqa: N802
|
||||
"""Initialize the publisher (no-op for ZMQ)."""
|
||||
pass
|
||||
|
||||
def Write(self, msg: Any) -> None: # noqa: N802
|
||||
"""Serialize and send a command message to the robot."""
|
||||
if _lowcmd_sock is None:
|
||||
raise RuntimeError("ChannelFactoryInitialize must be called first")
|
||||
|
||||
payload = json.dumps(lowcmd_to_dict(self.topic, msg)).encode("utf-8")
|
||||
_lowcmd_sock.send(payload)
|
||||
|
||||
|
||||
class ChannelSubscriber:
|
||||
"""ZMQ-based subscriber that receives state from the robot server."""
|
||||
|
||||
def __init__(self, topic: str, msg_type: type) -> None:
|
||||
self.topic = topic
|
||||
self.msg_type = msg_type
|
||||
|
||||
def Init(self) -> None: # noqa: N802
|
||||
"""Initialize the subscriber (no-op for ZMQ)."""
|
||||
pass
|
||||
|
||||
def Read(self) -> LowStateMsg: # noqa: N802
|
||||
"""Receive and deserialize a state message from the robot."""
|
||||
if _lowstate_sock is None:
|
||||
raise RuntimeError("ChannelFactoryInitialize must be called first")
|
||||
|
||||
payload = _lowstate_sock.recv()
|
||||
msg_dict = json.loads(payload.decode("utf-8"))
|
||||
return LowStateMsg(msg_dict.get("data", {}))
|
||||
@@ -65,7 +65,6 @@ import argparse
|
||||
import gc
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -78,19 +77,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
return iter(self.frame_ids)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.frame_ids)
|
||||
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
assert chw_float32_torch.dtype == torch.float32
|
||||
assert chw_float32_torch.ndim == 3
|
||||
@@ -119,12 +105,10 @@ def visualize_dataset(
|
||||
repo_id = dataset.repo_id
|
||||
|
||||
logging.info("Loading dataloader")
|
||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
sampler=episode_sampler,
|
||||
)
|
||||
|
||||
logging.info("Starting Rerun")
|
||||
|
||||
Reference in New Issue
Block a user