sync updates

This commit is contained in:
Martino Russi
2025-11-26 13:11:01 +01:00
parent bebf9b8480
commit c65866ddd8
9 changed files with 978 additions and 350 deletions

View File

@@ -19,6 +19,7 @@ Provides the ZMQCamera class for capturing frames from remote cameras via ZeroMQ
import json
import logging
import os
import threading
import time
from pathlib import Path
from threading import Event, Lock, Thread
@@ -105,6 +106,9 @@ class ZMQCamera(Camera):
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.new_frame_event: Event = Event()
# Format type detected during connection (msgpack, json, or raw_jpeg)
self._format_type: str | None = None
def __str__(self) -> str:
return f"{self.__class__.__name__}({self.camera_name}@{self.server_address}:{self.port})"
@@ -141,9 +145,22 @@ class ZMQCamera(Camera):
self._connected = True
# Try to receive one frame to validate connection
# Try to receive one frame to validate connection and detect format
try:
test_frame = self.read()
# Try each format until one works
test_frame = None
for format_type in ["msgpack", "json", "raw_jpeg"]:
try:
test_frame = self.read(format=format_type)
self._format_type = format_type
logger.info(f"{self} detected format: {format_type}")
break
except Exception as e:
logger.debug(f"{self} format '{format_type}' failed: {e}")
continue
if test_frame is None:
raise RuntimeError("Failed to decode frame with any supported format (msgpack, json, raw_jpeg)")
# Auto-detect resolution if not specified
if self.width is None or self.height is None:
@@ -179,136 +196,185 @@ class ZMQCamera(Camera):
raise RuntimeError(f"Failed to connect to {self}: {e}")
@staticmethod
def find_cameras() -> list[dict[str, Any]]:
def find_cameras(
subnet: str | None = None,
ports: list[int] | None = None,
timeout_ms: int = 200,
) -> list[dict[str, Any]]:
"""
Detects available ZMQ cameras based on configuration.
Scans the local network for ZMQ cameras (fast parallel scan).
Reads camera configurations from:
1. Environment variable LEROBOT_ZMQ_CAMERAS (JSON format)
2. Config file at ~/.lerobot/zmq_cameras.json
Uses threading to scan multiple hosts simultaneously. Without parallelization,
scanning 254 hosts would take 6+ minutes. With threads, takes ~10-15 seconds.
Example JSON format:
```json
[
{
"name": "unitree_g1_head",
"address": "192.168.123.164",
"port": 5554
},
{
"name": "lab_cam_1",
"address": "192.168.1.100",
"port": 5555
}
]
```
Args:
subnet: Network subnet to scan (e.g., "192.168.1.0/24"). If None, auto-detects.
ports: List of ports to scan. Defaults to [5554, 5555, 5556].
timeout_ms: Connection timeout per host in milliseconds. Default: 200ms.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing ZMQ camera information.
List of dicts containing camera info (address, port, format, resolution).
Example:
>>> cameras = ZMQCamera.find_cameras()
>>> # Or specify: cameras = ZMQCamera.find_cameras(subnet="10.0.0.0/24", ports=[5554])
"""
found_cameras_info = []
camera_configs = []
import socket
import ipaddress
from concurrent.futures import ThreadPoolExecutor, as_completed
# Try to load from environment variable first
env_cameras = os.environ.get("LEROBOT_ZMQ_CAMERAS")
if env_cameras:
if ports is None:
ports = [5554, 5555, 5556]
# Auto-detect local subnet
if subnet is None:
try:
camera_configs = json.loads(env_cameras)
logger.info(f"Loaded {len(camera_configs)} ZMQ camera configs from LEROBOT_ZMQ_CAMERAS")
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse LEROBOT_ZMQ_CAMERAS environment variable: {e}")
#use unitree_g1_head as an example
camera_configs = [
{
"name": "unitree_g1_head",
"address": "192.168.123.164",
"port": 5554
}
]
# Try to load from config file
if not camera_configs:
config_path = Path.home() / ".lerobot" / "zmq_cameras.json"
if config_path.exists():
try:
with open(config_path) as f:
camera_configs = json.load(f)
logger.info(f"Loaded {len(camera_configs)} ZMQ camera configs from {config_path}")
except (json.JSONDecodeError, IOError) as e:
logger.warning(f"Failed to load ZMQ camera config from {config_path}: {e}")
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
local_ip = s.getsockname()[0]
s.close()
subnet = ".".join(local_ip.split(".")[:-1]) + ".0/24"
logger.info(f"Auto-detected subnet: {subnet}")
except Exception as e:
logger.error(f"Failed to auto-detect subnet: {e}")
return []
if not camera_configs:
logger.info(
"No ZMQ cameras configured. Set LEROBOT_ZMQ_CAMERAS environment variable "
f"or create {Path.home() / '.lerobot' / 'zmq_cameras.json'}"
)
# Parse subnet
try:
network = ipaddress.ip_network(subnet, strict=False)
hosts = list(network.hosts())
# Always include localhost (for MuJoCo sim, local servers)
hosts.insert(0, ipaddress.IPv4Address("127.0.0.1"))
except Exception as e:
logger.error(f"Invalid subnet '{subnet}': {e}")
return []
# Test each configured camera
for cam_config in camera_configs:
try:
name = cam_config.get("name", "unknown")
address = cam_config.get("address")
port = cam_config.get("port", 5554)
if not address:
logger.warning(f"Skipping camera '{name}': missing address")
continue
# Try to connect with a short timeout
context = zmq.Context()
socket = context.socket(zmq.SUB)
socket.connect(f"tcp://{address}:{port}")
socket.setsockopt_string(zmq.SUBSCRIBE, "")
socket.setsockopt(zmq.RCVTIMEO, 2000) # 2 second timeout for discovery
total = len(hosts) * len(ports)
logger.info(f"Scanning {len(hosts)} hosts × {len(ports)} ports = {total} targets (this takes ~10-15s)...")
def test_target(host_ip: str, port: int) -> dict | None:
"""Test one host:port for ZMQ camera."""
ctx = zmq.Context()
sock = ctx.socket(zmq.SUB)
sock.connect(f"tcp://{host_ip}:{port}")
sock.setsockopt_string(zmq.SUBSCRIBE, "")
sock.setsockopt(zmq.RCVTIMEO, timeout_ms)
# Wait for subscription to establish (ZMQ "slow joiner" problem)
time.sleep(0.1)
# Try receiving a few times
msg = None
for _ in range(3):
try:
# Try to receive one frame to validate
message = socket.recv()
np_img = np.frombuffer(message, dtype=np.uint8)
test_image = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
if test_image is not None:
height, width = test_image.shape[:2]
camera_info = {
"name": f"ZMQ Camera: {name}",
"type": "ZMQ",
"id": f"{address}:{port}",
"server_address": address,
"port": port,
"camera_name": name,
"default_stream_profile": {
"width": width,
"height": height,
"format": "JPEG",
},
}
found_cameras_info.append(camera_info)
logger.info(f"Found ZMQ camera: {name} at {address}:{port}")
else:
logger.warning(f"Camera '{name}' at {address}:{port} returned invalid image")
msg = sock.recv()
break
except zmq.Again:
logger.warning(f"Camera '{name}' at {address}:{port} timeout - not streaming")
except Exception as e:
logger.warning(f"Error testing camera '{name}' at {address}:{port}: {e}")
finally:
socket.close()
context.term()
time.sleep(0.05)
if msg is None:
sock.close()
ctx.term()
return None
# Try formats: msgpack → json → raw_jpeg
frame = fmt = None
# Msgpack
try:
d = msgpack.unpackb(msg, object_hook=m.decode)
if isinstance(d, dict) and "images" in d and len(d["images"]) > 0:
img = next(iter(d["images"].values()))
if isinstance(img, str):
frame = cv2.imdecode(np.frombuffer(base64.b64decode(img), np.uint8), cv2.IMREAD_COLOR)
elif isinstance(img, np.ndarray):
frame = img
if frame is not None:
fmt = "msgpack"
except:
pass
# JSON
if frame is None:
try:
d = json.loads(msg.decode('utf-8'))
if isinstance(d, dict):
for v in d.values():
if isinstance(v, str) and len(v) > 100:
try:
frame = cv2.imdecode(np.frombuffer(base64.b64decode(v), np.uint8), cv2.IMREAD_COLOR)
if frame is not None:
fmt = "json"
break
except:
pass
except:
pass
# Raw JPEG
if frame is None:
try:
frame = cv2.imdecode(np.frombuffer(msg, np.uint8), cv2.IMREAD_COLOR)
if frame is not None:
fmt = "raw_jpeg"
except:
pass
sock.close()
ctx.term()
if frame is not None:
h, w = frame.shape[:2]
return {
"name": f"ZMQ @ {host_ip}:{port}",
"type": "ZMQ",
"id": f"{host_ip}:{port}",
"server_address": host_ip,
"port": port,
"camera_name": f"cam_{host_ip.replace('.', '_')}_{port}",
"format": fmt,
"default_stream_profile": {"width": w, "height": h, "format": fmt.upper()},
}
return None
except Exception as e:
logger.warning(f"Error processing camera config: {e}")
# Parallel scan with thread pool
found = []
with ThreadPoolExecutor(max_workers=100) as ex:
futures = [ex.submit(test_target, str(h), p) for h in hosts for p in ports]
for i, fut in enumerate(as_completed(futures), 1):
if i % 100 == 0:
logger.info(f" Progress: {i}/{total} ({100*i//total}%)")
res = fut.result()
if res:
found.append(res)
logger.info(f"{res['server_address']}:{res['port']} ({res['format']})")
logger.info(f"Scan complete! Found {len(found)} camera(s).")
return found
return found_cameras_info
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
def read(self, color_mode: ColorMode | None = None, format: str | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the ZMQ camera.
For the mujoco sim (127.0.0.1:5554), the server sends a msgpack-encoded
dict with base64-encoded JPEG images:
{"timestamps": {...}, "images": {camera_name: "<b64 jpeg>"}}
For other ZMQ cameras, we assume a raw JPEG buffer.
Supports three message formats:
1. "msgpack": Msgpack with base64 JPEGs: {"timestamps": {...}, "images": {camera_name: "b64"}}
(used by MuJoCo sim)
2. "json": JSON with base64 JPEGs: {"state": 0.0, "camera_name": "b64jpeg"}
(used by LeKiwi-style servers)
3. "raw_jpeg": Raw JPEG bytes (used by Unitree G1 head camera)
Args:
color_mode: Target color mode (RGB or BGR). If None, uses self.color_mode.
format: Message format to use. If None, uses auto-detected format from connect().
One of: "msgpack", "json", "raw_jpeg"
Returns:
np.ndarray: Decoded frame in shape (height, width, 3)
Raises:
DeviceNotConnectedError: If camera is not connected
TimeoutError: If no frame received within timeout_ms
RuntimeError: If frame decoding fails
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -316,6 +382,13 @@ class ZMQCamera(Camera):
if self.socket is None:
raise DeviceNotConnectedError(f"{self} socket is not initialized")
# Use detected format if not specified
if format is None:
format = self._format_type
if format is None:
raise RuntimeError(f"{self} format not specified and not auto-detected during connect()")
start_time = time.perf_counter()
try:
@@ -327,48 +400,55 @@ class ZMQCamera(Camera):
frame = None
# special-case: mujoco sim publisher (msgpack + base64)
if self.server_address in ("127.0.0.1", "localhost") and self.port == 5554:
try:
data = msgpack.unpackb(message, object_hook=m.decode)
# Decode based on format
if format == "msgpack":
data = msgpack.unpackb(message, object_hook=m.decode)
if not isinstance(data, dict) or "images" not in data:
raise RuntimeError(f"{self} invalid msgpack format: expected dict with 'images' key")
if not isinstance(data, dict) or "images" not in data:
raise RuntimeError(f"{self} received invalid message format from sim")
images_dict = data["images"]
# Prefer named camera if present
if self.camera_name in images_dict:
img_data = images_dict[self.camera_name]
elif len(images_dict) > 0:
# Fallback: first available camera
img_data = next(iter(images_dict.values()))
else:
raise RuntimeError(f"{self} no images found in msgpack message")
images_dict = data["images"]
img_data = None
# prefer named camera if present
if self.camera_name in images_dict:
img_data = images_dict[self.camera_name]
elif len(images_dict) > 0:
# fallback: first available camera
img_data = next(iter(images_dict.values()))
else:
raise RuntimeError(f"{self} no images found in sim message")
# same logic as ImageUtils.decode_image
if isinstance(img_data, str):
color_bytes = base64.b64decode(img_data)
np_img = np.frombuffer(color_bytes, dtype=np.uint8)
frame = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
elif isinstance(img_data, np.ndarray):
frame = img_data
else:
raise RuntimeError(
f"{self} unknown image payload type from sim: {type(img_data)}"
)
except Exception as e:
raise RuntimeError(f"{self} failed to decode sim image: {e}")
# default path: raw jpeg over ZMQ (robot-side cameras etc.)
else:
# Decode the image data
if isinstance(img_data, str):
color_bytes = base64.b64decode(img_data)
np_img = np.frombuffer(color_bytes, dtype=np.uint8)
frame = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
elif isinstance(img_data, np.ndarray):
frame = img_data
else:
raise RuntimeError(f"{self} unknown image payload type: {type(img_data)}")
elif format == "json":
data = json.loads(message.decode('utf-8'))
if not isinstance(data, dict) or self.camera_name not in data:
raise RuntimeError(f"{self} invalid JSON format: expected dict with '{self.camera_name}' key")
img_b64 = data[self.camera_name]
if not isinstance(img_b64, str):
raise RuntimeError(f"{self} expected base64 string in JSON, got {type(img_b64)}")
color_bytes = base64.b64decode(img_b64)
np_img = np.frombuffer(color_bytes, dtype=np.uint8)
frame = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
elif format == "raw_jpeg":
np_img = np.frombuffer(message, dtype=np.uint8)
frame = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
else:
raise ValueError(f"{self} unsupported format: {format}. Use 'msgpack', 'json', or 'raw_jpeg'")
if frame is None or not isinstance(frame, np.ndarray):
raise RuntimeError(f"{self} failed to decode image")
raise RuntimeError(f"{self} failed to decode image using format '{format}'")
processed_frame = self._postprocess_image(frame, color_mode)
@@ -471,7 +551,7 @@ class ZMQCamera(Camera):
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 2000) -> NDArray[Any]:
def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.

