import argparse import logging import time from threading import Lock from typing import Annotated, Any, Callable, Dict, Optional, Tuple import gymnasium as gym import numpy as np import torch import torchvision.transforms.functional as F # noqa: N812 from lerobot.common.envs.utils import preprocess_observation from lerobot.common.robot_devices.control_utils import busy_wait, is_headless from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.utils.utils import init_hydra_config, log_say logging.basicConfig(level=logging.INFO) class HILSerlRobotEnv(gym.Env): """ Gym-compatible environment for evaluating robotic control policies with integrated human intervention. This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta) and absolute joint position commands and automatically configures its observation and action spaces based on the robot's sensors and configuration. The environment can switch between executing actions from a policy or using teleoperated actions (human intervention) during each step. When teleoperation is used, the override action is captured and returned in the `info` dict along with a flag `is_intervention`. """ def __init__( self, robot, use_delta_action_space: bool = True, delta: float | None = None, display_cameras: bool = False, ): """ Initialize the HILSerlRobotEnv environment. The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup supports both relative (delta) adjustments and absolute joint positions for controlling the robot. Args: robot: The robot interface object used to connect and interact with the physical robot. use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute joint positions are used. delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between 0 and 1 when using a delta action space. display_cameras (bool): If True, the robot's camera feeds will be displayed during execution. """ super().__init__() self.robot = robot self.display_cameras = display_cameras # Connect to the robot if not already connected. if not self.robot.is_connected: self.robot.connect() self.initial_follower_position = robot.follower_arms["main"].read( "Present_Position" ) # Episode tracking. self.current_step = 0 self.episode_data = None self.delta = delta self.use_delta_action_space = use_delta_action_space self.current_joint_positions = self.robot.follower_arms["main"].read( "Present_Position" ) # Retrieve the size of the joint position interval bound. self.relative_bounds_size = ( self.robot.config.joint_position_relative_bounds["max"] - self.robot.config.joint_position_relative_bounds["min"] ) self.delta_relative_bounds_size = self.relative_bounds_size * self.delta self.robot.config.max_relative_target = self.delta_relative_bounds_size.float() # Dynamically configure the observation and action spaces. self._setup_spaces() def _setup_spaces(self): """ Dynamically configure the observation and action spaces based on the robot's capabilities. Observation Space: - For keys with "image": A Box space with pixel values ranging from 0 to 255. - For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range. Action Space: - The action space is defined as a Tuple where: • The first element is a Box space representing joint position commands. It is defined as relative (delta) or absolute, based on the configuration. • The second element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation). """ example_obs = self.robot.capture_observation() # Define observation spaces for images and other states. image_keys = [key for key in example_obs if "image" in key] state_keys = [key for key in example_obs if "image" not in key] observation_spaces = { key: gym.spaces.Box( low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8 ) for key in image_keys } observation_spaces["observation.state"] = gym.spaces.Dict( { key: gym.spaces.Box( low=0, high=10, shape=example_obs[key].shape, dtype=np.float32 ) for key in state_keys } ) self.observation_space = gym.spaces.Dict(observation_spaces) # Define the action space for joint positions along with setting an intervention flag. action_dim = len(self.robot.follower_arms["main"].read("Present_Position")) if self.use_delta_action_space: action_space_robot = gym.spaces.Box( low=-self.relative_bounds_size.cpu().numpy(), high=self.relative_bounds_size.cpu().numpy(), shape=(action_dim,), dtype=np.float32, ) else: action_space_robot = gym.spaces.Box( low=self.robot.config.joint_position_relative_bounds["min"] .cpu() .numpy(), high=self.robot.config.joint_position_relative_bounds["max"] .cpu() .numpy(), shape=(action_dim,), dtype=np.float32, ) self.action_space = gym.spaces.Tuple( ( action_space_robot, gym.spaces.Discrete(2), ), ) def reset( self, seed=None, options=None ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: """ Reset the environment to its initial state. This method resets the step counter and clears any episodic data. Args: seed (Optional[int]): A seed for random number generation to ensure reproducibility. options (Optional[dict]): Additional options to influence the reset behavior. Returns: A tuple containing: - observation (dict): The initial sensor observation. - info (dict): A dictionary with supplementary information, including the key "initial_position". """ super().reset(seed=seed, options=options) # Capture the initial observation. observation = self.robot.capture_observation() # Reset episode tracking variables. self.current_step = 0 self.episode_data = None return observation, {"initial_position": self.initial_follower_position} def step( self, action: Tuple[np.ndarray, bool] ) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]: """ Execute a single step within the environment using the specified action. The provided action is a tuple comprised of: • A policy action (joint position commands) that may be either in absolute values or as a delta. • A boolean flag indicating whether teleoperation (human intervention) should be used for this step. Behavior: - When the intervention flag is False, the environment processes and sends the policy action to the robot. - When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted to relative change based on the current joint positions. Args: action (tuple): A tuple with two elements: - policy_action (np.ndarray or torch.Tensor): The commanded joint positions. - intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input. Returns: tuple: A tuple containing: - observation (dict): The new sensor observation after taking the step. - reward (float): The step reward (default is 0.0 within this wrapper). - terminated (bool): True if the episode has reached a terminal state. - truncated (bool): True if the episode was truncated (e.g., time constraints). - info (dict): Additional debugging information including: ◦ "action_intervention": The teleop action if intervention was used. ◦ "is_intervention": Flag indicating whether teleoperation was employed. """ policy_action, intervention_bool = action teleop_action = None self.current_joint_positions = self.robot.follower_arms["main"].read( "Present_Position" ) if isinstance(policy_action, torch.Tensor): policy_action = policy_action.cpu().numpy() policy_action = np.clip( policy_action, self.action_space[0].low, self.action_space[0].high ) if not intervention_bool: if self.use_delta_action_space: target_joint_positions = ( self.current_joint_positions + self.delta * policy_action ) else: target_joint_positions = policy_action self.robot.send_action(torch.from_numpy(target_joint_positions)) observation = self.robot.capture_observation() else: observation, teleop_action = self.robot.teleop_step(record_data=True) teleop_action = teleop_action[ "action" ] # Convert tensor to appropriate format # When applying the delta action space, convert teleop absolute values to relative differences. if self.use_delta_action_space: teleop_action = ( teleop_action - self.current_joint_positions ) / self.delta if torch.any(teleop_action < -self.relative_bounds_size) and torch.any( teleop_action > self.relative_bounds_size ): logging.debug( f"Relative teleop delta exceeded bounds {self.relative_bounds_size}, teleop_action {teleop_action}\n" f"lower bounds condition {teleop_action < -self.relative_bounds_size}\n" f"upper bounds condition {teleop_action > self.relative_bounds_size}" ) teleop_action = torch.clamp( teleop_action, -self.relative_bounds_size, self.relative_bounds_size, ) # NOTE: To mimic the shape of a neural network output, we add a batch dimension to the teleop action. if teleop_action.dim() == 1: teleop_action = teleop_action.unsqueeze(0) # self.render() self.current_step += 1 reward = 0.0 terminated = False truncated = False return ( observation, reward, terminated, truncated, { "action_intervention": teleop_action, "is_intervention": teleop_action is not None, }, ) def render(self): """ Render the current state of the environment by displaying the robot's camera feeds. """ import cv2 observation = self.robot.capture_observation() image_keys = [key for key in observation if "image" in key] for key in image_keys: cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) cv2.waitKey(1) def close(self): """ Close the environment and clean up resources by disconnecting the robot. If the robot is currently connected, this method properly terminates the connection to ensure that all associated resources are released. """ if self.robot.is_connected: self.robot.disconnect() class ActionRepeatWrapper(gym.Wrapper): def __init__(self, env, nb_repeat: int = 1): super().__init__(env) self.nb_repeat = nb_repeat def step(self, action): for _ in range(self.nb_repeat): obs, reward, done, truncated, info = self.env.step(action) if done or truncated: break return obs, reward, done, truncated, info class RewardWrapper(gym.Wrapper): def __init__(self, env, reward_classifier, device: torch.device = "cuda"): """ Wrapper to add reward prediction to the environment, it use a trained classifer. Args: env: The environment to wrap reward_classifier: The reward classifier model device: The device to run the model on """ self.env = env # NOTE: We got 15% speedup by compiling the model self.reward_classifier = torch.compile(reward_classifier) if isinstance(device, str): device = torch.device(device) self.device = device def step(self, action): observation, _, terminated, truncated, info = self.env.step(action) images = [ observation[key].to(self.device, non_blocking=self.device.type == "cuda") for key in observation if "image" in key ] start_time = time.perf_counter() with torch.inference_mode(): reward = ( self.reward_classifier.predict_reward(images, threshold=0.8) if self.reward_classifier is not None else 0.0 ) info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time) # logging.info(f"Reward: {reward}") if reward == 1.0: terminated = True return observation, reward, terminated, truncated, info def reset(self, seed=None, options=None): return self.env.reset(seed=seed, options=options) class JointMaskingActionSpace(gym.Wrapper): def __init__(self, env, mask): """ Wrapper to mask out dimensions of the action space. Args: env: The environment to wrap mask: Binary mask array where 0 indicates dimensions to remove """ super().__init__(env) # Validate mask matches action space # Keep only dimensions where mask is 1 self.active_dims = np.where(mask)[0] if isinstance(env.action_space, gym.spaces.Box): if len(mask) != env.action_space.shape[0]: raise ValueError("Mask length must match action space dimensions") low = env.action_space.low[self.active_dims] high = env.action_space.high[self.active_dims] self.action_space = gym.spaces.Box( low=low, high=high, dtype=env.action_space.dtype ) if isinstance(env.action_space, gym.spaces.Tuple): if len(mask) != env.action_space[0].shape[0]: raise ValueError("Mask length must match action space 0 dimensions") low = env.action_space[0].low[self.active_dims] high = env.action_space[0].high[self.active_dims] action_space_masked = gym.spaces.Box( low=low, high=high, dtype=env.action_space[0].dtype ) self.action_space = gym.spaces.Tuple( (action_space_masked, env.action_space[1]) ) # Create new action space with masked dimensions def action(self, action): """ Convert masked action back to full action space. Args: action: Action in masked space. For Tuple spaces, the first element is masked. Returns: Action in original space with masked dims set to 0. """ # Determine whether we are handling a Tuple space or a Box. if isinstance(self.env.action_space, gym.spaces.Tuple): # Extract the masked component from the tuple. masked_action = action[0] if isinstance(action, tuple) else action # Create a full action for the Box element. full_box_action = np.zeros( self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype ) full_box_action[self.active_dims] = masked_action # Return a tuple with the reconstructed Box action and the unchanged remainder. return (full_box_action, action[1]) else: # For Box action spaces. masked_action = action if not isinstance(action, tuple) else action[0] full_action = np.zeros( self.env.action_space.shape, dtype=self.env.action_space.dtype ) full_action[self.active_dims] = masked_action return full_action def step(self, action): action = self.action(action) obs, reward, terminated, truncated, info = self.env.step(action) if "action_intervention" in info and info["action_intervention"] is not None: if info["action_intervention"].dim() == 1: info["action_intervention"] = info["action_intervention"][ self.active_dims ] else: info["action_intervention"] = info["action_intervention"][ :, self.active_dims ] return obs, reward, terminated, truncated, info class TimeLimitWrapper(gym.Wrapper): def __init__(self, env, control_time_s, fps): self.env = env self.control_time_s = control_time_s self.fps = fps self.last_timestamp = 0.0 self.episode_time_in_s = 0.0 self.max_episode_steps = int(self.control_time_s * self.fps) self.current_step = 0 def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) time_since_last_step = time.perf_counter() - self.last_timestamp self.episode_time_in_s += time_since_last_step self.last_timestamp = time.perf_counter() self.current_step += 1 # check if last timestep took more time than the expected fps if 1.0 / time_since_last_step < self.fps: logging.debug(f"Current timestep exceeded expected fps {self.fps}") if self.episode_time_in_s > self.control_time_s: # if self.current_step >= self.max_episode_steps: # Terminated = True terminated = True return obs, reward, terminated, truncated, info def reset(self, seed=None, options=None): self.episode_time_in_s = 0.0 self.last_timestamp = time.perf_counter() self.current_step = 0 return self.env.reset(seed=seed, options=options) class ImageCropResizeWrapper(gym.Wrapper): def __init__( self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None, ): super().__init__(env) self.env = env self.crop_params_dict = crop_params_dict print(f"obs_keys , {self.env.observation_space}") print(f"crop params dict {crop_params_dict.keys()}") for key_crop in crop_params_dict: if key_crop not in self.env.observation_space.keys(): # noqa: SIM118 raise ValueError(f"Key {key_crop} not in observation space") for key in crop_params_dict: top, left, height, width = crop_params_dict[key] new_shape = (top + height, left + width) self.observation_space[key] = gym.spaces.Box( low=0, high=255, shape=new_shape ) self.resize_size = resize_size if self.resize_size is None: self.resize_size = (128, 128) def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) for k in self.crop_params_dict: device = obs[k].device # Check for NaNs before processing if torch.isnan(obs[k]).any(): logging.error( f"NaN values detected in observation {k} before crop and resize" ) if device == torch.device("mps:0"): obs[k] = obs[k].cpu() obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) obs[k] = F.resize(obs[k], self.resize_size) # Check for NaNs after processing if torch.isnan(obs[k]).any(): logging.error( f"NaN values detected in observation {k} after crop and resize" ) obs[k] = obs[k].to(device) return obs, reward, terminated, truncated, info def reset(self, seed=None, options=None): obs, info = self.env.reset(seed=seed, options=options) for k in self.crop_params_dict: device = obs[k].device if device == torch.device("mps:0"): obs[k] = obs[k].cpu() obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) obs[k] = F.resize(obs[k], self.resize_size) obs[k] = obs[k].to(device) return obs, info class ConvertToLeRobotObservation(gym.ObservationWrapper): def __init__(self, env, device): super().__init__(env) if isinstance(device, str): device = torch.device(device) self.device = device def observation(self, observation): observation = preprocess_observation(observation) observation = { key: observation[key].to( self.device, non_blocking=self.device.type == "cuda" ) for key in observation } observation = { k: torch.tensor(v, device=self.device) for k, v in observation.items() } return observation class KeyboardInterfaceWrapper(gym.Wrapper): def __init__(self, env): super().__init__(env) self.listener = None self.events = { "exit_early": False, "pause_policy": False, "reset_env": False, "human_intervention_step": False, "episode_success": False, } self.event_lock = Lock() # Thread-safe access to events self._init_keyboard_listener() def _init_keyboard_listener(self): """Initialize keyboard listener if not in headless mode""" if is_headless(): logging.warning( "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." ) return try: from pynput import keyboard def on_press(key): with self.event_lock: try: if key == keyboard.Key.right or key == keyboard.Key.esc: print("Right arrow key pressed. Exiting loop...") self.events["exit_early"] = True return if hasattr(key, "char") and key.char == "s": print("Key 's' pressed. Episode success triggered.") self.events["episode_success"] = True return if key == keyboard.Key.space and not self.events["exit_early"]: if not self.events["pause_policy"]: print( "Space key pressed. Human intervention required.\n" "Place the leader in similar pose to the follower and press space again." ) self.events["pause_policy"] = True log_say( "Human intervention stage. Get ready to take over.", play_sounds=True, ) return if ( self.events["pause_policy"] and not self.events["human_intervention_step"] ): self.events["human_intervention_step"] = True print("Space key pressed. Human intervention starting.") log_say( "Starting human intervention.", play_sounds=True ) return if ( self.events["pause_policy"] and self.events["human_intervention_step"] ): self.events["pause_policy"] = False self.events["human_intervention_step"] = False print("Space key pressed for a third time.") log_say( "Continuing with policy actions.", play_sounds=True ) return except Exception as e: print(f"Error handling key press: {e}") self.listener = keyboard.Listener(on_press=on_press) self.listener.start() except ImportError: logging.warning( "Could not import pynput. Keyboard interface will not be available." ) self.listener = None def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]: is_intervention = False terminated_by_keyboard = False # Extract policy_action if needed if isinstance(self.env.action_space, gym.spaces.Tuple): policy_action = action[0] # Check the event flags without holding the lock for too long. with self.event_lock: if self.events["exit_early"]: terminated_by_keyboard = True pause_policy = self.events["pause_policy"] if pause_policy: # Now, wait for human_intervention_step without holding the lock while True: with self.event_lock: if self.events["human_intervention_step"]: is_intervention = True break time.sleep(0.1) # Check more frequently if desired # Execute the step in the underlying environment obs, reward, terminated, truncated, info = self.env.step( (policy_action, is_intervention) ) # Override reward and termination if episode success event triggered with self.event_lock: if self.events["episode_success"]: reward = 1 terminated_by_keyboard = True return obs, reward, terminated or terminated_by_keyboard, truncated, info def reset(self, **kwargs) -> Tuple[Any, Dict]: """ Reset the environment and clear any pending events """ with self.event_lock: self.events = {k: False for k in self.events} return self.env.reset(**kwargs) def close(self): """ Properly clean up the keyboard listener when the environment is closed """ if self.listener is not None: self.listener.stop() super().close() class ResetWrapper(gym.Wrapper): def __init__( self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5, ): super().__init__(env) self.reset_fn = reset_fn self.reset_time_s = reset_time_s self.robot = self.unwrapped.robot self.init_pos = self.unwrapped.initial_follower_position def reset(self, *, seed=None, options=None): if self.reset_fn is not None: self.reset_fn(self.env) else: log_say( f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True, ) start_time = time.perf_counter() while time.perf_counter() - start_time < self.reset_time_s: self.robot.teleop_step() log_say("Manual reseting of the environment done.", play_sounds=True) return super().reset(seed=seed, options=options) class BatchCompitableWrapper(gym.ObservationWrapper): def __init__(self, env): super().__init__(env) def observation( self, observation: dict[str, torch.Tensor] ) -> dict[str, torch.Tensor]: for key in observation: if "image" in key and observation[key].dim() == 3: observation[key] = observation[key].unsqueeze(0) if "state" in key and observation[key].dim() == 1: observation[key] = observation[key].unsqueeze(0) return observation # TODO: REMOVE TH def make_robot_env( robot, reward_classifier, cfg, n_envs: int = 1, ) -> gym.vector.VectorEnv: """ Factory function to create a vectorized robot environment. Args: robot: Robot instance to control reward_classifier: Classifier model for computing rewards cfg: Configuration object containing environment parameters n_envs: Number of environments to create in parallel. Defaults to 1. Returns: A vectorized gym environment with all the necessary wrappers applied. """ if "maniskill" in cfg.env.name: from lerobot.scripts.server.maniskill_manipulator import make_maniskill logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN") env = make_maniskill( cfg=cfg, n_envs=1, ) return env # Create base environment env = HILSerlRobotEnv( robot=robot, display_cameras=cfg.env.wrapper.display_cameras, delta=cfg.env.wrapper.delta_action, use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions, ) # Add observation and image processing env = ConvertToLeRobotObservation(env=env, device=cfg.device) if cfg.env.wrapper.crop_params_dict is not None: env = ImageCropResizeWrapper( env=env, crop_params_dict=cfg.env.wrapper.crop_params_dict, resize_size=cfg.env.wrapper.resize_size, ) # Add reward computation and control wrappers env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) env = TimeLimitWrapper( env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps ) env = KeyboardInterfaceWrapper(env=env) env = ResetWrapper( env=env, reset_fn=None, reset_time_s=cfg.env.wrapper.reset_time_s ) env = JointMaskingActionSpace( env=env, mask=cfg.env.wrapper.joint_masking_action_space ) env = BatchCompitableWrapper(env=env) return env # batched version of the env that returns an observation of shape (b, c) def get_classifier(pretrained_path, config_path, device="mps"): if pretrained_path is None or config_path is None: return None from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg from lerobot.common.policies.hilserl.classifier.configuration_classifier import ( ClassifierConfig, ) from lerobot.common.policies.hilserl.classifier.modeling_classifier import ( Classifier, ) cfg = init_hydra_config(config_path) classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) classifier_config.num_cameras = len( cfg.training.image_keys ) # TODO automate these paths model = Classifier(classifier_config) model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict()) model = model.to(device) return model def replay_episode(env, repo_id, root=None, episode=0): from lerobot.common.datasets.lerobot_dataset import LeRobotDataset local_files_only = root is not None dataset = LeRobotDataset( repo_id, root=root, episodes=[episode], local_files_only=local_files_only ) actions = dataset.hf_dataset.select_columns("action") for idx in range(dataset.num_frames): start_episode_t = time.perf_counter() action = actions[idx]["action"][:4] print(action) env.step((action / env.unwrapped.delta, False)) dt_s = time.perf_counter() - start_episode_t busy_wait(1 / 10 - dt_s) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--fps", type=int, default=30, help="control frequency") parser.add_argument( "--robot-path", type=str, default="lerobot/configs/robot/koch.yaml", help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.", ) parser.add_argument( "--robot-overrides", type=str, nargs="*", help="Any key=value arguments to override config values (use dots for.nested=overrides)", ) parser.add_argument( "-p", "--pretrained-policy-name-or-path", help=( "Either the repo ID of a model hosted on the Hub or a path to a directory containing weights " "saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch " "(useful for debugging). This argument is mutually exclusive with `--config`." ), ) parser.add_argument( "--config", help=( "Path to a yaml config you want to use for initializing a policy from scratch (useful for " "debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)." ), ) parser.add_argument( "--display-cameras", help=("Whether to display the camera feed while the rollout is happening"), ) parser.add_argument( "--reward-classifier-pretrained-path", type=str, default=None, help="Path to the pretrained classifier weights.", ) parser.add_argument( "--reward-classifier-config-file", type=str, default=None, help="Path to a yaml config file that is necessary to build the reward classifier model.", ) parser.add_argument( "--env-path", type=str, default=None, help="Path to the env yaml file" ) parser.add_argument( "--env-overrides", type=str, default=None, help="Overrides for the env yaml file", ) parser.add_argument( "--control-time-s", type=float, default=20, help="Maximum episode length in seconds", ) parser.add_argument( "--reset-follower-pos", type=int, default=1, help="Reset follower between episodes", ) parser.add_argument( "--replay-repo-id", type=str, default=None, help="Repo ID of the episode to replay", ) parser.add_argument( "--replay-root", type=str, default=None, help="Root of the dataset to replay" ) parser.add_argument( "--replay-episode", type=int, default=0, help="Episode to replay" ) args = parser.parse_args() robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides) robot = make_robot(robot_cfg) reward_classifier = get_classifier( args.reward_classifier_pretrained_path, args.reward_classifier_config_file ) user_relative_joint_positions = True cfg = init_hydra_config(args.env_path, args.env_overrides) env = make_robot_env( robot, reward_classifier, cfg.env, # .wrapper, ) env.reset() if args.replay_repo_id is not None: replay_episode( env, args.replay_repo_id, root=args.replay_root, episode=args.replay_episode ) exit() # Retrieve the robot's action space for joint commands. action_space_robot = env.action_space.spaces[0] # Initialize the smoothed action as a random sample. smoothed_action = action_space_robot.sample() # Smoothing coefficient (alpha) defines how much of the new random sample to mix in. # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. alpha = 0.4 while True: start_loop_s = time.perf_counter() # Sample a new random action from the robot's action space. new_random_action = action_space_robot.sample() # Update the smoothed action using an exponential moving average. smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action # Execute the step: wrap the NumPy action in a torch tensor. obs, reward, terminated, truncated, info = env.step( (torch.from_numpy(smoothed_action), False) ) if terminated or truncated: env.reset() dt_s = time.perf_counter() - start_loop_s busy_wait(1 / args.fps - dt_s)