Compare commits

..

24 Commits

Author SHA1 Message Date
Martino Russi
64ee8aa620 remove transform_imu_data 2025-11-28 15:47:20 +01:00
Martino Russi
6dd86b9f43 temperature can be list, average in such case 2025-11-28 14:38:47 +01:00
Martino Russi
30ea470f1c ensure robot is connected before changing mode 2025-11-28 13:43:28 +01:00
Martino Russi
4c834aa059 cast robot data to int/float 2025-11-28 13:38:32 +01:00
Michel Aractingi
b697e767ac remove globals use 2025-11-28 13:26:34 +01:00
Michel Aractingi
942eec9332 nit in docs 2025-11-27 17:35:37 +01:00
Michel Aractingi
414df3ef5b feat(robots): add Unitree G1 humanoid support with ZMQ bridge
- Use JSON + base64 serialization for secure communication instead of pickle
- Add documentation section
- Rename robot_server to run_g1_server
- Add dependecies to pyproject.toml
2025-11-27 17:32:24 +01:00
Martino Russi
53c678c29f push utils 2025-11-27 14:25:03 +01:00
Martino Russi
7fab4103ed add docs 2025-11-27 14:24:42 +01:00
Martino Russi
998d108424 fix linter 5 2025-11-27 13:15:31 +01:00
Martino Russi
d8e69c637a [done] make precommit happy 2025-11-27 13:12:27 +01:00
Martino Russi
fad06afe9b linter pt4 2025-11-27 13:09:11 +01:00
Martino Russi
0213011fec linter pt3 2025-11-27 11:02:55 +01:00
Martino Russi
bf0e7f4c63 make precommit happy, add ignore flags 2025-11-27 10:57:13 +01:00
Martino Russi
9a90b7dcb2 fix linter 2025-11-27 10:43:15 +01:00
Michel Aractingi
36ed02adfa download policy from the hub in examples/unitree_g1/gr00t_locomotion 2025-11-27 10:23:02 +01:00
Martino Russi
288cfc7f8e ready to review 2025-11-26 22:00:17 +01:00
Martino Russi
3ec332fabc properly comment config, example locomotion and unitree_g1 class 2025-11-26 21:27:45 +01:00
Martino Russi
55ee13aec1 format config 2025-11-26 18:26:56 +01:00
Martino Russi
dc2ebd4e12 remove leftover locomotion variable, unify kp kd 2025-11-26 18:26:02 +01:00
Martino Russi
3385350f2d separate groot locomotion logic 2025-11-26 17:17:02 +01:00
Martino Russi
d7481f653e precommit 2025-11-26 16:26:14 +01:00
Martino Russi
1bd91a04ce finish locomotion loading code 2025-11-26 16:12:53 +01:00
Martino Russi
d07c65eb9a add unitree_g1_robot_class 2025-11-26 15:51:11 +01:00
52 changed files with 1765 additions and 6503 deletions

View File

@@ -60,17 +60,12 @@ jobs:
runs-on: ubuntu-latest
env:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
lfs: true
- name: Setup /mnt storage
run: sudo chown -R $USER:$USER /mnt
# TODO(Steven): Evaluate the need of these dependencies
- name: Install apt dependencies
run: |
@@ -85,14 +80,8 @@ jobs:
version: ${{ env.UV_VERSION }}
python-version: ${{ env.PYTHON_VERSION }}
- name: Check disk usage
run: df -h
- name: Install lerobot with test extras
run: uv sync --extra "test"
- name: Check disk usage
run: df -h
- name: Run pytest
run: uv run pytest tests -vv --maxfail=10

View File

@@ -58,17 +58,12 @@ jobs:
github.event_name == 'workflow_dispatch'
env:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
steps:
- uses: actions/checkout@v4
with:
lfs: true
persist-credentials: false
- name: Setup /mnt storage
run: sudo chown -R $USER:$USER /mnt
- name: Install apt dependencies
run: |
sudo apt-get update && sudo apt-get install -y build-essential \
@@ -85,21 +80,12 @@ jobs:
- name: Install lerobot with all extras
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
- name: Check disk usage
run: df -h
- name: Run pytest (all extras)
run: uv run pytest tests -vv --maxfail=10
- name: Check disk usage
run: df -h
- name: Run end-to-end tests
run: uv run make test-end-to-end
- name: Check disk usage
run: df -h
# This job builds a GPU enabled image for testing
# It runs everytime a PR is approved or a push to main
# TODO(Steven): For now we skip this job for community PRs

View File

@@ -45,15 +45,11 @@ jobs:
runs-on: ubuntu-latest
env:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
steps:
- uses: actions/checkout@v4
with:
lfs: true
persist-credentials: false
- name: Setup /mnt storage
run: sudo chown -R $USER:$USER /mnt
- name: Install apt dependencies
run: |

View File

@@ -0,0 +1,94 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
from contextlib import ContextDecorator
class TimeBenchmark(ContextDecorator):
"""
Measures execution time using a context manager or decorator.
This class supports both context manager and decorator usage, and is thread-safe for multithreaded
environments.
Args:
print: If True, prints the elapsed time upon exiting the context or completing the function. Defaults
to False.
Examples:
Using as a context manager:
>>> benchmark = TimeBenchmark()
>>> with benchmark:
... time.sleep(1)
>>> print(f"Block took {benchmark.result:.4f} seconds")
Block took approximately 1.0000 seconds
Using with multithreading:
```python
import threading
benchmark = TimeBenchmark()
def context_manager_example():
with benchmark:
time.sleep(0.01)
print(f"Block took {benchmark.result_ms:.2f} milliseconds")
threads = []
for _ in range(3):
t1 = threading.Thread(target=context_manager_example)
threads.append(t1)
for t in threads:
t.start()
for t in threads:
t.join()
```
Expected output:
Block took approximately 10.00 milliseconds
Block took approximately 10.00 milliseconds
Block took approximately 10.00 milliseconds
"""
def __init__(self, print=False):
self.local = threading.local()
self.print_time = print
def __enter__(self):
self.local.start_time = time.perf_counter()
return self
def __exit__(self, *exc):
self.local.end_time = time.perf_counter()
self.local.elapsed_time = self.local.end_time - self.local.start_time
if self.print_time:
print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds")
return False
@property
def result(self):
return getattr(self.local, "elapsed_time", None)
@property
def result_ms(self):
return self.result * 1e3

View File