View File

@@ -54,11 +54,6 @@ class ZMQCameraConfig(CameraConfig):
camera_name: Identifier name for this camera (for logging/debugging).
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
timeout_ms: Timeout in milliseconds for receiving frames. Defaults to 1000ms.
Note:
- The server must be streaming JPEG-encoded images over ZMQ PUB socket.
- Width and height should match the expected output dimensions from the server.
- FPS is informational and doesn't control the server's frame rate.
"""
server_address: str
@@ -81,4 +76,3 @@ class ZMQCameraConfig(CameraConfig):
if self.port <= 0 or self.port > 65535:
raise ValueError(f"`port` must be between 1 and 65535, but {self.port} is provided.")

View File

@@ -4,7 +4,7 @@
"drive_mode": 0,
"homing_offset": 0,
"range_min": -3,
"range_max": 2.6
"range_max": 1
},
"kLeftShoulderYaw.pos": {
"id": 1,
@@ -25,7 +25,7 @@
"drive_mode": 0,
"homing_offset": 0,
"range_min": -1,
"range_max": 2
"range_max": 1
},
"kLeftWristRoll.pos": {
"id": 4,
@@ -61,7 +61,7 @@
"drive_mode": 0,
"homing_offset": 0,
"range_min": -3.0,
"range_max": 2.6
"range_max": 1
},
"kRightShoulderYaw.pos": {
"id": 1,
@@ -82,7 +82,7 @@
"drive_mode": 0,
"homing_offset": 0,
"range_min": -1,
"range_max": 2
"range_max": 1
},
"kRightWristRoll.pos": {
"id": 4,

View File

@@ -48,12 +48,19 @@ class UnitreeG1Config(RobotConfig):
audio_client: bool = True
freeze_body: bool = False
gravity_compensation: bool = False
gravity_compensation: bool = True
cameras: dict[str, CameraConfig] = field(default_factory=dict)
# Socket communication configuration (REQUIRED)
# This robot class ONLY uses sockets to communicate with a bridge on the Orin
# Run 'python dds_to_socket.py' on the Orin first, then set this to the Orin's IP
# Example: socket_host="192.168.123.164" (Orin's wlan0 IP)
socket_host: str | None = None
socket_port: int | None = None
# Locomotion control
locomotion_control: bool = False
locomotion_control: bool = True
#policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/motion.pt"
policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/GR00T-WholeBodyControl-Walk.onnx"
@@ -64,9 +71,7 @@ class UnitreeG1Config(RobotConfig):
motion_file_path: str = "unitree_rl_lab/deploy/robots/g1_29dof/config/policy/mimic/dance_102/params/G1_Take_102.bvh_60hz.csv"
motion_fps: float = 60.0
motion_control_dt: float = 0.02
# Motion imitation parameters (from deploy.yaml)
motion_joint_ids_map: list = field(default_factory=lambda: [0, 6, 12, 1, 7, 13, 2, 8, 14, 3, 9, 15, 22, 4, 10, 16, 23, 5, 11, 17, 24, 18, 25, 19, 26, 20, 27, 21, 28])
motion_stiffness: list = field(default_factory=lambda: [40.2, 99.1, 40.2, 99.1, 28.5, 28.5, 40.2, 99.1, 40.2, 99.1, 28.5, 28.5, 40.2, 28.5, 28.5, 14.3, 14.3, 14.3, 14.3, 14.3, 16.8, 16.8, 14.3, 14.3, 14.3, 14.3, 14.3, 16.8, 16.8])
motion_damping: list = field(default_factory=lambda: [2.56, 6.31, 2.56, 6.31, 1.81, 1.81, 2.56, 6.31, 2.56, 6.31, 1.81, 1.81, 2.56, 1.81, 1.81, 0.907, 0.907, 0.907, 0.907, 0.907, 1.07, 1.07, 0.907, 0.907, 0.907, 0.907, 0.907, 1.07, 1.07])
@@ -98,6 +103,4 @@ class UnitreeG1Config(RobotConfig):
num_locomotion_actions: int = 12
num_locomotion_obs: int = 47
max_cmd: list = field(default_factory=lambda: [0.8, 0.5, 1.57])
locomotion_imu_type: str = "pelvis" # "torso" or "pelvis"
locomotion_imu_type: str = "pelvis" # "torso" or "pelvis"

View File

@@ -0,0 +1,347 @@
import cv2
import zmq
import time
import struct
from collections import deque
import numpy as np
import pyrealsense2 as rs
import logging_mp
logger_mp = logging_mp.get_logger(__name__, level=logging_mp.DEBUG)
class RealSenseCamera(object):
def __init__(self, img_shape, fps, serial_number=None, enable_depth=False) -> None:
"""
img_shape: [height, width]
serial_number: serial number
"""
self.img_shape = img_shape
self.fps = fps
self.serial_number = serial_number
self.enable_depth = enable_depth
align_to = rs.stream.color
self.align = rs.align(align_to)
self.init_realsense()
def init_realsense(self):
self.pipeline = rs.pipeline()
config = rs.config()
if self.serial_number is not None:
config.enable_device(self.serial_number)
config.enable_stream(rs.stream.color, self.img_shape[1], self.img_shape[0], rs.format.bgr8, self.fps)
if self.enable_depth:
config.enable_stream(rs.stream.depth, self.img_shape[1], self.img_shape[0], rs.format.z16, self.fps)
profile = self.pipeline.start(config)
self._device = profile.get_device()
if self._device is None:
logger_mp.error("[Image Server] pipe_profile.get_device() is None .")
if self.enable_depth:
assert self._device is not None
depth_sensor = self._device.first_depth_sensor()
self.g_depth_scale = depth_sensor.get_depth_scale()
self.intrinsics = profile.get_stream(rs.stream.color).as_video_stream_profile().get_intrinsics()
def get_frame(self):
frames = self.pipeline.wait_for_frames()
aligned_frames = self.align.process(frames)
color_frame = aligned_frames.get_color_frame()
if self.enable_depth:
depth_frame = aligned_frames.get_depth_frame()
if not color_frame:
return None
color_image = np.asanyarray(color_frame.get_data())
# color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
depth_image = np.asanyarray(depth_frame.get_data()) if self.enable_depth else None
return color_image, depth_image
def release(self):
self.pipeline.stop()
class OpenCVCamera:
def __init__(self, device_id, img_shape, fps):
"""
decive_id: /dev/video* or *
img_shape: [height, width]
"""
self.id = device_id
self.fps = fps
self.img_shape = img_shape
self.cap = cv2.VideoCapture(self.id, cv2.CAP_V4L2)
self.cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter.fourcc("M", "J", "P", "G"))
self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.img_shape[0])
self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.img_shape[1])
self.cap.set(cv2.CAP_PROP_FPS, self.fps)
# Test if the camera can read frames
if not self._can_read_frame():
logger_mp.error(
f"[Image Server] Camera {self.id} Error: Failed to initialize the camera or read frames. Exiting..."
)
self.release()
def _can_read_frame(self):
success, _ = self.cap.read()
return success
def release(self):
self.cap.release()
def get_frame(self):
ret, color_image = self.cap.read()
if not ret:
return None
return color_image
class ImageServer:
def __init__(self, config, port=5554, Unit_Test=False):
"""
config example1:
{
'fps':30 # frame per second
'head_camera_type': 'opencv', # opencv or realsense
'head_camera_image_shape': [480, 1280], # Head camera resolution [height, width]
'head_camera_id_numbers': [0], # '/dev/video0' (opencv)
'wrist_camera_type': 'realsense',
'wrist_camera_image_shape': [480, 640], # Wrist camera resolution [height, width]
'wrist_camera_id_numbers': ["218622271789", "241222076627"], # realsense camera's serial number
}
config example2:
{
'fps':30 # frame per second
'head_camera_type': 'realsense', # opencv or realsense
'head_camera_image_shape': [480, 640], # Head camera resolution [height, width]
'head_camera_id_numbers': ["218622271739"], # realsense camera's serial number
'wrist_camera_type': 'opencv',
'wrist_camera_image_shape': [480, 640], # Wrist camera resolution [height, width]
'wrist_camera_id_numbers': [0,1], # '/dev/video0' and '/dev/video1' (opencv)
}
If you are not using the wrist camera, you can comment out its configuration, like this below:
config:
{
'fps':30 # frame per second
'head_camera_type': 'opencv', # opencv or realsense
'head_camera_image_shape': [480, 1280], # Head camera resolution [height, width]
'head_camera_id_numbers': [0], # '/dev/video0' (opencv)
#'wrist_camera_type': 'realsense',
#'wrist_camera_image_shape': [480, 640], # Wrist camera resolution [height, width]
#'wrist_camera_id_numbers': ["218622271789", "241222076627"], # serial number (realsense)
}
"""
logger_mp.info(config)
self.fps = config.get("fps", 30)
self.head_camera_type = config.get("head_camera_type", "opencv")
self.head_image_shape = config.get("head_camera_image_shape", [480, 640]) # (height, width)
self.head_camera_id_numbers = config.get("head_camera_id_numbers", [0])
self.wrist_camera_type = config.get("wrist_camera_type", None)
self.wrist_image_shape = config.get("wrist_camera_image_shape", [480, 640]) # (height, width)
self.wrist_camera_id_numbers = config.get("wrist_camera_id_numbers", None)
self.port = port
self.Unit_Test = Unit_Test
# Initialize head cameras
self.head_cameras = []
if self.head_camera_type == "opencv":
for device_id in self.head_camera_id_numbers:
camera = OpenCVCamera(device_id=device_id, img_shape=self.head_image_shape, fps=self.fps)
self.head_cameras.append(camera)
elif self.head_camera_type == "realsense":
for serial_number in self.head_camera_id_numbers:
camera = RealSenseCamera(img_shape=self.head_image_shape, fps=self.fps, serial_number=serial_number)
self.head_cameras.append(camera)
else:
logger_mp.warning(f"[Image Server] Unsupported head_camera_type: {self.head_camera_type}")
# Initialize wrist cameras if provided
self.wrist_cameras = []
if self.wrist_camera_type and self.wrist_camera_id_numbers:
if self.wrist_camera_type == "opencv":
for device_id in self.wrist_camera_id_numbers:
camera = OpenCVCamera(device_id=device_id, img_shape=self.wrist_image_shape, fps=self.fps)
self.wrist_cameras.append(camera)
elif self.wrist_camera_type == "realsense":
for serial_number in self.wrist_camera_id_numbers:
camera = RealSenseCamera(
img_shape=self.wrist_image_shape, fps=self.fps, serial_number=serial_number
)
self.wrist_cameras.append(camera)
else:
logger_mp.warning(f"[Image Server] Unsupported wrist_camera_type: {self.wrist_camera_type}")
# Set ZeroMQ context and socket
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.bind(f"tcp://*:{self.port}")
if self.Unit_Test:
self._init_performance_metrics()
for cam in self.head_cameras:
if isinstance(cam, OpenCVCamera):
logger_mp.info(
f"[Image Server] Head camera {cam.id} resolution: {cam.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)} x {cam.cap.get(cv2.CAP_PROP_FRAME_WIDTH)}"
)
elif isinstance(cam, RealSenseCamera):
logger_mp.info(
f"[Image Server] Head camera {cam.serial_number} resolution: {cam.img_shape[0]} x {cam.img_shape[1]}"
)
else:
logger_mp.warning("[Image Server] Unknown camera type in head_cameras.")
for cam in self.wrist_cameras:
if isinstance(cam, OpenCVCamera):
logger_mp.info(
f"[Image Server] Wrist camera {cam.id} resolution: {cam.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)} x {cam.cap.get(cv2.CAP_PROP_FRAME_WIDTH)}"
)
elif isinstance(cam, RealSenseCamera):
logger_mp.info(
f"[Image Server] Wrist camera {cam.serial_number} resolution: {cam.img_shape[0]} x {cam.img_shape[1]}"
)
else:
logger_mp.warning("[Image Server] Unknown camera type in wrist_cameras.")
logger_mp.info("[Image Server] Image server has started, waiting for client connections...")
def _init_performance_metrics(self):
self.frame_count = 0 # Total frames sent
self.time_window = 1.0 # Time window for FPS calculation (in seconds)
self.frame_times = deque() # Timestamps of frames sent within the time window
self.start_time = time.time() # Start time of the streaming
def _update_performance_metrics(self, current_time):
# Add current time to frame times deque
self.frame_times.append(current_time)
# Remove timestamps outside the time window
while self.frame_times and self.frame_times[0] < current_time - self.time_window:
self.frame_times.popleft()
# Increment frame count
self.frame_count += 1
def _print_performance_metrics(self, current_time):
if self.frame_count % 30 == 0:
elapsed_time = current_time - self.start_time
real_time_fps = len(self.frame_times) / self.time_window
logger_mp.info(
f"[Image Server] Real-time FPS: {real_time_fps:.2f}, Total frames sent: {self.frame_count}, Elapsed time: {elapsed_time:.2f} sec"
)
def _close(self):
for cam in self.head_cameras:
cam.release()
for cam in self.wrist_cameras:
cam.release()
self.socket.close()
self.context.term()
logger_mp.info("[Image Server] The server has been closed.")
def send_process(self):
try:
while True:
head_frames = []
for cam in self.head_cameras:
if self.head_camera_type == "opencv":
color_image = cam.get_frame()
if color_image is None:
logger_mp.error("[Image Server] Head camera frame read is error.")
break
elif self.head_camera_type == "realsense":
color_image, depth_iamge = cam.get_frame()
if color_image is None:
logger_mp.error("[Image Server] Head camera frame read is error.")
break
head_frames.append(color_image)
if len(head_frames) != len(self.head_cameras):
break
head_color = cv2.hconcat(head_frames)
if self.wrist_cameras:
wrist_frames = []
for cam in self.wrist_cameras:
if self.wrist_camera_type == "opencv":
color_image = cam.get_frame()
if color_image is None:
logger_mp.error("[Image Server] Wrist camera frame read is error.")
break
elif self.wrist_camera_type == "realsense":
color_image, depth_iamge = cam.get_frame()
if color_image is None:
logger_mp.error("[Image Server] Wrist camera frame read is error.")
break
wrist_frames.append(color_image)
wrist_color = cv2.hconcat(wrist_frames)
# Concatenate head and wrist frames
full_color = cv2.hconcat([head_color, wrist_color])
else:
full_color = head_color
ret, buffer = cv2.imencode(".jpg", full_color)
if not ret:
logger_mp.error("[Image Server] Frame imencode is failed.")
continue
jpg_bytes = buffer.tobytes()
if self.Unit_Test:
timestamp = time.time()
frame_id = self.frame_count
header = struct.pack("dI", timestamp, frame_id) # 8-byte double, 4-byte unsigned int
message = header + jpg_bytes
else:
message = jpg_bytes
self.socket.send(message)
if self.Unit_Test:
current_time = time.time()
self._update_performance_metrics(current_time)
self._print_performance_metrics(current_time)
except KeyboardInterrupt:
logger_mp.warning("[Image Server] Interrupted by user.")
finally:
self._close()
if __name__ == "__main__":
# config = {
# "fps": 30,
# "head_camera_type": "opencv",
# "head_camera_image_shape": [480, 1280], # Head camera resolution
# "head_camera_id_numbers": [0],
# "wrist_camera_type": "opencv",
# "wrist_camera_image_shape": [480, 640], # Wrist camera resolution
# "wrist_camera_id_numbers": [2, 4],
#
#infrared
# config = {
# "fps": 30,
# "head_camera_type": "opencv",
# "head_camera_image_shape": [480, 640],
# "head_camera_id_numbers": [2], # <-- wrist cam that reported 480x640
# # no wrist_* keys
# }
#rgb
config = {
"fps": 30,
"head_camera_type": "opencv",
"head_camera_image_shape": [480,640], # match the device
"head_camera_id_numbers": [4], # /dev/video4 is RGB
}
server = ImageServer(config, Unit_Test=False)
server.send_process()

View File

@@ -0,0 +1,134 @@
#!/usr/bin/env python3
import time
import pickle
import threading
import zmq
from unitree_sdk2py.core.channel import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState
from unitree_sdk2py.utils.crc import CRC
from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import MotionSwitcherClient
kTopicLowCommand_Debug = "rt/lowcmd"
kTopicLowCommand_Motion = "rt/arm_sdk"
kTopicLowState = "rt/lowstate"
LOWCMD_PORT = 6000 # laptop -> robot
LOWSTATE_PORT = 6001 # robot -> laptop
def state_forward_loop(lowstate_sub, lowstate_sock, state_period: float):
"""
read lowstate from dds and push to laptop at ~state_period.
runs in its own thread.
"""
last_state_time = 0.0
while True:
# read from dds (blocking)
msg = lowstate_sub.Read()
if msg is None:
continue
now = time.time()
# optional downsampling (if robot dds rate > state_period)
if now - last_state_time >= state_period:
payload = pickle.dumps((kTopicLowState, msg), protocol=pickle.HIGHEST_PROTOCOL)
try:
lowstate_sock.send(payload, zmq.NOBLOCK)
except zmq.Again:
# if no subscribers / tx buffer full, just drop
pass
last_state_time = now
def cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, lowcmd_pub_motion, crc: CRC):
"""
read lowcmd from laptop (zmq) and push to dds.
runs in its own thread.
"""
while True:
# blocking wait for commands from laptop
payload = lowcmd_sock.recv()
topic, cmd = pickle.loads(payload) # cmd is hg_LowCmd
# recompute crc just in case
cmd.crc = crc.Crc(cmd)
if topic == kTopicLowCommand_Debug:
lowcmd_pub_debug.Write(cmd)
elif topic == kTopicLowCommand_Motion:
lowcmd_pub_motion.Write(cmd)
else:
# ignore unknown topics
pass
def main():
# dds init
ChannelFactoryInitialize(0)
# acquire motion mode on the robot
msc = MotionSwitcherClient()
msc.SetTimeout(5.0)
msc.Init()
status, result = msc.CheckMode()
while result is not None and "name" in result and result["name"]:
msc.ReleaseMode()
status, result = msc.CheckMode()
time.sleep(1.0)
crc = CRC()
# dds publishers / subscriber
lowcmd_pub_debug = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
lowcmd_pub_motion = ChannelPublisher(kTopicLowCommand_Motion, hg_LowCmd)
lowcmd_pub_debug.Init()
lowcmd_pub_motion.Init()
lowstate_sub = ChannelSubscriber(kTopicLowState, hg_LowState)
lowstate_sub.Init()
# zmq setup
ctx = zmq.Context.instance()
# commands from laptop
lowcmd_sock = ctx.socket(zmq.PULL)
lowcmd_sock.bind(f"tcp://0.0.0.0:{LOWCMD_PORT}")
# state to laptop
lowstate_sock = ctx.socket(zmq.PUB)
lowstate_sock.bind(f"tcp://0.0.0.0:{LOWSTATE_PORT}")
state_period = 0.002 # ~500 hz
# start threads
t_state = threading.Thread(
target=state_forward_loop,
args=(lowstate_sub, lowstate_sock, state_period),
daemon=True,
)
t_cmd = threading.Thread(
target=cmd_forward_loop,
args=(lowcmd_sock, lowcmd_pub_debug, lowcmd_pub_motion, crc),
daemon=True,
)
t_state.start()
t_cmd.start()
print("bridge running (lowstate -> zmq, lowcmd -> dds)")
# keep main thread alive so daemon threads dont exit
try:
while True:
time.sleep(1.0)
except KeyboardInterrupt:
print("shutting down bridge...")
# sockets/context will be cleaned up on process exit
if __name__ == "__main__":
main()

View File

@@ -27,31 +27,16 @@ import termios
import tty
from collections import deque
from unitree_sdk2py.core.channel import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize # dds
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState # idl for g1, h1_2
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
from unitree_sdk2py.utils.crc import CRC
from unitree_sdk2py.g1.audio.g1_audio_client import AudioClient
from unitree_sdk2py.idl.unitree_go.msg.dds_ import LowCmd_ as go_LowCmd, LowState_ as go_LowState # idl for h1
from unitree_sdk2py.idl.default import unitree_go_msg_dds__LowCmd_
from typing import Union
import numpy as np
import time
import torch
import onnxruntime as ort
from unitree_sdk2py.core.channel import ChannelPublisher, ChannelFactoryInitialize
from unitree_sdk2py.core.channel import ChannelSubscriber, ChannelFactoryInitialize
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_, unitree_hg_msg_dds__LowState_
from unitree_sdk2py.idl.default import unitree_go_msg_dds__LowCmd_, unitree_go_msg_dds__LowState_
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as LowCmdHG
from unitree_sdk2py.idl.unitree_go.msg.dds_ import LowCmd_ as LowCmdGo
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowState_ as LowStateHG
from unitree_sdk2py.idl.unitree_go.msg.dds_ import LowState_ as LowStateGo
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState # idl for g1, h1_2
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
from unitree_sdk2py.utils.crc import CRC
from unitree_sdk2py.g1.audio.g1_audio_client import AudioClient
from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import (
MotionSwitcherClient,
)
@@ -65,13 +50,11 @@ import yaml
from typing import Union
import logging_mp
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
import torch
logger_mp = logging_mp.get_logger(__name__)
logger = logging.getLogger(__name__)
kTopicLowCommand_Debug = "rt/lowcmd"
kTopicLowCommand_Motion = "rt/arm_sdk"
@@ -121,7 +104,6 @@ class DataBuffer:
#eventually observations should be everything: motor torques etc etc
#motor class for unitree?
#TODO: camera, sim
class UnitreeG1(Robot):
config_class = UnitreeG1Config
@@ -130,7 +112,7 @@ class UnitreeG1(Robot):
def __init__(self, config: UnitreeG1Config):
super().__init__(config)
logger_mp.info("Initialize UnitreeG1...")
logger.info("Initialize UnitreeG1...")
self.config = config
self.cameras = make_cameras_from_configs(config.cameras)
@@ -153,6 +135,11 @@ class UnitreeG1(Robot):
self._gradual_start_time = config.gradual_start_time
self._gradual_time = config.gradual_time
# Teleop warmup: gradually move from current position to targets over 2 seconds
self.teleop_warmup_duration = 2.0 # seconds
self.teleop_warmup_start_time = None
self.teleop_warmup_initial_q = None
self.freeze_body = config.freeze_body
self.gravity_compensation = config.gravity_compensation
@@ -163,33 +150,39 @@ class UnitreeG1(Robot):
self.arm_ik = G1_29_ArmIK()
if self.config.socket_host is not None:
from lerobot.robots.unitree_g1.unitree_sdk2_socket import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize # dds
else:
from unitree_sdk2py.core.channel import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize # dds
# initialize lowcmd publisher and lowstate subscriber
# initialize lowcmd nd lowstate subscriber
if self.simulation_mode:
ChannelFactoryInitialize(0, "lo")
# Launch MuJoCo simulation environment
logger_mp.info("Launching MuJoCo simulation environment...")
logger.info("Launching MuJoCo simulation environment...")
self.mujoco_env = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True)
logger_mp.info("MuJoCo environment launched successfully!")
logger.info("MuJoCo environment launched successfully!")
else:
ChannelFactoryInitialize(0)
if not self.config.simulation_mode:
self.msc = MotionSwitcherClient()
self.msc.SetTimeout(5.0)
self.msc.Init()
status, result = self.msc.CheckMode()
print(status, result)
#check if result name first
if result is not None and "name" in result:
while result["name"]:
self.msc.ReleaseMode()
status, result = self.msc.CheckMode()
print(status, result)
time.sleep(1)
if not self.config.simulation_mode:
pass
# self.msc = MotionSwitcherClient()
# self.msc.SetTimeout(5.0)
# self.msc.Init()
# status, result = self.msc.CheckMode()
# print(status, result)
# #check if result name first
# if result is not None and "name" in result:
# while result["name"]:
# self.msc.ReleaseMode()
# status, result = self.msc.CheckMode()
# print(status, result)
# time.sleep(1)
if self.motion_mode:
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Motion, hg_LowCmd)
@@ -207,8 +200,8 @@ class UnitreeG1(Robot):
while not self.lowstate_buffer.GetData():
time.sleep(0.1)
logger_mp.warning("[UnitreeG1] Waiting to subscribe dds...")
logger_mp.info("[UnitreeG1] Subscribe dds ok.")
logger.warning("[UnitreeG1] Waiting to subscribe dds...")
logger.info("[UnitreeG1] Subscribe dds ok.")
# initialize audio client for LED, TTS, and audio playback
@@ -221,9 +214,9 @@ class UnitreeG1(Robot):
print(self.msg)
self.all_motor_q = self.get_current_motor_q()
logger_mp.info(f"Current all body motor state q:\n{self.all_motor_q} \n")
logger_mp.info(f"Current two arms motor state q:\n{self.get_current_dual_arm_q()}\n")
logger_mp.info("Lock all joints except two arms...\n")
logger.info(f"Current all body motor state q:\n{self.all_motor_q} \n")
logger.info(f"Current two arms motor state q:\n{self.get_current_dual_arm_q()}\n")
logger.info("Lock all joints except two arms...\n")
arm_indices = set(member.value for member in G1_29_JointArmIndex)
for id in G1_29_JointIndex:
@@ -248,12 +241,13 @@ class UnitreeG1(Robot):
if config.audio_client:
self.audio_client = AudioClient()
self.audio_client.SetTimeout(10.0)
self.audio_client.Init()
logger_mp.info("[UnitreeG1] Audio client initialized!")
pass
# self.audio_client = AudioClient()
# self.audio_client.SetTimeout(10.0)
# self.audio_client.Init()
# logger.info("[UnitreeG1] Audio client initialized!")
logger_mp.info("Lock OK!\n") #motors are not locked x
logger.info("Lock OK!\n") #motors are not locked x
# for i in range(10000):
# print(self.get_current_motor_q())
# time.sleep(0.05)
@@ -266,15 +260,18 @@ class UnitreeG1(Robot):
self.motion_imitation_thread = None
self.motion_imitation_running = False
# Initialize publish thread ONLY if not using motion imitation or locomotion
# (those modes handle their own motor commands and publishing)
# Initialize publish thread for arm control
# Note: This thread runs alongside locomotion/motion_imitation threads
# - Arm thread: controls arms (indices 15-28)
# - Locomotion thread: controls legs (0-11), waist (12-14)
# Both update different parts of self.msg, both call Write()
self.publish_thread = None
self.ctrl_lock = threading.Lock()
if not config.motion_imitation_control and not config.locomotion_control:
if not config.motion_imitation_control: # Allow with locomotion, disable only for motion imitation
self.publish_thread = threading.Thread(target=self._ctrl_motor_state)
self.publish_thread.daemon = True
self.publish_thread.start()
logger_mp.info("Arm control publish thread started")
logger.info("Arm control publish thread started")
# Load locomotion policy if enabled
self.policy = None
@@ -286,21 +283,21 @@ class UnitreeG1(Robot):
if config.motion_file_path is None:
raise ValueError("motion_imitation_control is True but motion_file_path is not set")
logger_mp.info(f"Loading motion reference from {config.motion_file_path}")
logger.info(f"Loading motion reference from {config.motion_file_path}")
# Load motion file
self.motion_loader = self.MotionLoader(config.motion_file_path, config.motion_fps)
# Load ONNX policy (optional for now - can run in direct playback mode)
if config.motion_policy_path and Path(config.motion_policy_path).exists():
logger_mp.info(f"Loading motion imitation policy from {config.motion_policy_path}")
logger.info(f"Loading motion imitation policy from {config.motion_policy_path}")
self.policy = ort.InferenceSession(config.motion_policy_path)
self.policy_type = 'motion_imitation'
logger_mp.info("Motion imitation ONNX policy loaded successfully")
logger_mp.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
logger_mp.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
logger.info("Motion imitation ONNX policy loaded successfully")
logger.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
logger.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
else:
logger_mp.info("Running in DIRECT PLAYBACK mode (no policy - just reference motion)")
logger.info("Running in DIRECT PLAYBACK mode (no policy - just reference motion)")
self.policy = None
self.policy_type = 'motion_playback'
@@ -319,21 +316,21 @@ class UnitreeG1(Robot):
if config.policy_path is None:
raise ValueError("locomotion_control is True but policy_path is not set")
logger_mp.info(f"Loading locomotion policy from {config.policy_path}")
logger.info(f"Loading locomotion policy from {config.policy_path}")
# Check file extension and load accordingly
if config.policy_path.endswith('.pt'):
logger_mp.info("Detected TorchScript (.pt) policy")
logger.info("Detected TorchScript (.pt) policy")
self.policy = torch.jit.load(config.policy_path)
self.policy_type = 'torchscript'
logger_mp.info("TorchScript policy loaded successfully")
logger.info("TorchScript policy loaded successfully")
elif config.policy_path.endswith('.onnx'):
logger_mp.info("Detected ONNX (.onnx) policy")
logger.info("Detected ONNX (.onnx) policy")
self.policy = ort.InferenceSession(config.policy_path)
self.policy_type = 'onnx'
logger_mp.info("ONNX policy loaded successfully")
logger_mp.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
logger_mp.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
logger.info("ONNX policy loaded successfully")
logger.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
logger.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
else:
raise ValueError(f"Unsupported policy format: {config.policy_path}. Only .pt (TorchScript) and .onnx (ONNX) are supported.")
@@ -363,7 +360,7 @@ class UnitreeG1(Robot):
# Start keyboard controls if in simulation mode
if self.simulation_mode:
logger_mp.info("Starting keyboard controls for simulation...")
logger.info("Starting keyboard controls for simulation...")
self.start_keyboard_controls()
# Use different init based on policy type
@@ -373,10 +370,10 @@ class UnitreeG1(Robot):
self.init_locomotion()
elif self.simulation_mode:
# Even without locomotion, provide keyboard feedback in sim
logger_mp.info("Simulation mode active (locomotion disabled)")
logger.info("Simulation mode active (locomotion disabled)")
logger_mp.info("Initialize G1 OK!\n")
logger.info("Initialize G1 OK!\n")
def _subscribe_motor_state(self):
while True:
@@ -463,8 +460,8 @@ class UnitreeG1(Robot):
all_t_elapsed = current_time - start_time
sleep_time = max(0, (self.control_dt - all_t_elapsed))
time.sleep(sleep_time)
# logger_mp.debug(f"arm_velocity_limit:{self.arm_velocity_limit}")
# logger_mp.debug(f"sleep_time:{sleep_time}")
# logger.debug(f"arm_velocity_limit:{self.arm_velocity_limit}")
# logger.debug(f"sleep_time:{sleep_time}")
def ctrl_dual_arm(self, q_target, tauff_target):
"""Set control target values q & tau of the left and right arm motors."""
@@ -490,7 +487,7 @@ class UnitreeG1(Robot):
def ctrl_dual_arm_go_home(self):
"""Move both the left and right arms of the robot to their home position by setting the target joint angles (q) and torques (tau) to zero."""
logger_mp.info("[G1_29_ArmController] ctrl_dual_arm_go_home start...")
logger.info("[G1_29_ArmController] ctrl_dual_arm_go_home start...")
max_attempts = 100
current_attempts = 0
with self.ctrl_lock:
@@ -505,7 +502,7 @@ class UnitreeG1(Robot):
for weight in np.linspace(1, 0, num=101):
self.msg.motor_cmd[G1_29_JointIndex.kNotUsedJoint0].q = weight
time.sleep(0.02)
logger_mp.info("[G1_29_ArmController] both arms have reached the home position.")
logger.info("[G1_29_ArmController] both arms have reached the home position.")
break
current_attempts += 1
time.sleep(0.05)
@@ -563,7 +560,7 @@ class UnitreeG1(Robot):
# Connect cameras
for cam in self.cameras.values():
cam.connect()
logger_mp.info(f"{self} connected with {len(self.cameras)} camera(s).")
logger.info(f"{self} connected with {len(self.cameras)} camera(s).")
def disconnect(self):
# Disconnect cameras
@@ -572,11 +569,11 @@ class UnitreeG1(Robot):
# Close MuJoCo environment if in simulation mode
if self.simulation_mode and hasattr(self, 'mujoco_env'):
logger_mp.info("Closing MuJoCo environment...")
logger.info("Closing MuJoCo environment...")
print(self.mujoco_env)
self.mujoco_env["hub_env"][0].envs[0].kill_sim()
logger_mp.info(f"{self} disconnected.")
logger.info(f"{self} disconnected.")
def get_full_robot_state(self) -> dict[str, Any]:
"""
@@ -640,18 +637,18 @@ class UnitreeG1(Robot):
if isinstance(command, tuple) and len(command) == 3:
# LED control - RGB tuple
r, g, b = command
logger_mp.info(f"Setting LED to RGB({r}, {g}, {b})")
logger.info(f"Setting LED to RGB({r}, {g}, {b})")
self.audio_client.LedControl(r, g, b)
elif isinstance(command, str):
# Check if it's a file path
if Path(command).exists():
# Play WAV file
logger_mp.info(f"Playing audio file: {command}")
logger.info(f"Playing audio file: {command}")
self._play_wav_file(command)
else:
# Text-to-speech
logger_mp.info(f"Speaking: {command}")
logger.info(f"Speaking: {command}")
self.audio_client.TtsMaker(command, 0) # 0 for English
else:
raise ValueError(
@@ -746,7 +743,7 @@ class UnitreeG1(Robot):
chunk_index = 0
total_size = len(pcm_data)
logger_mp.info(f"Playing audio: {total_size} bytes in {(total_size // chunk_size) + 1} chunks")
logger.info(f"Playing audio: {total_size} bytes in {(total_size // chunk_size) + 1} chunks")
# Send audio in chunks
while offset < total_size:
@@ -757,10 +754,10 @@ class UnitreeG1(Robot):
# Send chunk
ret_code, _ = self.audio_client.PlayStream(app_name, stream_id, list(chunk))
if ret_code != 0:
logger_mp.error(f"Failed to send chunk {chunk_index}, return code: {ret_code}")
logger.error(f"Failed to send chunk {chunk_index}, return code: {ret_code}")
break
else:
logger_mp.debug(f"Sent chunk {chunk_index}/{(total_size // chunk_size)}")
logger.debug(f"Sent chunk {chunk_index}/{(total_size // chunk_size)}")
offset += current_chunk_size
chunk_index += 1
@@ -768,7 +765,7 @@ class UnitreeG1(Robot):
# Calculate playback duration
duration_seconds = len(pcm_data) / (16000 * 2) # 16kHz, 16-bit (2 bytes)
logger_mp.info(f"Audio playback will take ~{duration_seconds:.1f} seconds")
logger.info(f"Audio playback will take ~{duration_seconds:.1f} seconds")
def get_observation(self) -> dict[str, Any]:
obs_array = self.get_current_dual_arm_q()
@@ -779,7 +776,7 @@ class UnitreeG1(Robot):
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger_mp.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
return obs_dict
@@ -909,13 +906,13 @@ class UnitreeG1(Robot):
def locomotion_zero_torque_state(self):
"""Enter zero torque state."""
logger_mp.info("Enter zero torque state.")
logger.info("Enter zero torque state.")
self.locomotion_create_zero_cmd()
time.sleep(self.config.locomotion_control_dt)
def locomotion_move_to_default_pos(self):
"""Move robot legs to default standing position over 2 seconds (arms are not moved)."""
logger_mp.info("Moving legs to default locomotion pos.")
logger.info("Moving legs to default locomotion pos.")
total_time = 2.0
num_step = int(total_time / self.config.locomotion_control_dt)
@@ -929,7 +926,7 @@ class UnitreeG1(Robot):
# Get current lowstate
lowstate = self.lowstate_buffer.GetData()
if lowstate is None:
logger_mp.error("Cannot get lowstate for locomotion")
logger.error("Cannot get lowstate for locomotion")
return
# Record the current leg positions
@@ -951,11 +948,11 @@ class UnitreeG1(Robot):
self.msg.crc = self.crc.Crc(self.msg)
self.lowcmd_publisher.Write(self.msg)
time.sleep(self.config.locomotion_control_dt)
logger_mp.info("Reached default locomotion position (legs only)")
logger.info("Reached default locomotion position (legs only)")
def locomotion_default_pos_state(self):
"""Hold default leg position for 2 seconds (arms are not controlled)."""
logger_mp.info("Enter default pos state - holding legs for 2 seconds")
logger.info("Enter default pos state - holding legs for 2 seconds")
# Only control legs, not arms
for i in range(len(self.config.leg_joint2motor_idx)):
@@ -973,7 +970,7 @@ class UnitreeG1(Robot):
self.msg.crc = self.crc.Crc(self.msg)
self.lowcmd_publisher.Write(self.msg)
time.sleep(self.config.locomotion_control_dt)
logger_mp.info("Finished holding default leg position")
logger.info("Finished holding default leg position")
class RemoteController:
@@ -1022,7 +1019,7 @@ class UnitreeG1(Robot):
self.index_1 = 0
self.blend = 0.0
logger_mp.info(f"MotionLoader: Loaded {self.num_frames} frames, duration={self.duration:.2f}s")
logger.info(f"MotionLoader: Loaded {self.num_frames} frames, duration={self.duration:.2f}s")
def update(self, time: float):
"""Update motion to specific time (loops at duration)."""
@@ -1146,7 +1143,7 @@ class UnitreeG1(Robot):
# Debug: print remote controller values every 50 iterations (~1 second at 50Hz)
if self.locomotion_counter % 50 == 0:
logger_mp.debug(f"Remote controller - lx:{self.remote_controller.lx:.2f}, ly:{self.remote_controller.ly:.2f}, rx:{self.remote_controller.rx:.2f}")
logger.debug(f"Remote controller - lx:{self.remote_controller.lx:.2f}, ly:{self.remote_controller.ly:.2f}, rx:{self.remote_controller.rx:.2f}")
# Build observation vector
num_actions = self.config.num_locomotion_actions
@@ -1343,7 +1340,7 @@ class UnitreeG1(Robot):
def _locomotion_thread_loop(self):
"""Background thread that runs the locomotion policy at specified rate."""
logger_mp.info("Locomotion thread started")
logger.info("Locomotion thread started")
while self.locomotion_running:
start_time = time.time()
try:
@@ -1353,40 +1350,40 @@ class UnitreeG1(Robot):
else:
self.locomotion_run()
except Exception as e:
logger_mp.error(f"Error in locomotion loop: {e}")
logger.error(f"Error in locomotion loop: {e}")
# Sleep to maintain control rate
elapsed = time.time() - start_time
sleep_time = max(0, self.config.locomotion_control_dt - elapsed)
time.sleep(sleep_time)
logger_mp.info("Locomotion thread stopped")
logger.info("Locomotion thread stopped")
def start_locomotion_thread(self):
"""Start the background locomotion control thread."""
if not self.config.locomotion_control:
logger_mp.warning("locomotion_control is False, cannot start thread")
logger.warning("locomotion_control is False, cannot start thread")
return
if self.locomotion_running:
logger_mp.warning("Locomotion thread already running")
logger.warning("Locomotion thread already running")
return
logger_mp.info("Starting locomotion control thread...")
logger.info("Starting locomotion control thread...")
self.locomotion_running = True
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
self.locomotion_thread.start()
logger_mp.info("Locomotion control thread started!")
logger.info("Locomotion control thread started!")
def stop_locomotion_thread(self):
"""Stop the background locomotion control thread."""
if not self.locomotion_running:
return
logger_mp.info("Stopping locomotion control thread...")
logger.info("Stopping locomotion control thread...")
self.locomotion_running = False
if self.locomotion_thread:
self.locomotion_thread.join(timeout=2.0)
logger_mp.info("Locomotion control thread stopped")
logger.info("Locomotion control thread stopped")
# Also stop keyboard thread if running
if self.keyboard_running:
@@ -1459,40 +1456,40 @@ class UnitreeG1(Robot):
def start_keyboard_controls(self):
"""Start the keyboard control thread (sim mode only)."""
if not self.simulation_mode:
logger_mp.warning("Keyboard controls only available in simulation mode")
logger.warning("Keyboard controls only available in simulation mode")
return
if self.keyboard_running:
logger_mp.warning("Keyboard controls already running")
logger.warning("Keyboard controls already running")
return
self.keyboard_running = True
self.keyboard_thread = threading.Thread(target=self._keyboard_listener_thread, daemon=True)
self.keyboard_thread.start()
logger_mp.info("Keyboard controls started!")
logger.info("Keyboard controls started!")
def stop_keyboard_controls(self):
"""Stop the keyboard control thread."""
if not self.keyboard_running:
return
logger_mp.info("Stopping keyboard controls...")
logger.info("Stopping keyboard controls...")
self.keyboard_running = False
if self.keyboard_thread:
self.keyboard_thread.join(timeout=2.0)
logger_mp.info("Keyboard controls stopped")
logger.info("Keyboard controls stopped")
def init_locomotion(self):
"""Test locomotion control sequence: home arms -> move legs to default -> start policy thread."""
if not self.config.locomotion_control:
logger_mp.warning("locomotion_control is False, cannot run test sequence")
logger.warning("locomotion_control is False, cannot run test sequence")
return
logger_mp.info("Starting locomotion test sequence...")
logger.info("Starting locomotion test sequence...")
# 1. Home the arms first
logger_mp.info("Homing arms to zero position...")
logger.info("Homing arms to zero position...")
#self.ctrl_dual_arm_go_home()
# 2. Move legs to default position
@@ -1505,19 +1502,19 @@ class UnitreeG1(Robot):
self.locomotion_default_pos_state()
# 5. Start locomotion policy thread (runs in background)
logger_mp.info("Starting locomotion policy control...")
logger.info("Starting locomotion policy control...")
self.start_locomotion_thread()
logger_mp.info("Locomotion test sequence complete! Policy is now running in background.")
logger_mp.info("Use robot.stop_locomotion_thread() to stop the policy.")
logger.info("Locomotion test sequence complete! Policy is now running in background.")
logger.info("Use robot.stop_locomotion_thread() to stop the policy.")
def init_groot_locomotion(self):
"""Initialize GR00T-style locomotion for ONNX policies (29 DOF, 15D actions)."""
if not self.config.locomotion_control:
logger_mp.warning("locomotion_control is False, cannot run GR00T init")
logger.warning("locomotion_control is False, cannot run GR00T init")
return
logger_mp.info("Starting GR00T locomotion initialization...")
logger.info("Starting GR00T locomotion initialization...")
# Move legs to default position (same as regular locomotion)
self.locomotion_move_to_default_pos()
@@ -1529,11 +1526,11 @@ class UnitreeG1(Robot):
self.locomotion_default_pos_state()
# Start locomotion policy thread (will use groot_locomotion_run)
logger_mp.info("Starting GR00T locomotion policy control...")
logger.info("Starting GR00T locomotion policy control...")
self.start_locomotion_thread()
logger_mp.info("GR00T locomotion initialization complete! Policy is now running.")
logger_mp.info("516D observations (86D × 6 frames), 15D actions (legs + waist)")
logger.info("GR00T locomotion initialization complete! Policy is now running.")
logger.info("516D observations (86D × 6 frames), 15D actions (legs + waist)")
def motion_imitation_run(self):
"""Motion imitation policy loop - tracks reference motion (dance_102, etc)."""
@@ -1568,13 +1565,13 @@ class UnitreeG1(Robot):
self.motion_qj_all[joint_idx] = 0.0
self.motion_dqj_all[joint_idx] = 0.0
if self.motion_counter == 1:
logger_mp.info("="*60)
logger_mp.info("🤖 23 DOF MODE ENABLED")
logger_mp.info(f" Zeroing joints: {JOINTS_TO_ZERO_23DOF}")
logger_mp.info(" Waist: yaw(12), pitch(14)")
logger_mp.info(" Wrist L: pitch(20), yaw(21) | Wrist R: pitch(27), yaw(28)")
logger_mp.info(" Applied to: robot obs, reference motion, policy actions")
logger_mp.info("="*60)
logger.info("="*60)
logger.info("🤖 23 DOF MODE ENABLED")
logger.info(f" Zeroing joints: {JOINTS_TO_ZERO_23DOF}")
logger.info(" Waist: yaw(12), pitch(14)")
logger.info(" Wrist L: pitch(20), yaw(21) | Wrist R: pitch(27), yaw(28)")
logger.info(" Applied to: robot obs, reference motion, policy actions")
logger.info("="*60)
# Get IMU data
robot_quat = lowstate.imu_state.quaternion # [w, x, y, z]
@@ -1643,9 +1640,9 @@ class UnitreeG1(Robot):
self.msg.motor_cmd[motor_idx].tau = 0
if self.motion_counter == 1:
logger_mp.info("="*60)
logger_mp.info("⚠️ DEBUG MODE: DIRECT PLAYBACK (reference motion, no policy)")
logger_mp.info("="*60)
logger.info("="*60)
logger.info("⚠️ DEBUG MODE: DIRECT PLAYBACK (reference motion, no policy)")
logger.info("="*60)
target_joint_pos_bfs = None # Not used in this mode
@@ -1656,9 +1653,9 @@ class UnitreeG1(Robot):
motion_joint_pos_bfs = np.zeros(29, dtype=np.float32)
motion_joint_vel_bfs = np.zeros(29, dtype=np.float32)
if self.motion_counter == 1:
logger_mp.info("="*60)
logger_mp.info("⚠️ DEBUG MODE: Using ZERO reference motion + RUNNING POLICY")
logger_mp.info("="*60)
logger.info("="*60)
logger.info("⚠️ DEBUG MODE: Using ZERO reference motion + RUNNING POLICY")
logger.info("="*60)
else:
# Get reference motion (DFS order from CSV)
motion_joint_pos_dfs = self.motion_loader.get_joint_pos() # 29D
@@ -1710,13 +1707,13 @@ class UnitreeG1(Robot):
# DEBUG: Just send default positions (should make robot stand still)
target_joint_pos_bfs = default_joint_pos.copy()
if self.motion_counter == 1:
logger_mp.info("="*60)
logger_mp.info("⚠️ DEBUG MODE: Sending DEFAULT positions (NO POLICY)")
logger_mp.info("="*60)
logger_mp.info(f" Default pos BFS[0:5]: {target_joint_pos_bfs[0:5]}")
logger.info("="*60)
logger.info("⚠️ DEBUG MODE: Sending DEFAULT positions (NO POLICY)")
logger.info("="*60)
logger.info(f" Default pos BFS[0:5]: {target_joint_pos_bfs[0:5]}")
if self.motion_counter % 50 == 0:
logger_mp.info(f" [DEFAULT MODE] Sending: [{target_joint_pos_bfs[0]:.4f}, {target_joint_pos_bfs[6]:.4f}, {target_joint_pos_bfs[12]:.4f}]")
logger_mp.info(f" [DEFAULT MODE] Robot at: [{self.motion_qj_all[0]:.4f}, {self.motion_qj_all[6]:.4f}, {self.motion_qj_all[12]:.4f}]")
logger.info(f" [DEFAULT MODE] Sending: [{target_joint_pos_bfs[0]:.4f}, {target_joint_pos_bfs[6]:.4f}, {target_joint_pos_bfs[12]:.4f}]")
logger.info(f" [DEFAULT MODE] Robot at: [{self.motion_qj_all[0]:.4f}, {self.motion_qj_all[6]:.4f}, {self.motion_qj_all[12]:.4f}]")
else:
# Run ONNX policy inference
obs_tensor = torch.from_numpy(self.motion_obs).unsqueeze(0)
@@ -1744,29 +1741,29 @@ class UnitreeG1(Robot):
# Debug print (only when running policy, not in TEST_SEND_DEFAULT_POS or TEST_DIRECT_PLAYBACK mode)
if self.motion_counter == 1 and self.policy and not TEST_SEND_DEFAULT_POS and not TEST_DIRECT_PLAYBACK:
logger_mp.info("="*60)
logger_mp.info("POLICY MODE OBSERVATION CHECK (First iteration)")
logger_mp.info("="*60)
logger_mp.info(f"Reference motion (BFS) samples: [{motion_joint_pos_bfs[0]:.3f}, {motion_joint_pos_bfs[6]:.3f}, {motion_joint_pos_bfs[12]:.3f}]")
logger_mp.info(f"Robot joints (BFS) samples: [{self.motion_qj_all[0]:.3f}, {self.motion_qj_all[6]:.3f}, {self.motion_qj_all[12]:.3f}]")
logger_mp.info(f"Default positions samples: [{default_joint_pos[0]:.3f}, {default_joint_pos[6]:.3f}, {default_joint_pos[12]:.3f}]")
logger_mp.info(f"Joint pos rel samples: [{joint_pos_rel[0]:.3f}, {joint_pos_rel[6]:.3f}, {joint_pos_rel[12]:.3f}]")
logger_mp.info(f"Joint vel rel samples: [{joint_vel_rel[0]:.3f}, {joint_vel_rel[6]:.3f}, {joint_vel_rel[12]:.3f}]")
logger_mp.info(f"Angular velocity: [{ang_vel[0]:.3f}, {ang_vel[1]:.3f}, {ang_vel[2]:.3f}]")
logger_mp.info(f"Motion anchor ori: [{motion_anchor_ori_b[0]:.3f}, ..., {motion_anchor_ori_b[5]:.3f}]")
logger_mp.info(f"Observation breakdown:")
logger_mp.info(f" [0:29] motion_cmd_pos: range [{self.motion_obs[0:29].min():.3f}, {self.motion_obs[0:29].max():.3f}]")
logger_mp.info(f" [29:58] motion_cmd_vel: range [{self.motion_obs[29:58].min():.3f}, {self.motion_obs[29:58].max():.3f}]")
logger_mp.info(f" [58:64] anchor_ori: range [{self.motion_obs[58:64].min():.3f}, {self.motion_obs[58:64].max():.3f}]")
logger_mp.info(f" [64:67] ang_vel: range [{self.motion_obs[64:67].min():.3f}, {self.motion_obs[64:67].max():.3f}]")
logger_mp.info(f" [67:96] joint_pos_rel: range [{self.motion_obs[67:96].min():.3f}, {self.motion_obs[67:96].max():.3f}]")
logger_mp.info(f" [96:125] joint_vel_rel: range [{self.motion_obs[96:125].min():.3f}, {self.motion_obs[96:125].max():.3f}]")
logger_mp.info(f" [125:154] last_action: range [{self.motion_obs[125:154].min():.3f}, {self.motion_obs[125:154].max():.3f}]")
logger_mp.info(f"Full obs range: [{self.motion_obs.min():.3f}, {self.motion_obs.max():.3f}]")
logger_mp.info(f"Action output (first): [{self.motion_action.min():.3f}, {self.motion_action.max():.3f}]")
logger_mp.info(f"Action scale samples: [{action_scale[0]:.3f}, {action_scale[6]:.3f}, {action_scale[12]:.3f}]")
logger_mp.info(f"Target positions samples: [{target_joint_pos_bfs[0]:.3f}, {target_joint_pos_bfs[6]:.3f}, {target_joint_pos_bfs[12]:.3f}]")
logger_mp.info("="*60)
logger.info("="*60)
logger.info("POLICY MODE OBSERVATION CHECK (First iteration)")
logger.info("="*60)
logger.info(f"Reference motion (BFS) samples: [{motion_joint_pos_bfs[0]:.3f}, {motion_joint_pos_bfs[6]:.3f}, {motion_joint_pos_bfs[12]:.3f}]")
logger.info(f"Robot joints (BFS) samples: [{self.motion_qj_all[0]:.3f}, {self.motion_qj_all[6]:.3f}, {self.motion_qj_all[12]:.3f}]")
logger.info(f"Default positions samples: [{default_joint_pos[0]:.3f}, {default_joint_pos[6]:.3f}, {default_joint_pos[12]:.3f}]")
logger.info(f"Joint pos rel samples: [{joint_pos_rel[0]:.3f}, {joint_pos_rel[6]:.3f}, {joint_pos_rel[12]:.3f}]")
logger.info(f"Joint vel rel samples: [{joint_vel_rel[0]:.3f}, {joint_vel_rel[6]:.3f}, {joint_vel_rel[12]:.3f}]")
logger.info(f"Angular velocity: [{ang_vel[0]:.3f}, {ang_vel[1]:.3f}, {ang_vel[2]:.3f}]")
logger.info(f"Motion anchor ori: [{motion_anchor_ori_b[0]:.3f}, ..., {motion_anchor_ori_b[5]:.3f}]")
logger.info(f"Observation breakdown:")
logger.info(f" [0:29] motion_cmd_pos: range [{self.motion_obs[0:29].min():.3f}, {self.motion_obs[0:29].max():.3f}]")
logger.info(f" [29:58] motion_cmd_vel: range [{self.motion_obs[29:58].min():.3f}, {self.motion_obs[29:58].max():.3f}]")
logger.info(f" [58:64] anchor_ori: range [{self.motion_obs[58:64].min():.3f}, {self.motion_obs[58:64].max():.3f}]")
logger.info(f" [64:67] ang_vel: range [{self.motion_obs[64:67].min():.3f}, {self.motion_obs[64:67].max():.3f}]")
logger.info(f" [67:96] joint_pos_rel: range [{self.motion_obs[67:96].min():.3f}, {self.motion_obs[67:96].max():.3f}]")
logger.info(f" [96:125] joint_vel_rel: range [{self.motion_obs[96:125].min():.3f}, {self.motion_obs[96:125].max():.3f}]")
logger.info(f" [125:154] last_action: range [{self.motion_obs[125:154].min():.3f}, {self.motion_obs[125:154].max():.3f}]")
logger.info(f"Full obs range: [{self.motion_obs.min():.3f}, {self.motion_obs.max():.3f}]")
logger.info(f"Action output (first): [{self.motion_action.min():.3f}, {self.motion_action.max():.3f}]")
logger.info(f"Action scale samples: [{action_scale[0]:.3f}, {action_scale[6]:.3f}, {action_scale[12]:.3f}]")
logger.info(f"Target positions samples: [{target_joint_pos_bfs[0]:.3f}, {target_joint_pos_bfs[6]:.3f}, {target_joint_pos_bfs[12]:.3f}]")
logger.info("="*60)
if self.motion_counter % 50 == 0:
if self.policy is None:
@@ -1779,12 +1776,12 @@ class UnitreeG1(Robot):
mode = "POLICY_ZEROS"
else:
mode = "POLICY"
logger_mp.info(f"Motion {mode}: t={self.motion_elapsed_time:.2f}s, frame={self.motion_loader.index_0}/{self.motion_loader.num_frames}")
logger.info(f"Motion {mode}: t={self.motion_elapsed_time:.2f}s, frame={self.motion_loader.index_0}/{self.motion_loader.num_frames}")
if self.policy and not TEST_SEND_DEFAULT_POS and not TEST_DIRECT_PLAYBACK:
logger_mp.info(f" Policy action range: [{self.motion_action.min():.3f}, {self.motion_action.max():.3f}]")
logger_mp.info(f" Sample actions[0,6,12]: [{self.motion_action[0]:.3f}, {self.motion_action[6]:.3f}, {self.motion_action[12]:.3f}]")
logger_mp.info(f" Target pos (after scale)[0,6,12]: [{target_joint_pos_bfs[0]:.3f}, {target_joint_pos_bfs[6]:.3f}, {target_joint_pos_bfs[12]:.3f}]")
logger_mp.info(f" Robot pos (BFS)[0,6,12]: [{self.motion_qj_all[0]:.3f}, {self.motion_qj_all[6]:.3f}, {self.motion_qj_all[12]:.3f}]")
logger.info(f" Policy action range: [{self.motion_action.min():.3f}, {self.motion_action.max():.3f}]")
logger.info(f" Sample actions[0,6,12]: [{self.motion_action[0]:.3f}, {self.motion_action[6]:.3f}, {self.motion_action[12]:.3f}]")
logger.info(f" Target pos (after scale)[0,6,12]: [{target_joint_pos_bfs[0]:.3f}, {target_joint_pos_bfs[6]:.3f}, {target_joint_pos_bfs[12]:.3f}]")
logger.info(f" Robot pos (BFS)[0,6,12]: [{self.motion_qj_all[0]:.3f}, {self.motion_qj_all[6]:.3f}, {self.motion_qj_all[12]:.3f}]")
# Send command
self.msg.crc = self.crc.Crc(self.msg)
@@ -1792,13 +1789,13 @@ class UnitreeG1(Robot):
def _motion_imitation_thread_loop(self):
"""Background thread that runs the motion imitation policy at specified rate."""
logger_mp.info("Motion imitation thread started")
logger.info("Motion imitation thread started")
while self.motion_imitation_running:
start_time = time.time()
try:
self.motion_imitation_run()
except Exception as e:
logger_mp.error(f"Error in motion imitation loop: {e}")
logger.error(f"Error in motion imitation loop: {e}")
import traceback
traceback.print_exc()
@@ -1806,45 +1803,45 @@ class UnitreeG1(Robot):
elapsed = time.time() - start_time
sleep_time = max(0, self.config.motion_control_dt - elapsed)
time.sleep(sleep_time)
logger_mp.info("Motion imitation thread stopped")
logger.info("Motion imitation thread stopped")
def start_motion_imitation_thread(self):
"""Start the background motion imitation control thread."""
if not self.config.motion_imitation_control:
logger_mp.warning("motion_imitation_control is False, cannot start thread")
logger.warning("motion_imitation_control is False, cannot start thread")
return
if self.motion_imitation_running:
logger_mp.warning("Motion imitation thread already running")
logger.warning("Motion imitation thread already running")
return
logger_mp.info("Starting motion imitation control thread...")
logger.info("Starting motion imitation control thread...")
self.motion_imitation_running = True
self.motion_imitation_thread = threading.Thread(target=self._motion_imitation_thread_loop, daemon=True)
self.motion_imitation_thread.start()
logger_mp.info("Motion imitation control thread started!")
logger.info("Motion imitation control thread started!")
def stop_motion_imitation_thread(self):
"""Stop the background motion imitation control thread."""
if not self.motion_imitation_running:
return
logger_mp.info("Stopping motion imitation control thread...")
logger.info("Stopping motion imitation control thread...")
self.motion_imitation_running = False
if self.motion_imitation_thread:
self.motion_imitation_thread.join(timeout=2.0)
logger_mp.info("Motion imitation control thread stopped")
logger.info("Motion imitation control thread stopped")
def init_motion_imitation(self):
"""Initialize motion imitation - move to default standing pose and start policy."""
if not self.config.motion_imitation_control:
logger_mp.warning("motion_imitation_control is False, cannot run initialization")
logger.warning("motion_imitation_control is False, cannot run initialization")
return
logger_mp.info("Starting motion imitation initialization...")
logger.info("Starting motion imitation initialization...")
# Move to default standing position
logger_mp.info("Moving to default standing position...")
logger.info("Moving to default standing position...")
total_time = 3.0
num_steps = int(total_time / self.config.motion_control_dt)
@@ -1871,17 +1868,17 @@ class UnitreeG1(Robot):
self.lowcmd_publisher.Write(self.msg)
time.sleep(self.config.motion_control_dt)
logger_mp.info("Reached default position")
logger.info("Reached default position")
# Wait 2 seconds
time.sleep(2.0)
# Start motion imitation policy thread
logger_mp.info("Starting motion imitation policy control...")
logger.info("Starting motion imitation policy control...")
self.start_motion_imitation_thread()
logger_mp.info("Motion imitation initialization complete! Policy is now running.")
logger_mp.info(f"154D observations, 29D actions. Motion duration: {self.motion_loader.duration:.2f}s")
logger.info("Motion imitation initialization complete! Policy is now running.")
logger.info(f"154D observations, 29D actions. Motion duration: {self.motion_loader.duration:.2f}s")
class G1_29_JointArmIndex(IntEnum):

View File

@@ -0,0 +1,73 @@
# unitree_sdk2_socket.py
import zmq
import pickle
import time
# you can tune these or read from env
ROBOT_IP = "172.18.129.215"
LOWCMD_PORT = 6000 # laptop -> robot
LOWSTATE_PORT = 6001 # robot -> laptop
_ctx = None
_lowcmd_sock = None
_lowstate_sock = None
def ChannelFactoryInitialize(*args, **kwargs):
global _ctx, _lowcmd_sock, _lowstate_sock
if _ctx is not None:
return
_ctx = zmq.Context.instance()
# lowcmd: PUSH from laptop to robot
_lowcmd_sock = _ctx.socket(zmq.PUSH)
_lowcmd_sock.setsockopt(zmq.CONFLATE, 1)
_lowcmd_sock.connect(f"tcp://{ROBOT_IP}:{LOWCMD_PORT}")
# lowstate: SUB from robot
_lowstate_sock = _ctx.socket(zmq.SUB) # no topic filtering
_lowstate_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
_lowstate_sock.connect(f"tcp://{ROBOT_IP}:{LOWSTATE_PORT}")
_lowstate_sock.setsockopt_string(zmq.SUBSCRIBE, "") # subscribe to all
class ChannelPublisher:
# just enough api for your code: __init__, Init, Write
def __init__(self, topic, msg_type):
# we ignore topic/msg_type, the bridge only supports the topics you use
self.topic = topic
self.msg_type = msg_type
def Init(self):
# nothing to do, sockets are global
pass
def Write(self, msg):
# msg is hg_LowCmd_ instance we just pickle it
payload = pickle.dumps((self.topic, msg))
_lowcmd_sock.send(payload)
class ChannelSubscriber:
# api: __init__, Init, Read
def __init__(self, topic, msg_type):
self.topic = topic
self.msg_type = msg_type
def Init(self):
pass
def Read(self, timeout_ms=None):
"""Block until we get a lowstate, optionally with timeout (ms)."""
if timeout_ms is None:
payload = _lowstate_sock.recv()
else:
poller = zmq.Poller()
poller.register(_lowstate_sock, zmq.POLLIN)
events = dict(poller.poll(timeout_ms))
if _lowstate_sock not in events:
return None
payload = _lowstate_sock.recv()
topic, msg = pickle.loads(payload)
# you can assert topic == self.topic, but not necessary if you only use one
return msg

View File

@@ -55,7 +55,7 @@ class HomunculusArm(Teleoperator):
"wrist_yaw": MotorNormMode.RANGE_M100_100,
"wrist_pitch": MotorNormMode.RANGE_M100_100,
}
n = 50
n = 10
# EMA parameters ---------------------------------------------------
self.n: int = n
self.alpha: float = 2 / (n + 1)