diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index ec5491b2a..f1c15a1d0 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -519,11 +519,14 @@ from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import _init_rerun from lerobot.record import record_loop +from lerobot.policies.factory import make_processor NUM_EPISODES = 5 FPS = 30 EPISODE_TIME_SEC = 60 TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" +HF_DATASET_ID = "/" # Create the robot configuration camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} @@ -535,7 +538,7 @@ robot_config = SO100FollowerConfig( robot = SO100Follower(robot_config) # Initialize the policy -policy = ACTPolicy.from_pretrained("/") +policy = ACTPolicy.from_pretrained(HF_MODEL_ID) # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, "action") @@ -544,7 +547,7 @@ dataset_features = {**action_features, **obs_features} # Create the dataset dataset = LeRobotDataset.create( - repo_id="/eval_", + repo_id=HF_DATASET_ID, fps=FPS, features=dataset_features, robot_type=robot.name, @@ -559,6 +562,12 @@ _init_rerun(session_name="recording") # Connect the robot robot.connect() +preprocessor, postprocessor = make_processor( + policy_cfg=policy, + pretrained_path=HF_MODEL_ID, + dataset_stats=dataset.meta.stats, +) + for episode_idx in range(NUM_EPISODES): log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") @@ -568,6 +577,8 @@ for episode_idx in range(NUM_EPISODES): events=events, fps=FPS, policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 57fb62e10..564648329 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -1,6 +1,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import hw_to_dataset_features from lerobot.policies.act.modeling_act import ACTPolicy +from lerobot.policies.factory import make_processor from lerobot.record import record_loop from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.utils.control_utils import init_keyboard_listener @@ -11,12 +12,14 @@ NUM_EPISODES = 2 FPS = 30 EPISODE_TIME_SEC = 60 TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" +HF_DATASET_ID = "/" # Create the robot and teleoperator configurations robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi") robot = LeKiwiClient(robot_config) -policy = ACTPolicy.from_pretrained("/") +policy = ACTPolicy.from_pretrained(HF_MODEL_ID) # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, "action") @@ -25,7 +28,7 @@ dataset_features = {**action_features, **obs_features} # Create the dataset dataset = LeRobotDataset.create( - repo_id="/", + repo_id=HF_DATASET_ID, fps=FPS, features=dataset_features, robot_type=robot.name, @@ -43,6 +46,12 @@ listener, events = init_keyboard_listener() if not robot.is_connected: raise ValueError("Robot is not connected!") +preprocessor, postprocessor = make_processor( + policy_cfg=policy, + pretrained_path=HF_MODEL_ID, + dataset_stats=dataset.meta.stats, +) + recorded_episodes = 0 while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}") @@ -53,6 +62,8 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: events=events, fps=FPS, policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, diff --git a/examples/lekiwi/teleoperate.py b/examples/lekiwi/teleoperate.py index 8358a2b93..45afca0cf 100644 --- a/examples/lekiwi/teleoperate.py +++ b/examples/lekiwi/teleoperate.py @@ -38,7 +38,7 @@ while True: keyboard_keys = keyboard.get_action() base_action = robot._from_keyboard_to_base_action(keyboard_keys) - log_rerun_data(observation, {**arm_action, **base_action}) + log_rerun_data(observation=observation, action={**arm_action, **base_action}) action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action diff --git a/examples/phone_so100_eval.py b/examples/phone_so100_eval.py new file mode 100644 index 000000000..e3a577de5 --- /dev/null +++ b/examples/phone_so100_eval.py @@ -0,0 +1,158 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features +from lerobot.datasets.utils import merge_features +from lerobot.model.kinematics import RobotKinematics +from lerobot.policies.act.modeling_act import ACTPolicy +from lerobot.policies.factory import make_processor +from lerobot.processor.converters import ( + to_output_robot_action, + to_transition_robot_observation, +) +from lerobot.processor.pipeline import RobotProcessor +from lerobot.record import record_loop +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + AddRobotObservationAsComplimentaryData, + ForwardKinematicsJointsToEE, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.utils.utils import log_say +from lerobot.utils.visualization_utils import _init_rerun + +NUM_EPISODES = 5 +FPS = 30 +EPISODE_TIME_SEC = 60 +TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" +HF_DATASET_ID = "/" + +# Initialize the robot with degrees +camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem58760434471", + id="my_awesome_follower_arm", + cameras=camera_config, + use_degrees=True, +) + +# Initialize the robot +robot = SO100Follower(robot_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert ee pose action to joint action +robot_ee_to_joints = RobotProcessor( + steps=[ + AddRobotObservationAsComplimentaryData(robot=robot), + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=True, + ), + ], + to_transition=lambda tr: tr, + to_output=to_output_robot_action, +) + +# Build pipeline to convert joint observation to ee pose observation +robot_joints_to_ee_pose = RobotProcessor( + steps=[ + ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())) + ], + to_transition=to_transition_robot_observation, + to_output=lambda tr: tr, +) + +# Build dataset action and gripper features +action_ee_and_gripper = aggregate_pipeline_dataset_features( + pipeline=robot_ee_to_joints, + initial_features={}, + use_videos=True, + patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"], +) # Get all ee action features + gripper pos action features + +# Build dataset observation features +obs_ee = aggregate_pipeline_dataset_features( + pipeline=robot_joints_to_ee_pose, + initial_features=robot.observation_features, + use_videos=True, + patterns=["observation.state.ee"], +) # Get all ee observation features + +dataset_features = merge_features(obs_ee, action_ee_and_gripper) + +print("All dataset features: ", dataset_features) + +# Create the dataset +dataset = LeRobotDataset.create( + repo_id=HF_DATASET_ID, + fps=FPS, + features=dataset_features, + robot_type=robot.name, + use_videos=True, + image_writer_threads=4, +) + +# Initialize the keyboard listener and rerun visualization +_, events = init_keyboard_listener() +_init_rerun(session_name="recording_phone") + +# Connect the robot and teleoperator +robot.connect() + +episode_idx = 0 + +policy = ACTPolicy.from_pretrained(HF_MODEL_ID) +preprocessor, postprocessor = make_processor( + policy_cfg=policy, + pretrained_path=HF_MODEL_ID, + dataset_stats=dataset.meta.stats, +) + +for episode_idx in range(NUM_EPISODES): + log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") + + record_loop( + robot=robot, + events=events, + fps=FPS, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + robot_action_processor=robot_ee_to_joints, + robot_observation_processor=robot_joints_to_ee_pose, + ) + dataset.save_episode() + +# Clean up +log_say("Stop recording") +robot.disconnect() +dataset.push_to_hub() diff --git a/examples/phone_so100_record.py b/examples/phone_so100_record.py new file mode 100644 index 000000000..4ec3948ea --- /dev/null +++ b/examples/phone_so100_record.py @@ -0,0 +1,215 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features +from lerobot.datasets.utils import merge_features +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor.converters import ( + to_output_robot_action, + to_transition_robot_observation, + to_transition_teleop_action, +) +from lerobot.processor.pipeline import RobotProcessor +from lerobot.record import record_loop +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + AddRobotObservationAsComplimentaryData, + EEBoundsAndSafety, + EEReferenceAndDelta, + ForwardKinematicsJointsToEE, + GripperVelocityToJoint, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS +from lerobot.teleoperators.phone.phone import Phone +from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction +from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.utils.utils import log_say +from lerobot.utils.visualization_utils import _init_rerun + +NUM_EPISODES = 10 +FPS = 30 +EPISODE_TIME_SEC = 60 +RESET_TIME_SEC = 30 +TASK_DESCRIPTION = "My task description" +HF_REPO_ID = "/" + +# Initialize the robot and teleoperator +camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem58760434471", + id="my_awesome_follower_arm", + cameras=camera_config, + use_degrees=True, +) +teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID + +# Initialize the robot and teleoperator +robot = SO100Follower(robot_config) +phone = Phone(teleop_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert phone action to ee pose action +phone_to_robot_ee_pose = RobotProcessor( + steps=[ + MapPhoneActionToRobotAction(platform=teleop_config.phone_os), + AddRobotObservationAsComplimentaryData(robot=robot), + EEReferenceAndDelta( + kinematics=kinematics_solver, + end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, + motor_names=list(robot.bus.motors.keys()), + ), + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, + max_ee_step_m=0.20, + max_ee_twist_step_rad=0.50, + ), + ], + to_transition=to_transition_teleop_action, + to_output=lambda tr: tr, +) + +# Build pipeline to convert ee pose action to joint action +robot_ee_to_joints = RobotProcessor( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=True, + ), + GripperVelocityToJoint( + motor_names=list(robot.bus.motors.keys()), + speed_factor=20.0, + ), + ], + to_transition=lambda tr: tr, + to_output=to_output_robot_action, +) + +# Build pipeline to convert joint observation to ee pose observation +robot_joints_to_ee_pose = RobotProcessor( + steps=[ + ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())) + ], + to_transition=to_transition_robot_observation, + to_output=lambda tr: tr, +) + +# Build dataset ee action features +action_ee = aggregate_pipeline_dataset_features( + pipeline=phone_to_robot_ee_pose, + initial_features=phone.action_features, + use_videos=True, + patterns=["action.ee"], +) + +# Get gripper pos action features +gripper = aggregate_pipeline_dataset_features( + pipeline=robot_ee_to_joints, + initial_features={}, + use_videos=True, + patterns=["action.gripper.pos", "observation.state.gripper.pos"], +) + +# Build dataset ee observation features +observation_ee = aggregate_pipeline_dataset_features( + pipeline=robot_joints_to_ee_pose, + initial_features=robot.observation_features, + use_videos=True, + patterns=["observation.state.ee"], +) + +dataset_features = merge_features(action_ee, gripper, observation_ee) + +print("All dataset features: ", dataset_features) + +# Create the dataset +dataset = LeRobotDataset.create( + repo_id=HF_REPO_ID, + fps=FPS, + features=dataset_features, + robot_type=robot.name, + use_videos=True, + image_writer_threads=4, +) + +# Initialize the keyboard listener and rerun visualization +_, events = init_keyboard_listener() +_init_rerun(session_name="recording_phone") + +# Connect the robot and teleoperator +robot.connect() +phone.connect() + +episode_idx = 0 +while episode_idx < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") + + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop=phone, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=phone_to_robot_ee_pose, + robot_action_processor=robot_ee_to_joints, + robot_observation_processor=robot_joints_to_ee_pose, + ) + + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop=phone, + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=phone_to_robot_ee_pose, + robot_action_processor=robot_ee_to_joints, + robot_observation_processor=robot_joints_to_ee_pose, + ) + + if events["rerecord_episode"]: + log_say("Re-recording episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + dataset.save_episode() + episode_idx += 1 + +# Clean up +log_say("Stop recording") +robot.disconnect() +phone.disconnect() +dataset.push_to_hub() diff --git a/examples/phone_so100_replay.py b/examples/phone_so100_replay.py new file mode 100644 index 000000000..f44207789 --- /dev/null +++ b/examples/phone_so100_replay.py @@ -0,0 +1,106 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor.converters import to_output_robot_action +from lerobot.processor.pipeline import RobotProcessor +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + AddRobotObservationAsComplimentaryData, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.utils.robot_utils import busy_wait +from lerobot.utils.utils import log_say + +EPISODE_IDX = 0 +HF_REPO_ID = "/" + +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True +) +robot = SO100Follower(robot_config) +robot.connect() + +dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX]) +actions = dataset.hf_dataset.select_columns("action") + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + + +# This method converts the action from the dataset to a transition for pipeline +def action_to_transition(action: dict): + act = {} + + # EE pose + for k in ("ee.x", "ee.y", "ee.z", "ee.wx", "ee.wy", "ee.wz"): + if k in action: + act[f"action.{k}"] = float(action[k]) + + # Gripper: your dataset has absolute position + if "gripper.pos" in action: + act["action.gripper.pos"] = float(action["gripper.pos"]) + + return { + "observation": None, + "action": act, + "reward": None, + "done": False, + "truncated": False, + "info": {}, + "complementary_data": {}, + } + + +# Build pipeline to convert ee pose action to joint action +robot_ee_to_joints = RobotProcessor( + steps=[ + AddRobotObservationAsComplimentaryData(robot=robot), + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=False, # Because replay is open loop + ), + ], + to_transition=action_to_transition, + to_output=to_output_robot_action, +) + +robot_ee_to_joints.reset() + +log_say(f"Replaying episode {EPISODE_IDX}") +for idx in range(dataset.num_frames): + t0 = time.perf_counter() + + ee_action = { + name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) + } + + joint_action = robot_ee_to_joints(ee_action) + action_sent = robot.send_action(joint_action) + + busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0)) + +robot.disconnect() diff --git a/examples/phone_so100_teleop.py b/examples/phone_so100_teleop.py new file mode 100644 index 000000000..82515c98f --- /dev/null +++ b/examples/phone_so100_teleop.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specif + +import time + +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import RobotProcessor +from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + AddRobotObservationAsComplimentaryData, + EEBoundsAndSafety, + EEReferenceAndDelta, + GripperVelocityToJoint, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS +from lerobot.teleoperators.phone.phone import Phone +from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction + +# Initialize the robot and teleoperator +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True +) +teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID + +# Initialize the robot and teleoperator +robot = SO100Follower(robot_config) +teleop_device = Phone(teleop_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert phone action to ee pose action +phone_to_robot_ee_pose = RobotProcessor( + steps=[ + MapPhoneActionToRobotAction(platform=teleop_config.phone_os), + AddRobotObservationAsComplimentaryData(robot=robot), + EEReferenceAndDelta( + kinematics=kinematics_solver, + end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, + motor_names=list(robot.bus.motors.keys()), + ), + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, + max_ee_step_m=0.10, + max_ee_twist_step_rad=0.50, + ), + ], + to_transition=to_transition_teleop_action, + to_output=lambda tr: tr, +) + +# Build pipeline to convert ee pose action to joint action +robot_ee_to_joints = RobotProcessor( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + ), + GripperVelocityToJoint( + motor_names=list(robot.bus.motors.keys()), + speed_factor=20.0, + ), + ], + to_transition=lambda tr: tr, + to_output=to_output_robot_action, +) + +robot.connect() +teleop_device.connect() + +print("Starting teleop loop. Move your phone to teleoperate the robot.") +while True: + phone_obs = teleop_device.get_action() + if not phone_obs: + time.sleep(0.01) + continue + + # Get teleop observation + phone_obs = teleop_device.get_action() + + # Phone to EE pose transition + ee_transition = phone_to_robot_ee_pose(phone_obs) + + # EE pose to Joints transition + joint_action = robot_ee_to_joints(ee_transition) + + if joint_action: + robot.send_action(joint_action) + + time.sleep(0.01) diff --git a/pyproject.toml b/pyproject.toml index 968005281..bdd634f71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ intelrealsense = [ "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'", "pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'", ] +phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"] # stretch = [ # "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'", # "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'", @@ -152,7 +153,8 @@ all = [ "lerobot[video_benchmark]", "lerobot[aloha]", "lerobot[pusht]", - "lerobot[xarm]" + "lerobot[xarm]", + "lerobot[phone]", ] [project.scripts] diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py new file mode 100644 index 000000000..fef75b407 --- /dev/null +++ b/src/lerobot/datasets/pipeline_features.py @@ -0,0 +1,94 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +from typing import Any + +from lerobot.datasets.utils import hw_to_dataset_features +from lerobot.processor.pipeline import RobotProcessor + + +def aggregate_pipeline_dataset_features( + pipeline: RobotProcessor, + initial_features: dict[str, Any], + *, + use_videos: bool = True, + patterns: Sequence[str] | None = None, +) -> dict[str, dict]: + """ + Aggregates the pipeline's features and returns a features dict ready for the dataset, + filtered to only those keys matching any of the given patterns (for action/state only). + + - `initial_features`: raw camera specs, e.g. {"front": (h,w,c), ...} + - `use_videos`: whether to treat image features as video streams + - `patterns`: regexes to filter action & state features; images are included + whenever use_videos=True, regardless of patterns. + """ + import re + + # Gather everything the pipeline features specifies, seeded with hardware cams: + all_features = pipeline.transform_features(initial_features) + + # Helper to decide which action/state keys survive the `patterns` filter: + def keep(key: str) -> bool: + if patterns is None: + return True + return any(re.search(pat, key) for pat in patterns) + + # Start with hardware dict, injecting initial cameras if videos are ON: + hw: dict[str, dict[str, Any]] = {} + if use_videos: + cams = { + name: shape + for name, shape in initial_features.items() + if isinstance(shape, tuple) and len(shape) == 3 + } + if cams: + hw["observation"] = dict(cams) + + # Go over every feature from the pipeline and merge: + for full_key, ty in all_features.items(): + if full_key.startswith("action."): + # action. + if not keep(full_key): + continue + name = full_key[len("action.") :] + hw.setdefault("action", {})[name] = ty + + elif full_key.startswith("observation.state."): + # observation.state. + if not keep(full_key): + continue + name = full_key[len("observation.state.") :] + hw.setdefault("observation", {})[name] = ty + + elif full_key.startswith("observation.images."): + # observation.images. + # images obey ONLY the use_videos flag, not patterns + if not use_videos: + continue + name = full_key[len("observation.images.") :] + hw.setdefault("observation", {})[name] = ty + + else: + # anything else (e.g. policy-only features) is ignored here + continue + + out: dict[str, dict] = {} + if "action" in hw: + out.update(hw_to_dataset_features(hw["action"], "action", use_videos)) + if "observation" in hw: + out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos)) + + return out diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 078c5351d..db60e63b3 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -470,6 +470,50 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea return policy_features +def merge_features(*dicts: dict) -> dict: + """ + Merge LeRobot grouped feature dicts. + + - For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape. + - For others (observation.images.*), last one wins (if they are identical). + """ + out: dict = {} + for d in dicts: + for key, value in d.items(): + if not isinstance(value, dict): + out[key] = value + continue + + dtype = value.get("dtype") + shape = value.get("shape") + is_vector = ( + dtype not in ("image", "video", "string") + and isinstance(shape, tuple) + and len(shape) == 1 + and "names" in value + ) + + if is_vector: + # Initialize or retrieve the accumulating dict for this feature key + target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)}) + # Ensure consistent data types across merged entries + if "dtype" in target and dtype != target["dtype"]: + raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}") + + # Merge feature names: append only new ones to preserve order without duplicates + seen = set(target["names"]) + for n in value["names"]: + if n not in seen: + target["names"].append(n) + seen.add(n) + # Recompute the shape to reflect the updated number of features + target["shape"] = (len(target["names"]),) + else: + # For images/videos and non-1D entries: override with the latest definition + out[key] = value + return out + + def create_empty_dataset_info( codebase_version: str, fps: int, diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py index 06cb9848a..4c411dd66 100644 --- a/src/lerobot/policies/pi0/processor_pi0.py +++ b/src/lerobot/policies/pi0/processor_pi0.py @@ -65,8 +65,8 @@ class Pi0NewLineProcessor(ProcessorStep): return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - """Add tokenized task features to the feature contract.""" + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """Add tokenized task features to the features.""" return features def state_dict(self) -> dict[str, torch.Tensor]: diff --git a/src/lerobot/policies/smolvla/processor_smolvla.py b/src/lerobot/policies/smolvla/processor_smolvla.py index 5a8caec60..2c0221f9e 100644 --- a/src/lerobot/policies/smolvla/processor_smolvla.py +++ b/src/lerobot/policies/smolvla/processor_smolvla.py @@ -88,8 +88,8 @@ class SmolVLANewLineProcessor(ProcessorStep): return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - """Add tokenized task features to the feature contract.""" + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """Adds nothing to the features.""" return features def state_dict(self) -> dict[str, torch.Tensor]: diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index 40017760b..8a74afd3e 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -17,6 +17,7 @@ from typing import Any import torch from torch import Tensor +from lerobot.configs.types import PolicyFeature from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey @@ -134,6 +135,5 @@ class ToBatchProcessor: """Reset processor state (no-op for this processor).""" pass - def feature_contract(self, features: dict[str, Any]) -> dict[str, Any]: - """Return features (no-op for this processor).""" + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py new file mode 100644 index 000000000..f0e081577 --- /dev/null +++ b/src/lerobot/processor/converters.py @@ -0,0 +1,225 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from copy import deepcopy +from typing import Any + +import numpy as np +import torch +from scipy.spatial.transform import Rotation + +from .pipeline import EnvTransition, TransitionKey + + +def _to_tensor(x: torch.Tensor | np.ndarray | Sequence[int | float]): + if isinstance(x, torch.Tensor): + return x + if isinstance(x, np.ndarray): + # Keep images (uint8 HWC) and python objects as-is + if x.dtype == np.uint8 or x.dtype == np.object_: + return x + # Scalars/arrays to float32 tensor + return torch.as_tensor(x, dtype=torch.float32) + # Anything else to float32 tensor + return torch.as_tensor(x, dtype=torch.float32) + + +def _from_tensor(x: Any): + if isinstance(x, torch.Tensor): + return x.item() if x.numel() == 1 else x.detach().cpu().numpy() + return x + + +def _is_image(arr: Any) -> bool: + return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3 + + +def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + state, images = {}, {} + for k, v in obs.items(): + if _is_image(v): + images[k] = v + else: + state[k] = v + return state, images + + +def make_obs_act_transition( + *, obs: dict[str, Any] | None = None, act: dict[str, Any] | None = None +) -> EnvTransition: + return { + TransitionKey.OBSERVATION: {} if obs is None else obs, + TransitionKey.ACTION: {} if act is None else act, + TransitionKey.INFO: {}, + TransitionKey.COMPLEMENTARY_DATA: {}, + TransitionKey.REWARD: None, + TransitionKey.DONE: None, + TransitionKey.TRUNCATED: None, + } + + +def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition: + """ + Convert a raw teleop action dict into an EnvTransition under the ACTION TransitionKey. + """ + act_dict: dict[str, Any] = {} + for k, v in action.items(): + # Check if the value is a type that should not be converted to a tensor. + if isinstance(v, (Rotation, dict)): + act_dict[f"action.{k}"] = v + continue + + arr = np.array(v) if np.isscalar(v) else v + act_dict[f"action.{k}"] = _to_tensor(arr) + + return make_obs_act_transition(act=act_dict) + + +# TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself +def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransition: + """ + Convert a raw robot observation dict into an EnvTransition under the OBSERVATION TransitionKey. + """ + state, images = _split_obs_to_state_and_images(observation) + + obs_dict: dict[str, Any] = {} + for k, v in state.items(): + arr = np.array(v) if np.isscalar(v) else v + obs_dict[f"observation.state.{k}"] = _to_tensor(arr) + + for cam, img in images.items(): + obs_dict[f"observation.images.{cam}"] = img + + return make_obs_act_transition(obs=obs_dict) + + +def to_output_robot_action(transition: EnvTransition) -> dict[str, Any]: + """ + Converts a EnvTransition under the ACTION TransitionKey to a dict with keys ending in '.pos' for raw robot actions. + """ + out: dict[str, Any] = {} + action_dict = transition.get(TransitionKey.ACTION) or {} + + for k, v in action_dict.items(): + if isinstance(k, str) and k.startswith("action.") and k.endswith((".pos", ".vel")): + out_key = k[len("action.") :] # Strip the 'action.' prefix. + out[out_key] = float(v) + + return out + + +def to_dataset_frame( + transitions_or_transition: EnvTransition | Iterable[EnvTransition], features: dict[str, dict] +) -> dict[str, any]: + """ + Converts a single EnvTransition or an iterable of them into a flat, + dataset-friendly dictionary for training or evaluation, according to + the provided `features` spec. + + Args: + transitions_or_transition: Either a single EnvTransition dict + or an iterable of them (which will be merged). + features (dict[str, dict]): + A feature specification dictionary: + - 'action': dict with 'names': list of action feature names + - 'observation.state': dict with 'names': list of state feature names + - keys starting with 'observation.images.' are passed through + + Returns: + batch (dict[str, any]): Flat dictionary containing: + - numpy arrays for "observation.state" and "action" + - any image tensors defined in features + - next.{reward,done,truncated} + - info dict + - *_is_pad flags and task from complementary_data + """ + action_names = features.get("action", {}).get("names", []) + obs_state_names = features.get("observation.state", {}).get("names", []) + image_keys = [k for k in features if k.startswith("observation.images.")] + + def _merge(base: EnvTransition, other: EnvTransition) -> EnvTransition: + out = deepcopy(base) + for key in ( + TransitionKey.OBSERVATION, + TransitionKey.ACTION, + TransitionKey.INFO, + TransitionKey.COMPLEMENTARY_DATA, + ): + if other.get(key): + out.setdefault(key, {}).update(deepcopy(other[key])) + for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED): + if k in other: + out[k] = other[k] + return out + + def _ensure_transition(obj) -> EnvTransition: + # single transition + if isinstance(obj, dict) and any(isinstance(k, TransitionKey) for k in obj): + return obj + # iterable of transitions + if isinstance(obj, Iterable): + items = list(obj) + if not items: + return {} + acc = items[0] + for t in items[1:]: + acc = _merge(acc, t) + return acc + raise TypeError("Expected EnvTransition or iterable of them") + + tr = _ensure_transition(transitions_or_transition) + obs = tr.get(TransitionKey.OBSERVATION, {}) or {} + act = tr.get(TransitionKey.ACTION, {}) or {} + batch: dict[str, any] = {} + + # Images passthrough + for k in image_keys: + if k in obs: + batch[k] = obs[k] + + # Observation.state vector + if obs_state_names: + vals = [_from_tensor(obs.get(f"observation.state.{n}", 0.0)) for n in obs_state_names] + batch["observation.state"] = np.asarray(vals, dtype=np.float32) + + # Action vector + if action_names: + vals = [_from_tensor(act.get(f"action.{n}", 0.0)) for n in action_names] + batch["action"] = np.asarray(vals, dtype=np.float32) + + # Next.* fields + if tr.get(TransitionKey.REWARD) is not None: + batch["next.reward"] = _from_tensor(tr[TransitionKey.REWARD]) + if tr.get(TransitionKey.DONE) is not None: + batch["next.done"] = _from_tensor(tr[TransitionKey.DONE]) + if tr.get(TransitionKey.TRUNCATED) is not None: + batch["next.truncated"] = _from_tensor(tr[TransitionKey.TRUNCATED]) + + # Complementary data flags and task + comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {} + if comp: + # pad flags + for k, v in comp.items(): + if k.endswith("_is_pad"): + batch[k] = v + # task label + if comp.get("task") is not None: + batch["task"] = comp["task"] + + return batch diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index 12f9a5abc..39bd1cf11 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -141,5 +141,5 @@ class DeviceProcessor: """Reset processor state (no-op for this processor).""" pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 94390b004..92e654472 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -257,7 +257,7 @@ class NormalizerProcessor: def reset(self): pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -435,7 +435,7 @@ class UnnormalizerProcessor: def reset(self): pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index 7d63db238..40273548e 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -106,9 +106,8 @@ class VanillaObservationProcessor(ObservationProcessor): def observation(self, observation): return self._process_observation(observation) - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: """Transforms feature keys to a standardized contract. - This method handles several renaming patterns: - Exact matches (e.g., 'pixels' -> 'OBS_IMAGE'). - Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE'). diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 6d3546035..19dc668f7 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -23,7 +23,7 @@ from copy import deepcopy from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Any, Protocol, TypedDict +from typing import Any, Protocol, TypedDict, runtime_checkable import torch from huggingface_hub import ModelHubMixin, hf_hub_download @@ -132,6 +132,7 @@ class ProcessorStepRegistry: cls._registry.clear() +@runtime_checkable class ProcessorStep(Protocol): """Structural typing interface for a single processor step. @@ -145,7 +146,6 @@ class ProcessorStep(Protocol): **Required**: - ``__call__(transition: EnvTransition) -> EnvTransition`` - - ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]`` Optional helper protocol: * ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable @@ -158,6 +158,8 @@ class ProcessorStep(Protocol): * ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict containing torch tensors only. * ``reset()`` – Clear internal buffers at episode boundaries. + * ``transform_features(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]`` + If present, this method will be called to aggregate the dataset features of all steps. Example separation: - get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10} @@ -174,7 +176,7 @@ class ProcessorStep(Protocol): def reset(self) -> None: ... - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ... + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ... def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401 @@ -354,7 +356,10 @@ class RobotProcessor(ModelHubMixin): hook(idx, current_transition) # Convert back to original format if needed - return self.to_output(current_transition) if called_with_batch else current_transition + if called_with_batch or self.to_output is not _default_transition_to_batch: + return self.to_output(current_transition) + else: + return current_transition def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]: """Prepare and validate transition data for processing. @@ -819,23 +824,15 @@ class RobotProcessor(ModelHubMixin): f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition" ) - fc = getattr(step, "feature_contract", None) - if not callable(fc): - raise TypeError( - f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]" - ) - - def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: """ - Apply ALL steps in order. Each step must implement - feature_contract(features) and return a dict (full or incremental schema). + Apply ALL steps in order. Only if a step has a features method, it will be called. + We aggregate the dataset features of all steps. """ features: dict[str, PolicyFeature] = deepcopy(initial_features) for _, step in enumerate(self.steps): - out = step.feature_contract(features) - if not isinstance(out, dict): - raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]") + out = step.transform_features(features) features = out return features @@ -895,7 +892,7 @@ class ObservationProcessor: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -955,7 +952,7 @@ class ActionProcessor: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -1014,7 +1011,7 @@ class RewardProcessor: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -1078,7 +1075,7 @@ class DoneProcessor: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -1138,7 +1135,7 @@ class TruncatedProcessor: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -1203,7 +1200,7 @@ class InfoProcessor: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -1249,7 +1246,7 @@ class ComplementaryDataProcessor: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -1271,5 +1268,5 @@ class IdentityProcessor: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py index 4fe4105a5..db20424df 100644 --- a/src/lerobot/processor/rename_processor.py +++ b/src/lerobot/processor/rename_processor.py @@ -43,7 +43,7 @@ class RenameProcessor(ObservationProcessor): def get_config(self) -> dict[str, Any]: return {"rename_map": self.rename_map} - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: """Transforms: - Each key in the observation that appears in `rename_map` is renamed to its value. - Keys not in `rename_map` remain unchanged. diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index c7086d6ce..4ec9fb351 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -187,7 +187,7 @@ class TokenizerProcessor: """Reset processor state (no-op for this processor).""" pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: """Add tokenized task features to the feature contract. Args: diff --git a/src/lerobot/record.py b/src/lerobot/record.py index e73c76384..78b671646 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -72,12 +72,19 @@ from lerobot.configs import parser from lerobot.configs.policies import PreTrainedConfig from lerobot.datasets.image_writer import safe_stop_image_writer from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features +from lerobot.datasets.utils import hw_to_dataset_features from lerobot.datasets.video_utils import VideoEncodingManager from lerobot.policies.factory import make_policy, make_processor from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.processor import RobotProcessor +from lerobot.processor.converters import ( + to_dataset_frame, + to_output_robot_action, + to_transition_robot_observation, + to_transition_teleop_action, +) from lerobot.processor.normalize_processor import rename_stats +from lerobot.processor.pipeline import IdentityProcessor, TransitionKey from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -191,6 +198,36 @@ class RecordConfig: return ["policy"] +""" --------------- record_loop() data flow -------------------------- + [ Robot ] + V + [ robot.get_observation() ] ---> raw_obs + V + [ robot_observation_processor ] ---> obs_transition + V + .-----( ACTION LOGIC )------------------. + V V + [ From Teleoperator ] [ From Policy ] + | | + | [teleop.get_action] -> raw_action | [predict_action] + | | | | + | V | V + | [teleop_action_processor] | | + | | | | + '---> teleop_transition '---> policy_transition + | | + '-------------------------.-------------' + V + [ robot_action_processor ] --> robot_action_to_send + V + [ robot.send_action() ] -- (Robot Executes) + V + ( Transitions are merged & added to Dataset ) + V + ( Rerun Log / Loop Wait ) +""" + + @safe_stop_image_writer def record_loop( robot: Robot, @@ -202,14 +239,27 @@ def record_loop( preprocessor: RobotProcessor | None = None, postprocessor: RobotProcessor | None = None, control_time_s: int | None = None, + teleop_action_processor: RobotProcessor | None = None, # runs after teleop + robot_action_processor: RobotProcessor | None = None, # runs before robot + robot_observation_processor: RobotProcessor | None = None, # runs after robot single_task: str | None = None, display_data: bool = False, ): + teleop_action_processor = teleop_action_processor or RobotProcessor( + steps=[IdentityProcessor()], to_transition=to_transition_teleop_action, to_output=lambda tr: tr + ) + robot_action_processor = robot_action_processor or RobotProcessor( + steps=[IdentityProcessor()], to_transition=lambda tr: tr, to_output=to_output_robot_action + ) + robot_observation_processor = robot_observation_processor or RobotProcessor( + steps=[IdentityProcessor()], to_transition=to_transition_robot_observation, to_output=lambda tr: tr + ) + if dataset is not None and dataset.fps != fps: raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).") teleop_arm = teleop_keyboard = None - if isinstance(teleop, list): + if isinstance(teleop, list): # For LeKiwi teleop_keyboard = next((t for t in teleop if isinstance(t, KeyboardTeleop)), None) teleop_arm = next( ( @@ -226,11 +276,20 @@ def record_loop( ) # Reset policy and processor if they are provided - if policy is not None or preprocessor is not None: + if policy is not None and preprocessor is not None and postprocessor is not None: policy.reset() preprocessor.reset() postprocessor.reset() + # Reset custom pipelines + teleop_action_processor.reset() + robot_action_processor.reset() + robot_observation_processor.reset() + + policy_transition = None + teleop_transition = None + obs_transition = None + timestamp = 0 start_episode_t = time.perf_counter() while timestamp < control_time_s: @@ -240,12 +299,19 @@ def record_loop( events["exit_early"] = False break - observation = robot.get_observation() + # Get robot observation + obs = robot.get_observation() - if policy is not None or dataset is not None: - observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") + # Applies a pipeline to the raw robot observation, default is IdentityProcessor + obs_transition = robot_observation_processor(obs) + + # Get action from either policy or teleop + if policy is not None and preprocessor is not None and postprocessor is not None: + if dataset is not None: + observation_frame = to_dataset_frame( + obs_transition, dataset.features + ) # Convert the observation to the dataset format - if policy is not None or preprocessor is not None: action_values = predict_action( observation=observation_frame, policy=policy, @@ -256,37 +322,64 @@ def record_loop( task=single_task, robot_type=robot.robot_type, ) - action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)} - elif policy is None and isinstance(teleop, Teleoperator): - action = teleop.get_action() - elif policy is None and isinstance(teleop, list): - # TODO(pepijn, steven): clean the record loop for use of multiple robots (possibly with pipeline) + + action_names = dataset.features["action"]["names"] + policy_action = {f"action.{name}": float(action_values[i]) for i, name in enumerate(action_names)} + policy_transition = { + TransitionKey.ACTION: policy_action, + TransitionKey.COMPLEMENTARY_DATA: {}, + } + + elif isinstance(teleop, Teleoperator): + act = teleop.get_action() + + # Applies a pipeline to the raw teleop action, default is IdentityProcessor + teleop_transition = teleop_action_processor(act) + + elif isinstance(teleop, list): arm_action = teleop_arm.get_action() arm_action = {f"arm_{k}": v for k, v in arm_action.items()} - keyboard_action = teleop_keyboard.get_action() base_action = robot._from_keyboard_to_base_action(keyboard_action) - - action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action + act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action + teleop_transition = teleop_action_processor(act) else: logging.info( - "No policy or teleoperator provided, skipping action generation." - "This is likely to happen when resetting the environment without a teleop device." - "The robot won't be at its rest position at the start of the next episode." + "No policy or teleoperator provided, skipping action generation. " + "This is likely to happen during environment reset." ) - continue + # Still continue to next loop to respect timing + # Applies a pipeline to the action, default is IdentityProcessor + # IMPORTANT: action_pipeline.to_output must return a dict suitable for robot.send_action() + if policy_transition is not None: + robot_action_to_send = robot_action_processor(policy_transition) + else: + robot_action_to_send = robot_action_processor(teleop_transition) + + # Send action to robot # Action can eventually be clipped using `max_relative_target`, # so action actually sent is saved in the dataset. action = postprocessor.process(action) - sent_action = robot.send_action(action) + # TODO(pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot. + _ = robot.send_action(robot_action_to_send) + # Write to dataset if dataset is not None: - action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action") - frame = {**observation_frame, **action_frame} + # If to_dataset_frame is provided, use it to merge the transitions. + merged = [] + if obs_transition is not None: # The observation from the robot + merged.append(obs_transition) + if teleop_transition is not None: # The action from teleop + merged.append(teleop_transition) + if policy_transition is not None: # The action from policy + merged.append(policy_transition) + frame = to_dataset_frame( + merged if len(merged) > 1 else merged[0], dataset.features + ) # Convert the observation to the dataset format dataset.add_frame(frame, task=single_task) if display_data: - log_rerun_data(observation, action) + log_rerun_data([obs_transition, teleop_transition or policy_transition]) dt_s = time.perf_counter() - start_loop_t busy_wait(1 / fps - dt_s) @@ -417,9 +510,5 @@ def record(cfg: RecordConfig) -> LeRobotDataset: return dataset -def main(): - record() - - if __name__ == "__main__": - main() + record() diff --git a/src/lerobot/robots/so100_follower/__init__.py b/src/lerobot/robots/so100_follower/__init__.py index b995aab13..5dc43ac3b 100644 --- a/src/lerobot/robots/so100_follower/__init__.py +++ b/src/lerobot/robots/so100_follower/__init__.py @@ -14,6 +14,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig +from .config_so100_follower import SO100FollowerConfig from .so100_follower import SO100Follower -from .so100_follower_end_effector import SO100FollowerEndEffector diff --git a/src/lerobot/robots/so100_follower/config_so100_follower.py b/src/lerobot/robots/so100_follower/config_so100_follower.py index ea8b9f1c2..16bab13e4 100644 --- a/src/lerobot/robots/so100_follower/config_so100_follower.py +++ b/src/lerobot/robots/so100_follower/config_so100_follower.py @@ -39,35 +39,3 @@ class SO100FollowerConfig(RobotConfig): # Set to `True` for backward compatibility with previous policies/dataset use_degrees: bool = False - - -@RobotConfig.register_subclass("so100_follower_end_effector") -@dataclass -class SO100FollowerEndEffectorConfig(SO100FollowerConfig): - """Configuration for the SO100FollowerEndEffector robot.""" - - # Path to URDF file for kinematics - # NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: - # https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf - urdf_path: str | None = None - - # End-effector frame name in the URDF - target_frame_name: str = "gripper_frame_link" - - # Default bounds for the end-effector position (in meters) - end_effector_bounds: dict[str, list[float]] = field( - default_factory=lambda: { - "min": [-1.0, -1.0, -1.0], # min x, y, z - "max": [1.0, 1.0, 1.0], # max x, y, z - } - ) - - max_gripper_pos: float = 50 - - end_effector_step_sizes: dict[str, float] = field( - default_factory=lambda: { - "x": 0.02, - "y": 0.02, - "z": 0.02, - } - ) diff --git a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py new file mode 100644 index 000000000..ed498557f --- /dev/null +++ b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py @@ -0,0 +1,447 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +import numpy as np +from scipy.spatial.transform import Rotation + +from lerobot.configs.types import PolicyFeature +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor.pipeline import ( + ActionProcessor, + ComplementaryDataProcessor, + EnvTransition, + ObservationProcessor, + ProcessorStepRegistry, + TransitionKey, +) +from lerobot.robots.robot import Robot + + +@ProcessorStepRegistry.register("ee_reference_and_delta") +@dataclass +class EEReferenceAndDelta: + """ + Compute the desired end-effector pose from the target pose and the current pose. + + Input ACTION keys: + { + "action.ee.{x,y,z,wx,wy,wz}" : float + "complementary_data.raw_joint_positions": dict, + } + + Output ACTION keys: + { + "action.ee.{x,y,z,wx,wy,wz}" : float + } + """ + + kinematics: RobotKinematics + end_effector_step_sizes: dict + motor_names: list[str] + + reference_ee_pose: np.ndarray | None = field(default=None, init=False, repr=False) + _prev_enabled: bool = field(default=False, init=False, repr=False) + _command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + act = transition.get(TransitionKey.ACTION) or {} + comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} + + # Get joint positions from complimentary data + raw = comp.get("raw_joint_positions", None) + if raw is None: + raise ValueError( + "raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta" + ) + + q = np.array([float(raw[n]) for n in self.motor_names], dtype=float) + + # Current pose from FK on measured joints + t_curr = self.kinematics.forward_kinematics(q) + + enabled = bool(act.pop("action.enabled", 0)) + tx = float(act.pop("action.target_x", 0.0)) + ty = float(act.pop("action.target_y", 0.0)) + tz = float(act.pop("action.target_z", 0.0)) + wx = float(act.pop("action.target_wx", 0.0)) + wy = float(act.pop("action.target_wy", 0.0)) + wz = float(act.pop("action.target_wz", 0.0)) + + desired = None + + if enabled: + # Latch a reference at the rising edge; also be defensive if None + if not self._prev_enabled or self.reference_ee_pose is None: + self.reference_ee_pose = t_curr.copy() + + ref = self.reference_ee_pose if self.reference_ee_pose is not None else t_curr + + delta_p = np.array( + [ + tx * self.end_effector_step_sizes["x"], + ty * self.end_effector_step_sizes["y"], + tz * self.end_effector_step_sizes["z"], + ], + dtype=float, + ) + r_abs = Rotation.from_rotvec([wx, wy, wz]).as_matrix() + + desired = np.eye(4, dtype=float) + desired[:3, :3] = ref[:3, :3] @ r_abs + desired[:3, 3] = ref[:3, 3] + delta_p + + self._command_when_disabled = desired.copy() + else: + # While disabled, keep sending the same command to avoid drift. + if self._command_when_disabled is None: + # If we've never had an enabled command yet, freeze current FK pose once. + self._command_when_disabled = t_curr.copy() + desired = self._command_when_disabled.copy() + + # Write action fields + pos = desired[:3, 3] + tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec() + act.update( + { + "action.ee.x": float(pos[0]), + "action.ee.y": float(pos[1]), + "action.ee.z": float(pos[2]), + "action.ee.wx": float(tw[0]), + "action.ee.wy": float(tw[1]), + "action.ee.wz": float(tw[2]), + } + ) + + self._prev_enabled = enabled + transition[TransitionKey.ACTION] = act + return transition + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +@ProcessorStepRegistry.register("ee_bounds_and_safety") +@dataclass +class EEBoundsAndSafety(ActionProcessor): + """ + Clip the end-effector pose to the bounds and check for jumps. + + Input ACTION keys: + { + "action.ee.{x,y,z,wx,wy,wz}" : float + } + + Output ACTION keys: + { + "action.ee.{x,y,z,wx,wy,wz}" : float + } + """ + + end_effector_bounds: dict + max_ee_step_m: float = 0.05 + max_ee_twist_step_rad: float = 0.20 + _last_pos: np.ndarray | None = field(default=None, init=False, repr=False) + + def action(self, act: dict | None) -> dict: + x = act.pop("action.ee.x", None) + y = act.pop("action.ee.y", None) + z = act.pop("action.ee.z", None) + wx = act.pop("action.ee.wx", None) + wy = act.pop("action.ee.wy", None) + wz = act.pop("action.ee.wz", None) + + if None in (x, y, z, wx, wy, wz): + return act + + pos = np.array([x, y, z], dtype=float) + twist = np.array([wx, wy, wz], dtype=float) + + # Clip position + pos = np.clip(pos, self.end_effector_bounds["min"], self.end_effector_bounds["max"]) + + # Check for jumps in position + if self._last_pos is not None: + dpos = pos - self._last_pos + n = float(np.linalg.norm(dpos)) + if n > self.max_ee_step_m and n > 0: + pos = self._last_pos + dpos * (self.max_ee_step_m / n) + raise ValueError(f"EE jump {n:.3f}m > {self.max_ee_step_m}m") + + self._last_pos = pos + self._last_twist = twist + + act.update( + { + "action.ee.x": float(pos[0]), + "action.ee.y": float(pos[1]), + "action.ee.z": float(pos[2]), + "action.ee.wx": float(twist[0]), + "action.ee.wy": float(twist[1]), + "action.ee.wz": float(twist[2]), + } + ) + return act + + def reset(self): + self._last_pos = None + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # Because this is last step we specify the dataset features of this step that we want to be stored in the dataset + features["action.ee.x"] = float + features["action.ee.y"] = float + features["action.ee.z"] = float + features["action.ee.wx"] = float + features["action.ee.wy"] = float + features["action.ee.wz"] = float + return features + + +@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints") +@dataclass +class InverseKinematicsEEToJoints: + """ + Compute the desired joint positions from the desired end-effector pose. + + Input ACTION keys: + { + "action.ee.{x,y,z,wx,wy,wz}" : float + "complementary_data.raw_joint_positions": dict, + } + + Output ACTION keys: + { + "action.joint_name_1.pos": float, + "action.joint_name_2.pos": float, + ... + "action.joint_name_n.pos": float, + } + """ + + kinematics: RobotKinematics + motor_names: list[str] + q_curr: np.ndarray | None = field(default=None, init=False, repr=False) + initial_guess_current_joints: bool = True + + def __call__(self, transition: EnvTransition) -> EnvTransition: + act = transition.get(TransitionKey.ACTION) or {} + comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} + + x = act.get("action.ee.x", None) + y = act.get("action.ee.y", None) + z = act.get("action.ee.z", None) + wx = act.get("action.ee.wx", None) + wy = act.get("action.ee.wy", None) + wz = act.get("action.ee.wz", None) + + if None in (x, y, z, wx, wy, wz): + # Nothing to do; restore what we popped and return + act.update( + { + "action.ee.x": x, + "action.ee.y": y, + "action.ee.z": z, + "action.ee.wx": wx, + "action.ee.wy": wy, + "action.ee.wz": wz, + } + ) + transition[TransitionKey.ACTION] = act + return transition + + # Get joint positions from complimentary data + raw = comp.get("raw_joint_positions", None) + if raw is None: + raise ValueError( + "raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta" + ) + + if self.initial_guess_current_joints: # Use current joints as initial guess + self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float) + else: # Use previous ik solution as initial guess + if self.q_curr is None: + self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float) + + # Build desired 4x4 transform from pos + rotvec (twist) + t_des = np.eye(4, dtype=float) + t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix() + t_des[:3, 3] = [x, y, z] + + # Compute inverse kinematics + q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des) + self.q_curr = q_target + + new_act = dict(act) + for i, name in enumerate(self.motor_names): + if name == "gripper": + new_act["observation.state.gripper.pos"] = float(raw["gripper"]) + else: + new_act[f"action.{name}.pos"] = float(q_target[i]) + transition[TransitionKey.ACTION] = new_act + return transition + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We specify the dataset features of this step that we want to be stored in the dataset + features["action.ee.x"] = float + features["action.ee.y"] = float + features["action.ee.z"] = float + features["action.ee.wx"] = float + features["action.ee.wy"] = float + features["action.ee.wz"] = float + + features["observation.state.gripper.pos"] = float + features["action.gripper.pos"] = float + return features + + def reset(self): + self.q_curr = None + + +@ProcessorStepRegistry.register("gripper_velocity_to_joint") +@dataclass +class GripperVelocityToJoint: + """ + Convert the gripper velocity to a joint velocity. + + Input ACTION keys: + { + "action.gripper": float, + } + + Output ACTION keys: + { + "action.gripper.pos": float, + } + """ + + motor_names: list[str] + speed_factor: float = 20.0 + clip_min: float = 0.0 + clip_max: float = 100.0 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + obs = transition.get(TransitionKey.OBSERVATION) or {} + act = transition.get(TransitionKey.ACTION) or {} + comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} + + if "action.gripper" not in act: + return transition + + if "gripper" not in self.motor_names: + new_act = dict(act) + new_act.pop("action.gripper", None) + transition[TransitionKey.ACTION] = new_act + return transition + + # Get current gripper position from complementary data + raw = comp.get("raw_joint_positions") or {} + curr_pos = float(raw.get("gripper")) + + # Compute desired gripper velocity + u = float(act.get("action.gripper", 0.0)) + delta = u * float(self.speed_factor) + gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max)) + + new_act = dict(act) + new_act["action.gripper.pos"] = gripper_pos + new_act.pop("action.gripper", None) + transition[TransitionKey.ACTION] = new_act + + obs.update({"observation.state.gripper.pos": curr_pos}) + transition[TransitionKey.OBSERVATION] = obs + return transition + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We specify the dataset features of this step that we want to be stored in the dataset + features["observation.state.gripper.pos"] = float + features["action.gripper.pos"] = float + return features + + +@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee") +@dataclass +class ForwardKinematicsJointsToEE(ObservationProcessor): + """ + Compute the end-effector pose from the joint positions. + + Input OBSERVATION keys: + { + "observation.state.{joint_name_1,joint_name_2,...,joint_name_n}.pos": float, + } + + Output OBSERVATION keys: + { + "observation.state.ee.{x,y,z,wx,wy,wz}" : float + } + """ + + kinematics: RobotKinematics + motor_names: list[str] + + def observation(self, obs: dict | None) -> dict: + if not all(f"observation.state.{n}.pos" in obs for n in self.motor_names): + return obs + + q = np.array([obs[f"observation.state.{n}.pos"] for n in self.motor_names], dtype=float) + t = self.kinematics.forward_kinematics(q) + pos = t[:3, 3] + tw = Rotation.from_matrix(t[:3, :3]).as_rotvec() + + obs.update( + { + "observation.state.ee.x": float(pos[0]), + "observation.state.ee.y": float(pos[1]), + "observation.state.ee.z": float(pos[2]), + "observation.state.ee.wx": float(tw[0]), + "observation.state.ee.wy": float(tw[1]), + "observation.state.ee.wz": float(tw[2]), + } + ) + return obs + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We specify the dataset features of this step that we want to be stored in the dataset + for k in ["x", "y", "z", "wx", "wy", "wz"]: + features[f"observation.state.ee.{k}"] = float + return features + + +@ProcessorStepRegistry.register("add_robot_observation") +@dataclass +class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessor): + """ + Read the robot's current observation and insert it into the transition as complementary data. + + - Joint positions are added under complementary_data["raw_joint_positions"] as a dict: + { "": , ... } + """ + + robot: Robot + + def complementary_data(self, comp: dict | None) -> dict: + comp = {} if comp is None else dict(comp) + obs = self.robot.get_observation() + + comp["raw_joint_positions"] = { + k.removesuffix(".pos"): float(v) + for k, v in obs.items() + if isinstance(k, str) and k.endswith(".pos") + } + return comp + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/robots/so100_follower/so100_follower_end_effector.py b/src/lerobot/robots/so100_follower/so100_follower_end_effector.py deleted file mode 100644 index 5fe2993cb..000000000 --- a/src/lerobot/robots/so100_follower/so100_follower_end_effector.py +++ /dev/null @@ -1,200 +0,0 @@ -# !/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time -from typing import Any - -import numpy as np - -from lerobot.cameras import make_cameras_from_configs -from lerobot.errors import DeviceNotConnectedError -from lerobot.model.kinematics import RobotKinematics -from lerobot.motors import Motor, MotorNormMode -from lerobot.motors.feetech import FeetechMotorsBus - -from . import SO100Follower -from .config_so100_follower import SO100FollowerEndEffectorConfig - -logger = logging.getLogger(__name__) - - -class SO100FollowerEndEffector(SO100Follower): - """ - SO100Follower robot with end-effector space control. - - This robot inherits from SO100Follower but transforms actions from - end-effector space to joint space before sending them to the motors. - """ - - config_class = SO100FollowerEndEffectorConfig - name = "so100_follower_end_effector" - - def __init__(self, config: SO100FollowerEndEffectorConfig): - super().__init__(config) - self.bus = FeetechMotorsBus( - port=self.config.port, - motors={ - "shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREES), - "shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREES), - "elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREES), - "wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREES), - "wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREES), - "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), - }, - calibration=self.calibration, - ) - - self.cameras = make_cameras_from_configs(config.cameras) - - self.config = config - - # Initialize the kinematics module for the so100 robot - if self.config.urdf_path is None: - raise ValueError( - "urdf_path must be provided in the configuration for end-effector control. " - "Please set urdf_path in your SO100FollowerEndEffectorConfig." - ) - - self.kinematics = RobotKinematics( - urdf_path=self.config.urdf_path, - target_frame_name=self.config.target_frame_name, - ) - - # Store the bounds for end-effector position - self.end_effector_bounds = self.config.end_effector_bounds - - self.current_ee_pos = None - self.current_joint_pos = None - - @property - def action_features(self) -> dict[str, Any]: - """ - Define action features for end-effector control. - Returns dictionary with dtype, shape, and names. - """ - return { - "dtype": "float32", - "shape": (4,), - "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3}, - } - - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: - """ - Transform action from end-effector space to joint space and send to motors. - - Args: - action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control - or a numpy array with [delta_x, delta_y, delta_z] - - Returns: - The joint-space action that was sent to the motors - """ - - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - # Convert action to numpy array if not already - if isinstance(action, dict): - if all(k in action for k in ["delta_x", "delta_y", "delta_z"]): - delta_ee = np.array( - [ - action["delta_x"] * self.config.end_effector_step_sizes["x"], - action["delta_y"] * self.config.end_effector_step_sizes["y"], - action["delta_z"] * self.config.end_effector_step_sizes["z"], - ], - dtype=np.float32, - ) - if "gripper" not in action: - action["gripper"] = [1.0] - action = np.append(delta_ee, action["gripper"]) - else: - logger.warning( - f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}" - ) - action = np.zeros(4, dtype=np.float32) - - if self.current_joint_pos is None: - # Read current joint positions - current_joint_pos = self.bus.sync_read("Present_Position") - self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors]) - - # Calculate current end-effector position using forward kinematics - if self.current_ee_pos is None: - self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos) - - # Set desired end-effector position by adding delta - desired_ee_pos = np.eye(4) - desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation - - # Add delta to position and clip to bounds - desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3] - if self.end_effector_bounds is not None: - desired_ee_pos[:3, 3] = np.clip( - desired_ee_pos[:3, 3], - self.end_effector_bounds["min"], - self.end_effector_bounds["max"], - ) - - # Compute inverse kinematics to get joint positions - target_joint_values_in_degrees = self.kinematics.inverse_kinematics( - self.current_joint_pos, desired_ee_pos - ) - - # Create joint space action dictionary - joint_action = { - f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys()) - } - - # Handle gripper separately if included in action - # Gripper delta action is in the range 0 - 2, - # We need to shift the action to the range -1, 1 so that we can expand it to -Max_gripper_pos, Max_gripper_pos - joint_action["gripper.pos"] = np.clip( - self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos, - 5, - self.config.max_gripper_pos, - ) - - self.current_ee_pos = desired_ee_pos.copy() - self.current_joint_pos = target_joint_values_in_degrees.copy() - self.current_joint_pos[-1] = joint_action["gripper.pos"] - - # Send joint space action to parent class - return super().send_action(joint_action) - - def get_observation(self) -> dict[str, Any]: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - # Read arm position - start = time.perf_counter() - obs_dict = self.bus.sync_read("Present_Position") - obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()} - dt_ms = (time.perf_counter() - start) * 1e3 - logger.debug(f"{self} read state: {dt_ms:.1f}ms") - - # Capture images from cameras - for cam_key, cam in self.cameras.items(): - start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() - dt_ms = (time.perf_counter() - start) * 1e3 - logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") - - return obs_dict - - def reset(self): - self.current_ee_pos = None - self.current_joint_pos = None diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index 7486ee499..87e751b26 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -69,6 +69,7 @@ def make_robot_from_config(config: RobotConfig) -> Robot: raise ValueError(config.type) +# TODO(pepijn): Move to pipeline step to make sure we don't have to do this in the robot code and send action to robot is clean for use in dataset def ensure_safe_goal_position( goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float] ) -> dict[str, float]: diff --git a/src/lerobot/teleoperate.py b/src/lerobot/teleoperate.py index 3c72caf79..320140bdb 100644 --- a/src/lerobot/teleoperate.py +++ b/src/lerobot/teleoperate.py @@ -109,7 +109,7 @@ def teleop_loop( action = teleop.get_action() if display_data: observation = robot.get_observation() - log_rerun_data(observation, action) + log_rerun_data(observation=observation, action=action) robot.send_action(action) dt_s = time.perf_counter() - loop_start diff --git a/src/lerobot/teleoperators/phone/__init__.py b/src/lerobot/teleoperators/phone/__init__.py new file mode 100644 index 000000000..f82ab11e1 --- /dev/null +++ b/src/lerobot/teleoperators/phone/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_phone import PhoneConfig +from .phone import Phone diff --git a/src/lerobot/teleoperators/phone/config_phone.py b/src/lerobot/teleoperators/phone/config_phone.py new file mode 100644 index 000000000..380d5f5ff --- /dev/null +++ b/src/lerobot/teleoperators/phone/config_phone.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from enum import Enum + +import numpy as np + +from ..config import TeleoperatorConfig + + +class PhoneOS(Enum): + ANDROID = "android" + IOS = "ios" + + +@TeleoperatorConfig.register_subclass("phone") +@dataclass +class PhoneConfig(TeleoperatorConfig): + phone_os: PhoneOS = PhoneOS.IOS + camera_offset = np.array( + [0.0, -0.02, 0.04] + ) # iPhone 14 Pro camera is 2cm off center and 4cm above center diff --git a/src/lerobot/teleoperators/phone/phone.py b/src/lerobot/teleoperators/phone/phone.py new file mode 100644 index 000000000..3c6d5fc5d --- /dev/null +++ b/src/lerobot/teleoperators/phone/phone.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Docs: +# hebi: https://docs.hebi.us/tools.html#mobile-io +# teleop: https://github.com/SpesRobotics/teleop + +import logging +import threading +import time + +import hebi +import numpy as np +from scipy.spatial.transform import Rotation +from teleop import Teleop + +from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS +from lerobot.teleoperators.teleoperator import Teleoperator + +logger = logging.getLogger(__name__) + + +class Phone(Teleoperator): + """ + Phone-based teleoperator using ARKit (iOS via HEBI Mobile I/O App) or the teleop Python package (Android via WebXR API). + For HEBI Mobile I/O we also expose 8 analog (a1-a8) and 8 digital (b1-b8) inputs. + + Press and hold **B1** to enable teleoperation. While enabled, the first B1 press + captures a reference pose and rotation, when disabled and pressed again the position is reapplied. + """ + + config_class = PhoneConfig + name = "phone" + + def __init__(self, config: PhoneConfig): + super().__init__(config) + self.config = config + self._group = None + self._teleop = None + self._teleop_thread = None + self._latest_pose = None + self._latest_message = None + self._enabled: bool = False + self._calib_pos: np.ndarray | None = None + self._calib_rot_inv: Rotation | None = None + + @property + def is_connected(self) -> bool: + return (self.config.phone_os == PhoneOS.IOS and self._group is not None) or ( + self.config.phone_os == PhoneOS.ANDROID and self._teleop is not None + ) + + def connect(self) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + if self.config.phone_os == PhoneOS.IOS: + logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.") + lookup = hebi.Lookup() + time.sleep(2.0) + group = lookup.get_group_from_names(["HEBI"], ["mobileIO"]) + if group is None: + raise RuntimeError("Mobile I/O not found — check name/family settings in the app.") + self._group = group + logger.info(f"{self} connected to HEBI group with {group.size} module(s).") + elif self.config.phone_os == PhoneOS.ANDROID: + logger.info("Starting teleop stream for Android...") + self._teleop = Teleop() + self._teleop.subscribe(self._android_callback) + self._teleop_thread = threading.Thread(target=self._teleop.run, daemon=True) + self._teleop_thread.start() + logger.info(f"{self} connected, teleop stream started.") + else: + raise ValueError(f"Invalid config phone_os: {self.config.phone_os}") + + self.calibrate() + + def calibrate(self) -> None: + print( + "Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)" + ) + if self.config.phone_os == PhoneOS.IOS: + print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n") + else: + print("Touch and move on the WebXR page to capture this pose...\n") + + pos, rot = self._wait_for_capture_trigger() + self._calib_pos = pos.copy() + self._calib_rot_inv = rot.inv() + self._enabled = False + print("Calibration done\n") + + def _reapply_position_calibration(self, pos: np.ndarray) -> None: + self._calib_pos = pos.copy() + + @property + def is_calibrated(self) -> bool: + return (self._calib_pos is not None) and (self._calib_rot_inv is not None) + + @property + def action_features(self) -> dict[str, type]: + return { + "phone.pos": np.ndarray, # shape (3,) + "phone.rot": Rotation, # scipy.spatial.transform.Rotation + "phone.raw_inputs": dict, # analogs/buttons or webXR meta + "phone.enabled": bool, + } + + def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]: + """Wait trigger for calibration: iOS: B1. Android: 'move'.""" + while True: + ok, pos, rot, pose = self._read_current_pose() + if not ok: + time.sleep(0.01) + continue + + if self.config.phone_os == PhoneOS.IOS: + io = getattr(pose, "io", None) + b = getattr(io, "b", None) if io is not None else None + b1 = False + if b is not None: + b1 = bool(b.get_int(1)) + if b1: + return pos, rot + else: + msg = self._latest_message or {} + if bool(msg.get("move", False)): + return pos, rot + + time.sleep(0.01) + + def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]: + if self.config.phone_os == PhoneOS.IOS: + fbk = self._group.get_next_feedback() + pose = fbk[0] + ar_pos = getattr(pose, "ar_position", None) + ar_quat = getattr(pose, "ar_orientation", None) + if ar_pos is None or ar_quat is None: + return False, None, None, None + quat_xyzw = np.concatenate((ar_quat[1:], [ar_quat[0]])) # wxyz to xyzw + rot = Rotation.from_quat(quat_xyzw) + pos = ar_pos - rot.apply(self.config.camera_offset) + return True, pos, rot, pose + else: + p = self._latest_pose + if p is None: + return False, None, None, None + rot = Rotation.from_matrix(p[:3, :3]) + pos = p[:3, 3] - rot.apply(self.config.camera_offset) + pose = self._latest_pose + return True, pos, rot, pose + + @property + def feedback_features(self) -> dict[str, type]: + # No haptic or other feedback implemented yet + pass + + def configure(self) -> None: + # No additional configuration required for phone teleop + pass + + def _android_callback(self, pose: np.ndarray, message: dict) -> None: + self._latest_pose = pose + self._latest_message = message + time.sleep(0.001) # 1ms delay to avoid race condition + + def get_action(self) -> dict: + ok, raw_pos, raw_rot, pose = self._read_current_pose() + if not ok or not self.is_calibrated: + return {} + + # Collect raw inputs (B1 / analogs on iOS, move/scale on Android) + raw_inputs: dict[str, float | int | bool] = {} + if self.config.phone_os == PhoneOS.IOS: + io = getattr(pose, "io", None) + if io is not None: + bank_a, bank_b = io.a, io.b + if bank_a: + for ch in range(1, 9): + if bank_a.has_float(ch): + raw_inputs[f"a{ch}"] = float(bank_a.get_float(ch)) + if bank_b: + for ch in range(1, 9): + if bank_b.has_int(ch): + raw_inputs[f"b{ch}"] = int(bank_b.get_int(ch)) + elif hasattr(bank_b, "has_bool") and bank_b.has_bool(ch): + raw_inputs[f"b{ch}"] = int(bank_b.get_bool(ch)) + else: + msg = self._latest_message or {} + raw_inputs["move"] = bool(msg.get("move", False)) + raw_inputs["scale"] = float(msg.get("scale", 1.0)) + raw_inputs["reservedButtonA"] = bool(msg.get("reservedButtonA", False)) + raw_inputs["reservedButtonB"] = bool(msg.get("reservedButtonB", False)) + + if self.config.phone_os == PhoneOS.IOS: + enable = bool(raw_inputs.get("b1", 0)) + else: + enable = bool(raw_inputs.get("move", False)) + + # Rising edge then re-capture calibration immediately from current raw pose + if enable and not self._enabled: + self._reapply_position_calibration(raw_pos) + + # Apply calibration + pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos) + rot_cal = self._calib_rot_inv * raw_rot + + self._enabled = enable + + return { + "phone.pos": pos_cal, + "phone.rot": rot_cal, + "phone.raw_inputs": raw_inputs, + "phone.enabled": self._enabled, + } + + def send_feedback(self, feedback: dict[str, float]) -> None: + # We could add haptic feedback (vibrations) here, but it's not implemented yet + raise NotImplementedError + + def disconnect(self) -> None: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.config.phone_os == PhoneOS.IOS: + self._group = None + else: + self._teleop = None + if self._teleop_thread and self._teleop_thread.is_alive(): + self._teleop_thread.join(timeout=1.0) + self._teleop_thread = None + self._latest_pose = None diff --git a/src/lerobot/teleoperators/phone/phone_processor.py b/src/lerobot/teleoperators/phone/phone_processor.py new file mode 100644 index 000000000..436ee8444 --- /dev/null +++ b/src/lerobot/teleoperators/phone/phone_processor.py @@ -0,0 +1,87 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.configs.types import PolicyFeature +from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry +from lerobot.teleoperators.phone.config_phone import PhoneOS + + +@ProcessorStepRegistry.register("map_phone_action_to_robot_action") +@dataclass +class MapPhoneActionToRobotAction(ActionProcessor): + """ + Map calibrated phone pose (actions) to the inputs for robot actions + + Expected input ACTION keys: + { + "action.phone.enabled": bool, + "action.phone.pos": np.ndarray, + "action.phone.rot": Rotation, + "action.phone.raw_inputs": dict, + } + + Output ACTION keys: + { + "action.enabled": bool, + "action.ee.{x,y,z,wx,wy,wz}" : float + "action.gripper": float, + } + """ + + platform: PhoneOS + _enabled_prev: bool = field(default=False, init=False, repr=False) + + def action(self, act: dict | None) -> dict: + # Pop them from the action + enabled = act.pop("action.phone.enabled", 0) + pos = act.pop("action.phone.pos", None) + rot = act.pop("action.phone.rot", None) + inputs = act.pop("action.phone.raw_inputs", {}) + + if pos is None or rot is None: + return act + + rotvec = rot.as_rotvec() # Absolute orientation as rotvec + + # Map certain inputs to certain actions + if self.platform == PhoneOS.IOS: + gripper = float(inputs.get("a3", 0.0)) + else: + a = float(inputs.get("reservedButtonA", 0.0)) + b = float(inputs.get("reservedButtonB", 0.0)) + gripper = ( + a - b + ) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed + + # For some actions we need to invert the axis + act.update( + { + "action.enabled": enabled, + "action.target_x": -pos[1] if enabled else 0.0, + "action.target_y": pos[0] if enabled else 0.0, + "action.target_z": pos[2] if enabled else 0.0, + "action.target_wx": rotvec[1] if enabled else 0.0, + "action.target_wy": rotvec[0] if enabled else 0.0, + "action.target_wz": -rotvec[2] if enabled else 0.0, + "action.gripper": gripper, # Still send gripper action when disabled + } + ) + return act + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index f0f9aebb7..8a4f65a03 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numbers import os from typing import Any import numpy as np import rerun as rr +from lerobot.processor.pipeline import EnvTransition, TransitionKey + def _init_rerun(session_name: str = "lerobot_control_loop") -> None: """Initializes the Rerun SDK for visualizing the control loop.""" @@ -28,19 +31,87 @@ def _init_rerun(session_name: str = "lerobot_control_loop") -> None: rr.spawn(memory_limit=memory_limit) -def log_rerun_data(observation: dict[str | Any], action: dict[str | Any]): - for obs, val in observation.items(): - if isinstance(val, float): - rr.log(f"observation.{obs}", rr.Scalar(val)) - elif isinstance(val, np.ndarray): - if val.ndim == 1: - for i, v in enumerate(val): - rr.log(f"observation.{obs}_{i}", rr.Scalar(float(v))) +def _is_scalar(x): + return ( + isinstance(x, numbers.Real) + or isinstance(x, (np.integer, np.floating)) + or (isinstance(x, np.ndarray) and x.ndim == 0) + ) + + +def log_rerun_data( + data: list[dict[str | Any] | EnvTransition] | dict[str | Any] | EnvTransition | None = None, + *, + observation: dict[str, Any] | None = None, + action: dict[str, Any] | None = None, +) -> None: + items = data if isinstance(data, list) else ([data] if data is not None else []) + + obs = {} if observation is None else dict(observation) + act = {} if action is None else dict(action) + + for idx, item in enumerate(items): + if not isinstance(item, dict): + continue + + if any(isinstance(k, TransitionKey) for k in item.keys()): + o = item.get(TransitionKey.OBSERVATION) or {} + a = item.get(TransitionKey.ACTION) or {} + if isinstance(o, dict): + obs.update(o) + if isinstance(a, dict): + act.update(a) + continue + + keys = list(item.keys()) + has_obs = any(str(k).startswith("observation.") for k in keys) + has_act = any(str(k).startswith("action.") for k in keys) + + if has_obs or has_act: + if has_obs: + obs.update(item) + if has_act: + act.update(item) + else: + # No prefixes: assume first is observation, second is action, others are observation + if idx == 0: + obs.update(item) + elif idx == 1: + act.update(item) else: - rr.log(f"observation.{obs}", rr.Image(val), static=True) - for act, val in action.items(): - if isinstance(val, float): - rr.log(f"action.{act}", rr.Scalar(val)) - elif isinstance(val, np.ndarray): - for i, v in enumerate(val): - rr.log(f"action.{act}_{i}", rr.Scalar(float(v))) + obs.update(item) + + for k, v in obs.items(): + if v is None: + continue + key = k if str(k).startswith("observation.") else f"observation.{k}" + + if _is_scalar(v): + rr.log(key, rr.Scalar(float(v))) + elif isinstance(v, np.ndarray): + arr = v + # Convert CHW -> HWC when needed + if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4): + arr = np.transpose(arr, (1, 2, 0)) + if arr.ndim == 1: + for i, vi in enumerate(arr): + rr.log(f"{key}_{i}", rr.Scalar(float(vi))) + else: + rr.log(key, rr.Image(arr), static=True) + + for k, v in act.items(): + if v is None: + continue + key = k if str(k).startswith("action.") else f"action.{k}" + + if _is_scalar(v): + rr.log(key, rr.Scalar(float(v))) + elif isinstance(v, np.ndarray): + if v.ndim == 1: + for i, vi in enumerate(v): + rr.log(f"{key}_{i}", rr.Scalar(float(vi))) + else: + # Fall back to flattening higher-dimensional arrays + flat = v.flatten() + for i, vi in enumerate(flat): + rr.log(f"{key}_{i}", rr.Scalar(float(vi))) diff --git a/tests/datasets/test_dataset_utils.py b/tests/datasets/test_dataset_utils.py new file mode 100644 index 000000000..ae09fb262 --- /dev/null +++ b/tests/datasets/test_dataset_utils.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from datasets import Dataset +from huggingface_hub import DatasetCard + +from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index +from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch, merge_features + + +def test_default_parameters(): + card = create_lerobot_dataset_card() + assert isinstance(card, DatasetCard) + assert card.data.tags == ["LeRobot"] + assert card.data.task_categories == ["robotics"] + assert card.data.configs == [ + { + "config_name": "default", + "data_files": "data/*/*.parquet", + } + ] + + +def test_with_tags(): + tags = ["tag1", "tag2"] + card = create_lerobot_dataset_card(tags=tags) + assert card.data.tags == ["LeRobot", "tag1", "tag2"] + + +def test_calculate_episode_data_index(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3])) + assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6])) + + +def test_merge_simple_vectors(): + g1 = { + "action": { + "dtype": "float32", + "shape": (2,), + "names": ["ee.x", "ee.y"], + } + } + g2 = { + "action": { + "dtype": "float32", + "shape": (2,), + "names": ["ee.y", "ee.z"], + } + } + + out = merge_features(g1, g2) + + assert "action" in out + assert out["action"]["dtype"] == "float32" + # Names merged with preserved order and de-dupuplication + assert out["action"]["names"] == ["ee.x", "ee.y", "ee.z"] + # Shape correctly recomputed from names length + assert out["action"]["shape"] == (3,) + + +def test_merge_multiple_groups_order_and_dedup(): + g1 = {"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}} + g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}} + g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}} + + out = merge_features(g1, g2, g3) + + assert out["action"]["names"] == ["a", "b", "c", "d"] + assert out["action"]["shape"] == (4,) + + +def test_non_vector_last_wins_for_images(): + # Non-vector (images) with same name should be overwritten by the last image specified + g1 = { + "observation.images.front": { + "dtype": "image", + "shape": (3, 480, 640), + "names": ["channels", "height", "width"], + } + } + g2 = { + "observation.images.front": { + "dtype": "image", + "shape": (3, 720, 1280), + "names": ["channels", "height", "width"], + } + } + + out = merge_features(g1, g2) + assert out["observation.images.front"]["shape"] == (3, 720, 1280) + assert out["observation.images.front"]["dtype"] == "image" + + +def test_dtype_mismatch_raises(): + g1 = {"action": {"dtype": "float32", "shape": (1,), "names": ["a"]}} + g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}} + + with pytest.raises(ValueError, match="dtype mismatch for 'action'"): + _ = merge_features(g1, g2) + + +def test_non_dict_passthrough_last_wins(): + g1 = {"misc": 123} + g2 = {"misc": 456} + + out = merge_features(g1, g2) + # For non-dict entries the last one wins + assert out["misc"] == 456 diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py deleted file mode 100644 index ba16874d0..000000000 --- a/tests/datasets/test_utils.py +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from datasets import Dataset -from huggingface_hub import DatasetCard - -from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index -from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch - - -def test_default_parameters(): - card = create_lerobot_dataset_card() - assert isinstance(card, DatasetCard) - assert card.data.tags == ["LeRobot"] - assert card.data.task_categories == ["robotics"] - assert card.data.configs == [ - { - "config_name": "default", - "data_files": "data/*/*.parquet", - } - ] - - -def test_with_tags(): - tags = ["tag1", "tag2"] - card = create_lerobot_dataset_card(tags=tags) - assert card.data.tags == ["LeRobot", "tag1", "tag2"] - - -def test_calculate_episode_data_index(): - dataset = Dataset.from_dict( - { - "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - "index": [0, 1, 2, 3, 4, 5], - "episode_index": [0, 0, 1, 2, 2, 2], - }, - ) - dataset.set_transform(hf_transform_to_torch) - episode_data_index = calculate_episode_data_index(dataset) - assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3])) - assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6])) diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py new file mode 100644 index 000000000..590f6a892 --- /dev/null +++ b/tests/processor/test_converters.py @@ -0,0 +1,196 @@ +import numpy as np +import pytest +import torch + +from lerobot.processor.converters import ( + to_dataset_frame, + to_output_robot_action, + to_transition_robot_observation, + to_transition_teleop_action, +) +from lerobot.processor.pipeline import TransitionKey + + +def test_to_transition_teleop_action_prefix_and_tensor_conversion(): + # Scalars, arrays, and "image-like" uint8 arrays are supported + img = np.zeros((8, 12, 3), dtype=np.uint8) + act = { + "ee.x": 0.5, # scalar to torch tensor + "delta": np.array([1.0, 2.0]), # ndarray to torch tensor + "raw_img": img, # uint8 HWC to passthrough ndarray + } + + tr = to_transition_teleop_action(act) + + # Should be an EnvTransition-like dict with ACTION populated + assert isinstance(tr, dict) + assert TransitionKey.ACTION in tr + assert "action.ee.x" in tr[TransitionKey.ACTION] + assert "action.delta" in tr[TransitionKey.ACTION] + assert "action.raw_img" in tr[TransitionKey.ACTION] + + # Types: scalars/arrays -> torch tensor; images to np.ndarray + assert isinstance(tr[TransitionKey.ACTION]["action.ee.x"], torch.Tensor) + assert tr[TransitionKey.ACTION]["action.ee.x"].item() == pytest.approx(0.5) + + assert isinstance(tr[TransitionKey.ACTION]["action.delta"], torch.Tensor) + assert tr[TransitionKey.ACTION]["action.delta"].shape == (2,) + assert torch.allclose(tr[TransitionKey.ACTION]["action.delta"], torch.tensor([1.0, 2.0])) + + assert isinstance(tr[TransitionKey.ACTION]["action.raw_img"], np.ndarray) + assert tr[TransitionKey.ACTION]["action.raw_img"].dtype == np.uint8 + assert tr[TransitionKey.ACTION]["action.raw_img"].shape == (8, 12, 3) + + # Observation is created as empty dict by make_transition + assert TransitionKey.OBSERVATION in tr + assert isinstance(tr[TransitionKey.OBSERVATION], dict) + assert tr[TransitionKey.OBSERVATION] == {} + + +def test_to_transition_robot_observation_state_vs_images_split(): + # Create an observation with mixed content + img = np.full((10, 20, 3), 255, dtype=np.uint8) # image (uint8 HWC) + obs = { + "j1.pos": 10.0, # scalar to state to torch tensor + "j2.pos": np.float32(20.0), # scalar np to state to torch tensor + "image_front": img, # to images passthrough + "flag": np.int32(7), # scalar to state to torch tensor + "arr": np.array([1.5, 2.5]), # vector to state to torch tensor + } + + tr = to_transition_robot_observation(obs) + assert isinstance(tr, dict) + assert TransitionKey.OBSERVATION in tr + + out = tr[TransitionKey.OBSERVATION] + # Check state keys are present and converted to tensors + for k in ("j1.pos", "j2.pos", "flag", "arr"): + key = f"observation.state.{k}" + assert key in out + v = out[key] + if k != "arr": + assert isinstance(v, torch.Tensor) and v.ndim == 0 + else: + assert isinstance(v, torch.Tensor) and v.ndim == 1 and v.shape == (2,) + + # Check image present as is + assert "observation.images.image_front" in out + assert isinstance(out["observation.images.image_front"], np.ndarray) + assert out["observation.images.image_front"].dtype == np.uint8 + assert out["observation.images.image_front"].shape == (10, 20, 3) + + # ACTION should be empty dict by make_transition + assert TransitionKey.ACTION in tr + assert isinstance(tr[TransitionKey.ACTION], dict) + assert tr[TransitionKey.ACTION] == {} + + +def test_to_output_robot_action_strips_prefix_and_filters_pos_keys_only(): + # Build a transition with mixed action keys + tr = { + TransitionKey.ACTION: { + "action.j1.pos": 11.0, # keep "j1.pos" + "action.gripper.pos": torch.tensor(33.0), # keep: tensor accepted + "action.ee.x": 0.5, # ignore (doesn't end with .pos) + "misc": "ignore_me", # ignore (no 'action.' prefix) + } + } + + out = to_output_robot_action(tr) + # Only ".pos" keys with "action." prefix are retained and stripped to base names + assert set(out.keys()) == {"j1.pos", "gripper.pos"} + # Values converted to float + assert isinstance(out["j1.pos"], float) + assert isinstance(out["gripper.pos"], float) + assert out["j1.pos"] == pytest.approx(11.0) + assert out["gripper.pos"] == pytest.approx(33.0) + + +def test_to_dataset_frame_merge_and_pack_vectors_and_metadata(): + # Fabricate dataset features (as stored in dataset.meta["features"]) + features = { + # Action vector: 3 elements in specific order + "action": { + "dtype": "float32", + "shape": (3,), + "names": ["j1.pos", "j2.pos", "gripper.pos"], + }, + # Observation state vector: 2 elements + "observation.state": { + "dtype": "float32", + "shape": (2,), + "names": ["j1.pos", "j2.pos"], + }, + # Image spec (video/image dtype acceptable) + "observation.images.front": { + "dtype": "image", + "shape": (480, 640, 3), + "names": ["h", "w", "c"], + }, + } + + # Build two transitions to be merged: teleop (action) and robot obs (state/images) + img = np.random.randint(0, 255, size=(480, 640, 3), dtype=np.uint8) + + teleop_transition = { + TransitionKey.OBSERVATION: {}, + TransitionKey.ACTION: { + "action.j1.pos": torch.tensor(1.1), + "action.j2.pos": torch.tensor(2.2), + # gripper.pos missing → defaults to 0.0 + "action.ee.x": 0.5, # ignored, not in features["action"]["names"] + }, + TransitionKey.COMPLEMENTARY_DATA: { + "frame_is_pad": True, + "task": "Pick cube", + }, + } + + robot_transition = { + TransitionKey.OBSERVATION: { + "observation.state.j1.pos": torch.tensor(10.0), + "observation.state.j2.pos": torch.tensor(20.0), + "observation.images.front": img, + }, + TransitionKey.REWARD: torch.tensor(5.0), + TransitionKey.DONE: True, + TransitionKey.TRUNCATED: False, + TransitionKey.INFO: {"note": "ok"}, + } + + # Directly call the refactored function + batch = to_dataset_frame([teleop_transition, robot_transition], features) + + # Images passthrough + assert "observation.images.front" in batch + assert batch["observation.images.front"].shape == img.shape + assert batch["observation.images.front"].dtype == np.uint8 + assert np.shares_memory(batch["observation.images.front"], img) or np.array_equal( + batch["observation.images.front"], img + ) + + # Observation.state vector + assert "observation.state" in batch + obs_vec = batch["observation.state"] + assert isinstance(obs_vec, np.ndarray) and obs_vec.dtype == np.float32 + assert obs_vec.shape == (2,) + assert obs_vec[0] == pytest.approx(10.0) + assert obs_vec[1] == pytest.approx(20.0) + + # Action vector + assert "action" in batch + act_vec = batch["action"] + assert isinstance(act_vec, np.ndarray) and act_vec.dtype == np.float32 + assert act_vec.shape == (3,) + assert act_vec[0] == pytest.approx(1.1) + assert act_vec[1] == pytest.approx(2.2) + assert act_vec[2] == pytest.approx(0.0) # default for missing gripper.pos + + # Next.* metadata + assert batch["next.reward"] == pytest.approx(5.0) + assert batch["next.done"] is True + assert batch["next.truncated"] is False + + # Complementary data + assert batch["frame_is_pad"] is True + assert batch["task"] == "Pick cube" diff --git a/tests/processor/test_device_processor.py b/tests/processor/test_device_processor.py index bda120015..7e30750f4 100644 --- a/tests/processor/test_device_processor.py +++ b/tests/processor/test_device_processor.py @@ -288,8 +288,8 @@ def test_serialization_methods(): assert processor.device == device -def test_feature_contract(): - """Test that feature_contract returns features unchanged.""" +def test_features(): + """Test that features returns features unchanged.""" processor = DeviceProcessor(device="cpu") features = { @@ -297,7 +297,7 @@ def test_feature_contract(): "action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)), } - result = processor.feature_contract(features) + result = processor.transform_features(features) assert result == features assert result is features # Should return the same object diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 6fc60b49b..97c737e0c 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -621,10 +621,19 @@ def test_serialization_roundtrip(full_stats): assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) # Verify features and norm_map are correctly reconstructed - assert new_processor.features.keys() == original_processor.features.keys() - for key in new_processor.features: - assert new_processor.features[key].type == original_processor.features[key].type - assert new_processor.features[key].shape == original_processor.features[key].shape + assert ( + new_processor.transform_features(features).keys() + == original_processor.transform_features(features).keys() + ) + for key in new_processor.transform_features(features): + assert ( + new_processor.transform_features(features)[key].type + == original_processor.transform_features(features)[key].type + ) + assert ( + new_processor.transform_features(features)[key].shape + == original_processor.transform_features(features)[key].shape + ) assert new_processor.norm_map == original_processor.norm_map diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index e48b6bc08..4e6efdb6c 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -410,13 +410,13 @@ def test_equivalent_with_image_dict(): torch.testing.assert_close(original_result[key], processor_result[key]) -def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory): +def test_image_processor_features_pixels_to_image(policy_feature_factory): processor = VanillaObservationProcessor() features = { "pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), "keep": policy_feature_factory(FeatureType.ENV, (1,)), } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"] assert "pixels" not in out @@ -424,13 +424,13 @@ def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory assert_contract_is_typed(out) -def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory): +def test_image_processor_features_observation_pixels_to_image(policy_feature_factory): processor = VanillaObservationProcessor() features = { "observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), "keep": policy_feature_factory(FeatureType.ENV, (1,)), } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"] assert "observation.pixels" not in out @@ -438,7 +438,7 @@ def test_image_processor_feature_contract_observation_pixels_to_image(policy_fea assert_contract_is_typed(out) -def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory): +def test_image_processor_features_multi_camera_and_prefixed(policy_feature_factory): processor = VanillaObservationProcessor() features = { "pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), @@ -446,7 +446,7 @@ def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_featu "observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), "keep": policy_feature_factory(FeatureType.ENV, (7,)), } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"] assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"] @@ -456,14 +456,14 @@ def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_featu assert_contract_is_typed(out) -def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory): +def test_state_processor_features_environment_and_agent_pos(policy_feature_factory): processor = VanillaObservationProcessor() features = { "environment_state": policy_feature_factory(FeatureType.STATE, (3,)), "agent_pos": policy_feature_factory(FeatureType.STATE, (7,)), "keep": policy_feature_factory(FeatureType.ENV, (1,)), } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"] assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"] @@ -472,13 +472,13 @@ def test_state_processor_feature_contract_environment_and_agent_pos(policy_featu assert_contract_is_typed(out) -def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory): +def test_state_processor_features_prefixed_inputs(policy_feature_factory): proc = VanillaObservationProcessor() features = { "observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), } - out = proc.feature_contract(features.copy()) + out = proc.transform_features(features.copy()) assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"] assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"] diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 26e865fad..42a8eb538 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -26,6 +26,7 @@ import torch import torch.nn as nn from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor from lerobot.processor.pipeline import TransitionKey from tests.conftest import assert_contract_is_typed @@ -90,8 +91,8 @@ class MockStep: def reset(self) -> None: self.counter = 0 - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features @@ -112,8 +113,8 @@ class MockStepWithoutOptionalMethods: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features @@ -168,8 +169,8 @@ class MockStepWithTensorState: self.running_mean.zero_() self.running_count.zero_() - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features @@ -662,8 +663,8 @@ class MockModuleStep(nn.Module): self.running_mean.zero_() self.counter = 0 - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features @@ -744,8 +745,8 @@ class MockNonModuleStepWithState: self.step_count.zero_() self.history.clear() - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features @@ -799,8 +800,8 @@ class MockStepWithNonSerializableParam: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features @@ -838,8 +839,8 @@ class RegisteredMockStep: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features @@ -1382,8 +1383,8 @@ def test_state_file_naming_with_registry(): def load_state_dict(self, state): self.state_tensor = state["state_tensor"] - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features try: @@ -1439,8 +1440,8 @@ def test_override_with_nested_config(): def get_config(self): return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config} - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features try: @@ -1531,8 +1532,8 @@ def test_override_with_callables(): def get_config(self): return {"name": self.name} - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features try: @@ -1766,8 +1767,8 @@ def test_override_with_device_strings(): def load_state_dict(self, state): self.buffer = state["buffer"] - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features try: @@ -1860,21 +1861,16 @@ def test_save_load_with_custom_converter_functions(): class NonCompliantStep: - """Intentionally non-compliant: missing feature_contract.""" + """Intentionally non-compliant: missing features.""" def __call__(self, transition: EnvTransition) -> EnvTransition: return transition -def test_construction_rejects_step_without_feature_contract(): - with pytest.raises(TypeError, match=r"must define feature_contract\(features\) -> dict\[str, Any\]"): - RobotProcessor([NonCompliantStep()]) - - class NonCallableStep: """Intentionally non-compliant: missing __call__.""" - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -1893,7 +1889,7 @@ class FeatureContractAddStep: def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: features[self.key] = self.value return features @@ -1908,7 +1904,7 @@ class FeatureContractMutateStep: def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: features[self.key] = self.fn(features.get(self.key)) return features @@ -1920,7 +1916,7 @@ class FeatureContractBadReturnStep: def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return ["not-a-dict"] @@ -1933,12 +1929,12 @@ class FeatureContractRemoveStep: def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: features.pop(self.key, None) return features -def test_feature_contract_orders_and_merges(policy_feature_factory): +def test_features_orders_and_merges(policy_feature_factory): p = RobotProcessor( [ FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), @@ -1946,14 +1942,14 @@ def test_feature_contract_orders_and_merges(policy_feature_factory): FeatureContractAddStep("b", policy_feature_factory(FeatureType.ENV, (2,))), ] ) - out = p.feature_contract({}) + out = p.transform_features({}) assert out["a"].type == FeatureType.STATE and out["a"].shape == (3,) assert out["b"].type == FeatureType.ENV and out["b"].shape == (2,) assert_contract_is_typed(out) -def test_feature_contract_respects_initial_without_mutation(policy_feature_factory): +def test_features_respects_initial_without_mutation(policy_feature_factory): initial = { "seed": policy_feature_factory(FeatureType.STATE, (7,)), "nested": policy_feature_factory(FeatureType.ENV, (0,)), @@ -1966,7 +1962,7 @@ def test_feature_contract_respects_initial_without_mutation(policy_feature_facto ), ] ) - out = p.feature_contract(initial_features=initial) + out = p.transform_features(initial_features=initial) assert out["seed"].shape == (8,) assert out["nested"].shape == (5,) @@ -1977,13 +1973,7 @@ def test_feature_contract_respects_initial_without_mutation(policy_feature_facto assert_contract_is_typed(out) -def test_feature_contract_type_error_on_bad_step(): - p = RobotProcessor([FeatureContractAddStep(), FeatureContractBadReturnStep()]) - with pytest.raises(TypeError, match=r"\w+\.feature_contract must return dict\[str, Any\]"): - _ = p.feature_contract({}) - - -def test_feature_contract_execution_order_tracking(): +def test_features_execution_order_tracking(): class Track: def __init__(self, label): self.label = label @@ -1991,32 +1981,186 @@ def test_feature_contract_execution_order_tracking(): def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: code = {"A": 1, "B": 2, "C": 3}[self.label] pf = features.get("order", PolicyFeature(type=FeatureType.ENV, shape=())) features["order"] = PolicyFeature(type=pf.type, shape=pf.shape + (code,)) return features - out = RobotProcessor([Track("A"), Track("B"), Track("C")]).feature_contract({}) + out = RobotProcessor([Track("A"), Track("B"), Track("C")]).transform_features({}) assert out["order"].shape == (1, 2, 3) -def test_feature_contract_remove_key(policy_feature_factory): +def test_features_remove_key(policy_feature_factory): p = RobotProcessor( [ FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), FeatureContractRemoveStep("a"), ] ) - out = p.feature_contract({}) + out = p.transform_features({}) assert "a" not in out -def test_feature_contract_remove_from_initial(policy_feature_factory): +def test_features_remove_from_initial(policy_feature_factory): initial = { "keep": policy_feature_factory(FeatureType.STATE, (1,)), "drop": policy_feature_factory(FeatureType.STATE, (1,)), } p = RobotProcessor([FeatureContractRemoveStep("drop")]) - out = p.feature_contract(initial_features=initial) + out = p.transform_features(initial_features=initial) assert "drop" not in out and out["keep"] == initial["keep"] + + +@dataclass +class AddActionEEAndJointFeatures: + """Adds both EE and JOINT action features.""" + + def __call__(self, tr): + return tr + + def transform_features(self, features: dict) -> dict: + # EE features + features["action.ee.x"] = float + features["action.ee.y"] = float + # JOINT features + features["action.j1.pos"] = float + features["action.j2.pos"] = float + return features + + +@dataclass +class AddObservationStateFeatures: + """Adds state features (and optionally an image spec to test precedence).""" + + add_front_image: bool = False + front_image_shape: tuple = (240, 320, 3) + + def __call__(self, tr): + return tr + + def transform_features(self, features: dict) -> dict: + # State features (mix EE and a joint state) + features["observation.state.ee.x"] = float + features["observation.state.j1.pos"] = float + if self.add_front_image: + features["observation.images.front"] = self.front_image_shape + return features + + +def test_aggregate_joint_action_only(): + rp = RobotProcessor([AddActionEEAndJointFeatures()]) + initial = {"front": (480, 640, 3)} + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features=initial, + use_videos=True, + patterns=["action.j1.pos", "action.j2.pos"], + ) + + # Expect only "action" with joint names + assert "action" in out and "observation.state" not in out + assert out["action"]["dtype"] == "float32" + assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"} + assert out["action"]["shape"] == (len(out["action"]["names"]),) + + +def test_aggregate_ee_action_and_observation_with_videos(): + rp = RobotProcessor([AddActionEEAndJointFeatures(), AddObservationStateFeatures()]) + initial = {"front": (480, 640, 3), "side": (720, 1280, 3)} + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features=initial, + use_videos=True, + patterns=["action.ee", "observation.state"], + ) + + # Action should pack only EE names + assert "action" in out + assert set(out["action"]["names"]) == {"ee.x", "ee.y"} + assert out["action"]["dtype"] == "float32" + + # Observation state should pack both ee.x and j1.pos as a vector + assert "observation.state" in out + assert set(out["observation.state"]["names"]) == {"ee.x", "j1.pos"} + assert out["observation.state"]["dtype"] == "float32" + + # Cameras from initial_features appear as videos + for cam in ("front", "side"): + key = f"observation.images.{cam}" + assert key in out + assert out[key]["dtype"] == "video" + assert out[key]["shape"] == initial[cam] + assert out[key]["names"] == ["height", "width", "channels"] + + +def test_aggregate_both_action_types(): + rp = RobotProcessor([AddActionEEAndJointFeatures()]) + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features={}, + use_videos=True, + patterns=["action.ee", "action.j1", "action.j2.pos"], + ) + + assert "action" in out + expected = {"ee.x", "ee.y", "j1.pos", "j2.pos"} + assert set(out["action"]["names"]) == expected + assert out["action"]["shape"] == (len(expected),) + + +def test_aggregate_images_when_use_videos_false(): + rp = RobotProcessor([AddObservationStateFeatures(add_front_image=True)]) + initial = {"back": (480, 640, 3)} + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features=initial, + use_videos=False, # expect "image" dtype + patterns=None, + ) + + key = "observation.images.back" + key_front = "observation.images.front" + assert key not in out + assert key_front not in out + + +def test_aggregate_images_when_use_videos_true(): + rp = RobotProcessor([AddObservationStateFeatures(add_front_image=True)]) + initial = {"back": (480, 640, 3)} + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features=initial, + use_videos=True, + patterns=None, + ) + + key = "observation.images.front" + key_back = "observation.images.back" + assert key in out + assert key_back in out + assert out[key]["dtype"] == "video" + assert out[key_back]["dtype"] == "video" + assert out[key_back]["shape"] == initial["back"] + + +def test_initial_camera_not_overridden_by_step_image(): + # Step explicitly sets a different front image shape; initial has another shape. + # aggregate_pipeline_dataset_features should keep the step's value (setdefault behavior on initial cams). + rp = RobotProcessor([AddObservationStateFeatures(add_front_image=True, front_image_shape=(240, 320, 3))]) + initial = {"front": (480, 640, 3)} # should NOT override the step-provided (240, 320, 3) + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features=initial, + use_videos=True, + patterns=["observation.images.front"], + ) + + key = "observation.images.front" + assert key in out + assert out[key]["shape"] == (240, 320, 3) # from the step, not from initial diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py index 229d57f9f..398b3ec9c 100644 --- a/tests/processor/test_rename_processor.py +++ b/tests/processor/test_rename_processor.py @@ -410,7 +410,7 @@ def test_value_types_preserved(): assert processed_obs["old_list"] == [1, 2, 3] -def test_feature_contract_basic_renaming(policy_feature_factory): +def test_features_basic_renaming(policy_feature_factory): processor = RenameProcessor(rename_map={"a": "x", "b": "y"}) features = { "a": policy_feature_factory(FeatureType.STATE, (2,)), @@ -418,7 +418,7 @@ def test_feature_contract_basic_renaming(policy_feature_factory): "c": policy_feature_factory(FeatureType.ENV, (1,)), } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) # Values preserved and typed assert out["x"] == features["a"] @@ -430,14 +430,14 @@ def test_feature_contract_basic_renaming(policy_feature_factory): assert set(features) == {"a", "b", "c"} -def test_feature_contract_overlapping_keys(policy_feature_factory): +def test_features_overlapping_keys(policy_feature_factory): # Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c' processor = RenameProcessor(rename_map={"a": "b", "b": "c"}) features = { "a": policy_feature_factory(FeatureType.STATE, (1,)), "b": policy_feature_factory(FeatureType.STATE, (2,)), } - out = processor.feature_contract(features) + out = processor.transform_features(features) assert set(out) == {"b", "c"} assert out["b"] == features["a"] # 'a' renamed to'b' @@ -445,7 +445,7 @@ def test_feature_contract_overlapping_keys(policy_feature_factory): assert_contract_is_typed(out) -def test_feature_contract_chained_processors(policy_feature_factory): +def test_features_chained_processors(policy_feature_factory): # Chain two rename processors at the contract level processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"}) processor2 = RenameProcessor( @@ -458,7 +458,7 @@ def test_feature_contract_chained_processors(policy_feature_factory): "img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), "extra": policy_feature_factory(FeatureType.ENV, (1,)), } - out = pipeline.feature_contract(initial_features=spec) + out = pipeline.transform_features(initial_features=spec) assert set(out) == {"observation.state", "observation.image", "extra"} assert out["observation.state"] == spec["pos"] diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index 452c36da9..784b1ce81 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -470,7 +470,7 @@ def test_registry_functionality(): @require_package("transformers") -def test_feature_contract_basic(): +def test_features_basic(): """Test basic feature contract functionality.""" mock_tokenizer = MockTokenizer(vocab_size=100) processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=128) @@ -480,7 +480,7 @@ def test_feature_contract_basic(): "action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)), } - output_features = processor.feature_contract(input_features) + output_features = processor.transform_features(input_features) # Check that original features are preserved assert "observation.state" in output_features @@ -501,13 +501,13 @@ def test_feature_contract_basic(): @require_package("transformers") -def test_feature_contract_with_custom_max_length(): +def test_features_with_custom_max_length(): """Test feature contract with custom max_length.""" mock_tokenizer = MockTokenizer(vocab_size=100) processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=64) input_features = {} - output_features = processor.feature_contract(input_features) + output_features = processor.transform_features(input_features) # Check that features use correct max_length assert f"{OBS_LANGUAGE}.tokens" in output_features @@ -521,7 +521,7 @@ def test_feature_contract_with_custom_max_length(): @require_package("transformers") -def test_feature_contract_existing_features(): +def test_features_existing_features(): """Test feature contract when tokenized features already exist.""" mock_tokenizer = MockTokenizer(vocab_size=100) processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=256) @@ -531,7 +531,7 @@ def test_feature_contract_existing_features(): f"{OBS_LANGUAGE}.attention_mask": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)), } - output_features = processor.feature_contract(input_features) + output_features = processor.transform_features(input_features) # Should not overwrite existing features assert output_features[f"{OBS_LANGUAGE}.tokens"].shape == (100,) # Original shape preserved diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py new file mode 100644 index 000000000..5e1eb4bab --- /dev/null +++ b/tests/utils/test_visualization_utils.py @@ -0,0 +1,205 @@ +import importlib +import sys +from types import SimpleNamespace + +import numpy as np +import pytest + +from lerobot.processor.pipeline import TransitionKey + + +@pytest.fixture +def mock_rerun(monkeypatch): + """ + Provide a mock `rerun` module so tests don't depend on the real library. + Also reload the module-under-test so it binds to this mock `rr`. + """ + calls = [] + + class DummyScalar: + def __init__(self, value): + self.value = float(value) + + class DummyImage: + def __init__(self, arr): + self.arr = arr + + def dummy_log(key, obj, **kwargs): + calls.append((key, obj, kwargs)) + + dummy_rr = SimpleNamespace( + Scalar=DummyScalar, + Image=DummyImage, + log=dummy_log, + init=lambda *a, **k: None, + spawn=lambda *a, **k: None, + ) + + # Inject fake module into sys.modules + monkeypatch.setitem(sys.modules, "rerun", dummy_rr) + + # Now import and reload the module under test, to bind to our rerun mock + import lerobot.utils.visualization_utils as vu + + importlib.reload(vu) + + # Expose both the reloaded module and the call recorder + yield vu, calls + + +def _keys(calls): + """Helper to extract just the keys logged to rr.log""" + return [k for (k, _obj, _kw) in calls] + + +def _obj_for(calls, key): + """Find the first object logged under a given key.""" + for k, obj, _kw in calls: + if k == key: + return obj + raise KeyError(f"Key {key} not found in calls: {calls}") + + +def _kwargs_for(calls, key): + for k, _obj, kw in calls: + if k == key: + return kw + raise KeyError(f"Key {key} not found in calls: {calls}") + + +def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): + vu, calls = mock_rerun + + # Build EnvTransition dict + obs = { + "observation.state.temperature": np.float32(25.0), + # CHW image should be converted to HWC for rr.Image + "observation.camera": np.zeros((3, 10, 20), dtype=np.uint8), + } + act = { + "action.throttle": 0.7, + # 1D array should log individual Scalars with suffix _i + "action.vector": np.array([1.0, 2.0], dtype=np.float32), + } + transition = { + TransitionKey.OBSERVATION: obs, + TransitionKey.ACTION: act, + } + + vu.log_rerun_data(transition) + + # We expect: + # - observation.state.temperature -> Scalar + # - observation.camera -> Image (HWC) with static=True + # - action.throttle -> Scalar + # - action.vector_0, action.vector_1 -> Scalars + expected_keys = { + "observation.state.temperature", + "observation.camera", + "action.throttle", + "action.vector_0", + "action.vector_1", + } + assert set(_keys(calls)) == expected_keys + + # Check scalar types and values + temp_obj = _obj_for(calls, "observation.state.temperature") + assert type(temp_obj).__name__ == "DummyScalar" + assert temp_obj.value == pytest.approx(25.0) + + throttle_obj = _obj_for(calls, "action.throttle") + assert type(throttle_obj).__name__ == "DummyScalar" + assert throttle_obj.value == pytest.approx(0.7) + + v0 = _obj_for(calls, "action.vector_0") + v1 = _obj_for(calls, "action.vector_1") + assert type(v0).__name__ == "DummyScalar" + assert type(v1).__name__ == "DummyScalar" + assert v0.value == pytest.approx(1.0) + assert v1.value == pytest.approx(2.0) + + # Check image handling: CHW -> HWC + img_obj = _obj_for(calls, "observation.camera") + assert type(img_obj).__name__ == "DummyImage" + assert img_obj.arr.shape == (10, 20, 3) # transposed + assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images + + +def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun): + vu, calls = mock_rerun + + # First dict without prefixes treated as observation + # Second dict without prefixes treated as action + obs_plain = { + "temp": 1.5, + # Already HWC image => should stay as-is + "img": np.zeros((5, 6, 3), dtype=np.uint8), + "none": None, # should be skipped + } + act_plain = { + "throttle": 0.3, + "vec": np.array([9, 8, 7], dtype=np.float32), + } + + vu.log_rerun_data([obs_plain, act_plain]) + + # Expected keys with auto-prefixes + expected = { + "observation.temp", + "observation.img", + "action.throttle", + "action.vec_0", + "action.vec_1", + "action.vec_2", + } + logged = set(_keys(calls)) + assert logged == expected + + # Scalars + t = _obj_for(calls, "observation.temp") + assert type(t).__name__ == "DummyScalar" + assert t.value == pytest.approx(1.5) + + throttle = _obj_for(calls, "action.throttle") + assert type(throttle).__name__ == "DummyScalar" + assert throttle.value == pytest.approx(0.3) + + # Image stays HWC + img = _obj_for(calls, "observation.img") + assert type(img).__name__ == "DummyImage" + assert img.arr.shape == (5, 6, 3) + assert _kwargs_for(calls, "observation.img").get("static", False) is True + + # Vectors + for i, val in enumerate([9, 8, 7]): + o = _obj_for(calls, f"action.vec_{i}") + assert type(o).__name__ == "DummyScalar" + assert o.value == pytest.approx(val) + + +def test_log_rerun_data_kwargs_only(mock_rerun): + vu, calls = mock_rerun + + vu.log_rerun_data( + None, + observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)}, + action={"action.a": 1.0}, + ) + + keys = set(_keys(calls)) + assert "observation.temp" in keys + assert "observation.gray" in keys + assert "action.a" in keys + + temp = _obj_for(calls, "observation.temp") + assert type(temp).__name__ == "DummyScalar" + assert temp.value == pytest.approx(10.0) + + img = _obj_for(calls, "observation.gray") + assert type(img).__name__ == "DummyImage" + assert img.arr.shape == (8, 8, 1) # remains HWC + assert _kwargs_for(calls, "observation.gray").get("static", False) is True + + a = _obj_for(calls, "action.a") + assert type(a).__name__ == "DummyScalar" + assert a.value == pytest.approx(1.0)