@@ -0,0 +1,102 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Capture video feed from a camera as raw images."""
import argparse
import datetime as dt
import os
import time
from pathlib import Path
import cv2
import rerun as rr
# see https://rerun.io/docs/howto/visualization/limit-ram
RERUN_MEMORY_LIMIT = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "5%")
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int, duration: int):
rr.init("lerobot_capture_camera_feed")
rr.spawn(memory_limit=RERUN_MEMORY_LIMIT)
now = dt.datetime.now()
capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}"
if not capture_dir.exists():
capture_dir.mkdir(parents=True, exist_ok=True)
# Opens the default webcam
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("Error: Could not open video stream.")
return
cap.set(cv2.CAP_PROP_FPS, fps)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
frame_index = 0
start_time = time.time()
while time.time() - start_time < duration:
ret, frame = cap.read()
if not ret:
print("Error: Could not read frame.")
break
rr.log("video/stream", rr.Image(frame), static=True)
cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame)
frame_index += 1
# Release the capture
cap.release()
# TODO(Steven): Add a graceful shutdown via a close() method for the Viewer context, though not currently supported in the Rerun API.
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--output-dir",
type=Path,
default=Path("outputs/cam_capture/"),
help="Directory where the capture images are written. A subfolder named with the current date & time will be created inside it for each capture.",
)
parser.add_argument(
"--fps",
type=int,
default=30,
help="Frames Per Second of the capture.",
)
parser.add_argument(
"--width",
type=int,
default=1280,
help="Width of the captured images.",
)
parser.add_argument(
"--height",
type=int,
default=720,
help="Height of the captured images.",
)
parser.add_argument(
"--duration",
type=int,
default=20,
help="Duration in seconds for which the video stream should be captured.",
)
args = parser.parse_args()
display_and_save_video_stream(**vars(args))

View File

@@ -21,13 +21,11 @@ See the provided README.md or run `python benchmark/video/run_video_benchmark.py
import argparse
import datetime as dt
import itertools
import random
import shutil
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from threading import Lock
import einops
import numpy as np
@@ -37,13 +35,13 @@ import torch
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
from tqdm import tqdm
from benchmarks.video.benchmark import TimeBenchmark
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.video_utils import (
decode_video_frames,
decode_video_frames_torchvision,
encode_video_frames,
)
from lerobot.utils.constants import OBS_IMAGE
from lerobot.utils.utils import TimerManager
BASE_ENCODING = OrderedDict(
[
@@ -88,7 +86,7 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
frames = []
for ts in timestamps:
idx = int(ts * fps)
frame = PIL.Image.open(imgs_dir / f"frame-{idx:06d}.png")
frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
frame = torch.from_numpy(np.array(frame))
frame = frame.type(torch.float32) / 255
frame = einops.rearrange(frame, "h w c -> c h w")
@@ -99,21 +97,21 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
def save_decoded_frames(
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
) -> None:
if save_dir.exists() and len(list(save_dir.glob("frame-*.png"))) == len(timestamps):
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
return
save_dir.mkdir(parents=True, exist_ok=True)
for i, ts in enumerate(timestamps):
idx = int(ts * fps)
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame-{idx:06d}_decoded.png")
shutil.copyfile(imgs_dir / f"frame-{idx:06d}.png", save_dir / f"frame-{idx:06d}_original.png")
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
episode_index = 0
ep_num_images = dataset.meta.episodes["length"][episode_index]
if imgs_dir.exists() and len(list(imgs_dir.glob("frame-*.png"))) == ep_num_images:
if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images:
return
imgs_dir.mkdir(parents=True, exist_ok=True)
@@ -127,7 +125,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
):
img = item[img_keys[0]]
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
if i >= ep_num_images - 1:
break
@@ -151,6 +149,18 @@ def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> lis
return [idx / fps for idx in frame_indexes]
def decode_video_frames(
video_path: str,
timestamps: list[float],
tolerance_s: float,
backend: str,
) -> torch.Tensor:
if backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise NotImplementedError(backend)
def benchmark_decoding(
imgs_dir: Path,
video_path: Path,
@@ -162,8 +172,8 @@ def benchmark_decoding(
num_workers: int = 4,
save_frames: bool = False,
) -> dict:
def process_sample(sample: int, lock: Lock):
time_benchmark = TimerManager(log=False)
def process_sample(sample: int):
time_benchmark = TimeBenchmark()
timestamps = sample_timestamps(timestamps_mode, ep_num_images, fps)
num_frames = len(timestamps)
result = {
@@ -172,13 +182,13 @@ def benchmark_decoding(
"mse_values": [],
}
with time_benchmark, lock:
with time_benchmark:
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
result["load_time_video_ms"] = (time_benchmark.last * 1000) / num_frames
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames
with time_benchmark:
original_frames = load_original_frames(imgs_dir, timestamps, fps)
result["load_time_images_ms"] = (time_benchmark.last * 1000) / num_frames
result["load_time_images_ms"] = time_benchmark.result_ms / num_frames
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
for i in range(num_frames):
@@ -205,10 +215,8 @@ def benchmark_decoding(
# A sample is a single set of decoded frames specified by timestamps_mode (e.g. a single frame, 2 frames, etc.).
# For each sample, we record metrics (loading time and quality metrics) which are then averaged over all samples.
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
# Use a single shared lock for all worker threads
shared_lock = Lock()
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_sample, i, shared_lock) for i in range(num_samples)]
futures = [executor.submit(process_sample, i) for i in range(num_samples)]
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
result = future.result()
load_times_video_ms.append(result["load_time_video_ms"])
@@ -350,27 +358,24 @@ def main(
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
# We only use the first episode
save_first_episode(imgs_dir, dataset)
for duet in [
dict(zip(encoding_benchmarks.keys(), unique_combination, strict=False))
for unique_combination in itertools.product(*encoding_benchmarks.values())
]:
encoding_cfg = BASE_ENCODING.copy()
encoding_cfg["vcodec"] = video_codec
encoding_cfg["pix_fmt"] = pixel_format
for key, value in duet.items():
for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False):
for value in tqdm(values, desc=f"encodings ({key})", leave=False):
encoding_cfg = BASE_ENCODING.copy()
encoding_cfg["vcodec"] = video_codec
encoding_cfg["pix_fmt"] = pixel_format
encoding_cfg[key] = value
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
benchmark_table += benchmark_encoding_decoding(
dataset,
video_path,
imgs_dir,
encoding_cfg,
decoding_benchmarks,
num_samples,
num_workers,
save_frames,
)
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
benchmark_table += benchmark_encoding_decoding(
dataset,
video_path,
imgs_dir,
encoding_cfg,
decoding_benchmarks,
num_samples,
num_workers,
save_frames,
)
# Save intermediate results
benchmark_df = pd.DataFrame(benchmark_table, columns=headers)
@@ -404,9 +409,9 @@ if __name__ == "__main__":
nargs="*",
default=[
"lerobot/pusht_image",
"lerobot/aloha_mobile_shrimp_image",
"lerobot/paris_street",
"lerobot/kitchen",
"aliberts/aloha_mobile_shrimp_image",
"aliberts/paris_street",
"aliberts/kitchen",
],
help="Datasets repo-ids to test against. First episodes only are used. Must be images.",
)
@@ -414,7 +419,7 @@ if __name__ == "__main__":
"--vcodec",
type=str,
nargs="*",
default=["h264", "hevc", "libsvtav1"],
default=["libx264", "hevc", "libsvtav1"],
help="Video codecs to be tested",
)
parser.add_argument(
@@ -463,7 +468,7 @@ if __name__ == "__main__":
"--backends",
type=str,
nargs="*",
default=["torchcodec", "pyav"],
default=["pyav", "video_reader"],
help="Torchvision decoding backend to be tested.",
)
parser.add_argument(

View File

@@ -37,8 +37,6 @@
title: π₀.₅ (Pi05)
- local: groot
title: NVIDIA GR00T N1.5
- local: xvla
title: X-VLA
title: "Policies"
- sections:
- local: async
@@ -81,6 +79,8 @@
title: Hope Jr
- local: reachy2
title: Reachy 2
- local: unitree_g1
title: Unitree G1
title: "Robots"
- sections:
- local: phone_teleop

View File

@@ -139,7 +139,7 @@ from lerobot.teleoperators import ( # noqa: F401
make_teleoperator_from_config,
so101_leader,
)
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import init_logging
from lerobot.envs.factory import make_env
@@ -196,7 +196,7 @@ def teleop_loop(teleop: Teleoperator, env: gym.Env, fps: int):
obs, info = env.reset()
dt_s = time.perf_counter() - loop_start
precise_sleep(1 / fps - dt_s)
busy_wait(1 / fps - dt_s)
loop_s = time.perf_counter() - loop_start
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")

View File

@@ -393,7 +393,7 @@ import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
episode_idx = 0
@@ -415,7 +415,7 @@ for idx in range(dataset.num_frames):
}
robot.send_action(action)
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
robot.disconnect()
```

View File

@@ -62,11 +62,6 @@ lerobot-eval \
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
### Control Mode
LIBERO now supports two control modes: relative and absolute. This matters because different VLA checkpoints are trained with different mode of action to output hence control parameterizations.
You can switch them with: `env.control_mode = "relative"` and `env.control_mode = "absolute"`
### Policy inputs and outputs
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:

240
docs/source/unitree_g1.mdx Normal file
View File

@@ -0,0 +1,240 @@
# Unitree G1 Robot Setup and Control
This guide covers the complete setup process for the Unitree G1 humanoid robot, from initial connection to running locomotion policies.
## 🤖 About the Unitree G1
The Unitree G1 humanoid comes in two flavors: 29-DOF and 23-DOF humanoid robot capable of whole-body control, manipulation, and locomotion. In this first PR we introduce:
- **Low-level motor control** via DDS (Data Distribution Service)
- **ZMQ socket bridge** for remote communication over WiFi, allowing one to deploy policies remotely instead of over ethernet or directly on the Orin
- **GR00T locomotion policiey** for bipedal walking and balance
---
## Part 1: Connect to Robot over Ethernet
### Step 1: Configure Your Computer's Ethernet Interface
Set a static IP on the same subnet as the robot:
```bash
# Replace 'enp131s0' with your ethernet interface name (check with `ip a`)
sudo ip addr flush dev enp131s0
sudo ip addr add 192.168.123.200/24 dev enp131s0
sudo ip link set enp131s0 up
```
> **Note**: The robot's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` where x ≠ 164.
### Step 2: SSH into the Robot
```bash
ssh unitree@192.168.123.164
# Password: 123
```
You should now be connected to the robot's onboard computer.
---
## Part 2: Enable WiFi on the Robot
Once connected via Ethernet, follow these steps to enable WiFi:
### Step 1: Enable WiFi Hardware
```bash
# Unblock WiFi radio
sudo rfkill unblock wifi
sudo rfkill unblock all
# Bring up WiFi interface
sudo ip link set wlan0 up
# Enable NetworkManager control
sudo nmcli radio wifi on
sudo nmcli device set wlan0 managed yes
sudo systemctl restart NetworkManager
```
### Step 2: Enable Internet Forwarding
**On your laptop:**
```bash
# Enable IP forwarding
sudo sysctl -w net.ipv4.ip_forward=1
# Set up NAT (replace wlp132s0f0 with your WiFi interface)
sudo iptables -t nat -A POSTROUTING -o wlp132s0f0 -s 192.168.123.0/24 -j MASQUERADE
sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTABLISHED -j ACCEPT
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
```
**On the robot:**
```bash
# Add laptop as default gateway
sudo ip route del default 2>/dev/null || true
sudo ip route add default via 192.168.123.200 dev eth0
echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
# Test connection
ping -c 3 8.8.8.8
```
### Step 3: Connect to WiFi Network
```bash
# List available networks
nmcli device wifi list
# Connect to your WiFi (example)
sudo nmcli connection add type wifi ifname wlan0 con-name "YourNetwork" ssid "YourNetwork"
sudo nmcli connection modify "YourNetwork" wifi-sec.key-mgmt wpa-psk
sudo nmcli connection modify "YourNetwork" wifi-sec.psk "YourPassword"
sudo nmcli connection modify "YourNetwork" connection.autoconnect yes
sudo nmcli connection up "YourNetwork"
# Check WiFi IP address
ip a show wlan0
```
### Step 4: SSH Over WiFi
Once connected to WiFi, note the robot's IP address (e.g., `172.18.129.215`) and disconnect the Ethernet cable. You can now SSH over WiFi:
```bash
ssh unitree@172.18.129.215
# Password: 123
```
---
## Part 3: Robot Server Setup
The robot server introduced here acts as a DDS-to-ZMQ bridge, allowing your one to control the robot wirelessly.
### Step 1: Copy Server Script to Robot
From your laptop, copy the robot server script:
```bash
# Copy the server script and its dependencies
scp src/lerobot/robots/unitree_g1/run_g1_server.py unitree@172.18.129.215:~/run_g1_server.py
scp src/lerobot/robots/unitree_g1/g1_utils.py unitree@172.18.129.215:~/g1_utils.py
```
### Step 2: Install Dependencies on Robot
SSH into the robot and install required packages:
```bash
ssh unitree@172.18.129.215
# Install build tools and Python dependencies
sudo apt update
sudo apt install -y build-essential python3-dev python3-pip
# Install Python packages (pyzmq and Unitree SDK)
pip3 install pyzmq
pip3 install git+https://github.com/unitreerobotics/unitree_sdk2_python.git
```
> **Note**: The Unitree SDK requires CycloneDDS v0.10.2 to be installed. See the [Unitree SDK documentation](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
### Step 3: Run the Robot Server
On the robot:
```bash
python3 ~/run_g1_server.py
```
You should see output like:
```
Robot server listening on:
Commands: tcp://*:6000 (PULL)
State: tcp://*:6001 (PUB)
DDS initialized, forwarding started...
```
> **Important**: Keep this terminal running. The server must be active for remote control.
---
## 🚶 Part 4: Running GR00T Locomotion
With the robot server running, you can now control the robot from your laptop.
### Step 1: Install LeRobot with Unitree G1 Support (on your laptop)
```bash
pip install -e '.[unitree_g1]'
```
### Step 2: Update Robot IP in Config
Edit the config file to match your robot's WiFi IP:
```python
# In src/lerobot/robots/unitree_g1/config_unitree_g1.py
robot_ip: str = "172.18.129.215" # Your robot's WiFi IP
```
### Step 3: Run the Locomotion Policy
```bash
# Run GR00T locomotion controller (downloads policies from HuggingFace)
python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1"
# Or use the default repo (same as above):
python examples/unitree_g1/gr00t_locomotion.py
```
The script will:
1. Download Balance and Walk policies from the Hub (cached locally after first run)
2. Connect to the robot server over WiFi/ZMQ
3. Initialize the robot and locomotion controller
4. Gradually move legs to default standing position (3 seconds)
5. Start locomotion control loop at 50Hz in background thread
6. Accept commands from the wireless remote controller
**Expected output:**
```
INFO - Loading GR00T Balance policy...
INFO - Loading GR00T Walk policy...
INFO - [UnitreeG1] Initialize UnitreeG1...
INFO - [UnitreeG1] Connected to robot.
INFO - Reached default position (legs only)
INFO - Locomotion control thread started!
INFO - Robot initialized with GR00T locomotion policies
INFO - Locomotion controller running in background thread
INFO - Press Ctrl+C to stop
```
### Step 4: Control with Remote
- **Left stick**: Forward/backward and left/right movement
- **Right stick**: Rotation
- **R1 button**: Raise waist height
- **R2 button**: Lower waist height
To stop, press `Ctrl+C` in the terminal.
---
## Additional Resources
- [Unitree SDK Documentation](https://github.com/unitreerobotics/unitree_sdk2_python)
- [GR00T Policy Repository](https://huggingface.co/nepyope/GR00T-WholeBodyControl_g1)
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
- [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
---
_Last updated: November 2025_

View File

@@ -1,543 +0,0 @@
# X-VLA: The First Soft-Prompted Robot Foundation Model for Any Robot, Any Task
## Overview
For years, robotics has aspired to build agents that can follow natural human instructions and operate dexterously across many environments and robot bodies. Recent breakthroughs in LLMs and VLMs suggest a path forward: extend these foundation-model architectures to embodied control by grounding them in actions. This has led to the rise of Vision-Language-Action (VLA) models, with the hope that a single generalist model could combine broad semantic understanding with robust manipulation skills.
But training such models is difficult. Robot data is fragmented across platforms, sensors, embodiments, and collection protocols. Heterogeneity appears everywhere: different arm configurations, different action spaces, different camera setups, different visual domains, and different task distributions. These inconsistencies create major distribution shifts that make pretraining unstable and adaptation unreliable.
Inspired by meta-learning and prompt learning, we ask: **"What if a VLA model could learn the structure of each robot and dataset the same way LLMs learn tasks, through prompts?"**
**X-VLA** is a soft-prompted, flow-matching VLA framework that treats each hardware setup as a "task" and encodes it using a small set of learnable embeddings. These **Soft Prompts** capture embodiment and domain-specific variations, guiding the Transformer from the earliest stages of multimodal fusion. With this mechanism, X-VLA can reconcile diverse robot morphologies, data types, and sensor setups within a single unified architecture.
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture.png" width="400">
Built from pure Transformer encoders, X-VLA scales naturally with model size and dataset diversity. Across 6 simulation benchmarks and 3 real robots, Soft Prompts consistently outperform existing methods in handling hardware and domain differences. X-VLA-0.9B, trained on 290K episodes spanning seven robotic platforms, learns an embodiment-agnostic generalist policy in Phase I, and adapts efficiently to new robots in Phase II simply by learning a new set of prompts, while keeping the backbone frozen.
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture2.png" width="400">
With only 1% of parameters tuned (9M), X-VLA-0.9B achieves near-π₀ performance on LIBERO and Simpler-WidowX, despite using **300× fewer trainable parameters**. It also demonstrates strong real-world dexterity with minimal demonstrations, including folding cloths in under two minutes.
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-fold.png" width="400">
X-VLA shows that generalist robot intelligence does not require increasingly complex architectures, only the right way to absorb heterogeneity. Soft Prompts offer a simple, scalable mechanism for unifying diverse robotic data, paving the way toward adaptable, cross-embodiment robot foundation models.
---
## Installation
After installing LeRobot, install the X-VLA dependencies:
```bash
pip install -e .[xvla]
```
After the new release, you'll be able to do:
```bash
pip install lerobot[xvla]
```
---
## Quick Start
### Basic Usage
To use X-VLA in your LeRobot configuration, specify the policy type as:
```bash
policy.type=xvla
```
### Evaluating Pre-trained Checkpoints
Example evaluation with LIBERO:
```bash
lerobot-eval \
--policy.path="lerobot/xvla-libero" \
--env.type=libero \
--env.task=libero_spatial,libero_goal,libero_10 \
--env.control_mode=absolute \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--env.episode_length=800 \
--seed=142
```
---
## Available Checkpoints
### 🎯 Base Model
**[lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base)**
A 0.9B parameter instantiation of X-VLA, trained with a carefully designed data processing and learning recipe. The training pipeline consists of two phases:
- **Phase I: Pretraining** - Pretrained on 290K episodes from Droid, Robomind, and Agibot, spanning seven platforms across five types of robotic arms (single-arm to bi-manual setups). By leveraging soft prompts to absorb embodiment-specific variations, the model learns an embodiment-agnostic generalist policy.
- **Phase II: Domain Adaptation** - Adapted to deployable policies for target domains. A new set of soft prompts is introduced and optimized to encode the hardware configuration of the novel domain, while the pretrained backbone remains frozen.
### 🎮 Simulation Checkpoints
**[lerobot/xvla-libero](https://huggingface.co/lerobot/xvla-libero)**
Achieves 93% success rate on LIBERO benchmarks. Fine-tuned from the base model for simulation tasks.
**[lerobot/xvla-widowx](https://huggingface.co/lerobot/xvla-widowx)**
Fine-tuned on BridgeData for pick-and-place experiments on compact WidowX platforms. Demonstrates robust manipulation capabilities.
### 🤖 Real-World Checkpoints
**[lerobot/xvla-folding](https://huggingface.co/lerobot/xvla-folding)**
A fine-tuned dexterous manipulation model trained on the high-quality Soft-FOLD cloth folding dataset. Achieves 100% success rate over 2 hours of continuous cloth folding.
**[lerobot/xvla-agibot-world](https://huggingface.co/lerobot/xvla-agibot-world)**
Optimized for AgileX robot dexterous manipulation tasks.
**[lerobot/xvla-google-robot](https://huggingface.co/lerobot/xvla-google-robot)**
Adapted for Google Robot platforms.
---
## Training X-VLA
### Recommended Training Configuration
When fine-tuning X-VLA for a new embodiment or task, we recommend the following freezing strategy:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/xvla_training \
--job_name=xvla_training \
--policy.path="lerobot/xvla-base" \
--policy.repo_id="HF_USER/xvla-your-robot" \
--steps=3000 \
--policy.device=cuda \
--policy.freeze_vision_encoder=True \
--policy.freeze_language_encoder=True \
--policy.train_policy_transformer=True \
--policy.train_soft_prompts=True \
--policy.action_mode=YOUR_ACTION_MODE
```
### Training Parameters Explained
| Parameter | Default | Description |
| -------------------------- | ------- | ---------------------------------------- |
| `freeze_vision_encoder` | `True` | Freeze the VLM vision encoder weights |
| `freeze_language_encoder` | `True` | Freeze the VLM language encoder weights |
| `train_policy_transformer` | `True` | Allow policy transformer layers to train |
| `train_soft_prompts` | `True` | Allow soft prompts to train |
**💡 Best Practice**: For Phase II adaptation to new embodiments, freeze the VLM encoders and only train the policy transformer and soft prompts. This provides excellent sample efficiency with minimal compute.
### Example: Training on Bimanual Robot
```bash
lerobot-train \
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
--output_dir=./outputs/xvla_bimanual \
--job_name=xvla_so101_training \
--policy.path="lerobot/xvla-base" \
--policy.repo_id="YOUR_USERNAME/xvla-biso101" \
--steps=3000 \
--policy.device=cuda \
--policy.action_mode=so101_bimanual \
--policy.freeze_vision_encoder=True \
--policy.freeze_language_encoder=True \
--policy.train_policy_transformer=True \
--policy.train_soft_prompts=True
```
💡 **Best Performance:** If you have sufficient computational resources and want to achieve best X-VLA finetuning performance, you should follow the official finetuning strategy:
**🔥 Full-finetune all components with a custom learning-rate scheme**
To ensure stable optimization, the Vision-Language Model (VLM) must be trained with only 1/10 of the base learning rate, while all other components use the full LR.
This LR ratio is crucial for achieving strong and stable finetuning performance.
To enable this behavior, you must:
1. Implement a custom optimizer and register it in your training config
```
from dataclasses import dataclass, asdict
from lerobot.optim.optimizers import OptimizerConfig
import torch
@OptimizerConfig.register_subclass("xvla-adamw")
@dataclass
class XVLAAdamW(OptimizerConfig):
lr: float = 1e-4
betas: tuple[float, float] = (0.9, 0.99)
eps: float = 1e-8
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
def build(self, params: dict) -> torch.optim.Optimizer:
"""
Expect `named_parameters()` as input.
Apply lr = lr / 10 for all VLM-related parameters.
"""
assert isinstance(params, dict), \
"Custom LR optimizer requires `named_parameters()` as inputs."
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
vlm_group, other_group = [], []
for name, p in params.items():
if not p.requires_grad:
continue
if "vlm" in name.lower():
vlm_group.append(p)
else:
other_group.append(p)
param_groups = [
{"params": vlm_group, "lr": self.lr * 0.1, "weight_decay": self.weight_decay * 0.1},
{"params": other_group, "lr": self.lr, "weight_decay": self.weight_decay},
]
return torch.optim.AdamW(param_groups, **kwargs)
```
2. Modify X-VLAs get_optim_params to return named parameters
Replace:
```
def get_optim_params(self) -> dict:
"""Return only trainable parameters for optimization."""
return filter(lambda p: p.requires_grad, self.parameters())
```
with:
```
def get_optim_params(self):
"""Return trainable named parameters."""
return filter(lambda kv: kv[1].requires_grad, self.named_parameters())
```
This ensures the optimizer receives a dict of named parameters, allowing it to correctly detect VLM modules and apply the 1/10 LR rule.
❕Note
Completely matching the official reported performance may require an additional warm-up LR schedule for soft-prompts, which can bring minor improvements.
We encourage implementing this in your customized training pipeline for optimal results.
---
## Core Concepts
### 1. Action Modes
X-VLA uses an **Action Registry** system to handle different action spaces and embodiments. The `action_mode` parameter defines how actions are processed, what loss functions are used, and how predictions are post-processed.
#### Available Action Modes
| Action Mode | Action Dim | Description | Use Case |
| ---------------- | --------------------- | ------------------------------------------- | ------------------------------------ |
| `ee6d` | 20 | End-effector with xyz, 6D rotation, gripper | Dual-arm setups with spatial control |
| `joint` | 14 | Joint-space with gripper | Direct joint control robots |
| `agibot_ee6d` | 20 | AGI-bot variant with MSE loss | AGI-bot platforms |
| `franka_joint7` | 7 | Franka Panda 7-joint control | Franka robots without gripper |
| `so101_bimanual` | 20 (model), 12 (real) | SO101 bimanual robot | Bimanual manipulation tasks |
#### Why Action Modes Matter
When you have a pretrained checkpoint like `lerobot/xvla-base` trained with `action_dim=20`, and you want to train on a dataset with a different action dimension (e.g., 14 for bimanual arms), you can't simply trim the action dimension. The action mode orchestrates:
1. **Loss Computation**: Different loss functions for different action components (MSE for joints, BCE for grippers, etc.)
2. **Preprocessing**: Zeroing out gripper channels, padding dimensions
3. **Postprocessing**: Applying sigmoid to gripper logits, trimming padding
#### Example: BimanualSO101 Action Space
The `so101_bimanual` action mode handles the mismatch between model output (20D) and real robot control (12D):
```python
# Model outputs 20 dimensions for compatibility
dim_action = 20
# Real robot only needs 12 dimensions
# [left_arm (6), right_arm (6)] = [joints (5) + gripper (1)] × 2
REAL_DIM = 12
# Preprocessing: Pad 12D actions to 20D for training
# Postprocessing: Trim 20D predictions to 12D for deployment
```
See the [action_hub.py](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py) implementation for details.
### 2. Domain IDs
Domain IDs are learnable identifiers for different robot configurations and camera setups. They allow X-VLA to distinguish between:
- Different robots (Robot 1 vs Robot 2)
- Different camera configurations (cam1 vs cam2)
- Different combinations (Robot1-cam1-cam2 vs Robot1-cam1 vs Robot2-cam1)
#### Setting Domain IDs
**During Training**: By default, domain_id is set to 0 for general training.
**During Evaluation**: Specify the domain_id that matches your checkpoint's training configuration.
```python
# Example: LIBERO checkpoint uses domain_id=3
domain_id = 3
```
The domain_id is automatically added to observations by the `XVLAAddDomainIdProcessorStep` in the preprocessing pipeline.
### 3. Processor Steps
X-VLA requires specific preprocessing and postprocessing steps for proper operation.
#### Required Preprocessing Steps
1. **XVLAImageToFloatProcessorStep**: Converts images from [0, 255] to [0, 1] range
2. **XVLAImageNetNormalizeProcessorStep**: Applies ImageNet normalization (required for VLM backbone)
3. **XVLAAddDomainIdProcessorStep**: Adds domain_id to observations
#### Example Custom Processor
For LIBERO environments, a custom processor handles the specific observation format:
```python
from lerobot.policies.xvla.processor_xvla import LiberoProcessorStep
processor = LiberoProcessorStep()
# Handles robot_state dictionary, converts rotation matrices to 6D representation
# Applies 180° image rotation for camera convention
```
### 4. Configuration Parameters
Key configuration parameters for X-VLA:
```python
# Observation and action
n_obs_steps: int = 1 # Number of observation timesteps
chunk_size: int = 32 # Action sequence length
n_action_steps: int = 32 # Number of action steps to execute
# Model architecture
hidden_size: int = 1024 # Transformer hidden dimension
depth: int = 24 # Number of transformer layers
num_heads: int = 16 # Number of attention heads
num_domains: int = 30 # Maximum number of domain IDs
len_soft_prompts: int = 32 # Length of soft prompt embeddings
# Action space
action_mode: str = "ee6d" # Action space type
use_proprio: bool = True # Use proprioceptive state
max_state_dim: int = 32 # Maximum state dimension
# Vision
num_image_views: int | None # Number of camera views
resize_imgs_with_padding: tuple[int, int] | None # Target image size with padding
# Training
num_denoising_steps: int = 10 # Flow matching denoising steps
```
---
## Creating Custom Action Modes
If your robot has a unique action space, you can create a custom action mode:
### Step 1: Define Your Action Space
```python
from lerobot.policies.xvla.action_hub import BaseActionSpace, register_action
import torch.nn as nn
@register_action("my_custom_robot")
class MyCustomActionSpace(BaseActionSpace):
"""Custom action space for my robot."""
dim_action = 15 # Your robot's action dimension
gripper_idx = (7, 14) # Gripper channel indices
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
def compute_loss(self, pred, target):
"""Define your loss computation."""
# Example: MSE for joints, BCE for grippers
joints_loss = self.mse(pred[:, :, :7], target[:, :, :7])
gripper_loss = self.bce(pred[:, :, self.gripper_idx],
target[:, :, self.gripper_idx])
return {
"joints_loss": joints_loss,
"gripper_loss": gripper_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""Preprocess actions before training."""
# Example: Zero out grippers in proprioception
proprio_m = proprio.clone()
action_m = action.clone() if action is not None else None
proprio_m[..., self.gripper_idx] = 0.0
if action_m is not None:
action_m[..., self.gripper_idx] = 0.0
return proprio_m, action_m
def postprocess(self, action):
"""Post-process predictions for deployment."""
# Example: Apply sigmoid to gripper logits
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
return action
```
### Step 2: Use Your Custom Action Mode
```bash
lerobot-train \
--policy.action_mode=my_custom_robot \
--dataset.repo_id=YOUR_DATASET \
--policy.path="lerobot/xvla-base" \
...
```
---
## Advanced Topics
### Multi-Camera Support
X-VLA supports multiple camera views through the `num_image_views` parameter:
```python
# Configure for 3 camera views
policy.num_image_views=3
# Add empty cameras if you have fewer physical cameras
policy.empty_cameras=1 # Adds 1 zero-padded camera view
```
### Custom Preprocessing Pipeline
Create a custom preprocessing pipeline for your environment:
```python
from lerobot.processor import PolicyProcessorPipeline
from lerobot.policies.xvla.processor_xvla import (
XVLAImageToFloatProcessorStep,
XVLAImageNetNormalizeProcessorStep,
XVLAAddDomainIdProcessorStep,
)
# Build custom pipeline
preprocessor = PolicyProcessorPipeline(
steps=[
YourCustomProcessorStep(), # Your custom processing
XVLAImageToFloatProcessorStep(), # Required: convert to float
XVLAImageNetNormalizeProcessorStep(), # Required: ImageNet norm
XVLAAddDomainIdProcessorStep(domain_id=5), # Your domain ID
]
)
```
### Handling Different Action Dimensions
When your dataset has fewer action dimensions than the pretrained model:
**Option 1**: Use padding (automatic in most action modes)
```python
# Model expects 20D, dataset has 12D
# Action mode handles padding internally
action_mode = "so101_bimanual" # Pads 12 → 20
```
**Option 2**: Create a custom action mode that maps dimensions explicitly
```python
@register_action("my_mapped_action")
class MappedActionSpace(BaseActionSpace):
dim_action = 20
REAL_DIM = 12
def _pad_to_model_dim(self, x):
# Custom padding logic
...
```
---
## Troubleshooting
### Common Issues
**Issue**: "Action dimension mismatch"
- **Solution**: Check that your `action_mode` matches your robot's action space. Create a custom action mode if needed.
**Issue**: "Image values outside [0, 1] range"
- **Solution**: Ensure images are preprocessed with `XVLAImageToFloatProcessorStep` before normalization.
**Issue**: "Domain ID not found"
- **Solution**: Make sure `XVLAAddDomainIdProcessorStep` is in your preprocessing pipeline with the correct domain_id.
**Issue**: "Low success rate on new embodiment"
- **Solution**:
1. Verify your action_mode is correct
2. Check that soft prompts are being trained (`train_soft_prompts=True`)
3. Ensure proper preprocessing (ImageNet normalization, domain_id)
4. Consider increasing training steps
**Issue**: "Out of memory during training"
- **Solution**:
1. Reduce `chunk_size` (e.g., from 32 to 16)
2. Enable gradient checkpointing
3. Reduce batch size
4. Freeze more components
---
## Citation
If you use X-VLA in your research, please cite:
```bibtex
@article{zheng2025x,
title = {X-VLA: Soft-Prompted Transformer as Scalable Cross-Embodiment Vision-Language-Action Model},
author = {Zheng, Jinliang and Li, Jianxiong and Wang, Zhihao and Liu, Dongxiu and Kang, Xirui
and Feng, Yuchun and Zheng, Yinan and Zou, Jiayin and Chen, Yilun and Zeng, Jia and others},
journal = {arXiv preprint arXiv:2510.10274},
year = {2025}
}
```
---
## Additional Resources
- [X-VLA Paper](https://arxiv.org) (coming soon)
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
- [Action Registry Implementation](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py)
- [Processor Implementation](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/processor_xvla.py)
- [Model Configuration](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/configuration_xvla.py)
---
## Contributing
We welcome contributions! If you've implemented a new action mode or processor for your robot, please consider submitting a PR to help the community.

View File

@@ -45,7 +45,7 @@ from lerobot.robots import ( # noqa: F401
so101_follower,
)
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
init_logging,
log_say,
@@ -97,7 +97,7 @@ def replay(cfg: ReplayConfig):
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
precise_sleep(1 / dataset.fps - dt_s)
busy_wait(1 / dataset.fps - dt_s)
robot.disconnect()

View File

@@ -20,7 +20,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
EPISODE_IDX = 0
@@ -58,7 +58,7 @@ def main():
# Send action to robot
_ = robot.send_action(action)
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
busy_wait(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
robot.disconnect()

View File

@@ -19,7 +19,7 @@ import time
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop, KeyboardTeleopConfig
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
FPS = 30
@@ -71,7 +71,7 @@ def main():
# Visualize
log_rerun_data(observation=observation, action=action)
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
if __name__ == "__main__":

View File

@@ -29,7 +29,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
EPISODE_IDX = 0
@@ -96,7 +96,7 @@ def main():
# Send action to robot
_ = robot.send_action(joint_action)
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
# Clean up
robot.disconnect()

View File

@@ -32,7 +32,7 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.teleoperators.phone.teleop_phone import Phone
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
FPS = 30
@@ -114,7 +114,7 @@ def main():
# Visualize
log_rerun_data(observation=phone_obs, action=joint_action)
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
if __name__ == "__main__":

View File

@@ -30,7 +30,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
EPISODE_IDX = 0
@@ -97,7 +97,7 @@ def main():
# Send action to robot
_ = robot.send_action(joint_action)
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
# Clean up
robot.disconnect()

View File

@@ -32,7 +32,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
FPS = 30
@@ -120,7 +120,7 @@ def main():
# Visualize
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
if __name__ == "__main__":

View File

@@ -0,0 +1,367 @@
#!/usr/bin/env python
"""
Example: GR00T Locomotion with Pre-loaded Policies
This example demonstrates the NEW pattern for loading GR00T policies externally
and passing them to the robot class.
"""
import argparse
import logging
import threading
import time
from collections import deque
import numpy as np
import onnxruntime as ort
import torch
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
logger = logging.getLogger(__name__)
GROOT_DEFAULT_ANGLES = np.array(
[
-0.1,
0.0,
0.0,
0.3,
-0.2,
0.0, # left leg
-0.1,
0.0,
0.0,
0.3,
-0.2,
0.0, # right leg
0.0,
0.0,
0.0, # waist
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0, # left arm
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0, # right arm
],
dtype=np.float32,
)
G1_MODEL = "g1_23"
if G1_MODEL == "g1_23":
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # waist yaw/pitch, wrist pitch/yaw
elif G1_MODEL == "g1_29":
MISSING_JOINTS = [] # waist yaw/pitch, wrist pitch/yaw
LOCOMOTION_ACTION_SCALE = 0.25
LOCOMOTION_CONTROL_DT = 0.02
ANG_VEL_SCALE: float = 0.25
DOF_POS_SCALE: float = 1.0
DOF_VEL_SCALE: float = 0.05
CMD_SCALE: list = [2.0, 2.0, 0.25]
DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
def load_groot_policies(
repo_id: str = DEFAULT_GROOT_REPO_ID,
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
"""Load GR00T dual-policy system (Balance + Walk) from Hugging Face Hub.
Args:
repo_id: Hugging Face Hub repository ID containing the ONNX policies.
"""
logger.info(f"Loading GR00T dual-policy system from Hugging Face Hub ({repo_id})...")
# Download ONNX policies from Hugging Face Hub
balance_path = hf_hub_download(
repo_id=repo_id,
filename="GR00T-WholeBodyControl-Balance.onnx",
)
walk_path = hf_hub_download(
repo_id=repo_id,
filename="GR00T-WholeBodyControl-Walk.onnx",
)
# Load ONNX policies
policy_balance = ort.InferenceSession(balance_path)
policy_walk = ort.InferenceSession(walk_path)
logger.info("GR00T policies loaded successfully")
return policy_balance, policy_walk
class GrootLocomotionController:
"""
Handles GR00T-style locomotion control for the Unitree G1 robot.
This controller manages:
- Dual-policy system (Balance + Walk)
- 29-joint observation processing
- 15D action output (legs + waist)
- Policy inference and motor command generation
"""
def __init__(self, policy_balance, policy_walk, robot, config):
self.policy_balance = policy_balance
self.policy_walk = policy_walk
self.robot = robot
self.config = config
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot
# GR00T-specific state
self.groot_qj_all = np.zeros(29, dtype=np.float32)
self.groot_dqj_all = np.zeros(29, dtype=np.float32)
self.groot_action = np.zeros(15, dtype=np.float32)
self.groot_obs_single = np.zeros(86, dtype=np.float32)
self.groot_obs_history = deque(maxlen=6)
self.groot_obs_stacked = np.zeros(516, dtype=np.float32)
self.groot_height_cmd = 0.74 # Default base height
self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
# input to gr00t is 6 frames (6*86D=516)
for _ in range(6):
self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
# Thread management
self.locomotion_running = False
self.locomotion_thread = None
logger.info("GrootLocomotionController initialized")
def groot_locomotion_run(self):
# get current observation
robot_state = self.robot.get_observation()
if robot_state is None:
return
# get command from remote controller
if robot_state.wireless_remote is not None:
self.robot.remote_controller.set(robot_state.wireless_remote)
if self.robot.remote_controller.button[0]: # R1 - raise waist
self.groot_height_cmd += 0.001
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
if self.robot.remote_controller.button[4]: # R2 - lower waist
self.groot_height_cmd -= 0.001
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
else:
self.robot.remote_controller.lx = 0.0
self.robot.remote_controller.ly = 0.0
self.robot.remote_controller.rx = 0.0
self.robot.remote_controller.ry = 0.0
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # rotation rate
for i in range(29):
self.groot_qj_all[i] = robot_state.motor_state[i].q
self.groot_dqj_all[i] = robot_state.motor_state[i].dq
# adapt observation for g1_23dof
for idx in MISSING_JOINTS:
self.groot_qj_all[idx] = 0.0
self.groot_dqj_all[idx] = 0.0
# Scale joint positions and velocities
qj_obs = self.groot_qj_all.copy()
dqj_obs = self.groot_dqj_all.copy()
# express imu data in gravity frame of reference
quat = robot_state.imu_state.quaternion
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
gravity_orientation = self.robot.get_gravity_orientation(quat)
# scale joint positions and velocities before policy inference
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
dqj_obs = dqj_obs * DOF_VEL_SCALE
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
# build single frame observation
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE)
self.groot_obs_single[3] = self.groot_height_cmd
self.groot_obs_single[4:7] = self.groot_orientation_cmd
self.groot_obs_single[7:10] = ang_vel_scaled
self.groot_obs_single[10:13] = gravity_orientation
self.groot_obs_single[13:42] = qj_obs
self.groot_obs_single[42:71] = dqj_obs
self.groot_obs_single[71:86] = self.groot_action # 15D previous actions
# Add to history and stack observations (6 frames × 86D = 516D)
self.groot_obs_history.append(self.groot_obs_single.copy())
# Stack all 6 frames into 516D vector
for i, obs_frame in enumerate(self.groot_obs_history):
start_idx = i * 86
end_idx = start_idx + 86
self.groot_obs_stacked[start_idx:end_idx] = obs_frame
# Run policy inference (ONNX) with 516D stacked observation
obs_tensor = torch.from_numpy(self.groot_obs_stacked).unsqueeze(0)
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
selected_policy = (
self.policy_balance if cmd_magnitude < 0.05 else self.policy_walk
) # balance/standing policy for small commands, walking policy for movement commands
# run policy inference
ort_inputs = {selected_policy.get_inputs()[0].name: obs_tensor.cpu().numpy()}
ort_outs = selected_policy.run(None, ort_inputs)
self.groot_action = ort_outs[0].squeeze()
# transform action back to target joint positions
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * LOCOMOTION_ACTION_SCALE
# command motors
for i in range(15):
motor_idx = i
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# adapt action for g1_23dof
for joint_idx in MISSING_JOINTS:
self.robot.msg.motor_cmd[joint_idx].q = 0.0
self.robot.msg.motor_cmd[joint_idx].qd = 0
self.robot.msg.motor_cmd[joint_idx].kp = self.robot.kp[joint_idx]
self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd[joint_idx]
self.robot.msg.motor_cmd[joint_idx].tau = 0
# send action to robot
self.robot.send_action(self.robot.msg)
def _locomotion_thread_loop(self):
"""Background thread that runs the locomotion policy at specified rate."""
logger.info("Locomotion thread started")
while self.locomotion_running:
start_time = time.time()
try:
self.groot_locomotion_run()
except Exception as e:
logger.error(f"Error in locomotion loop: {e}")
# Sleep to maintain control rate
elapsed = time.time() - start_time
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
time.sleep(sleep_time)
logger.info("Locomotion thread stopped")
def start_locomotion_thread(self):
if self.locomotion_running:
logger.warning("Locomotion thread already running")
return
logger.info("Starting locomotion control thread...")
self.locomotion_running = True
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
self.locomotion_thread.start()
logger.info("Locomotion control thread started!")
def stop_locomotion_thread(self):
if not self.locomotion_running:
return
logger.info("Stopping locomotion control thread...")
self.locomotion_running = False
if self.locomotion_thread:
self.locomotion_thread.join(timeout=2.0)
logger.info("Locomotion control thread stopped")
def reset_robot(self):
"""Move robot legs to default standing position over 2 seconds (arms are not moved)."""
total_time = 3.0
num_step = int(total_time / self.robot.control_dt)
# Only control legs, not arms (first 12 joints)
default_pos = GROOT_DEFAULT_ANGLES # First 12 values are leg angles
dof_size = len(default_pos)
# Get current lowstate
robot_state = self.robot.get_observation()
# Record the current leg positions
init_dof_pos = np.zeros(dof_size, dtype=np.float32)
for i in range(dof_size):
init_dof_pos[i] = robot_state.motor_state[i].q
# Move legs to default pos
for i in range(num_step):
alpha = i / num_step
for motor_idx in range(dof_size):
target_pos = default_pos[motor_idx]
self.robot.msg.motor_cmd[motor_idx].q = (
init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
)
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
self.robot.msg.motor_cmd[motor_idx].tau = 0
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
self.robot.lowcmd_publisher.Write(self.robot.msg)
time.sleep(self.robot.control_dt)
logger.info("Reached default position (legs only)")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1")
parser.add_argument(
"--repo-id",
type=str,
default=DEFAULT_GROOT_REPO_ID,
help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})",
)
args = parser.parse_args()
# load policies
policy_balance, policy_walk = load_groot_policies(repo_id=args.repo_id)
# initialize robot
config = UnitreeG1Config()
robot = UnitreeG1(config)
# initialize gr00t locomotion controller
groot_controller = GrootLocomotionController(
policy_balance=policy_balance,
policy_walk=policy_walk,
robot=robot,
config=config,
)
# reset legs and start locomotion thread
try:
groot_controller.reset_robot()
groot_controller.start_locomotion_thread()
# log status
logger.info("Robot initialized with GR00T locomotion policies")
logger.info("Locomotion controller running in background thread")
logger.info("Press Ctrl+C to stop")
# keep robot alive
while True:
time.sleep(1.0)
except KeyboardInterrupt:
print("\nStopping locomotion...")
groot_controller.stop_locomotion_thread()
print("Done!")

View File

@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.4.3"
version = "0.4.2"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
readme = "README.md"
license = { text = "Apache-2.0" }
@@ -107,6 +107,10 @@ dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
unitree_g1 = [
"pyzmq>=26.2.1,<28.0.0",
"unitree_sdk2py @ git+https://github.com/unitreerobotics/unitree_sdk2_python.git",
]
reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"]
kinematics = ["lerobot[placo-dep]"]
intelrealsense = [
@@ -129,7 +133,6 @@ groot = [
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
xvla = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
@@ -158,7 +161,6 @@ all = [
"lerobot[pi]",
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[async]",
"lerobot[dev]",

View File

@@ -245,7 +245,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
class LiberoEnv(EnvConfig):
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
fps: int = 30
episode_length: int | None = None
episode_length: int = 520
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
@@ -272,7 +272,6 @@ class LiberoEnv(EnvConfig):
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
}
)
control_mode: str = "relative" # or "absolute"
def __post_init__(self):
if self.obs_type == "pixels":

View File

@@ -19,10 +19,8 @@ from typing import Any
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import ProcessorStep
from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
@@ -41,7 +39,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
def make_env_pre_post_processors(
env_cfg: EnvConfig,
policy_cfg: PreTrainedConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
@@ -64,10 +61,6 @@ def make_env_pre_post_processors(
# Preprocessor and Postprocessor steps are Identity for most environments
preprocessor_steps: list[ProcessorStep] = []
postprocessor_steps: list[ProcessorStep] = []
if isinstance(policy_cfg, XVLAConfig):
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
return make_xvla_libero_pre_post_processors()
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
@@ -143,8 +136,6 @@ def make_env(
init_states=cfg.init_states,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
control_mode=cfg.control_mode,
episode_length=cfg.episode_length,
)
elif "metaworld" in cfg.type:
from lerobot.envs.metaworld import create_metaworld_envs

View File

@@ -80,7 +80,10 @@ def get_libero_dummy_action():
return [0, 0, 0, 0, 0, 0, -1]
OBS_STATE_DIM = 8
ACTION_DIM = 7
AGENT_POS_LOW = -1000.0
AGENT_POS_HIGH = 1000.0
ACTION_LOW = -1.0
ACTION_HIGH = 1.0
TASK_SUITE_MAX_STEPS: dict[str, int] = {
@@ -100,7 +103,6 @@ class LiberoEnv(gym.Env):
task_suite: Any,
task_id: int,
task_suite_name: str,
episode_length: int | None = None,
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
obs_type: str = "pixels",
render_mode: str = "rgb_array",
@@ -112,7 +114,6 @@ class LiberoEnv(gym.Env):
episode_index: int = 0,
camera_name_mapping: dict[str, str] | None = None,
num_steps_wait: int = 10,
control_mode: str = "relative",
):
super().__init__()
self.task_id = task_id
@@ -140,19 +141,14 @@ class LiberoEnv(gym.Env):
self.camera_name_mapping = camera_name_mapping
self.num_steps_wait = num_steps_wait
self.episode_index = episode_index
self.episode_length = episode_length
# Load once and keep
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._env = self._make_envs_task(task_suite, self.task_id)
default_steps = 500
self._max_episode_steps = (
TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
if self.episode_length is None
else self.episode_length
)
self.control_mode = control_mode
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
images = {}
for cam in self.camera_name:
images[self.camera_name_mapping[cam]] = spaces.Box(
@@ -300,15 +296,6 @@ class LiberoEnv(gym.Env):
# Increasing this value can improve determinism and reproducibility across resets.
for _ in range(self.num_steps_wait):
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
if self.control_mode == "absolute":
for robot in self._env.robots:
robot.controller.use_delta = False
elif self.control_mode == "relative":
for robot in self._env.robots:
robot.controller.use_delta = True
else:
raise ValueError(f"Invalid control mode: {self.control_mode}")
observation = self._format_raw_obs(raw_obs)
info = {"is_success": False}
return observation, info
@@ -354,10 +341,8 @@ def _make_env_fns(
task_id: int,
n_envs: int,
camera_names: list[str],
episode_length: int | None,
init_states: bool,
gym_kwargs: Mapping[str, Any],
control_mode: str,
) -> list[Callable[[], LiberoEnv]]:
"""Build n_envs factory callables for a single (suite, task_id)."""
@@ -369,9 +354,7 @@ def _make_env_fns(
task_suite_name=suite_name,
camera_name=camera_names,
init_states=init_states,
episode_length=episode_length,
episode_index=episode_index,
control_mode=control_mode,
**local_kwargs,
)
@@ -391,8 +374,6 @@ def create_libero_envs(
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
init_states: bool = True,
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
control_mode: str = "relative",
episode_length: int | None = None,
) -> dict[str, dict[int, Any]]:
"""
Create vectorized LIBERO environments with a consistent return shape.
@@ -434,14 +415,12 @@ def create_libero_envs(
for tid in selected:
fns = _make_env_fns(
suite=suite,
episode_length=episode_length,
suite_name=suite_name,
task_id=tid,
n_envs=n_envs,
camera_names=camera_names,
init_states=init_states,
gym_kwargs=gym_kwargs,
control_mode=control_mode,
)
out[suite_name][tid] = env_cls(fns)
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")

View File

@@ -21,7 +21,6 @@ from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
__all__ = [
"ACTConfig",
@@ -32,5 +31,4 @@ __all__ = [
"TDMPCConfig",
"VQBeTConfig",
"GrootConfig",
"XVLAConfig",
]

View File

@@ -40,7 +40,6 @@ from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
batch_to_transition,
@@ -108,10 +107,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.groot.modeling_groot import GrootPolicy
return GrootPolicy
elif name == "xvla":
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
return XVLAPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -155,8 +150,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return RewardClassifierConfig(**kwargs)
elif policy_type == "groot":
return GrootConfig(**kwargs)
elif policy_type == "xvla":
return XVLAConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
@@ -336,15 +329,6 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, XVLAConfig):
from lerobot.policies.xvla.processor_xvla import (
make_xvla_pre_post_processors,
)
processors = make_xvla_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")

View File

@@ -1,6 +0,0 @@
# register the processor steps
from lerobot.policies.xvla.processor_xvla import (
XVLAAddDomainIdProcessorStep,
XVLAImageNetNormalizeProcessorStep,
XVLAImageToFloatProcessorStep,
)

View File

@@ -1,454 +0,0 @@
# ------------------------------------------------------------------------------
# Copyright 2025 2toINF and HuggingFace Inc. (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
from collections.abc import Iterable
import torch
import torch.nn as nn
# =============================================================================
# Registry
# =============================================================================
ACTION_REGISTRY: dict[str, type[BaseActionSpace]] = {}
def register_action(name: str):
"""Decorator for registering a new action space."""
def _wrap(cls):
key = name.lower()
if key in ACTION_REGISTRY:
raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}")
ACTION_REGISTRY[key] = cls
cls.name = key
return cls
return _wrap
def build_action_space(name: str, **kwargs) -> BaseActionSpace:
"""Instantiate a registered action space by name."""
key = name.lower()
if key not in ACTION_REGISTRY:
raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}")
return ACTION_REGISTRY[key](**kwargs)
# =============================================================================
# Base class
# =============================================================================
class BaseActionSpace(nn.Module):
"""
Abstract base class for all action-space definitions.
Each subclass defines:
- `dim_action`: dimension of the action vector.
- `gripper_idx`: indices of gripper channels.
- `compute_loss(pred, target)`: supervised loss for this space.
- `preprocess(proprio, action, mode)`: pre-step modifications.
- `postprocess(action)`: post-step corrections (e.g. apply sigmoid).
"""
name: str = "base"
dim_action: int = 0
gripper_idx: tuple[int, ...] = ()
def __init__(self):
super().__init__()
# ---------------------------------------------------------------------
# Core supervised loss
# ---------------------------------------------------------------------
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
raise NotImplementedError
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
"""Alias for compute_loss."""
return self.compute_loss(pred, target)
# ---------------------------------------------------------------------
# Space-level hooks
# ---------------------------------------------------------------------
def preprocess(
self,
proprio: torch.Tensor,
action: torch.Tensor,
mode: str = "train",
) -> tuple[torch.Tensor, torch.Tensor]:
"""Default: return unchanged."""
return proprio, action
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""Default: return unchanged."""
return action
# =============================================================================
# Utilities
# =============================================================================
def _ensure_indices_valid(dim_action: int, idx: Iterable[int], name: str) -> None:
bad = [i for i in idx if i < 0 or i >= dim_action]
if bad:
raise IndexError(f"{name} contains out-of-range indices {bad} for action dim dim_action={dim_action}")
# =============================================================================
# Implementations
# =============================================================================
@register_action("ee6d")
class EE6DActionSpace(BaseActionSpace):
"""End-effector layout with xyz, 6D rotation, and gripper channels."""
dim_action = 20
gripper_idx = (9, 19)
GRIPPER_SCALE = 1.0
XYZ_SCALE = 500.0
ROT_SCALE = 10.0
POS_IDX_1 = (0, 1, 2)
POS_IDX_2 = (10, 11, 12)
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape, "pred/target shapes must match"
batch_size, seq_len, action_dim = pred.shape
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
# Gripper BCE
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
# XYZ position
pos_loss = (
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
) * self.XYZ_SCALE
# Rotation 6D
rot_loss = (
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
) * self.ROT_SCALE
return {
"position_loss": pos_loss,
"rotate6D_loss": rot_loss,
"gripper_loss": gripper_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""Zero-out gripper channels in proprio/action."""
proprio_m = proprio.clone()
action_m = action.clone()
proprio_m[..., self.gripper_idx] = 0.0
action_m[..., self.gripper_idx] = 0.0
return proprio_m, action_m
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""Apply sigmoid to gripper logits."""
if action.size(-1) > max(self.gripper_idx):
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
return action
@register_action("joint")
class JointActionSpace(BaseActionSpace):
"""Joint-space layout with joints + gripper only."""
dim_action = 14
gripper_idx = (6, 13)
GRIPPER_SCALE = 0.1
JOINTS_SCALE = 1.0
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape
batch_size, seq_len, action_dim = pred.shape
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
joints_idx = tuple(i for i in range(action_dim) if i not in set(self.gripper_idx))
joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
return {
"joints_loss": joints_loss,
"gripper_loss": gripper_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""Zero-out gripper channels in proprio/action."""
proprio_m = proprio.clone()
action_m = action.clone()
proprio_m[..., self.gripper_idx] = 0.0
action_m[..., self.gripper_idx] = 0.0
return proprio_m, action_m
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""Apply sigmoid to gripper logits."""
if action.size(-1) > max(self.gripper_idx):
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
return action
@register_action("agibot_ee6d")
class AGIBOTEE6DActionSpace(BaseActionSpace):
"""AGI-bot variant of EE6DActionSpace using MSE for all components."""
dim_action = 20
gripper_idx = (9, 19)
GRIPPER_SCALE = 10.0
XYZ_SCALE = 500.0
ROT_SCALE = 10.0
POS_IDX_1 = (0, 1, 2)
POS_IDX_2 = (10, 11, 12)
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape
batch_size, seq_len, action_dim = pred.shape
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
gripper_loss = (
self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
)
pos_loss = (
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
) * self.XYZ_SCALE
rot_loss = (
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
) * self.ROT_SCALE
return {
"position_loss": pos_loss,
"rotate6D_loss": rot_loss,
"gripper_loss": gripper_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""No preprocessing applied in AGIBOT variant."""
return proprio, action
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""AGIBOT does not postprocess."""
return action
@register_action("franka_joint7")
class FrankaJoint7ActionSpace(BaseActionSpace):
"""Franka Panda joint-space: 7 joints, no gripper."""
dim_action = 7
JOINTS_SCALE = 1.0
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape, "pred/target shapes must match"
joints_loss = self.mse(pred, target) * self.JOINTS_SCALE
return {"joints_loss": joints_loss}
def preprocess(self, proprio, action, mode="train"):
"""No preprocessing needed for 7 joint actions."""
return proprio, action
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""Return directly (no sigmoid since no gripper)."""
return action
@register_action("so101_bimanual")
class BimanualSO101ActionSpace(BaseActionSpace):
"""
Bimanual SO101 robot: 2 arms with 5 joints each + gripper.
Layout (real robot):
[left_arm (5 joints + gripper), right_arm (5 joints + gripper)]
- Left arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
- Right arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
Real action dim: 12
Model-facing dim: 20 (extra 8 dummy dims at the end)
"""
# Model output / training dimension (to match pretrained policy)
dim_action = 20
# Real robot action dimension
REAL_DIM = 12
# Indices of real vs dummy channels
REAL_IDXS = tuple(range(REAL_DIM)) # 0..11
DUMMY_IDXS = tuple(range(REAL_DIM, dim_action)) # 12..19
# Grippers live in the real part
gripper_idx = (5, 11) # left_gripper at idx 5, right_gripper at idx 11
GRIPPER_SCALE = 1.0
JOINTS_SCALE = 1.0
# Indices for left and right arm joints (excluding grippers)
LEFT_ARM_JOINTS = (0, 1, 2, 3, 4)
RIGHT_ARM_JOINTS = (6, 7, 8, 9, 10)
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
# ---------- helpers ----------
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
"""If last dim is REAL_DIM (12), pad zeros to reach dim_action (20)."""
if x is None:
return None
if x.size(-1) == self.dim_action:
return x
if x.size(-1) != self.REAL_DIM:
raise ValueError(
f"Expected last dim to be {self.REAL_DIM} or {self.dim_action}, got {x.size(-1)}"
)
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.REAL_DIM]
pad = x.new_zeros(pad_shape)
return torch.cat([x, pad], dim=-1)
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
"""Keep only the first REAL_DIM (12) dims for the real robot."""
return x[..., : self.REAL_DIM]
# ---------- loss ----------
def compute_loss(self, pred, target):
"""
pred: [B, T, 20] from the model
target: [B, T, 12] or [B, T, 20]
We pad target → 20 and compute loss only on the real dims.
"""
# Ensure both are [B, T, 20]
pred = self._pad_to_model_dim(pred)
target = self._pad_to_model_dim(target)
assert pred.shape == target.shape
# ---- MSE for all real dims (011) ----
real_dims = 12
joints_loss = (
self.mse(
pred[:, :, :real_dims],
target[:, :, :real_dims],
)
* self.JOINTS_SCALE
)
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
gripper_loss = (
self.mse(
pred[:, :, [5, 11]],
target[:, :, [5, 11]],
)
* self.GRIPPER_SCALE
)
return {
"joints_loss": joints_loss,
"gripper_loss": gripper_loss,
"left_arm_loss": left_arm_loss,
"right_arm_loss": right_arm_loss,
}
# ---------- preprocess / postprocess ----------
def preprocess(self, proprio, action, mode="train"):
"""
- If proprio/action are 12-dim, pad them to 20 for the model.
- Zero-out gripper channels in proprio/action to focus learning on joints.
"""
proprio_m = self._pad_to_model_dim(proprio.clone())
action_m = self._pad_to_model_dim(action.clone()) if action is not None else None
proprio_m[..., self.gripper_idx] = 0.0
if action_m is not None:
action_m[..., self.gripper_idx] = 0.0
return proprio_m, action_m
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""
- Model outputs [*, 20]
- Apply sigmoid to gripper logits
- Return only the first 12 dims for the real robot:
["left_shoulder_pan.pos",
"left_shoulder_lift.pos",
"left_elbow_flex.pos",
"left_wrist_flex.pos",
"left_wrist_roll.pos",
"left_gripper.pos",
"right_shoulder_pan.pos",
"right_shoulder_lift.pos",
"right_elbow_flex.pos",
"right_wrist_flex.pos",
"right_wrist_roll.pos",
"right_gripper.pos"]
"""
# Ensure we at least have the real dims + grippers
if action.size(-1) < self.REAL_DIM:
raise ValueError(f"Expected at least {self.REAL_DIM} dims in action, got {action.size(-1)}")
# Apply sigmoid on gripper channels in model space (indices 5 and 11)
if action.size(-1) > max(self.gripper_idx):
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
# Return only the real 12-dim control vector for the env
return self._trim_to_real_dim(action)
# =============================================================================
# Exports
# =============================================================================
__all__ = [
"BaseActionSpace",
"build_action_space",
"register_action",
"EE6DActionSpace",
"JointActionSpace",
"AGIBOTEE6DActionSpace",
"FrankaJoint7ActionSpace",
"BimanualSO101ActionSpace",
"ACTION_REGISTRY",
]

View File

@@ -1,353 +0,0 @@
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
""" Florence-2 configuration"""
logger = logging.get_logger(__name__)
class Florence2VisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
drop_path_rate (`float`, *optional*, defaults to 0.1):
The dropout rate of the drop path layer.
patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
The patch size of the image.
patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
The patch stride of the image.
patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
The patch padding of the image.
patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
Whether to apply layer normalization before the patch embedding layer.
enable_checkpoint (`bool`, *optional*, defaults to False):
Whether to enable checkpointing.
dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
The dimension of the embedding layer.
num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
The number of attention heads.
num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
The number of groups.
depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
The depth of the model.
window_size (`int`, *optional*, defaults to 12):
The window size of the model.
projection_dim (`int`, *optional*, defaults to 1024):
The dimension of the projection layer.
visual_temporal_embedding (`dict`, *optional*):
The configuration of the visual temporal embedding.
image_pos_embed (`dict`, *optional*):
The configuration of the image position embedding.
image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
The source of the image feature.
Example:
```python
>>> from transformers import Florence2VisionConfig, Florence2VisionModel
>>> # Initializing a Florence2 Vision style configuration
>>> configuration = Florence2VisionConfig()
>>> # Initializing a model (with random weights)
>>> model = Florence2VisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "davit"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
drop_path_rate=0.1,
patch_size=None,
patch_stride=None,
patch_padding=None,
patch_prenorm=None,
enable_checkpoint=False,
dim_embed=None,
num_heads=None,
num_groups=None,
depths=None,
window_size=12,
projection_dim=1024,
visual_temporal_embedding=None,
image_pos_embed=None,
image_feature_source=None,
**kwargs,
):
self.drop_path_rate = drop_path_rate
self.patch_size = patch_size if patch_size is not None else [7, 3, 3, 3]
self.patch_stride = patch_stride if patch_stride is not None else [4, 2, 2, 2]
self.patch_padding = patch_padding if patch_padding is not None else [3, 1, 1, 1]
self.patch_prenorm = patch_prenorm if patch_prenorm is not None else [False, True, True, True]
self.enable_checkpoint = enable_checkpoint
self.dim_embed = dim_embed if dim_embed is not None else [256, 512, 1024, 2048]
self.num_heads = num_heads if num_heads is not None else [8, 16, 32, 64]
self.num_groups = num_groups if num_groups is not None else [8, 16, 32, 64]
self.depths = depths if depths is not None else [1, 1, 9, 1]
self.window_size = window_size
self.projection_dim = projection_dim
if visual_temporal_embedding is None:
visual_temporal_embedding = {
"type": "COSINE",
"max_temporal_embeddings": 100,
}
self.visual_temporal_embedding = visual_temporal_embedding
if image_pos_embed is None:
image_pos_embed = {
"type": "learned_abs_2d",
"max_pos_embeddings": 1000,
}
self.image_pos_embed = image_pos_embed
self.image_feature_source = (
image_feature_source
if image_feature_source is not None
else ["spatial_avg_pool", "temporal_avg_pool"]
)
super().__init__(**kwargs)
class Florence2LanguageConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the BART
[facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 51289):
Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Florence2LanguageModel`].
d_model (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
encoder_layers (`int`, *optional*, defaults to 12):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 12):
Number of decoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
scale_embedding (`bool`, *optional*, defaults to `False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
num_labels (`int`, *optional*, defaults to 3):
The number of labels to use in [`Florence2LanguageForSequenceClassification`].
forced_eos_token_id (`int`, *optional*, defaults to 2):
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
`eos_token_id`.
Example:
```python
>>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
>>> # Initializing a Florence2 Language style configuration
>>> configuration = Florence2LanguageConfig()
>>> # Initializing a model (with random weights)
>>> model = Florence2LanguageModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "florence2_language"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
def __init__(
self,
vocab_size=51289,
max_position_embeddings=1024,
encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
activation_function="gelu",
d_model=1024,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
classifier_dropout=0.0,
scale_embedding=False,
use_cache=True,
num_labels=3,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
is_encoder_decoder=True,
decoder_start_token_id=2,
forced_eos_token_id=2,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__(
num_labels=num_labels,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id,
forced_eos_token_id=forced_eos_token_id,
**kwargs,
)
# ensure backward compatibility for BART CNN models
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
self.forced_bos_token_id = self.bos_token_id
warnings.warn(
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
"The config can simply be saved and uploaded again to be fixed.",
stacklevel=2,
)
class Florence2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
Florence-2 model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`Florence2VisionConfig`, *optional*):
Custom vision config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
vocab_size (`int`, *optional*, defaults to 51289):
Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
projection_dim (`int`, *optional*, defaults to 1024):
Dimension of the multimodal projection space.
Example:
```python
>>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
>>> # Initializing a clip-like vision config
>>> vision_config = CLIPVisionConfig()
>>> # Initializing a Bart config
>>> text_config = BartConfig()
>>> # Initializing a Florence-2 configuration
>>> configuration = Florence2Config(vision_config, text_config)
>>> # Initializing a model from the florence-2 configuration
>>> model = Florence2ForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "florence2"
is_composition = False
def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
vocab_size=51289,
projection_dim=1024,
**kwargs,
):
self.ignore_index = ignore_index
self.vocab_size = vocab_size
self.projection_dim = projection_dim
if vision_config is not None:
vision_config = Florence2VisionConfig(**vision_config)
self.vision_config = vision_config
self.text_config = text_config
if text_config is not None:
self.text_config = Florence2LanguageConfig(**text_config)
super().__init__(**kwargs)

