Files
lerobot-clone/src/lerobot/scripts/rl/gym_manipulator.py
2025-08-02 19:06:56 +02:00

1068 lines
38 KiB
Python

# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from dataclasses import dataclass, field
from typing import Any
import gymnasium as gym
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.types import PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.envs.configs import EnvConfig
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import (
DeviceProcessor,
ImageProcessor,
RobotProcessor,
StateProcessor,
)
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
from lerobot.robots import ( # noqa: F401
RobotConfig,
make_robot_from_config,
so100_follower,
)
from lerobot.robots.robot import Robot
from lerobot.teleoperators import (
gamepad, # noqa: F401
keyboard, # noqa: F401
make_teleoperator_from_config,
so101_leader, # noqa: F401
)
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
logging.basicConfig(level=logging.INFO)
def create_transition(
observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info if info is not None else {},
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
}
def reset_follower_position(robot_arm: Robot, target_position: np.ndarray):
current_position_dict = robot_arm.bus.sync_read("Present_Position")
current_position = np.array(
[current_position_dict[name] for name in current_position_dict], dtype=np.float32
)
trajectory = torch.from_numpy(
np.linspace(current_position, target_position, 50)
) # NOTE: 30 is just an arbitrary number
for pose in trajectory:
action_dict = dict(zip(current_position_dict, pose, strict=False))
robot_arm.bus.sync_write("Goal_Position", action_dict)
busy_wait(0.015)
class RobotEnv(gym.Env):
"""
Gym-compatible environment for evaluating robotic control policies with integrated human intervention.
This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta)
and absolute joint position commands and automatically configures its observation and action spaces based on the robot's
sensors and configuration.
"""
def __init__(
self,
robot,
use_gripper: bool = False,
display_cameras: bool = False,
reset_pose: list[float] = None,
reset_time_s: float = 5.0,
):
"""
Initialize the RobotEnv environment.
The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup
supports both relative (delta) adjustments and absolute joint positions for controlling the robot.
Args:
robot: The robot interface object used to connect and interact with the physical robot.
display_cameras: If True, the robot's camera feeds will be displayed during execution.
"""
super().__init__()
self.robot = robot
self.display_cameras = display_cameras
# Connect to the robot if not already connected.
if not self.robot.is_connected:
self.robot.connect()
# Episode tracking.
self.current_step = 0
self.episode_data = None
self._joint_names = [f"{key}.pos" for key in self.robot.bus.motors]
self._image_keys = self.robot.cameras.keys()
self.current_observation = None
self.reset_pose = reset_pose
self.reset_time_s = reset_time_s
self.use_gripper = use_gripper
self._setup_spaces()
def _get_observation(self) -> dict[str, np.ndarray]:
"""Helper to convert a dictionary from bus.sync_read to an ordered numpy array."""
obs_dict = self.robot.get_observation()
joint_positions = np.array([obs_dict[name] for name in self._joint_names])
images = {key: obs_dict[key] for key in self._image_keys}
self.current_observation = {"agent_pos": joint_positions, "pixels": images}
def _setup_spaces(self):
"""
Dynamically configure the observation and action spaces based on the robot's capabilities.
Observation Space:
- For keys with "image": A Box space with pixel values ranging from 0 to 255.
- For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range.
Action Space:
- The action space is defined as a Box space representing joint position commands. It is defined as relative (delta)
or absolute, based on the configuration.
"""
self._get_observation()
observation_spaces = {}
# Define observation spaces for images and other states.
if "pixels" in self.current_observation:
prefix = "observation.images"
observation_spaces = {
f"{prefix}.{key}": gym.spaces.Box(
low=0, high=255, shape=self.current_observation["pixels"][key].shape, dtype=np.uint8
)
for key in self.current_observation["pixels"]
}
observation_spaces["observation.state"] = gym.spaces.Box(
low=0,
high=10,
shape=self.current_observation["agent_pos"].shape,
dtype=np.float32,
)
self.observation_space = gym.spaces.Dict(observation_spaces)
# Define the action space for joint positions along with setting an intervention flag.
action_dim = 3
bounds = {}
bounds["min"] = -np.ones(action_dim)
bounds["max"] = np.ones(action_dim)
if self.use_gripper:
action_dim += 1
bounds["min"] = np.concatenate([bounds["min"], [0]])
bounds["max"] = np.concatenate([bounds["max"], [2]])
self.action_space = gym.spaces.Box(
low=bounds["min"],
high=bounds["max"],
shape=(action_dim,),
dtype=np.float32,
)
def reset(self, seed=None, options=None) -> tuple[dict[str, np.ndarray], dict[str, Any]]:
"""
Reset the environment to its initial state.
This method resets the step counter and clears any episodic data.
Args:
seed: A seed for random number generation to ensure reproducibility.
options: Additional options to influence the reset behavior.
Returns:
A tuple containing:
- observation (dict): The initial sensor observation.
- info (dict): A dictionary with supplementary information, including the key "is_intervention".
"""
# Reset the robot
# self.robot.reset()
start_time = time.perf_counter()
if self.reset_pose is not None:
log_say("Reset the environment.", play_sounds=True)
reset_follower_position(self.robot, self.reset_pose)
log_say("Reset the environment done.", play_sounds=True)
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
super().reset(seed=seed, options=options)
# Reset episode tracking variables.
self.current_step = 0
self.episode_data = None
self.current_observation = None
self._get_observation()
return self.current_observation, {
"is_intervention": False,
"raw_joint_positions": self.current_observation["agent_pos"],
}
def step(self, action) -> tuple[dict[str, np.ndarray], float, bool, bool, dict[str, Any]]:
joint_targets_dict = {f"{key}.pos": action[i] for i, key in enumerate(self.robot.bus.motors.keys())}
self.robot.send_action(joint_targets_dict)
self._get_observation()
if self.display_cameras:
self.render()
self.current_step += 1
reward = 0.0
terminated = False
truncated = False
return (
self.current_observation,
reward,
terminated,
truncated,
{"is_intervention": False, "raw_joint_positions": self.current_observation["agent_pos"]},
)
def render(self):
"""
Render the current state of the environment by displaying the robot's camera feeds.
"""
import cv2
image_keys = [key for key in self.current_observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(self.current_observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
def close(self):
"""
Close the environment and clean up resources by disconnecting the robot.
If the robot is currently connected, this method properly terminates the connection to ensure that all
associated resources are released.
"""
if self.robot.is_connected:
self.robot.disconnect()
@dataclass
@ProcessorStepRegistry.register("joint_velocity_processor")
class JointVelocityProcessor:
"""Add joint velocity information to observations.
Computes joint velocities by tracking changes in joint positions over time.
"""
joint_velocity_limits: float = 100.0
dt: float = 1.0 / 10
last_joint_positions: torch.Tensor | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
# Get current joint positions (assuming they're in observation.state)
current_positions = observation.get("observation.state")
if current_positions is None:
return transition
# Initialize last joint positions if not already set
if self.last_joint_positions is None:
self.last_joint_positions = current_positions.clone()
# Compute velocities
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
self.last_joint_positions = current_positions.clone()
# Extend observation with velocities
extended_state = torch.cat([current_positions, joint_velocities], dim=-1)
# Create new observation dict
new_observation = dict(observation)
new_observation["observation.state"] = extended_state
# Return new transition
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"joint_velocity_limits": self.joint_velocity_limits,
"fps": self.fps,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
self.last_joint_positions = None
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("current_processor")
class MotorCurrentProcessor:
"""Add motor current information to observations."""
env: gym.Env = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
# Get current values from complementary_data (where robot state would be stored)
present_current_dict = self.env.unwrapped.robot.bus.sync_read("Present_Current")
motor_currents = torch.tensor(
[present_current_dict[name] for name in self.env.unwrapped.robot.bus.motors],
dtype=torch.float32,
).unsqueeze(0)
current_state = observation.get("observation.state")
if current_state is None:
return transition
extended_state = torch.cat([current_state, motor_currents], dim=-1)
# Create new observation dict
new_observation = dict(observation)
new_observation["observation.state"] = extended_state
# Return new transition
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("image_crop_resize_processor")
class ImageCropResizeProcessor:
"""Crop and resize image observations."""
crop_params_dict: dict[str, tuple[int, int, int, int]]
resize_size: tuple[int, int] = (128, 128)
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
if self.resize_size is None and not self.crop_params_dict:
return transition
new_observation = dict(observation)
# Process all image keys in the observation
for key in observation:
if "image" not in key:
continue
image = observation[key]
device = image.device
if device.type == "mps":
image = image.cpu()
# Crop if crop params are provided for this key
if key in self.crop_params_dict:
crop_params = self.crop_params_dict[key]
image = F.crop(image, *crop_params)
# Always resize
image = F.resize(image, self.resize_size)
image = image.clamp(0.0, 1.0)
new_observation[key] = image.to(device)
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"crop_params_dict": self.crop_params_dict,
"resize_size": self.resize_size,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("time_limit_processor")
class TimeLimitProcessor:
"""Track episode time and enforce time limits."""
max_episode_steps: int
current_step: int = 0
def __call__(self, transition: EnvTransition) -> EnvTransition:
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is None:
return transition
self.current_step += 1
if self.current_step >= self.max_episode_steps:
truncated = True
new_transition = transition.copy()
new_transition[TransitionKey.TRUNCATED] = truncated
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"max_episode_steps": self.max_episode_steps,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
self.current_step = 0
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessor:
penalty: float = -0.01
max_gripper_pos: float = 30.0
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Calculate gripper penalty and add to complementary data."""
action = transition.get(TransitionKey.ACTION)
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None or action is None:
return transition
current_gripper_pos = complementary_data.get("raw_joint_positions", None)[-1]
if current_gripper_pos is None:
return transition
gripper_action = action[-1].item()
gripper_action_normalized = gripper_action / self.max_gripper_pos
# Normalize gripper state and action
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
gripper_action_normalized = gripper_action - 1.0
# Calculate penalty boolean as in original
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
)
gripper_penalty = self.penalty * int(gripper_penalty_bool)
# Add penalty information to complementary data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
# Create new complementary data with penalty info
new_complementary_data = dict(complementary_data)
new_complementary_data["discrete_penalty"] = gripper_penalty
# Create new transition with updated complementary data
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"penalty": self.penalty,
"max_gripper_pos": self.max_gripper_pos,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
"""Reset the processor state."""
self.last_gripper_state = None
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("intervention_action_processor")
class InterventionActionProcessor:
"""Handle action intervention based on signals in the transition.
This processor checks for intervention signals in the transition's complementary data
and overrides agent actions when intervention is active.
"""
use_gripper: bool = False
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is None:
return transition
# Get intervention signals from complementary data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
teleop_action = complementary_data.get("teleop_action", {})
is_intervention = complementary_data.get("is_intervention", False)
terminate_episode = complementary_data.get("terminate_episode", False)
success = complementary_data.get("success", False)
rerecord_episode = complementary_data.get("rerecord_episode", False)
new_transition = transition.copy()
# Override action if intervention is active
if is_intervention and teleop_action:
# Convert teleop_action dict to tensor format
action_list = [
teleop_action.get("delta_x", 0.0),
teleop_action.get("delta_y", 0.0),
teleop_action.get("delta_z", 0.0),
]
if self.use_gripper:
action_list.append(teleop_action.get("gripper", 1.0))
teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device)
new_transition[TransitionKey.ACTION] = teleop_action_tensor
# Handle episode termination
new_transition[TransitionKey.DONE] = bool(terminate_episode)
new_transition[TransitionKey.REWARD] = float(success)
# Update info with intervention metadata
info = new_transition.get(TransitionKey.INFO, {})
info["is_intervention"] = is_intervention
info["rerecord_episode"] = rerecord_episode
info["next.success"] = success if terminate_episode else info.get("next.success", False)
new_transition[TransitionKey.INFO] = info
new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] = new_transition[
TransitionKey.ACTION
]
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"use_gripper": self.use_gripper,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("inverse_kinematics_processor")
class InverseKinematicsProcessor:
"""Convert end-effector space actions to joint space using inverse kinematics.
This processor transforms delta commands in end-effector space (delta_x, delta_y, delta_z)
to joint space commands using forward and inverse kinematics. It maintains the current
end-effector pose and joint positions to compute the transformations.
"""
urdf_path: str
target_frame_name: str = "gripper_link"
end_effector_step_sizes: dict[str, float] = field(default_factory=lambda: {"x": 1.0, "y": 1.0, "z": 1.0})
end_effector_bounds: dict[str, list[float]] | None = None
max_gripper_pos: float = 30.0
env: gym.Env = None # Environment reference to get current state
# State tracking
current_ee_pos: np.ndarray | None = field(default=None, init=False, repr=False)
current_joint_pos: np.ndarray | None = field(default=None, init=False, repr=False)
kinematics: RobotKinematics | None = field(default=None, init=False, repr=False)
def __post_init__(self):
"""Initialize the kinematics module after dataclass initialization."""
if self.urdf_path:
self.kinematics = RobotKinematics(
urdf_path=self.urdf_path,
target_frame_name=self.target_frame_name,
)
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is None:
return transition
action_np = action.detach().cpu().numpy().squeeze()
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
raw_joint_positions = complementary_data.get("raw_joint_positions")
current_gripper_pos = raw_joint_positions[-1]
if self.current_joint_pos is None:
self.current_joint_pos = raw_joint_positions
# Initialize end-effector position if not available
if self.current_joint_pos is None:
return transition # Cannot proceed without joint positions
# Calculate current end-effector position using forward kinematics
if self.current_ee_pos is None:
self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos)
# Scale deltas by step sizes
delta_ee = np.array(
[
action_np[0] * self.end_effector_step_sizes["x"],
action_np[1] * self.end_effector_step_sizes["y"],
action_np[2] * self.end_effector_step_sizes["z"],
],
dtype=np.float32,
)
# Set desired end-effector position by adding delta
desired_ee_pos = np.eye(4)
desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
# Add delta to position and clip to bounds
desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + delta_ee
if self.end_effector_bounds is not None:
desired_ee_pos[:3, 3] = np.clip(
desired_ee_pos[:3, 3],
self.end_effector_bounds["min"],
self.end_effector_bounds["max"],
)
# Compute inverse kinematics to get joint positions
target_joint_values = self.kinematics.inverse_kinematics(self.current_joint_pos, desired_ee_pos)
# Update current state
self.current_ee_pos = desired_ee_pos.copy()
self.current_joint_pos = target_joint_values.copy()
# Create new action with joint space commands
gripper_action = current_gripper_pos
if len(action_np) > 3:
# Handle gripper command separately
gripper_command = action_np[3]
# Process gripper command (convert from [0,2] to delta) and discretize
gripper_delta = np.round(gripper_command - 1.0).astype(int) * self.max_gripper_pos
gripper_action = np.clip(current_gripper_pos + gripper_delta, 0, self.max_gripper_pos)
# Combine joint positions and gripper
target_joint_values[-1] = gripper_action
converted_action = torch.from_numpy(target_joint_values).to(action.device).to(action.dtype)
new_transition = transition.copy()
new_transition[TransitionKey.ACTION] = converted_action
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"urdf_path": self.urdf_path,
"target_frame_name": self.target_frame_name,
"end_effector_step_sizes": self.end_effector_step_sizes,
"end_effector_bounds": self.end_effector_bounds,
"max_gripper_pos": self.max_gripper_pos,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
"""Reset the processor state."""
self.current_ee_pos = None
self.current_joint_pos = None
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
def make_robot_env(cfg: EnvConfig) -> tuple[gym.Env, Any]:
"""
Factory function to create a robot environment.
This function builds a robot environment with all necessary wrappers
based on the provided configuration.
Args:
cfg: Configuration object containing environment parameters.
Returns:
A tuple containing:
- A gym environment with all necessary wrappers applied.
- The teleoperation device for use in action processors.
"""
robot = make_robot_from_config(cfg.robot)
teleop_device = make_teleoperator_from_config(cfg.teleop)
teleop_device.connect()
# Create base environment
env = RobotEnv(
robot=robot,
use_gripper=cfg.processor.use_gripper,
display_cameras=cfg.processor.display_cameras,
reset_pose=cfg.processor.fixed_reset_joint_positions,
)
return env, teleop_device
def make_processors(env, cfg):
"""
Factory function to create environment and action processors.
Args:
env: The robot environment
cfg: Configuration object containing processor parameters
Returns:
tuple: (env_processor, action_processor)
"""
env_pipeline_steps = [
ImageProcessor(),
StateProcessor(),
JointVelocityProcessor(dt=1.0 / cfg.fps),
MotorCurrentProcessor(env=env),
ImageCropResizeProcessor(
crop_params_dict=cfg.processor.crop_params_dict,
resize_size=cfg.processor.resize_size
),
TimeLimitProcessor(max_episode_steps=int(cfg.processor.control_time_s * cfg.fps)),
GripperPenaltyProcessor(
penalty=cfg.processor.gripper_penalty,
max_gripper_pos=cfg.processor.max_gripper_pos
),
DeviceProcessor(device=cfg.device),
]
env_processor = RobotProcessor(steps=env_pipeline_steps)
action_pipeline_steps = [
InterventionActionProcessor(
use_gripper=cfg.processor.use_gripper,
),
InverseKinematicsProcessor(
urdf_path=cfg.processor.urdf_path,
target_frame_name=cfg.processor.target_frame_name,
end_effector_step_sizes=cfg.processor.end_effector_step_sizes,
end_effector_bounds=cfg.processor.end_effector_bounds,
max_gripper_pos=cfg.processor.max_gripper_pos,
env=env,
),
]
action_processor = RobotProcessor(steps=action_pipeline_steps)
return env_processor, action_processor
def step_env_and_process_transition(
env,
transition,
action,
teleop_device,
env_processor,
action_processor,
):
"""
Execute one step with processors handling intervention and observation processing.
Args:
env: The robot environment
transition: Current transition state
action: Action to execute (will be replaced by neutral action in gym_manipulator mode)
teleop_device: Teleoperator device for getting intervention signals
env_processor: Environment processor for observations
action_processor: Action processor for handling interventions
Returns:
tuple: (new_transition, terminate_episode)
"""
# Get teleoperation action and events
teleop_action = teleop_device.get_action()
teleop_events = teleop_device.get_teleop_events()
# Create action transition
action_transition = dict(transition)
action_transition[TransitionKey.ACTION] = action
# Add teleoperation data to complementary data
action_complementary_data = action_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).copy()
action_complementary_data["teleop_action"] = teleop_action
action_complementary_data.update(teleop_events)
action_transition[TransitionKey.COMPLEMENTARY_DATA] = action_complementary_data
# Process action through action pipeline (handles intervention)
processed_action_transition = action_processor(action_transition)
# Extract processed action and metadata
processed_action = processed_action_transition[TransitionKey.ACTION]
terminate_episode = processed_action_transition.get(TransitionKey.DONE, False)
# Step environment with processed action
obs, reward, terminated, truncated, info = env.step(processed_action)
# Combine rewards from environment and action processor
reward = reward + processed_action_transition[TransitionKey.REWARD]
# Process new observation
complementary_data = {
"raw_joint_positions": info.pop("raw_joint_positions"),
**processed_action_transition[TransitionKey.COMPLEMENTARY_DATA],
}
info.update(processed_action_transition[TransitionKey.INFO])
new_transition = create_transition(
observation=obs,
action=processed_action,
reward=reward,
done=terminated or terminate_episode,
truncated=truncated,
info=info,
complementary_data=complementary_data,
)
new_transition = env_processor(new_transition)
return new_transition, terminate_episode
def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvConfig):
dt = 1.0 / cfg.fps
print(f"Starting control loop at {cfg.fps} FPS")
print("Controls:")
print("- Use gamepad/teleop device for intervention")
print("- When not intervening, robot will stay still")
print("- Press Ctrl+C to exit")
# Reset environment and processors
obs, info = env.reset()
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
transition = env_processor(transition)
if cfg.mode == "record":
action_features = teleop_device.action_features
features = {
"action": action_features,
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
"complementary_info.discrete_penalty": {
"dtype": "float32",
"shape": (1,),
"names": ["discrete_penalty"],
},
}
for key, value in transition[TransitionKey.OBSERVATION].items():
if key == "observation.state":
features[key] = {
"dtype": "float32",
"shape": value.squeeze(0).shape,
"names": None,
}
if "image" in key:
features[key] = {
"dtype": "video",
"shape": value.squeeze(0).shape,
"names": ["channels", "height", "width"],
}
# Create dataset
dataset = LeRobotDataset.create(
cfg.repo_id,
cfg.fps,
root=cfg.dataset_root,
use_videos=True,
image_writer_threads=4,
image_writer_processes=0,
features=features,
)
episode_idx = 0
episode_step = 0
episode_start_time = time.perf_counter()
while episode_idx < cfg.num_episodes:
step_start_time = time.perf_counter()
# Create a neutral action (no movement)
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
if hasattr(env, "use_gripper") and env.use_gripper:
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
# Use the new step function
transition, terminate_episode = step_env_and_process_transition(
env=env,
transition=transition,
action=neutral_action,
teleop_device=teleop_device,
env_processor=env_processor,
action_processor=action_processor,
)
terminated = transition.get(TransitionKey.DONE, False)
truncated = transition.get(TransitionKey.TRUNCATED, False)
if cfg.mode == "record":
observations = {k: v.squeeze(0) for k, v in transition[TransitionKey.OBSERVATION].items()}
frame = {
**observations,
"action": transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"],
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
"next.done": np.array([terminated or truncated], dtype=bool),
"complementary_info.discrete_penalty": np.array(
[transition[TransitionKey.COMPLEMENTARY_DATA]["discrete_penalty"]], dtype=np.float32
),
}
dataset.add_frame(frame, task=cfg.task)
episode_step += 1
# Handle episode termination
if terminated or truncated or terminate_episode:
episode_time = time.perf_counter() - episode_start_time
logging.info(
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
)
# Reset for new episode
obs, info = env.reset()
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
env_processor.reset()
action_processor.reset()
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
transition = env_processor(transition)
episode_step = 0
episode_idx += 1
if cfg.mode == "record":
if transition[TransitionKey.INFO].get("rerecord_episode", False):
logging.info(f"Re-recording episode {episode_idx}")
dataset.clear_episode_buffer()
episode_idx -= 1
else:
logging.info(f"Saving episode {episode_idx}")
dataset.save_episode()
# Maintain fps timing
busy_wait(dt - (time.perf_counter() - step_start_time))
if cfg.mode == "record" and cfg.push_to_hub:
logging.info("Pushing dataset to hub")
dataset.push_to_hub()
def replay_trajectory(env, action_processor, cfg):
dataset = LeRobotDataset(
cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode], download_videos=False
)
dataset_actions = dataset.hf_dataset.select_columns(["action"])
_, info = env.reset()
for _, action in enumerate(dataset_actions):
start_time = time.perf_counter()
transition = create_transition(
action=action["action"], complementary_data={"raw_joint_positions": info["raw_joint_positions"]}
)
transition = action_processor(transition)
env.step(transition[TransitionKey.ACTION])
busy_wait(1 / cfg.fps - (time.perf_counter() - start_time))
@parser.wrap()
def main(cfg: EnvConfig):
env, teleop_device = make_robot_env(cfg)
env_processor, action_processor = make_processors(env, cfg)
print("Environment observation space:", env.observation_space)
print("Environment action space:", env.action_space)
print("Environment processor:", env_processor)
print("Action processor:", action_processor)
if cfg.mode == "replay":
replay_trajectory(env, action_processor, cfg)
exit()
control_loop(env, env_processor, action_processor, teleop_device, cfg)
if __name__ == "__main__":
main()