sync recent changes

This commit is contained in:
Martino Russi
2025-11-21 14:13:05 +01:00
parent e5cae6be64
commit 9a052566a3
326 changed files with 20122 additions and 15 deletions

View File

@@ -0,0 +1,219 @@
import os
import cv2
import json
import datetime
import numpy as np
import time
from queue import Queue, Empty
from threading import Thread
import logging_mp
logger_mp = logging_mp.get_logger(__name__)
class EpisodeWriter:
def __init__(self, task_dir, frequency=30, image_size=[640, 480]):
"""
image_size: [width, height]
"""
logger_mp.info("==> EpisodeWriter initializing...\n")
self.task_dir = task_dir
self.frequency = frequency
self.image_size = image_size
self.data = {}
self.episode_data = []
self.item_id = -1
self.episode_id = -1
if os.path.exists(self.task_dir):
episode_dirs = [episode_dir for episode_dir in os.listdir(self.task_dir) if "episode_" in episode_dir]
episode_last = sorted(episode_dirs)[-1] if len(episode_dirs) > 0 else None
self.episode_id = 0 if episode_last is None else int(episode_last.split("_")[-1])
logger_mp.info(f"==> task_dir directory already exist, now self.episode_id is:{self.episode_id}\n")
else:
os.makedirs(self.task_dir)
logger_mp.info("==> episode directory does not exist, now create one.\n")
self.data_info()
self.text_desc()
self.result = None
self.is_available = True # Indicates whether the class is available for new operations
# Initialize the queue and worker thread
self.item_data_queue = Queue(-1)
self.stop_worker = False
self.need_save = False # Flag to indicate when save_episode is triggered
self.worker_thread = Thread(target=self.process_queue)
self.worker_thread.start()
logger_mp.info("==> EpisodeWriter initialized successfully.\n")
def data_info(self, version="1.0.0", date=None, author=None):
self.info = {
"version": "1.0.0" if version is None else version,
"date": datetime.date.today().strftime("%Y-%m-%d") if date is None else date,
"author": "unitree" if author is None else author,
"image": {"width": self.image_size[0], "height": self.image_size[1], "fps": self.frequency},
"depth": {"width": self.image_size[0], "height": self.image_size[1], "fps": self.frequency},
"audio": {"sample_rate": 16000, "channels": 1, "format": "PCM", "bits": 16}, # PCM_S16
"joint_names": {
"left_arm": [
"kLeftShoulderPitch",
"kLeftShoulderRoll",
"kLeftShoulderYaw",
"kLeftElbow",
"kLeftWristRoll",
"kLeftWristPitch",
"kLeftWristyaw",
],
"left_ee": [],
"right_arm": [],
"right_ee": [],
"body": [],
},
"tactile_names": {
"left_ee": [],
"right_ee": [],
},
"sim_state": "",
}
def text_desc(self):
self.text = {
"goal": "Place the wooden blocks into the yellow frame, stacking them from bottom to top in the order: red, yellow, green.",
"desc": "Using the gripper, first place the red wooden block into the yellow frame. Next, stack the yellow wooden block on top of the red one, and finally place the green wooden block on top of the yellow block.",
"steps": "",
}
def create_episode(self):
"""
Create a new episode.
Returns:
bool: True if the episode is successfully created, False otherwise.
Note:
Once successfully created, this function will only be available again after save_episode complete its save task.
"""
if not self.is_available:
logger_mp.info(
"==> The class is currently unavailable for new operations. Please wait until ongoing tasks are completed."
)
return False # Return False if the class is unavailable
# Reset episode-related data and create necessary directories
self.item_id = -1
self.episode_data = []
self.episode_id = self.episode_id + 1
self.episode_dir = os.path.join(self.task_dir, f"episode_{str(self.episode_id).zfill(4)}")
self.color_dir = os.path.join(self.episode_dir, "colors")
self.depth_dir = os.path.join(self.episode_dir, "depths")
self.audio_dir = os.path.join(self.episode_dir, "audios")
self.json_path = os.path.join(self.episode_dir, "data.json")
os.makedirs(self.episode_dir, exist_ok=True)
os.makedirs(self.color_dir, exist_ok=True)
os.makedirs(self.depth_dir, exist_ok=True)
os.makedirs(self.audio_dir, exist_ok=True)
self.is_available = False # After the episode is created, the class is marked as unavailable until the episode is successfully saved
logger_mp.info(f"==> New episode created: {self.episode_dir}")
return True # Return True if the episode is successfully created
def add_item(self, colors, depths=None, states=None, actions=None, tactiles=None, audios=None, sim_state=None):
# Increment the item ID
self.item_id += 1
# Create the item data dictionary
item_data = {
"idx": self.item_id,
"colors": colors,
"depths": depths,
"states": states,
"actions": actions,
"tactiles": tactiles,
"audios": audios,
"sim_state": sim_state,
}
# Enqueue the item data
self.item_data_queue.put(item_data)
def process_queue(self):
while not self.stop_worker or not self.item_data_queue.empty():
# Process items in the queue
try:
item_data = self.item_data_queue.get(timeout=1)
try:
self._process_item_data(item_data)
except Exception as e:
logger_mp.info(f"Error processing item_data (idx={item_data['idx']}): {e}")
self.item_data_queue.task_done()
except Empty:
pass
# Check if save_episode was triggered
if self.need_save and self.item_data_queue.empty():
self._save_episode()
def _process_item_data(self, item_data):
idx = item_data["idx"]
colors = item_data.get("colors", {})
depths = item_data.get("depths", {})
audios = item_data.get("audios", {})
# Save images
if colors:
for idx_color, (color_key, color) in enumerate(colors.items()):
color_name = f"{str(idx).zfill(6)}_{color_key}.jpg"
if not cv2.imwrite(os.path.join(self.color_dir, color_name), color):
logger_mp.info("Failed to save color image.")
item_data["colors"][color_key] = os.path.join("colors", color_name)
# Save depths
if depths:
for idx_depth, (depth_key, depth) in enumerate(depths.items()):
depth_name = f"{str(idx).zfill(6)}_{depth_key}.jpg"
if not cv2.imwrite(os.path.join(self.depth_dir, depth_name), depth):
logger_mp.info("Failed to save depth image.")
item_data["depths"][depth_key] = os.path.join("depths", depth_name)
# Save audios
if audios:
for mic, audio in audios.items():
audio_name = f"audio_{str(idx).zfill(6)}_{mic}.npy"
np.save(os.path.join(self.audio_dir, audio_name), audio.astype(np.int16))
item_data["audios"][mic] = os.path.join("audios", audio_name)
# Update episode data
self.episode_data.append(item_data)
def save_episode(self, result):
"""
Trigger the save operation. This sets the save flag, and the process_queue thread will handle it.
"""
self.need_save = True # Set the save flag
self.result = result
logger_mp.info("==> Episode saved start...")
def _save_episode(self):
"""
Save the episode data to a JSON file.
"""
self.data["info"] = self.info
self.data["text"] = self.text
self.data["data"] = self.episode_data
self.data["result"] = self.result
with open(self.json_path, "w", encoding="utf-8") as jsonf:
jsonf.write(json.dumps(self.data, indent=4, ensure_ascii=False))
self.need_save = False # Reset the save flag
self.is_available = True # Mark the class as available after saving
logger_mp.info(f"==> Episode saved successfully to {self.json_path} with result: {self.result}")
def close(self):
"""
Stop the worker thread and ensure all tasks are completed.
"""
self.item_data_queue.join()
if not self.is_available: # If self.is_available is False, it means there is still data not saved.
self.save_episode(self.result)
while not self.is_available:
time.sleep(0.01)
self.stop_worker = True
self.worker_thread.join()