View File

@@ -1,190 +0,0 @@
#!/usr/bin/env python
# ------------------------------------------------------------------------------
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import OBS_IMAGES
# Conditional import for type checking and lazy loading
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from .configuration_florence2 import Florence2Config
else:
Florence2Config = None
@PreTrainedConfig.register_subclass("xvla")
@dataclass
class XVLAConfig(PreTrainedConfig):
"""
Configuration class for the XVLA (Extended Vision-Language-Action) policy so it can
plug into the LeRobot training stack.
The config mirrors the knobs exposed in the original XVLA repository but also
declares the input/output feature contract required by LeRobot.
"""
# Input / output structure
n_obs_steps: int = 1
chunk_size: int = 32
n_action_steps: int = 32
dtype: str = "float32" # Options: "bfloat16", "float32"
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Florence2 backbone and tokenizer configuration
florence_config: dict[str, Any] = field(default_factory=dict)
tokenizer_name: str = "facebook/bart-large"
tokenizer_max_length: int = 64
tokenizer_padding_side: str = "right"
pad_language_to: str = "max_length"
# Transformer head
hidden_size: int = 1024
depth: int = 24
num_heads: int = 16
mlp_ratio: float = 4.0
num_domains: int = 30
len_soft_prompts: int = 32
dim_time: int = 32
max_len_seq: int = 512
use_hetero_proj: bool = False
# Action & proprioception
action_mode: str = "ee6d"
num_denoising_steps: int = 10
use_proprio: bool = True
max_state_dim: int = 32
domain_feature_key: str | None = None
# Vision preprocessing
resize_imgs_with_padding: tuple[int, int] | None = None
num_image_views: int | None = None
empty_cameras: int = 0
# Freezing options for VLM components
# By default, VLM encoders are frozen and only policy transformer + soft prompts train
freeze_vision_encoder: bool = True # Freeze VLM vision encoder weights
freeze_language_encoder: bool = True # Freeze VLM language encoder weights
train_policy_transformer: bool = True # Allow policy transformer to train
train_soft_prompts: bool = True # Allow soft prompts to train
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-4
optimizer_grad_clip_norm: float = 10.0
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
def __post_init__(self) -> None:
super().__post_init__()
if self.chunk_size <= 0:
raise ValueError("`chunk_size` must be strictly positive.")
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"`n_action_steps` ({self.n_action_steps}) must be <= `chunk_size` ({self.chunk_size})."
)
if self.num_image_views is not None and self.num_image_views <= 0:
raise ValueError("`num_image_views` must be > 0 when specified.")
if self.dtype not in ["bfloat16", "float32"]:
raise ValueError(f"Invalid dtype: {self.dtype}")
self._florence_config_obj: Florence2Config | None = None
def get_florence_config(self) -> Florence2Config:
"""
Build (and cache) the Florence2 transformer config that should back the VLM.
"""
if self._florence_config_obj is None:
config_dict = dict(self.florence_config)
if "vision_config" not in config_dict or config_dict["vision_config"] is None:
raise ValueError("vision_config is required")
if "text_config" not in config_dict or config_dict["text_config"] is None:
raise ValueError("text_config is required")
self._florence_config_obj = Florence2Config(**config_dict)
return self._florence_config_obj
def validate_features(self) -> None:
if not self.image_features:
raise ValueError("XVLA requires at least one visual feature in the inputs.")
if self.use_proprio and self.robot_state_feature is None:
raise ValueError("`use_proprio=True` requires a proprioceptive state feature.")
if self.num_image_views is None:
self.num_image_views = len(self.image_features) + self.empty_cameras
else:
self.num_image_views = max(self.num_image_views, len(self.image_features) + self.empty_cameras)
if self.empty_cameras > 0:
height, width = (480, 640)
if self.resize_imgs_with_padding is not None:
height, width = self.resize_imgs_with_padding
for idx in range(self.empty_cameras):
key = f"{OBS_IMAGES}.empty_camera_{idx}"
if key not in self.input_features:
self.input_features[key] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, height, width),
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list[int] | None:
return None
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> list[int] | None:
return None

