mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
sync updates
This commit is contained in:
@@ -19,6 +19,7 @@ Provides the ZMQCamera class for capturing frames from remote cameras via ZeroMQ
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Event, Lock, Thread
|
||||
@@ -105,6 +106,9 @@ class ZMQCamera(Camera):
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
# Format type detected during connection (msgpack, json, or raw_jpeg)
|
||||
self._format_type: str | None = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.camera_name}@{self.server_address}:{self.port})"
|
||||
@@ -141,9 +145,22 @@ class ZMQCamera(Camera):
|
||||
|
||||
self._connected = True
|
||||
|
||||
# Try to receive one frame to validate connection
|
||||
# Try to receive one frame to validate connection and detect format
|
||||
try:
|
||||
test_frame = self.read()
|
||||
# Try each format until one works
|
||||
test_frame = None
|
||||
for format_type in ["msgpack", "json", "raw_jpeg"]:
|
||||
try:
|
||||
test_frame = self.read(format=format_type)
|
||||
self._format_type = format_type
|
||||
logger.info(f"{self} detected format: {format_type}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"{self} format '{format_type}' failed: {e}")
|
||||
continue
|
||||
|
||||
if test_frame is None:
|
||||
raise RuntimeError("Failed to decode frame with any supported format (msgpack, json, raw_jpeg)")
|
||||
|
||||
# Auto-detect resolution if not specified
|
||||
if self.width is None or self.height is None:
|
||||
@@ -179,136 +196,185 @@ class ZMQCamera(Camera):
|
||||
raise RuntimeError(f"Failed to connect to {self}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def find_cameras() -> list[dict[str, Any]]:
|
||||
def find_cameras(
|
||||
subnet: str | None = None,
|
||||
ports: list[int] | None = None,
|
||||
timeout_ms: int = 200,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detects available ZMQ cameras based on configuration.
|
||||
Scans the local network for ZMQ cameras (fast parallel scan).
|
||||
|
||||
Reads camera configurations from:
|
||||
1. Environment variable LEROBOT_ZMQ_CAMERAS (JSON format)
|
||||
2. Config file at ~/.lerobot/zmq_cameras.json
|
||||
Uses threading to scan multiple hosts simultaneously. Without parallelization,
|
||||
scanning 254 hosts would take 6+ minutes. With threads, takes ~10-15 seconds.
|
||||
|
||||
Example JSON format:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"name": "unitree_g1_head",
|
||||
"address": "192.168.123.164",
|
||||
"port": 5554
|
||||
},
|
||||
{
|
||||
"name": "lab_cam_1",
|
||||
"address": "192.168.1.100",
|
||||
"port": 5555
|
||||
}
|
||||
]
|
||||
```
|
||||
Args:
|
||||
subnet: Network subnet to scan (e.g., "192.168.1.0/24"). If None, auto-detects.
|
||||
ports: List of ports to scan. Defaults to [5554, 5555, 5556].
|
||||
timeout_ms: Connection timeout per host in milliseconds. Default: 200ms.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries containing ZMQ camera information.
|
||||
List of dicts containing camera info (address, port, format, resolution).
|
||||
|
||||
Example:
|
||||
>>> cameras = ZMQCamera.find_cameras()
|
||||
>>> # Or specify: cameras = ZMQCamera.find_cameras(subnet="10.0.0.0/24", ports=[5554])
|
||||
"""
|
||||
found_cameras_info = []
|
||||
camera_configs = []
|
||||
import socket
|
||||
import ipaddress
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
# Try to load from environment variable first
|
||||
env_cameras = os.environ.get("LEROBOT_ZMQ_CAMERAS")
|
||||
if env_cameras:
|
||||
if ports is None:
|
||||
ports = [5554, 5555, 5556]
|
||||
|
||||
# Auto-detect local subnet
|
||||
if subnet is None:
|
||||
try:
|
||||
camera_configs = json.loads(env_cameras)
|
||||
logger.info(f"Loaded {len(camera_configs)} ZMQ camera configs from LEROBOT_ZMQ_CAMERAS")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse LEROBOT_ZMQ_CAMERAS environment variable: {e}")
|
||||
#use unitree_g1_head as an example
|
||||
camera_configs = [
|
||||
{
|
||||
"name": "unitree_g1_head",
|
||||
"address": "192.168.123.164",
|
||||
"port": 5554
|
||||
}
|
||||
]
|
||||
# Try to load from config file
|
||||
if not camera_configs:
|
||||
config_path = Path.home() / ".lerobot" / "zmq_cameras.json"
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
camera_configs = json.load(f)
|
||||
logger.info(f"Loaded {len(camera_configs)} ZMQ camera configs from {config_path}")
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.warning(f"Failed to load ZMQ camera config from {config_path}: {e}")
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.connect(("8.8.8.8", 80))
|
||||
local_ip = s.getsockname()[0]
|
||||
s.close()
|
||||
subnet = ".".join(local_ip.split(".")[:-1]) + ".0/24"
|
||||
logger.info(f"Auto-detected subnet: {subnet}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to auto-detect subnet: {e}")
|
||||
return []
|
||||
|
||||
if not camera_configs:
|
||||
logger.info(
|
||||
"No ZMQ cameras configured. Set LEROBOT_ZMQ_CAMERAS environment variable "
|
||||
f"or create {Path.home() / '.lerobot' / 'zmq_cameras.json'}"
|
||||
)
|
||||
# Parse subnet
|
||||
try:
|
||||
network = ipaddress.ip_network(subnet, strict=False)
|
||||
hosts = list(network.hosts())
|
||||
# Always include localhost (for MuJoCo sim, local servers)
|
||||
hosts.insert(0, ipaddress.IPv4Address("127.0.0.1"))
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid subnet '{subnet}': {e}")
|
||||
return []
|
||||
|
||||
# Test each configured camera
|
||||
for cam_config in camera_configs:
|
||||
try:
|
||||
name = cam_config.get("name", "unknown")
|
||||
address = cam_config.get("address")
|
||||
port = cam_config.get("port", 5554)
|
||||
|
||||
if not address:
|
||||
logger.warning(f"Skipping camera '{name}': missing address")
|
||||
continue
|
||||
|
||||
# Try to connect with a short timeout
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.SUB)
|
||||
socket.connect(f"tcp://{address}:{port}")
|
||||
socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 2000) # 2 second timeout for discovery
|
||||
|
||||
total = len(hosts) * len(ports)
|
||||
logger.info(f"Scanning {len(hosts)} hosts × {len(ports)} ports = {total} targets (this takes ~10-15s)...")
|
||||
|
||||
def test_target(host_ip: str, port: int) -> dict | None:
|
||||
"""Test one host:port for ZMQ camera."""
|
||||
ctx = zmq.Context()
|
||||
sock = ctx.socket(zmq.SUB)
|
||||
sock.connect(f"tcp://{host_ip}:{port}")
|
||||
sock.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
sock.setsockopt(zmq.RCVTIMEO, timeout_ms)
|
||||
|
||||
# Wait for subscription to establish (ZMQ "slow joiner" problem)
|
||||
time.sleep(0.1)
|
||||
|
||||
# Try receiving a few times
|
||||
msg = None
|
||||
for _ in range(3):
|
||||
try:
|
||||
# Try to receive one frame to validate
|
||||
message = socket.recv()
|
||||
np_img = np.frombuffer(message, dtype=np.uint8)
|
||||
test_image = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
|
||||
|
||||
if test_image is not None:
|
||||
height, width = test_image.shape[:2]
|
||||
|
||||
camera_info = {
|
||||
"name": f"ZMQ Camera: {name}",
|
||||
"type": "ZMQ",
|
||||
"id": f"{address}:{port}",
|
||||
"server_address": address,
|
||||
"port": port,
|
||||
"camera_name": name,
|
||||
"default_stream_profile": {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"format": "JPEG",
|
||||
},
|
||||
}
|
||||
found_cameras_info.append(camera_info)
|
||||
logger.info(f"Found ZMQ camera: {name} at {address}:{port}")
|
||||
else:
|
||||
logger.warning(f"Camera '{name}' at {address}:{port} returned invalid image")
|
||||
|
||||
msg = sock.recv()
|
||||
break
|
||||
except zmq.Again:
|
||||
logger.warning(f"Camera '{name}' at {address}:{port} timeout - not streaming")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error testing camera '{name}' at {address}:{port}: {e}")
|
||||
finally:
|
||||
socket.close()
|
||||
context.term()
|
||||
time.sleep(0.05)
|
||||
|
||||
if msg is None:
|
||||
sock.close()
|
||||
ctx.term()
|
||||
return None
|
||||
|
||||
# Try formats: msgpack → json → raw_jpeg
|
||||
frame = fmt = None
|
||||
|
||||
# Msgpack
|
||||
try:
|
||||
d = msgpack.unpackb(msg, object_hook=m.decode)
|
||||
if isinstance(d, dict) and "images" in d and len(d["images"]) > 0:
|
||||
img = next(iter(d["images"].values()))
|
||||
if isinstance(img, str):
|
||||
frame = cv2.imdecode(np.frombuffer(base64.b64decode(img), np.uint8), cv2.IMREAD_COLOR)
|
||||
elif isinstance(img, np.ndarray):
|
||||
frame = img
|
||||
if frame is not None:
|
||||
fmt = "msgpack"
|
||||
except:
|
||||
pass
|
||||
|
||||
# JSON
|
||||
if frame is None:
|
||||
try:
|
||||
d = json.loads(msg.decode('utf-8'))
|
||||
if isinstance(d, dict):
|
||||
for v in d.values():
|
||||
if isinstance(v, str) and len(v) > 100:
|
||||
try:
|
||||
frame = cv2.imdecode(np.frombuffer(base64.b64decode(v), np.uint8), cv2.IMREAD_COLOR)
|
||||
if frame is not None:
|
||||
fmt = "json"
|
||||
break
|
||||
except:
|
||||
pass
|
||||
except:
|
||||
pass
|
||||
|
||||
# Raw JPEG
|
||||
if frame is None:
|
||||
try:
|
||||
frame = cv2.imdecode(np.frombuffer(msg, np.uint8), cv2.IMREAD_COLOR)
|
||||
if frame is not None:
|
||||
fmt = "raw_jpeg"
|
||||
except:
|
||||
pass
|
||||
|
||||
sock.close()
|
||||
ctx.term()
|
||||
|
||||
if frame is not None:
|
||||
h, w = frame.shape[:2]
|
||||
return {
|
||||
"name": f"ZMQ @ {host_ip}:{port}",
|
||||
"type": "ZMQ",
|
||||
"id": f"{host_ip}:{port}",
|
||||
"server_address": host_ip,
|
||||
"port": port,
|
||||
"camera_name": f"cam_{host_ip.replace('.', '_')}_{port}",
|
||||
"format": fmt,
|
||||
"default_stream_profile": {"width": w, "height": h, "format": fmt.upper()},
|
||||
}
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing camera config: {e}")
|
||||
# Parallel scan with thread pool
|
||||
found = []
|
||||
with ThreadPoolExecutor(max_workers=100) as ex:
|
||||
futures = [ex.submit(test_target, str(h), p) for h in hosts for p in ports]
|
||||
for i, fut in enumerate(as_completed(futures), 1):
|
||||
if i % 100 == 0:
|
||||
logger.info(f" Progress: {i}/{total} ({100*i//total}%)")
|
||||
res = fut.result()
|
||||
if res:
|
||||
found.append(res)
|
||||
logger.info(f" ✓ {res['server_address']}:{res['port']} ({res['format']})")
|
||||
|
||||
logger.info(f"Scan complete! Found {len(found)} camera(s).")
|
||||
return found
|
||||
|
||||
return found_cameras_info
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
def read(self, color_mode: ColorMode | None = None, format: str | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the ZMQ camera.
|
||||
|
||||
For the mujoco sim (127.0.0.1:5554), the server sends a msgpack-encoded
|
||||
dict with base64-encoded JPEG images:
|
||||
{"timestamps": {...}, "images": {camera_name: "<b64 jpeg>"}}
|
||||
For other ZMQ cameras, we assume a raw JPEG buffer.
|
||||
Supports three message formats:
|
||||
1. "msgpack": Msgpack with base64 JPEGs: {"timestamps": {...}, "images": {camera_name: "b64"}}
|
||||
(used by MuJoCo sim)
|
||||
2. "json": JSON with base64 JPEGs: {"state": 0.0, "camera_name": "b64jpeg"}
|
||||
(used by LeKiwi-style servers)
|
||||
3. "raw_jpeg": Raw JPEG bytes (used by Unitree G1 head camera)
|
||||
|
||||
Args:
|
||||
color_mode: Target color mode (RGB or BGR). If None, uses self.color_mode.
|
||||
format: Message format to use. If None, uses auto-detected format from connect().
|
||||
One of: "msgpack", "json", "raw_jpeg"
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decoded frame in shape (height, width, 3)
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If camera is not connected
|
||||
TimeoutError: If no frame received within timeout_ms
|
||||
RuntimeError: If frame decoding fails
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
@@ -316,6 +382,13 @@ class ZMQCamera(Camera):
|
||||
if self.socket is None:
|
||||
raise DeviceNotConnectedError(f"{self} socket is not initialized")
|
||||
|
||||
# Use detected format if not specified
|
||||
if format is None:
|
||||
format = self._format_type
|
||||
|
||||
if format is None:
|
||||
raise RuntimeError(f"{self} format not specified and not auto-detected during connect()")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
@@ -327,48 +400,55 @@ class ZMQCamera(Camera):
|
||||
|
||||
frame = None
|
||||
|
||||
# special-case: mujoco sim publisher (msgpack + base64)
|
||||
if self.server_address in ("127.0.0.1", "localhost") and self.port == 5554:
|
||||
try:
|
||||
data = msgpack.unpackb(message, object_hook=m.decode)
|
||||
# Decode based on format
|
||||
if format == "msgpack":
|
||||
data = msgpack.unpackb(message, object_hook=m.decode)
|
||||
if not isinstance(data, dict) or "images" not in data:
|
||||
raise RuntimeError(f"{self} invalid msgpack format: expected dict with 'images' key")
|
||||
|
||||
if not isinstance(data, dict) or "images" not in data:
|
||||
raise RuntimeError(f"{self} received invalid message format from sim")
|
||||
images_dict = data["images"]
|
||||
|
||||
# Prefer named camera if present
|
||||
if self.camera_name in images_dict:
|
||||
img_data = images_dict[self.camera_name]
|
||||
elif len(images_dict) > 0:
|
||||
# Fallback: first available camera
|
||||
img_data = next(iter(images_dict.values()))
|
||||
else:
|
||||
raise RuntimeError(f"{self} no images found in msgpack message")
|
||||
|
||||
images_dict = data["images"]
|
||||
img_data = None
|
||||
|
||||
# prefer named camera if present
|
||||
if self.camera_name in images_dict:
|
||||
img_data = images_dict[self.camera_name]
|
||||
elif len(images_dict) > 0:
|
||||
# fallback: first available camera
|
||||
img_data = next(iter(images_dict.values()))
|
||||
else:
|
||||
raise RuntimeError(f"{self} no images found in sim message")
|
||||
|
||||
# same logic as ImageUtils.decode_image
|
||||
if isinstance(img_data, str):
|
||||
color_bytes = base64.b64decode(img_data)
|
||||
np_img = np.frombuffer(color_bytes, dtype=np.uint8)
|
||||
frame = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
|
||||
elif isinstance(img_data, np.ndarray):
|
||||
frame = img_data
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{self} unknown image payload type from sim: {type(img_data)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"{self} failed to decode sim image: {e}")
|
||||
|
||||
# default path: raw jpeg over ZMQ (robot-side cameras etc.)
|
||||
else:
|
||||
# Decode the image data
|
||||
if isinstance(img_data, str):
|
||||
color_bytes = base64.b64decode(img_data)
|
||||
np_img = np.frombuffer(color_bytes, dtype=np.uint8)
|
||||
frame = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
|
||||
elif isinstance(img_data, np.ndarray):
|
||||
frame = img_data
|
||||
else:
|
||||
raise RuntimeError(f"{self} unknown image payload type: {type(img_data)}")
|
||||
|
||||
elif format == "json":
|
||||
data = json.loads(message.decode('utf-8'))
|
||||
if not isinstance(data, dict) or self.camera_name not in data:
|
||||
raise RuntimeError(f"{self} invalid JSON format: expected dict with '{self.camera_name}' key")
|
||||
|
||||
img_b64 = data[self.camera_name]
|
||||
if not isinstance(img_b64, str):
|
||||
raise RuntimeError(f"{self} expected base64 string in JSON, got {type(img_b64)}")
|
||||
|
||||
color_bytes = base64.b64decode(img_b64)
|
||||
np_img = np.frombuffer(color_bytes, dtype=np.uint8)
|
||||
frame = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
|
||||
|
||||
elif format == "raw_jpeg":
|
||||
np_img = np.frombuffer(message, dtype=np.uint8)
|
||||
frame = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
|
||||
|
||||
else:
|
||||
raise ValueError(f"{self} unsupported format: {format}. Use 'msgpack', 'json', or 'raw_jpeg'")
|
||||
|
||||
if frame is None or not isinstance(frame, np.ndarray):
|
||||
raise RuntimeError(f"{self} failed to decode image")
|
||||
raise RuntimeError(f"{self} failed to decode image using format '{format}'")
|
||||
|
||||
processed_frame = self._postprocess_image(frame, color_mode)
|
||||
|
||||
@@ -471,7 +551,7 @@ class ZMQCamera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 2000) -> NDArray[Any]:
|
||||
def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
|
||||
@@ -54,11 +54,6 @@ class ZMQCameraConfig(CameraConfig):
|
||||
camera_name: Identifier name for this camera (for logging/debugging).
|
||||
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
|
||||
timeout_ms: Timeout in milliseconds for receiving frames. Defaults to 1000ms.
|
||||
|
||||
Note:
|
||||
- The server must be streaming JPEG-encoded images over ZMQ PUB socket.
|
||||
- Width and height should match the expected output dimensions from the server.
|
||||
- FPS is informational and doesn't control the server's frame rate.
|
||||
"""
|
||||
|
||||
server_address: str
|
||||
@@ -81,4 +76,3 @@ class ZMQCameraConfig(CameraConfig):
|
||||
|
||||
if self.port <= 0 or self.port > 65535:
|
||||
raise ValueError(f"`port` must be between 1 and 65535, but {self.port} is provided.")
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": -3,
|
||||
"range_max": 2.6
|
||||
"range_max": 1
|
||||
},
|
||||
"kLeftShoulderYaw.pos": {
|
||||
"id": 1,
|
||||
@@ -25,7 +25,7 @@
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": -1,
|
||||
"range_max": 2
|
||||
"range_max": 1
|
||||
},
|
||||
"kLeftWristRoll.pos": {
|
||||
"id": 4,
|
||||
@@ -61,7 +61,7 @@
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": -3.0,
|
||||
"range_max": 2.6
|
||||
"range_max": 1
|
||||
},
|
||||
"kRightShoulderYaw.pos": {
|
||||
"id": 1,
|
||||
@@ -82,7 +82,7 @@
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": -1,
|
||||
"range_max": 2
|
||||
"range_max": 1
|
||||
},
|
||||
"kRightWristRoll.pos": {
|
||||
"id": 4,
|
||||
|
||||
@@ -48,12 +48,19 @@ class UnitreeG1Config(RobotConfig):
|
||||
audio_client: bool = True
|
||||
|
||||
freeze_body: bool = False
|
||||
gravity_compensation: bool = False
|
||||
gravity_compensation: bool = True
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# Socket communication configuration (REQUIRED)
|
||||
# This robot class ONLY uses sockets to communicate with a bridge on the Orin
|
||||
# Run 'python dds_to_socket.py' on the Orin first, then set this to the Orin's IP
|
||||
# Example: socket_host="192.168.123.164" (Orin's wlan0 IP)
|
||||
socket_host: str | None = None
|
||||
socket_port: int | None = None
|
||||
|
||||
# Locomotion control
|
||||
locomotion_control: bool = False
|
||||
locomotion_control: bool = True
|
||||
#policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/motion.pt"
|
||||
policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/GR00T-WholeBodyControl-Walk.onnx"
|
||||
|
||||
@@ -64,9 +71,7 @@ class UnitreeG1Config(RobotConfig):
|
||||
motion_file_path: str = "unitree_rl_lab/deploy/robots/g1_29dof/config/policy/mimic/dance_102/params/G1_Take_102.bvh_60hz.csv"
|
||||
motion_fps: float = 60.0
|
||||
motion_control_dt: float = 0.02
|
||||
|
||||
# Motion imitation parameters (from deploy.yaml)
|
||||
|
||||
|
||||
motion_joint_ids_map: list = field(default_factory=lambda: [0, 6, 12, 1, 7, 13, 2, 8, 14, 3, 9, 15, 22, 4, 10, 16, 23, 5, 11, 17, 24, 18, 25, 19, 26, 20, 27, 21, 28])
|
||||
motion_stiffness: list = field(default_factory=lambda: [40.2, 99.1, 40.2, 99.1, 28.5, 28.5, 40.2, 99.1, 40.2, 99.1, 28.5, 28.5, 40.2, 28.5, 28.5, 14.3, 14.3, 14.3, 14.3, 14.3, 16.8, 16.8, 14.3, 14.3, 14.3, 14.3, 14.3, 16.8, 16.8])
|
||||
motion_damping: list = field(default_factory=lambda: [2.56, 6.31, 2.56, 6.31, 1.81, 1.81, 2.56, 6.31, 2.56, 6.31, 1.81, 1.81, 2.56, 1.81, 1.81, 0.907, 0.907, 0.907, 0.907, 0.907, 1.07, 1.07, 0.907, 0.907, 0.907, 0.907, 0.907, 1.07, 1.07])
|
||||
@@ -98,6 +103,4 @@ class UnitreeG1Config(RobotConfig):
|
||||
num_locomotion_actions: int = 12
|
||||
num_locomotion_obs: int = 47
|
||||
max_cmd: list = field(default_factory=lambda: [0.8, 0.5, 1.57])
|
||||
locomotion_imu_type: str = "pelvis" # "torso" or "pelvis"
|
||||
|
||||
|
||||
locomotion_imu_type: str = "pelvis" # "torso" or "pelvis"
|
||||
347
src/lerobot/robots/unitree_g1/image_server.py
Normal file
347
src/lerobot/robots/unitree_g1/image_server.py
Normal file
@@ -0,0 +1,347 @@
|
||||
import cv2
|
||||
import zmq
|
||||
import time
|
||||
import struct
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
import pyrealsense2 as rs
|
||||
import logging_mp
|
||||
|
||||
logger_mp = logging_mp.get_logger(__name__, level=logging_mp.DEBUG)
|
||||
|
||||
|
||||
class RealSenseCamera(object):
|
||||
def __init__(self, img_shape, fps, serial_number=None, enable_depth=False) -> None:
|
||||
"""
|
||||
img_shape: [height, width]
|
||||
serial_number: serial number
|
||||
"""
|
||||
self.img_shape = img_shape
|
||||
self.fps = fps
|
||||
self.serial_number = serial_number
|
||||
self.enable_depth = enable_depth
|
||||
|
||||
align_to = rs.stream.color
|
||||
self.align = rs.align(align_to)
|
||||
self.init_realsense()
|
||||
|
||||
def init_realsense(self):
|
||||
self.pipeline = rs.pipeline()
|
||||
config = rs.config()
|
||||
if self.serial_number is not None:
|
||||
config.enable_device(self.serial_number)
|
||||
|
||||
config.enable_stream(rs.stream.color, self.img_shape[1], self.img_shape[0], rs.format.bgr8, self.fps)
|
||||
|
||||
if self.enable_depth:
|
||||
config.enable_stream(rs.stream.depth, self.img_shape[1], self.img_shape[0], rs.format.z16, self.fps)
|
||||
|
||||
profile = self.pipeline.start(config)
|
||||
self._device = profile.get_device()
|
||||
if self._device is None:
|
||||
logger_mp.error("[Image Server] pipe_profile.get_device() is None .")
|
||||
if self.enable_depth:
|
||||
assert self._device is not None
|
||||
depth_sensor = self._device.first_depth_sensor()
|
||||
self.g_depth_scale = depth_sensor.get_depth_scale()
|
||||
|
||||
self.intrinsics = profile.get_stream(rs.stream.color).as_video_stream_profile().get_intrinsics()
|
||||
|
||||
def get_frame(self):
|
||||
frames = self.pipeline.wait_for_frames()
|
||||
aligned_frames = self.align.process(frames)
|
||||
color_frame = aligned_frames.get_color_frame()
|
||||
|
||||
if self.enable_depth:
|
||||
depth_frame = aligned_frames.get_depth_frame()
|
||||
|
||||
if not color_frame:
|
||||
return None
|
||||
|
||||
color_image = np.asanyarray(color_frame.get_data())
|
||||
# color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
||||
depth_image = np.asanyarray(depth_frame.get_data()) if self.enable_depth else None
|
||||
return color_image, depth_image
|
||||
|
||||
def release(self):
|
||||
self.pipeline.stop()
|
||||
|
||||
|
||||
class OpenCVCamera:
|
||||
def __init__(self, device_id, img_shape, fps):
|
||||
"""
|
||||
decive_id: /dev/video* or *
|
||||
img_shape: [height, width]
|
||||
"""
|
||||
self.id = device_id
|
||||
self.fps = fps
|
||||
self.img_shape = img_shape
|
||||
self.cap = cv2.VideoCapture(self.id, cv2.CAP_V4L2)
|
||||
self.cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter.fourcc("M", "J", "P", "G"))
|
||||
self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.img_shape[0])
|
||||
self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.img_shape[1])
|
||||
self.cap.set(cv2.CAP_PROP_FPS, self.fps)
|
||||
|
||||
# Test if the camera can read frames
|
||||
if not self._can_read_frame():
|
||||
logger_mp.error(
|
||||
f"[Image Server] Camera {self.id} Error: Failed to initialize the camera or read frames. Exiting..."
|
||||
)
|
||||
self.release()
|
||||
|
||||
def _can_read_frame(self):
|
||||
success, _ = self.cap.read()
|
||||
return success
|
||||
|
||||
def release(self):
|
||||
self.cap.release()
|
||||
|
||||
def get_frame(self):
|
||||
ret, color_image = self.cap.read()
|
||||
if not ret:
|
||||
return None
|
||||
return color_image
|
||||
|
||||
|
||||
class ImageServer:
|
||||
def __init__(self, config, port=5554, Unit_Test=False):
|
||||
"""
|
||||
config example1:
|
||||
{
|
||||
'fps':30 # frame per second
|
||||
'head_camera_type': 'opencv', # opencv or realsense
|
||||
'head_camera_image_shape': [480, 1280], # Head camera resolution [height, width]
|
||||
'head_camera_id_numbers': [0], # '/dev/video0' (opencv)
|
||||
'wrist_camera_type': 'realsense',
|
||||
'wrist_camera_image_shape': [480, 640], # Wrist camera resolution [height, width]
|
||||
'wrist_camera_id_numbers': ["218622271789", "241222076627"], # realsense camera's serial number
|
||||
}
|
||||
|
||||
config example2:
|
||||
{
|
||||
'fps':30 # frame per second
|
||||
'head_camera_type': 'realsense', # opencv or realsense
|
||||
'head_camera_image_shape': [480, 640], # Head camera resolution [height, width]
|
||||
'head_camera_id_numbers': ["218622271739"], # realsense camera's serial number
|
||||
'wrist_camera_type': 'opencv',
|
||||
'wrist_camera_image_shape': [480, 640], # Wrist camera resolution [height, width]
|
||||
'wrist_camera_id_numbers': [0,1], # '/dev/video0' and '/dev/video1' (opencv)
|
||||
}
|
||||
|
||||
If you are not using the wrist camera, you can comment out its configuration, like this below:
|
||||
config:
|
||||
{
|
||||
'fps':30 # frame per second
|
||||
'head_camera_type': 'opencv', # opencv or realsense
|
||||
'head_camera_image_shape': [480, 1280], # Head camera resolution [height, width]
|
||||
'head_camera_id_numbers': [0], # '/dev/video0' (opencv)
|
||||
#'wrist_camera_type': 'realsense',
|
||||
#'wrist_camera_image_shape': [480, 640], # Wrist camera resolution [height, width]
|
||||
#'wrist_camera_id_numbers': ["218622271789", "241222076627"], # serial number (realsense)
|
||||
}
|
||||
"""
|
||||
logger_mp.info(config)
|
||||
self.fps = config.get("fps", 30)
|
||||
self.head_camera_type = config.get("head_camera_type", "opencv")
|
||||
self.head_image_shape = config.get("head_camera_image_shape", [480, 640]) # (height, width)
|
||||
self.head_camera_id_numbers = config.get("head_camera_id_numbers", [0])
|
||||
|
||||
self.wrist_camera_type = config.get("wrist_camera_type", None)
|
||||
self.wrist_image_shape = config.get("wrist_camera_image_shape", [480, 640]) # (height, width)
|
||||
self.wrist_camera_id_numbers = config.get("wrist_camera_id_numbers", None)
|
||||
|
||||
self.port = port
|
||||
self.Unit_Test = Unit_Test
|
||||
|
||||
# Initialize head cameras
|
||||
self.head_cameras = []
|
||||
if self.head_camera_type == "opencv":
|
||||
for device_id in self.head_camera_id_numbers:
|
||||
camera = OpenCVCamera(device_id=device_id, img_shape=self.head_image_shape, fps=self.fps)
|
||||
self.head_cameras.append(camera)
|
||||
elif self.head_camera_type == "realsense":
|
||||
for serial_number in self.head_camera_id_numbers:
|
||||
camera = RealSenseCamera(img_shape=self.head_image_shape, fps=self.fps, serial_number=serial_number)
|
||||
self.head_cameras.append(camera)
|
||||
else:
|
||||
logger_mp.warning(f"[Image Server] Unsupported head_camera_type: {self.head_camera_type}")
|
||||
|
||||
# Initialize wrist cameras if provided
|
||||
self.wrist_cameras = []
|
||||
if self.wrist_camera_type and self.wrist_camera_id_numbers:
|
||||
if self.wrist_camera_type == "opencv":
|
||||
for device_id in self.wrist_camera_id_numbers:
|
||||
camera = OpenCVCamera(device_id=device_id, img_shape=self.wrist_image_shape, fps=self.fps)
|
||||
self.wrist_cameras.append(camera)
|
||||
elif self.wrist_camera_type == "realsense":
|
||||
for serial_number in self.wrist_camera_id_numbers:
|
||||
camera = RealSenseCamera(
|
||||
img_shape=self.wrist_image_shape, fps=self.fps, serial_number=serial_number
|
||||
)
|
||||
self.wrist_cameras.append(camera)
|
||||
else:
|
||||
logger_mp.warning(f"[Image Server] Unsupported wrist_camera_type: {self.wrist_camera_type}")
|
||||
|
||||
# Set ZeroMQ context and socket
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.PUB)
|
||||
self.socket.bind(f"tcp://*:{self.port}")
|
||||
|
||||
if self.Unit_Test:
|
||||
self._init_performance_metrics()
|
||||
|
||||
for cam in self.head_cameras:
|
||||
if isinstance(cam, OpenCVCamera):
|
||||
logger_mp.info(
|
||||
f"[Image Server] Head camera {cam.id} resolution: {cam.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)} x {cam.cap.get(cv2.CAP_PROP_FRAME_WIDTH)}"
|
||||
)
|
||||
elif isinstance(cam, RealSenseCamera):
|
||||
logger_mp.info(
|
||||
f"[Image Server] Head camera {cam.serial_number} resolution: {cam.img_shape[0]} x {cam.img_shape[1]}"
|
||||
)
|
||||
else:
|
||||
logger_mp.warning("[Image Server] Unknown camera type in head_cameras.")
|
||||
|
||||
for cam in self.wrist_cameras:
|
||||
if isinstance(cam, OpenCVCamera):
|
||||
logger_mp.info(
|
||||
f"[Image Server] Wrist camera {cam.id} resolution: {cam.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)} x {cam.cap.get(cv2.CAP_PROP_FRAME_WIDTH)}"
|
||||
)
|
||||
elif isinstance(cam, RealSenseCamera):
|
||||
logger_mp.info(
|
||||
f"[Image Server] Wrist camera {cam.serial_number} resolution: {cam.img_shape[0]} x {cam.img_shape[1]}"
|
||||
)
|
||||
else:
|
||||
logger_mp.warning("[Image Server] Unknown camera type in wrist_cameras.")
|
||||
|
||||
logger_mp.info("[Image Server] Image server has started, waiting for client connections...")
|
||||
|
||||
def _init_performance_metrics(self):
|
||||
self.frame_count = 0 # Total frames sent
|
||||
self.time_window = 1.0 # Time window for FPS calculation (in seconds)
|
||||
self.frame_times = deque() # Timestamps of frames sent within the time window
|
||||
self.start_time = time.time() # Start time of the streaming
|
||||
|
||||
def _update_performance_metrics(self, current_time):
|
||||
# Add current time to frame times deque
|
||||
self.frame_times.append(current_time)
|
||||
# Remove timestamps outside the time window
|
||||
while self.frame_times and self.frame_times[0] < current_time - self.time_window:
|
||||
self.frame_times.popleft()
|
||||
# Increment frame count
|
||||
self.frame_count += 1
|
||||
|
||||
def _print_performance_metrics(self, current_time):
|
||||
if self.frame_count % 30 == 0:
|
||||
elapsed_time = current_time - self.start_time
|
||||
real_time_fps = len(self.frame_times) / self.time_window
|
||||
logger_mp.info(
|
||||
f"[Image Server] Real-time FPS: {real_time_fps:.2f}, Total frames sent: {self.frame_count}, Elapsed time: {elapsed_time:.2f} sec"
|
||||
)
|
||||
|
||||
def _close(self):
|
||||
for cam in self.head_cameras:
|
||||
cam.release()
|
||||
for cam in self.wrist_cameras:
|
||||
cam.release()
|
||||
self.socket.close()
|
||||
self.context.term()
|
||||
logger_mp.info("[Image Server] The server has been closed.")
|
||||
|
||||
def send_process(self):
|
||||
try:
|
||||
while True:
|
||||
head_frames = []
|
||||
for cam in self.head_cameras:
|
||||
if self.head_camera_type == "opencv":
|
||||
color_image = cam.get_frame()
|
||||
if color_image is None:
|
||||
logger_mp.error("[Image Server] Head camera frame read is error.")
|
||||
break
|
||||
elif self.head_camera_type == "realsense":
|
||||
color_image, depth_iamge = cam.get_frame()
|
||||
if color_image is None:
|
||||
logger_mp.error("[Image Server] Head camera frame read is error.")
|
||||
break
|
||||
head_frames.append(color_image)
|
||||
if len(head_frames) != len(self.head_cameras):
|
||||
break
|
||||
head_color = cv2.hconcat(head_frames)
|
||||
|
||||
if self.wrist_cameras:
|
||||
wrist_frames = []
|
||||
for cam in self.wrist_cameras:
|
||||
if self.wrist_camera_type == "opencv":
|
||||
color_image = cam.get_frame()
|
||||
if color_image is None:
|
||||
logger_mp.error("[Image Server] Wrist camera frame read is error.")
|
||||
break
|
||||
elif self.wrist_camera_type == "realsense":
|
||||
color_image, depth_iamge = cam.get_frame()
|
||||
if color_image is None:
|
||||
logger_mp.error("[Image Server] Wrist camera frame read is error.")
|
||||
break
|
||||
wrist_frames.append(color_image)
|
||||
wrist_color = cv2.hconcat(wrist_frames)
|
||||
|
||||
# Concatenate head and wrist frames
|
||||
full_color = cv2.hconcat([head_color, wrist_color])
|
||||
else:
|
||||
full_color = head_color
|
||||
|
||||
ret, buffer = cv2.imencode(".jpg", full_color)
|
||||
if not ret:
|
||||
logger_mp.error("[Image Server] Frame imencode is failed.")
|
||||
continue
|
||||
|
||||
jpg_bytes = buffer.tobytes()
|
||||
|
||||
if self.Unit_Test:
|
||||
timestamp = time.time()
|
||||
frame_id = self.frame_count
|
||||
header = struct.pack("dI", timestamp, frame_id) # 8-byte double, 4-byte unsigned int
|
||||
message = header + jpg_bytes
|
||||
else:
|
||||
message = jpg_bytes
|
||||
|
||||
self.socket.send(message)
|
||||
|
||||
if self.Unit_Test:
|
||||
current_time = time.time()
|
||||
self._update_performance_metrics(current_time)
|
||||
self._print_performance_metrics(current_time)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger_mp.warning("[Image Server] Interrupted by user.")
|
||||
finally:
|
||||
self._close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# config = {
|
||||
# "fps": 30,
|
||||
# "head_camera_type": "opencv",
|
||||
# "head_camera_image_shape": [480, 1280], # Head camera resolution
|
||||
# "head_camera_id_numbers": [0],
|
||||
# "wrist_camera_type": "opencv",
|
||||
# "wrist_camera_image_shape": [480, 640], # Wrist camera resolution
|
||||
# "wrist_camera_id_numbers": [2, 4],
|
||||
#
|
||||
#infrared
|
||||
# config = {
|
||||
# "fps": 30,
|
||||
# "head_camera_type": "opencv",
|
||||
# "head_camera_image_shape": [480, 640],
|
||||
# "head_camera_id_numbers": [2], # <-- wrist cam that reported 480x640
|
||||
# # no wrist_* keys
|
||||
# }
|
||||
#rgb
|
||||
config = {
|
||||
"fps": 30,
|
||||
"head_camera_type": "opencv",
|
||||
"head_camera_image_shape": [480,640], # match the device
|
||||
"head_camera_id_numbers": [4], # /dev/video4 is RGB
|
||||
}
|
||||
|
||||
server = ImageServer(config, Unit_Test=False)
|
||||
server.send_process()
|
||||
134
src/lerobot/robots/unitree_g1/robot_server.py
Normal file
134
src/lerobot/robots/unitree_g1/robot_server.py
Normal file
@@ -0,0 +1,134 @@
|
||||
#!/usr/bin/env python3
|
||||
import time
|
||||
import pickle
|
||||
import threading
|
||||
|
||||
import zmq
|
||||
|
||||
from unitree_sdk2py.core.channel import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import MotionSwitcherClient
|
||||
|
||||
kTopicLowCommand_Debug = "rt/lowcmd"
|
||||
kTopicLowCommand_Motion = "rt/arm_sdk"
|
||||
kTopicLowState = "rt/lowstate"
|
||||
|
||||
LOWCMD_PORT = 6000 # laptop -> robot
|
||||
LOWSTATE_PORT = 6001 # robot -> laptop
|
||||
|
||||
|
||||
def state_forward_loop(lowstate_sub, lowstate_sock, state_period: float):
|
||||
"""
|
||||
read lowstate from dds and push to laptop at ~state_period.
|
||||
runs in its own thread.
|
||||
"""
|
||||
last_state_time = 0.0
|
||||
|
||||
while True:
|
||||
# read from dds (blocking)
|
||||
msg = lowstate_sub.Read()
|
||||
if msg is None:
|
||||
continue
|
||||
|
||||
now = time.time()
|
||||
# optional downsampling (if robot dds rate > state_period)
|
||||
if now - last_state_time >= state_period:
|
||||
payload = pickle.dumps((kTopicLowState, msg), protocol=pickle.HIGHEST_PROTOCOL)
|
||||
try:
|
||||
lowstate_sock.send(payload, zmq.NOBLOCK)
|
||||
except zmq.Again:
|
||||
# if no subscribers / tx buffer full, just drop
|
||||
pass
|
||||
last_state_time = now
|
||||
|
||||
|
||||
def cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, lowcmd_pub_motion, crc: CRC):
|
||||
"""
|
||||
read lowcmd from laptop (zmq) and push to dds.
|
||||
runs in its own thread.
|
||||
"""
|
||||
while True:
|
||||
# blocking wait for commands from laptop
|
||||
payload = lowcmd_sock.recv()
|
||||
topic, cmd = pickle.loads(payload) # cmd is hg_LowCmd
|
||||
|
||||
# recompute crc just in case
|
||||
cmd.crc = crc.Crc(cmd)
|
||||
|
||||
if topic == kTopicLowCommand_Debug:
|
||||
lowcmd_pub_debug.Write(cmd)
|
||||
elif topic == kTopicLowCommand_Motion:
|
||||
lowcmd_pub_motion.Write(cmd)
|
||||
else:
|
||||
# ignore unknown topics
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
# dds init
|
||||
ChannelFactoryInitialize(0)
|
||||
|
||||
# acquire motion mode on the robot
|
||||
msc = MotionSwitcherClient()
|
||||
msc.SetTimeout(5.0)
|
||||
msc.Init()
|
||||
|
||||
status, result = msc.CheckMode()
|
||||
while result is not None and "name" in result and result["name"]:
|
||||
msc.ReleaseMode()
|
||||
status, result = msc.CheckMode()
|
||||
time.sleep(1.0)
|
||||
|
||||
crc = CRC()
|
||||
|
||||
# dds publishers / subscriber
|
||||
lowcmd_pub_debug = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||
lowcmd_pub_motion = ChannelPublisher(kTopicLowCommand_Motion, hg_LowCmd)
|
||||
lowcmd_pub_debug.Init()
|
||||
lowcmd_pub_motion.Init()
|
||||
|
||||
lowstate_sub = ChannelSubscriber(kTopicLowState, hg_LowState)
|
||||
lowstate_sub.Init()
|
||||
|
||||
# zmq setup
|
||||
ctx = zmq.Context.instance()
|
||||
|
||||
# commands from laptop
|
||||
lowcmd_sock = ctx.socket(zmq.PULL)
|
||||
lowcmd_sock.bind(f"tcp://0.0.0.0:{LOWCMD_PORT}")
|
||||
|
||||
# state to laptop
|
||||
lowstate_sock = ctx.socket(zmq.PUB)
|
||||
lowstate_sock.bind(f"tcp://0.0.0.0:{LOWSTATE_PORT}")
|
||||
|
||||
state_period = 0.002 # ~500 hz
|
||||
|
||||
# start threads
|
||||
t_state = threading.Thread(
|
||||
target=state_forward_loop,
|
||||
args=(lowstate_sub, lowstate_sock, state_period),
|
||||
daemon=True,
|
||||
)
|
||||
t_cmd = threading.Thread(
|
||||
target=cmd_forward_loop,
|
||||
args=(lowcmd_sock, lowcmd_pub_debug, lowcmd_pub_motion, crc),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
t_state.start()
|
||||
t_cmd.start()
|
||||
|
||||
print("bridge running (lowstate -> zmq, lowcmd -> dds)")
|
||||
|
||||
# keep main thread alive so daemon threads don’t exit
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1.0)
|
||||
except KeyboardInterrupt:
|
||||
print("shutting down bridge...")
|
||||
# sockets/context will be cleaned up on process exit
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -27,31 +27,16 @@ import termios
|
||||
import tty
|
||||
from collections import deque
|
||||
|
||||
from unitree_sdk2py.core.channel import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize # dds
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState # idl for g1, h1_2
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
from unitree_sdk2py.g1.audio.g1_audio_client import AudioClient
|
||||
|
||||
from unitree_sdk2py.idl.unitree_go.msg.dds_ import LowCmd_ as go_LowCmd, LowState_ as go_LowState # idl for h1
|
||||
from unitree_sdk2py.idl.default import unitree_go_msg_dds__LowCmd_
|
||||
|
||||
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
import onnxruntime as ort
|
||||
|
||||
from unitree_sdk2py.core.channel import ChannelPublisher, ChannelFactoryInitialize
|
||||
from unitree_sdk2py.core.channel import ChannelSubscriber, ChannelFactoryInitialize
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_, unitree_hg_msg_dds__LowState_
|
||||
from unitree_sdk2py.idl.default import unitree_go_msg_dds__LowCmd_, unitree_go_msg_dds__LowState_
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as LowCmdHG
|
||||
from unitree_sdk2py.idl.unitree_go.msg.dds_ import LowCmd_ as LowCmdGo
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowState_ as LowStateHG
|
||||
from unitree_sdk2py.idl.unitree_go.msg.dds_ import LowState_ as LowStateGo
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState # idl for g1, h1_2
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
from unitree_sdk2py.g1.audio.g1_audio_client import AudioClient
|
||||
from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import (
|
||||
MotionSwitcherClient,
|
||||
)
|
||||
@@ -65,13 +50,11 @@ import yaml
|
||||
|
||||
from typing import Union
|
||||
|
||||
import logging_mp
|
||||
|
||||
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
|
||||
|
||||
import torch
|
||||
|
||||
logger_mp = logging_mp.get_logger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
kTopicLowCommand_Debug = "rt/lowcmd"
|
||||
kTopicLowCommand_Motion = "rt/arm_sdk"
|
||||
@@ -121,7 +104,6 @@ class DataBuffer:
|
||||
|
||||
#eventually observations should be everything: motor torques etc etc
|
||||
#motor class for unitree?
|
||||
#TODO: camera, sim
|
||||
class UnitreeG1(Robot):
|
||||
|
||||
config_class = UnitreeG1Config
|
||||
@@ -130,7 +112,7 @@ class UnitreeG1(Robot):
|
||||
def __init__(self, config: UnitreeG1Config):
|
||||
super().__init__(config)
|
||||
|
||||
logger_mp.info("Initialize UnitreeG1...")
|
||||
logger.info("Initialize UnitreeG1...")
|
||||
|
||||
self.config = config
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
@@ -153,6 +135,11 @@ class UnitreeG1(Robot):
|
||||
self._gradual_start_time = config.gradual_start_time
|
||||
self._gradual_time = config.gradual_time
|
||||
|
||||
# Teleop warmup: gradually move from current position to targets over 2 seconds
|
||||
self.teleop_warmup_duration = 2.0 # seconds
|
||||
self.teleop_warmup_start_time = None
|
||||
self.teleop_warmup_initial_q = None
|
||||
|
||||
self.freeze_body = config.freeze_body
|
||||
self.gravity_compensation = config.gravity_compensation
|
||||
|
||||
@@ -163,33 +150,39 @@ class UnitreeG1(Robot):
|
||||
|
||||
self.arm_ik = G1_29_ArmIK()
|
||||
|
||||
if self.config.socket_host is not None:
|
||||
from lerobot.robots.unitree_g1.unitree_sdk2_socket import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize # dds
|
||||
else:
|
||||
from unitree_sdk2py.core.channel import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize # dds
|
||||
|
||||
# initialize lowcmd publisher and lowstate subscriber
|
||||
# initialize lowcmd nd lowstate subscriber
|
||||
if self.simulation_mode:
|
||||
ChannelFactoryInitialize(0, "lo")
|
||||
|
||||
# Launch MuJoCo simulation environment
|
||||
logger_mp.info("Launching MuJoCo simulation environment...")
|
||||
logger.info("Launching MuJoCo simulation environment...")
|
||||
self.mujoco_env = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True)
|
||||
logger_mp.info("MuJoCo environment launched successfully!")
|
||||
logger.info("MuJoCo environment launched successfully!")
|
||||
else:
|
||||
ChannelFactoryInitialize(0)
|
||||
|
||||
|
||||
if not self.config.simulation_mode:
|
||||
self.msc = MotionSwitcherClient()
|
||||
self.msc.SetTimeout(5.0)
|
||||
self.msc.Init()
|
||||
|
||||
status, result = self.msc.CheckMode()
|
||||
print(status, result)
|
||||
#check if result name first
|
||||
if result is not None and "name" in result:
|
||||
while result["name"]:
|
||||
self.msc.ReleaseMode()
|
||||
status, result = self.msc.CheckMode()
|
||||
print(status, result)
|
||||
time.sleep(1)
|
||||
if not self.config.simulation_mode:
|
||||
pass
|
||||
# self.msc = MotionSwitcherClient()
|
||||
# self.msc.SetTimeout(5.0)
|
||||
# self.msc.Init()
|
||||
|
||||
# status, result = self.msc.CheckMode()
|
||||
# print(status, result)
|
||||
# #check if result name first
|
||||
# if result is not None and "name" in result:
|
||||
# while result["name"]:
|
||||
# self.msc.ReleaseMode()
|
||||
# status, result = self.msc.CheckMode()
|
||||
# print(status, result)
|
||||
# time.sleep(1)
|
||||
|
||||
if self.motion_mode:
|
||||
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Motion, hg_LowCmd)
|
||||
@@ -207,8 +200,8 @@ class UnitreeG1(Robot):
|
||||
|
||||
while not self.lowstate_buffer.GetData():
|
||||
time.sleep(0.1)
|
||||
logger_mp.warning("[UnitreeG1] Waiting to subscribe dds...")
|
||||
logger_mp.info("[UnitreeG1] Subscribe dds ok.")
|
||||
logger.warning("[UnitreeG1] Waiting to subscribe dds...")
|
||||
logger.info("[UnitreeG1] Subscribe dds ok.")
|
||||
|
||||
# initialize audio client for LED, TTS, and audio playback
|
||||
|
||||
@@ -221,9 +214,9 @@ class UnitreeG1(Robot):
|
||||
print(self.msg)
|
||||
|
||||
self.all_motor_q = self.get_current_motor_q()
|
||||
logger_mp.info(f"Current all body motor state q:\n{self.all_motor_q} \n")
|
||||
logger_mp.info(f"Current two arms motor state q:\n{self.get_current_dual_arm_q()}\n")
|
||||
logger_mp.info("Lock all joints except two arms...\n")
|
||||
logger.info(f"Current all body motor state q:\n{self.all_motor_q} \n")
|
||||
logger.info(f"Current two arms motor state q:\n{self.get_current_dual_arm_q()}\n")
|
||||
logger.info("Lock all joints except two arms...\n")
|
||||
|
||||
arm_indices = set(member.value for member in G1_29_JointArmIndex)
|
||||
for id in G1_29_JointIndex:
|
||||
@@ -248,12 +241,13 @@ class UnitreeG1(Robot):
|
||||
|
||||
|
||||
if config.audio_client:
|
||||
self.audio_client = AudioClient()
|
||||
self.audio_client.SetTimeout(10.0)
|
||||
self.audio_client.Init()
|
||||
logger_mp.info("[UnitreeG1] Audio client initialized!")
|
||||
pass
|
||||
# self.audio_client = AudioClient()
|
||||
# self.audio_client.SetTimeout(10.0)
|
||||
# self.audio_client.Init()
|
||||
# logger.info("[UnitreeG1] Audio client initialized!")
|
||||
|
||||
logger_mp.info("Lock OK!\n") #motors are not locked x
|
||||
logger.info("Lock OK!\n") #motors are not locked x
|
||||
# for i in range(10000):
|
||||
# print(self.get_current_motor_q())
|
||||
# time.sleep(0.05)
|
||||
@@ -266,15 +260,18 @@ class UnitreeG1(Robot):
|
||||
self.motion_imitation_thread = None
|
||||
self.motion_imitation_running = False
|
||||
|
||||
# Initialize publish thread ONLY if not using motion imitation or locomotion
|
||||
# (those modes handle their own motor commands and publishing)
|
||||
# Initialize publish thread for arm control
|
||||
# Note: This thread runs alongside locomotion/motion_imitation threads
|
||||
# - Arm thread: controls arms (indices 15-28)
|
||||
# - Locomotion thread: controls legs (0-11), waist (12-14)
|
||||
# Both update different parts of self.msg, both call Write()
|
||||
self.publish_thread = None
|
||||
self.ctrl_lock = threading.Lock()
|
||||
if not config.motion_imitation_control and not config.locomotion_control:
|
||||
if not config.motion_imitation_control: # Allow with locomotion, disable only for motion imitation
|
||||
self.publish_thread = threading.Thread(target=self._ctrl_motor_state)
|
||||
self.publish_thread.daemon = True
|
||||
self.publish_thread.start()
|
||||
logger_mp.info("Arm control publish thread started")
|
||||
logger.info("Arm control publish thread started")
|
||||
|
||||
# Load locomotion policy if enabled
|
||||
self.policy = None
|
||||
@@ -286,21 +283,21 @@ class UnitreeG1(Robot):
|
||||
if config.motion_file_path is None:
|
||||
raise ValueError("motion_imitation_control is True but motion_file_path is not set")
|
||||
|
||||
logger_mp.info(f"Loading motion reference from {config.motion_file_path}")
|
||||
logger.info(f"Loading motion reference from {config.motion_file_path}")
|
||||
|
||||
# Load motion file
|
||||
self.motion_loader = self.MotionLoader(config.motion_file_path, config.motion_fps)
|
||||
|
||||
# Load ONNX policy (optional for now - can run in direct playback mode)
|
||||
if config.motion_policy_path and Path(config.motion_policy_path).exists():
|
||||
logger_mp.info(f"Loading motion imitation policy from {config.motion_policy_path}")
|
||||
logger.info(f"Loading motion imitation policy from {config.motion_policy_path}")
|
||||
self.policy = ort.InferenceSession(config.motion_policy_path)
|
||||
self.policy_type = 'motion_imitation'
|
||||
logger_mp.info("Motion imitation ONNX policy loaded successfully")
|
||||
logger_mp.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
|
||||
logger_mp.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
|
||||
logger.info("Motion imitation ONNX policy loaded successfully")
|
||||
logger.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
|
||||
logger.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
|
||||
else:
|
||||
logger_mp.info("Running in DIRECT PLAYBACK mode (no policy - just reference motion)")
|
||||
logger.info("Running in DIRECT PLAYBACK mode (no policy - just reference motion)")
|
||||
self.policy = None
|
||||
self.policy_type = 'motion_playback'
|
||||
|
||||
@@ -319,21 +316,21 @@ class UnitreeG1(Robot):
|
||||
if config.policy_path is None:
|
||||
raise ValueError("locomotion_control is True but policy_path is not set")
|
||||
|
||||
logger_mp.info(f"Loading locomotion policy from {config.policy_path}")
|
||||
logger.info(f"Loading locomotion policy from {config.policy_path}")
|
||||
|
||||
# Check file extension and load accordingly
|
||||
if config.policy_path.endswith('.pt'):
|
||||
logger_mp.info("Detected TorchScript (.pt) policy")
|
||||
logger.info("Detected TorchScript (.pt) policy")
|
||||
self.policy = torch.jit.load(config.policy_path)
|
||||
self.policy_type = 'torchscript'
|
||||
logger_mp.info("TorchScript policy loaded successfully")
|
||||
logger.info("TorchScript policy loaded successfully")
|
||||
elif config.policy_path.endswith('.onnx'):
|
||||
logger_mp.info("Detected ONNX (.onnx) policy")
|
||||
logger.info("Detected ONNX (.onnx) policy")
|
||||
self.policy = ort.InferenceSession(config.policy_path)
|
||||
self.policy_type = 'onnx'
|
||||
logger_mp.info("ONNX policy loaded successfully")
|
||||
logger_mp.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
|
||||
logger_mp.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
|
||||
logger.info("ONNX policy loaded successfully")
|
||||
logger.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
|
||||
logger.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported policy format: {config.policy_path}. Only .pt (TorchScript) and .onnx (ONNX) are supported.")
|
||||
|
||||
@@ -363,7 +360,7 @@ class UnitreeG1(Robot):
|
||||
|
||||
# Start keyboard controls if in simulation mode
|
||||
if self.simulation_mode:
|
||||
logger_mp.info("Starting keyboard controls for simulation...")
|
||||
logger.info("Starting keyboard controls for simulation...")
|
||||
self.start_keyboard_controls()
|
||||
|
||||
# Use different init based on policy type
|
||||
@@ -373,10 +370,10 @@ class UnitreeG1(Robot):
|
||||
self.init_locomotion()
|
||||
elif self.simulation_mode:
|
||||
# Even without locomotion, provide keyboard feedback in sim
|
||||
logger_mp.info("Simulation mode active (locomotion disabled)")
|
||||
logger.info("Simulation mode active (locomotion disabled)")
|
||||
|
||||
|
||||
logger_mp.info("Initialize G1 OK!\n")
|
||||
logger.info("Initialize G1 OK!\n")
|
||||
|
||||
def _subscribe_motor_state(self):
|
||||
while True:
|
||||
@@ -463,8 +460,8 @@ class UnitreeG1(Robot):
|
||||
all_t_elapsed = current_time - start_time
|
||||
sleep_time = max(0, (self.control_dt - all_t_elapsed))
|
||||
time.sleep(sleep_time)
|
||||
# logger_mp.debug(f"arm_velocity_limit:{self.arm_velocity_limit}")
|
||||
# logger_mp.debug(f"sleep_time:{sleep_time}")
|
||||
# logger.debug(f"arm_velocity_limit:{self.arm_velocity_limit}")
|
||||
# logger.debug(f"sleep_time:{sleep_time}")
|
||||
|
||||
def ctrl_dual_arm(self, q_target, tauff_target):
|
||||
"""Set control target values q & tau of the left and right arm motors."""
|
||||
@@ -490,7 +487,7 @@ class UnitreeG1(Robot):
|
||||
|
||||
def ctrl_dual_arm_go_home(self):
|
||||
"""Move both the left and right arms of the robot to their home position by setting the target joint angles (q) and torques (tau) to zero."""
|
||||
logger_mp.info("[G1_29_ArmController] ctrl_dual_arm_go_home start...")
|
||||
logger.info("[G1_29_ArmController] ctrl_dual_arm_go_home start...")
|
||||
max_attempts = 100
|
||||
current_attempts = 0
|
||||
with self.ctrl_lock:
|
||||
@@ -505,7 +502,7 @@ class UnitreeG1(Robot):
|
||||
for weight in np.linspace(1, 0, num=101):
|
||||
self.msg.motor_cmd[G1_29_JointIndex.kNotUsedJoint0].q = weight
|
||||
time.sleep(0.02)
|
||||
logger_mp.info("[G1_29_ArmController] both arms have reached the home position.")
|
||||
logger.info("[G1_29_ArmController] both arms have reached the home position.")
|
||||
break
|
||||
current_attempts += 1
|
||||
time.sleep(0.05)
|
||||
@@ -563,7 +560,7 @@ class UnitreeG1(Robot):
|
||||
# Connect cameras
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
logger_mp.info(f"{self} connected with {len(self.cameras)} camera(s).")
|
||||
logger.info(f"{self} connected with {len(self.cameras)} camera(s).")
|
||||
|
||||
def disconnect(self):
|
||||
# Disconnect cameras
|
||||
@@ -572,11 +569,11 @@ class UnitreeG1(Robot):
|
||||
|
||||
# Close MuJoCo environment if in simulation mode
|
||||
if self.simulation_mode and hasattr(self, 'mujoco_env'):
|
||||
logger_mp.info("Closing MuJoCo environment...")
|
||||
logger.info("Closing MuJoCo environment...")
|
||||
print(self.mujoco_env)
|
||||
self.mujoco_env["hub_env"][0].envs[0].kill_sim()
|
||||
|
||||
logger_mp.info(f"{self} disconnected.")
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
def get_full_robot_state(self) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -640,18 +637,18 @@ class UnitreeG1(Robot):
|
||||
if isinstance(command, tuple) and len(command) == 3:
|
||||
# LED control - RGB tuple
|
||||
r, g, b = command
|
||||
logger_mp.info(f"Setting LED to RGB({r}, {g}, {b})")
|
||||
logger.info(f"Setting LED to RGB({r}, {g}, {b})")
|
||||
self.audio_client.LedControl(r, g, b)
|
||||
|
||||
elif isinstance(command, str):
|
||||
# Check if it's a file path
|
||||
if Path(command).exists():
|
||||
# Play WAV file
|
||||
logger_mp.info(f"Playing audio file: {command}")
|
||||
logger.info(f"Playing audio file: {command}")
|
||||
self._play_wav_file(command)
|
||||
else:
|
||||
# Text-to-speech
|
||||
logger_mp.info(f"Speaking: {command}")
|
||||
logger.info(f"Speaking: {command}")
|
||||
self.audio_client.TtsMaker(command, 0) # 0 for English
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -746,7 +743,7 @@ class UnitreeG1(Robot):
|
||||
chunk_index = 0
|
||||
total_size = len(pcm_data)
|
||||
|
||||
logger_mp.info(f"Playing audio: {total_size} bytes in {(total_size // chunk_size) + 1} chunks")
|
||||
logger.info(f"Playing audio: {total_size} bytes in {(total_size // chunk_size) + 1} chunks")
|
||||
|
||||
# Send audio in chunks
|
||||
while offset < total_size:
|
||||
@@ -757,10 +754,10 @@ class UnitreeG1(Robot):
|
||||
# Send chunk
|
||||
ret_code, _ = self.audio_client.PlayStream(app_name, stream_id, list(chunk))
|
||||
if ret_code != 0:
|
||||
logger_mp.error(f"Failed to send chunk {chunk_index}, return code: {ret_code}")
|
||||
logger.error(f"Failed to send chunk {chunk_index}, return code: {ret_code}")
|
||||
break
|
||||
else:
|
||||
logger_mp.debug(f"Sent chunk {chunk_index}/{(total_size // chunk_size)}")
|
||||
logger.debug(f"Sent chunk {chunk_index}/{(total_size // chunk_size)}")
|
||||
|
||||
offset += current_chunk_size
|
||||
chunk_index += 1
|
||||
@@ -768,7 +765,7 @@ class UnitreeG1(Robot):
|
||||
|
||||
# Calculate playback duration
|
||||
duration_seconds = len(pcm_data) / (16000 * 2) # 16kHz, 16-bit (2 bytes)
|
||||
logger_mp.info(f"Audio playback will take ~{duration_seconds:.1f} seconds")
|
||||
logger.info(f"Audio playback will take ~{duration_seconds:.1f} seconds")
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
obs_array = self.get_current_dual_arm_q()
|
||||
@@ -779,7 +776,7 @@ class UnitreeG1(Robot):
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger_mp.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@@ -909,13 +906,13 @@ class UnitreeG1(Robot):
|
||||
|
||||
def locomotion_zero_torque_state(self):
|
||||
"""Enter zero torque state."""
|
||||
logger_mp.info("Enter zero torque state.")
|
||||
logger.info("Enter zero torque state.")
|
||||
self.locomotion_create_zero_cmd()
|
||||
time.sleep(self.config.locomotion_control_dt)
|
||||
|
||||
def locomotion_move_to_default_pos(self):
|
||||
"""Move robot legs to default standing position over 2 seconds (arms are not moved)."""
|
||||
logger_mp.info("Moving legs to default locomotion pos.")
|
||||
logger.info("Moving legs to default locomotion pos.")
|
||||
total_time = 2.0
|
||||
num_step = int(total_time / self.config.locomotion_control_dt)
|
||||
|
||||
@@ -929,7 +926,7 @@ class UnitreeG1(Robot):
|
||||
# Get current lowstate
|
||||
lowstate = self.lowstate_buffer.GetData()
|
||||
if lowstate is None:
|
||||
logger_mp.error("Cannot get lowstate for locomotion")
|
||||
logger.error("Cannot get lowstate for locomotion")
|
||||
return
|
||||
|
||||
# Record the current leg positions
|
||||
@@ -951,11 +948,11 @@ class UnitreeG1(Robot):
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
time.sleep(self.config.locomotion_control_dt)
|
||||
logger_mp.info("Reached default locomotion position (legs only)")
|
||||
logger.info("Reached default locomotion position (legs only)")
|
||||
|
||||
def locomotion_default_pos_state(self):
|
||||
"""Hold default leg position for 2 seconds (arms are not controlled)."""
|
||||
logger_mp.info("Enter default pos state - holding legs for 2 seconds")
|
||||
logger.info("Enter default pos state - holding legs for 2 seconds")
|
||||
|
||||
# Only control legs, not arms
|
||||
for i in range(len(self.config.leg_joint2motor_idx)):
|
||||
@@ -973,7 +970,7 @@ class UnitreeG1(Robot):
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
time.sleep(self.config.locomotion_control_dt)
|
||||
logger_mp.info("Finished holding default leg position")
|
||||
logger.info("Finished holding default leg position")
|
||||
|
||||
|
||||
class RemoteController:
|
||||
@@ -1022,7 +1019,7 @@ class UnitreeG1(Robot):
|
||||
self.index_1 = 0
|
||||
self.blend = 0.0
|
||||
|
||||
logger_mp.info(f"MotionLoader: Loaded {self.num_frames} frames, duration={self.duration:.2f}s")
|
||||
logger.info(f"MotionLoader: Loaded {self.num_frames} frames, duration={self.duration:.2f}s")
|
||||
|
||||
def update(self, time: float):
|
||||
"""Update motion to specific time (loops at duration)."""
|
||||
@@ -1146,7 +1143,7 @@ class UnitreeG1(Robot):
|
||||
|
||||
# Debug: print remote controller values every 50 iterations (~1 second at 50Hz)
|
||||
if self.locomotion_counter % 50 == 0:
|
||||
logger_mp.debug(f"Remote controller - lx:{self.remote_controller.lx:.2f}, ly:{self.remote_controller.ly:.2f}, rx:{self.remote_controller.rx:.2f}")
|
||||
logger.debug(f"Remote controller - lx:{self.remote_controller.lx:.2f}, ly:{self.remote_controller.ly:.2f}, rx:{self.remote_controller.rx:.2f}")
|
||||
|
||||
# Build observation vector
|
||||
num_actions = self.config.num_locomotion_actions
|
||||
@@ -1343,7 +1340,7 @@ class UnitreeG1(Robot):
|
||||
|
||||
def _locomotion_thread_loop(self):
|
||||
"""Background thread that runs the locomotion policy at specified rate."""
|
||||
logger_mp.info("Locomotion thread started")
|
||||
logger.info("Locomotion thread started")
|
||||
while self.locomotion_running:
|
||||
start_time = time.time()
|
||||
try:
|
||||
@@ -1353,40 +1350,40 @@ class UnitreeG1(Robot):
|
||||
else:
|
||||
self.locomotion_run()
|
||||
except Exception as e:
|
||||
logger_mp.error(f"Error in locomotion loop: {e}")
|
||||
logger.error(f"Error in locomotion loop: {e}")
|
||||
|
||||
# Sleep to maintain control rate
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.config.locomotion_control_dt - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger_mp.info("Locomotion thread stopped")
|
||||
logger.info("Locomotion thread stopped")
|
||||
|
||||
def start_locomotion_thread(self):
|
||||
"""Start the background locomotion control thread."""
|
||||
if not self.config.locomotion_control:
|
||||
logger_mp.warning("locomotion_control is False, cannot start thread")
|
||||
logger.warning("locomotion_control is False, cannot start thread")
|
||||
return
|
||||
|
||||
if self.locomotion_running:
|
||||
logger_mp.warning("Locomotion thread already running")
|
||||
logger.warning("Locomotion thread already running")
|
||||
return
|
||||
|
||||
logger_mp.info("Starting locomotion control thread...")
|
||||
logger.info("Starting locomotion control thread...")
|
||||
self.locomotion_running = True
|
||||
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
|
||||
self.locomotion_thread.start()
|
||||
logger_mp.info("Locomotion control thread started!")
|
||||
logger.info("Locomotion control thread started!")
|
||||
|
||||
def stop_locomotion_thread(self):
|
||||
"""Stop the background locomotion control thread."""
|
||||
if not self.locomotion_running:
|
||||
return
|
||||
|
||||
logger_mp.info("Stopping locomotion control thread...")
|
||||
logger.info("Stopping locomotion control thread...")
|
||||
self.locomotion_running = False
|
||||
if self.locomotion_thread:
|
||||
self.locomotion_thread.join(timeout=2.0)
|
||||
logger_mp.info("Locomotion control thread stopped")
|
||||
logger.info("Locomotion control thread stopped")
|
||||
|
||||
# Also stop keyboard thread if running
|
||||
if self.keyboard_running:
|
||||
@@ -1459,40 +1456,40 @@ class UnitreeG1(Robot):
|
||||
def start_keyboard_controls(self):
|
||||
"""Start the keyboard control thread (sim mode only)."""
|
||||
if not self.simulation_mode:
|
||||
logger_mp.warning("Keyboard controls only available in simulation mode")
|
||||
logger.warning("Keyboard controls only available in simulation mode")
|
||||
return
|
||||
|
||||
if self.keyboard_running:
|
||||
logger_mp.warning("Keyboard controls already running")
|
||||
logger.warning("Keyboard controls already running")
|
||||
return
|
||||
|
||||
self.keyboard_running = True
|
||||
self.keyboard_thread = threading.Thread(target=self._keyboard_listener_thread, daemon=True)
|
||||
self.keyboard_thread.start()
|
||||
logger_mp.info("Keyboard controls started!")
|
||||
logger.info("Keyboard controls started!")
|
||||
|
||||
def stop_keyboard_controls(self):
|
||||
"""Stop the keyboard control thread."""
|
||||
if not self.keyboard_running:
|
||||
return
|
||||
|
||||
logger_mp.info("Stopping keyboard controls...")
|
||||
logger.info("Stopping keyboard controls...")
|
||||
self.keyboard_running = False
|
||||
if self.keyboard_thread:
|
||||
self.keyboard_thread.join(timeout=2.0)
|
||||
logger_mp.info("Keyboard controls stopped")
|
||||
logger.info("Keyboard controls stopped")
|
||||
|
||||
|
||||
def init_locomotion(self):
|
||||
"""Test locomotion control sequence: home arms -> move legs to default -> start policy thread."""
|
||||
if not self.config.locomotion_control:
|
||||
logger_mp.warning("locomotion_control is False, cannot run test sequence")
|
||||
logger.warning("locomotion_control is False, cannot run test sequence")
|
||||
return
|
||||
|
||||
logger_mp.info("Starting locomotion test sequence...")
|
||||
logger.info("Starting locomotion test sequence...")
|
||||
|
||||
# 1. Home the arms first
|
||||
logger_mp.info("Homing arms to zero position...")
|
||||
logger.info("Homing arms to zero position...")
|
||||
#self.ctrl_dual_arm_go_home()
|
||||
|
||||
# 2. Move legs to default position
|
||||
@@ -1505,19 +1502,19 @@ class UnitreeG1(Robot):
|
||||
self.locomotion_default_pos_state()
|
||||
|
||||
# 5. Start locomotion policy thread (runs in background)
|
||||
logger_mp.info("Starting locomotion policy control...")
|
||||
logger.info("Starting locomotion policy control...")
|
||||
self.start_locomotion_thread()
|
||||
|
||||
logger_mp.info("Locomotion test sequence complete! Policy is now running in background.")
|
||||
logger_mp.info("Use robot.stop_locomotion_thread() to stop the policy.")
|
||||
logger.info("Locomotion test sequence complete! Policy is now running in background.")
|
||||
logger.info("Use robot.stop_locomotion_thread() to stop the policy.")
|
||||
|
||||
def init_groot_locomotion(self):
|
||||
"""Initialize GR00T-style locomotion for ONNX policies (29 DOF, 15D actions)."""
|
||||
if not self.config.locomotion_control:
|
||||
logger_mp.warning("locomotion_control is False, cannot run GR00T init")
|
||||
logger.warning("locomotion_control is False, cannot run GR00T init")
|
||||
return
|
||||
|
||||
logger_mp.info("Starting GR00T locomotion initialization...")
|
||||
logger.info("Starting GR00T locomotion initialization...")
|
||||
|
||||
# Move legs to default position (same as regular locomotion)
|
||||
self.locomotion_move_to_default_pos()
|
||||
@@ -1529,11 +1526,11 @@ class UnitreeG1(Robot):
|
||||
self.locomotion_default_pos_state()
|
||||
|
||||
# Start locomotion policy thread (will use groot_locomotion_run)
|
||||
logger_mp.info("Starting GR00T locomotion policy control...")
|
||||
logger.info("Starting GR00T locomotion policy control...")
|
||||
self.start_locomotion_thread()
|
||||
|
||||
logger_mp.info("GR00T locomotion initialization complete! Policy is now running.")
|
||||
logger_mp.info("516D observations (86D × 6 frames), 15D actions (legs + waist)")
|
||||
logger.info("GR00T locomotion initialization complete! Policy is now running.")
|
||||
logger.info("516D observations (86D × 6 frames), 15D actions (legs + waist)")
|
||||
|
||||
def motion_imitation_run(self):
|
||||
"""Motion imitation policy loop - tracks reference motion (dance_102, etc)."""
|
||||
@@ -1568,13 +1565,13 @@ class UnitreeG1(Robot):
|
||||
self.motion_qj_all[joint_idx] = 0.0
|
||||
self.motion_dqj_all[joint_idx] = 0.0
|
||||
if self.motion_counter == 1:
|
||||
logger_mp.info("="*60)
|
||||
logger_mp.info("🤖 23 DOF MODE ENABLED")
|
||||
logger_mp.info(f" Zeroing joints: {JOINTS_TO_ZERO_23DOF}")
|
||||
logger_mp.info(" Waist: yaw(12), pitch(14)")
|
||||
logger_mp.info(" Wrist L: pitch(20), yaw(21) | Wrist R: pitch(27), yaw(28)")
|
||||
logger_mp.info(" Applied to: robot obs, reference motion, policy actions")
|
||||
logger_mp.info("="*60)
|
||||
logger.info("="*60)
|
||||
logger.info("🤖 23 DOF MODE ENABLED")
|
||||
logger.info(f" Zeroing joints: {JOINTS_TO_ZERO_23DOF}")
|
||||
logger.info(" Waist: yaw(12), pitch(14)")
|
||||
logger.info(" Wrist L: pitch(20), yaw(21) | Wrist R: pitch(27), yaw(28)")
|
||||
logger.info(" Applied to: robot obs, reference motion, policy actions")
|
||||
logger.info("="*60)
|
||||
|
||||
# Get IMU data
|
||||
robot_quat = lowstate.imu_state.quaternion # [w, x, y, z]
|
||||
@@ -1643,9 +1640,9 @@ class UnitreeG1(Robot):
|
||||
self.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
if self.motion_counter == 1:
|
||||
logger_mp.info("="*60)
|
||||
logger_mp.info("⚠️ DEBUG MODE: DIRECT PLAYBACK (reference motion, no policy)")
|
||||
logger_mp.info("="*60)
|
||||
logger.info("="*60)
|
||||
logger.info("⚠️ DEBUG MODE: DIRECT PLAYBACK (reference motion, no policy)")
|
||||
logger.info("="*60)
|
||||
|
||||
target_joint_pos_bfs = None # Not used in this mode
|
||||
|
||||
@@ -1656,9 +1653,9 @@ class UnitreeG1(Robot):
|
||||
motion_joint_pos_bfs = np.zeros(29, dtype=np.float32)
|
||||
motion_joint_vel_bfs = np.zeros(29, dtype=np.float32)
|
||||
if self.motion_counter == 1:
|
||||
logger_mp.info("="*60)
|
||||
logger_mp.info("⚠️ DEBUG MODE: Using ZERO reference motion + RUNNING POLICY")
|
||||
logger_mp.info("="*60)
|
||||
logger.info("="*60)
|
||||
logger.info("⚠️ DEBUG MODE: Using ZERO reference motion + RUNNING POLICY")
|
||||
logger.info("="*60)
|
||||
else:
|
||||
# Get reference motion (DFS order from CSV)
|
||||
motion_joint_pos_dfs = self.motion_loader.get_joint_pos() # 29D
|
||||
@@ -1710,13 +1707,13 @@ class UnitreeG1(Robot):
|
||||
# DEBUG: Just send default positions (should make robot stand still)
|
||||
target_joint_pos_bfs = default_joint_pos.copy()
|
||||
if self.motion_counter == 1:
|
||||
logger_mp.info("="*60)
|
||||
logger_mp.info("⚠️ DEBUG MODE: Sending DEFAULT positions (NO POLICY)")
|
||||
logger_mp.info("="*60)
|
||||
logger_mp.info(f" Default pos BFS[0:5]: {target_joint_pos_bfs[0:5]}")
|
||||
logger.info("="*60)
|
||||
logger.info("⚠️ DEBUG MODE: Sending DEFAULT positions (NO POLICY)")
|
||||
logger.info("="*60)
|
||||
logger.info(f" Default pos BFS[0:5]: {target_joint_pos_bfs[0:5]}")
|
||||
if self.motion_counter % 50 == 0:
|
||||
logger_mp.info(f" [DEFAULT MODE] Sending: [{target_joint_pos_bfs[0]:.4f}, {target_joint_pos_bfs[6]:.4f}, {target_joint_pos_bfs[12]:.4f}]")
|
||||
logger_mp.info(f" [DEFAULT MODE] Robot at: [{self.motion_qj_all[0]:.4f}, {self.motion_qj_all[6]:.4f}, {self.motion_qj_all[12]:.4f}]")
|
||||
logger.info(f" [DEFAULT MODE] Sending: [{target_joint_pos_bfs[0]:.4f}, {target_joint_pos_bfs[6]:.4f}, {target_joint_pos_bfs[12]:.4f}]")
|
||||
logger.info(f" [DEFAULT MODE] Robot at: [{self.motion_qj_all[0]:.4f}, {self.motion_qj_all[6]:.4f}, {self.motion_qj_all[12]:.4f}]")
|
||||
else:
|
||||
# Run ONNX policy inference
|
||||
obs_tensor = torch.from_numpy(self.motion_obs).unsqueeze(0)
|
||||
@@ -1744,29 +1741,29 @@ class UnitreeG1(Robot):
|
||||
|
||||
# Debug print (only when running policy, not in TEST_SEND_DEFAULT_POS or TEST_DIRECT_PLAYBACK mode)
|
||||
if self.motion_counter == 1 and self.policy and not TEST_SEND_DEFAULT_POS and not TEST_DIRECT_PLAYBACK:
|
||||
logger_mp.info("="*60)
|
||||
logger_mp.info("POLICY MODE OBSERVATION CHECK (First iteration)")
|
||||
logger_mp.info("="*60)
|
||||
logger_mp.info(f"Reference motion (BFS) samples: [{motion_joint_pos_bfs[0]:.3f}, {motion_joint_pos_bfs[6]:.3f}, {motion_joint_pos_bfs[12]:.3f}]")
|
||||
logger_mp.info(f"Robot joints (BFS) samples: [{self.motion_qj_all[0]:.3f}, {self.motion_qj_all[6]:.3f}, {self.motion_qj_all[12]:.3f}]")
|
||||
logger_mp.info(f"Default positions samples: [{default_joint_pos[0]:.3f}, {default_joint_pos[6]:.3f}, {default_joint_pos[12]:.3f}]")
|
||||
logger_mp.info(f"Joint pos rel samples: [{joint_pos_rel[0]:.3f}, {joint_pos_rel[6]:.3f}, {joint_pos_rel[12]:.3f}]")
|
||||
logger_mp.info(f"Joint vel rel samples: [{joint_vel_rel[0]:.3f}, {joint_vel_rel[6]:.3f}, {joint_vel_rel[12]:.3f}]")
|
||||
logger_mp.info(f"Angular velocity: [{ang_vel[0]:.3f}, {ang_vel[1]:.3f}, {ang_vel[2]:.3f}]")
|
||||
logger_mp.info(f"Motion anchor ori: [{motion_anchor_ori_b[0]:.3f}, ..., {motion_anchor_ori_b[5]:.3f}]")
|
||||
logger_mp.info(f"Observation breakdown:")
|
||||
logger_mp.info(f" [0:29] motion_cmd_pos: range [{self.motion_obs[0:29].min():.3f}, {self.motion_obs[0:29].max():.3f}]")
|
||||
logger_mp.info(f" [29:58] motion_cmd_vel: range [{self.motion_obs[29:58].min():.3f}, {self.motion_obs[29:58].max():.3f}]")
|
||||
logger_mp.info(f" [58:64] anchor_ori: range [{self.motion_obs[58:64].min():.3f}, {self.motion_obs[58:64].max():.3f}]")
|
||||
logger_mp.info(f" [64:67] ang_vel: range [{self.motion_obs[64:67].min():.3f}, {self.motion_obs[64:67].max():.3f}]")
|
||||
logger_mp.info(f" [67:96] joint_pos_rel: range [{self.motion_obs[67:96].min():.3f}, {self.motion_obs[67:96].max():.3f}]")
|
||||
logger_mp.info(f" [96:125] joint_vel_rel: range [{self.motion_obs[96:125].min():.3f}, {self.motion_obs[96:125].max():.3f}]")
|
||||
logger_mp.info(f" [125:154] last_action: range [{self.motion_obs[125:154].min():.3f}, {self.motion_obs[125:154].max():.3f}]")
|
||||
logger_mp.info(f"Full obs range: [{self.motion_obs.min():.3f}, {self.motion_obs.max():.3f}]")
|
||||
logger_mp.info(f"Action output (first): [{self.motion_action.min():.3f}, {self.motion_action.max():.3f}]")
|
||||
logger_mp.info(f"Action scale samples: [{action_scale[0]:.3f}, {action_scale[6]:.3f}, {action_scale[12]:.3f}]")
|
||||
logger_mp.info(f"Target positions samples: [{target_joint_pos_bfs[0]:.3f}, {target_joint_pos_bfs[6]:.3f}, {target_joint_pos_bfs[12]:.3f}]")
|
||||
logger_mp.info("="*60)
|
||||
logger.info("="*60)
|
||||
logger.info("POLICY MODE OBSERVATION CHECK (First iteration)")
|
||||
logger.info("="*60)
|
||||
logger.info(f"Reference motion (BFS) samples: [{motion_joint_pos_bfs[0]:.3f}, {motion_joint_pos_bfs[6]:.3f}, {motion_joint_pos_bfs[12]:.3f}]")
|
||||
logger.info(f"Robot joints (BFS) samples: [{self.motion_qj_all[0]:.3f}, {self.motion_qj_all[6]:.3f}, {self.motion_qj_all[12]:.3f}]")
|
||||
logger.info(f"Default positions samples: [{default_joint_pos[0]:.3f}, {default_joint_pos[6]:.3f}, {default_joint_pos[12]:.3f}]")
|
||||
logger.info(f"Joint pos rel samples: [{joint_pos_rel[0]:.3f}, {joint_pos_rel[6]:.3f}, {joint_pos_rel[12]:.3f}]")
|
||||
logger.info(f"Joint vel rel samples: [{joint_vel_rel[0]:.3f}, {joint_vel_rel[6]:.3f}, {joint_vel_rel[12]:.3f}]")
|
||||
logger.info(f"Angular velocity: [{ang_vel[0]:.3f}, {ang_vel[1]:.3f}, {ang_vel[2]:.3f}]")
|
||||
logger.info(f"Motion anchor ori: [{motion_anchor_ori_b[0]:.3f}, ..., {motion_anchor_ori_b[5]:.3f}]")
|
||||
logger.info(f"Observation breakdown:")
|
||||
logger.info(f" [0:29] motion_cmd_pos: range [{self.motion_obs[0:29].min():.3f}, {self.motion_obs[0:29].max():.3f}]")
|
||||
logger.info(f" [29:58] motion_cmd_vel: range [{self.motion_obs[29:58].min():.3f}, {self.motion_obs[29:58].max():.3f}]")
|
||||
logger.info(f" [58:64] anchor_ori: range [{self.motion_obs[58:64].min():.3f}, {self.motion_obs[58:64].max():.3f}]")
|
||||
logger.info(f" [64:67] ang_vel: range [{self.motion_obs[64:67].min():.3f}, {self.motion_obs[64:67].max():.3f}]")
|
||||
logger.info(f" [67:96] joint_pos_rel: range [{self.motion_obs[67:96].min():.3f}, {self.motion_obs[67:96].max():.3f}]")
|
||||
logger.info(f" [96:125] joint_vel_rel: range [{self.motion_obs[96:125].min():.3f}, {self.motion_obs[96:125].max():.3f}]")
|
||||
logger.info(f" [125:154] last_action: range [{self.motion_obs[125:154].min():.3f}, {self.motion_obs[125:154].max():.3f}]")
|
||||
logger.info(f"Full obs range: [{self.motion_obs.min():.3f}, {self.motion_obs.max():.3f}]")
|
||||
logger.info(f"Action output (first): [{self.motion_action.min():.3f}, {self.motion_action.max():.3f}]")
|
||||
logger.info(f"Action scale samples: [{action_scale[0]:.3f}, {action_scale[6]:.3f}, {action_scale[12]:.3f}]")
|
||||
logger.info(f"Target positions samples: [{target_joint_pos_bfs[0]:.3f}, {target_joint_pos_bfs[6]:.3f}, {target_joint_pos_bfs[12]:.3f}]")
|
||||
logger.info("="*60)
|
||||
|
||||
if self.motion_counter % 50 == 0:
|
||||
if self.policy is None:
|
||||
@@ -1779,12 +1776,12 @@ class UnitreeG1(Robot):
|
||||
mode = "POLICY_ZEROS"
|
||||
else:
|
||||
mode = "POLICY"
|
||||
logger_mp.info(f"Motion {mode}: t={self.motion_elapsed_time:.2f}s, frame={self.motion_loader.index_0}/{self.motion_loader.num_frames}")
|
||||
logger.info(f"Motion {mode}: t={self.motion_elapsed_time:.2f}s, frame={self.motion_loader.index_0}/{self.motion_loader.num_frames}")
|
||||
if self.policy and not TEST_SEND_DEFAULT_POS and not TEST_DIRECT_PLAYBACK:
|
||||
logger_mp.info(f" Policy action range: [{self.motion_action.min():.3f}, {self.motion_action.max():.3f}]")
|
||||
logger_mp.info(f" Sample actions[0,6,12]: [{self.motion_action[0]:.3f}, {self.motion_action[6]:.3f}, {self.motion_action[12]:.3f}]")
|
||||
logger_mp.info(f" Target pos (after scale)[0,6,12]: [{target_joint_pos_bfs[0]:.3f}, {target_joint_pos_bfs[6]:.3f}, {target_joint_pos_bfs[12]:.3f}]")
|
||||
logger_mp.info(f" Robot pos (BFS)[0,6,12]: [{self.motion_qj_all[0]:.3f}, {self.motion_qj_all[6]:.3f}, {self.motion_qj_all[12]:.3f}]")
|
||||
logger.info(f" Policy action range: [{self.motion_action.min():.3f}, {self.motion_action.max():.3f}]")
|
||||
logger.info(f" Sample actions[0,6,12]: [{self.motion_action[0]:.3f}, {self.motion_action[6]:.3f}, {self.motion_action[12]:.3f}]")
|
||||
logger.info(f" Target pos (after scale)[0,6,12]: [{target_joint_pos_bfs[0]:.3f}, {target_joint_pos_bfs[6]:.3f}, {target_joint_pos_bfs[12]:.3f}]")
|
||||
logger.info(f" Robot pos (BFS)[0,6,12]: [{self.motion_qj_all[0]:.3f}, {self.motion_qj_all[6]:.3f}, {self.motion_qj_all[12]:.3f}]")
|
||||
|
||||
# Send command
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
@@ -1792,13 +1789,13 @@ class UnitreeG1(Robot):
|
||||
|
||||
def _motion_imitation_thread_loop(self):
|
||||
"""Background thread that runs the motion imitation policy at specified rate."""
|
||||
logger_mp.info("Motion imitation thread started")
|
||||
logger.info("Motion imitation thread started")
|
||||
while self.motion_imitation_running:
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.motion_imitation_run()
|
||||
except Exception as e:
|
||||
logger_mp.error(f"Error in motion imitation loop: {e}")
|
||||
logger.error(f"Error in motion imitation loop: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -1806,45 +1803,45 @@ class UnitreeG1(Robot):
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.config.motion_control_dt - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger_mp.info("Motion imitation thread stopped")
|
||||
logger.info("Motion imitation thread stopped")
|
||||
|
||||
def start_motion_imitation_thread(self):
|
||||
"""Start the background motion imitation control thread."""
|
||||
if not self.config.motion_imitation_control:
|
||||
logger_mp.warning("motion_imitation_control is False, cannot start thread")
|
||||
logger.warning("motion_imitation_control is False, cannot start thread")
|
||||
return
|
||||
|
||||
if self.motion_imitation_running:
|
||||
logger_mp.warning("Motion imitation thread already running")
|
||||
logger.warning("Motion imitation thread already running")
|
||||
return
|
||||
|
||||
logger_mp.info("Starting motion imitation control thread...")
|
||||
logger.info("Starting motion imitation control thread...")
|
||||
self.motion_imitation_running = True
|
||||
self.motion_imitation_thread = threading.Thread(target=self._motion_imitation_thread_loop, daemon=True)
|
||||
self.motion_imitation_thread.start()
|
||||
logger_mp.info("Motion imitation control thread started!")
|
||||
logger.info("Motion imitation control thread started!")
|
||||
|
||||
def stop_motion_imitation_thread(self):
|
||||
"""Stop the background motion imitation control thread."""
|
||||
if not self.motion_imitation_running:
|
||||
return
|
||||
|
||||
logger_mp.info("Stopping motion imitation control thread...")
|
||||
logger.info("Stopping motion imitation control thread...")
|
||||
self.motion_imitation_running = False
|
||||
if self.motion_imitation_thread:
|
||||
self.motion_imitation_thread.join(timeout=2.0)
|
||||
logger_mp.info("Motion imitation control thread stopped")
|
||||
logger.info("Motion imitation control thread stopped")
|
||||
|
||||
def init_motion_imitation(self):
|
||||
"""Initialize motion imitation - move to default standing pose and start policy."""
|
||||
if not self.config.motion_imitation_control:
|
||||
logger_mp.warning("motion_imitation_control is False, cannot run initialization")
|
||||
logger.warning("motion_imitation_control is False, cannot run initialization")
|
||||
return
|
||||
|
||||
logger_mp.info("Starting motion imitation initialization...")
|
||||
logger.info("Starting motion imitation initialization...")
|
||||
|
||||
# Move to default standing position
|
||||
logger_mp.info("Moving to default standing position...")
|
||||
logger.info("Moving to default standing position...")
|
||||
total_time = 3.0
|
||||
num_steps = int(total_time / self.config.motion_control_dt)
|
||||
|
||||
@@ -1871,17 +1868,17 @@ class UnitreeG1(Robot):
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
time.sleep(self.config.motion_control_dt)
|
||||
|
||||
logger_mp.info("Reached default position")
|
||||
logger.info("Reached default position")
|
||||
|
||||
# Wait 2 seconds
|
||||
time.sleep(2.0)
|
||||
|
||||
# Start motion imitation policy thread
|
||||
logger_mp.info("Starting motion imitation policy control...")
|
||||
logger.info("Starting motion imitation policy control...")
|
||||
self.start_motion_imitation_thread()
|
||||
|
||||
logger_mp.info("Motion imitation initialization complete! Policy is now running.")
|
||||
logger_mp.info(f"154D observations, 29D actions. Motion duration: {self.motion_loader.duration:.2f}s")
|
||||
logger.info("Motion imitation initialization complete! Policy is now running.")
|
||||
logger.info(f"154D observations, 29D actions. Motion duration: {self.motion_loader.duration:.2f}s")
|
||||
|
||||
|
||||
class G1_29_JointArmIndex(IntEnum):
|
||||
|
||||
73
src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py
Normal file
73
src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# unitree_sdk2_socket.py
|
||||
import zmq
|
||||
import pickle
|
||||
import time
|
||||
|
||||
# you can tune these or read from env
|
||||
ROBOT_IP = "172.18.129.215"
|
||||
LOWCMD_PORT = 6000 # laptop -> robot
|
||||
LOWSTATE_PORT = 6001 # robot -> laptop
|
||||
|
||||
_ctx = None
|
||||
_lowcmd_sock = None
|
||||
_lowstate_sock = None
|
||||
|
||||
def ChannelFactoryInitialize(*args, **kwargs):
|
||||
global _ctx, _lowcmd_sock, _lowstate_sock
|
||||
if _ctx is not None:
|
||||
return
|
||||
_ctx = zmq.Context.instance()
|
||||
|
||||
# lowcmd: PUSH from laptop to robot
|
||||
_lowcmd_sock = _ctx.socket(zmq.PUSH)
|
||||
_lowcmd_sock.setsockopt(zmq.CONFLATE, 1)
|
||||
_lowcmd_sock.connect(f"tcp://{ROBOT_IP}:{LOWCMD_PORT}")
|
||||
|
||||
# lowstate: SUB from robot
|
||||
_lowstate_sock = _ctx.socket(zmq.SUB) # no topic filtering
|
||||
_lowstate_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
|
||||
_lowstate_sock.connect(f"tcp://{ROBOT_IP}:{LOWSTATE_PORT}")
|
||||
_lowstate_sock.setsockopt_string(zmq.SUBSCRIBE, "") # subscribe to all
|
||||
|
||||
|
||||
class ChannelPublisher:
|
||||
# just enough api for your code: __init__, Init, Write
|
||||
def __init__(self, topic, msg_type):
|
||||
# we ignore topic/msg_type, the bridge only supports the topics you use
|
||||
self.topic = topic
|
||||
self.msg_type = msg_type
|
||||
|
||||
def Init(self):
|
||||
# nothing to do, sockets are global
|
||||
pass
|
||||
|
||||
def Write(self, msg):
|
||||
# msg is hg_LowCmd_ instance – we just pickle it
|
||||
payload = pickle.dumps((self.topic, msg))
|
||||
_lowcmd_sock.send(payload)
|
||||
|
||||
|
||||
class ChannelSubscriber:
|
||||
# api: __init__, Init, Read
|
||||
def __init__(self, topic, msg_type):
|
||||
self.topic = topic
|
||||
self.msg_type = msg_type
|
||||
|
||||
def Init(self):
|
||||
pass
|
||||
|
||||
def Read(self, timeout_ms=None):
|
||||
"""Block until we get a lowstate, optionally with timeout (ms)."""
|
||||
if timeout_ms is None:
|
||||
payload = _lowstate_sock.recv()
|
||||
else:
|
||||
poller = zmq.Poller()
|
||||
poller.register(_lowstate_sock, zmq.POLLIN)
|
||||
events = dict(poller.poll(timeout_ms))
|
||||
if _lowstate_sock not in events:
|
||||
return None
|
||||
payload = _lowstate_sock.recv()
|
||||
|
||||
topic, msg = pickle.loads(payload)
|
||||
# you can assert topic == self.topic, but not necessary if you only use one
|
||||
return msg
|
||||
@@ -55,7 +55,7 @@ class HomunculusArm(Teleoperator):
|
||||
"wrist_yaw": MotorNormMode.RANGE_M100_100,
|
||||
"wrist_pitch": MotorNormMode.RANGE_M100_100,
|
||||
}
|
||||
n = 50
|
||||
n = 10
|
||||
# EMA parameters ---------------------------------------------------
|
||||
self.n: int = n
|
||||
self.alpha: float = 2 / (n + 1)
|
||||
|
||||
Reference in New Issue
Block a user