View File

@@ -0,0 +1,166 @@
import torch
from datetime import datetime
from typing import Any
import rerun as rr
import rerun.blueprint as rrb
class RerunLogger:
"""
A fully automatic Rerun logger designed to parse and visualize step
dictionaries directly from a LeRobotDataset.
"""
def __init__(
self,
prefix: str = "",
memory_limit: str = "200MB",
idxrangeboundary: int | None = 300,
):
"""Initializes the Rerun logger."""
# Use a descriptive name for the Rerun recording
rr.init(f"Dataset_Log_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
rr.spawn(memory_limit=memory_limit)
self.prefix = prefix
self.blueprint_sent = False
self.idxrangeboundary = idxrangeboundary
# --- Internal cache for discovered keys ---
self._image_keys: tuple[str, ...] = ()
self._state_key: str = ""
self._action_key: str = ""
self._index_key: str = "index"
self._task_key: str = "task"
self._episode_index_key: str = "episode_index"
self.current_episode = -1
def _initialize_from_data(self, step_data: dict[str, Any]):
"""Inspects the first data dictionary to discover components and set up the blueprint."""
print("RerunLogger: First data packet received. Auto-configuring...")
image_keys = []
for key, value in step_data.items():
if key.startswith("observation.images.") and isinstance(value, torch.Tensor) and value.ndim > 2:
image_keys.append(key)
elif key == "observation.state":
self._state_key = key
elif key == "action":
self._action_key = key
self._image_keys = tuple(sorted(image_keys))
if "index" in step_data:
self._index_key = "index"
elif "frame_index" in step_data:
self._index_key = "frame_index"
print(f" - Using '{self._index_key}' for time sequence.")
print(f" - Detected State Key: '{self._state_key}'")
print(f" - Detected Action Key: '{self._action_key}'")
print(f" - Detected Image Keys: {self._image_keys}")
if self.idxrangeboundary:
self.setup_blueprint()
def setup_blueprint(self):
"""Sets up and sends the Rerun blueprint based on detected components."""
views = []
for key in self._image_keys:
clean_name = key.replace("observation.images.", "")
entity_path = f"{self.prefix}images/{clean_name}"
views.append(rrb.Spatial2DView(origin=entity_path, name=clean_name))
if self._state_key:
entity_path = f"{self.prefix}state"
views.append(
rrb.TimeSeriesView(
origin=entity_path,
name="Observation State",
time_ranges=[
rrb.VisibleTimeRange(
"frame",
start=rrb.TimeRangeBoundary.cursor_relative(seq=-self.idxrangeboundary),
end=rrb.TimeRangeBoundary.cursor_relative(),
)
],
plot_legend=rrb.PlotLegend(visible=True),
)
)
if self._action_key:
entity_path = f"{self.prefix}action"
views.append(
rrb.TimeSeriesView(
origin=entity_path,
name="Action",
time_ranges=[
rrb.VisibleTimeRange(
"frame",
start=rrb.TimeRangeBoundary.cursor_relative(seq=-self.idxrangeboundary),
end=rrb.TimeRangeBoundary.cursor_relative(),
)
],
plot_legend=rrb.PlotLegend(visible=True),
)
)
if not views:
print("Warning: No visualizable components detected in the data.")
return
grid = rrb.Grid(contents=views)
rr.send_blueprint(grid)
self.blueprint_sent = True
def log_step(self, step_data: dict[str, Any]):
"""Logs a single step dictionary from your dataset."""
if not self.blueprint_sent:
self._initialize_from_data(step_data)
if self._index_key in step_data:
current_index = step_data[self._index_key].item()
rr.set_time_sequence("frame", current_index)
episode_idx = step_data.get(self._episode_index_key, torch.tensor(-1)).item()
if episode_idx != self.current_episode:
self.current_episode = episode_idx
task_name = step_data.get(self._task_key, "Unknown Task")
log_text = f"Starting Episode {self.current_episode}: {task_name}"
rr.log(f"{self.prefix}info/task", rr.TextLog(log_text, level=rr.TextLogLevel.INFO))
for key in self._image_keys:
if key in step_data:
image_tensor = step_data[key]
if image_tensor.ndim > 2:
clean_name = key.replace("observation.images.", "")
entity_path = f"{self.prefix}images/{clean_name}"
if image_tensor.shape[0] in [1, 3, 4]:
image_tensor = image_tensor.permute(1, 2, 0)
rr.log(entity_path, rr.Image(image_tensor))
if self._state_key in step_data:
state_tensor = step_data[self._state_key]
entity_path = f"{self.prefix}state"
for i, val in enumerate(state_tensor):
rr.log(f"{entity_path}/joint_{i}", rr.Scalar(val.item()))
if self._action_key in step_data:
action_tensor = step_data[self._action_key]
entity_path = f"{self.prefix}action"
for i, val in enumerate(action_tensor):
rr.log(f"{entity_path}/joint_{i}", rr.Scalar(val.item()))
def visualization_data(idx, observation, state, action, online_logger):
item_data: dict[str, Any] = {
"index": torch.tensor(idx),
"observation.state": state,
"action": action,
}
for k, v in observation.items():
if k not in ("index", "observation.state", "action"):
item_data[k] = v
online_logger.log_step(item_data)

View File

@@ -0,0 +1,209 @@
# for simulation
import torch
import numpy as np
import logging_mp
from unitree_lerobot.eval_robot.utils.utils import (
reset_policy,
)
from unitree_lerobot.eval_robot.make_robot import (
publish_reset_category,
)
from dataclasses import dataclass
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
import time
logging_mp.basic_config(level=logging_mp.INFO)
logger_mp = logging_mp.get_logger(__name__)
def process_data_add(episode_writer, observation_image, current_arm_q, ee_state, action, arm_dof, ee_dof):
if episode_writer is None:
return
if (
observation_image is not None
and current_arm_q is not None
and ee_state is not None
and action is not None
and arm_dof is not None
and ee_dof is not None
):
# Convert tensors to numpy arrays for JSON serialization
if torch.is_tensor(current_arm_q):
current_arm_q = current_arm_q.detach().cpu().numpy()
if torch.is_tensor(ee_state):
ee_state = ee_state.detach().cpu().numpy()
if torch.is_tensor(action):
action = action.detach().cpu().numpy()
colors = {}
i = 0
for key, value in observation_image.items():
if "images" in key:
if value is not None:
# Convert PyTorch tensor to numpy array for OpenCV compatibility
if torch.is_tensor(value):
# Convert tensor to numpy array and ensure correct format for OpenCV
img_array = value.detach().cpu().numpy()
# If the image is in CHW format (channels first), convert to HWC format (channels last)
if img_array.ndim == 3 and img_array.shape[0] in [1, 3, 4]:
img_array = np.transpose(img_array, (1, 2, 0))
# Ensure the array is in uint8 format for OpenCV
if img_array.dtype != np.uint8:
if img_array.max() <= 1.0: # Normalized values [0, 1]
img_array = (img_array * 255).astype(np.uint8)
else: # Values already in [0, 255] range
img_array = img_array.astype(np.uint8)
# Keep original RGB format - no color channel conversion needed
colors[f"color_{i}"] = img_array
else:
colors[f"color_{i}"] = value
i += 1
states = {
"left_arm": {
"qpos": current_arm_q[: arm_dof // 2].tolist(), # numpy.array -> list
"qvel": [],
"torque": [],
},
"right_arm": {
"qpos": current_arm_q[arm_dof // 2 :].tolist(),
"qvel": [],
"torque": [],
},
"left_ee": {
"qpos": ee_state[:ee_dof].tolist(),
"qvel": [],
"torque": [],
},
"right_ee": {
"qpos": ee_state[ee_dof:].tolist(),
"qvel": [],
"torque": [],
},
"body": {
"qpos": [],
},
}
actions = {
"left_arm": {
"qpos": action[: arm_dof // 2].tolist(),
"qvel": [],
"torque": [],
},
"right_arm": {
"qpos": action[arm_dof // 2 :].tolist(),
"qvel": [],
"torque": [],
},
"left_ee": {
"qpos": action[arm_dof : arm_dof + ee_dof].tolist(),
"qvel": [],
"torque": [],
},
"right_ee": {
"qpos": action[arm_dof + ee_dof : arm_dof + 2 * ee_dof].tolist(),
"qvel": [],
"torque": [],
},
"body": {
"qpos": [],
},
}
episode_writer.add_item(colors, states=states, actions=actions)
def process_data_save(episode_writer, result):
"""Processes data and saves it."""
if episode_writer is None:
return
episode_writer.save_episode(result)
def is_success(
sim_reward_subscriber,
episode_writer,
reset_pose_publisher,
policy,
cfg,
reward_stats,
init_arm_pose,
robot_interface,
):
# logger_mp.info(f"arm_action {arm_action}, tau {tau}")
if sim_reward_subscriber:
data = sim_reward_subscriber.read_data()
if data is not None:
if int(data["rewards"][0]) == 1:
reward_stats["reward_sum"] += 1
sim_reward_subscriber.reset_data()
# success
if reward_stats["reward_sum"] >= 25:
process_data_save(episode_writer, "success")
logger_mp.info(
f"Episode {reward_stats['episode_num']} finished with reward {reward_stats['reward_sum']},save data..."
)
reward_stats["episode_num"] = -1
reward_stats["reward_sum"] = 0
time.sleep(1)
publish_reset_category(1, reset_pose_publisher)
time.sleep(1)
reset_policy(policy)
sim_reward_subscriber.reset_data()
# fail
elif reward_stats["episode_num"] > cfg.max_episodes:
process_data_save(episode_writer, "fail")
logger_mp.info(f"Episode {reward_stats['episode_num']} finished with reward {reward_stats['reward_sum']}")
reward_stats["episode_num"] = -1
reward_stats["reward_sum"] = 0
reset_policy(policy)
sim_reward_subscriber.reset_data()
logger_mp.info("Initializing robot to starting pose...")
tau = robot_interface["arm_ik"].solve_tau(init_arm_pose)
robot_interface["arm_ctrl"].ctrl_dual_arm(init_arm_pose, tau)
time.sleep(1)
publish_reset_category(1, reset_pose_publisher)
time.sleep(1)
reset_policy(policy)
sim_reward_subscriber.reset_data()
time.sleep(1)
@dataclass
class EvalRealConfig:
repo_id: str
policy: PreTrainedConfig | None = None
root: str = ""
episodes: int = 0
frequency: float = 30.0
# Basic control parameters
arm: str = "G1_29" # G1_29, G1_23
ee: str = "dex3" # dex3, dex1, inspire1, brainco
# Mode flags
motion: bool = False
headless: bool = False
sim: bool = True
visualization: bool = False
send_real_robot: bool = False
use_dataset: bool = False
save_data: bool = False
task_dir: str = "./data"
max_episodes: int = 1200
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
logger_mp.warning(
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
)
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]

View File

@@ -0,0 +1,402 @@
# Copyright (c) 2025, Unitree Robotics Co., Ltd. All Rights Reserved.
# License: Apache License, Version 2.0
"""
Simple sim state subscriber class
Subscribe to rt/sim_state_cmd topic and write to shared memory
"""
import threading
import time
import json
from multiprocessing import shared_memory
from typing import Any
from unitree_sdk2py.core.channel import ChannelSubscriber
from unitree_sdk2py.idl.std_msgs.msg.dds_ import String_
import logging_mp
logger_mp = logging_mp.get_logger(__name__)
class SharedMemoryManager:
"""Shared memory manager"""
def __init__(self, name: str | None = None, size: int = 512):
"""Initialize shared memory manager
Args:
name: shared memory name, if None, create new one
size: shared memory size (bytes)
"""
self.size = size
self.lock = threading.RLock() # reentrant lock
if name:
try:
self.shm = shared_memory.SharedMemory(name=name)
self.shm_name = name
self.created = False
except FileNotFoundError:
self.shm = shared_memory.SharedMemory(create=True, size=size)
self.shm_name = self.shm.name
self.created = True
else:
self.shm = shared_memory.SharedMemory(create=True, size=size)
self.shm_name = self.shm.name
self.created = True
def write_data(self, data: dict[str, Any]) -> bool:
"""Write data to shared memory
Args:
data: data to write
Returns:
bool: write success or not
"""
try:
with self.lock:
json_str = json.dumps(data)
json_bytes = json_str.encode("utf-8")
if len(json_bytes) > self.size - 8: # reserve 8 bytes for length and timestamp
logger_mp.warning(f"Data too large for shared memory ({len(json_bytes)} > {self.size - 8})")
return False
# write timestamp (4 bytes) and data length (4 bytes)
timestamp = int(time.time()) & 0xFFFFFFFF # 32-bit timestamp, use bitmask to ensure in range
self.shm.buf[0:4] = timestamp.to_bytes(4, "little")
self.shm.buf[4:8] = len(json_bytes).to_bytes(4, "little")
# write data
self.shm.buf[8 : 8 + len(json_bytes)] = json_bytes
return True
except Exception as e:
logger_mp.error(f"Error writing to shared memory: {e}")
return False
def read_data(self) -> dict[str, Any] | None:
"""Read data from shared memory
Returns:
Dict[str, Any]: read data dictionary, return None if failed
"""
try:
with self.lock:
# read timestamp and data length
timestamp = int.from_bytes(self.shm.buf[0:4], "little")
data_len = int.from_bytes(self.shm.buf[4:8], "little")
if data_len == 0:
return None
# read data
json_bytes = bytes(self.shm.buf[8 : 8 + data_len])
data = json.loads(json_bytes.decode("utf-8"))
data["_timestamp"] = timestamp # add timestamp information
return data
except Exception as e:
logger_mp.error(f"Error reading from shared memory: {e}")
return None
def reset_data(self):
"""Reset data"""
if self.shm:
self.shm.buf[0:8] = b"\x00" * 8
else:
logger_mp.error("[SharedMemoryManager] Shared memory is not initialized")
def get_name(self) -> str:
"""Get shared memory name"""
return self.shm_name
def cleanup(self):
"""Clean up shared memory"""
if hasattr(self, "shm") and self.shm:
self.shm.close()
if self.created:
self.shm.unlink()
def __del__(self):
"""Destructor"""
self.cleanup()
class SimStateSubscriber:
"""Simple sim state subscriber class"""
def __init__(self, shm_name: str = "sim_state_cmd_data", shm_size: int = 4096):
"""Initialize the subscriber
Args:
shm_name: shared memory name
shm_size: shared memory size
"""
self.shm_name = shm_name
self.shm_size = shm_size
self.running = False
self.subscriber = None
self.subscribe_thread = None
self.shared_memory = None
# initialize shared memory
self._setup_shared_memory()
logger_mp.debug(f"[SimStateSubscriber] Initialized with shared memory: {shm_name}")
def _setup_shared_memory(self):
"""Setup shared memory"""
try:
self.shared_memory = SharedMemoryManager(self.shm_name, self.shm_size)
logger_mp.debug("[SimStateSubscriber] Shared memory setup successfully")
except Exception as e:
logger_mp.error(f"[SimStateSubscriber] Failed to setup shared memory: {e}")
def start_subscribe(self):
"""Start subscribing"""
if self.running:
logger_mp.warning("[SimStateSubscriber] Already running")
return
try:
self.subscriber = ChannelSubscriber("rt/sim_state", String_)
self.subscriber.Init()
self.running = True
self.subscribe_thread = threading.Thread(target=self._subscribe_sim_state, daemon=True)
self.subscribe_thread.start()
logger_mp.info("[SimStateSubscriber] Started subscribing to rt/sim_state")
except Exception as e:
logger_mp.error(f"[SimStateSubscriber] Failed to start subscribing: {e}")
self.running = False
def _subscribe_sim_state(self):
"""Subscribe loop thread"""
logger_mp.debug("[SimStateSubscriber] Subscribe thread started")
while self.running:
try:
if self.subscriber:
msg = self.subscriber.Read()
if msg:
data = json.loads(msg.data)
else:
logger_mp.warning("[SimStateSubscriber] Received None message")
if self.shared_memory and data:
self.shared_memory.write_data(data)
else:
logger_mp.error("[SimStateSubscriber] Subscriber is not initialized")
time.sleep(0.002)
except Exception as e:
logger_mp.error(f"[SimStateSubscriber] Error in subscribe loop: {e}")
time.sleep(0.01)
def stop_subscribe(self):
"""Stop subscribing"""
if not self.running:
logger_mp.warning("[SimStateSubscriber] Already stopped or not running")
return
self.running = False
# wait for thread to finish
if self.subscribe_thread:
self.subscribe_thread.join(timeout=1.0)
if self.shared_memory:
self.shared_memory.cleanup()
logger_mp.info("[SimStateSubscriber] Subscriber stopped")
def read_data(self) -> dict[str, Any] | None:
"""Read data from shared memory
Returns:
Dict: received data, None if no data or error
"""
try:
if self.shared_memory:
return self.shared_memory.read_data()
return None
except Exception as e:
logger_mp.error(f"[SimStateSubscriber] Error reading data: {e}")
return None
def is_running(self) -> bool:
"""Check if subscriber is running"""
return self.running
def __del__(self):
"""Destructor"""
self.stop_subscribe()
def start_sim_state_subscribe(shm_name: str = "sim_state_cmd_data", shm_size: int = 4096) -> SimStateSubscriber:
"""Start sim state subscribing
Args:
shm_name: shared memory name
shm_size: shared memory size
Returns:
SimStateSubscriber: started subscriber instance
"""
subscriber = SimStateSubscriber(shm_name, shm_size)
subscriber.start_subscribe()
return subscriber
# ============================== sim reward topic ==============================
class SimRewardSubscriber:
"""Simple sim state subscriber class"""
def __init__(self, shm_name: str = "sim_reward_cmd_data", shm_size: int = 256):
"""Initialize the subscriber
Args:
shm_name: shared memory name
shm_size: shared memory size
"""
self.shm_name = shm_name
self.shm_size = shm_size
self.running = False
self.subscriber = None
self.subscribe_thread = None
self.shared_memory = None
# initialize shared memory
self._setup_shared_memory()
logger_mp.debug(f"[SimRewardSubscriber] Initialized with shared memory: {shm_name}")
def _setup_shared_memory(self):
"""Setup shared memory"""
try:
self.shared_memory = SharedMemoryManager(self.shm_name, self.shm_size)
logger_mp.debug("[SimRewardSubscriber] Shared memory setup successfully")
except Exception as e:
logger_mp.error(f"[SimRewardSubscriber] Failed to setup shared memory: {e}")
def start_subscribe(self):
"""Start subscribing"""
if self.running:
logger_mp.warning("[SimRewardSubscriber] Already running")
return
try:
self.subscriber = ChannelSubscriber("rt/rewards_state", String_)
self.subscriber.Init()
self.running = True
self.subscribe_thread = threading.Thread(target=self._subscribe_sim_reward, daemon=True)
self.subscribe_thread.start()
logger_mp.info("[SimRewardSubscriber] Started subscribing to rt/sim_reward")
except Exception as e:
logger_mp.error(f"[SimRewardSubscriber] Failed to start subscribing: {e}")
self.running = False
def _subscribe_sim_reward(self):
"""Subscribe loop thread"""
logger_mp.debug("[SimRewardSubscriber] Subscribe thread started")
while self.running:
try:
if self.subscriber:
msg = self.subscriber.Read()
if msg:
data = json.loads(msg.data)
else:
logger_mp.warning("[SimRewardSubscriber] Received None message")
if self.shared_memory and data:
self.shared_memory.write_data(data)
else:
logger_mp.error("[SimRewardSubscriber] Subscriber is not initialized")
time.sleep(0.01)
except Exception as e:
logger_mp.error(f"[SimRewardSubscriber] Error in subscribe loop: {e}")
time.sleep(0.02)
def stop_subscribe(self):
"""Stop subscribing"""
if not self.running:
logger_mp.warning("[SimRewardSubscriber] Already stopped or not running")
return
self.running = False
# wait for thread to finish
if self.subscribe_thread:
self.subscribe_thread.join(timeout=1.0)
if self.shared_memory:
self.shared_memory.cleanup()
logger_mp.info("[SimRewardSubscriber] Subscriber stopped")
def read_data(self) -> dict[str, Any] | None:
"""Read data from shared memory
Returns:
Dict: received data, None if no data or error
"""
try:
if self.shared_memory:
return self.shared_memory.read_data()
return None
except Exception as e:
logger_mp.error(f"[SimRewardSubscriber] Error reading data: {e}")
return None
def reset_data(self):
"""Reset data"""
if self.shared_memory:
data = {"rewards": [0.0], "timestamp": 1758009108.266387}
self.shared_memory.write_data(data)
def is_running(self) -> bool:
"""Check if subscriber is running"""
return self.running
def __del__(self):
"""Destructor"""
self.stop_subscribe()
# ============================== sim reward topic ==============================
def start_sim_reward_subscribe(shm_name: str = "sim_reward_cmd_data", shm_size: int = 256) -> SimRewardSubscriber:
"""Start sim reward subscribing
Args:
shm_name: shared memory name
shm_size: shared memory size
Returns:
SimRewardSubscriber: started subscriber instance
"""
subscriber = SimRewardSubscriber(shm_name, shm_size)
subscriber.start_subscribe()
return subscriber
# if __name__ == "__main__":
# # example usage
# logger_mp.info("Starting sim state subscriber...")
# ChannelFactoryInitialize(0)
# # create and start subscriber
# subscriber = start_sim_state_subscribe()
# try:
# # keep running and check for data
# while True:
# data = subscriber.read_data()
# if data:
# logger_mp.info(f"Read data: {data}")
# time.sleep(1)
# except KeyboardInterrupt:
# logger_mp.warning("\nInterrupted by user")
# finally:
# subscriber.stop_subscribe()
# logger_mp.info("Subscriber stopped")

142
eval_robot/utils/utils.py Normal file
View File

@@ -0,0 +1,142 @@
import numpy as np
import torch
from typing import Any
from contextlib import nullcontext
from copy import copy
import logging
from dataclasses import dataclass
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pretrained import PreTrainedPolicy
import logging_mp
logging_mp.basic_config(level=logging_mp.INFO)
logger_mp = logging_mp.get_logger(__name__)
def extract_observation(step: dict):
observation = {}
for key, value in step.items():
if key.startswith("observation.images."):
if isinstance(value, np.ndarray) and value.ndim == 3 and value.shape[-1] in [1, 3]:
value = np.transpose(value, (2, 0, 1))
observation[key] = value
elif key == "observation.state":
observation[key] = value
return observation
def predict_action(
observation: dict[str, np.ndarray],
policy: PreTrainedPolicy,
device: torch.device,
use_amp: bool,
task: str | None = None,
use_dataset: bool | None = False,
):
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
if not use_dataset:
# Skip non-tensor observations (like task strings)
if not hasattr(observation[name], "unsqueeze"):
continue
if "images" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0).to(device)
observation["task"] = [task if task else ""]
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
return action
def reset_policy(policy: PreTrainedPolicy):
policy.reset()
def cleanup_resources(image_info: dict[str, Any]):
"""Safely close and unlink shared memory resources."""
logger_mp.info("Cleaning up shared memory resources.")
for shm in image_info["shm_resources"]:
if shm:
shm.close()
shm.unlink()
def to_list(x):
if torch is not None and isinstance(x, torch.Tensor):
return x.detach().cpu().ravel().tolist()
if isinstance(x, np.ndarray):
return x.ravel().tolist()
if isinstance(x, (list, tuple)):
return list(x)
return [x]
def to_scalar(x):
if torch is not None and isinstance(x, torch.Tensor):
return float(x.detach().cpu().ravel()[0].item())
if isinstance(x, np.ndarray):
return float(x.ravel()[0])
if isinstance(x, (list, tuple)):
return float(x[0])
return float(x)
@dataclass
class EvalRealConfig:
repo_id: str
policy: PreTrainedConfig | None = None
root: str = ""
episodes: int = 0
frequency: float = 30.0
# Basic control parameters
arm: str = "G1_29" # G1_29, G1_23
ee: str = "dex3" # dex3, dex1, inspire1, brainco
# Mode flags
motion: bool = False
headless: bool = False
visualization: bool = False
send_real_robot: bool = False
use_dataset: bool = False
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
logging.warning(
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
)
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]

View File

@@ -0,0 +1,99 @@
import numpy as np
import matplotlib.pyplot as plt
class WeightedMovingFilter:
def __init__(self, weights, data_size=14):
self._window_size = len(weights)
self._weights = np.array(weights)
# assert np.isclose(np.sum(self._weights), 1.0), "[WeightedMovingFilter] the sum of weights list must be 1.0!"
self._data_size = data_size
self._filtered_data = np.zeros(self._data_size)
self._data_queue = []
def _apply_filter(self):
if len(self._data_queue) < self._window_size:
return self._data_queue[-1]
data_array = np.array(self._data_queue)
temp_filtered_data = np.zeros(self._data_size)
for i in range(self._data_size):
temp_filtered_data[i] = np.convolve(data_array[:, i], self._weights, mode="valid")[-1]
return temp_filtered_data
def add_data(self, new_data):
assert len(new_data) == self._data_size
if len(self._data_queue) > 0 and np.array_equal(new_data, self._data_queue[-1]):
return # skip duplicate data
if len(self._data_queue) >= self._window_size:
self._data_queue.pop(0)
self._data_queue.append(new_data)
self._filtered_data = self._apply_filter()
@property
def filtered_data(self):
return self._filtered_data
def visualize_filter_comparison(filter_params, steps):
import time
t = np.linspace(0, 4 * np.pi, steps)
original_data = np.array(
[np.sin(t + i) + np.random.normal(0, 0.2, len(t)) for i in range(35)]
).T # sin wave with noise, shape is [len(t), 35]
plt.figure(figsize=(14, 10))
for idx, weights in enumerate(filter_params):
filter = WeightedMovingFilter(weights, 14)
data_2b_filtered = original_data.copy()
filtered_data = []
time1 = time.time()
for i in range(steps):
filter.add_data(data_2b_filtered[i][13:27]) # step i, columns 13 to 26 (total:14)
data_2b_filtered[i][13:27] = filter.filtered_data
filtered_data.append(data_2b_filtered[i])
time2 = time.time()
print(f"filter_params:{filter_params[idx]}, time cosume:{time2 - time1}")
filtered_data = np.array(filtered_data)
# col0 should not 2b filtered
plt.subplot(len(filter_params), 2, idx * 2 + 1)
plt.plot(filtered_data[:, 0], label=f"Filtered (Window {filter._window_size})")
plt.plot(original_data[:, 0], "r--", label="Original", alpha=0.5)
plt.title("Joint 1 - Should not to be filtered.")
plt.xlabel("Step")
plt.ylabel("Value")
plt.legend()
# col13 should 2b filtered
plt.subplot(len(filter_params), 2, idx * 2 + 2)
plt.plot(filtered_data[:, 13], label=f"Filtered (Window {filter._window_size})")
plt.plot(original_data[:, 13], "r--", label="Original", alpha=0.5)
plt.title(f"Joint 13 - Window {filter._window_size}, Weights {weights}")
plt.xlabel("Step")
plt.ylabel("Value")
plt.legend()
plt.tight_layout()
plt.show()
if __name__ == "__main__":
# windows_size and weights
filter_params = [
(np.array([0.7, 0.2, 0.1])),
(np.array([0.5, 0.3, 0.2])),
(np.array([0.4, 0.3, 0.2, 0.1])),
]
visualize_filter_comparison(filter_params, steps=100)