File diff suppressed because it is too large Load Diff

View File

@@ -1,526 +0,0 @@
#!/usr/bin/env python
# ------------------------------------------------------------------------------
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
import builtins
import logging
import os
from collections import deque
from pathlib import Path
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE
from .action_hub import build_action_space
from .configuration_florence2 import Florence2Config
from .configuration_xvla import XVLAConfig
from .modeling_florence2 import Florence2ForConditionalGeneration
from .soft_transformer import SoftPromptedTransformer
class XVLAModel(nn.Module):
"""
XVLA backbone that stitches Florence-2 embeddings with the temporal/action transformer head.
"""
def __init__(
self,
config: XVLAConfig,
florence_config: Florence2Config,
proprio_dim: int,
) -> None:
super().__init__()
self.config = config
self.chunk_size: int = config.chunk_size
self.use_proprio: bool = config.use_proprio
self.action_space = build_action_space(config.action_mode.lower())
self.dim_action = self.action_space.dim_action
self.dim_proprio = proprio_dim
self.vlm = Florence2ForConditionalGeneration(florence_config)
if hasattr(self.vlm, "language_model"):
lm = self.vlm.language_model
if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
del lm.model.decoder
if hasattr(lm, "lm_head"):
del lm.lm_head
projection_dim = getattr(self.vlm.config, "projection_dim", None)
if projection_dim is None:
raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
self.transformer = SoftPromptedTransformer(
hidden_size=config.hidden_size,
multi_modal_input_size=projection_dim,
depth=config.depth,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio,
num_domains=config.num_domains,
dim_action=self.dim_action,
dim_propio=self.dim_proprio,
len_soft_prompts=config.len_soft_prompts,
dim_time=config.dim_time,
max_len_seq=config.max_len_seq,
use_hetero_proj=config.use_hetero_proj,
)
# Apply freezing based on config
self._apply_freezing()
# Apply dtype casting based on config
self._apply_dtype()
def _get_target_dtype(self) -> torch.dtype:
"""Get the target dtype based on config."""
if self.config.dtype == "bfloat16":
return torch.bfloat16
return torch.float32
def _apply_dtype(self) -> None:
"""
Apply dtype casting to model components based on config.
"""
target_dtype = self._get_target_dtype()
self.to(dtype=target_dtype)
def _apply_freezing(self) -> None:
"""
Freeze VLM vision and language encoders based on config options.
Keep only policy transformer and soft prompts trainable.
"""
# Freeze vision encoder
if self.config.freeze_vision_encoder and hasattr(self.vlm, "vision_tower"):
for param in self.vlm.vision_tower.parameters():
param.requires_grad = False
# Freeze language encoder
if self.config.freeze_language_encoder and hasattr(self.vlm, "language_model"):
lm = self.vlm.language_model
# Freeze encoder
if hasattr(lm, "model") and hasattr(lm.model, "encoder"):
for param in lm.model.encoder.parameters():
param.requires_grad = False
# Freeze shared embeddings
if hasattr(lm, "model") and hasattr(lm.model, "shared"):
for param in lm.model.shared.parameters():
param.requires_grad = False
# Freeze or unfreeze policy transformer
if not self.config.train_policy_transformer:
for name, param in self.transformer.named_parameters():
if "soft_prompts" not in name:
param.requires_grad = False
# Freeze or unfreeze soft prompts
if not self.config.train_soft_prompts and hasattr(self.transformer, "soft_prompt_hub"):
for param in self.transformer.soft_prompt_hub.parameters():
param.requires_grad = False
def forward_vlm(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
image_mask: torch.Tensor,
) -> dict[str, torch.Tensor]:
"""
Encode text and multi-view images via Florence2 encoder.
"""
batch_size, num_views = pixel_values.shape[:2]
flat_mask = image_mask.view(-1).to(dtype=torch.bool)
flat_images = pixel_values.flatten(0, 1)
num_valid = int(flat_mask.sum().item())
if num_valid == 0:
raise ValueError("At least one image view must be valid per batch.")
valid_images = flat_images[flat_mask]
valid_feats = self.vlm._encode_image(valid_images)
tokens_per_view, hidden_dim = valid_feats.shape[1:]
image_features = valid_feats.new_zeros((batch_size * num_views, tokens_per_view, hidden_dim))
image_features[flat_mask] = valid_feats
image_features = image_features.view(batch_size, num_views, tokens_per_view, hidden_dim)
inputs_embeds = self.vlm.get_input_embeddings()(input_ids)
merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features(
image_features[:, 0],
inputs_embeds,
)
enc_out = self.vlm.language_model.model.encoder(
attention_mask=attention_mask,
inputs_embeds=merged_embeds,
)[0]
aux_visual_inputs = image_features[:, 1:].reshape(batch_size, -1, hidden_dim)
return {"vlm_features": enc_out, "aux_visual_inputs": aux_visual_inputs}
def forward(
self,
input_ids: torch.LongTensor,
image_input: torch.FloatTensor,
image_mask: torch.Tensor,
domain_id: torch.LongTensor,
proprio: torch.Tensor,
action: torch.Tensor,
) -> dict[str, torch.Tensor]:
enc = self.forward_vlm(input_ids, image_input, image_mask)
batch_size = input_ids.shape[0]
t = (
torch.rand(1, device=input_ids.device)
+ torch.arange(batch_size, device=input_ids.device) / batch_size
) % (1 - 1e-5)
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
pred_action = self.transformer(
domain_id=domain_id,
action_with_noise=action_noisy_m,
t=t,
proprio=proprio_m,
**enc,
)
return self.action_space.compute_loss(pred_action, action)
@torch.no_grad()
def generate_actions(
self,
input_ids: torch.LongTensor,
image_input: torch.FloatTensor,
image_mask: torch.Tensor,
domain_id: torch.LongTensor,
proprio: torch.Tensor,
steps: int,
) -> torch.Tensor:
self.eval()
enc = self.forward_vlm(input_ids, image_input, image_mask)
batch_size = input_ids.shape[0]
action_dim = self.dim_action
x1 = torch.randn(batch_size, self.chunk_size, action_dim, device=proprio.device, dtype=proprio.dtype)
action = torch.zeros_like(x1)
steps = max(1, int(steps))
for i in range(steps, 0, -1):
t = torch.full((batch_size,), i / steps, device=proprio.device, dtype=proprio.dtype)
x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
action = self.transformer(
domain_id=domain_id,
action_with_noise=x_t_m,
proprio=proprio_m,
t=t,
**enc,
)
return self.action_space.postprocess(action)
class XVLAPolicy(PreTrainedPolicy):
"""LeRobot-compliant wrapper built around the XVLA model."""
config_class = XVLAConfig
name = "xvla"
def __init__(self, config: XVLAConfig):
super().__init__(config)
config.validate_features()
florence_config = config.get_florence_config()
proprio_dim = config.max_state_dim if config.use_proprio else 0
self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim)
self.reset()
def reset(self) -> None:
self._queues = {
ACTION: deque(maxlen=self.config.n_action_steps),
}
def get_optim_params(self) -> dict:
"""Return only trainable parameters for optimization."""
return filter(lambda p: p.requires_grad, self.parameters())
def _prepare_state(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
if not self.config.use_proprio or OBS_STATE not in batch:
return torch.zeros(batch_size, 0, device=device)
state = batch[OBS_STATE]
if state.ndim > 2:
state = state[:, -1, :]
return pad_vector(state, self.model.dim_proprio)
def _prepare_images(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
present_img_keys = [key for key in self.config.image_features if key in batch]
if len(present_img_keys) == 0:
raise ValueError(
"All image features are missing from the batch. "
f"Batch keys: {list(batch.keys())}, expected at least one of {list(self.config.image_features)}."
)
images = []
masks = []
for key in present_img_keys:
img = batch[key][:, -1] if batch[key].ndim == 5 else batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(img, *self.config.resize_imgs_with_padding)
images.append(img)
masks.append(torch.ones(img.size(0), dtype=torch.bool, device=img.device))
stacked_imgs = torch.stack(images, dim=1)
stacked_masks = torch.stack(masks, dim=1)
total_views = self.config.num_image_views or stacked_imgs.size(1)
total_views = max(total_views, stacked_imgs.size(1))
num_pad = total_views - stacked_imgs.size(1)
if num_pad > 0:
pad_shape = (stacked_imgs.size(0), num_pad, *stacked_imgs.shape[2:])
pad_imgs = stacked_imgs.new_zeros(pad_shape)
pad_masks = stacked_masks.new_zeros((stacked_masks.size(0), num_pad))
stacked_imgs = torch.cat([stacked_imgs, pad_imgs], dim=1)
stacked_masks = torch.cat([stacked_masks, pad_masks], dim=1)
return stacked_imgs, stacked_masks
def _get_domain_id(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
candidate = None
if self.config.domain_feature_key and self.config.domain_feature_key in batch:
candidate = batch[self.config.domain_feature_key]
elif "domain_id" in batch:
candidate = batch["domain_id"]
if candidate is None:
return torch.zeros(batch_size, dtype=torch.long, device=device)
if not isinstance(candidate, torch.Tensor):
candidate = torch.as_tensor(candidate, device=device)
else:
candidate = candidate.to(device=device)
if candidate.ndim == 0:
candidate = candidate.expand(batch_size)
if candidate.ndim > 1:
candidate = candidate.view(candidate.shape[0], -1)[:, 0]
if candidate.shape[0] != batch_size:
candidate = candidate.expand(batch_size)
return candidate.to(dtype=torch.long)
def _prepare_action_targets(self, batch: dict[str, Tensor]) -> Tensor:
if ACTION not in batch:
raise ValueError("Batch is missing action targets required for training.")
actions = batch[ACTION]
if actions.ndim == 2:
actions = actions.unsqueeze(1)
actions = pad_tensor_along_dim(actions, self.config.chunk_size, dim=1)
if actions.shape[-1] != self.model.dim_action:
actions = pad_vector(actions, self.model.dim_action)
return actions
def _build_model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
input_ids = batch[OBS_LANGUAGE_TOKENS]
batch_size = input_ids.shape[0]
images, image_mask = self._prepare_images(batch)
domain_id = self._get_domain_id(batch, batch_size, images.device)
proprio = self._prepare_state(batch, batch_size, images.device)
return {
"input_ids": input_ids,
"image_input": images,
"image_mask": image_mask,
"domain_id": domain_id,
"proprio": proprio,
}
def _trim_action_dim(self, actions: Tensor) -> Tensor:
feature = self.config.action_feature
if feature is None:
return actions
desired_dim = self.model.dim_action
if desired_dim == actions.shape[-1]:
return actions
if desired_dim < actions.shape[-1]:
return actions[..., :desired_dim]
return pad_vector(actions, desired_dim)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
inputs = self._build_model_inputs(batch)
targets = self._prepare_action_targets(batch)
losses = self.model(action=targets, **inputs)
total_loss = sum(losses.values())
log_dict = {k: v.detach().item() for k, v in losses.items()}
log_dict["loss"] = total_loss.detach().item()
return total_loss, log_dict
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
inputs = self._build_model_inputs(batch)
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
actions = self._trim_action_dim(actions)
return actions
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
return self._get_action_chunk(batch)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
if len(self._queues[ACTION]) == 0:
actions = self._get_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = False,
**kwargs,
):
"""
Loads XVLA model weights with:
- automatic prefix 'model.' added to all keys
- skip list for layers that should remain randomly initialized
"""
import safetensors.torch
# step 1: load config
# TODO: jadechoghari, fix this
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
model_id = str(pretrained_name_or_path)
instance = cls(config, **kwargs)
# step 2: locate model.safetensors
if os.path.isdir(model_id):
logging.info("Loading weights from local directory")
model_file = os.path.join(model_id, "model.safetensors")
else:
try:
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
model_file = hf_hub_download(
repo_id=model_id,
filename="model.safetensors",
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
logging.info(f"Loading checkpoint from {model_file}")
# step 3: load state dict
state_dict = safetensors.torch.load_file(model_file)
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
shared_key = "model.vlm.language_model.model.shared.weight"
if encoder_key in state_dict:
state_dict[shared_key] = state_dict[encoder_key]
# or deepcopy
# step 4: load into instance
instance.load_state_dict(state_dict, strict=True)
logging.info("Loaded XVLA checkpoint")
# step 5: finalize
# Reapply dtype after loading state dict
instance.model._apply_dtype()
instance.to(config.device)
instance.eval()
return instance
def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float = 0.0) -> torch.Tensor:
if img.ndim != 4:
raise ValueError(f"(b,c,h,w) expected, but got {img.shape}")
current_height, current_width = img.shape[2:]
if current_height == height and current_width == width:
return img
ratio = max(current_width / width, current_height / height)
resized_height = int(current_height / ratio)
resized_width = int(current_width / ratio)
resized_img = F.interpolate(
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
pad_height = max(0, height - resized_height)
pad_width = max(0, width - resized_width)
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img
def pad_vector(vector: Tensor, new_dim: int) -> Tensor:
if vector.shape[-1] == new_dim:
return vector
if new_dim == 0:
shape = list(vector.shape)
shape[-1] = 0
return vector.new_zeros(*shape)
shape = list(vector.shape)
current_dim = shape[-1]
shape[-1] = new_dim
new_vector = vector.new_zeros(*shape)
length = min(current_dim, new_dim)
new_vector[..., :length] = vector[..., :length]
return new_vector
def pad_tensor_along_dim(tensor: Tensor, target_len: int, dim: int = 1) -> Tensor:
current_len = tensor.size(dim)
if current_len == target_len:
return tensor
if current_len > target_len:
slices = [slice(None)] * tensor.dim()
slices[dim] = slice(0, target_len)
return tensor[tuple(slices)]
pad_shape = list(tensor.shape)
pad_shape[dim] = target_len - current_len
pad_tensor = tensor.new_zeros(pad_shape)
return torch.cat([tensor, pad_tensor], dim=dim)

View File

@@ -1,551 +0,0 @@
# ------------------------------------------------------------------------------
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.datasets.factory import IMAGENET_STATS
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
ObservationProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_IMAGES,
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
def make_xvla_pre_post_processors(
config: XVLAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Build the LeRobot processor pipelines for XVLA.
"""
features = {**config.input_features, **config.output_features}
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
TokenizerProcessorStep(
tokenizer_name=config.tokenizer_name,
max_length=config.tokenizer_max_length,
padding=config.pad_language_to,
padding_side=config.tokenizer_padding_side,
),
XVLAImageToFloatProcessorStep(),
XVLAImageNetNormalizeProcessorStep(),
XVLAAddDomainIdProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features=features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
# Custom XVLA processor steps
@dataclass
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
This step handles the specific observation structure from LIBERO environments,
which includes nested robot_state dictionaries and image observations.
**State Processing:**
- Processes the `robot_state` dictionary which contains nested end-effector,
gripper, and joint information.
- Extracts and concatenates:
- End-effector position (3D)
- End-effector quaternion converted to axis-angle (3D)
- Gripper joint positions (2D)
- Maps the concatenated state to `"observation.state"`.
**Image Processing:**
- Rotates images by 180 degrees by flipping both height and width dimensions.
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
"""
def _process_observation(self, observation):
"""
Processes both image and robot_state observations from LIBERO.
"""
processed_obs = observation.copy()
for key in list(processed_obs.keys()):
if key.startswith(f"{OBS_IMAGES}."):
img = processed_obs[key]
if key == f"{OBS_IMAGES}.image":
# Flip both H and W
img = torch.flip(img, dims=[2, 3])
processed_obs[key] = img
# Process robot_state into a flat state vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
# Extract components
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
eef_mat = robot_state["eef"]["mat"] # (B, 3, 3)
eef_rot6d = self._mat_to_rotate6d(eef_mat) # (B, 6)
extra = torch.zeros((eef_pos.shape[0], 1), dtype=torch.float32, device=eef_pos.device)
proprio_state = torch.cat((eef_pos, eef_rot6d, extra), dim=-1) # (B, 10)
state = torch.cat((proprio_state, torch.zeros_like(proprio_state)), dim=-1) # (B, 20)
# ensure float32
state = state.float()
if state.dim() == 1:
state = state.unsqueeze(0)
processed_obs[OBS_STATE] = state
return processed_obs
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Transforms feature keys from the LIBERO format to the LeRobot standard.
"""
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
# copy over non-STATE features
for ft, feats in features.items():
if ft != PipelineFeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
state_feats = {}
# add our new flattened state
state_feats["observation.state"] = PolicyFeature(
key="observation.state",
shape=(20,),
dtype="float32",
)
new_features[PipelineFeatureType.STATE] = state_feats
return new_features
def _mat_to_rotate6d(self, rot_mats: torch.Tensor) -> torch.Tensor:
"""
Convert batched rotation matrices (B, 3, 3) into 6D rotation representation (B, 6).
Args:
rot_mats (Tensor): Rotation matrices of shape (B, 3, 3)
Returns:
Tensor: 6D rotation representation, shape (B, 6)
Raises:
TypeError: if input is not a torch tensor
ValueError: if shape is not (B, 3, 3)
"""
if not isinstance(rot_mats, torch.Tensor):
raise TypeError(f"mat_to_rot6d expects a torch.Tensor, got {type(rot_mats)}")
if rot_mats.ndim != 3 or rot_mats.shape[1:] != (3, 3):
raise ValueError(f"mat_to_rot6d expects shape (B, 3, 3), got {tuple(rot_mats.shape)}")
rot_mats = rot_mats.to(torch.float32)
col1 = rot_mats[:, :3, 0] # (B, 3)
col2 = rot_mats[:, :3, 1] # (B, 3)
rot6d = torch.cat([col1, col2], dim=-1) # (B, 6)
return rot6d
def observation(self, observation):
return self._process_observation(observation)
@dataclass
@ProcessorStepRegistry.register(name="xvla_image_scale")
class XVLAImageScaleProcessorStep(ProcessorStep):
"""Scale image observations by 255 to convert from [0, 1] to [0, 255] range.
This processor step multiplies all image observations by 255, which is required
for XVLA models that expect images in uint8-like range.
Args:
image_keys: List of observation keys that contain images to scale.
If None, will automatically detect keys starting with "observation.images."
"""
image_keys: list[str] | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Scale image observations by 255."""
new_transition = transition.copy()
obs = new_transition.get(TransitionKey.OBSERVATION, {})
if obs is None:
return new_transition
# Make a copy of observations to avoid modifying the original
obs = obs.copy()
# Determine which keys to scale
keys_to_scale = self.image_keys
if keys_to_scale is None:
# Auto-detect image keys
keys_to_scale = [k for k in obs if k.startswith("observation.images.")]
# Scale each image
for key in keys_to_scale:
if key in obs and isinstance(obs[key], torch.Tensor):
obs[key] = obs[key] * 255
new_transition[TransitionKey.OBSERVATION] = obs
return new_transition
def transform_features(self, features):
"""Image scaling doesn't change feature structure."""
return features
def get_config(self) -> dict[str, Any]:
"""Return serializable configuration."""
return {
"image_keys": self.image_keys,
}
@dataclass
@ProcessorStepRegistry.register(name="xvla_image_to_float")
class XVLAImageToFloatProcessorStep(ProcessorStep):
"""Convert image observations from [0, 255] to [0, 1] range.
This processor step divides image observations by 255 to convert from uint8-like
range [0, 255] to float range [0, 1]. This is typically used when loading images
that are stored as uint8 values.
Args:
image_keys: List of observation keys that contain images to convert.
If None, will automatically detect keys starting with "observation.images."
validate_range: If True, validates that input values are in [0, 255] range (default: True)
Raises:
ValueError: If validate_range is True and image values are not in [0, 255] range.
"""
image_keys: list[str] | None = None
validate_range: bool = True
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Convert image observations from [0, 255] to [0, 1]."""
new_transition = transition.copy()
obs = new_transition.get(TransitionKey.OBSERVATION, {})
if obs is None:
return new_transition
# Make a copy of observations to avoid modifying the original
obs = obs.copy()
# Determine which keys to convert
keys_to_convert = self.image_keys
if keys_to_convert is None:
# Auto-detect image keys
keys_to_convert = [k for k in obs if k.startswith("observation.images.")]
# Convert each image
for key in keys_to_convert:
if key in obs and isinstance(obs[key], torch.Tensor):
tensor = obs[key]
# Validate that values are in [0, 255] range if requested
if self.validate_range:
min_val = tensor.min().item()
max_val = tensor.max().item()
if min_val < 0.0 or max_val > 255.0:
raise ValueError(
f"Image '{key}' has values outside [0, 255] range: "
f"min={min_val:.4f}, max={max_val:.4f}. "
f"Cannot convert to [0, 1] range."
)
# Convert to float and divide by 255
obs[key] = tensor.float() / 255.0
new_transition[TransitionKey.OBSERVATION] = obs
return new_transition
def transform_features(self, features):
"""Image conversion doesn't change feature structure."""
return features
def get_config(self) -> dict[str, Any]:
"""Return serializable configuration."""
return {
"image_keys": self.image_keys,
"validate_range": self.validate_range,
}
@dataclass
@ProcessorStepRegistry.register(name="xvla_imagenet_normalize")
class XVLAImageNetNormalizeProcessorStep(ProcessorStep):
"""Normalize image observations using ImageNet statistics.
This processor step applies ImageNet normalization (mean and std) to image observations.
It validates that input values are in the [0, 1] range before normalizing.
The normalization formula is: (image - mean) / std
Args:
image_keys: List of observation keys that contain images to normalize.
If None, will automatically detect keys starting with "observation.images."
Raises:
ValueError: If image values are not in the [0, 1] range.
"""
image_keys: list[str] | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Normalize image observations using ImageNet statistics."""
new_transition = transition.copy()
obs = new_transition.get(TransitionKey.OBSERVATION, {})
if obs is None:
return new_transition
# Make a copy of observations to avoid modifying the original
obs = obs.copy()
# Determine which keys to normalize
keys_to_normalize = self.image_keys
if keys_to_normalize is None:
# Auto-detect image keys
keys_to_normalize = [k for k in obs if k.startswith("observation.images.")]
# Normalize each image
for key in keys_to_normalize:
if key in obs and isinstance(obs[key], torch.Tensor):
tensor = obs[key]
# Validate that values are in [0, 1] range
min_val = tensor.min().item()
max_val = tensor.max().item()
if min_val < 0.0 or max_val > 1.0:
raise ValueError(
f"Image '{key}' has values outside [0, 1] range: "
f"min={min_val:.4f}, max={max_val:.4f}. "
f"ImageNet normalization requires input values in [0, 1]."
)
# Apply ImageNet normalization
mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
# Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
while mean.dim() < tensor.dim():
mean = mean.unsqueeze(0)
std = std.unsqueeze(0)
# Normalize: (image - mean) / std
obs[key] = (tensor - mean) / std
new_transition[TransitionKey.OBSERVATION] = obs
return new_transition
def transform_features(self, features):
"""ImageNet normalization doesn't change feature structure."""
return features
def get_config(self) -> dict[str, Any]:
"""Return serializable configuration."""
return {
"image_keys": self.image_keys,
}
@dataclass
@ProcessorStepRegistry.register(name="xvla_add_domain_id")
class XVLAAddDomainIdProcessorStep(ProcessorStep):
"""Add domain_id to complementary data.
This processor step adds a domain_id tensor to the complementary data,
which is used by XVLA to identify different robot embodiments or task domains.
Args:
domain_id: The domain ID to add (default: 3)
"""
domain_id: int = 0
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Add domain_id to complementary data."""
new_transition = transition.copy()
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp = {} if comp is None else comp.copy()
# Infer batch size from observation tensors
obs = new_transition.get(TransitionKey.OBSERVATION, {})
batch_size = 1
if obs:
for v in obs.values():
if isinstance(v, torch.Tensor):
batch_size = v.shape[0]
break
# Add domain_id tensor
comp["domain_id"] = torch.tensor([int(self.domain_id)] * batch_size, dtype=torch.long)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp
return new_transition
def transform_features(self, features):
"""Domain ID addition doesn't change feature structure."""
return features
def get_config(self) -> dict[str, Any]:
"""Return serializable configuration."""
return {
"domain_id": self.domain_id,
}
@dataclass
@ProcessorStepRegistry.register(name="xvla_rotation_6d_to_axis_angle")
class XVLARotation6DToAxisAngleProcessorStep(ProcessorStep):
"""Convert 6D rotation representation to axis-angle and reorganize action dimensions.
This processor step takes actions with 6D rotation representation and converts them to
axis-angle representation, reorganizing the action dimensions as:
- action[:, :3] -> target_eef (end-effector position)
- action[:, 3:9] -> 6D rotation (converted to axis-angle, 3D)
- action[:, 9:10] -> gripper action
Final output: [target_eef (3), axis_angle (3), gripper (1)] = 7D action
Args:
expected_action_dim: Expected input action dimension (default: 10, supports 6D rotation + extras)
"""
expected_action_dim: int = 10
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Convert 6D rotation to axis-angle in action."""
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None or not isinstance(action, torch.Tensor):
return new_transition
# Convert to numpy for processing
device = action.device
dtype = action.dtype
action_np = action.cpu().numpy()
# Extract components
# action shape: (B, D) where D >= 10
target_eef = action_np[:, :3] # (B, 3)
rotation_6d = action_np[:, 3:9] # (B, 6)
target_act = action_np[:, 9:10] # (B, 1)
# Convert 6D rotation to axis-angle
target_axis = rotate6d_to_axis_angle(rotation_6d) # (B, 3)
# Concatenate: [eef (3), axis_angle (3), gripper (1)] = 7D
action_np = np.concatenate([target_eef, target_axis, target_act], axis=-1)
# Convert gripper action to -1 or 1
action_np[:, -1] = np.where(action_np[:, -1] > 0.5, 1.0, -1.0)
# Convert back to tensor
action = torch.from_numpy(action_np).to(device=device, dtype=dtype)
new_transition[TransitionKey.ACTION] = action
return new_transition
def transform_features(self, features):
"""Rotation conversion changes action dimension from 10 to 7."""
# Note: This is a simplified version. In practice, you might want to
# update the action feature shape in the features dict.
return features
def get_config(self) -> dict[str, Any]:
"""Return serializable configuration."""
return {
"expected_action_dim": self.expected_action_dim,
}
def make_xvla_libero_pre_post_processors() -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Build the LeRobot processor pipelines for XVLA with LIBERO environment.
"""
pre_processor_steps: list[ProcessorStep] = []
post_processor_steps: list[ProcessorStep] = []
pre_processor_steps.extend(
[LiberoProcessorStep(), XVLAImageNetNormalizeProcessorStep(), XVLAAddDomainIdProcessorStep()]
)
post_processor_steps.extend([XVLARotation6DToAxisAngleProcessorStep()])
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=pre_processor_steps,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=post_processor_steps,
),
)

View File

@@ -1,415 +0,0 @@
# ------------------------------------------------------------------------------
# Copyright 2025 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
import math
from collections.abc import Iterable
from functools import partial
from typing import Final
import torch
import torch.nn as nn
import torch.nn.functional as functional
# ------------------------------- Small utils ----------------------------------
def _to_2tuple(x) -> tuple:
"""Minimal replacement for timm.layers.to_2tuple."""
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
t = tuple(x)
return (t[0], t[1]) if len(t) >= 2 else (t[0], t[0])
return (x, x)
def _has_sdp_attention() -> bool:
"""Check if we can use PyTorch fused scaled_dot_product_attention."""
return hasattr(functional, "scaled_dot_product_attention")
# ---------------------------------- MLP --------------------------------------
class Mlp(nn.Module):
"""
MLP used in ViT-style blocks.
Supports Linear or 1x1 Conv 'linear_layer' for token/channel mixing.
"""
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
norm_layer: type[nn.Module] | None = None,
bias: bool | tuple[bool, bool] = True,
drop: float | tuple[float, float] = 0.0,
use_conv: bool = False,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = _to_2tuple(bias)
drop_probs = _to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = nn.GELU(approximate="tanh")
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Expect [B, T, C] for Linear variant; caller is responsible for shapes.
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
# -------------------------------- Attention ----------------------------------
class Attention(nn.Module):
"""
Multi-Head Self-Attention with optional fused SDPA fallback.
If PyTorch provides `scaled_dot_product_attention`, it will be used
(usually faster and more stable); otherwise we use a manual implementation.
"""
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: type[nn.Module] = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.fused_attn = _has_sdp_attention()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Parameters
----------
x : Tensor, shape [batch_size, seq_len, channels]
Input sequence.
Returns
-------
Tensor, shape [batch_size, seq_len, channels]
Output sequence after MHSA + projection.
"""
batch_size, seq_len, channels = x.shape
qkv = (
self.qkv(x)
.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4) # 3 x [batch_size, num_heads, seq_len, head_dim]
)
q, k, v = qkv.unbind(0) # each: [batch_size, num_heads, seq_len, head_dim]
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = functional.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
) # [batch_size, num_heads, seq_len, head_dim]
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1) # [batch_size, num_heads, seq_len, seq_len]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v # [batch_size, num_heads, seq_len, head_dim]
x = x.transpose(1, 2).reshape(batch_size, seq_len, channels) # [batch_size, seq_len, channels]
x = self.proj(x)
x = self.proj_drop(x)
return x
# ------------------------------- Utilities -----------------------------------
def basic_init(module: nn.Module) -> None:
"""
Apply a basic initialization scheme to Linear layers.
- Weight: Xavier uniform initialization.
- Bias: Set to zero.
"""
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torch.Tensor:
"""
Create sinusoidal timestep embeddings.
Parameters
----------
t : torch.Tensor
Shape [B]. Each element is a timestep index, may be fractional.
dim : int
Dimensionality of the output embedding.
max_period : int, default=100
Controls the minimum frequency of the sinusoids.
Returns
-------
torch.Tensor
Shape [B, dim]. Sinusoidal embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=t.dtype, device=t.device) / half
)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2 == 1:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
# ------------------------------- Core Layers ----------------------------------
class DomainAwareLinear(nn.Module):
"""
Linear layer with domain-conditioned parameters (per-sample).
Each domain has its own weight and bias vectors, stored in embeddings.
"""
def __init__(self, input_size: int, output_size: int, num_domains: int = 20) -> None:
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.fc = nn.Embedding(num_domains, output_size * input_size)
self.bias = nn.Embedding(num_domains, output_size)
nn.init.xavier_uniform_(self.fc.weight)
nn.init.zeros_(self.bias.weight)
def forward(self, x: torch.Tensor, domain_id: torch.LongTensor) -> torch.Tensor:
"""
Parameters
----------
x : Tensor
[B, I] or [B, T, I]
domain_id : LongTensor
[B], domain indices.
Returns
-------
Tensor
[batch_size, output_size] or [batch_size, seq_len, output_size]
"""
batch_size = domain_id.shape[0]
squeeze_seq = False
if x.dim() == 2:
x = x.unsqueeze(1)
squeeze_seq = True
weight = self.fc(domain_id).view(batch_size, self.input_size, self.output_size)
bias = self.bias(domain_id).view(batch_size, self.output_size)
y = torch.matmul(x, weight) + bias.view(batch_size, 1, self.output_size)
if squeeze_seq:
y = y.squeeze(1)
return y
class TransformerBlock(nn.Module):
"""
Standard Transformer block (pre-LN): LN → MHSA → residual, LN → MLP → residual.
"""
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, attn_drop=0.1)
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=int(hidden_size * mlp_ratio),
drop=0.1,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Parameters
----------
x : Tensor, [B, T, H]
Returns
-------
Tensor, [B, T, H]
"""
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
# --------------------------- Main Model ---------------------------------------
class SoftPromptedTransformer(nn.Module):
"""
Multi-modal, domain-aware Transformer with optional soft prompts.
See parameter and forward I/O descriptions inside the docstrings.
"""
def __init__(
self,
hidden_size: int = 768,
multi_modal_input_size: int = 768,
depth: int = 24,
num_heads: int = 16,
mlp_ratio: float = 4.0,
num_domains: int = 20,
dim_action: int = 20,
dim_propio: int = 20,
dim_time: int = 32,
len_soft_prompts: int = 32,
max_len_seq: int = 512,
use_hetero_proj: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.dim_action = dim_action
self.dim_time = dim_time
self.len_soft_prompts = len_soft_prompts
self.use_hetero_proj = use_hetero_proj
self.blocks = nn.ModuleList(
[TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)]
)
if use_hetero_proj:
self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
self.aux_visual_proj = DomainAwareLinear(
multi_modal_input_size, hidden_size, num_domains=num_domains
)
else:
self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
self.pos_emb = nn.Parameter(torch.zeros(1, max_len_seq, hidden_size), requires_grad=True)
nn.init.normal_(self.pos_emb, std=0.02)
self.norm = nn.LayerNorm(hidden_size)
self.action_encoder = DomainAwareLinear(
dim_action + dim_time + dim_propio, hidden_size, num_domains=num_domains
)
self.action_decoder = DomainAwareLinear(hidden_size, dim_action, num_domains=num_domains)
if len_soft_prompts > 0:
self.soft_prompt_hub = nn.Embedding(num_domains, len_soft_prompts * hidden_size)
nn.init.normal_(self.soft_prompt_hub.weight, std=0.02)
self.apply(basic_init)
def forward(
self,
domain_id: torch.LongTensor,
vlm_features: torch.Tensor,
aux_visual_inputs: torch.Tensor,
action_with_noise: torch.Tensor,
proprio: torch.Tensor,
t: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass.
Inputs
------
domain_id : [B]
vlm_features : [B, T_vlm, D]
aux_visual_inputs : [B, T_aux, D]
action_with_noise : [B, T_action, dim_action]
proprio : [B, dim_propio]
t : [B]
Returns
-------
Tensor
Predicted actions, [batch_size, num_actions, dim_action]
"""
batch_size, num_actions = action_with_noise.shape[:2]
# Encode (action + proprio + time) → tokens
time_emb = timestep_embedding(t, self.dim_time) # [batch_size, dim_time]
time_tokens = time_emb.unsqueeze(1).expand(batch_size, num_actions, self.dim_time)
proprio_tokens = proprio.unsqueeze(1).expand(batch_size, num_actions, proprio.shape[-1])
action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
x = self.action_encoder(action_tokens, domain_id) # [batch_size, num_actions, hidden_size]
# Project visual streams and concatenate
if self.use_hetero_proj:
x = torch.cat(
[
x,
self.vlm_proj(vlm_features, domain_id),
self.aux_visual_proj(aux_visual_inputs, domain_id),
],
dim=1,
)
else:
x = torch.cat([x, self.vlm_proj(vlm_features), self.aux_visual_proj(aux_visual_inputs)], dim=1)
# Add positional embeddings (truncate if needed)
seq_len = x.shape[1]
if seq_len > self.pos_emb.shape[1]:
raise ValueError(f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}.")
x = x + self.pos_emb[:, :seq_len, :]
# Append soft prompts
if self.len_soft_prompts > 0:
soft_prompts = self.soft_prompt_hub(domain_id).view(
batch_size, self.len_soft_prompts, self.hidden_size
)
x = torch.cat([x, soft_prompts], dim=1)
# Transformer backbone
for block in self.blocks:
x = block(x)
# Decode only the action segment
return self.action_decoder(self.norm(x[:, :num_actions]), domain_id)

