mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
sync recent changes
This commit is contained in:
219
eval_robot/utils/episode_writer.py
Normal file
219
eval_robot/utils/episode_writer.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import os
|
||||
import cv2
|
||||
import json
|
||||
import datetime
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
from queue import Queue, Empty
|
||||
from threading import Thread
|
||||
import logging_mp
|
||||
|
||||
logger_mp = logging_mp.get_logger(__name__)
|
||||
|
||||
|
||||
class EpisodeWriter:
|
||||
def __init__(self, task_dir, frequency=30, image_size=[640, 480]):
|
||||
"""
|
||||
image_size: [width, height]
|
||||
"""
|
||||
logger_mp.info("==> EpisodeWriter initializing...\n")
|
||||
self.task_dir = task_dir
|
||||
self.frequency = frequency
|
||||
self.image_size = image_size
|
||||
|
||||
self.data = {}
|
||||
self.episode_data = []
|
||||
self.item_id = -1
|
||||
self.episode_id = -1
|
||||
if os.path.exists(self.task_dir):
|
||||
episode_dirs = [episode_dir for episode_dir in os.listdir(self.task_dir) if "episode_" in episode_dir]
|
||||
episode_last = sorted(episode_dirs)[-1] if len(episode_dirs) > 0 else None
|
||||
self.episode_id = 0 if episode_last is None else int(episode_last.split("_")[-1])
|
||||
logger_mp.info(f"==> task_dir directory already exist, now self.episode_id is:{self.episode_id}\n")
|
||||
else:
|
||||
os.makedirs(self.task_dir)
|
||||
logger_mp.info("==> episode directory does not exist, now create one.\n")
|
||||
self.data_info()
|
||||
self.text_desc()
|
||||
self.result = None
|
||||
self.is_available = True # Indicates whether the class is available for new operations
|
||||
# Initialize the queue and worker thread
|
||||
self.item_data_queue = Queue(-1)
|
||||
self.stop_worker = False
|
||||
self.need_save = False # Flag to indicate when save_episode is triggered
|
||||
self.worker_thread = Thread(target=self.process_queue)
|
||||
self.worker_thread.start()
|
||||
|
||||
logger_mp.info("==> EpisodeWriter initialized successfully.\n")
|
||||
|
||||
def data_info(self, version="1.0.0", date=None, author=None):
|
||||
self.info = {
|
||||
"version": "1.0.0" if version is None else version,
|
||||
"date": datetime.date.today().strftime("%Y-%m-%d") if date is None else date,
|
||||
"author": "unitree" if author is None else author,
|
||||
"image": {"width": self.image_size[0], "height": self.image_size[1], "fps": self.frequency},
|
||||
"depth": {"width": self.image_size[0], "height": self.image_size[1], "fps": self.frequency},
|
||||
"audio": {"sample_rate": 16000, "channels": 1, "format": "PCM", "bits": 16}, # PCM_S16
|
||||
"joint_names": {
|
||||
"left_arm": [
|
||||
"kLeftShoulderPitch",
|
||||
"kLeftShoulderRoll",
|
||||
"kLeftShoulderYaw",
|
||||
"kLeftElbow",
|
||||
"kLeftWristRoll",
|
||||
"kLeftWristPitch",
|
||||
"kLeftWristyaw",
|
||||
],
|
||||
"left_ee": [],
|
||||
"right_arm": [],
|
||||
"right_ee": [],
|
||||
"body": [],
|
||||
},
|
||||
"tactile_names": {
|
||||
"left_ee": [],
|
||||
"right_ee": [],
|
||||
},
|
||||
"sim_state": "",
|
||||
}
|
||||
|
||||
def text_desc(self):
|
||||
self.text = {
|
||||
"goal": "Place the wooden blocks into the yellow frame, stacking them from bottom to top in the order: red, yellow, green.",
|
||||
"desc": "Using the gripper, first place the red wooden block into the yellow frame. Next, stack the yellow wooden block on top of the red one, and finally place the green wooden block on top of the yellow block.",
|
||||
"steps": "",
|
||||
}
|
||||
|
||||
def create_episode(self):
|
||||
"""
|
||||
Create a new episode.
|
||||
Returns:
|
||||
bool: True if the episode is successfully created, False otherwise.
|
||||
Note:
|
||||
Once successfully created, this function will only be available again after save_episode complete its save task.
|
||||
"""
|
||||
if not self.is_available:
|
||||
logger_mp.info(
|
||||
"==> The class is currently unavailable for new operations. Please wait until ongoing tasks are completed."
|
||||
)
|
||||
return False # Return False if the class is unavailable
|
||||
|
||||
# Reset episode-related data and create necessary directories
|
||||
self.item_id = -1
|
||||
self.episode_data = []
|
||||
self.episode_id = self.episode_id + 1
|
||||
|
||||
self.episode_dir = os.path.join(self.task_dir, f"episode_{str(self.episode_id).zfill(4)}")
|
||||
self.color_dir = os.path.join(self.episode_dir, "colors")
|
||||
self.depth_dir = os.path.join(self.episode_dir, "depths")
|
||||
self.audio_dir = os.path.join(self.episode_dir, "audios")
|
||||
self.json_path = os.path.join(self.episode_dir, "data.json")
|
||||
os.makedirs(self.episode_dir, exist_ok=True)
|
||||
os.makedirs(self.color_dir, exist_ok=True)
|
||||
os.makedirs(self.depth_dir, exist_ok=True)
|
||||
os.makedirs(self.audio_dir, exist_ok=True)
|
||||
|
||||
self.is_available = False # After the episode is created, the class is marked as unavailable until the episode is successfully saved
|
||||
logger_mp.info(f"==> New episode created: {self.episode_dir}")
|
||||
return True # Return True if the episode is successfully created
|
||||
|
||||
def add_item(self, colors, depths=None, states=None, actions=None, tactiles=None, audios=None, sim_state=None):
|
||||
# Increment the item ID
|
||||
self.item_id += 1
|
||||
# Create the item data dictionary
|
||||
item_data = {
|
||||
"idx": self.item_id,
|
||||
"colors": colors,
|
||||
"depths": depths,
|
||||
"states": states,
|
||||
"actions": actions,
|
||||
"tactiles": tactiles,
|
||||
"audios": audios,
|
||||
"sim_state": sim_state,
|
||||
}
|
||||
# Enqueue the item data
|
||||
self.item_data_queue.put(item_data)
|
||||
|
||||
def process_queue(self):
|
||||
while not self.stop_worker or not self.item_data_queue.empty():
|
||||
# Process items in the queue
|
||||
try:
|
||||
item_data = self.item_data_queue.get(timeout=1)
|
||||
try:
|
||||
self._process_item_data(item_data)
|
||||
except Exception as e:
|
||||
logger_mp.info(f"Error processing item_data (idx={item_data['idx']}): {e}")
|
||||
self.item_data_queue.task_done()
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
# Check if save_episode was triggered
|
||||
if self.need_save and self.item_data_queue.empty():
|
||||
self._save_episode()
|
||||
|
||||
def _process_item_data(self, item_data):
|
||||
idx = item_data["idx"]
|
||||
colors = item_data.get("colors", {})
|
||||
depths = item_data.get("depths", {})
|
||||
audios = item_data.get("audios", {})
|
||||
|
||||
# Save images
|
||||
if colors:
|
||||
for idx_color, (color_key, color) in enumerate(colors.items()):
|
||||
color_name = f"{str(idx).zfill(6)}_{color_key}.jpg"
|
||||
if not cv2.imwrite(os.path.join(self.color_dir, color_name), color):
|
||||
logger_mp.info("Failed to save color image.")
|
||||
item_data["colors"][color_key] = os.path.join("colors", color_name)
|
||||
|
||||
# Save depths
|
||||
if depths:
|
||||
for idx_depth, (depth_key, depth) in enumerate(depths.items()):
|
||||
depth_name = f"{str(idx).zfill(6)}_{depth_key}.jpg"
|
||||
if not cv2.imwrite(os.path.join(self.depth_dir, depth_name), depth):
|
||||
logger_mp.info("Failed to save depth image.")
|
||||
item_data["depths"][depth_key] = os.path.join("depths", depth_name)
|
||||
|
||||
# Save audios
|
||||
if audios:
|
||||
for mic, audio in audios.items():
|
||||
audio_name = f"audio_{str(idx).zfill(6)}_{mic}.npy"
|
||||
np.save(os.path.join(self.audio_dir, audio_name), audio.astype(np.int16))
|
||||
item_data["audios"][mic] = os.path.join("audios", audio_name)
|
||||
|
||||
# Update episode data
|
||||
self.episode_data.append(item_data)
|
||||
|
||||
def save_episode(self, result):
|
||||
"""
|
||||
Trigger the save operation. This sets the save flag, and the process_queue thread will handle it.
|
||||
"""
|
||||
self.need_save = True # Set the save flag
|
||||
self.result = result
|
||||
logger_mp.info("==> Episode saved start...")
|
||||
|
||||
def _save_episode(self):
|
||||
"""
|
||||
Save the episode data to a JSON file.
|
||||
"""
|
||||
self.data["info"] = self.info
|
||||
self.data["text"] = self.text
|
||||
self.data["data"] = self.episode_data
|
||||
self.data["result"] = self.result
|
||||
|
||||
with open(self.json_path, "w", encoding="utf-8") as jsonf:
|
||||
jsonf.write(json.dumps(self.data, indent=4, ensure_ascii=False))
|
||||
self.need_save = False # Reset the save flag
|
||||
self.is_available = True # Mark the class as available after saving
|
||||
logger_mp.info(f"==> Episode saved successfully to {self.json_path} with result: {self.result}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Stop the worker thread and ensure all tasks are completed.
|
||||
"""
|
||||
self.item_data_queue.join()
|
||||
if not self.is_available: # If self.is_available is False, it means there is still data not saved.
|
||||
self.save_episode(self.result)
|
||||
while not self.is_available:
|
||||
time.sleep(0.01)
|
||||
self.stop_worker = True
|
||||
self.worker_thread.join()
|
||||
166
eval_robot/utils/rerun_visualizer.py
Normal file
166
eval_robot/utils/rerun_visualizer.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import torch
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import rerun as rr
|
||||
import rerun.blueprint as rrb
|
||||
|
||||
|
||||
class RerunLogger:
|
||||
"""
|
||||
A fully automatic Rerun logger designed to parse and visualize step
|
||||
dictionaries directly from a LeRobotDataset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str = "",
|
||||
memory_limit: str = "200MB",
|
||||
idxrangeboundary: int | None = 300,
|
||||
):
|
||||
"""Initializes the Rerun logger."""
|
||||
# Use a descriptive name for the Rerun recording
|
||||
rr.init(f"Dataset_Log_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
||||
rr.spawn(memory_limit=memory_limit)
|
||||
|
||||
self.prefix = prefix
|
||||
self.blueprint_sent = False
|
||||
self.idxrangeboundary = idxrangeboundary
|
||||
|
||||
# --- Internal cache for discovered keys ---
|
||||
self._image_keys: tuple[str, ...] = ()
|
||||
self._state_key: str = ""
|
||||
self._action_key: str = ""
|
||||
self._index_key: str = "index"
|
||||
self._task_key: str = "task"
|
||||
self._episode_index_key: str = "episode_index"
|
||||
|
||||
self.current_episode = -1
|
||||
|
||||
def _initialize_from_data(self, step_data: dict[str, Any]):
|
||||
"""Inspects the first data dictionary to discover components and set up the blueprint."""
|
||||
print("RerunLogger: First data packet received. Auto-configuring...")
|
||||
|
||||
image_keys = []
|
||||
for key, value in step_data.items():
|
||||
if key.startswith("observation.images.") and isinstance(value, torch.Tensor) and value.ndim > 2:
|
||||
image_keys.append(key)
|
||||
elif key == "observation.state":
|
||||
self._state_key = key
|
||||
elif key == "action":
|
||||
self._action_key = key
|
||||
|
||||
self._image_keys = tuple(sorted(image_keys))
|
||||
|
||||
if "index" in step_data:
|
||||
self._index_key = "index"
|
||||
elif "frame_index" in step_data:
|
||||
self._index_key = "frame_index"
|
||||
|
||||
print(f" - Using '{self._index_key}' for time sequence.")
|
||||
print(f" - Detected State Key: '{self._state_key}'")
|
||||
print(f" - Detected Action Key: '{self._action_key}'")
|
||||
print(f" - Detected Image Keys: {self._image_keys}")
|
||||
if self.idxrangeboundary:
|
||||
self.setup_blueprint()
|
||||
|
||||
def setup_blueprint(self):
|
||||
"""Sets up and sends the Rerun blueprint based on detected components."""
|
||||
views = []
|
||||
|
||||
for key in self._image_keys:
|
||||
clean_name = key.replace("observation.images.", "")
|
||||
entity_path = f"{self.prefix}images/{clean_name}"
|
||||
views.append(rrb.Spatial2DView(origin=entity_path, name=clean_name))
|
||||
|
||||
if self._state_key:
|
||||
entity_path = f"{self.prefix}state"
|
||||
views.append(
|
||||
rrb.TimeSeriesView(
|
||||
origin=entity_path,
|
||||
name="Observation State",
|
||||
time_ranges=[
|
||||
rrb.VisibleTimeRange(
|
||||
"frame",
|
||||
start=rrb.TimeRangeBoundary.cursor_relative(seq=-self.idxrangeboundary),
|
||||
end=rrb.TimeRangeBoundary.cursor_relative(),
|
||||
)
|
||||
],
|
||||
plot_legend=rrb.PlotLegend(visible=True),
|
||||
)
|
||||
)
|
||||
|
||||
if self._action_key:
|
||||
entity_path = f"{self.prefix}action"
|
||||
views.append(
|
||||
rrb.TimeSeriesView(
|
||||
origin=entity_path,
|
||||
name="Action",
|
||||
time_ranges=[
|
||||
rrb.VisibleTimeRange(
|
||||
"frame",
|
||||
start=rrb.TimeRangeBoundary.cursor_relative(seq=-self.idxrangeboundary),
|
||||
end=rrb.TimeRangeBoundary.cursor_relative(),
|
||||
)
|
||||
],
|
||||
plot_legend=rrb.PlotLegend(visible=True),
|
||||
)
|
||||
)
|
||||
|
||||
if not views:
|
||||
print("Warning: No visualizable components detected in the data.")
|
||||
return
|
||||
|
||||
grid = rrb.Grid(contents=views)
|
||||
rr.send_blueprint(grid)
|
||||
self.blueprint_sent = True
|
||||
|
||||
def log_step(self, step_data: dict[str, Any]):
|
||||
"""Logs a single step dictionary from your dataset."""
|
||||
if not self.blueprint_sent:
|
||||
self._initialize_from_data(step_data)
|
||||
|
||||
if self._index_key in step_data:
|
||||
current_index = step_data[self._index_key].item()
|
||||
rr.set_time_sequence("frame", current_index)
|
||||
|
||||
episode_idx = step_data.get(self._episode_index_key, torch.tensor(-1)).item()
|
||||
if episode_idx != self.current_episode:
|
||||
self.current_episode = episode_idx
|
||||
task_name = step_data.get(self._task_key, "Unknown Task")
|
||||
log_text = f"Starting Episode {self.current_episode}: {task_name}"
|
||||
rr.log(f"{self.prefix}info/task", rr.TextLog(log_text, level=rr.TextLogLevel.INFO))
|
||||
|
||||
for key in self._image_keys:
|
||||
if key in step_data:
|
||||
image_tensor = step_data[key]
|
||||
if image_tensor.ndim > 2:
|
||||
clean_name = key.replace("observation.images.", "")
|
||||
entity_path = f"{self.prefix}images/{clean_name}"
|
||||
if image_tensor.shape[0] in [1, 3, 4]:
|
||||
image_tensor = image_tensor.permute(1, 2, 0)
|
||||
rr.log(entity_path, rr.Image(image_tensor))
|
||||
|
||||
if self._state_key in step_data:
|
||||
state_tensor = step_data[self._state_key]
|
||||
entity_path = f"{self.prefix}state"
|
||||
for i, val in enumerate(state_tensor):
|
||||
rr.log(f"{entity_path}/joint_{i}", rr.Scalar(val.item()))
|
||||
|
||||
if self._action_key in step_data:
|
||||
action_tensor = step_data[self._action_key]
|
||||
entity_path = f"{self.prefix}action"
|
||||
for i, val in enumerate(action_tensor):
|
||||
rr.log(f"{entity_path}/joint_{i}", rr.Scalar(val.item()))
|
||||
|
||||
|
||||
def visualization_data(idx, observation, state, action, online_logger):
|
||||
item_data: dict[str, Any] = {
|
||||
"index": torch.tensor(idx),
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
}
|
||||
for k, v in observation.items():
|
||||
if k not in ("index", "observation.state", "action"):
|
||||
item_data[k] = v
|
||||
online_logger.log_step(item_data)
|
||||
209
eval_robot/utils/sim_savedata_utils.py
Normal file
209
eval_robot/utils/sim_savedata_utils.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# for simulation
|
||||
import torch
|
||||
import numpy as np
|
||||
import logging_mp
|
||||
from unitree_lerobot.eval_robot.utils.utils import (
|
||||
reset_policy,
|
||||
)
|
||||
from unitree_lerobot.eval_robot.make_robot import (
|
||||
publish_reset_category,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
import time
|
||||
|
||||
logging_mp.basic_config(level=logging_mp.INFO)
|
||||
logger_mp = logging_mp.get_logger(__name__)
|
||||
|
||||
|
||||
def process_data_add(episode_writer, observation_image, current_arm_q, ee_state, action, arm_dof, ee_dof):
|
||||
if episode_writer is None:
|
||||
return
|
||||
if (
|
||||
observation_image is not None
|
||||
and current_arm_q is not None
|
||||
and ee_state is not None
|
||||
and action is not None
|
||||
and arm_dof is not None
|
||||
and ee_dof is not None
|
||||
):
|
||||
# Convert tensors to numpy arrays for JSON serialization
|
||||
if torch.is_tensor(current_arm_q):
|
||||
current_arm_q = current_arm_q.detach().cpu().numpy()
|
||||
if torch.is_tensor(ee_state):
|
||||
ee_state = ee_state.detach().cpu().numpy()
|
||||
if torch.is_tensor(action):
|
||||
action = action.detach().cpu().numpy()
|
||||
colors = {}
|
||||
i = 0
|
||||
for key, value in observation_image.items():
|
||||
if "images" in key:
|
||||
if value is not None:
|
||||
# Convert PyTorch tensor to numpy array for OpenCV compatibility
|
||||
if torch.is_tensor(value):
|
||||
# Convert tensor to numpy array and ensure correct format for OpenCV
|
||||
img_array = value.detach().cpu().numpy()
|
||||
# If the image is in CHW format (channels first), convert to HWC format (channels last)
|
||||
if img_array.ndim == 3 and img_array.shape[0] in [1, 3, 4]:
|
||||
img_array = np.transpose(img_array, (1, 2, 0))
|
||||
# Ensure the array is in uint8 format for OpenCV
|
||||
if img_array.dtype != np.uint8:
|
||||
if img_array.max() <= 1.0: # Normalized values [0, 1]
|
||||
img_array = (img_array * 255).astype(np.uint8)
|
||||
else: # Values already in [0, 255] range
|
||||
img_array = img_array.astype(np.uint8)
|
||||
# Keep original RGB format - no color channel conversion needed
|
||||
colors[f"color_{i}"] = img_array
|
||||
else:
|
||||
colors[f"color_{i}"] = value
|
||||
i += 1
|
||||
states = {
|
||||
"left_arm": {
|
||||
"qpos": current_arm_q[: arm_dof // 2].tolist(), # numpy.array -> list
|
||||
"qvel": [],
|
||||
"torque": [],
|
||||
},
|
||||
"right_arm": {
|
||||
"qpos": current_arm_q[arm_dof // 2 :].tolist(),
|
||||
"qvel": [],
|
||||
"torque": [],
|
||||
},
|
||||
"left_ee": {
|
||||
"qpos": ee_state[:ee_dof].tolist(),
|
||||
"qvel": [],
|
||||
"torque": [],
|
||||
},
|
||||
"right_ee": {
|
||||
"qpos": ee_state[ee_dof:].tolist(),
|
||||
"qvel": [],
|
||||
"torque": [],
|
||||
},
|
||||
"body": {
|
||||
"qpos": [],
|
||||
},
|
||||
}
|
||||
actions = {
|
||||
"left_arm": {
|
||||
"qpos": action[: arm_dof // 2].tolist(),
|
||||
"qvel": [],
|
||||
"torque": [],
|
||||
},
|
||||
"right_arm": {
|
||||
"qpos": action[arm_dof // 2 :].tolist(),
|
||||
"qvel": [],
|
||||
"torque": [],
|
||||
},
|
||||
"left_ee": {
|
||||
"qpos": action[arm_dof : arm_dof + ee_dof].tolist(),
|
||||
"qvel": [],
|
||||
"torque": [],
|
||||
},
|
||||
"right_ee": {
|
||||
"qpos": action[arm_dof + ee_dof : arm_dof + 2 * ee_dof].tolist(),
|
||||
"qvel": [],
|
||||
"torque": [],
|
||||
},
|
||||
"body": {
|
||||
"qpos": [],
|
||||
},
|
||||
}
|
||||
episode_writer.add_item(colors, states=states, actions=actions)
|
||||
|
||||
|
||||
def process_data_save(episode_writer, result):
|
||||
"""Processes data and saves it."""
|
||||
if episode_writer is None:
|
||||
return
|
||||
episode_writer.save_episode(result)
|
||||
|
||||
|
||||
def is_success(
|
||||
sim_reward_subscriber,
|
||||
episode_writer,
|
||||
reset_pose_publisher,
|
||||
policy,
|
||||
cfg,
|
||||
reward_stats,
|
||||
init_arm_pose,
|
||||
robot_interface,
|
||||
):
|
||||
# logger_mp.info(f"arm_action {arm_action}, tau {tau}")
|
||||
if sim_reward_subscriber:
|
||||
data = sim_reward_subscriber.read_data()
|
||||
if data is not None:
|
||||
if int(data["rewards"][0]) == 1:
|
||||
reward_stats["reward_sum"] += 1
|
||||
sim_reward_subscriber.reset_data()
|
||||
# success
|
||||
if reward_stats["reward_sum"] >= 25:
|
||||
process_data_save(episode_writer, "success")
|
||||
logger_mp.info(
|
||||
f"Episode {reward_stats['episode_num']} finished with reward {reward_stats['reward_sum']},save data..."
|
||||
)
|
||||
reward_stats["episode_num"] = -1
|
||||
reward_stats["reward_sum"] = 0
|
||||
time.sleep(1)
|
||||
publish_reset_category(1, reset_pose_publisher)
|
||||
time.sleep(1)
|
||||
reset_policy(policy)
|
||||
sim_reward_subscriber.reset_data()
|
||||
# fail
|
||||
elif reward_stats["episode_num"] > cfg.max_episodes:
|
||||
process_data_save(episode_writer, "fail")
|
||||
logger_mp.info(f"Episode {reward_stats['episode_num']} finished with reward {reward_stats['reward_sum']}")
|
||||
reward_stats["episode_num"] = -1
|
||||
reward_stats["reward_sum"] = 0
|
||||
reset_policy(policy)
|
||||
sim_reward_subscriber.reset_data()
|
||||
logger_mp.info("Initializing robot to starting pose...")
|
||||
tau = robot_interface["arm_ik"].solve_tau(init_arm_pose)
|
||||
robot_interface["arm_ctrl"].ctrl_dual_arm(init_arm_pose, tau)
|
||||
time.sleep(1)
|
||||
publish_reset_category(1, reset_pose_publisher)
|
||||
time.sleep(1)
|
||||
reset_policy(policy)
|
||||
sim_reward_subscriber.reset_data()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalRealConfig:
|
||||
repo_id: str
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
root: str = ""
|
||||
episodes: int = 0
|
||||
frequency: float = 30.0
|
||||
|
||||
# Basic control parameters
|
||||
arm: str = "G1_29" # G1_29, G1_23
|
||||
ee: str = "dex3" # dex3, dex1, inspire1, brainco
|
||||
|
||||
# Mode flags
|
||||
motion: bool = False
|
||||
headless: bool = False
|
||||
sim: bool = True
|
||||
visualization: bool = False
|
||||
send_real_robot: bool = False
|
||||
use_dataset: bool = False
|
||||
save_data: bool = False
|
||||
task_dir: str = "./data"
|
||||
max_episodes: int = 1200
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
else:
|
||||
logger_mp.warning(
|
||||
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
402
eval_robot/utils/sim_state_topic.py
Normal file
402
eval_robot/utils/sim_state_topic.py
Normal file
@@ -0,0 +1,402 @@
|
||||
# Copyright (c) 2025, Unitree Robotics Co., Ltd. All Rights Reserved.
|
||||
# License: Apache License, Version 2.0
|
||||
"""
|
||||
Simple sim state subscriber class
|
||||
Subscribe to rt/sim_state_cmd topic and write to shared memory
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import json
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Any
|
||||
from unitree_sdk2py.core.channel import ChannelSubscriber
|
||||
from unitree_sdk2py.idl.std_msgs.msg.dds_ import String_
|
||||
|
||||
import logging_mp
|
||||
|
||||
logger_mp = logging_mp.get_logger(__name__)
|
||||
|
||||
|
||||
class SharedMemoryManager:
|
||||
"""Shared memory manager"""
|
||||
|
||||
def __init__(self, name: str | None = None, size: int = 512):
|
||||
"""Initialize shared memory manager
|
||||
|
||||
Args:
|
||||
name: shared memory name, if None, create new one
|
||||
size: shared memory size (bytes)
|
||||
"""
|
||||
self.size = size
|
||||
self.lock = threading.RLock() # reentrant lock
|
||||
|
||||
if name:
|
||||
try:
|
||||
self.shm = shared_memory.SharedMemory(name=name)
|
||||
self.shm_name = name
|
||||
self.created = False
|
||||
except FileNotFoundError:
|
||||
self.shm = shared_memory.SharedMemory(create=True, size=size)
|
||||
self.shm_name = self.shm.name
|
||||
self.created = True
|
||||
else:
|
||||
self.shm = shared_memory.SharedMemory(create=True, size=size)
|
||||
self.shm_name = self.shm.name
|
||||
self.created = True
|
||||
|
||||
def write_data(self, data: dict[str, Any]) -> bool:
|
||||
"""Write data to shared memory
|
||||
|
||||
Args:
|
||||
data: data to write
|
||||
|
||||
Returns:
|
||||
bool: write success or not
|
||||
"""
|
||||
try:
|
||||
with self.lock:
|
||||
json_str = json.dumps(data)
|
||||
json_bytes = json_str.encode("utf-8")
|
||||
|
||||
if len(json_bytes) > self.size - 8: # reserve 8 bytes for length and timestamp
|
||||
logger_mp.warning(f"Data too large for shared memory ({len(json_bytes)} > {self.size - 8})")
|
||||
return False
|
||||
|
||||
# write timestamp (4 bytes) and data length (4 bytes)
|
||||
timestamp = int(time.time()) & 0xFFFFFFFF # 32-bit timestamp, use bitmask to ensure in range
|
||||
self.shm.buf[0:4] = timestamp.to_bytes(4, "little")
|
||||
self.shm.buf[4:8] = len(json_bytes).to_bytes(4, "little")
|
||||
|
||||
# write data
|
||||
self.shm.buf[8 : 8 + len(json_bytes)] = json_bytes
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger_mp.error(f"Error writing to shared memory: {e}")
|
||||
return False
|
||||
|
||||
def read_data(self) -> dict[str, Any] | None:
|
||||
"""Read data from shared memory
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: read data dictionary, return None if failed
|
||||
"""
|
||||
try:
|
||||
with self.lock:
|
||||
# read timestamp and data length
|
||||
timestamp = int.from_bytes(self.shm.buf[0:4], "little")
|
||||
data_len = int.from_bytes(self.shm.buf[4:8], "little")
|
||||
|
||||
if data_len == 0:
|
||||
return None
|
||||
|
||||
# read data
|
||||
json_bytes = bytes(self.shm.buf[8 : 8 + data_len])
|
||||
data = json.loads(json_bytes.decode("utf-8"))
|
||||
data["_timestamp"] = timestamp # add timestamp information
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
logger_mp.error(f"Error reading from shared memory: {e}")
|
||||
return None
|
||||
|
||||
def reset_data(self):
|
||||
"""Reset data"""
|
||||
if self.shm:
|
||||
self.shm.buf[0:8] = b"\x00" * 8
|
||||
else:
|
||||
logger_mp.error("[SharedMemoryManager] Shared memory is not initialized")
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Get shared memory name"""
|
||||
return self.shm_name
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up shared memory"""
|
||||
if hasattr(self, "shm") and self.shm:
|
||||
self.shm.close()
|
||||
if self.created:
|
||||
self.shm.unlink()
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor"""
|
||||
self.cleanup()
|
||||
|
||||
|
||||
class SimStateSubscriber:
|
||||
"""Simple sim state subscriber class"""
|
||||
|
||||
def __init__(self, shm_name: str = "sim_state_cmd_data", shm_size: int = 4096):
|
||||
"""Initialize the subscriber
|
||||
|
||||
Args:
|
||||
shm_name: shared memory name
|
||||
shm_size: shared memory size
|
||||
"""
|
||||
self.shm_name = shm_name
|
||||
self.shm_size = shm_size
|
||||
self.running = False
|
||||
self.subscriber = None
|
||||
self.subscribe_thread = None
|
||||
self.shared_memory = None
|
||||
|
||||
# initialize shared memory
|
||||
self._setup_shared_memory()
|
||||
|
||||
logger_mp.debug(f"[SimStateSubscriber] Initialized with shared memory: {shm_name}")
|
||||
|
||||
def _setup_shared_memory(self):
|
||||
"""Setup shared memory"""
|
||||
try:
|
||||
self.shared_memory = SharedMemoryManager(self.shm_name, self.shm_size)
|
||||
logger_mp.debug("[SimStateSubscriber] Shared memory setup successfully")
|
||||
except Exception as e:
|
||||
logger_mp.error(f"[SimStateSubscriber] Failed to setup shared memory: {e}")
|
||||
|
||||
def start_subscribe(self):
|
||||
"""Start subscribing"""
|
||||
if self.running:
|
||||
logger_mp.warning("[SimStateSubscriber] Already running")
|
||||
return
|
||||
|
||||
try:
|
||||
self.subscriber = ChannelSubscriber("rt/sim_state", String_)
|
||||
self.subscriber.Init()
|
||||
self.running = True
|
||||
|
||||
self.subscribe_thread = threading.Thread(target=self._subscribe_sim_state, daemon=True)
|
||||
self.subscribe_thread.start()
|
||||
|
||||
logger_mp.info("[SimStateSubscriber] Started subscribing to rt/sim_state")
|
||||
|
||||
except Exception as e:
|
||||
logger_mp.error(f"[SimStateSubscriber] Failed to start subscribing: {e}")
|
||||
self.running = False
|
||||
|
||||
def _subscribe_sim_state(self):
|
||||
"""Subscribe loop thread"""
|
||||
logger_mp.debug("[SimStateSubscriber] Subscribe thread started")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
if self.subscriber:
|
||||
msg = self.subscriber.Read()
|
||||
if msg:
|
||||
data = json.loads(msg.data)
|
||||
else:
|
||||
logger_mp.warning("[SimStateSubscriber] Received None message")
|
||||
if self.shared_memory and data:
|
||||
self.shared_memory.write_data(data)
|
||||
else:
|
||||
logger_mp.error("[SimStateSubscriber] Subscriber is not initialized")
|
||||
time.sleep(0.002)
|
||||
except Exception as e:
|
||||
logger_mp.error(f"[SimStateSubscriber] Error in subscribe loop: {e}")
|
||||
time.sleep(0.01)
|
||||
|
||||
def stop_subscribe(self):
|
||||
"""Stop subscribing"""
|
||||
if not self.running:
|
||||
logger_mp.warning("[SimStateSubscriber] Already stopped or not running")
|
||||
return
|
||||
|
||||
self.running = False
|
||||
# wait for thread to finish
|
||||
if self.subscribe_thread:
|
||||
self.subscribe_thread.join(timeout=1.0)
|
||||
|
||||
if self.shared_memory:
|
||||
self.shared_memory.cleanup()
|
||||
logger_mp.info("[SimStateSubscriber] Subscriber stopped")
|
||||
|
||||
def read_data(self) -> dict[str, Any] | None:
|
||||
"""Read data from shared memory
|
||||
|
||||
Returns:
|
||||
Dict: received data, None if no data or error
|
||||
"""
|
||||
try:
|
||||
if self.shared_memory:
|
||||
return self.shared_memory.read_data()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger_mp.error(f"[SimStateSubscriber] Error reading data: {e}")
|
||||
return None
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if subscriber is running"""
|
||||
return self.running
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor"""
|
||||
self.stop_subscribe()
|
||||
|
||||
|
||||
def start_sim_state_subscribe(shm_name: str = "sim_state_cmd_data", shm_size: int = 4096) -> SimStateSubscriber:
|
||||
"""Start sim state subscribing
|
||||
|
||||
Args:
|
||||
shm_name: shared memory name
|
||||
shm_size: shared memory size
|
||||
|
||||
Returns:
|
||||
SimStateSubscriber: started subscriber instance
|
||||
"""
|
||||
subscriber = SimStateSubscriber(shm_name, shm_size)
|
||||
subscriber.start_subscribe()
|
||||
return subscriber
|
||||
|
||||
|
||||
# ============================== sim reward topic ==============================
|
||||
class SimRewardSubscriber:
|
||||
"""Simple sim state subscriber class"""
|
||||
|
||||
def __init__(self, shm_name: str = "sim_reward_cmd_data", shm_size: int = 256):
|
||||
"""Initialize the subscriber
|
||||
|
||||
Args:
|
||||
shm_name: shared memory name
|
||||
shm_size: shared memory size
|
||||
"""
|
||||
self.shm_name = shm_name
|
||||
self.shm_size = shm_size
|
||||
self.running = False
|
||||
self.subscriber = None
|
||||
self.subscribe_thread = None
|
||||
self.shared_memory = None
|
||||
|
||||
# initialize shared memory
|
||||
self._setup_shared_memory()
|
||||
|
||||
logger_mp.debug(f"[SimRewardSubscriber] Initialized with shared memory: {shm_name}")
|
||||
|
||||
def _setup_shared_memory(self):
|
||||
"""Setup shared memory"""
|
||||
try:
|
||||
self.shared_memory = SharedMemoryManager(self.shm_name, self.shm_size)
|
||||
logger_mp.debug("[SimRewardSubscriber] Shared memory setup successfully")
|
||||
except Exception as e:
|
||||
logger_mp.error(f"[SimRewardSubscriber] Failed to setup shared memory: {e}")
|
||||
|
||||
def start_subscribe(self):
|
||||
"""Start subscribing"""
|
||||
if self.running:
|
||||
logger_mp.warning("[SimRewardSubscriber] Already running")
|
||||
return
|
||||
|
||||
try:
|
||||
self.subscriber = ChannelSubscriber("rt/rewards_state", String_)
|
||||
self.subscriber.Init()
|
||||
self.running = True
|
||||
|
||||
self.subscribe_thread = threading.Thread(target=self._subscribe_sim_reward, daemon=True)
|
||||
self.subscribe_thread.start()
|
||||
|
||||
logger_mp.info("[SimRewardSubscriber] Started subscribing to rt/sim_reward")
|
||||
|
||||
except Exception as e:
|
||||
logger_mp.error(f"[SimRewardSubscriber] Failed to start subscribing: {e}")
|
||||
self.running = False
|
||||
|
||||
def _subscribe_sim_reward(self):
|
||||
"""Subscribe loop thread"""
|
||||
logger_mp.debug("[SimRewardSubscriber] Subscribe thread started")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
if self.subscriber:
|
||||
msg = self.subscriber.Read()
|
||||
if msg:
|
||||
data = json.loads(msg.data)
|
||||
else:
|
||||
logger_mp.warning("[SimRewardSubscriber] Received None message")
|
||||
if self.shared_memory and data:
|
||||
self.shared_memory.write_data(data)
|
||||
else:
|
||||
logger_mp.error("[SimRewardSubscriber] Subscriber is not initialized")
|
||||
time.sleep(0.01)
|
||||
except Exception as e:
|
||||
logger_mp.error(f"[SimRewardSubscriber] Error in subscribe loop: {e}")
|
||||
time.sleep(0.02)
|
||||
|
||||
def stop_subscribe(self):
|
||||
"""Stop subscribing"""
|
||||
if not self.running:
|
||||
logger_mp.warning("[SimRewardSubscriber] Already stopped or not running")
|
||||
return
|
||||
|
||||
self.running = False
|
||||
# wait for thread to finish
|
||||
if self.subscribe_thread:
|
||||
self.subscribe_thread.join(timeout=1.0)
|
||||
|
||||
if self.shared_memory:
|
||||
self.shared_memory.cleanup()
|
||||
logger_mp.info("[SimRewardSubscriber] Subscriber stopped")
|
||||
|
||||
def read_data(self) -> dict[str, Any] | None:
|
||||
"""Read data from shared memory
|
||||
|
||||
Returns:
|
||||
Dict: received data, None if no data or error
|
||||
"""
|
||||
try:
|
||||
if self.shared_memory:
|
||||
return self.shared_memory.read_data()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger_mp.error(f"[SimRewardSubscriber] Error reading data: {e}")
|
||||
return None
|
||||
|
||||
def reset_data(self):
|
||||
"""Reset data"""
|
||||
if self.shared_memory:
|
||||
data = {"rewards": [0.0], "timestamp": 1758009108.266387}
|
||||
self.shared_memory.write_data(data)
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if subscriber is running"""
|
||||
return self.running
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor"""
|
||||
self.stop_subscribe()
|
||||
|
||||
|
||||
# ============================== sim reward topic ==============================
|
||||
def start_sim_reward_subscribe(shm_name: str = "sim_reward_cmd_data", shm_size: int = 256) -> SimRewardSubscriber:
|
||||
"""Start sim reward subscribing
|
||||
|
||||
Args:
|
||||
shm_name: shared memory name
|
||||
shm_size: shared memory size
|
||||
|
||||
Returns:
|
||||
SimRewardSubscriber: started subscriber instance
|
||||
"""
|
||||
subscriber = SimRewardSubscriber(shm_name, shm_size)
|
||||
subscriber.start_subscribe()
|
||||
return subscriber
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# # example usage
|
||||
# logger_mp.info("Starting sim state subscriber...")
|
||||
# ChannelFactoryInitialize(0)
|
||||
# # create and start subscriber
|
||||
# subscriber = start_sim_state_subscribe()
|
||||
|
||||
# try:
|
||||
# # keep running and check for data
|
||||
# while True:
|
||||
# data = subscriber.read_data()
|
||||
# if data:
|
||||
# logger_mp.info(f"Read data: {data}")
|
||||
# time.sleep(1)
|
||||
|
||||
# except KeyboardInterrupt:
|
||||
# logger_mp.warning("\nInterrupted by user")
|
||||
# finally:
|
||||
# subscriber.stop_subscribe()
|
||||
# logger_mp.info("Subscriber stopped")
|
||||
142
eval_robot/utils/utils.py
Normal file
142
eval_robot/utils/utils.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Any
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
import logging_mp
|
||||
|
||||
logging_mp.basic_config(level=logging_mp.INFO)
|
||||
logger_mp = logging_mp.get_logger(__name__)
|
||||
|
||||
|
||||
def extract_observation(step: dict):
|
||||
observation = {}
|
||||
|
||||
for key, value in step.items():
|
||||
if key.startswith("observation.images."):
|
||||
if isinstance(value, np.ndarray) and value.ndim == 3 and value.shape[-1] in [1, 3]:
|
||||
value = np.transpose(value, (2, 0, 1))
|
||||
observation[key] = value
|
||||
|
||||
elif key == "observation.state":
|
||||
observation[key] = value
|
||||
|
||||
return observation
|
||||
|
||||
|
||||
def predict_action(
|
||||
observation: dict[str, np.ndarray],
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
use_amp: bool,
|
||||
task: str | None = None,
|
||||
use_dataset: bool | None = False,
|
||||
):
|
||||
observation = copy(observation)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
if not use_dataset:
|
||||
# Skip non-tensor observations (like task strings)
|
||||
if not hasattr(observation[name], "unsqueeze"):
|
||||
continue
|
||||
if "images" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
|
||||
observation[name] = observation[name].unsqueeze(0).to(device)
|
||||
|
||||
observation["task"] = [task if task else ""]
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
# Move to cpu, if not already the case
|
||||
action = action.to("cpu")
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def reset_policy(policy: PreTrainedPolicy):
|
||||
policy.reset()
|
||||
|
||||
|
||||
def cleanup_resources(image_info: dict[str, Any]):
|
||||
"""Safely close and unlink shared memory resources."""
|
||||
logger_mp.info("Cleaning up shared memory resources.")
|
||||
for shm in image_info["shm_resources"]:
|
||||
if shm:
|
||||
shm.close()
|
||||
shm.unlink()
|
||||
|
||||
|
||||
def to_list(x):
|
||||
if torch is not None and isinstance(x, torch.Tensor):
|
||||
return x.detach().cpu().ravel().tolist()
|
||||
if isinstance(x, np.ndarray):
|
||||
return x.ravel().tolist()
|
||||
if isinstance(x, (list, tuple)):
|
||||
return list(x)
|
||||
return [x]
|
||||
|
||||
|
||||
def to_scalar(x):
|
||||
if torch is not None and isinstance(x, torch.Tensor):
|
||||
return float(x.detach().cpu().ravel()[0].item())
|
||||
if isinstance(x, np.ndarray):
|
||||
return float(x.ravel()[0])
|
||||
if isinstance(x, (list, tuple)):
|
||||
return float(x[0])
|
||||
return float(x)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalRealConfig:
|
||||
repo_id: str
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
root: str = ""
|
||||
episodes: int = 0
|
||||
frequency: float = 30.0
|
||||
|
||||
# Basic control parameters
|
||||
arm: str = "G1_29" # G1_29, G1_23
|
||||
ee: str = "dex3" # dex3, dex1, inspire1, brainco
|
||||
|
||||
# Mode flags
|
||||
motion: bool = False
|
||||
headless: bool = False
|
||||
visualization: bool = False
|
||||
send_real_robot: bool = False
|
||||
use_dataset: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
else:
|
||||
logging.warning(
|
||||
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
99
eval_robot/utils/weighted_moving_filter.py
Normal file
99
eval_robot/utils/weighted_moving_filter.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class WeightedMovingFilter:
|
||||
def __init__(self, weights, data_size=14):
|
||||
self._window_size = len(weights)
|
||||
self._weights = np.array(weights)
|
||||
# assert np.isclose(np.sum(self._weights), 1.0), "[WeightedMovingFilter] the sum of weights list must be 1.0!"
|
||||
self._data_size = data_size
|
||||
self._filtered_data = np.zeros(self._data_size)
|
||||
self._data_queue = []
|
||||
|
||||
def _apply_filter(self):
|
||||
if len(self._data_queue) < self._window_size:
|
||||
return self._data_queue[-1]
|
||||
|
||||
data_array = np.array(self._data_queue)
|
||||
temp_filtered_data = np.zeros(self._data_size)
|
||||
for i in range(self._data_size):
|
||||
temp_filtered_data[i] = np.convolve(data_array[:, i], self._weights, mode="valid")[-1]
|
||||
|
||||
return temp_filtered_data
|
||||
|
||||
def add_data(self, new_data):
|
||||
assert len(new_data) == self._data_size
|
||||
|
||||
if len(self._data_queue) > 0 and np.array_equal(new_data, self._data_queue[-1]):
|
||||
return # skip duplicate data
|
||||
|
||||
if len(self._data_queue) >= self._window_size:
|
||||
self._data_queue.pop(0)
|
||||
|
||||
self._data_queue.append(new_data)
|
||||
self._filtered_data = self._apply_filter()
|
||||
|
||||
@property
|
||||
def filtered_data(self):
|
||||
return self._filtered_data
|
||||
|
||||
|
||||
def visualize_filter_comparison(filter_params, steps):
|
||||
import time
|
||||
|
||||
t = np.linspace(0, 4 * np.pi, steps)
|
||||
original_data = np.array(
|
||||
[np.sin(t + i) + np.random.normal(0, 0.2, len(t)) for i in range(35)]
|
||||
).T # sin wave with noise, shape is [len(t), 35]
|
||||
|
||||
plt.figure(figsize=(14, 10))
|
||||
|
||||
for idx, weights in enumerate(filter_params):
|
||||
filter = WeightedMovingFilter(weights, 14)
|
||||
data_2b_filtered = original_data.copy()
|
||||
filtered_data = []
|
||||
|
||||
time1 = time.time()
|
||||
|
||||
for i in range(steps):
|
||||
filter.add_data(data_2b_filtered[i][13:27]) # step i, columns 13 to 26 (total:14)
|
||||
data_2b_filtered[i][13:27] = filter.filtered_data
|
||||
filtered_data.append(data_2b_filtered[i])
|
||||
|
||||
time2 = time.time()
|
||||
print(f"filter_params:{filter_params[idx]}, time cosume:{time2 - time1}")
|
||||
|
||||
filtered_data = np.array(filtered_data)
|
||||
|
||||
# col0 should not 2b filtered
|
||||
plt.subplot(len(filter_params), 2, idx * 2 + 1)
|
||||
plt.plot(filtered_data[:, 0], label=f"Filtered (Window {filter._window_size})")
|
||||
plt.plot(original_data[:, 0], "r--", label="Original", alpha=0.5)
|
||||
plt.title("Joint 1 - Should not to be filtered.")
|
||||
plt.xlabel("Step")
|
||||
plt.ylabel("Value")
|
||||
plt.legend()
|
||||
|
||||
# col13 should 2b filtered
|
||||
plt.subplot(len(filter_params), 2, idx * 2 + 2)
|
||||
plt.plot(filtered_data[:, 13], label=f"Filtered (Window {filter._window_size})")
|
||||
plt.plot(original_data[:, 13], "r--", label="Original", alpha=0.5)
|
||||
plt.title(f"Joint 13 - Window {filter._window_size}, Weights {weights}")
|
||||
plt.xlabel("Step")
|
||||
plt.ylabel("Value")
|
||||
plt.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# windows_size and weights
|
||||
filter_params = [
|
||||
(np.array([0.7, 0.2, 0.1])),
|
||||
(np.array([0.5, 0.3, 0.2])),
|
||||
(np.array([0.4, 0.3, 0.2, 0.1])),
|
||||
]
|
||||
|
||||
visualize_filter_comparison(filter_params, steps=100)
|
||||
Reference in New Issue
Block a user