mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 21:01:26 +00:00
Refactor kinematics and switch to using placo (#1322)
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: leo-berte <leonardo.bertelli96@gmail.com>
This commit is contained in:
@@ -21,8 +21,7 @@ from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import cv2
|
||||
|
||||
# import torch.nn.functional as F # noqa: N812
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # type: ignore # noqa: N812
|
||||
from tqdm import tqdm # type: ignore
|
||||
|
||||
@@ -224,7 +223,8 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
cropped = F.crop(value, top, left, height, width)
|
||||
value = F.resize(cropped, resize_size)
|
||||
value = value.clamp(0, 1)
|
||||
|
||||
if key.startswith("complementary_info") and isinstance(value, torch.Tensor) and value.dim() == 0:
|
||||
value = value.unsqueeze(0)
|
||||
new_frame[key] = value
|
||||
|
||||
new_dataset.add_frame(new_frame, task=task)
|
||||
@@ -265,8 +265,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
type=bool,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to push the new dataset to the hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -254,20 +254,19 @@ class RobotEnv(gym.Env):
|
||||
self._joint_names = [f"{key}.pos" for key in self.robot.bus.motors]
|
||||
self._image_keys = self.robot.cameras.keys()
|
||||
|
||||
# Read initial joint positions using the bus
|
||||
self.current_joint_positions = self._get_observation()["agent_pos"]
|
||||
self.current_observation = None
|
||||
|
||||
self.use_gripper = use_gripper
|
||||
|
||||
self._setup_spaces()
|
||||
|
||||
def _get_observation(self) -> np.ndarray:
|
||||
def _get_observation(self) -> dict[str, np.ndarray]:
|
||||
"""Helper to convert a dictionary from bus.sync_read to an ordered numpy array."""
|
||||
obs_dict = self.robot.get_observation()
|
||||
joint_positions = np.array([obs_dict[name] for name in self._joint_names], dtype=np.float32)
|
||||
joint_positions = np.array([obs_dict[name] for name in self._joint_names])
|
||||
|
||||
images = {key: obs_dict[key] for key in self._image_keys}
|
||||
return {"agent_pos": joint_positions, "pixels": images}
|
||||
self.current_observation = {"agent_pos": joint_positions, "pixels": images}
|
||||
|
||||
def _setup_spaces(self):
|
||||
"""
|
||||
@@ -281,24 +280,24 @@ class RobotEnv(gym.Env):
|
||||
- The action space is defined as a Box space representing joint position commands. It is defined as relative (delta)
|
||||
or absolute, based on the configuration.
|
||||
"""
|
||||
example_obs = self._get_observation()
|
||||
self._get_observation()
|
||||
|
||||
observation_spaces = {}
|
||||
|
||||
# Define observation spaces for images and other states.
|
||||
if "pixels" in example_obs:
|
||||
prefix = "observation.images" if len(example_obs["pixels"]) > 1 else "observation.image"
|
||||
if "pixels" in self.current_observation:
|
||||
prefix = "observation.images"
|
||||
observation_spaces = {
|
||||
f"{prefix}.{key}": gym.spaces.Box(
|
||||
low=0, high=255, shape=example_obs["pixels"][key].shape, dtype=np.uint8
|
||||
low=0, high=255, shape=self.current_observation["pixels"][key].shape, dtype=np.uint8
|
||||
)
|
||||
for key in example_obs["pixels"]
|
||||
for key in self.current_observation["pixels"]
|
||||
}
|
||||
|
||||
observation_spaces["observation.state"] = gym.spaces.Box(
|
||||
low=0,
|
||||
high=10,
|
||||
shape=example_obs["agent_pos"].shape,
|
||||
shape=self.current_observation["agent_pos"].shape,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
@@ -340,14 +339,12 @@ class RobotEnv(gym.Env):
|
||||
|
||||
self.robot.reset()
|
||||
|
||||
# Capture the initial observation.
|
||||
observation = self._get_observation()
|
||||
|
||||
# Reset episode tracking variables.
|
||||
self.current_step = 0
|
||||
self.episode_data = None
|
||||
|
||||
return observation, {"is_intervention": False}
|
||||
self.current_observation = None
|
||||
self._get_observation()
|
||||
return self.current_observation, {"is_intervention": False}
|
||||
|
||||
def step(self, action) -> tuple[dict[str, np.ndarray], float, bool, bool, dict[str, Any]]:
|
||||
"""
|
||||
@@ -367,8 +364,6 @@ class RobotEnv(gym.Env):
|
||||
- truncated (bool): True if the episode was truncated (e.g., time constraints).
|
||||
- info (dict): Additional debugging information including intervention status.
|
||||
"""
|
||||
self.current_joint_positions = self._get_observation()["agent_pos"]
|
||||
|
||||
action_dict = {"delta_x": action[0], "delta_y": action[1], "delta_z": action[2]}
|
||||
|
||||
# 1.0 action corresponds to no-op action
|
||||
@@ -376,6 +371,8 @@ class RobotEnv(gym.Env):
|
||||
|
||||
self.robot.send_action(action_dict)
|
||||
|
||||
self._get_observation()
|
||||
|
||||
if self.display_cameras:
|
||||
self.render()
|
||||
|
||||
@@ -386,7 +383,7 @@ class RobotEnv(gym.Env):
|
||||
truncated = False
|
||||
|
||||
return (
|
||||
self._get_observation(),
|
||||
self.current_observation,
|
||||
reward,
|
||||
terminated,
|
||||
truncated,
|
||||
@@ -399,11 +396,10 @@ class RobotEnv(gym.Env):
|
||||
"""
|
||||
import cv2
|
||||
|
||||
observation = self._get_observation()
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
image_keys = [key for key in self.current_observation if "image" in key]
|
||||
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.imshow(key, cv2.cvtColor(self.current_observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
def close(self):
|
||||
@@ -520,7 +516,10 @@ class AddCurrentToObservation(gym.ObservationWrapper):
|
||||
Returns:
|
||||
The modified observation with current values.
|
||||
"""
|
||||
present_current_observation = self.unwrapped._get_observation()["agent_pos"]
|
||||
present_current_dict = self.env.unwrapped.robot.bus.sync_read("Present_Current")
|
||||
present_current_observation = np.array(
|
||||
[present_current_dict[name] for name in self.env.unwrapped.robot.bus.motors]
|
||||
)
|
||||
observation["agent_pos"] = np.concatenate(
|
||||
[observation["agent_pos"], present_current_observation], axis=-1
|
||||
)
|
||||
@@ -1090,13 +1089,10 @@ class EEObservationWrapper(gym.ObservationWrapper):
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
# Initialize kinematics instance for the appropriate robot type
|
||||
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so101")
|
||||
if "so100" in robot_type or "so101" in robot_type:
|
||||
# Note to be compatible with the rest of the codebase,
|
||||
# we are using the new calibration method for so101 and so100
|
||||
robot_type = "so_new_calibration"
|
||||
self.kinematics = RobotKinematics(robot_type)
|
||||
self.kinematics = RobotKinematics(
|
||||
urdf_path=env.unwrapped.robot.config.urdf_path,
|
||||
target_frame_name=env.unwrapped.robot.config.target_frame_name,
|
||||
)
|
||||
|
||||
def observation(self, observation):
|
||||
"""
|
||||
@@ -1108,9 +1104,9 @@ class EEObservationWrapper(gym.ObservationWrapper):
|
||||
Returns:
|
||||
Enhanced observation with end-effector pose information.
|
||||
"""
|
||||
current_joint_pos = self.unwrapped._get_observation()["agent_pos"]
|
||||
current_joint_pos = self.unwrapped.current_observation["agent_pos"]
|
||||
|
||||
current_ee_pos = self.kinematics.forward_kinematics(current_joint_pos, frame="gripper_tip")[:3, 3]
|
||||
current_ee_pos = self.kinematics.forward_kinematics(current_joint_pos)[:3, 3]
|
||||
observation["agent_pos"] = np.concatenate([observation["agent_pos"], current_ee_pos], -1)
|
||||
return observation
|
||||
|
||||
@@ -1157,12 +1153,10 @@ class BaseLeaderControlWrapper(gym.Wrapper):
|
||||
self.event_lock = Lock() # Thread-safe access to events
|
||||
|
||||
# Initialize robot control
|
||||
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so101")
|
||||
if "so100" in robot_type or "so101" in robot_type:
|
||||
# Note to be compatible with the rest of the codebase,
|
||||
# we are using the new calibration method for so101 and so100
|
||||
robot_type = "so_new_calibration"
|
||||
self.kinematics = RobotKinematics(robot_type)
|
||||
self.kinematics = RobotKinematics(
|
||||
urdf_path=env.unwrapped.robot.config.urdf_path,
|
||||
target_frame_name=env.unwrapped.robot.config.target_frame_name,
|
||||
)
|
||||
self.leader_torque_enabled = True
|
||||
self.prev_leader_gripper = None
|
||||
|
||||
@@ -1260,14 +1254,14 @@ class BaseLeaderControlWrapper(gym.Wrapper):
|
||||
leader_pos_dict = self.robot_leader.bus.sync_read("Present_Position")
|
||||
follower_pos_dict = self.robot_follower.bus.sync_read("Present_Position")
|
||||
|
||||
leader_pos = np.array([leader_pos_dict[name] for name in leader_pos_dict], dtype=np.float32)
|
||||
follower_pos = np.array([follower_pos_dict[name] for name in follower_pos_dict], dtype=np.float32)
|
||||
leader_pos = np.array([leader_pos_dict[name] for name in leader_pos_dict])
|
||||
follower_pos = np.array([follower_pos_dict[name] for name in follower_pos_dict])
|
||||
|
||||
self.leader_tracking_error_queue.append(np.linalg.norm(follower_pos[:-1] - leader_pos[:-1]))
|
||||
|
||||
# [:3, 3] Last column of the transformation matrix corresponds to the xyz translation
|
||||
leader_ee = self.kinematics.forward_kinematics(leader_pos, frame="gripper_tip")[:3, 3]
|
||||
follower_ee = self.kinematics.forward_kinematics(follower_pos, frame="gripper_tip")[:3, 3]
|
||||
leader_ee = self.kinematics.forward_kinematics(leader_pos)[:3, 3]
|
||||
follower_ee = self.kinematics.forward_kinematics(follower_pos)[:3, 3]
|
||||
|
||||
action = np.clip(leader_ee - follower_ee, -self.end_effector_step_sizes, self.end_effector_step_sizes)
|
||||
# Normalize the action to the range [-1, 1]
|
||||
@@ -1341,6 +1335,9 @@ class BaseLeaderControlWrapper(gym.Wrapper):
|
||||
# NOTE:
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
|
||||
if isinstance(action, np.ndarray):
|
||||
action = torch.from_numpy(action)
|
||||
|
||||
# Add intervention info
|
||||
info["is_intervention"] = is_intervention
|
||||
info["action_intervention"] = action
|
||||
@@ -1877,7 +1874,6 @@ def make_robot_env(cfg: EnvConfig) -> gym.Env:
|
||||
if cfg.robot is None:
|
||||
raise ValueError("RobotConfig (cfg.robot) must be provided for gym_manipulator environment.")
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
|
||||
teleop_device = make_teleoperator_from_config(cfg.teleop)
|
||||
teleop_device.connect()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user