View File

@@ -1,138 +0,0 @@
import math
import numpy as np
def mat2quat(rmat):
"""
Converts given rotation matrix to quaternion.
Args:
rmat (np.array): 3x3 rotation matrix
Returns:
np.array: (x,y,z,w) float quaternion angles
"""
mat = np.asarray(rmat).astype(np.float32)[:3, :3]
m00 = mat[0, 0]
m01 = mat[0, 1]
m02 = mat[0, 2]
m10 = mat[1, 0]
m11 = mat[1, 1]
m12 = mat[1, 2]
m20 = mat[2, 0]
m21 = mat[2, 1]
m22 = mat[2, 2]
# symmetric matrix k
k = np.array(
[
[m00 - m11 - m22, np.float32(0.0), np.float32(0.0), np.float32(0.0)],
[m01 + m10, m11 - m00 - m22, np.float32(0.0), np.float32(0.0)],
[m02 + m20, m12 + m21, m22 - m00 - m11, np.float32(0.0)],
[m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22],
]
)
k /= 3.0
# quaternion is Eigen vector of k that corresponds to largest eigenvalue
w, v = np.linalg.eigh(k)
inds = np.array([3, 0, 1, 2])
q1 = v[inds, np.argmax(w)]
if q1[0] < 0.0:
np.negative(q1, q1)
inds = np.array([1, 2, 3, 0])
return q1[inds]
def quat2axisangle(quat):
"""
Converts quaternion to axis-angle format.
Returns a unit vector direction scaled by its angle in radians.
Args:
quat (np.array): (x,y,z,w) vec4 float angles
Returns:
np.array: (ax,ay,az) axis-angle exponential coordinates
"""
# clip quaternion
if quat[3] > 1.0:
quat[3] = 1.0
elif quat[3] < -1.0:
quat[3] = -1.0
den = np.sqrt(1.0 - quat[3] * quat[3])
if math.isclose(den, 0.0):
# This is (close to) a zero degree rotation, immediately return
return np.zeros(3)
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
def rotate6d_to_axis_angle(r6d):
"""
r6d: np.ndarray, shape (N, 6)
return: np.ndarray, shape (N, 3), axis-angle vectors
"""
flag = 0
if len(r6d.shape) == 1:
r6d = r6d[None, ...]
flag = 1
a1 = r6d[:, 0:3]
a2 = r6d[:, 3:6]
# b1
b1 = a1 / (np.linalg.norm(a1, axis=-1, keepdims=True) + 1e-6)
# b2
dot_prod = np.sum(b1 * a2, axis=-1, keepdims=True)
b2_orth = a2 - dot_prod * b1
b2 = b2_orth / (np.linalg.norm(b2_orth, axis=-1, keepdims=True) + 1e-6)
# b3
b3 = np.cross(b1, b2, axis=-1)
rotation_matrix = np.stack([b1, b2, b3], axis=-1) # shape: (N, 3, 3)
axis_angle_list = []
for i in range(rotation_matrix.shape[0]):
quat = mat2quat(rotation_matrix[i])
axis_angle = quat2axisangle(quat)
axis_angle_list.append(axis_angle)
axis_angle_array = np.stack(axis_angle_list, axis=0) # shape: (N, 3)
if flag == 1:
axis_angle_array = axis_angle_array[0]
return axis_angle_array
def mat_to_rotate6d(abs_action):
if len(abs_action.shape) == 2:
return np.concatenate([abs_action[:3, 0], abs_action[:3, 1]], axis=-1)
elif len(abs_action.shape) == 3:
return np.concatenate([abs_action[:, :3, 0], abs_action[:, :3, 1]], axis=-1)
else:
raise NotImplementedError
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor

