Files
lerobot-clone/lerobot/scripts/server/gym_manipulator.py
2025-04-25 16:34:54 +02:00

1505 lines
54 KiB
Python

import logging
import sys
import time
from collections import deque
from threading import Lock
from typing import Annotated, Any, Dict, Sequence, Tuple
import gymnasium as gym
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.robot_devices.control_utils import (
busy_wait,
is_headless,
reset_follower_position,
)
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.common.utils.utils import log_say
from lerobot.configs import parser
from lerobot.scripts.server.kinematics import RobotKinematics
logging.basicConfig(level=logging.INFO)
MAX_GRIPPER_COMMAND = 40
class RobotEnv(gym.Env):
"""
Gym-compatible environment for evaluating robotic control policies with integrated human intervention.
This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta)
and absolute joint position commands and automatically configures its observation and action spaces based on the robot's
sensors and configuration.
"""
def __init__(
self,
robot,
display_cameras: bool = False,
):
"""
Initialize the RobotEnv environment.
The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup
supports both relative (delta) adjustments and absolute joint positions for controlling the robot.
cfg.
robot: The robot interface object used to connect and interact with the physical robot.
display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
"""
super().__init__()
self.robot = robot
self.display_cameras = display_cameras
# Connect to the robot if not already connected.
if not self.robot.is_connected:
self.robot.connect()
# Episode tracking.
self.current_step = 0
self.episode_data = None
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
self._setup_spaces()
def _setup_spaces(self):
"""
Dynamically configure the observation and action spaces based on the robot's capabilities.
Observation Space:
- For keys with "image": A Box space with pixel values ranging from 0 to 255.
- For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range.
Action Space:
- The action space is defined as a Box space representing joint position commands. It is defined as relative (delta)
or absolute, based on the configuration.
"""
example_obs = self.robot.capture_observation()
# Define observation spaces for images and other states.
image_keys = [key for key in example_obs if "image" in key]
observation_spaces = {
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
for key in image_keys
}
observation_spaces["observation.state"] = gym.spaces.Box(
low=0,
high=10,
shape=example_obs["observation.state"].shape,
dtype=np.float32,
)
self.observation_space = gym.spaces.Dict(observation_spaces)
# Define the action space for joint positions along with setting an intervention flag.
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
bounds = {}
bounds["min"] = np.ones(action_dim) * -1000
bounds["max"] = np.ones(action_dim) * 1000
self.action_space = gym.spaces.Box(
low=bounds["min"],
high=bounds["max"],
shape=(action_dim,),
dtype=np.float32,
)
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
"""
Reset the environment to its initial state.
This method resets the step counter and clears any episodic data.
cfg.
seed (Optional[int]): A seed for random number generation to ensure reproducibility.
options (Optional[dict]): Additional options to influence the reset behavior.
Returns:
A tuple containing:
- observation (dict): The initial sensor observation.
- info (dict): A dictionary with supplementary information, including the key "initial_position".
"""
super().reset(seed=seed, options=options)
# Capture the initial observation.
observation = self.robot.capture_observation()
# Reset episode tracking variables.
self.current_step = 0
self.episode_data = None
return observation, {"is_intervention": False}
def step(self, action) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]:
"""
Execute a single step within the environment using the specified action.
The provided action is processed and sent to the robot as joint position commands
that may be either absolute values or deltas based on the environment configuration.
cfg.
action (np.ndarray or torch.Tensor): The commanded joint positions.
Returns:
tuple: A tuple containing:
- observation (dict): The new sensor observation after taking the step.
- reward (float): The step reward (default is 0.0 within this wrapper).
- terminated (bool): True if the episode has reached a terminal state.
- truncated (bool): True if the episode was truncated (e.g., time constraints).
- info (dict): Additional debugging information including:
"""
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
self.robot.send_action(torch.from_numpy(action))
observation = self.robot.capture_observation()
if self.display_cameras:
self.render()
self.current_step += 1
reward = 0.0
terminated = False
truncated = False
return (
observation,
reward,
terminated,
truncated,
{"is_intervention": False},
)
def render(self):
"""
Render the current state of the environment by displaying the robot's camera feeds.
"""
import cv2
observation = self.robot.capture_observation()
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
def close(self):
"""
Close the environment and clean up resources by disconnecting the robot.
If the robot is currently connected, this method properly terminates the connection to ensure that all
associated resources are released.
"""
if self.robot.is_connected:
self.robot.disconnect()
class AddJointVelocityToObservation(gym.ObservationWrapper):
def __init__(self, env, joint_velocity_limits=100.0, fps=30, num_dof=6):
super().__init__(env)
# Extend observation space to include joint velocities
old_low = self.observation_space["observation.state"].low
old_high = self.observation_space["observation.state"].high
old_shape = self.observation_space["observation.state"].shape
self.last_joint_positions = np.zeros(num_dof)
new_low = np.concatenate([old_low, np.ones(num_dof) * -joint_velocity_limits])
new_high = np.concatenate([old_high, np.ones(num_dof) * joint_velocity_limits])
new_shape = (old_shape[0] + num_dof,)
self.observation_space["observation.state"] = gym.spaces.Box(
low=new_low,
high=new_high,
shape=new_shape,
dtype=np.float32,
)
self.dt = 1.0 / fps
def observation(self, observation):
joint_velocities = (observation["observation.state"] - self.last_joint_positions) / self.dt
self.last_joint_positions = observation["observation.state"].clone()
observation["observation.state"] = torch.cat(
[observation["observation.state"], joint_velocities], dim=-1
)
return observation
class AddCurrentToObservation(gym.ObservationWrapper):
def __init__(self, env, max_current=500, num_dof=6):
super().__init__(env)
# Extend observation space to include joint velocities
old_low = self.observation_space["observation.state"].low
old_high = self.observation_space["observation.state"].high
old_shape = self.observation_space["observation.state"].shape
new_low = np.concatenate([old_low, np.zeros(num_dof)])
new_high = np.concatenate([old_high, np.ones(num_dof) * max_current])
new_shape = (old_shape[0] + num_dof,)
self.observation_space["observation.state"] = gym.spaces.Box(
low=new_low,
high=new_high,
shape=new_shape,
dtype=np.float32,
)
def observation(self, observation):
present_current = (
self.unwrapped.robot.follower_arms["main"].read("Present_Current").astype(np.float32)
)
observation["observation.state"] = torch.cat(
[observation["observation.state"], torch.from_numpy(present_current)], dim=-1
)
return observation
class RewardWrapper(gym.Wrapper):
def __init__(self, env, reward_classifier, device: torch.device = "cuda"):
"""
Wrapper to add reward prediction to the environment, it use a trained classifier.
cfg.
env: The environment to wrap
reward_classifier: The reward classifier model
device: The device to run the model on
"""
self.env = env
self.device = device
self.reward_classifier = torch.compile(reward_classifier)
self.reward_classifier.to(self.device)
def step(self, action):
observation, _, terminated, truncated, info = self.env.step(action)
images = {
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
for key in observation
if "image" in key
}
start_time = time.perf_counter()
with torch.inference_mode():
success = (
self.reward_classifier.predict_reward(images, threshold=0.8)
if self.reward_classifier is not None
else 0.0
)
info["Reward classifier frequency"] = 1 / (time.perf_counter() - start_time)
if success == 1.0:
terminated = True
reward = 1.0
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
return self.env.reset(seed=seed, options=options)
class TimeLimitWrapper(gym.Wrapper):
def __init__(self, env, control_time_s, fps):
self.env = env
self.control_time_s = control_time_s
self.fps = fps
self.last_timestamp = 0.0
self.episode_time_in_s = 0.0
self.max_episode_steps = int(self.control_time_s * self.fps)
self.current_step = 0
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
time_since_last_step = time.perf_counter() - self.last_timestamp
self.episode_time_in_s += time_since_last_step
self.last_timestamp = time.perf_counter()
self.current_step += 1
# check if last timestep took more time than the expected fps
if 1.0 / time_since_last_step < self.fps:
logging.debug(f"Current timestep exceeded expected fps {self.fps}")
if self.current_step >= self.max_episode_steps:
terminated = True
return obs, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
self.episode_time_in_s = 0.0
self.last_timestamp = time.perf_counter()
self.current_step = 0
return self.env.reset(seed=seed, options=options)
class ImageCropResizeWrapper(gym.Wrapper):
def __init__(
self,
env,
crop_params_dict: Dict[str, Annotated[Tuple[int], 4]],
resize_size=None,
):
super().__init__(env)
self.env = env
self.crop_params_dict = crop_params_dict
print(f"obs_keys , {self.env.observation_space}")
print(f"crop params dict {crop_params_dict.keys()}")
for key_crop in crop_params_dict:
if key_crop not in self.env.observation_space.keys(): # noqa: SIM118
raise ValueError(f"Key {key_crop} not in observation space")
for key in crop_params_dict:
new_shape = (3, resize_size[0], resize_size[1])
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
self.resize_size = resize_size
if self.resize_size is None:
self.resize_size = (128, 128)
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
for k in self.crop_params_dict:
device = obs[k].device
if obs[k].dim() >= 3:
# Reshape to combine height and width dimensions for easier calculation
batch_size = obs[k].size(0)
channels = obs[k].size(1)
flattened_spatial_dims = obs[k].view(batch_size, channels, -1)
# Calculate standard deviation across spatial dimensions (H, W)
# If any channel has std=0, all pixels in that channel have the same value
# This is helpful if one camera mistakenly covered or the image is black
std_per_channel = torch.std(flattened_spatial_dims, dim=2)
if (std_per_channel <= 0.02).any():
logging.warning(
f"Potential hardware issue detected: All pixels have the same value in observation {k}"
)
if device == torch.device("mps:0"):
obs[k] = obs[k].cpu()
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
obs[k] = F.resize(obs[k], self.resize_size)
# TODO (michel-aractingi): Bug in resize, it returns values outside [0, 1]
obs[k] = obs[k].clamp(0.0, 1.0)
obs[k] = obs[k].to(device)
return obs, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options=options)
for k in self.crop_params_dict:
device = obs[k].device
if device == torch.device("mps:0"):
obs[k] = obs[k].cpu()
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
obs[k] = F.resize(obs[k], self.resize_size)
obs[k] = obs[k].clamp(0.0, 1.0)
obs[k] = obs[k].to(device)
return obs, info
class ConvertToLeRobotObservation(gym.ObservationWrapper):
def __init__(self, env, device: str = "cpu"):
super().__init__(env)
self.device = torch.device(device)
def observation(self, observation):
for key in observation:
observation[key] = observation[key].float()
if "image" in key:
observation[key] = observation[key].permute(2, 0, 1)
observation[key] /= 255.0
observation = {
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
for key in observation
}
return observation
class ResetWrapper(gym.Wrapper):
def __init__(
self,
env: RobotEnv,
reset_pose: np.ndarray | None = None,
reset_time_s: float = 5,
):
super().__init__(env)
self.reset_time_s = reset_time_s
self.reset_pose = reset_pose
self.robot = self.unwrapped.robot
def reset(self, *, seed=None, options=None):
start_time = time.perf_counter()
if self.reset_pose is not None:
log_say("Reset the environment.", play_sounds=True)
reset_follower_position(self.robot.follower_arms["main"], self.reset_pose)
log_say("Reset the environment done.", play_sounds=True)
if len(self.robot.leader_arms) > 0:
self.robot.leader_arms["main"].write("Torque_Enable", 1)
log_say("Reset the leader robot.", play_sounds=True)
reset_follower_position(self.robot.leader_arms["main"], self.reset_pose)
log_say("Reset the leader robot done.", play_sounds=True)
else:
log_say(
f"Manually reset the environment for {self.reset_time_s} seconds.",
play_sounds=True,
)
start_time = time.perf_counter()
while time.perf_counter() - start_time < self.reset_time_s:
self.robot.teleop_step()
log_say("Manual reset of the environment done.", play_sounds=True)
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
return super().reset(seed=seed, options=options)
class BatchCompatibleWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
for key in observation:
if "image" in key and observation[key].dim() == 3:
observation[key] = observation[key].unsqueeze(0)
if "state" in key and observation[key].dim() == 1:
observation[key] = observation[key].unsqueeze(0)
if "velocity" in key and observation[key].dim() == 1:
observation[key] = observation[key].unsqueeze(0)
return observation
class GripperPenaltyWrapper(gym.RewardWrapper):
def __init__(self, env, penalty: float = -0.1):
super().__init__(env)
self.penalty = penalty
self.last_gripper_state = None
def reward(self, reward, action):
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND
gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or (
gripper_state_normalized > 0.75 and action_normalized < -0.5
)
return reward + self.penalty * int(gripper_penalty_bool)
def step(self, action):
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
gripper_action = action[-1]
obs, reward, terminated, truncated, info = self.env.step(action)
gripper_penalty = self.reward(reward, gripper_action)
info["discrete_penalty"] = gripper_penalty
return obs, reward, terminated, truncated, info
def reset(self, **kwargs):
self.last_gripper_state = None
obs, info = super().reset(**kwargs)
info["gripper_penalty"] = 0.0
return obs, info
class GripperActionWrapper(gym.ActionWrapper):
def __init__(self, env, quantization_threshold: float = 0.2, gripper_sleep: float = 0.0):
super().__init__(env)
self.quantization_threshold = quantization_threshold
self.gripper_sleep = gripper_sleep
self.last_gripper_action_time = 0.0
self.last_gripper_action = None
def action(self, action):
if self.gripper_sleep > 0.0:
if (
self.last_gripper_action is not None
and time.perf_counter() - self.last_gripper_action_time < self.gripper_sleep
):
action[-1] = self.last_gripper_action
else:
self.last_gripper_action_time = time.perf_counter()
self.last_gripper_action = action[-1]
gripper_command = action[-1]
# Gripper actions are between 0, 2
# we want to quantize them to -1, 0 or 1
gripper_command = gripper_command - 1.0
if self.quantization_threshold is not None:
# Quantize gripper command to -1, 0 or 1
gripper_command = (
np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0
)
gripper_command = gripper_command * MAX_GRIPPER_COMMAND
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
action[-1] = gripper_action.item()
return action
def reset(self, **kwargs):
obs, info = super().reset(**kwargs)
self.last_gripper_action_time = 0.0
self.last_gripper_action = None
return obs, info
class EEActionWrapper(gym.ActionWrapper):
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
super().__init__(env)
self.ee_action_space_params = ee_action_space_params
self.use_gripper = use_gripper
# Initialize kinematics instance for the appropriate robot type
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
self.kinematics = RobotKinematics(robot_type)
self.fk_function = self.kinematics.fk_gripper_tip
action_space_bounds = np.array(
[
ee_action_space_params.x_step_size,
ee_action_space_params.y_step_size,
ee_action_space_params.z_step_size,
]
)
if self.use_gripper:
# gripper actions open at 2.0, and closed at 0.0
min_action_space_bounds = np.concatenate([-action_space_bounds, [0.0]])
max_action_space_bounds = np.concatenate([action_space_bounds, [2.0]])
else:
min_action_space_bounds = -action_space_bounds
max_action_space_bounds = action_space_bounds
self.action_space = gym.spaces.Box(
low=min_action_space_bounds,
high=max_action_space_bounds,
shape=(3 + int(self.use_gripper),),
dtype=np.float32,
)
self.bounds = ee_action_space_params.bounds
def action(self, action):
desired_ee_pos = np.eye(4)
if self.use_gripper:
gripper_command = action[-1]
action = action[:-1]
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
current_ee_pos = self.fk_function(current_joint_pos)
desired_ee_pos[:3, 3] = np.clip(
current_ee_pos[:3, 3] + action,
self.bounds["min"],
self.bounds["max"],
)
target_joint_pos = self.kinematics.ik(
current_joint_pos,
desired_ee_pos,
position_only=True,
fk_func=self.fk_function,
)
if self.use_gripper:
target_joint_pos[-1] = gripper_command
return target_joint_pos
class EEObservationWrapper(gym.ObservationWrapper):
def __init__(self, env, ee_pose_limits):
super().__init__(env)
# Extend observation space to include end effector pose
prev_space = self.observation_space["observation.state"]
self.observation_space["observation.state"] = gym.spaces.Box(
low=np.concatenate([prev_space.low, ee_pose_limits["min"]]),
high=np.concatenate([prev_space.high, ee_pose_limits["max"]]),
shape=(prev_space.shape[0] + 3,),
dtype=np.float32,
)
# Initialize kinematics instance for the appropriate robot type
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
self.kinematics = RobotKinematics(robot_type)
self.fk_function = self.kinematics.fk_gripper_tip
def observation(self, observation):
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
current_ee_pos = self.fk_function(current_joint_pos)
observation["observation.state"] = torch.cat(
[
observation["observation.state"],
torch.from_numpy(current_ee_pos[:3, 3]),
],
dim=-1,
)
return observation
###########################################################
# Wrappers related to human intervention and input devices
###########################################################
class BaseLeaderControlWrapper(gym.Wrapper):
"""Base class for leader-follower robot control wrappers."""
def __init__(
self, env, use_geared_leader_arm: bool = False, ee_action_space_params=None, use_gripper=False
):
super().__init__(env)
self.robot_leader = env.unwrapped.robot.leader_arms["main"]
self.robot_follower = env.unwrapped.robot.follower_arms["main"]
self.use_geared_leader_arm = use_geared_leader_arm
self.ee_action_space_params = ee_action_space_params
self.use_ee_action_space = ee_action_space_params is not None
self.use_gripper: bool = use_gripper
# Set up keyboard event tracking
self._init_keyboard_events()
self.event_lock = Lock() # Thread-safe access to events
# Initialize robot control
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
self.kinematics = RobotKinematics(robot_type)
self.prev_leader_ee = None
self.prev_leader_pos = None
self.leader_torque_enabled = True
# Configure leader arm
# NOTE: Lower the gains of leader arm for automatic take-over
# With lower gains we can manually move the leader arm without risk of injury to ourselves or the robot
# With higher gains, it would be dangerous and difficult to modify the leader's pose while torque is enabled
# Default value for P_coeff is 32
self.robot_leader.write("Torque_Enable", 1)
self.robot_leader.write("P_Coefficient", 4)
self.robot_leader.write("I_Coefficient", 0)
self.robot_leader.write("D_Coefficient", 4)
self._init_keyboard_listener()
def _init_keyboard_events(self):
"""Initialize the keyboard events dictionary - override in subclasses."""
self.keyboard_events = {
"episode_success": False,
"episode_end": False,
"rerecord_episode": False,
}
def _handle_key_press(self, key, keyboard):
"""Handle key presses - override in subclasses for additional keys."""
try:
if key == keyboard.Key.esc:
self.keyboard_events["episode_end"] = True
return
if key == keyboard.Key.left:
self.keyboard_events["rerecord_episode"] = True
return
if hasattr(key, "char") and key.char == "s":
logging.info("Key 's' pressed. Episode success triggered.")
self.keyboard_events["episode_success"] = True
return
except Exception as e:
logging.error(f"Error handling key press: {e}")
def _init_keyboard_listener(self):
"""Initialize keyboard listener if not in headless mode"""
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
return
try:
from pynput import keyboard
def on_press(key):
with self.event_lock:
self._handle_key_press(key, keyboard)
self.listener = keyboard.Listener(on_press=on_press)
self.listener.start()
except ImportError:
logging.warning("Could not import pynput. Keyboard interface will not be available.")
self.listener = None
def _check_intervention(self):
"""Check if intervention is needed - override in subclasses."""
return False
def _handle_intervention(self, action):
"""Process actions during intervention mode."""
if self.leader_torque_enabled:
self.robot_leader.write("Torque_Enable", 0)
self.leader_torque_enabled = False
leader_pos = self.robot_leader.read("Present_Position")
follower_pos = self.robot_follower.read("Present_Position")
# [:3, 3] Last column of the transformation matrix corresponds to the xyz translation
leader_ee = self.kinematics.fk_gripper_tip(leader_pos)[:3, 3]
follower_ee = self.kinematics.fk_gripper_tip(follower_pos)[:3, 3]
if self.prev_leader_ee is None:
self.prev_leader_ee = leader_ee
# NOTE: Using the leader's position delta for teleoperation is too noisy
# Instead, we move the follower to match the leader's absolute position,
# and record the leader's position changes as the intervention action
action = leader_ee - follower_ee
action_intervention = leader_ee - self.prev_leader_ee
self.prev_leader_ee = leader_ee
if self.use_gripper:
# Get gripper action delta based on leader pose
leader_gripper = leader_pos[-1]
follower_gripper = follower_pos[-1]
gripper_delta = leader_gripper - follower_gripper
# Normalize by max angle and quantize to {0,1,2}
normalized_delta = gripper_delta / MAX_GRIPPER_COMMAND
if normalized_delta > 0.3:
gripper_action = 2
elif normalized_delta < -0.3:
gripper_action = 0
else:
gripper_action = 1
action = np.append(action, gripper_action)
action_intervention = np.append(action_intervention, gripper_delta)
return action, action_intervention
def _handle_leader_teleoperation(self):
"""Handle leader teleoperation (non-intervention) operation."""
if not self.leader_torque_enabled:
self.robot_leader.write("Torque_Enable", 1)
self.leader_torque_enabled = True
follower_pos = self.robot_follower.read("Present_Position")
self.robot_leader.write("Goal_Position", follower_pos)
def step(self, action):
"""Execute environment step with possible intervention."""
is_intervention = self._check_intervention()
action_intervention = None
# NOTE:
if is_intervention:
action, action_intervention = self._handle_intervention(action)
else:
self._handle_leader_teleoperation()
# NOTE:
obs, reward, terminated, truncated, info = self.env.step(action)
# Add intervention info
info["is_intervention"] = is_intervention
info["action_intervention"] = action_intervention if is_intervention else None
# Check for success or manual termination
success = self.keyboard_events["episode_success"]
terminated = terminated or self.keyboard_events["episode_end"] or success
if success:
reward = 1.0
logging.info("Episode ended successfully with reward 1.0")
return obs, reward, terminated, truncated, info
def reset(self, **kwargs):
"""Reset the environment and internal state."""
self.prev_leader_ee = None
self.prev_leader_pos = None
self.keyboard_events = dict.fromkeys(self.keyboard_events, False)
return super().reset(**kwargs)
def close(self):
"""Clean up resources."""
if hasattr(self, "listener") and self.listener is not None:
self.listener.stop()
return self.env.close()
class GearedLeaderControlWrapper(BaseLeaderControlWrapper):
"""Wrapper that enables manual intervention via keyboard."""
def _init_keyboard_events(self):
"""Initialize keyboard events including human intervention flag."""
super()._init_keyboard_events()
self.keyboard_events["human_intervention_step"] = False
def _handle_key_press(self, key, keyboard):
"""Handle key presses including space for intervention toggle."""
super()._handle_key_press(key, keyboard)
if key == keyboard.Key.space:
if not self.keyboard_events["human_intervention_step"]:
logging.info(
"Space key pressed. Human intervention required.\n"
"Place the leader in similar pose to the follower and press space again."
)
self.keyboard_events["human_intervention_step"] = True
log_say("Human intervention step.", play_sounds=True)
else:
self.keyboard_events["human_intervention_step"] = False
logging.info("Space key pressed for a second time.\nContinuing with policy actions.")
log_say("Continuing with policy actions.", play_sounds=True)
def _check_intervention(self):
"""Check if human intervention is active."""
return self.keyboard_events["human_intervention_step"]
class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper):
"""Wrapper with automatic intervention based on error thresholds."""
def __init__(
self,
env,
ee_action_space_params=None,
use_gripper=False,
intervention_threshold=1.7,
release_threshold=0.01,
queue_size=10,
):
super().__init__(env, ee_action_space_params=ee_action_space_params, use_gripper=use_gripper)
# Error tracking parameters
self.intervention_threshold = intervention_threshold # Threshold to trigger intervention
self.release_threshold = release_threshold # Threshold to release intervention
self.queue_size = queue_size # Number of error measurements to keep
# Error tracking variables
self.error_queue = deque(maxlen=self.queue_size)
self.error_over_time_queue = deque(maxlen=self.queue_size)
self.previous_error = 0.0
self.is_intervention_active = False
self.start_time = time.perf_counter()
def _check_intervention(self):
"""Determine if intervention should occur based on leader-follower error."""
# Skip intervention logic for the first few steps to collect data
if time.perf_counter() - self.start_time < 1.0: # Wait 1 second before enabling
return False
# Get current positions
leader_positions = self.robot_leader.read("Present_Position")
follower_positions = self.robot_follower.read("Present_Position")
# Calculate error and error rate
error = np.linalg.norm(leader_positions - follower_positions)
error_over_time = np.abs(error - self.previous_error)
# Add to queue for running average
self.error_queue.append(error)
self.error_over_time_queue.append(error_over_time)
# Update previous error
self.previous_error = error
# Calculate averages if we have enough data
if len(self.error_over_time_queue) >= self.queue_size:
avg_error_over_time = np.mean(self.error_over_time_queue)
# Debug info
if self.is_intervention_active:
logging.debug(f"Error rate during intervention: {avg_error_over_time:.4f}")
# Determine if intervention should start or stop
if not self.is_intervention_active and avg_error_over_time > self.intervention_threshold:
# Transition to intervention mode
self.is_intervention_active = True
logging.info(f"Starting automatic intervention: error rate {avg_error_over_time:.4f}")
elif self.is_intervention_active and avg_error_over_time < self.release_threshold:
# End intervention mode
self.is_intervention_active = False
logging.info(f"Ending automatic intervention: error rate {avg_error_over_time:.4f}")
return self.is_intervention_active
def reset(self, **kwargs):
"""Reset error tracking on environment reset."""
self.error_queue.clear()
self.error_over_time_queue.clear()
self.previous_error = 0.0
self.is_intervention_active = False
self.start_time = time.perf_counter()
return super().reset(**kwargs)
class GamepadControlWrapper(gym.Wrapper):
"""
Wrapper that allows controlling a gym environment with a gamepad.
This wrapper intercepts the step method and allows human input via gamepad
to override the agent's actions when desired.
"""
def __init__(
self,
env,
x_step_size=1.0,
y_step_size=1.0,
z_step_size=1.0,
use_gripper=False,
auto_reset=False,
input_threshold=0.001,
):
"""
Initialize the gamepad controller wrapper.
cfg.
env: The environment to wrap
x_step_size: Base movement step size for X axis in meters
y_step_size: Base movement step size for Y axis in meters
z_step_size: Base movement step size for Z axis in meters
vendor_id: USB vendor ID of the gamepad (default: Logitech)
product_id: USB product ID of the gamepad (default: RumblePad 2)
auto_reset: Whether to auto reset the environment when episode ends
input_threshold: Minimum movement delta to consider as active input
"""
super().__init__(env)
from lerobot.scripts.server.end_effector_control_utils import (
GamepadController,
GamepadControllerHID,
)
# use HidApi for macos
if sys.platform == "darwin":
self.controller = GamepadControllerHID(
x_step_size=x_step_size,
y_step_size=y_step_size,
z_step_size=z_step_size,
)
else:
self.controller = GamepadController(
x_step_size=x_step_size,
y_step_size=y_step_size,
z_step_size=z_step_size,
)
self.auto_reset = auto_reset
self.use_gripper = use_gripper
self.input_threshold = input_threshold
self.controller.start()
logging.info("Gamepad control wrapper initialized")
print("Gamepad controls:")
print(" Left analog stick: Move in X-Y plane")
print(" Right analog stick: Move in Z axis (up/down)")
print(" X/Square button: End episode (FAILURE)")
print(" Y/Triangle button: End episode (SUCCESS)")
print(" B/Circle button: Exit program")
def get_gamepad_action(
self,
) -> Tuple[bool, np.ndarray, bool, bool, bool]:
"""
Get the current action from the gamepad if any input is active.
Returns:
Tuple of (is_active, action, terminate_episode, success)
"""
# Update the controller to get fresh inputs
self.controller.update()
# Get movement deltas from the controller
delta_x, delta_y, delta_z = self.controller.get_deltas()
intervention_is_active = self.controller.should_intervene()
# Create action from gamepad input
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
if self.use_gripper:
gripper_command = self.controller.gripper_command()
if gripper_command == "open":
gamepad_action = np.concatenate([gamepad_action, [2.0]])
elif gripper_command == "close":
gamepad_action = np.concatenate([gamepad_action, [0.0]])
else:
gamepad_action = np.concatenate([gamepad_action, [1.0]])
# Check episode ending buttons
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
episode_end_status = self.controller.get_episode_end_status()
terminate_episode = episode_end_status is not None
success = episode_end_status == "success"
rerecord_episode = episode_end_status == "rerecord_episode"
return (
intervention_is_active,
gamepad_action,
terminate_episode,
success,
rerecord_episode,
)
def step(self, action):
"""
Step the environment, using gamepad input to override actions when active.
cfg.
action: Original action from agent
Returns:
observation, reward, terminated, truncated, info
"""
# Get gamepad state and action
(
is_intervention,
gamepad_action,
terminate_episode,
success,
rerecord_episode,
) = self.get_gamepad_action()
# Update episode ending state if requested
if terminate_episode:
logging.info(f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}")
# Only override the action if gamepad is active
action = gamepad_action if is_intervention else action
# Step the environment
obs, reward, terminated, truncated, info = self.env.step(action)
# Add episode ending if requested via gamepad
terminated = terminated or truncated or terminate_episode
if success:
reward = 1.0
logging.info("Episode ended successfully with reward 1.0")
if isinstance(action, np.ndarray):
action = torch.from_numpy(action)
info["is_intervention"] = is_intervention
info["action_intervention"] = action
info["rerecord_episode"] = rerecord_episode
# If episode ended, reset the state
if terminated or truncated:
# Add success/failure information to info dict
info["next.success"] = success
# Auto reset if configured
if self.auto_reset:
obs, reset_info = self.reset()
info.update(reset_info)
return obs, reward, terminated, truncated, info
def close(self):
"""Clean up resources when environment closes."""
# Stop the controller
if hasattr(self, "controller"):
self.controller.stop()
# Call the parent close method
return self.env.close()
class TorchBox(gym.spaces.Box):
"""A version of gym.spaces.Box that handles PyTorch tensors.
This class extends gym.spaces.Box to work with PyTorch tensors,
providing compatibility between NumPy arrays and PyTorch tensors.
"""
def __init__(
self,
low: float | Sequence[float] | np.ndarray,
high: float | Sequence[float] | np.ndarray,
shape: Sequence[int] | None = None,
np_dtype: np.dtype | type = np.float32,
torch_dtype: torch.dtype = torch.float32,
device: str = "cpu",
seed: int | np.random.Generator | None = None,
) -> None:
super().__init__(low, high, shape=shape, dtype=np_dtype, seed=seed)
self.torch_dtype = torch_dtype
self.device = device
def sample(self) -> torch.Tensor:
arr = super().sample()
return torch.as_tensor(arr, dtype=self.torch_dtype, device=self.device)
def contains(self, x: torch.Tensor) -> bool:
# Move to CPU/numpy and cast to the internal dtype
arr = x.detach().cpu().numpy().astype(self.dtype, copy=False)
return super().contains(arr)
def seed(self, seed: int | np.random.Generator | None = None):
super().seed(seed)
return [seed]
def __repr__(self) -> str:
return (
f"TorchBox({self.low_repr}, {self.high_repr}, {self.shape}, "
f"np={self.dtype.name}, torch={self.torch_dtype}, device={self.device})"
)
class TorchActionWrapper(gym.Wrapper):
"""
The goal of this wrapper is to change the action_space.sample()
to torch tensors.
"""
def __init__(self, env: gym.Env, device: str):
super().__init__(env)
self.action_space = TorchBox(
low=env.action_space.low,
high=env.action_space.high,
shape=env.action_space.shape,
torch_dtype=torch.float32,
device=torch.device("cpu"),
)
def step(self, action: torch.Tensor):
if action.dim() == 2:
action = action.squeeze(0)
action = action.detach().cpu().numpy()
return self.env.step(action)
###########################################################
# Factory functions
###########################################################
def make_robot_env(cfg) -> gym.vector.VectorEnv:
"""
Factory function to create a vectorized robot environment.
cfg.
robot: Robot instance to control
reward_classifier: Classifier model for computing rewards
cfg: Configuration object containing environment parameters
Returns:
A vectorized gym environment with all the necessary wrappers applied.
"""
robot = make_robot_from_config(cfg.robot)
# Create base environment
env = RobotEnv(
robot=robot,
display_cameras=cfg.wrapper.display_cameras,
)
# Add observation and image processing
if cfg.wrapper.add_joint_velocity_to_observation:
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
if cfg.wrapper.add_current_to_observation:
env = AddCurrentToObservation(env=env)
if cfg.wrapper.add_ee_pose_to_observation:
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds)
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
if cfg.wrapper.crop_params_dict is not None:
env = ImageCropResizeWrapper(
env=env,
crop_params_dict=cfg.wrapper.crop_params_dict,
resize_size=cfg.wrapper.resize_size,
)
# Add reward computation and control wrappers
reward_classifier = init_reward_classifier(cfg)
if reward_classifier is not None:
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.use_gripper:
env = GripperActionWrapper(env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold)
if cfg.wrapper.gripper_penalty is not None:
env = GripperPenaltyWrapper(
env=env,
penalty=cfg.wrapper.gripper_penalty,
)
env = EEActionWrapper(
env=env,
ee_action_space_params=cfg.wrapper.ee_action_space_params,
use_gripper=cfg.wrapper.use_gripper,
)
if cfg.wrapper.ee_action_space_params.control_mode == "gamepad":
env = GamepadControlWrapper(
env=env,
x_step_size=cfg.wrapper.ee_action_space_params.x_step_size,
y_step_size=cfg.wrapper.ee_action_space_params.y_step_size,
z_step_size=cfg.wrapper.ee_action_space_params.z_step_size,
use_gripper=cfg.wrapper.use_gripper,
)
elif cfg.wrapper.ee_action_space_params.control_mode == "leader":
env = GearedLeaderControlWrapper(
env=env,
ee_action_space_params=cfg.wrapper.ee_action_space_params,
use_gripper=cfg.wrapper.use_gripper,
)
elif cfg.wrapper.ee_action_space_params.control_mode == "leader_automatic":
env = GearedLeaderAutomaticControlWrapper(
env=env,
ee_action_space_params=cfg.wrapper.ee_action_space_params,
use_gripper=cfg.wrapper.use_gripper,
)
else:
raise ValueError(f"Invalid control mode: {cfg.wrapper.ee_action_space_params.control_mode}")
env = ResetWrapper(
env=env,
reset_pose=cfg.wrapper.fixed_reset_joint_positions,
reset_time_s=cfg.wrapper.reset_time_s,
)
env = BatchCompatibleWrapper(env=env)
env = TorchActionWrapper(env=env, device=cfg.device)
return env
def init_reward_classifier(cfg):
"""
Load a reward classifier policy from a pretrained path if configured.
Args:
cfg: The environment configuration containing classifier paths
Returns:
The loaded classifier model or None if not configured
"""
if cfg.reward_classifier_pretrained_path is None:
return None
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
# Get device from config or default to CUDA
device = getattr(cfg, "device", "cpu")
# Load the classifier directly using from_pretrained
classifier = Classifier.from_pretrained(
pretrained_name_or_path=cfg.reward_classifier_pretrained_path,
)
# Ensure model is on the correct device
classifier.to(device)
classifier.eval() # Set to evaluation mode
return classifier
###########################################################
# Record and replay functions
###########################################################
def record_dataset(env, policy, cfg):
"""
Record a dataset of robot interactions using either a policy or teleop.
cfg.
env: The environment to record from
repo_id: Repository ID for dataset storage
root: Local root directory for dataset (optional)
num_episodes: Number of episodes to record
control_time_s: Maximum episode length in seconds
fps: Frames per second for recording
push_to_hub: Whether to push dataset to Hugging Face Hub
task_description: Description of the task being recorded
policy: Optional policy to generate actions (if None, uses teleop)
"""
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
# Setup initial action (zero action if using teleop)
action = env.action_space.sample() * 0.0
# Configure dataset features based on environment spaces
features = {
"observation.state": {
"dtype": "float32",
"shape": env.observation_space["observation.state"].shape,
"names": None,
},
"action": {
"dtype": "float32",
"shape": env.action_space.shape,
"names": None,
},
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
}
# Add image features
for key in env.observation_space:
if "image" in key:
features[key] = {
"dtype": "video",
"shape": env.observation_space[key].shape,
"names": None,
}
# Create dataset
dataset = LeRobotDataset.create(
cfg.repo_id,
cfg.fps,
root=cfg.dataset_root,
use_videos=True,
image_writer_threads=4,
image_writer_processes=0,
features=features,
)
# Record episodes
episode_index = 0
recorded_action = None
while episode_index < cfg.num_episodes:
obs, _ = env.reset()
start_episode_t = time.perf_counter()
log_say(f"Recording episode {episode_index}", play_sounds=True)
# Run episode steps
while time.perf_counter() - start_episode_t < cfg.wrapper.control_time_s:
start_loop_t = time.perf_counter()
# Get action from policy if available
if cfg.pretrained_policy_name_or_path is not None:
action = policy.select_action(obs)
# Step environment
obs, reward, terminated, truncated, info = env.step(action)
# Check if episode needs to be rerecorded
if info.get("rerecord_episode", False):
break
# For teleop, get action from intervention
recorded_action = {
"action": info["action_intervention"].cpu().squeeze(0).float() if policy is None else action
}
# Process observation for dataset
obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()}
# Add frame to dataset
frame = {**obs, **recorded_action}
frame["next.reward"] = np.array([reward], dtype=np.float32)
frame["next.done"] = np.array([terminated or truncated], dtype=bool)
frame["task"] = cfg.task
dataset.add_frame(frame)
# Maintain consistent timing
if cfg.fps:
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / cfg.fps - dt_s)
if terminated or truncated:
break
# Handle episode recording
if info.get("rerecord_episode", False):
dataset.clear_episode_buffer()
logging.info(f"Re-recording episode {episode_index}")
continue
dataset.save_episode(cfg.task)
episode_index += 1
# Finalize dataset
# dataset.consolidate(run_compute_stats=True)
if cfg.push_to_hub:
dataset.push_to_hub()
def replay_episode(env, cfg):
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode])
env.reset()
actions = dataset.hf_dataset.select_columns("action")
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action = actions[idx]["action"]
env.step(action)
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / 10 - dt_s)
@parser.wrap()
def main(cfg: EnvConfig):
env = make_robot_env(cfg)
if cfg.mode == "record":
policy = None
if cfg.pretrained_policy_name_or_path is not None:
from lerobot.common.policies.sac.modeling_sac import SACPolicy
policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
policy.to(cfg.device)
policy.eval()
record_dataset(
env,
policy=policy,
cfg=cfg,
)
exit()
if cfg.mode == "replay":
replay_episode(
env,
cfg=cfg,
)
exit()
env.reset()
# Initialize the smoothed action as a random sample.
smoothed_action = env.action_space.sample()
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
alpha = 1.0
num_episode = 0
successes = []
while num_episode < 20:
start_loop_s = time.perf_counter()
# Sample a new random action from the robot's action space.
new_random_action = env.action_space.sample()
# Update the smoothed action using an exponential moving average.
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
# Execute the step: wrap the NumPy action in a torch tensor.
obs, reward, terminated, truncated, info = env.step(smoothed_action)
if terminated or truncated:
successes.append(reward)
env.reset()
num_episode += 1
dt_s = time.perf_counter() - start_loop_s
busy_wait(1 / cfg.fps - dt_s)
logging.info(f"Success after 20 steps {successes}")
logging.info(f"success rate {sum(successes) / len(successes)}")
if __name__ == "__main__":
main()