View File

@@ -78,7 +78,7 @@ from lerobot.transport.utils import (
transitions_to_bytes,
)
from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.transition import (
Transition,
move_state_dict_to_device,
@@ -398,7 +398,7 @@ def act_with_policy(
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time
precise_sleep(1 / cfg.env.fps - dt_time)
busy_wait(1 / cfg.env.fps - dt_time)
# Communication Functions - Group all gRPC/messaging functions

View File

@@ -74,7 +74,7 @@ from lerobot.teleoperators import (
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
logging.basicConfig(level=logging.INFO)
@@ -114,7 +114,7 @@ def reset_follower_position(robot_arm: Robot, target_position: np.ndarray) -> No
for pose in trajectory:
action_dict = dict(zip(current_position_dict, pose, strict=False))
robot_arm.bus.sync_write("Goal_Position", action_dict)
precise_sleep(0.015)
busy_wait(0.015)
class RobotEnv(gym.Env):
@@ -238,7 +238,7 @@ class RobotEnv(gym.Env):
reset_follower_position(self.robot, np.array(self.reset_pose))
log_say("Reset the environment done.", play_sounds=True)
precise_sleep(self.reset_time_s - (time.perf_counter() - start_time))
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
super().reset(seed=seed, options=options)
@@ -713,7 +713,7 @@ def control_loop(
transition = env_processor(transition)
# Maintain fps timing
precise_sleep(dt - (time.perf_counter() - step_start_time))
busy_wait(dt - (time.perf_counter() - step_start_time))
if dataset is not None and cfg.dataset.push_to_hub:
logging.info("Pushing dataset to hub")
@@ -745,7 +745,7 @@ def replay_trajectory(
)
transition = action_processor(transition)
env.step(transition[TransitionKey.ACTION])
precise_sleep(1 / cfg.env.fps - (time.perf_counter() - start_time))
busy_wait(1 / cfg.env.fps - (time.perf_counter() - start_time))
@parser.wrap()

View File

@@ -0,0 +1,18 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_unitree_g1 import UnitreeG1Config
from .unitree_g1 import UnitreeG1

View File

@@ -0,0 +1,110 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from ..config import RobotConfig
@RobotConfig.register_subclass("unitree_g1")
@dataclass
class UnitreeG1Config(RobotConfig):
# id: str = "unitree_g1"
kp: list = field(
default_factory=lambda: [
150,
150,
150,
300,
40,
40, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll
150,
150,
150,
300,
40,
40, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll
250,
250,
250, # Waist yaw, roll, pitch
80,
80,
80,
80, # Left shoulder pitch, roll, yaw, elbow (kp_low)
40,
40,
40, # Left wrist roll, pitch, yaw (kp_wrist)
80,
80,
80,
80, # Right shoulder pitch, roll, yaw, elbow (kp_low)
40,
40,
40, # Right wrist roll, pitch, yaw (kp_wrist)
80,
80,
80,
80,
80,
80, # Other
]
)
kd: list = field(
default_factory=lambda: [
2,
2,
2,
4,
2,
2, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll
2,
2,
2,
4,
2,
2, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll
5,
5,
5, # Waist yaw, roll, pitch
3,
3,
3,
3, # Left shoulder pitch, roll, yaw, elbow (kd_low)
1.5,
1.5,
1.5, # Left wrist roll, pitch, yaw (kd_wrist)
3,
3,
3,
3, # Right shoulder pitch, roll, yaw, elbow (kd_low)
1.5,
1.5,
1.5, # Right wrist roll, pitch, yaw (kd_wrist)
3,
3,
3,
3,
3,
3, # Other
]
)
control_dt = 1.0 / 250.0 # 250Hz
# socket config for ZMQ bridge
robot_ip: str = "172.18.129.215"

View File

@@ -0,0 +1,73 @@
from enum import IntEnum
# ruff: noqa: N801, N815
NUM_MOTORS = 35
class G1_29_JointArmIndex(IntEnum):
# Left arm
kLeftShoulderPitch = 15
kLeftShoulderRoll = 16
kLeftShoulderYaw = 17
kLeftElbow = 18
kLeftWristRoll = 19
kLeftWristPitch = 20
kLeftWristyaw = 21
# Right arm
kRightShoulderPitch = 22
kRightShoulderRoll = 23
kRightShoulderYaw = 24
kRightElbow = 25
kRightWristRoll = 26
kRightWristPitch = 27
kRightWristYaw = 28
class G1_29_JointIndex(IntEnum):
# Left leg
kLeftHipPitch = 0
kLeftHipRoll = 1
kLeftHipYaw = 2
kLeftKnee = 3
kLeftAnklePitch = 4
kLeftAnkleRoll = 5
# Right leg
kRightHipPitch = 6
kRightHipRoll = 7
kRightHipYaw = 8
kRightKnee = 9
kRightAnklePitch = 10
kRightAnkleRoll = 11
kWaistYaw = 12
kWaistRoll = 13
kWaistPitch = 14
# Left arm
kLeftShoulderPitch = 15
kLeftShoulderRoll = 16
kLeftShoulderYaw = 17
kLeftElbow = 18
kLeftWristRoll = 19
kLeftWristPitch = 20
kLeftWristyaw = 21
# Right arm
kRightShoulderPitch = 22
kRightShoulderRoll = 23
kRightShoulderYaw = 24
kRightElbow = 25
kRightWristRoll = 26
kRightWristPitch = 27
kRightWristYaw = 28
# not used
kNotUsedJoint0 = 29
kNotUsedJoint1 = 30
kNotUsedJoint2 = 31
kNotUsedJoint3 = 32
kNotUsedJoint4 = 33
kNotUsedJoint5 = 34

View File

@@ -0,0 +1,212 @@
#!/usr/bin/env python3
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DDS-to-ZMQ bridge server for Unitree G1 robot.
This server runs on the robot and forwards:
- Robot state (LowState) from DDS to ZMQ (for remote clients)
- Robot commands (LowCmd) from ZMQ to DDS (from remote clients)
Uses JSON for secure serialization instead of pickle.
"""
import base64
import contextlib
import json
import threading
import time
from typing import Any
import zmq
from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import MotionSwitcherClient
from unitree_sdk2py.core.channel import ChannelFactoryInitialize, ChannelPublisher, ChannelSubscriber
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState
from unitree_sdk2py.utils.crc import CRC
# DDS topic names follow Unitree SDK naming conventions
# ruff: noqa: N816
kTopicLowCommand_Debug = "rt/lowcmd" # action to robot
kTopicLowState = "rt/lowstate" # observation from robot
LOWCMD_PORT = 6000
LOWSTATE_PORT = 6001
NUM_MOTORS = 35
def lowstate_to_dict(msg: hg_LowState) -> dict[str, Any]:
"""Convert LowState SDK message to a JSON-serializable dictionary."""
motor_states = []
for i in range(NUM_MOTORS):
temp = msg.motor_state[i].temperature
avg_temp = float(sum(temp) / len(temp)) if isinstance(temp, list) else float(temp)
motor_states.append(
{
"q": float(msg.motor_state[i].q),
"dq": float(msg.motor_state[i].dq),
"tau_est": float(msg.motor_state[i].tau_est),
"temperature": avg_temp,
}
)
return {
"motor_state": motor_states,
"imu_state": {
"quaternion": [float(x) for x in msg.imu_state.quaternion],
"gyroscope": [float(x) for x in msg.imu_state.gyroscope],
"accelerometer": [float(x) for x in msg.imu_state.accelerometer],
"rpy": [float(x) for x in msg.imu_state.rpy],
"temperature": float(msg.imu_state.temperature),
},
# Encode bytes as base64 for JSON compatibility
"wireless_remote": base64.b64encode(bytes(msg.wireless_remote)).decode("ascii"),
"mode_machine": int(msg.mode_machine),
}
def dict_to_lowcmd(data: dict[str, Any]) -> hg_LowCmd:
"""Convert dictionary back to LowCmd SDK message."""
cmd = unitree_hg_msg_dds__LowCmd_()
cmd.mode_pr = data.get("mode_pr", 0)
cmd.mode_machine = data.get("mode_machine", 0)
for i, motor_data in enumerate(data.get("motor_cmd", [])):
cmd.motor_cmd[i].mode = motor_data.get("mode", 0)
cmd.motor_cmd[i].q = motor_data.get("q", 0.0)
cmd.motor_cmd[i].dq = motor_data.get("dq", 0.0)
cmd.motor_cmd[i].kp = motor_data.get("kp", 0.0)
cmd.motor_cmd[i].kd = motor_data.get("kd", 0.0)
cmd.motor_cmd[i].tau = motor_data.get("tau", 0.0)
return cmd
def state_forward_loop(
lowstate_sub: ChannelSubscriber,
lowstate_sock: zmq.Socket,
state_period: float,
) -> None:
"""Read observation from DDS and forward to ZMQ clients."""
last_state_time = 0.0
while True:
# read from DDS
msg = lowstate_sub.Read()
if msg is None:
continue
now = time.time()
# optional downsampling (if robot dds rate > state_period)
if now - last_state_time >= state_period:
# Convert to dict and serialize with JSON
state_dict = lowstate_to_dict(msg)
payload = json.dumps({"topic": kTopicLowState, "data": state_dict}).encode("utf-8")
# if no subscribers / tx buffer full, just drop
with contextlib.suppress(zmq.Again):
lowstate_sock.send(payload, zmq.NOBLOCK)
last_state_time = now
def cmd_forward_loop(
lowcmd_sock: zmq.Socket,
lowcmd_pub_debug: ChannelPublisher,
crc: CRC,
) -> None:
"""Receive commands from ZMQ and forward to DDS."""
while True:
payload = lowcmd_sock.recv()
msg_dict = json.loads(payload.decode("utf-8"))
topic = msg_dict.get("topic", "")
cmd_data = msg_dict.get("data", {})
# Reconstruct LowCmd object from dict
cmd = dict_to_lowcmd(cmd_data)
# recompute crc
cmd.crc = crc.Crc(cmd)
if topic == kTopicLowCommand_Debug:
lowcmd_pub_debug.Write(cmd)
def main() -> None:
"""Main entry point for the robot server bridge."""
# initialize DDS
ChannelFactoryInitialize(0)
# stop all active publishers on the robot
msc = MotionSwitcherClient()
msc.SetTimeout(5.0)
msc.Init()
status, result = msc.CheckMode()
while result is not None and "name" in result and result["name"]:
msc.ReleaseMode()
status, result = msc.CheckMode()
time.sleep(1.0)
crc = CRC()
# initialize DDS publisher
lowcmd_pub_debug = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
lowcmd_pub_debug.Init()
# initialize DDS subscriber
lowstate_sub = ChannelSubscriber(kTopicLowState, hg_LowState)
lowstate_sub.Init()
# initialize ZMQ
ctx = zmq.Context.instance()
# receive commands from remote client
lowcmd_sock = ctx.socket(zmq.PULL)
lowcmd_sock.bind(f"tcp://0.0.0.0:{LOWCMD_PORT}")
# publish state to remote clients
lowstate_sock = ctx.socket(zmq.PUB)
lowstate_sock.bind(f"tcp://0.0.0.0:{LOWSTATE_PORT}")
state_period = 0.002 # ~500 hz
# start observation forwarding thread
t_state = threading.Thread(
target=state_forward_loop,
args=(lowstate_sub, lowstate_sock, state_period),
daemon=True,
)
t_state.start()
# start action forwarding thread
t_cmd = threading.Thread(
target=cmd_forward_loop,
args=(lowcmd_sock, lowcmd_pub_debug, crc),
daemon=True,
)
t_cmd.start()
print("bridge running (lowstate -> zmq, lowcmd -> dds)")
# keep main thread alive so daemon threads don't exit
try:
while True:
time.sleep(1.0)
except KeyboardInterrupt:
print("shutting down bridge...")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,264 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import struct
import threading
import time
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any
import numpy as np
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
LowCmd_ as hg_LowCmd,
LowState_ as hg_LowState,
)
from unitree_sdk2py.utils.crc import CRC
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
from lerobot.robots.unitree_g1.unitree_sdk2_socket import (
ChannelFactoryInitialize,
ChannelPublisher,
ChannelSubscriber,
)
from ..robot import Robot
from .config_unitree_g1 import UnitreeG1Config
logger = logging.getLogger(__name__)
# DDS topic names follow Unitree SDK naming conventions
# ruff: noqa: N816
kTopicLowCommand_Debug = "rt/lowcmd"
kTopicLowState = "rt/lowstate"
G1_29_Num_Motors = 35
G1_23_Num_Motors = 35
H1_2_Num_Motors = 35
H1_Num_Motors = 20
@dataclass
class MotorState:
q: float | None = None # position
dq: float | None = None # velocity
tau_est: float | None = None # estimated torque
temperature: float | None = None # motor temperature
@dataclass
class IMUState:
quaternion: np.ndarray | None = None # [w, x, y, z]
gyroscope: np.ndarray | None = None # [x, y, z] angular velocity (rad/s)
accelerometer: np.ndarray | None = None # [x, y, z] linear acceleration (m/s²)
rpy: np.ndarray | None = None # [roll, pitch, yaw] (rad)
temperature: float | None = None # IMU temperature
# g1 observation class
@dataclass
class G1_29_LowState: # noqa: N801
motor_state: list[MotorState] = field(default_factory=lambda: [MotorState() for _ in range(G1_29_Num_Motors)])
imu_state: IMUState = field(default_factory=IMUState)
wireless_remote: Any = None # Raw wireless remote data
mode_machine: int = 0 # Robot mode
class DataBuffer:
def __init__(self):
self.data = None
self.lock = threading.Lock()
def get_data(self):
with self.lock:
return self.data
def set_data(self, data):
with self.lock:
self.data = data
class UnitreeG1(Robot):
config_class = UnitreeG1Config
name = "unitree_g1"
# unitree remote controller
class RemoteController:
def __init__(self):
self.lx = 0
self.ly = 0
self.rx = 0
self.ry = 0
self.button = [0] * 16
def set(self, data):
# wireless_remote
keys = struct.unpack("H", data[2:4])[0]
for i in range(16):
self.button[i] = (keys & (1 << i)) >> i
self.lx = struct.unpack("f", data[4:8])[0]
self.rx = struct.unpack("f", data[8:12])[0]
self.ry = struct.unpack("f", data[12:16])[0]
self.ly = struct.unpack("f", data[20:24])[0]
def __init__(self, config: UnitreeG1Config):
super().__init__(config)
logger.info("Initialize UnitreeG1...")
self.config = config
self.control_dt = config.control_dt
# connect robot
self.connect()
# initialize direct motor control interface
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
self.lowcmd_publisher.Init()
self.lowstate_subscriber = ChannelSubscriber(kTopicLowState, hg_LowState)
self.lowstate_subscriber.Init()
self.lowstate_buffer = DataBuffer()
# initialize subscribe thread to read robot state
self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state)
self.subscribe_thread.daemon = True
self.subscribe_thread.start()
while not self.is_connected:
time.sleep(0.1)
# initialize hg's lowcmd msg
self.crc = CRC()
self.msg = unitree_hg_msg_dds__LowCmd_()
self.msg.mode_pr = 0
# Wait for first state message to arrive
lowstate = None
while lowstate is None:
lowstate = self.lowstate_buffer.get_data()
if lowstate is None:
time.sleep(0.01)
logger.warning("[UnitreeG1] Waiting for robot state...")
logger.warning("[UnitreeG1] Connected to robot.")
self.msg.mode_machine = lowstate.mode_machine
# initialize all motors with unified kp/kd from config
self.kp = np.array(config.kp, dtype=np.float32)
self.kd = np.array(config.kd, dtype=np.float32)
for id in G1_29_JointIndex:
self.msg.motor_cmd[id].mode = 1
self.msg.motor_cmd[id].kp = self.kp[id.value]
self.msg.motor_cmd[id].kd = self.kd[id.value]
self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q
# Initialize remote controller
self.remote_controller = self.RemoteController()
def _subscribe_motor_state(self): # polls robot state @ 250Hz
while True:
start_time = time.time()
msg = self.lowstate_subscriber.Read()
if msg is not None:
lowstate = G1_29_LowState()
# Capture motor states
for id in range(G1_29_Num_Motors):
lowstate.motor_state[id].q = msg.motor_state[id].q
lowstate.motor_state[id].dq = msg.motor_state[id].dq
lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est
lowstate.motor_state[id].temperature = msg.motor_state[id].temperature
# Capture IMU state
lowstate.imu_state.quaternion = list(msg.imu_state.quaternion)
lowstate.imu_state.gyroscope = list(msg.imu_state.gyroscope)
lowstate.imu_state.accelerometer = list(msg.imu_state.accelerometer)
lowstate.imu_state.rpy = list(msg.imu_state.rpy)
lowstate.imu_state.temperature = msg.imu_state.temperature
# Capture wireless remote data
lowstate.wireless_remote = msg.wireless_remote
# Capture mode_machine
lowstate.mode_machine = msg.mode_machine
self.lowstate_buffer.set_data(lowstate)
current_time = time.time()
all_t_elapsed = current_time - start_time
sleep_time = max(0, (self.control_dt - all_t_elapsed)) # maintina constant control dt
time.sleep(sleep_time)
@cached_property
def action_features(self) -> dict[str, type]:
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
def calibrate(self) -> None: # robot is already calibrated
pass
def configure(self) -> None:
pass
def connect(self, calibrate: bool = True) -> None: # connect to DDS
ChannelFactoryInitialize(0)
def disconnect(self):
pass
def get_observation(self) -> dict[str, Any]:
return self.lowstate_buffer.get_data()
@property
def is_calibrated(self) -> bool:
return True
@property
def is_connected(self) -> bool:
return self.lowstate_buffer.get_data() is not None
@property
def _motors_ft(self) -> dict[str, type]:
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
self.msg.crc = self.crc.Crc(action)
self.lowcmd_publisher.Write(action)
def get_gravity_orientation(self, quaternion): # get gravity orientation from quaternion
"""Get gravity orientation from quaternion."""
qw = quaternion[0]
qx = quaternion[1]
qy = quaternion[2]
qz = quaternion[3]
gravity_orientation = np.zeros(3)
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
return gravity_orientation

View File

@@ -0,0 +1,162 @@
"""
ZMQ socket wrapper that mimics the Unitree SDK Channel interface.
This module provides a drop-in replacement for the Unitree SDK's DDS-based
ChannelPublisher and ChannelSubscriber, using ZMQ sockets instead. This allows
remote communication with the robot over WiFi via the robot_server bridge.
Uses JSON for secure serialization instead of pickle.
"""
import base64
import json
from typing import Any
import zmq
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
_ctx: zmq.Context | None = None
_lowcmd_sock: zmq.Socket | None = None
_lowstate_sock: zmq.Socket | None = None
LOWCMD_PORT = 6000
LOWSTATE_PORT = 6001
# DDS topic names follow Unitree SDK naming conventions
# ruff: noqa: N816
kTopicLowCommand_Debug = "rt/lowcmd"
class LowStateMsg:
"""
Wrapper class that mimics the Unitree SDK LowState_ message structure.
Reconstructs the message from deserialized JSON data to maintain
compatibility with existing code that expects SDK message objects.
"""
class MotorState:
"""Motor state data for a single joint."""
def __init__(self, data: dict[str, Any]) -> None:
self.q: float = data.get("q", 0.0)
self.dq: float = data.get("dq", 0.0)
self.tau_est: float = data.get("tau_est", 0.0)
self.temperature: float = data.get("temperature", 0.0)
class IMUState:
"""IMU sensor data."""
def __init__(self, data: dict[str, Any]) -> None:
self.quaternion: list[float] = data.get("quaternion", [1.0, 0.0, 0.0, 0.0])
self.gyroscope: list[float] = data.get("gyroscope", [0.0, 0.0, 0.0])
self.accelerometer: list[float] = data.get("accelerometer", [0.0, 0.0, 0.0])
self.rpy: list[float] = data.get("rpy", [0.0, 0.0, 0.0])
self.temperature: float = data.get("temperature", 0.0)
def __init__(self, data: dict[str, Any]) -> None:
"""Initialize from deserialized JSON data."""
self.motor_state = [self.MotorState(m) for m in data.get("motor_state", [])]
self.imu_state = self.IMUState(data.get("imu_state", {}))
# Decode base64-encoded wireless_remote bytes
wireless_b64 = data.get("wireless_remote", "")
self.wireless_remote: bytes = base64.b64decode(wireless_b64) if wireless_b64 else b""
self.mode_machine: int = data.get("mode_machine", 0)
def lowcmd_to_dict(topic: str, msg: Any) -> dict[str, Any]:
"""Convert LowCmd message to a JSON-serializable dictionary."""
motor_cmds = []
# Iterate over all motor commands in the message
for i in range(len(msg.motor_cmd)):
motor_cmds.append(
{
"mode": int(msg.motor_cmd[i].mode),
"q": float(msg.motor_cmd[i].q),
"dq": float(msg.motor_cmd[i].dq),
"kp": float(msg.motor_cmd[i].kp),
"kd": float(msg.motor_cmd[i].kd),
"tau": float(msg.motor_cmd[i].tau),
}
)
return {
"topic": topic,
"data": {
"mode_pr": int(msg.mode_pr),
"mode_machine": int(msg.mode_machine),
"motor_cmd": motor_cmds,
},
}
def ChannelFactoryInitialize(*args: Any, **kwargs: Any) -> None: # noqa: N802
"""
Initialize ZMQ sockets for robot communication.
This function mimics the Unitree SDK's ChannelFactoryInitialize but uses
ZMQ sockets to connect to the robot server bridge instead of DDS.
"""
global _ctx, _lowcmd_sock, _lowstate_sock
# read socket config
config = UnitreeG1Config()
robot_ip = config.robot_ip
ctx = zmq.Context.instance()
_ctx = ctx
# lowcmd: send robot commands
lowcmd_sock = ctx.socket(zmq.PUSH)
lowcmd_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
lowcmd_sock.connect(f"tcp://{robot_ip}:{LOWCMD_PORT}")
_lowcmd_sock = lowcmd_sock
# lowstate: receive robot observations
lowstate_sock = ctx.socket(zmq.SUB)
lowstate_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
lowstate_sock.connect(f"tcp://{robot_ip}:{LOWSTATE_PORT}")
lowstate_sock.setsockopt_string(zmq.SUBSCRIBE, "")
_lowstate_sock = lowstate_sock
class ChannelPublisher:
"""ZMQ-based publisher that sends commands to the robot server."""
def __init__(self, topic: str, msg_type: type) -> None:
self.topic = topic
self.msg_type = msg_type
def Init(self) -> None: # noqa: N802
"""Initialize the publisher (no-op for ZMQ)."""
pass
def Write(self, msg: Any) -> None: # noqa: N802
"""Serialize and send a command message to the robot."""
if _lowcmd_sock is None:
raise RuntimeError("ChannelFactoryInitialize must be called first")
payload = json.dumps(lowcmd_to_dict(self.topic, msg)).encode("utf-8")
_lowcmd_sock.send(payload)
class ChannelSubscriber:
"""ZMQ-based subscriber that receives state from the robot server."""
def __init__(self, topic: str, msg_type: type) -> None:
self.topic = topic
self.msg_type = msg_type
def Init(self) -> None: # noqa: N802
"""Initialize the subscriber (no-op for ZMQ)."""
pass
def Read(self) -> LowStateMsg: # noqa: N802
"""Receive and deserialize a state message from the robot."""
if _lowstate_sock is None:
raise RuntimeError("ChannelFactoryInitialize must be called first")
payload = _lowstate_sock.recv()
msg_dict = json.loads(payload.decode("utf-8"))
return LowStateMsg(msg_dict.get("data", {}))

View File

@@ -65,6 +65,7 @@ import argparse
import gc
import logging
import time
from collections.abc import Iterator
from pathlib import Path
import numpy as np
@@ -77,6 +78,19 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset: LeRobotDataset, episode_index: int):
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
self.frame_ids = range(from_idx, to_idx)
def __iter__(self) -> Iterator:
return iter(self.frame_ids)
def __len__(self) -> int:
return len(self.frame_ids)
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
@@ -105,10 +119,12 @@ def visualize_dataset(
repo_id = dataset.repo_id
logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
sampler=episode_sampler,
)
logging.info("Starting Rerun")

View File

@@ -533,7 +533,7 @@ def eval_main(cfg: EvalPipelineConfig):
)
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all(

View File

@@ -50,7 +50,7 @@ from lerobot.teleoperators import ( # noqa: F401
make_teleoperator_from_config,
so100_leader,
)
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
@dataclass
@@ -114,7 +114,7 @@ def find_joint_and_ee_bounds(cfg: FindJointLimitsConfig):
print(f"Min joint pos position {np.round(min_pos, 4).tolist()}")
break
precise_sleep(0.01)
busy_wait(0.01)
def main():

View File

@@ -119,7 +119,7 @@ from lerobot.utils.control_utils import (
sanity_check_dataset_robot_compatibility,
)
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
get_safe_torch_device,
init_logging,
@@ -364,7 +364,7 @@ def record_loop(
log_rerun_data(observation=obs_processed, action=action_values)
dt_s = time.perf_counter() - start_loop_t
precise_sleep(1 / fps - dt_s)
busy_wait(1 / fps - dt_s)
timestamp = time.perf_counter() - start_episode_t

View File

@@ -62,7 +62,7 @@ from lerobot.robots import ( # noqa: F401
)
from lerobot.utils.constants import ACTION
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
init_logging,
log_say,
@@ -121,7 +121,7 @@ def replay(cfg: ReplayConfig):
_ = robot.send_action(processed_action)
dt_s = time.perf_counter() - start_episode_t
precise_sleep(1 / dataset.fps - dt_s)
busy_wait(1 / dataset.fps - dt_s)
robot.disconnect()

View File

@@ -89,7 +89,7 @@ from lerobot.teleoperators import ( # noqa: F401
so101_leader,
)
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import init_logging, move_cursor_up
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
@@ -170,13 +170,12 @@ def teleop_loop(
# Display the final robot action that was sent
for motor, value in robot_action_to_send.items():
print(f"{motor:<{display_len}} | {value:>7.2f}")
move_cursor_up(len(robot_action_to_send) + 3)
move_cursor_up(len(robot_action_to_send) + 5)
dt_s = time.perf_counter() - loop_start
precise_sleep(1 / fps - dt_s)
busy_wait(1 / fps - dt_s)
loop_s = time.perf_counter() - loop_start
print(f"Teleop loop time: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
move_cursor_up(1)
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
if duration is not None and time.perf_counter() - start >= duration:
return

View File

@@ -260,9 +260,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info("Creating environment processors")
env_preprocessor, env_postprocessor = make_env_pre_post_processors(
env_cfg=cfg.env, policy_cfg=cfg.policy
)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")

View File

@@ -16,40 +16,14 @@ import platform
import time
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.003):
"""
Wait for `seconds` with better precision than time.sleep alone at the expense of more CPU usage.
Parameters:
- seconds: duration to wait
- spin_threshold: if remaining <= spin_threshold -> spin; otherwise sleep (seconds). Default 10ms
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 3ms
Note:
The default parameters are chosen to prioritize timing accuracy over CPU usage for the common 30 FPS use case.
"""
if seconds <= 0:
return
system = platform.system()
# On macOS and Windows the scheduler / sleep granularity can make
# short sleeps inaccurate. Instead of burning CPU for the whole
# duration, sleep for most of the time and spin for the final few
# milliseconds to achieve good accuracy with much lower CPU usage.
if system in ("Darwin", "Windows"):
def busy_wait(seconds):
if platform.system() == "Darwin" or platform.system() == "Windows":
# On Mac and Windows, `time.sleep` is not accurate and we need to use this while loop trick,
# but it consumes CPU cycles.
end_time = time.perf_counter() + seconds
while True:
remaining = end_time - time.perf_counter()
if remaining <= 0:
break
# If there's more than a couple milliseconds left, sleep most
# of the remaining time and leave a small margin for the final spin.
if remaining > spin_threshold:
# Sleep but avoid sleeping past the end by leaving a small margin.
time.sleep(max(remaining - sleep_margin, 0))
else:
# Final short spin to hit precise timing without long sleeps.
pass
while time.perf_counter() < end_time:
pass
else:
# On Linux time.sleep is accurate enough for most uses
time.sleep(seconds)
# On Linux time.sleep is accurate
if seconds > 0:
time.sleep(seconds)

View File

@@ -1,361 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script to verify XVLA policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
# ruff: noqa: E402
import gc
import random
from copy import deepcopy
from typing import Any
import numpy as np
import pytest
import torch
pytest.importorskip("transformers")
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
from tests.utils import require_cuda # noqa: E402
# Constants
DUMMY_ACTION_DIM = 7 # Standard robot arm action dimension
DUMMY_STATE_DIM = 20 # Proprioceptive state dimension
IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224
NUM_VIEWS = 2 # Number of camera views
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH_LEROBOT = "lerobot/xvla-widowx"
LIBERO_DOMAIN_ID = 0 # Domain ID for examples purposes
# Expected values from original XVLA implementation (reference values)
EXPECTED_ACTIONS_SHAPE = (30, 20)
EXPECTED_ACTIONS_MEAN = 0.117606
EXPECTED_ACTIONS_STD = 0.245411
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.2742, 0.4977, 0.0500, 0.7040, -0.2653])
def cleanup_memory():
"""Clean up GPU/MPS memory to prevent OOM errors between tests."""
print("\nCleaning up memory...")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
print("Memory cleanup complete.")
def set_seed_all(seed: int):
"""Set random seed for all RNG sources to ensure reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Set deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True, warn_only=True)
def instantiate_lerobot_xvla(
from_pretrained: bool = False,
model_path: str = MODEL_PATH_LEROBOT,
) -> tuple[
Any, # Policy
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Instantiate LeRobot XVLA policy with preprocessor and postprocessor."""
if from_pretrained:
policy = XVLAPolicy.from_pretrained(
pretrained_name_or_path=model_path,
strict=False,
)
else:
config = XVLAConfig(
base_model_path=model_path,
n_action_steps=DUMMY_ACTION_DIM,
chunk_size=DUMMY_ACTION_DIM,
device=DEVICE,
num_image_views=NUM_VIEWS,
) # add resize_imgs_with_padding=IMAGE_SIZE, IMAGE_SIZE?
policy = XVLAPolicy(config)
policy.to(DEVICE)
policy.config.device = DEVICE
preprocessor, postprocessor = make_xvla_pre_post_processors(
config=policy.config,
dataset_stats=None, # Pass None for dataset_stats to disable normalization (original XVLA doesn't normalize)
)
return policy, preprocessor, postprocessor
def create_dummy_data(device=DEVICE):
"""Create dummy data for testing both implementations."""
batch_size = 1
prompt = "Pick up the red block and place it in the bin"
# Create random RGB images in [0, 255] uint8 range (as PIL images would be)
# Then convert to [0, 1] float32 range for LeRobot
def fake_rgb(h, w):
arr = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
t = torch.from_numpy(arr).permute(2, 0, 1) # CHW
return t
batch = {
f"{OBS_IMAGES}.image": torch.stack(
[fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]
).to(device),
f"{OBS_IMAGES}.image2": torch.stack(
[fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]
).to(device),
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
"task": [prompt for _ in range(batch_size)],
}
return batch
# Pytest fixtures
@pytest.fixture(scope="module")
def xvla_components():
"""Fixture to instantiate and provide all XVLA components for tests."""
print(f"\nTesting with DEVICE='{DEVICE}'")
print("\n[Setup] Instantiating LeRobot XVLA policy...")
policy_obj, preprocessor_obj, postprocessor_obj = instantiate_lerobot_xvla(from_pretrained=True)
print("✔️ Model loaded successfully")
yield policy_obj, preprocessor_obj, postprocessor_obj
cleanup_memory()
@pytest.fixture(scope="module")
def policy(xvla_components):
"""Fixture to provide the XVLA policy for tests."""
return xvla_components[0]
@pytest.fixture(scope="module")
def preprocessor(xvla_components):
"""Fixture to provide the XVLA preprocessor for tests."""
return xvla_components[1]
@require_cuda
def test_xvla_preprocessor_alignment(policy, preprocessor):
"""Test that LeRobot XVLA preprocessor produces expected outputs."""
print("\n" + "=" * 80)
print("Test: XVLA Preprocessor Outputs")
print("=" * 80)
set_seed_all(42)
print("\nCreating dummy data...")
batch = create_dummy_data()
print("\n[LeRobot] Preprocessing...")
lerobot_observation = preprocessor(deepcopy(batch))
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
print("\nVerifying preprocessor outputs:")
print("-" * 80)
# Expected shapes from tester.txt
expected_shapes = {
"domain_id": (1,),
"input_ids": (1, 50),
"proprio": (1, 20),
"image_mask": (1, 2),
"image_input": (1, 2, 3, 224, 224),
}
for key, expected_shape in expected_shapes.items():
if key in lerobot_inputs:
actual_shape = tuple(lerobot_inputs[key].shape)
print(f"\nKey: {key}")
print(f"Expected shape: {expected_shape}")
print(f"Actual shape: {actual_shape}")
if actual_shape == expected_shape:
print("Shape matches!")
else:
print("Shape mismatch!")
assert actual_shape == expected_shape, f"Shape mismatch for {key}"
else:
print(f"\nKey '{key}' not found in inputs!")
print("\nAll preprocessor outputs have correct shapes!")
@require_cuda
def test_xvla_action_generation(policy, preprocessor):
"""Test XVLA LeRobot implementation generates expected actions."""
print("\n" + "=" * 80)
print("Test: XVLA Action Generation Against Expected Values")
print("=" * 80)
set_seed_all(42)
print("\nCreating dummy data...")
batch = create_dummy_data()
print("\n[LeRobot] Running inference...")
lerobot_observation = preprocessor(deepcopy(batch))
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
# Reset seed for inference
torch.manual_seed(42)
with torch.no_grad():
lerobot_actions = policy.model.generate_actions(**lerobot_inputs, steps=10)
lerobot_actions = lerobot_actions.squeeze(0).float().cpu()
print(f"LeRobot actions shape: {lerobot_actions.shape}")
print(f"LeRobot actions mean: {lerobot_actions.mean().item():.6f}")
print(f"LeRobot actions std: {lerobot_actions.std().item():.6f}")
print(f"LeRobot actions first 5: {lerobot_actions[0, :5]}")
print("\nExpected values (from original XVLA):")
print(f"Expected actions shape: {EXPECTED_ACTIONS_SHAPE}")
print(f"Expected actions mean: {EXPECTED_ACTIONS_MEAN:.6f}")
print(f"Expected actions std: {EXPECTED_ACTIONS_STD:.6f}")
print(f"Expected actions first 5: {EXPECTED_ACTIONS_FIRST_5}")
print("\nAction Comparison:")
print("-" * 80)
# Compare shapes
actual_shape = tuple(lerobot_actions.shape)
assert actual_shape == EXPECTED_ACTIONS_SHAPE, (
f"Shape mismatch: {actual_shape} vs {EXPECTED_ACTIONS_SHAPE}"
)
print(f"✔️ Shape matches: {actual_shape}")
# Compare statistics
actual_mean = lerobot_actions.mean().item()
actual_std = lerobot_actions.std().item()
mean_diff = abs(actual_mean - EXPECTED_ACTIONS_MEAN)
std_diff = abs(actual_std - EXPECTED_ACTIONS_STD)
print(f"\nMean: {actual_mean:.6f} (expected: {EXPECTED_ACTIONS_MEAN:.6f}, diff: {mean_diff:.6e})")
print(f"Std: {actual_std:.6f} (expected: {EXPECTED_ACTIONS_STD:.6f}, diff: {std_diff:.6e})")
# Compare first 5 actions
actual_first_5 = lerobot_actions[0, :5]
first_5_diff = torch.abs(actual_first_5 - EXPECTED_ACTIONS_FIRST_5)
print("\nFirst 5 actions comparison:")
print(f" Actual: {actual_first_5}")
print(f" Expected: {EXPECTED_ACTIONS_FIRST_5}")
print(f" Max diff: {first_5_diff.max().item():.6e}")
print(f" Mean diff: {first_5_diff.mean().item():.6e}")
# Check with different tolerances
tolerances = [1e-5, 1e-4, 1e-3, 1e-2]
for tol in tolerances:
is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol)
status = "Success" if is_close else "Failure"
print(f"{status}: First 5 actions close (atol={tol}): {is_close}")
# Assert with reasonable tolerance
tolerance = 1e-3
assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), (
f"First 5 actions differ by more than tolerance ({tolerance})"
)
print(f"\nSuccess: Actions match expected values within tolerance ({tolerance})!")
@require_cuda
def test_xvla_inference_reproducibility(policy, preprocessor):
"""Test that XVLA inference is reproducible with the same seed."""
print("\n" + "=" * 80)
print("Test: XVLA Inference Reproducibility")
print("=" * 80)
print("\nCreating dummy data...")
batch = create_dummy_data()
# First inference
print("\n[Run 1] Running inference...")
set_seed_all(42)
lerobot_observation = preprocessor(deepcopy(batch))
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
with torch.no_grad():
actions_1 = policy.model.generate_actions(**lerobot_inputs, steps=10)
actions_1 = actions_1.squeeze(0).float().cpu()
# Second inference with same seed
print("\n[Run 2] Running inference with same seed...")
set_seed_all(42)
lerobot_observation = preprocessor(deepcopy(batch))
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
with torch.no_grad():
actions_2 = policy.model.generate_actions(**lerobot_inputs, steps=10)
actions_2 = actions_2.squeeze(0).float().cpu()
print("\nComparing two runs:")
print("-" * 80)
if torch.allclose(actions_1, actions_2, atol=1e-8):
print("Inference is perfectly reproducible!")
else:
diff = torch.abs(actions_1 - actions_2)
print("Small differences detected:")
print(f" Max diff: {diff.max().item():.6e}")
print(f" Mean diff: {diff.mean().item():.6e}")
assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!"
print("\nInference is reproducible!")
if __name__ == "__main__":
print("\n" + "=" * 80)
print("XVLA LeRobot Validation Test Suite")
print("=" * 80)
try:
# Initialize model once for all tests
print("\n[Setup] Instantiating LeRobot XVLA policy...")
policy, preprocessor, postprocessor = instantiate_lerobot_xvla(from_pretrained=True)
print("✔️ Model loaded successfully")
# Run all tests with the same model instance
test_xvla_preprocessor_alignment(policy, preprocessor)
test_xvla_action_generation(policy, preprocessor)
test_xvla_inference_reproducibility(policy, preprocessor)
print("\n" + "=" * 80)
print("All tests passed!")
print("=" * 80)
cleanup_memory()
except Exception as e:
print("\n" + "=" * 80)
print(f"Test failed with error: {e}")
print("=" * 80)
cleanup_memory()
raise