mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
Compare commits
1 Commits
ci/convert
...
feat/accel
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
37103baa07 |
@@ -29,7 +29,7 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
||||
|
||||
# Install system dependencies and uv (as root)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential git curl libglib2.0-0 libegl1-mesa-dev ffmpeg \
|
||||
build-essential git curl libglib2.0-0 libegl1-mesa ffmpeg \
|
||||
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||
|
||||
@@ -39,8 +39,6 @@
|
||||
- sections:
|
||||
- local: notebooks
|
||||
title: Notebooks
|
||||
- local: feetech
|
||||
title: Updating Feetech Firmware
|
||||
title: "Resources"
|
||||
- sections:
|
||||
- local: contributing
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
# Feetech Motor Firmware Update
|
||||
|
||||
This tutorial guides you through updating the firmware of Feetech motors using the official Feetech software.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Windows computer (Feetech software is only available for Windows)
|
||||
- Feetech motor control board
|
||||
- USB cable to connect the control board to your computer
|
||||
- Feetech motors connected to the control board
|
||||
|
||||
## Step 1: Download Feetech Software
|
||||
|
||||
1. Visit the official Feetech software download page: [https://www.feetechrc.com/software.html](https://www.feetechrc.com/software.html)
|
||||
2. Download the latest version of the Feetech debugging software (FD)
|
||||
3. Install the software on your Windows computer
|
||||
|
||||
## Step 2: Hardware Setup
|
||||
|
||||
1. Connect your Feetech motors to the motor control board
|
||||
2. Connect the motor control board to your Windows computer via USB cable
|
||||
3. Ensure power is supplied to the motors
|
||||
|
||||
## Step 3: Configure Connection
|
||||
|
||||
1. Launch the Feetech debugging software
|
||||
2. Select the correct COM port from the port dropdown menu
|
||||
- If unsure which port to use, check Windows Device Manager under "Ports (COM & LPT)"
|
||||
3. Set the appropriate baud rate (typically 1000000 for most Feetech motors)
|
||||
4. Click "Open" to establish communication with the control board
|
||||
|
||||
## Step 4: Scan for Motors
|
||||
|
||||
1. Once connected, click the "Search" button to detect all connected motors
|
||||
2. The software will automatically discover and list all motors on the bus
|
||||
3. Each motor will appear with its ID number
|
||||
|
||||
## Step 5: Update Firmware
|
||||
|
||||
For each motor you want to update:
|
||||
|
||||
1. **Select the motor** from the list by clicking on it
|
||||
2. **Click on Upgrade tab**:
|
||||
3. **Click on Online button**:
|
||||
- If an potential firmware update is found, it will be displayed in the box
|
||||
4. **Click on Upgrade button**:
|
||||
- The update progress will be displayed
|
||||
|
||||
## Step 6: Verify Update
|
||||
|
||||
1. After the update completes, the software should automatically refresh the motor information
|
||||
2. Verify that the firmware version has been updated to the expected version
|
||||
|
||||
## Important Notes
|
||||
|
||||
⚠️ **Warning**: Do not disconnect power or USB during firmware updates, it will potentially brick the motor.
|
||||
|
||||
## Bonus: Motor Debugging on Linux/macOS
|
||||
|
||||
For debugging purposes only, you can use the open-source Feetech Debug Tool:
|
||||
|
||||
- **Repository**: [FT_SCServo_Debug_Qt](https://github.com/CarolinePascal/FT_SCServo_Debug_Qt/tree/fix/port-search-timer)
|
||||
|
||||
### Installation Instructions
|
||||
|
||||
Follow the instructions in the repository to install the tool, for Ubuntu you can directly install it, for MacOS you need to build it from source.
|
||||
|
||||
**Limitations:**
|
||||
|
||||
- This tool is for debugging and parameter adjustment only
|
||||
- Firmware updates must still be done on Windows with official Feetech software
|
||||
@@ -127,11 +127,11 @@ class RewardClassifierConfig:
|
||||
# Dataset configuration
|
||||
class DatasetConfig:
|
||||
repo_id: str # LeRobot dataset repository ID
|
||||
dataset_root: str # Local dataset root directory
|
||||
task: str # Task identifier
|
||||
root: str | None = None # Local dataset root directory
|
||||
num_episodes_to_record: int = 5 # Number of episodes for recording
|
||||
replay_episode: int | None = None # Episode index for replay
|
||||
push_to_hub: bool = False # Whether to push datasets to Hub
|
||||
num_episodes: int # Number of episodes for recording
|
||||
episode: int # Episode index for replay
|
||||
push_to_hub: bool # Whether to push datasets to Hub
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -351,7 +351,7 @@ Create a configuration file for recording demonstrations (or edit an existing on
|
||||
|
||||
1. Set `mode` to `"record"` at the root level
|
||||
2. Specify a unique `repo_id` for your dataset in the `dataset` section (e.g., "username/task_name")
|
||||
3. Set `num_episodes_to_record` in the `dataset` section to the number of demonstrations you want to collect
|
||||
3. Set `num_episodes` in the `dataset` section to the number of demonstrations you want to collect
|
||||
4. Set `env.processor.image_preprocessing.crop_params_dict` to `{}` initially (we'll determine crops later)
|
||||
5. Configure `env.robot`, `env.teleop`, and other hardware settings in the `env` section
|
||||
|
||||
@@ -390,10 +390,10 @@ Example configuration section:
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "username/pick_lift_cube",
|
||||
"root": null,
|
||||
"dataset_root": null,
|
||||
"task": "pick_and_lift",
|
||||
"num_episodes_to_record": 15,
|
||||
"replay_episode": 0,
|
||||
"num_episodes": 15,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
@@ -626,7 +626,7 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/r
|
||||
|
||||
- **mode**: set it to `"record"` to collect a dataset (at root level)
|
||||
- **dataset.repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
|
||||
- **dataset.num_episodes_to_record**: Number of episodes to record
|
||||
- **dataset.num_episodes**: Number of episodes to record
|
||||
- **env.processor.reset.terminate_on_success**: Whether to automatically terminate episodes when success is detected (default: `true`)
|
||||
- **env.fps**: Number of frames per second to record
|
||||
- **dataset.push_to_hub**: Whether to push the dataset to the hub
|
||||
@@ -664,8 +664,8 @@ Example configuration section for data collection:
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"task": "reward_classifier_task",
|
||||
"num_episodes_to_record": 20,
|
||||
"replay_episode": null,
|
||||
"num_episodes": 20,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
|
||||
@@ -107,10 +107,10 @@ To collect a dataset, set the mode to `record` whilst defining the repo_id and n
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "username/sim_dataset",
|
||||
"root": null,
|
||||
"dataset_root": null,
|
||||
"task": "pick_cube",
|
||||
"num_episodes_to_record": 10,
|
||||
"replay_episode": null,
|
||||
"num_episodes": 10,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record"
|
||||
|
||||
@@ -36,10 +36,10 @@ To teleoperate and collect a dataset, we need to modify this config file. Here's
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_gym",
|
||||
"root": null,
|
||||
"dataset_root": null,
|
||||
"task": "pick_cube",
|
||||
"num_episodes_to_record": 30,
|
||||
"replay_episode": null,
|
||||
"num_episodes": 30,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
@@ -50,7 +50,7 @@ To teleoperate and collect a dataset, we need to modify this config file. Here's
|
||||
Key configuration points:
|
||||
|
||||
- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"`
|
||||
- Set `num_episodes_to_record: 30` to collect 30 demonstration episodes
|
||||
- Set `num_episodes: 30` to collect 30 demonstration episodes
|
||||
- Ensure `mode` is set to `"record"`
|
||||
- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"`
|
||||
- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"`
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.factory import make_processor
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
@@ -46,7 +46,7 @@ listener, events = init_keyboard_listener()
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
|
||||
@@ -20,7 +20,7 @@ from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_featur
|
||||
from lerobot.datasets.utils import merge_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.factory import make_processor
|
||||
from lerobot.processor.converters import (
|
||||
to_output_robot_action,
|
||||
to_transition_robot_observation,
|
||||
@@ -127,7 +127,7 @@ robot.connect()
|
||||
episode_idx = 0
|
||||
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
@@ -38,8 +38,8 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone import Phone
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
@@ -19,7 +19,7 @@ import time
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action
|
||||
from lerobot.processor.converters import to_output_robot_action
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
@@ -49,6 +49,31 @@ kinematics_solver = RobotKinematics(
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
|
||||
# This method converts the action from the dataset to a transition for pipeline
|
||||
def action_to_transition(action: dict):
|
||||
act = {}
|
||||
|
||||
# EE pose
|
||||
for k in ("ee.x", "ee.y", "ee.z", "ee.wx", "ee.wy", "ee.wz"):
|
||||
if k in action:
|
||||
act[f"action.{k}"] = float(action[k])
|
||||
|
||||
# Gripper: your dataset has absolute position
|
||||
if "gripper.pos" in action:
|
||||
act["action.gripper.pos"] = float(action["gripper.pos"])
|
||||
|
||||
return {
|
||||
"observation": None,
|
||||
"action": act,
|
||||
"reward": None,
|
||||
"done": False,
|
||||
"truncated": False,
|
||||
"info": {},
|
||||
"complementary_data": {},
|
||||
}
|
||||
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
steps=[
|
||||
@@ -59,7 +84,7 @@ robot_ee_to_joints = RobotProcessor(
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=to_transition_teleop_action,
|
||||
to_transition=action_to_transition,
|
||||
to_output=to_output_robot_action,
|
||||
)
|
||||
|
||||
@@ -28,8 +28,8 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone import Phone
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot_config = SO100FollowerConfig(
|
||||
@@ -48,8 +48,8 @@ kinematics_solver = RobotKinematics(
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action to joint action
|
||||
phone_to_robot_joints = RobotProcessor(
|
||||
# Build pipeline to convert phone action to ee pose action
|
||||
phone_to_robot_ee_pose = RobotProcessor(
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
@@ -63,6 +63,14 @@ phone_to_robot_joints = RobotProcessor(
|
||||
max_ee_step_m=0.10,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
],
|
||||
to_transition=to_transition_teleop_action,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
@@ -72,7 +80,7 @@ phone_to_robot_joints = RobotProcessor(
|
||||
speed_factor=20.0,
|
||||
),
|
||||
],
|
||||
to_transition=to_transition_teleop_action,
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=to_output_robot_action,
|
||||
)
|
||||
|
||||
@@ -81,11 +89,19 @@ teleop_device.connect()
|
||||
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot.")
|
||||
while True:
|
||||
phone_obs = teleop_device.get_action()
|
||||
if not phone_obs:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
# Get teleop observation
|
||||
phone_obs = teleop_device.get_action()
|
||||
|
||||
# Phone -> EE pose -> Joints transition
|
||||
joint_action = phone_to_robot_joints(phone_obs)
|
||||
# Phone to EE pose transition
|
||||
ee_transition = phone_to_robot_ee_pose(phone_obs)
|
||||
|
||||
# EE pose to Joints transition
|
||||
joint_action = robot_ee_to_joints(ee_transition)
|
||||
|
||||
if joint_action:
|
||||
robot.send_action(joint_action)
|
||||
@@ -24,11 +24,6 @@ OBS_IMAGES = "observation.images"
|
||||
OBS_LANGUAGE = "observation.language"
|
||||
ACTION = "action"
|
||||
REWARD = "next.reward"
|
||||
TRUNCATED = "next.truncated"
|
||||
DONE = "next.done"
|
||||
|
||||
OBS_LANGUAGE_TOKENS = "observation.language.tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = "observation.language.attention_mask"
|
||||
|
||||
ROBOTS = "robots"
|
||||
ROBOT_TYPE = "robot_type"
|
||||
|
||||
@@ -825,8 +825,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
if not episode_data:
|
||||
episode_buffer = self.episode_buffer
|
||||
else:
|
||||
episode_buffer = episode_data
|
||||
|
||||
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
|
||||
@@ -60,26 +59,26 @@ def aggregate_pipeline_dataset_features(
|
||||
|
||||
# Go over every feature from the pipeline and merge:
|
||||
for full_key, ty in all_features.items():
|
||||
if full_key.startswith(f"{ACTION}."):
|
||||
if full_key.startswith("action."):
|
||||
# action.<feat>
|
||||
if not keep(full_key):
|
||||
continue
|
||||
name = full_key[len(f"{ACTION}.") :]
|
||||
hw.setdefault(ACTION, {})[name] = ty
|
||||
name = full_key[len("action.") :]
|
||||
hw.setdefault("action", {})[name] = ty
|
||||
|
||||
elif full_key.startswith(f"{OBS_STATE}."):
|
||||
elif full_key.startswith("observation.state."):
|
||||
# observation.state.<feat>
|
||||
if not keep(full_key):
|
||||
continue
|
||||
name = full_key[len(f"{OBS_STATE}.") :]
|
||||
name = full_key[len("observation.state.") :]
|
||||
hw.setdefault("observation", {})[name] = ty
|
||||
|
||||
elif full_key.startswith(f"{OBS_IMAGES}."):
|
||||
elif full_key.startswith("observation.images."):
|
||||
# observation.images.<cam>
|
||||
# images obey ONLY the use_videos flag, not patterns
|
||||
if not use_videos:
|
||||
continue
|
||||
name = full_key[len(f"{OBS_IMAGES}.") :]
|
||||
name = full_key[len("observation.images.") :]
|
||||
hw.setdefault("observation", {})[name] = ty
|
||||
|
||||
else:
|
||||
@@ -87,8 +86,8 @@ def aggregate_pipeline_dataset_features(
|
||||
continue
|
||||
|
||||
out: dict[str, dict] = {}
|
||||
if ACTION in hw:
|
||||
out.update(hw_to_dataset_features(hw[ACTION], ACTION, use_videos))
|
||||
if "action" in hw:
|
||||
out.update(hw_to_dataset_features(hw["action"], "action", use_videos))
|
||||
if "observation" in hw:
|
||||
out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos))
|
||||
|
||||
|
||||
@@ -107,8 +107,6 @@ X_SERIES_ENCODINGS_TABLE = {
|
||||
"Goal_PWM": X_SERIES_CONTROL_TABLE["Goal_PWM"][1],
|
||||
"Goal_Current": X_SERIES_CONTROL_TABLE["Goal_Current"][1],
|
||||
"Goal_Velocity": X_SERIES_CONTROL_TABLE["Goal_Velocity"][1],
|
||||
"Goal_Position": X_SERIES_CONTROL_TABLE["Goal_Position"][1],
|
||||
"Present_Position": X_SERIES_CONTROL_TABLE["Present_Position"][1],
|
||||
"Present_PWM": X_SERIES_CONTROL_TABLE["Present_PWM"][1],
|
||||
"Present_Current": X_SERIES_CONTROL_TABLE["Present_Current"][1],
|
||||
"Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1],
|
||||
|
||||
@@ -287,7 +287,7 @@ class ACT(nn.Module):
|
||||
└───────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config: ACTConfig):
|
||||
def __init__(self, config: ACTConfig, dataset_stats=None):
|
||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
super().__init__()
|
||||
|
||||
@@ -20,7 +20,6 @@ from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
@@ -28,17 +27,9 @@ from lerobot.processor import (
|
||||
)
|
||||
|
||||
|
||||
def make_act_pre_post_processors(
|
||||
config: ACTConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
def make_act_processor(
|
||||
config: ACTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
@@ -55,16 +46,6 @@ def make_act_pre_post_processors(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
@@ -21,7 +21,6 @@ from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
@@ -29,17 +28,9 @@ from lerobot.processor import (
|
||||
)
|
||||
|
||||
|
||||
def make_diffusion_pre_post_processors(
|
||||
config: DiffusionConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
def make_diffusion_processor(
|
||||
config: DiffusionConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
@@ -56,15 +47,6 @@ def make_diffusion_pre_post_processors(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
@@ -17,9 +17,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
@@ -38,7 +39,7 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor.pipeline import ProcessorKwargs, RobotProcessor
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
@@ -114,11 +115,9 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
preprocessor_overrides: dict[str, Any] | None
|
||||
postprocessor_overrides: dict[str, Any] | None
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||
preprocessor_kwargs: ProcessorKwargs | None
|
||||
postprocessor_kwargs: ProcessorKwargs | None
|
||||
|
||||
|
||||
def make_pre_post_processors(
|
||||
def make_processor(
|
||||
policy_cfg: PreTrainedConfig,
|
||||
pretrained_path: str | None = None,
|
||||
**kwargs: Unpack[ProcessorConfigKwargs],
|
||||
@@ -141,116 +140,82 @@ def make_pre_post_processors(
|
||||
NotImplementedError: If the policy type doesn't have a processor implemented.
|
||||
"""
|
||||
if pretrained_path:
|
||||
# Extract preprocessor and postprocessor kwargs
|
||||
preprocessor_kwargs = kwargs.get("preprocessor_kwargs", {})
|
||||
postprocessor_kwargs = kwargs.get("postprocessor_kwargs", {})
|
||||
|
||||
return (
|
||||
RobotProcessor.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get("preprocessor_config_filename", "robot_preprocessor.json"),
|
||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||
to_transition=preprocessor_kwargs.get("to_transition"),
|
||||
to_output=preprocessor_kwargs.get("to_output"),
|
||||
),
|
||||
RobotProcessor.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get("postprocessor_config_filename", "robot_postprocessor.json"),
|
||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||
to_transition=postprocessor_kwargs.get("to_transition"),
|
||||
to_output=postprocessor_kwargs.get("to_output"),
|
||||
),
|
||||
)
|
||||
|
||||
# Create a new processor based on policy type
|
||||
if isinstance(policy_cfg, TDMPCConfig):
|
||||
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
|
||||
if policy_cfg.type == "tdmpc":
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor
|
||||
|
||||
processors = make_tdmpc_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
processors = make_tdmpc_processor(
|
||||
config=cast(TDMPCConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, DiffusionConfig):
|
||||
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors
|
||||
elif policy_cfg.type == "diffusion":
|
||||
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor
|
||||
|
||||
processors = make_diffusion_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
processors = make_diffusion_processor(
|
||||
cast(DiffusionConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, ACTConfig):
|
||||
from lerobot.policies.act.processor_act import make_act_pre_post_processors
|
||||
elif policy_cfg.type == "act":
|
||||
from lerobot.policies.act.processor_act import make_act_processor
|
||||
|
||||
processors = make_act_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
processors = make_act_processor(
|
||||
config=cast(ACTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VQBeTConfig):
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
|
||||
elif policy_cfg.type == "vqbet":
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor
|
||||
|
||||
processors = make_vqbet_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
processors = make_vqbet_processor(
|
||||
config=cast(VQBeTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI0Config):
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
|
||||
elif policy_cfg.type == "pi0":
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_processor
|
||||
|
||||
processors = make_pi0_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
processors = make_pi0_processor(
|
||||
config=cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI0FASTConfig):
|
||||
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
|
||||
elif policy_cfg.type == "pi0fast":
|
||||
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_processor
|
||||
|
||||
processors = make_pi0fast_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
processors = make_pi0fast_processor(
|
||||
cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SACConfig):
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
elif policy_cfg.type == "sac":
|
||||
from lerobot.policies.sac.processor_sac import make_sac_processor
|
||||
|
||||
processors = make_sac_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
processors = make_sac_processor(
|
||||
cast(SACConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, RewardClassifierConfig):
|
||||
elif policy_cfg.type == "reward_classifier":
|
||||
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
|
||||
|
||||
processors = make_classifier_processor(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
cast(RewardClassifierConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SmolVLAConfig):
|
||||
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||
elif policy_cfg.type == "smolvla":
|
||||
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_processor
|
||||
|
||||
processors = make_smolvla_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
processors = make_smolvla_processor(
|
||||
cast(SmolVLAConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -330,7 +295,7 @@ def make_policy(
|
||||
policy = policy_cls(**kwargs)
|
||||
|
||||
policy.to(cfg.device)
|
||||
assert isinstance(policy, torch.nn.Module)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
|
||||
@@ -14,68 +14,82 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
TokenizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import (
|
||||
ComplementaryDataProcessor,
|
||||
EnvTransition,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.rename_processor import RenameProcessor
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
|
||||
class Pi0NewLineProcessor(ComplementaryDataProcessor):
|
||||
class Pi0NewLineProcessor(ProcessorStep):
|
||||
"""Add a new line to the end of the task if it doesn't have one.
|
||||
This is required for the PaliGemma tokenizer.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Check if complementary_data exists
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None or "task" not in complementary_data:
|
||||
return transition
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return complementary_data
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
return transition
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: add newline if not present
|
||||
if not task.endswith("\n"):
|
||||
new_complementary_data["task"] = f"{task}\n"
|
||||
complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: add newline to each if not present
|
||||
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
|
||||
return new_complementary_data
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Add tokenized task features to the features."""
|
||||
return features
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {}
|
||||
|
||||
|
||||
def make_pi0_pre_post_processors(
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
def make_pi0_processor(
|
||||
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
|
||||
@@ -102,15 +116,6 @@ def make_pi0_pre_post_processors(
|
||||
),
|
||||
]
|
||||
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
@@ -21,7 +21,6 @@ from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
@@ -29,17 +28,9 @@ from lerobot.processor import (
|
||||
)
|
||||
|
||||
|
||||
def make_pi0fast_pre_post_processors(
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
def make_pi0fast_processor(
|
||||
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessor(
|
||||
@@ -56,15 +47,6 @@ def make_pi0fast_pre_post_processors(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
@@ -22,7 +22,6 @@ from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
@@ -30,17 +29,9 @@ from lerobot.processor import (
|
||||
)
|
||||
|
||||
|
||||
def make_sac_pre_post_processors(
|
||||
config: SACConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
def make_sac_processor(
|
||||
config: SACConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
@@ -57,15 +48,6 @@ def make_sac_pre_post_processors(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
@@ -20,22 +20,13 @@ from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
IdentityProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RobotProcessor,
|
||||
)
|
||||
|
||||
|
||||
def make_classifier_processor(
|
||||
config: RewardClassifierConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
config: RewardClassifierConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
NormalizerProcessor(
|
||||
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
@@ -46,16 +37,6 @@ def make_classifier_processor(
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [DeviceProcessor(device="cpu"), IdentityProcessor()]
|
||||
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name="classifier_preprocessor",
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name="classifier_postprocessor",
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
return RobotProcessor(steps=input_steps, name="classifier_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="classifier_postprocessor"
|
||||
)
|
||||
|
||||
@@ -13,38 +13,28 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
TokenizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import (
|
||||
ComplementaryDataProcessor,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStep, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
def make_smolvla_pre_post_processors(
|
||||
config: SmolVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
def make_smolvla_processor(
|
||||
config: SmolVLAConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessor(
|
||||
@@ -68,42 +58,53 @@ def make_smolvla_pre_post_processors(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
|
||||
class SmolVLANewLineProcessor(ComplementaryDataProcessor):
|
||||
class SmolVLANewLineProcessor(ProcessorStep):
|
||||
"""Add a new line to the end of the task if it doesn't have one."""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Check if complementary_data exists
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None or "task" not in complementary_data:
|
||||
return transition
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return complementary_data
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
return transition
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: add newline if not present
|
||||
if not task.endswith("\n"):
|
||||
new_complementary_data["task"] = f"{task}\n"
|
||||
complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: add newline to each if not present
|
||||
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
|
||||
return new_complementary_data
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Adds nothing to the features."""
|
||||
return features
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {}
|
||||
|
||||
@@ -21,7 +21,6 @@ from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
@@ -29,17 +28,9 @@ from lerobot.processor import (
|
||||
)
|
||||
|
||||
|
||||
def make_tdmpc_pre_post_processors(
|
||||
config: TDMPCConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
def make_tdmpc_processor(
|
||||
config: TDMPCConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
@@ -56,15 +47,6 @@ def make_tdmpc_pre_post_processors(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
@@ -22,7 +22,6 @@ from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
@@ -30,17 +29,9 @@ from lerobot.processor import (
|
||||
)
|
||||
|
||||
|
||||
def make_vqbet_pre_post_processors(
|
||||
config: VQBeTConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
def make_vqbet_processor(
|
||||
config: VQBeTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}), # Let the possibility to the user to rename the keys
|
||||
NormalizerProcessor(
|
||||
@@ -57,15 +48,6 @@ def make_vqbet_pre_post_processors(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
@@ -15,17 +15,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .batch_processor import ToBatchProcessor
|
||||
from .delta_action_processor import MapDeltaActionToRobotAction, MapTensorToDeltaActionDict
|
||||
from .delta_action_processor import MapDeltaActionToRobotAction
|
||||
from .device_processor import DeviceProcessor
|
||||
from .gym_action_processor import Numpy2TorchActionProcessor, Torch2NumpyActionProcessor
|
||||
from .hil_processor import (
|
||||
AddTeleopActionAsComplimentaryData,
|
||||
AddTeleopEventsAsInfo,
|
||||
GripperPenaltyProcessor,
|
||||
ImageCropResizeProcessor,
|
||||
InterventionActionProcessor,
|
||||
Numpy2TorchActionProcessor,
|
||||
RewardClassifierProcessor,
|
||||
TimeLimitProcessor,
|
||||
Torch2NumpyActionProcessor,
|
||||
)
|
||||
from .joint_observations_processor import JointVelocityProcessor, MotorCurrentProcessor
|
||||
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor, hotswap_stats
|
||||
@@ -37,7 +38,6 @@ from .pipeline import (
|
||||
IdentityProcessor,
|
||||
InfoProcessor,
|
||||
ObservationProcessor,
|
||||
ProcessorKwargs,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RewardProcessor,
|
||||
@@ -55,7 +55,6 @@ __all__ = [
|
||||
"DeviceProcessor",
|
||||
"DoneProcessor",
|
||||
"MapDeltaActionToRobotAction",
|
||||
"MapTensorToDeltaActionDict",
|
||||
"EnvTransition",
|
||||
"GripperPenaltyProcessor",
|
||||
"IdentityProcessor",
|
||||
@@ -69,7 +68,6 @@ __all__ = [
|
||||
"UnnormalizerProcessor",
|
||||
"hotswap_stats",
|
||||
"ObservationProcessor",
|
||||
"ProcessorKwargs",
|
||||
"ProcessorStep",
|
||||
"ProcessorStepRegistry",
|
||||
"RenameProcessor",
|
||||
|
||||
@@ -11,88 +11,20 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor.pipeline import (
|
||||
ActionProcessor,
|
||||
ComplementaryDataProcessor,
|
||||
EnvTransition,
|
||||
ObservationProcessor,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_action")
|
||||
class ToBatchProcessorAction(ActionProcessor):
|
||||
"""Process action component in-place, adding batch dimension if needed."""
|
||||
|
||||
def action(self, action):
|
||||
if not isinstance(action, Tensor) or action.dim() != 1:
|
||||
return action
|
||||
|
||||
return action.unsqueeze(0)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_observation")
|
||||
class ToBatchProcessorObservation(ObservationProcessor):
|
||||
"""Process observation component in-place, adding batch dimensions where needed."""
|
||||
|
||||
def observation(self, observation):
|
||||
# Process state observations - add batch dim if 1D
|
||||
for state_key in [OBS_STATE, OBS_ENV_STATE]:
|
||||
if state_key in observation:
|
||||
state_value = observation[state_key]
|
||||
if isinstance(state_value, Tensor) and state_value.dim() == 1:
|
||||
observation[state_key] = state_value.unsqueeze(0)
|
||||
|
||||
# Process single image observation - add batch dim if 3D
|
||||
if OBS_IMAGE in observation:
|
||||
image_value = observation[OBS_IMAGE]
|
||||
if isinstance(image_value, Tensor) and image_value.dim() == 3:
|
||||
observation[OBS_IMAGE] = image_value.unsqueeze(0)
|
||||
|
||||
# Process multiple image observations - add batch dim if 3D
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
|
||||
observation[key] = value.unsqueeze(0)
|
||||
return observation
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
|
||||
class ToBatchProcessorComplementaryData(ComplementaryDataProcessor):
|
||||
"""Process complementary data in-place, handling task field batching."""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
# Process task field - wrap string in list to add batch dimension
|
||||
if "task" in complementary_data:
|
||||
task_value = complementary_data["task"]
|
||||
if isinstance(task_value, str):
|
||||
complementary_data["task"] = [task_value]
|
||||
|
||||
# Process index field - add batch dim if 0D
|
||||
if "index" in complementary_data:
|
||||
index_value = complementary_data["index"]
|
||||
if isinstance(index_value, Tensor) and index_value.dim() == 0:
|
||||
complementary_data["index"] = index_value.unsqueeze(0)
|
||||
|
||||
# Process task_index field - add batch dim if 0D
|
||||
if "task_index" in complementary_data:
|
||||
task_index_value = complementary_data["task_index"]
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
return complementary_data
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor")
|
||||
class ToBatchProcessor(ProcessorStep):
|
||||
class ToBatchProcessor:
|
||||
"""Processor that adds batch dimensions to observations and actions when needed.
|
||||
|
||||
This processor ensures that observations and actions have proper batch dimensions for model processing:
|
||||
@@ -127,16 +59,81 @@ class ToBatchProcessor(ProcessorStep):
|
||||
```
|
||||
"""
|
||||
|
||||
to_batch_action_processor: ToBatchProcessorAction = field(default_factory=ToBatchProcessorAction)
|
||||
to_batch_observation_processor: ToBatchProcessorObservation = field(
|
||||
default_factory=ToBatchProcessorObservation
|
||||
)
|
||||
to_batch_complementary_data_processor: ToBatchProcessorComplementaryData = field(
|
||||
default_factory=ToBatchProcessorComplementaryData
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = self.to_batch_action_processor(transition)
|
||||
transition = self.to_batch_observation_processor(transition)
|
||||
transition = self.to_batch_complementary_data_processor(transition)
|
||||
self._process_observation(transition)
|
||||
self._process_action(transition)
|
||||
self._process_complementary_data(transition)
|
||||
return transition
|
||||
|
||||
def _process_observation(self, transition: EnvTransition) -> None:
|
||||
"""Process observation component in-place, adding batch dimensions where needed."""
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
return
|
||||
|
||||
# Process state observations - add batch dim if 1D
|
||||
for state_key in [OBS_STATE, OBS_ENV_STATE]:
|
||||
if state_key in observation:
|
||||
state_value = observation[state_key]
|
||||
if isinstance(state_value, Tensor) and state_value.dim() == 1:
|
||||
observation[state_key] = state_value.unsqueeze(0)
|
||||
|
||||
# Process single image observation - add batch dim if 3D
|
||||
if OBS_IMAGE in observation:
|
||||
image_value = observation[OBS_IMAGE]
|
||||
if isinstance(image_value, Tensor) and image_value.dim() == 3:
|
||||
observation[OBS_IMAGE] = image_value.unsqueeze(0)
|
||||
|
||||
# Process multiple image observations - add batch dim if 3D
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
|
||||
observation[key] = value.unsqueeze(0)
|
||||
|
||||
def _process_action(self, transition: EnvTransition) -> None:
|
||||
"""Process action component in-place, adding batch dimension if needed."""
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and isinstance(action, Tensor) and action.dim() == 1:
|
||||
transition[TransitionKey.ACTION] = action.unsqueeze(0)
|
||||
|
||||
def _process_complementary_data(self, transition: EnvTransition) -> None:
|
||||
"""Process complementary data in-place, handling task field batching."""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return
|
||||
|
||||
# Process task field - wrap string in list to add batch dimension
|
||||
if "task" in complementary_data:
|
||||
task_value = complementary_data["task"]
|
||||
if isinstance(task_value, str):
|
||||
complementary_data["task"] = [task_value]
|
||||
|
||||
# Process index field - add batch dim if 0D
|
||||
if "index" in complementary_data:
|
||||
index_value = complementary_data["index"]
|
||||
if isinstance(index_value, Tensor) and index_value.dim() == 0:
|
||||
complementary_data["index"] = index_value.unsqueeze(0)
|
||||
|
||||
# Process task_index field - add batch dim if 0D
|
||||
if "task_index" in complementary_data:
|
||||
task_index_value = complementary_data["task_index"]
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -18,125 +18,26 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from functools import singledispatch
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
|
||||
|
||||
from .pipeline import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
@singledispatch
|
||||
def to_tensor(
|
||||
value: Any,
|
||||
*,
|
||||
dtype: torch.dtype | None = torch.float32,
|
||||
device: torch.device | str | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert various data types to PyTorch tensors with configurable options.
|
||||
|
||||
This is a unified tensor conversion function using single dispatch to handle
|
||||
different input types appropriately.
|
||||
|
||||
Args:
|
||||
value: Input value to convert (tensor, array, scalar, sequence, etc.)
|
||||
dtype: Target tensor dtype. If None, preserves original dtype.
|
||||
device: Target device for the tensor.
|
||||
|
||||
Returns:
|
||||
PyTorch tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input type is not supported.
|
||||
"""
|
||||
raise TypeError(f"Unsupported type for tensor conversion: {type(value)}")
|
||||
|
||||
|
||||
@to_tensor.register(torch.Tensor)
|
||||
def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle existing PyTorch tensors."""
|
||||
if dtype is not None:
|
||||
value = value.to(dtype=dtype)
|
||||
if device is not None:
|
||||
value = value.to(device=device)
|
||||
return value
|
||||
|
||||
|
||||
@to_tensor.register(np.ndarray)
|
||||
def _(
|
||||
value: np.ndarray,
|
||||
*,
|
||||
dtype=torch.float32,
|
||||
device=None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Handle numpy arrays."""
|
||||
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars
|
||||
if value.ndim == 0:
|
||||
# Numpy scalars should be converted to 0-dimensional tensors
|
||||
scalar_value = value.item()
|
||||
return torch.tensor(scalar_value, dtype=dtype, device=device)
|
||||
|
||||
# Create tensor from numpy array (torch.from_numpy handles contiguity automatically)
|
||||
tensor = torch.from_numpy(value)
|
||||
|
||||
# Apply dtype conversion if specified
|
||||
if dtype is not None:
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
if device is not None:
|
||||
tensor = tensor.to(device=device)
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
@to_tensor.register(int)
|
||||
@to_tensor.register(float)
|
||||
@to_tensor.register(np.integer)
|
||||
@to_tensor.register(np.floating)
|
||||
def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle scalar values including numpy scalars."""
|
||||
return torch.tensor(value, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@to_tensor.register(list)
|
||||
@to_tensor.register(tuple)
|
||||
def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle sequences (lists, tuples)."""
|
||||
return torch.tensor(value, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@to_tensor.register(dict)
|
||||
def _(value: dict, *, device=None, **kwargs) -> dict:
|
||||
"""Handle dictionaries by recursively converting values to tensors."""
|
||||
if not value:
|
||||
return {}
|
||||
|
||||
result = {}
|
||||
for key, sub_value in value.items():
|
||||
if sub_value is None:
|
||||
continue
|
||||
|
||||
if isinstance(sub_value, dict):
|
||||
# Recursively process nested dictionaries
|
||||
result[key] = to_tensor(
|
||||
sub_value,
|
||||
device=device,
|
||||
**kwargs,
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert individual values to tensors
|
||||
result[key] = to_tensor(
|
||||
sub_value,
|
||||
device=device,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
def _to_tensor(x: torch.Tensor | np.ndarray | Sequence[int | float]):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if isinstance(x, np.ndarray):
|
||||
# Keep images (uint8 HWC) and python objects as-is
|
||||
if x.dtype == np.uint8 or x.dtype == np.object_:
|
||||
return x
|
||||
# Scalars/arrays to float32 tensor
|
||||
return torch.as_tensor(x, dtype=torch.float32)
|
||||
# Anything else to float32 tensor
|
||||
return torch.as_tensor(x, dtype=torch.float32)
|
||||
|
||||
|
||||
def _from_tensor(x: Any):
|
||||
@@ -152,7 +53,7 @@ def _is_image(arr: Any) -> bool:
|
||||
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
state, images = {}, {}
|
||||
for k, v in obs.items():
|
||||
if "image" in k.lower() or _is_image(v):
|
||||
if _is_image(v):
|
||||
images[k] = v
|
||||
else:
|
||||
state[k] = v
|
||||
@@ -181,11 +82,11 @@ def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition:
|
||||
for k, v in action.items():
|
||||
# Check if the value is a type that should not be converted to a tensor.
|
||||
if isinstance(v, (Rotation, dict)):
|
||||
act_dict[f"{ACTION}.{k}"] = v
|
||||
act_dict[f"action.{k}"] = v
|
||||
continue
|
||||
|
||||
arr = np.array(v) if np.isscalar(v) else v
|
||||
act_dict[f"{ACTION}.{k}"] = to_tensor(arr)
|
||||
act_dict[f"action.{k}"] = _to_tensor(arr)
|
||||
|
||||
return make_obs_act_transition(act=act_dict)
|
||||
|
||||
@@ -200,10 +101,10 @@ def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransitio
|
||||
obs_dict: dict[str, Any] = {}
|
||||
for k, v in state.items():
|
||||
arr = np.array(v) if np.isscalar(v) else v
|
||||
obs_dict[f"{OBS_STATE}.{k}"] = to_tensor(arr)
|
||||
obs_dict[f"observation.state.{k}"] = _to_tensor(arr)
|
||||
|
||||
for cam, img in images.items():
|
||||
obs_dict[f"{OBS_IMAGES}.{cam}"] = img
|
||||
obs_dict[f"observation.images.{cam}"] = img
|
||||
|
||||
return make_obs_act_transition(obs=obs_dict)
|
||||
|
||||
@@ -215,12 +116,9 @@ def to_output_robot_action(transition: EnvTransition) -> dict[str, Any]:
|
||||
out: dict[str, Any] = {}
|
||||
action_dict = transition.get(TransitionKey.ACTION) or {}
|
||||
|
||||
if action_dict is None:
|
||||
return out
|
||||
|
||||
for k, v in action_dict.items():
|
||||
if isinstance(k, str) and k.startswith(f"{ACTION}.") and k.endswith((".pos", ".vel")):
|
||||
out_key = k[len(f"{ACTION}.") :] # Strip the 'action.' prefix.
|
||||
if isinstance(k, str) and k.startswith("action.") and k.endswith((".pos", ".vel")):
|
||||
out_key = k[len("action.") :] # Strip the 'action.' prefix.
|
||||
out[out_key] = float(v)
|
||||
|
||||
return out
|
||||
@@ -251,9 +149,9 @@ def to_dataset_frame(
|
||||
- info dict
|
||||
- *_is_pad flags and task from complementary_data
|
||||
"""
|
||||
action_names = features.get(ACTION, {}).get("names", [])
|
||||
obs_state_names = features.get(OBS_STATE, {}).get("names", [])
|
||||
image_keys = [k for k in features if k.startswith(OBS_IMAGES)]
|
||||
action_names = features.get("action", {}).get("names", [])
|
||||
obs_state_names = features.get("observation.state", {}).get("names", [])
|
||||
image_keys = [k for k in features if k.startswith("observation.images.")]
|
||||
|
||||
def _merge(base: EnvTransition, other: EnvTransition) -> EnvTransition:
|
||||
out = deepcopy(base)
|
||||
@@ -297,20 +195,21 @@ def to_dataset_frame(
|
||||
|
||||
# Observation.state vector
|
||||
if obs_state_names:
|
||||
vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
|
||||
batch[OBS_STATE] = np.asarray(vals, dtype=np.float32)
|
||||
vals = [_from_tensor(obs.get(f"observation.state.{n}", 0.0)) for n in obs_state_names]
|
||||
batch["observation.state"] = np.asarray(vals, dtype=np.float32)
|
||||
|
||||
# Action vector
|
||||
if action_names:
|
||||
vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
|
||||
batch[ACTION] = np.asarray(vals, dtype=np.float32)
|
||||
vals = [_from_tensor(act.get(f"action.{n}", 0.0)) for n in action_names]
|
||||
batch["action"] = np.asarray(vals, dtype=np.float32)
|
||||
|
||||
# Next.* fields
|
||||
if tr.get(TransitionKey.REWARD) is not None:
|
||||
batch[REWARD] = _from_tensor(tr[TransitionKey.REWARD])
|
||||
batch["next.reward"] = _from_tensor(tr[TransitionKey.REWARD])
|
||||
if tr.get(TransitionKey.DONE) is not None:
|
||||
batch[DONE] = _from_tensor(tr[TransitionKey.DONE])
|
||||
batch["next.done"] = _from_tensor(tr[TransitionKey.DONE])
|
||||
if tr.get(TransitionKey.TRUNCATED) is not None:
|
||||
batch[TRUNCATED] = _from_tensor(tr[TransitionKey.TRUNCATED])
|
||||
batch["next.truncated"] = _from_tensor(tr[TransitionKey.TRUNCATED])
|
||||
|
||||
# Complementary data flags and task
|
||||
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
@@ -22,30 +22,6 @@ from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
|
||||
@dataclass
|
||||
class MapTensorToDeltaActionDict(ActionProcessor):
|
||||
"""
|
||||
Map a tensor to a delta action dictionary.
|
||||
"""
|
||||
|
||||
def action(self, action: Tensor) -> dict:
|
||||
if isinstance(action, dict):
|
||||
return action
|
||||
if action.dim() > 1:
|
||||
action = action.squeeze(0)
|
||||
|
||||
# TODO (maractingi): add rotation
|
||||
delta_action = {
|
||||
"action.delta_x": action[0],
|
||||
"action.delta_y": action[1],
|
||||
"action.delta_z": action[2],
|
||||
}
|
||||
if action.shape[0] > 3:
|
||||
delta_action["action.gripper"] = action[3]
|
||||
return delta_action
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
|
||||
@dataclass
|
||||
class MapDeltaActionToRobotAction(ActionProcessor):
|
||||
@@ -77,25 +53,35 @@ class MapDeltaActionToRobotAction(ActionProcessor):
|
||||
# Scale factors for delta movements
|
||||
position_scale: float = 1.0
|
||||
rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard
|
||||
noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise
|
||||
gripper_deadzone: float = 0.1 # Threshold for gripper activation
|
||||
_prev_enabled: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def action(self, action: dict | Tensor | None) -> dict:
|
||||
if action is None:
|
||||
return {}
|
||||
|
||||
def action(self, action: dict) -> dict:
|
||||
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
|
||||
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
|
||||
delta_x = action.pop("action.delta_x", 0.0)
|
||||
delta_y = action.pop("action.delta_y", 0.0)
|
||||
delta_z = action.pop("action.delta_z", 0.0)
|
||||
gripper = action.pop("action.gripper", 1.0) # Default to "stay" (1.0)
|
||||
if isinstance(action, dict):
|
||||
delta_x = action.pop("action.delta_x", 0.0)
|
||||
delta_y = action.pop("action.delta_y", 0.0)
|
||||
delta_z = action.pop("action.delta_z", 0.0)
|
||||
gripper = action.pop("action.gripper", 1.0) # Default to "stay" (1.0)
|
||||
else:
|
||||
delta_x = action[0].item()
|
||||
delta_y = action[1].item()
|
||||
delta_z = action[2].item()
|
||||
gripper = action[3].item()
|
||||
|
||||
# Determine if the teleoperator is actively providing input
|
||||
# Consider enabled if any significant movement delta is detected
|
||||
position_magnitude = (delta_x**2 + delta_y**2 + delta_z**2) ** 0.5 # Use Euclidean norm for position
|
||||
enabled = position_magnitude > self.noise_threshold # Small threshold to avoid noise
|
||||
position_magnitude = abs(delta_x) + abs(delta_y) + abs(delta_z)
|
||||
enabled = position_magnitude > 1e-6 # Small threshold to avoid noise
|
||||
|
||||
# Scale the deltas appropriately
|
||||
scaled_delta_x = delta_x * self.position_scale
|
||||
scaled_delta_y = delta_y * self.position_scale
|
||||
scaled_delta_z = delta_z * self.position_scale
|
||||
scaled_delta_x = float(delta_x) * self.position_scale
|
||||
scaled_delta_y = float(delta_y) * self.position_scale
|
||||
scaled_delta_z = float(delta_z) * self.position_scale
|
||||
|
||||
# For gamepad/keyboard, we don't have rotation input, so set to 0
|
||||
# These could be extended in the future for more sophisticated teleoperators
|
||||
@@ -115,6 +101,7 @@ class MapDeltaActionToRobotAction(ActionProcessor):
|
||||
"action.gripper": float(gripper),
|
||||
}
|
||||
|
||||
self._prev_enabled = enabled
|
||||
return action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
@@ -133,3 +120,6 @@ class MapDeltaActionToRobotAction(ActionProcessor):
|
||||
}
|
||||
)
|
||||
return features
|
||||
|
||||
def reset(self):
|
||||
self._prev_enabled = False
|
||||
|
||||
@@ -18,13 +18,14 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStep, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.utils.utils import get_safe_torch_device
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("device_processor")
|
||||
@dataclass
|
||||
class DeviceProcessor(ProcessorStep):
|
||||
class DeviceProcessor:
|
||||
"""Processes transitions by moving tensors to the specified device and optionally converting float dtypes.
|
||||
|
||||
This processor ensures that all tensors in the transition are moved to the
|
||||
@@ -35,30 +36,32 @@ class DeviceProcessor(ProcessorStep):
|
||||
|
||||
device: str = "cpu"
|
||||
float_dtype: str | None = None
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"half": torch.float16,
|
||||
"float": torch.float32,
|
||||
"double": torch.float64,
|
||||
}
|
||||
_device: torch.device | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._device: torch.device = get_safe_torch_device(self.device)
|
||||
self.device = self._device.type # cuda might have changed to cuda:1
|
||||
self._device = get_safe_torch_device(self.device)
|
||||
self.device = self._device.type
|
||||
self.non_blocking = "cuda" in str(self.device)
|
||||
|
||||
# Validate and convert float_dtype string to torch dtype
|
||||
if self.float_dtype is not None:
|
||||
if self.float_dtype not in self.DTYPE_MAPPING:
|
||||
dtype_mapping = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"half": torch.float16,
|
||||
"float": torch.float32,
|
||||
"double": torch.float64,
|
||||
}
|
||||
|
||||
if self.float_dtype not in dtype_mapping:
|
||||
available_dtypes = list(dtype_mapping.keys())
|
||||
raise ValueError(
|
||||
f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}"
|
||||
f"Invalid float_dtype '{self.float_dtype}'. Available options: {available_dtypes}"
|
||||
)
|
||||
|
||||
self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype]
|
||||
self._target_float_dtype = dtype_mapping[self.float_dtype]
|
||||
else:
|
||||
self._target_float_dtype = None
|
||||
|
||||
@@ -91,38 +94,69 @@ class DeviceProcessor(ProcessorStep):
|
||||
return tensor
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Create a copy of the transition
|
||||
new_transition = transition.copy()
|
||||
|
||||
simple_tensor_keys = [
|
||||
TransitionKey.ACTION,
|
||||
TransitionKey.REWARD,
|
||||
TransitionKey.DONE,
|
||||
TransitionKey.TRUNCATED,
|
||||
]
|
||||
# Process observation tensors
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is not None:
|
||||
new_observation = {
|
||||
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in observation.items()
|
||||
}
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
|
||||
dict_tensor_keys = [
|
||||
TransitionKey.OBSERVATION,
|
||||
TransitionKey.COMPLEMENTARY_DATA,
|
||||
]
|
||||
# Process action tensor
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and isinstance(action, torch.Tensor):
|
||||
new_transition[TransitionKey.ACTION] = self._process_tensor(action)
|
||||
|
||||
# Process simple tensors
|
||||
for key in simple_tensor_keys:
|
||||
value = transition.get(key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
new_transition[key] = self._process_tensor(value)
|
||||
# Process reward tensor
|
||||
reward = transition.get(TransitionKey.REWARD)
|
||||
if reward is not None and isinstance(reward, torch.Tensor):
|
||||
new_transition[TransitionKey.REWARD] = self._process_tensor(reward)
|
||||
|
||||
# Process dictionary-like tensors
|
||||
for key in dict_tensor_keys:
|
||||
data_dict = transition.get(key)
|
||||
if data_dict is not None:
|
||||
new_data_dict = {
|
||||
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in data_dict.items()
|
||||
}
|
||||
new_transition[key] = new_data_dict
|
||||
# Process done tensor
|
||||
done = transition.get(TransitionKey.DONE)
|
||||
if done is not None and isinstance(done, torch.Tensor):
|
||||
new_transition[TransitionKey.DONE] = self._process_tensor(done)
|
||||
|
||||
# Process truncated tensor
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
if truncated is not None and isinstance(truncated, torch.Tensor):
|
||||
new_transition[TransitionKey.TRUNCATED] = self._process_tensor(truncated)
|
||||
|
||||
# Process complementary data tensors
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is not None:
|
||||
new_complementary_data = {}
|
||||
|
||||
# Process all items in complementary_data
|
||||
for key, value in complementary_data.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
new_complementary_data[key] = self._process_tensor(value)
|
||||
else:
|
||||
new_complementary_data[key] = value
|
||||
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {"device": self.device, "float_dtype": self.float_dtype}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
#! /usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("torch2numpy_action_processor")
|
||||
@dataclass
|
||||
class Torch2NumpyActionProcessor(ActionProcessor):
|
||||
"""Convert PyTorch tensor actions to NumPy arrays."""
|
||||
|
||||
squeeze_batch_dim: bool = True
|
||||
|
||||
def action(self, action: torch.Tensor) -> np.ndarray:
|
||||
if not isinstance(action, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"Expected torch.Tensor or None, got {type(action).__name__}. "
|
||||
"Use appropriate processor for non-tensor actions."
|
||||
)
|
||||
|
||||
numpy_action = action.detach().cpu().numpy()
|
||||
|
||||
# Remove batch dimensions but preserve action dimensions
|
||||
# Only squeeze if there's a batch dimension (first dim == 1)
|
||||
if (
|
||||
self.squeeze_batch_dim
|
||||
and numpy_action.shape
|
||||
and len(numpy_action.shape) > 1
|
||||
and numpy_action.shape[0] == 1
|
||||
):
|
||||
numpy_action = numpy_action.squeeze(0)
|
||||
|
||||
return numpy_action
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("numpy2torch_action_processor")
|
||||
@dataclass
|
||||
class Numpy2TorchActionProcessor(ActionProcessor):
|
||||
"""Convert NumPy array action to PyTorch tensor."""
|
||||
|
||||
def action(self, action: np.ndarray) -> torch.Tensor:
|
||||
if not isinstance(action, np.ndarray):
|
||||
raise TypeError(
|
||||
f"Expected np.ndarray or None, got {type(action).__name__}. "
|
||||
"Use appropriate processor for non-tensor actions."
|
||||
)
|
||||
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
||||
return torch_action
|
||||
@@ -1,4 +1,3 @@
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
@@ -8,23 +7,19 @@ import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import ACTION
|
||||
from lerobot.processor.pipeline import (
|
||||
ActionProcessor,
|
||||
ComplementaryDataProcessor,
|
||||
EnvTransition,
|
||||
InfoProcessor,
|
||||
ObservationProcessor,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
TruncatedProcessor,
|
||||
)
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
GRIPPER_KEY = "gripper"
|
||||
DISCRETE_PENALTY_KEY = "discrete_penalty"
|
||||
TELEOP_ACTION_KEY = "teleop_action"
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
|
||||
@@ -34,10 +29,10 @@ class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor):
|
||||
|
||||
teleop_device: Teleoperator
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
|
||||
return new_complementary_data
|
||||
def complementary_data(self, complementary_data: dict | None) -> dict:
|
||||
complementary_data = {} if complementary_data is None else dict(complementary_data)
|
||||
complementary_data["teleop_action"] = self.teleop_device.get_action()
|
||||
return complementary_data
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_info")
|
||||
@@ -47,11 +42,60 @@ class AddTeleopEventsAsInfo(InfoProcessor):
|
||||
|
||||
teleop_device: Teleoperator
|
||||
|
||||
def info(self, info: dict) -> dict:
|
||||
new_info = dict(info)
|
||||
def info(self, info: dict | None) -> dict:
|
||||
info = {} if info is None else dict(info)
|
||||
teleop_events = getattr(self.teleop_device, "get_teleop_events", lambda: {})()
|
||||
new_info.update(teleop_events)
|
||||
return new_info
|
||||
info.update(teleop_events)
|
||||
return info
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("torch2numpy_action_processor")
|
||||
@dataclass
|
||||
class Torch2NumpyActionProcessor(ActionProcessor):
|
||||
"""Convert PyTorch tensor actions to NumPy arrays."""
|
||||
|
||||
squeeze_batch_dim: bool = True
|
||||
|
||||
def action(self, action: torch.Tensor | None) -> np.ndarray | None:
|
||||
if action is None:
|
||||
return None
|
||||
|
||||
if not isinstance(action, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"Expected torch.Tensor or None, got {type(action).__name__}. "
|
||||
"Use appropriate processor for non-tensor actions."
|
||||
)
|
||||
|
||||
numpy_action = action.detach().cpu().numpy()
|
||||
|
||||
# Remove batch dimensions but preserve action dimensions
|
||||
# Only squeeze if there's a batch dimension (first dim == 1)
|
||||
if (
|
||||
self.squeeze_batch_dim
|
||||
and numpy_action.shape
|
||||
and len(numpy_action.shape) > 1
|
||||
and numpy_action.shape[0] == 1
|
||||
):
|
||||
numpy_action = numpy_action.squeeze(0)
|
||||
|
||||
return numpy_action
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("numpy2torch_action_processor")
|
||||
@dataclass
|
||||
class Numpy2TorchActionProcessor(ActionProcessor):
|
||||
"""Convert NumPy array action to PyTorch tensor."""
|
||||
|
||||
def action(self, action: np.ndarray | None) -> torch.Tensor | None:
|
||||
if action is None:
|
||||
return None
|
||||
if not isinstance(action, np.ndarray):
|
||||
raise TypeError(
|
||||
f"Expected np.ndarray or None, got {type(action).__name__}. "
|
||||
"Use appropriate processor for non-tensor actions."
|
||||
)
|
||||
torch_action = torch.from_numpy(action)
|
||||
return torch_action
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("image_crop_resize_processor")
|
||||
@@ -62,7 +106,10 @@ class ImageCropResizeProcessor(ObservationProcessor):
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
def observation(self, observation: dict | None) -> dict | None:
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
if self.resize_size is None and not self.crop_params_dict:
|
||||
return observation
|
||||
|
||||
@@ -106,45 +153,63 @@ class ImageCropResizeProcessor(ObservationProcessor):
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("time_limit_processor")
|
||||
class TimeLimitProcessor(TruncatedProcessor):
|
||||
class TimeLimitProcessor:
|
||||
"""Track episode steps and enforce time limits."""
|
||||
|
||||
max_episode_steps: int
|
||||
current_step: int = 0
|
||||
|
||||
def truncated(self, truncated):
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
if truncated is None:
|
||||
return transition
|
||||
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.max_episode_steps:
|
||||
truncated = True
|
||||
# TODO (steven): missing an else truncated = False?
|
||||
return truncated
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.TRUNCATED] = truncated
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"max_episode_steps": self.max_episode_steps,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
self.current_step = 0
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessor(ComplementaryDataProcessor):
|
||||
class GripperPenaltyProcessor:
|
||||
"""Apply penalty for inappropriate gripper usage."""
|
||||
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Calculate gripper penalty and add to complementary data."""
|
||||
action = self.transition.get(TransitionKey.ACTION)
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
|
||||
if complementary_data is None or action is None:
|
||||
return transition
|
||||
|
||||
current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None)
|
||||
if current_gripper_pos is None:
|
||||
return complementary_data
|
||||
return transition
|
||||
|
||||
gripper_action = action[f"{ACTION}.{GRIPPER_KEY}.pos"]
|
||||
gripper_action = action[f"action.{GRIPPER_KEY}.pos"]
|
||||
gripper_action_normalized = gripper_action / self.max_gripper_pos
|
||||
|
||||
# Normalize gripper state and action
|
||||
@@ -157,11 +222,19 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
# Add penalty information to complementary data
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
# Create new complementary data with penalty info
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
|
||||
new_complementary_data["discrete_penalty"] = gripper_penalty
|
||||
|
||||
return new_complementary_data
|
||||
# Create new transition with updated complementary data
|
||||
new_transition = transition.copy()
|
||||
existing_comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
existing_comp_data.update(new_complementary_data)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = existing_comp_data # type: ignore[misc]
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
@@ -169,14 +242,23 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the processor state."""
|
||||
self.last_gripper_state = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("intervention_action_processor")
|
||||
class InterventionActionProcessor(ProcessorStep):
|
||||
class InterventionActionProcessor:
|
||||
"""Handle human intervention actions and episode termination."""
|
||||
|
||||
use_gripper: bool = False
|
||||
@@ -189,8 +271,7 @@ class InterventionActionProcessor(ProcessorStep):
|
||||
|
||||
# Get intervention signals from complementary data
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
teleop_action = complementary_data.get(TELEOP_ACTION_KEY, {})
|
||||
teleop_action = info.get("teleop_action", {})
|
||||
is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False)
|
||||
terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False)
|
||||
success = info.get(TeleopEvents.SUCCESS, False)
|
||||
@@ -203,12 +284,12 @@ class InterventionActionProcessor(ProcessorStep):
|
||||
if isinstance(teleop_action, dict):
|
||||
# Convert teleop_action dict to tensor format
|
||||
action_list = [
|
||||
teleop_action.get(f"{ACTION}.delta_x", 0.0),
|
||||
teleop_action.get(f"{ACTION}.delta_y", 0.0),
|
||||
teleop_action.get(f"{ACTION}.delta_z", 0.0),
|
||||
teleop_action.get("action.delta_x", 0.0),
|
||||
teleop_action.get("action.delta_y", 0.0),
|
||||
teleop_action.get("action.delta_z", 0.0),
|
||||
]
|
||||
if self.use_gripper:
|
||||
action_list.append(teleop_action.get(GRIPPER_KEY, 1.0))
|
||||
action_list.append(teleop_action.get("gripper", 1.0))
|
||||
elif isinstance(teleop_action, np.ndarray):
|
||||
action_list = teleop_action.tolist()
|
||||
else:
|
||||
@@ -232,7 +313,7 @@ class InterventionActionProcessor(ProcessorStep):
|
||||
|
||||
# Update complementary data with teleop action
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
complementary_data[TELEOP_ACTION_KEY] = new_transition.get(TransitionKey.ACTION)
|
||||
complementary_data["teleop_action"] = new_transition.get(TransitionKey.ACTION)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return new_transition
|
||||
@@ -240,13 +321,24 @@ class InterventionActionProcessor(ProcessorStep):
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"use_gripper": self.use_gripper,
|
||||
"terminate_on_success": self.terminate_on_success,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("reward_classifier_processor")
|
||||
class RewardClassifierProcessor(ProcessorStep):
|
||||
class RewardClassifierProcessor:
|
||||
"""Apply reward classification to image observations."""
|
||||
|
||||
pretrained_path: str | None = None
|
||||
@@ -288,7 +380,7 @@ class RewardClassifierProcessor(ProcessorStep):
|
||||
reward = transition.get(TransitionKey.REWARD, 0.0)
|
||||
terminated = transition.get(TransitionKey.DONE, False)
|
||||
|
||||
if math.isclose(success, 1, abs_tol=1e-2):
|
||||
if success == 1.0:
|
||||
reward = self.success_reward
|
||||
if self.terminate_on_success:
|
||||
terminated = True
|
||||
@@ -312,3 +404,15 @@ class RewardClassifierProcessor(ProcessorStep):
|
||||
"success_reward": self.success_reward,
|
||||
"terminate_on_success": self.terminate_on_success,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -13,28 +13,30 @@ from lerobot.robots import Robot
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("joint_velocity_processor")
|
||||
class JointVelocityProcessor(ObservationProcessor):
|
||||
class JointVelocityProcessor:
|
||||
"""Add joint velocity information to observations."""
|
||||
|
||||
dt: float = 0.1
|
||||
joint_velocity_limits: float = 100.0
|
||||
dt: float = 1.0 / 10
|
||||
num_dof: int | None = None
|
||||
|
||||
last_joint_positions: torch.Tensor | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
def observation(self, observation: dict | None) -> dict | None:
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
# Get current joint positions (assuming they're in observation.state)
|
||||
current_positions = observation.get("observation.state")
|
||||
if current_positions is None:
|
||||
# TODO(steven): if we get here, then the transform_features method will not hold
|
||||
return observation
|
||||
|
||||
# Initialize last joint positions if not already set
|
||||
if self.last_joint_positions is None:
|
||||
self.last_joint_positions = current_positions.clone()
|
||||
joint_velocities = torch.zeros_like(current_positions)
|
||||
else:
|
||||
# Compute velocities
|
||||
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
|
||||
|
||||
# Compute velocities
|
||||
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
|
||||
self.last_joint_positions = current_positions.clone()
|
||||
|
||||
# Extend observation with velocities
|
||||
@@ -48,6 +50,7 @@ class JointVelocityProcessor(ObservationProcessor):
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"joint_velocity_limits": self.joint_velocity_limits,
|
||||
"dt": self.dt,
|
||||
}
|
||||
|
||||
@@ -55,11 +58,12 @@ class JointVelocityProcessor(ObservationProcessor):
|
||||
self.last_joint_positions = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if "observation.state" in features:
|
||||
if "observation.state" in features and self.num_dof is not None:
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
|
||||
original_feature = features["observation.state"]
|
||||
# Double the shape to account for positions + velocities
|
||||
new_shape = (original_feature.shape[0] * 2,) + original_feature.shape[1:]
|
||||
|
||||
new_shape = (original_feature.shape[0] + self.num_dof,) + original_feature.shape[1:]
|
||||
features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
return features
|
||||
|
||||
@@ -71,7 +75,10 @@ class MotorCurrentProcessor(ObservationProcessor):
|
||||
|
||||
robot: Robot | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
def observation(self, observation: dict | None) -> dict | None:
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
# Get current values from robot state
|
||||
if self.robot is None:
|
||||
return observation
|
||||
|
||||
@@ -1,88 +1,232 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.processor.pipeline import (
|
||||
EnvTransition,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RobotProcessor,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, RobotProcessor, TransitionKey
|
||||
|
||||
|
||||
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
|
||||
"""Convert numpy arrays and other types to torch tensors."""
|
||||
tensor_stats: dict[str, dict[str, Tensor]] = {}
|
||||
for key, sub in stats.items():
|
||||
tensor_stats[key] = {}
|
||||
for stat_name, value in sub.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
tensor_val = torch.from_numpy(value.astype(np.float32))
|
||||
elif isinstance(value, torch.Tensor):
|
||||
tensor_val = value.to(dtype=torch.float32)
|
||||
elif isinstance(value, (int, float, list, tuple)):
|
||||
tensor_val = torch.tensor(value, dtype=torch.float32)
|
||||
else:
|
||||
raise TypeError(f"Unsupported type for stats['{key}']['{stat_name}']: {type(value)}")
|
||||
tensor_stats[key][stat_name] = tensor_val
|
||||
return tensor_stats
|
||||
|
||||
|
||||
@dataclass
|
||||
class _NormalizationMixin:
|
||||
"""
|
||||
A mixin class providing core functionality for normalization and unnormalization.
|
||||
@ProcessorStepRegistry.register(name="normalizer_processor")
|
||||
class NormalizerProcessor:
|
||||
"""Normalizes observations and actions in a single processor step.
|
||||
|
||||
This class manages normalization statistics, their conversion to tensors, device placement,
|
||||
and the application of normalization transformations. It is designed to be inherited by
|
||||
concrete ProcessorStep implementations.
|
||||
This processor handles normalization of both observation and action tensors
|
||||
using either mean/std normalization or min/max scaling to a [-1, 1] range.
|
||||
|
||||
For each tensor key in the stats dictionary, the processor will:
|
||||
- Use mean/std normalization if those statistics are provided: (x - mean) / std
|
||||
- Use min/max scaling if those statistics are provided: 2 * (x - min) / (max - min) - 1
|
||||
|
||||
The processor can be configured to normalize only specific keys by setting
|
||||
the normalize_keys parameter.
|
||||
"""
|
||||
|
||||
# Features and normalisation map are mandatory to match the design of normalize.py
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
|
||||
# Pre-computed statistics coming from dataset.meta.stats for instance.
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
device: torch.device | str | None = None
|
||||
|
||||
# Explicit subset of keys to normalise. If ``None`` every key (except
|
||||
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
|
||||
# membership checks O(1).
|
||||
normalize_keys: set[str] | None = None
|
||||
|
||||
eps: float = 1e-8
|
||||
normalize_observation_keys: set[str] | None = None
|
||||
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
normalize_keys: set[str] | None = None,
|
||||
eps: float = 1e-8,
|
||||
) -> NormalizerProcessor:
|
||||
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`.
|
||||
|
||||
The features and norm_map parameters are mandatory to match the design
|
||||
pattern used in normalize.py.
|
||||
"""
|
||||
|
||||
return cls(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=dataset.meta.stats,
|
||||
normalize_keys=normalize_keys,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Robust JSON deserialization handling (guard empty maps)
|
||||
if self.features:
|
||||
first_val = next(iter(self.features.values()))
|
||||
if isinstance(first_val, dict):
|
||||
reconstructed = {}
|
||||
for key, ft_dict in self.features.items():
|
||||
reconstructed[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.features = reconstructed
|
||||
# Handle deserialization from JSON config
|
||||
if self.features and isinstance(list(self.features.values())[0], dict):
|
||||
# Features came from JSON - need to reconstruct PolicyFeature objects
|
||||
reconstructed_features = {}
|
||||
for key, ft_dict in self.features.items():
|
||||
reconstructed_features[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.features = reconstructed_features
|
||||
|
||||
if self.norm_map:
|
||||
# if keys are strings (JSON), rebuild enum map
|
||||
if all(isinstance(k, str) for k in self.norm_map.keys()):
|
||||
reconstructed = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed
|
||||
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
|
||||
# norm_map came from JSON - need to reconstruct enum keys and values
|
||||
reconstructed_norm_map = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed_norm_map
|
||||
|
||||
# Convert stats to tensors and move to the target device once during initialization.
|
||||
# Convert statistics once so we avoid repeated numpy→Tensor conversions
|
||||
# during runtime.
|
||||
self.stats = self.stats or {}
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device)
|
||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||
|
||||
def to(self, device: torch.device | str) -> _NormalizationMixin:
|
||||
"""Moves the processor's normalization stats to the specified device and returns self."""
|
||||
self.device = device
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device)
|
||||
return self
|
||||
# Ensure *normalize_keys* is a set for fast look-ups and compare by
|
||||
# value later when returning the configuration.
|
||||
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
|
||||
self.normalize_keys = set(self.normalize_keys)
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat: dict[str, Tensor] = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor.cpu() # Always save to CPU
|
||||
return flat
|
||||
def _normalize_obs(self, observation, normalized_info):
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
def load_state_dict(self, state: dict[str, Tensor]) -> None:
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
# Load to the processor's configured device.
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
|
||||
dtype=torch.float32, device=self.device
|
||||
# Decide which keys should be normalised for this call.
|
||||
if self.normalize_keys is not None:
|
||||
keys_to_norm = self.normalize_keys
|
||||
else:
|
||||
# Use feature map to skip action keys.
|
||||
keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION}
|
||||
|
||||
processed = dict(observation)
|
||||
for key in keys_to_norm:
|
||||
if key not in processed or key not in self.features:
|
||||
continue
|
||||
|
||||
# Check the normalization mode for this feature type
|
||||
feature = self.features[key]
|
||||
norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY)
|
||||
|
||||
# Skip normalization if mode is IDENTITY
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
normalized_info[key] = "IDENTITY"
|
||||
continue
|
||||
|
||||
# Skip if no stats available for this key
|
||||
if key not in self._tensor_stats:
|
||||
continue
|
||||
|
||||
orig_val = processed[key]
|
||||
tensor = (
|
||||
orig_val.to(dtype=torch.float32)
|
||||
if isinstance(orig_val, torch.Tensor)
|
||||
else torch.as_tensor(orig_val, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = (tensor - mean) / (std + self.eps)
|
||||
normalized_info[key] = "MEAN_STD"
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
normalized_info[key] = "MIN_MAX"
|
||||
else:
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
return processed
|
||||
|
||||
def _normalize_action(self, action, normalized_info):
|
||||
if action is None:
|
||||
return action
|
||||
|
||||
# Check the normalization mode for actions
|
||||
norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY)
|
||||
|
||||
# Skip normalization if mode is IDENTITY
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
normalized_info["action"] = "IDENTITY"
|
||||
return action
|
||||
|
||||
# Skip if no stats available for actions
|
||||
if "action" not in self._tensor_stats:
|
||||
return action
|
||||
|
||||
tensor = (
|
||||
action.to(dtype=torch.float32)
|
||||
if isinstance(action, torch.Tensor)
|
||||
else torch.as_tensor(action, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
normalized_info["action"] = "MEAN_STD"
|
||||
return (tensor - mean) / (std + self.eps)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
normalized_info["action"] = "MIN_MAX"
|
||||
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
# If we reach here, the required stats for the normalization mode are not available
|
||||
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Track what was normalized
|
||||
normalized_info = {}
|
||||
|
||||
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION), normalized_info)
|
||||
action = self._normalize_action(transition.get(TransitionKey.ACTION), normalized_info)
|
||||
|
||||
# Create a new transition with normalized values
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Add normalization info to complementary data
|
||||
if normalized_info:
|
||||
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
comp_data = {} if comp_data is None else dict(comp_data)
|
||||
comp_data["normalized_keys"] = normalized_info
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
config = {
|
||||
@@ -92,177 +236,240 @@ class _NormalizationMixin:
|
||||
},
|
||||
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
|
||||
}
|
||||
if self.normalize_observation_keys is not None:
|
||||
config["normalize_observation_keys"] = sorted(self.normalize_observation_keys)
|
||||
if self.normalize_keys is not None:
|
||||
# Serialise as a list for YAML / JSON friendliness
|
||||
config["normalize_keys"] = sorted(self.normalize_keys)
|
||||
return config
|
||||
|
||||
def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]:
|
||||
new_observation = dict(observation)
|
||||
for key, feature in self.features.items():
|
||||
if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys:
|
||||
continue
|
||||
if feature.type != FeatureType.ACTION and key in new_observation:
|
||||
tensor = torch.as_tensor(new_observation[key], dtype=torch.float32)
|
||||
new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse)
|
||||
return new_observation
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor
|
||||
return flat
|
||||
|
||||
def _normalize_action(self, action: Any, inverse: bool) -> Tensor:
|
||||
tensor = torch.as_tensor(action, dtype=torch.float32)
|
||||
processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse)
|
||||
return processed_action
|
||||
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
|
||||
|
||||
def _apply_transform(
|
||||
self, tensor: Tensor, key: str, feature_type: FeatureType, *, inverse: bool = False
|
||||
) -> Tensor:
|
||||
"""Core logic to apply normalization or unnormalization."""
|
||||
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
|
||||
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
|
||||
return tensor
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX):
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
# Ensure input tensor is on the same device as the stats.
|
||||
if self.device and tensor.device != self.device:
|
||||
tensor = tensor.to(self.device)
|
||||
|
||||
# For Accelerate compatibility: move stats to match input tensor device
|
||||
input_device = tensor.device
|
||||
stats = self._tensor_stats[key]
|
||||
tensor = tensor.to(dtype=torch.float32)
|
||||
|
||||
# Move stats to input device if needed
|
||||
stats_device = next(iter(stats.values())).device
|
||||
if stats_device != input_device:
|
||||
stats = to_tensor({key: self._tensor_stats[key]}, device=input_device)[key]
|
||||
|
||||
if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
# Avoid division by zero by adding a small epsilon.
|
||||
denom = std + self.eps
|
||||
if inverse:
|
||||
return tensor * std + mean
|
||||
return (tensor - mean) / denom
|
||||
|
||||
if norm_mode == NormalizationMode.MIN_MAX and "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
denom = max_val - min_val
|
||||
# When min_val == max_val, substitute the denominator with a small epsilon
|
||||
# to prevent division by zero. This consistently maps an input equal to
|
||||
# min_val to -1, ensuring a stable transformation.
|
||||
denom = torch.where(
|
||||
denom == 0, torch.tensor(self.eps, device=input_device, dtype=torch.float32), denom
|
||||
)
|
||||
if inverse:
|
||||
# Map from [-1, 1] back to [min, max]
|
||||
return (tensor + 1) / 2 * denom + min_val
|
||||
# Map from [min, max] to [-1, 1]
|
||||
return 2 * (tensor - min_val) / denom - 1
|
||||
|
||||
# If necessary stats are missing, return input unchanged.
|
||||
return tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="normalizer_processor")
|
||||
class NormalizerProcessor(_NormalizationMixin, ProcessorStep):
|
||||
"""
|
||||
A processor that applies normalization to observations and actions in a transition.
|
||||
|
||||
This class directly implements the normalization logic for both observation and action
|
||||
components of an `EnvTransition`, using statistics (mean/std or min/max) provided at
|
||||
initialization.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
normalize_observation_keys: set[str] | None = None,
|
||||
eps: float = 1e-8,
|
||||
device: torch.device | str | None = None,
|
||||
) -> NormalizerProcessor:
|
||||
return cls(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=dataset.meta.stats,
|
||||
normalize_observation_keys=normalize_observation_keys,
|
||||
eps=eps,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Handle observation normalization.
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is not None:
|
||||
new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(
|
||||
observation, inverse=False
|
||||
)
|
||||
|
||||
# Handle action normalization.
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False)
|
||||
|
||||
return new_transition
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="unnormalizer_processor")
|
||||
class UnnormalizerProcessor(_NormalizationMixin, ProcessorStep):
|
||||
"""
|
||||
A processor that applies unnormalization (the inverse of normalization) to
|
||||
observations and actions in a transition.
|
||||
class UnnormalizerProcessor:
|
||||
"""Inverse normalisation for observations and actions.
|
||||
|
||||
This is typically used to transform actions from a normalized policy output back into
|
||||
the original scale for execution in an environment.
|
||||
Exactly mirrors :class:`NormalizerProcessor` but applies the inverse
|
||||
transform.
|
||||
"""
|
||||
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
device: torch.device | str | None = None,
|
||||
) -> UnnormalizerProcessor:
|
||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, device=device)
|
||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats)
|
||||
|
||||
def __post_init__(self):
|
||||
# Handle deserialization from JSON config
|
||||
if self.features and isinstance(list(self.features.values())[0], dict):
|
||||
# Features came from JSON - need to reconstruct PolicyFeature objects
|
||||
reconstructed_features = {}
|
||||
for key, ft_dict in self.features.items():
|
||||
reconstructed_features[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.features = reconstructed_features
|
||||
|
||||
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
|
||||
# norm_map came from JSON - need to reconstruct enum keys and values
|
||||
reconstructed_norm_map = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed_norm_map
|
||||
|
||||
self.stats = self.stats or {}
|
||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||
|
||||
def _unnormalize_obs(self, observation, unnormalized_info):
|
||||
if observation is None:
|
||||
return None
|
||||
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
|
||||
processed = dict(observation)
|
||||
for key in keys:
|
||||
if key not in processed or key not in self.features:
|
||||
continue
|
||||
|
||||
# Check the normalization mode for this feature type
|
||||
feature = self.features[key]
|
||||
norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY)
|
||||
|
||||
# Skip unnormalization if mode is IDENTITY
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
unnormalized_info[key] = "IDENTITY"
|
||||
continue
|
||||
|
||||
# Skip if no stats available for this key
|
||||
if key not in self._tensor_stats:
|
||||
continue
|
||||
|
||||
orig_val = processed[key]
|
||||
tensor = (
|
||||
orig_val.to(dtype=torch.float32)
|
||||
if isinstance(orig_val, torch.Tensor)
|
||||
else torch.as_tensor(orig_val, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = tensor * std + mean
|
||||
unnormalized_info[key] = "MEAN_STD"
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
unnormalized_info[key] = "MIN_MAX"
|
||||
else:
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
return processed
|
||||
|
||||
def _unnormalize_action(self, action, unnormalized_info):
|
||||
if action is None:
|
||||
return action
|
||||
|
||||
# Check the normalization mode for actions
|
||||
norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY)
|
||||
|
||||
# Skip unnormalization if mode is IDENTITY
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
unnormalized_info["action"] = "IDENTITY"
|
||||
return action
|
||||
|
||||
# Skip if no stats available for actions
|
||||
if "action" not in self._tensor_stats:
|
||||
return action
|
||||
|
||||
tensor = (
|
||||
action.to(dtype=torch.float32)
|
||||
if isinstance(action, torch.Tensor)
|
||||
else torch.as_tensor(action, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
unnormalized_info["action"] = "MEAN_STD"
|
||||
return tensor * std + mean
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
unnormalized_info["action"] = "MIN_MAX"
|
||||
return (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
else:
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
# If we reach here, the required stats for the normalization mode are not available
|
||||
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Track what was unnormalized
|
||||
unnormalized_info = {}
|
||||
|
||||
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION), unnormalized_info)
|
||||
action = self._unnormalize_action(transition.get(TransitionKey.ACTION), unnormalized_info)
|
||||
|
||||
# Create a new transition with unnormalized values
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Handle observation unnormalization.
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is not None:
|
||||
new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(observation, inverse=True)
|
||||
|
||||
# Handle action unnormalization.
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True)
|
||||
# Add unnormalization info to complementary data
|
||||
if unnormalized_info:
|
||||
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
comp_data = {} if comp_data is None else dict(comp_data)
|
||||
comp_data["unnormalized_keys"] = unnormalized_info
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"features": {
|
||||
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
|
||||
},
|
||||
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor
|
||||
return flat
|
||||
|
||||
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor:
|
||||
"""
|
||||
Replaces normalization statistics in a RobotProcessor pipeline.
|
||||
|
||||
This function creates a deep copy of the provided `RobotProcessor` and updates the
|
||||
statistics of any `NormalizerProcessor` or `UnnormalizerProcessor` steps within it.
|
||||
It's useful for adapting a trained policy to a new environment or dataset with
|
||||
different data distributions.
|
||||
"""
|
||||
rp = deepcopy(robot_processor)
|
||||
for step in rp.steps:
|
||||
if isinstance(step, _NormalizationMixin):
|
||||
robot_processor = deepcopy(robot_processor)
|
||||
for step in robot_processor.steps:
|
||||
if isinstance(step, NormalizerProcessor) or isinstance(step, UnnormalizerProcessor):
|
||||
step: NormalizerProcessor | UnnormalizerProcessor
|
||||
step.stats = stats
|
||||
# Re-initialize tensor_stats on the correct device.
|
||||
step._tensor_stats = to_tensor(stats, device=step.device)
|
||||
return rp
|
||||
step._tensor_stats = _convert_stats_to_tensors(stats)
|
||||
return robot_processor
|
||||
|
||||
|
||||
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
|
||||
"""Rename keys in the stats dictionary according to the provided mapping.
|
||||
|
||||
Args:
|
||||
stats: The statistics dictionary with structure {feature_key: {stat_name: value}}
|
||||
rename_map: Dictionary mapping old key names to new key names
|
||||
|
||||
Returns:
|
||||
A new stats dictionary with renamed keys
|
||||
|
||||
Example:
|
||||
>>> stats = {"observation.state": {"mean": 0.0, "std": 1.0}, "action": {"mean": 0.5, "std": 0.5}}
|
||||
>>> rename_map = {"observation.state": "observation.robot_state"}
|
||||
>>> new_stats = rename_stats(stats, rename_map)
|
||||
>>> # new_stats will have "observation.robot_state" instead of "observation.state"
|
||||
"""
|
||||
renamed_stats = {}
|
||||
|
||||
for old_key, sub_stats in stats.items():
|
||||
# Use the new key if it exists in the rename map, otherwise keep the old key
|
||||
new_key = rename_map.get(old_key, old_key)
|
||||
renamed_stats[new_key] = deepcopy(sub_stats)
|
||||
|
||||
return renamed_stats
|
||||
|
||||
@@ -18,13 +18,12 @@ from __future__ import annotations
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Generic, TypedDict, TypeVar, cast
|
||||
from typing import Any, Protocol, TypedDict, runtime_checkable
|
||||
|
||||
import torch
|
||||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||||
@@ -33,9 +32,6 @@ from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
|
||||
# Type variable for generic processor output type
|
||||
TOutput = TypeVar("TOutput")
|
||||
|
||||
|
||||
class TransitionKey(str, Enum):
|
||||
"""Keys for accessing EnvTransition dictionary components."""
|
||||
@@ -136,7 +132,8 @@ class ProcessorStepRegistry:
|
||||
cls._registry.clear()
|
||||
|
||||
|
||||
class ProcessorStep(ABC):
|
||||
@runtime_checkable
|
||||
class ProcessorStep(Protocol):
|
||||
"""Structural typing interface for a single processor step.
|
||||
|
||||
A step is any callable accepting a full `EnvTransition` dict and
|
||||
@@ -169,34 +166,17 @@ class ProcessorStep(ABC):
|
||||
- state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)}
|
||||
"""
|
||||
|
||||
_current_transition: EnvTransition | None = None
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition: ...
|
||||
|
||||
@property
|
||||
def transition(self) -> EnvTransition:
|
||||
"""The current transition being processed by this step."""
|
||||
if self._current_transition is None:
|
||||
raise ValueError("Transition is not set. Make sure to call the step with a transition first.")
|
||||
return self._current_transition
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
def state_dict(self) -> dict[str, torch.Tensor]: ...
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ...
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
def reset(self) -> None: ...
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
return None
|
||||
|
||||
def reset(self) -> None:
|
||||
return None
|
||||
|
||||
# TODO(Steven): Consider making this abstract so it is more explicit
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ...
|
||||
|
||||
|
||||
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
|
||||
@@ -219,10 +199,6 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq
|
||||
metadata without breaking the processor.
|
||||
"""
|
||||
|
||||
# Validate input type
|
||||
if not isinstance(batch, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
|
||||
|
||||
# Extract observation keys
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
observation = observation_keys if observation_keys else None
|
||||
@@ -286,15 +262,8 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
return batch
|
||||
|
||||
|
||||
class ProcessorKwargs(TypedDict, total=False):
|
||||
"""Keyword arguments for RobotProcessor constructor."""
|
||||
|
||||
to_transition: Callable[[dict[str, Any]], EnvTransition] | None
|
||||
to_output: Callable[[EnvTransition], Any] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
class RobotProcessor(ModelHubMixin):
|
||||
"""
|
||||
Composable, debuggable post-processing processor for robot transitions.
|
||||
|
||||
@@ -302,43 +271,20 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` dicts
|
||||
and batch dictionaries, automatically converting between formats as needed.
|
||||
|
||||
The processor is generic over its output type TOutput, which provides better type safety
|
||||
and clarity about what the processor returns.
|
||||
|
||||
Args:
|
||||
steps: Ordered list of processing steps executed on every call. Defaults to empty list.
|
||||
name: Human-readable identifier that is persisted inside the JSON config.
|
||||
Defaults to "RobotProcessor".
|
||||
to_transition: Function to convert batch dict to EnvTransition dict.
|
||||
Defaults to _default_batch_to_transition.
|
||||
to_output: Function to convert EnvTransition dict to the desired output format of type TOutput.
|
||||
Defaults to _default_transition_to_batch (returns batch dict).
|
||||
Use identity function (lambda x: x) for EnvTransition output.
|
||||
to_output: Function to convert EnvTransition dict to the desired output format.
|
||||
Usually it is a batch dict or EnvTransition dict.
|
||||
Defaults to _default_transition_to_batch.
|
||||
before_step_hooks: List of hooks called before each step. Each hook receives the step
|
||||
index and transition, and can optionally return a modified transition.
|
||||
after_step_hooks: List of hooks called after each step. Each hook receives the step
|
||||
index and transition, and can optionally return a modified transition.
|
||||
|
||||
Type Safety Examples:
|
||||
```python
|
||||
# Default behavior - returns batch dict
|
||||
processor: RobotProcessor[dict[str, Any]] = RobotProcessor(steps=[some_step1, some_step2])
|
||||
result: dict[str, Any] = processor(batch_data) # Type checker knows this is a dict
|
||||
|
||||
# For EnvTransition output, explicitly specify identity function
|
||||
transition_processor: RobotProcessor[EnvTransition] = RobotProcessor(
|
||||
steps=[some_step1, some_step2],
|
||||
to_output=lambda x: x, # Identity function
|
||||
)
|
||||
result: EnvTransition = transition_processor(batch_data) # Type checker knows this is EnvTransition
|
||||
|
||||
# For custom output types
|
||||
processor: RobotProcessor[str] = RobotProcessor(
|
||||
steps=[custom_step], to_output=lambda t: f"Processed {len(t)} keys"
|
||||
)
|
||||
result: str = processor(batch_data) # Type checker knows this is str
|
||||
```
|
||||
|
||||
Hook Semantics:
|
||||
- Hooks are executed sequentially in the order they were registered. There is no way to
|
||||
reorder hooks after registration without creating a new pipeline.
|
||||
@@ -360,13 +306,8 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(
|
||||
default_factory=lambda: _default_batch_to_transition, repr=False
|
||||
)
|
||||
to_output: Callable[[EnvTransition], TOutput] = field(
|
||||
# Cast is necessary here: Working around Python type-checker limitation.
|
||||
# _default_transition_to_batch returns dict[str, Any], but we need it to be TOutput
|
||||
# for the generic to work. When no explicit type is given, TOutput defaults to dict[str, Any],
|
||||
# making this cast safe.
|
||||
default_factory=lambda: cast(Callable[[EnvTransition], TOutput], _default_transition_to_batch),
|
||||
repr=False,
|
||||
to_output: Callable[[EnvTransition], dict[str, Any] | EnvTransition] = field(
|
||||
default_factory=lambda: _default_transition_to_batch, repr=False
|
||||
)
|
||||
|
||||
# Processor-level hooks for observation/monitoring
|
||||
@@ -374,57 +315,98 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
|
||||
def __call__(self, data: dict[str, Any]) -> TOutput:
|
||||
def __call__(self, data: EnvTransition | dict[str, Any]):
|
||||
"""Process data through all steps.
|
||||
|
||||
The method accepts a batch dictionary (like the ones returned by ReplayBuffer or
|
||||
LeRobotDataset). It is first converted to EnvTransition format using to_transition,
|
||||
then processed through all steps, and finally converted to the output format using to_output.
|
||||
The method accepts either the classic EnvTransition dict or a batch dictionary
|
||||
(like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied
|
||||
it is first converted to the internal dict format using to_transition; after all
|
||||
steps are executed the dict is transformed back into a batch dict with to_batch and the
|
||||
result is returned – thereby preserving the caller's original data type.
|
||||
|
||||
Args:
|
||||
data: A batch dictionary to process.
|
||||
data: Either an EnvTransition dict or a batch dictionary to process.
|
||||
|
||||
Returns:
|
||||
The processed data in the format specified by to_output.
|
||||
The processed data in the same format as the input (EnvTransition or batch dict).
|
||||
|
||||
Raises:
|
||||
ValueError: If the transition is not a valid EnvTransition format.
|
||||
"""
|
||||
# Always convert input through to_transition
|
||||
transition = self.to_transition(data)
|
||||
# Check if we need to convert back to batch format at the end
|
||||
_, called_with_batch = self._prepare_transition(data)
|
||||
|
||||
# Process through all steps
|
||||
for idx, processor_step in enumerate(self.steps):
|
||||
# Apply before hooks
|
||||
# Use step_through to get the iterator
|
||||
step_iterator = self.step_through(data)
|
||||
|
||||
# Get initial state (before any steps)
|
||||
current_transition = next(step_iterator)
|
||||
|
||||
# Process each step with hooks
|
||||
for idx, next_transition in enumerate(step_iterator):
|
||||
# Apply before hooks with current state (before step execution)
|
||||
for hook in self.before_step_hooks:
|
||||
hook(idx, transition)
|
||||
hook(idx, current_transition)
|
||||
|
||||
# Execute step
|
||||
transition = processor_step(transition)
|
||||
# Move to next state (after step execution)
|
||||
current_transition = next_transition
|
||||
|
||||
# Apply after hooks
|
||||
# Apply after hooks with updated state
|
||||
for hook in self.after_step_hooks:
|
||||
hook(idx, transition)
|
||||
hook(idx, current_transition)
|
||||
|
||||
# Always use to_output for consistent typing
|
||||
return self.to_output(transition)
|
||||
# Convert back to original format if needed
|
||||
if called_with_batch or self.to_output is not _default_transition_to_batch:
|
||||
return self.to_output(current_transition)
|
||||
else:
|
||||
return current_transition
|
||||
|
||||
def step_through(self, data: dict[str, Any]) -> Iterable[EnvTransition]:
|
||||
def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
|
||||
"""Prepare and validate transition data for processing.
|
||||
|
||||
Args:
|
||||
data: Either an EnvTransition dict or a batch dictionary to process.
|
||||
|
||||
Returns:
|
||||
A tuple of (prepared_transition, called_with_batch_flag)
|
||||
|
||||
Raises:
|
||||
ValueError: If the transition is not a valid EnvTransition format.
|
||||
"""
|
||||
# Check if data is already an EnvTransition or needs conversion
|
||||
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
|
||||
# It's a batch dict, convert it
|
||||
called_with_batch = True
|
||||
transition = self.to_transition(data)
|
||||
else:
|
||||
# It's already an EnvTransition
|
||||
called_with_batch = False
|
||||
transition = data
|
||||
|
||||
# Basic validation
|
||||
if not isinstance(transition, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
|
||||
|
||||
return transition, called_with_batch
|
||||
|
||||
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition]:
|
||||
"""Yield the intermediate results after each processor step.
|
||||
|
||||
This is a low-level method that does NOT apply hooks. It simply executes each step
|
||||
and yields the intermediate results. This allows users to debug the pipeline or
|
||||
apply custom logic between steps if needed.
|
||||
|
||||
Note: This method always yields EnvTransition objects regardless of output format.
|
||||
If you need the results in the output format, you'll need to convert them
|
||||
Note: This method always yields EnvTransition objects regardless of input format.
|
||||
If you need the results in the original input format, you'll need to convert them
|
||||
using `to_output()`.
|
||||
|
||||
Args:
|
||||
data: A batch dictionary to process.
|
||||
data: Either an EnvTransition dict or a batch dictionary to process.
|
||||
|
||||
Yields:
|
||||
The intermediate EnvTransition results after each step.
|
||||
"""
|
||||
# Always convert input through to_transition
|
||||
transition = self.to_transition(data)
|
||||
transition, _ = self._prepare_transition(data)
|
||||
|
||||
# Yield initial state
|
||||
yield transition
|
||||
@@ -526,10 +508,8 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
revision: str | None = None,
|
||||
config_filename: str | None = None,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
to_transition: Callable[[dict[str, Any]], EnvTransition] | None = None,
|
||||
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
||||
**kwargs,
|
||||
) -> RobotProcessor[TOutput]:
|
||||
) -> RobotProcessor:
|
||||
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
|
||||
|
||||
Args:
|
||||
@@ -543,14 +523,9 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
(for registered steps). Values are dictionaries containing parameter overrides
|
||||
that will be merged with the saved configuration. This is useful for providing
|
||||
non-serializable objects like environment instances.
|
||||
to_transition: Function to convert batch dict to EnvTransition dict.
|
||||
Defaults to _default_batch_to_transition.
|
||||
to_output: Function to convert EnvTransition dict to the desired output format of type T.
|
||||
Defaults to _default_transition_to_batch (returns batch dict).
|
||||
Use identity function (lambda x: x) for EnvTransition output.
|
||||
|
||||
Returns:
|
||||
A RobotProcessor[TOutput] instance loaded from the saved configuration.
|
||||
A RobotProcessor instance loaded from the saved configuration.
|
||||
|
||||
Raises:
|
||||
ImportError: If a processor step class cannot be loaded or imported.
|
||||
@@ -764,34 +739,19 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
f"Make sure override keys match exact step class names or registry names."
|
||||
)
|
||||
|
||||
return cls(
|
||||
steps=steps,
|
||||
name=loaded_config.get("name", "RobotProcessor"),
|
||||
to_transition=to_transition or _default_batch_to_transition,
|
||||
# Cast is necessary here: Same type-checker limitation as above.
|
||||
# When to_output is None, we use the default which returns dict[str, Any].
|
||||
# The cast ensures type consistency with the generic TOutput parameter.
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], _default_transition_to_batch),
|
||||
)
|
||||
return cls(steps, loaded_config.get("name", "RobotProcessor"))
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of steps in the processor."""
|
||||
return len(self.steps)
|
||||
|
||||
def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor[TOutput]:
|
||||
def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor:
|
||||
"""Indexing helper exposing underlying steps.
|
||||
* ``int`` – returns the idx-th ProcessorStep.
|
||||
* ``slice`` – returns a new RobotProcessor with the sliced steps.
|
||||
"""
|
||||
if isinstance(idx, slice):
|
||||
return RobotProcessor(
|
||||
steps=self.steps[idx],
|
||||
name=self.name,
|
||||
to_transition=self.to_transition,
|
||||
to_output=self.to_output,
|
||||
before_step_hooks=self.before_step_hooks.copy(),
|
||||
after_step_hooks=self.after_step_hooks.copy(),
|
||||
)
|
||||
return RobotProcessor(self.steps[idx], self.name)
|
||||
return self.steps[idx]
|
||||
|
||||
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
||||
@@ -860,7 +820,6 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
def __post_init__(self):
|
||||
for i, step in enumerate(self.steps):
|
||||
if not callable(step):
|
||||
# TODO(steven): This should instead check isinstance(step, ProcessorStep), test need to be updated
|
||||
raise TypeError(
|
||||
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
|
||||
)
|
||||
@@ -878,7 +837,7 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
return features
|
||||
|
||||
|
||||
class ObservationProcessor(ProcessorStep, ABC):
|
||||
class ObservationProcessor:
|
||||
"""Base class for processors that modify only the observation component of a transition.
|
||||
|
||||
Subclasses should override the `observation` method to implement custom observation processing.
|
||||
@@ -899,8 +858,7 @@ class ObservationProcessor(ProcessorStep, ABC):
|
||||
manipulation, focusing only on the specific observation processing logic.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def observation(self, observation) -> dict[str, Any]:
|
||||
def observation(self, observation):
|
||||
"""Process the observation component.
|
||||
|
||||
Args:
|
||||
@@ -909,22 +867,36 @@ class ObservationProcessor(ProcessorStep, ABC):
|
||||
Returns:
|
||||
The processed observation
|
||||
"""
|
||||
...
|
||||
return observation
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
self._current_transition = transition.copy()
|
||||
new_transition = self._current_transition
|
||||
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
return new_transition
|
||||
return transition
|
||||
|
||||
processed_observation = self.observation(observation)
|
||||
# Create a new transition dict with the processed observation
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_observation
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
class ActionProcessor(ProcessorStep, ABC):
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
class ActionProcessor:
|
||||
"""Base class for processors that modify only the action component of a transition.
|
||||
|
||||
Subclasses should override the `action` method to implement custom action processing.
|
||||
@@ -946,8 +918,7 @@ class ActionProcessor(ProcessorStep, ABC):
|
||||
manipulation, focusing only on the specific action processing logic.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def action(self, action) -> Any | torch.Tensor:
|
||||
def action(self, action):
|
||||
"""Process the action component.
|
||||
|
||||
Args:
|
||||
@@ -956,22 +927,36 @@ class ActionProcessor(ProcessorStep, ABC):
|
||||
Returns:
|
||||
The processed action
|
||||
"""
|
||||
...
|
||||
return action
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
self._current_transition = transition.copy()
|
||||
new_transition = self._current_transition
|
||||
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return new_transition
|
||||
return transition
|
||||
|
||||
processed_action = self.action(action)
|
||||
# Create a new transition dict with the processed action
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.ACTION] = processed_action
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
class RewardProcessor(ProcessorStep, ABC):
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
class RewardProcessor:
|
||||
"""Base class for processors that modify only the reward component of a transition.
|
||||
|
||||
Subclasses should override the `reward` method to implement custom reward processing.
|
||||
@@ -992,8 +977,7 @@ class RewardProcessor(ProcessorStep, ABC):
|
||||
manipulation, focusing only on the specific reward processing logic.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reward(self, reward) -> float | torch.Tensor:
|
||||
def reward(self, reward):
|
||||
"""Process the reward component.
|
||||
|
||||
Args:
|
||||
@@ -1002,22 +986,36 @@ class RewardProcessor(ProcessorStep, ABC):
|
||||
Returns:
|
||||
The processed reward
|
||||
"""
|
||||
...
|
||||
return reward
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
self._current_transition = transition.copy()
|
||||
new_transition = self._current_transition
|
||||
|
||||
reward = new_transition.get(TransitionKey.REWARD)
|
||||
reward = transition.get(TransitionKey.REWARD)
|
||||
if reward is None:
|
||||
return new_transition
|
||||
return transition
|
||||
|
||||
processed_reward = self.reward(reward)
|
||||
# Create a new transition dict with the processed reward
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.REWARD] = processed_reward
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
class DoneProcessor(ProcessorStep, ABC):
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
class DoneProcessor:
|
||||
"""Base class for processors that modify only the done flag of a transition.
|
||||
|
||||
Subclasses should override the `done` method to implement custom done flag processing.
|
||||
@@ -1043,8 +1041,7 @@ class DoneProcessor(ProcessorStep, ABC):
|
||||
manipulation, focusing only on the specific done flag processing logic.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def done(self, done) -> bool | torch.Tensor:
|
||||
def done(self, done):
|
||||
"""Process the done flag.
|
||||
|
||||
Args:
|
||||
@@ -1053,22 +1050,36 @@ class DoneProcessor(ProcessorStep, ABC):
|
||||
Returns:
|
||||
The processed done flag
|
||||
"""
|
||||
...
|
||||
return done
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
self._current_transition = transition.copy()
|
||||
new_transition = self._current_transition
|
||||
|
||||
done = new_transition.get(TransitionKey.DONE)
|
||||
done = transition.get(TransitionKey.DONE)
|
||||
if done is None:
|
||||
return new_transition
|
||||
return transition
|
||||
|
||||
processed_done = self.done(done)
|
||||
# Create a new transition dict with the processed done flag
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.DONE] = processed_done
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
class TruncatedProcessor(ProcessorStep, ABC):
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
class TruncatedProcessor:
|
||||
"""Base class for processors that modify only the truncated flag of a transition.
|
||||
|
||||
Subclasses should override the `truncated` method to implement custom truncated flag processing.
|
||||
@@ -1090,8 +1101,7 @@ class TruncatedProcessor(ProcessorStep, ABC):
|
||||
manipulation, focusing only on the specific truncated flag processing logic.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def truncated(self, truncated) -> bool | torch.Tensor:
|
||||
def truncated(self, truncated):
|
||||
"""Process the truncated flag.
|
||||
|
||||
Args:
|
||||
@@ -1100,22 +1110,36 @@ class TruncatedProcessor(ProcessorStep, ABC):
|
||||
Returns:
|
||||
The processed truncated flag
|
||||
"""
|
||||
...
|
||||
return truncated
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
self._current_transition = transition.copy()
|
||||
new_transition = self._current_transition
|
||||
|
||||
truncated = new_transition.get(TransitionKey.TRUNCATED)
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
if truncated is None:
|
||||
return new_transition
|
||||
return transition
|
||||
|
||||
processed_truncated = self.truncated(truncated)
|
||||
# Create a new transition dict with the processed truncated flag
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.TRUNCATED] = processed_truncated
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
class InfoProcessor(ProcessorStep, ABC):
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
class InfoProcessor:
|
||||
"""Base class for processors that modify only the info dictionary of a transition.
|
||||
|
||||
Subclasses should override the `info` method to implement custom info processing.
|
||||
@@ -1142,8 +1166,7 @@ class InfoProcessor(ProcessorStep, ABC):
|
||||
manipulation, focusing only on the specific info dictionary processing logic.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def info(self, info) -> dict[str, Any]:
|
||||
def info(self, info):
|
||||
"""Process the info dictionary.
|
||||
|
||||
Args:
|
||||
@@ -1152,22 +1175,36 @@ class InfoProcessor(ProcessorStep, ABC):
|
||||
Returns:
|
||||
The processed info dictionary
|
||||
"""
|
||||
...
|
||||
return info
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
self._current_transition = transition.copy()
|
||||
new_transition = self._current_transition
|
||||
|
||||
info = new_transition.get(TransitionKey.INFO)
|
||||
info = transition.get(TransitionKey.INFO)
|
||||
if info is None:
|
||||
return new_transition
|
||||
return transition
|
||||
|
||||
processed_info = self.info(info)
|
||||
# Create a new transition dict with the processed info
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.INFO] = processed_info
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
class ComplementaryDataProcessor(ProcessorStep, ABC):
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
class ComplementaryDataProcessor:
|
||||
"""Base class for processors that modify only the complementary data of a transition.
|
||||
|
||||
Subclasses should override the `complementary_data` method to implement custom complementary data processing.
|
||||
@@ -1175,8 +1212,7 @@ class ComplementaryDataProcessor(ProcessorStep, ABC):
|
||||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def complementary_data(self, complementary_data) -> dict[str, Any]:
|
||||
def complementary_data(self, complementary_data):
|
||||
"""Process the complementary data.
|
||||
|
||||
Args:
|
||||
@@ -1185,23 +1221,52 @@ class ComplementaryDataProcessor(ProcessorStep, ABC):
|
||||
Returns:
|
||||
The processed complementary data
|
||||
"""
|
||||
...
|
||||
return complementary_data
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
self._current_transition = transition.copy()
|
||||
new_transition = self._current_transition
|
||||
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return new_transition
|
||||
return transition
|
||||
|
||||
processed_complementary_data = self.complementary_data(complementary_data)
|
||||
# Create a new transition dict with the processed complementary data
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
class IdentityProcessor(ProcessorStep):
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
class IdentityProcessor:
|
||||
"""Identity processor that does nothing."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -50,14 +49,3 @@ class RenameProcessor(ObservationProcessor):
|
||||
- Keys not in `rename_map` remain unchanged.
|
||||
"""
|
||||
return {self.rename_map.get(k, k): v for k, v in features.items()}
|
||||
|
||||
|
||||
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
|
||||
"""Rename keys in the stats dictionary according to rename_map (defensive copy)."""
|
||||
if not stats:
|
||||
return {}
|
||||
renamed: dict[str, dict[str, Any]] = {}
|
||||
for old_key, sub_stats in stats.items():
|
||||
new_key = rename_map.get(old_key, old_key)
|
||||
renamed[new_key] = deepcopy(sub_stats) if sub_stats is not None else {}
|
||||
return renamed
|
||||
|
||||
@@ -10,13 +10,8 @@ from typing import TYPE_CHECKING, Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
from lerobot.processor.pipeline import (
|
||||
EnvTransition,
|
||||
ObservationProcessor,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.constants import OBS_LANGUAGE
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
@@ -27,7 +22,7 @@ else:
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="tokenizer_processor")
|
||||
class TokenizerProcessor(ObservationProcessor):
|
||||
class TokenizerProcessor:
|
||||
"""Tokenizes text tasks in complementary data using a huggingface tokenizer.
|
||||
|
||||
This processor handles tokenization of task strings found in the complementary_data
|
||||
@@ -123,7 +118,7 @@ class TokenizerProcessor(ObservationProcessor):
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation):
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Process the transition by tokenizing the task text.
|
||||
|
||||
Args:
|
||||
@@ -135,15 +130,15 @@ class TokenizerProcessor(ObservationProcessor):
|
||||
Raises:
|
||||
ValueError: If tokenizer initialization failed.
|
||||
"""
|
||||
task = self.get_task(self.transition)
|
||||
task = self.get_task(transition)
|
||||
if task is None:
|
||||
return observation
|
||||
return transition
|
||||
|
||||
# Tokenize the task (creates CPU tensors)
|
||||
tokenized_prompt = self._tokenize_text(task)
|
||||
|
||||
# Detect device from existing tensors in the transition
|
||||
target_device = self._detect_device(self.transition)
|
||||
target_device = self._detect_device(transition)
|
||||
|
||||
# Move tokenized tensors to match the device of other data
|
||||
if target_device is not None:
|
||||
@@ -153,13 +148,20 @@ class TokenizerProcessor(ObservationProcessor):
|
||||
}
|
||||
|
||||
# Get or create observation dict
|
||||
new_observation = dict(observation)
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
observation = {}
|
||||
else:
|
||||
observation = dict(observation) # Make a copy
|
||||
|
||||
# Add tokenized data to observation
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"]
|
||||
observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
|
||||
return new_observation
|
||||
transition[TransitionKey.OBSERVATION.value] = observation # type: ignore[misc]
|
||||
return transition
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
"""Detect device from existing tensors in the transition.
|
||||
@@ -185,6 +187,19 @@ class TokenizerProcessor(ObservationProcessor):
|
||||
if isinstance(action, torch.Tensor):
|
||||
return action.device
|
||||
|
||||
# Check other tensor fields
|
||||
for key in [TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED]:
|
||||
value = transition.get(key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
|
||||
# Check complementary data for tensors
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data:
|
||||
for value in complementary_data.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
|
||||
return None # No tensors found, keep on CPU
|
||||
|
||||
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
|
||||
@@ -220,12 +235,23 @@ class TokenizerProcessor(ObservationProcessor):
|
||||
}
|
||||
|
||||
# Only include tokenizer_name if it was used (not when tokenizer object was provided)
|
||||
# TODO(steven): Consider saving the name of the _tokenizer if it was loaded
|
||||
if self.tokenizer_name is not None and self.tokenizer is None:
|
||||
if self.tokenizer_name is not None:
|
||||
config["tokenizer_name"] = self.tokenizer_name
|
||||
|
||||
return config
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Add tokenized task features to the feature contract.
|
||||
|
||||
@@ -237,13 +263,13 @@ class TokenizerProcessor(ObservationProcessor):
|
||||
"""
|
||||
# Add features for tokenized output if they don't exist
|
||||
# Standard tokenizer output includes tokens and attention_mask
|
||||
tokens_key = f"{OBS_LANGUAGE}.tokens"
|
||||
attention_mask_key = f"{OBS_LANGUAGE}.attention_mask"
|
||||
|
||||
if OBS_LANGUAGE_TOKENS not in features:
|
||||
features[OBS_LANGUAGE_TOKENS] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
|
||||
if tokens_key not in features:
|
||||
features[tokens_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
|
||||
|
||||
if OBS_LANGUAGE_ATTENTION_MASK not in features:
|
||||
features[OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
if attention_mask_key not in features:
|
||||
features[attention_mask_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
|
||||
|
||||
return features
|
||||
|
||||
@@ -74,7 +74,7 @@ from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.factory import make_policy, make_processor
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import RobotProcessor
|
||||
from lerobot.processor.converters import (
|
||||
@@ -83,8 +83,8 @@ from lerobot.processor.converters import (
|
||||
to_transition_robot_observation,
|
||||
to_transition_teleop_action,
|
||||
)
|
||||
from lerobot.processor.normalize_processor import rename_stats
|
||||
from lerobot.processor.pipeline import IdentityProcessor, TransitionKey
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -346,14 +346,13 @@ def record_loop(
|
||||
else:
|
||||
logging.info(
|
||||
"No policy or teleoperator provided, skipping action generation. "
|
||||
"This is likely to happen when resetting the environment without a teleop device."
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
"This is likely to happen during environment reset."
|
||||
)
|
||||
continue
|
||||
# Still continue to next loop to respect timing
|
||||
|
||||
# Applies a pipeline to the action, default is IdentityProcessor
|
||||
# IMPORTANT: action_pipeline.to_output must return a dict suitable for robot.send_action()
|
||||
if policy is not None and policy_transition is not None:
|
||||
if policy_transition is not None:
|
||||
robot_action_to_send = robot_action_processor(policy_transition)
|
||||
else:
|
||||
robot_action_to_send = robot_action_processor(teleop_transition)
|
||||
@@ -435,7 +434,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
if cfg.policy is not None:
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
@@ -511,9 +510,5 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
record()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
record()
|
||||
|
||||
@@ -45,11 +45,9 @@ from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
from lerobot.configs import parser
|
||||
import draccus
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor import RobotProcessor
|
||||
from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action
|
||||
from lerobot.processor.pipeline import IdentityProcessor
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -85,25 +83,13 @@ class ReplayConfig:
|
||||
dataset: DatasetReplayConfig
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
# Optional processor for actions before sending to robot
|
||||
robot_action_processor: RobotProcessor | None = None
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
@draccus.wrap()
|
||||
def replay(cfg: ReplayConfig):
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# Initialize robot action processor with default if not provided
|
||||
robot_action_processor = cfg.robot_action_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()],
|
||||
to_transition=to_transition_teleop_action,
|
||||
to_output=to_output_robot_action, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Reset processor
|
||||
robot_action_processor.reset()
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
@@ -118,10 +104,7 @@ def replay(cfg: ReplayConfig):
|
||||
for i, name in enumerate(dataset.features["action"]["names"]):
|
||||
action[name] = action_array[i]
|
||||
|
||||
# Process action through robot action processor
|
||||
processed_action = robot_action_processor(action)
|
||||
|
||||
robot.send_action(processed_action) # type: ignore[arg-type]
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / dataset.fps - dt_s)
|
||||
|
||||
@@ -19,15 +19,13 @@ from dataclasses import dataclass, field
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor.pipeline import (
|
||||
ActionProcessor,
|
||||
ComplementaryDataProcessor,
|
||||
EnvTransition,
|
||||
ObservationProcessor,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
@@ -36,7 +34,7 @@ from lerobot.robots.robot import Robot
|
||||
|
||||
@ProcessorStepRegistry.register("ee_reference_and_delta")
|
||||
@dataclass
|
||||
class EEReferenceAndDelta(ActionProcessor):
|
||||
class EEReferenceAndDelta:
|
||||
"""
|
||||
Compute the desired end-effector pose from the target pose and the current pose.
|
||||
|
||||
@@ -63,9 +61,9 @@ class EEReferenceAndDelta(ActionProcessor):
|
||||
_prev_enabled: bool = field(default=False, init=False, repr=False)
|
||||
_command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def action(self, action):
|
||||
new_action = action.copy()
|
||||
comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
act = transition.get(TransitionKey.ACTION) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
raw = comp.get("raw_joint_positions", None)
|
||||
@@ -82,13 +80,13 @@ class EEReferenceAndDelta(ActionProcessor):
|
||||
# Current pose from FK on measured joints
|
||||
t_curr = self.kinematics.forward_kinematics(q)
|
||||
|
||||
enabled = bool(new_action.pop(f"{ACTION}.enabled", 0))
|
||||
tx = float(new_action.pop(f"{ACTION}.target_x", 0.0))
|
||||
ty = float(new_action.pop(f"{ACTION}.target_y", 0.0))
|
||||
tz = float(new_action.pop(f"{ACTION}.target_z", 0.0))
|
||||
wx = float(new_action.pop(f"{ACTION}.target_wx", 0.0))
|
||||
wy = float(new_action.pop(f"{ACTION}.target_wy", 0.0))
|
||||
wz = float(new_action.pop(f"{ACTION}.target_wz", 0.0))
|
||||
enabled = bool(act.pop("action.enabled", 0))
|
||||
tx = float(act.pop("action.target_x", 0.0))
|
||||
ty = float(act.pop("action.target_y", 0.0))
|
||||
tz = float(act.pop("action.target_z", 0.0))
|
||||
wx = float(act.pop("action.target_wx", 0.0))
|
||||
wy = float(act.pop("action.target_wy", 0.0))
|
||||
wz = float(act.pop("action.target_wz", 0.0))
|
||||
|
||||
desired = None
|
||||
|
||||
@@ -124,36 +122,22 @@ class EEReferenceAndDelta(ActionProcessor):
|
||||
# Write action fields
|
||||
pos = desired[:3, 3]
|
||||
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
|
||||
new_action[f"{ACTION}.ee.x"] = float(pos[0])
|
||||
new_action[f"{ACTION}.ee.y"] = float(pos[1])
|
||||
new_action[f"{ACTION}.ee.z"] = float(pos[2])
|
||||
new_action[f"{ACTION}.ee.wx"] = float(tw[0])
|
||||
new_action[f"{ACTION}.ee.wy"] = float(tw[1])
|
||||
new_action[f"{ACTION}.ee.wz"] = float(tw[2])
|
||||
act.update(
|
||||
{
|
||||
"action.ee.x": float(pos[0]),
|
||||
"action.ee.y": float(pos[1]),
|
||||
"action.ee.z": float(pos[2]),
|
||||
"action.ee.wx": float(tw[0]),
|
||||
"action.ee.wy": float(tw[1]),
|
||||
"action.ee.wz": float(tw[2]),
|
||||
}
|
||||
)
|
||||
|
||||
self._prev_enabled = enabled
|
||||
return new_action
|
||||
|
||||
def reset(self):
|
||||
self._prev_enabled = False
|
||||
self.reference_ee_pose = None
|
||||
self._command_when_disabled = None
|
||||
transition[TransitionKey.ACTION] = act
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(f"{ACTION}.enabled", None)
|
||||
features.pop(f"{ACTION}.target_x", None)
|
||||
features.pop(f"{ACTION}.target_y", None)
|
||||
features.pop(f"{ACTION}.target_z", None)
|
||||
features.pop(f"{ACTION}.target_wx", None)
|
||||
features.pop(f"{ACTION}.target_wy", None)
|
||||
features.pop(f"{ACTION}.target_wz", None)
|
||||
|
||||
features[f"{ACTION}.ee.x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.ee.y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.ee.z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.ee.wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.ee.wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.ee.wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
return features
|
||||
|
||||
|
||||
@@ -178,15 +162,14 @@ class EEBoundsAndSafety(ActionProcessor):
|
||||
max_ee_step_m: float = 0.05
|
||||
max_ee_twist_step_rad: float = 0.20
|
||||
_last_pos: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def action(self, act: dict) -> dict:
|
||||
x = act.get(f"{ACTION}.ee.x", None)
|
||||
y = act.get(f"{ACTION}.ee.y", None)
|
||||
z = act.get(f"{ACTION}.ee.z", None)
|
||||
wx = act.get(f"{ACTION}.ee.wx", None)
|
||||
wy = act.get(f"{ACTION}.ee.wy", None)
|
||||
wz = act.get(f"{ACTION}.ee.wz", None)
|
||||
def action(self, act: dict | None) -> dict:
|
||||
x = act.pop("action.ee.x", None)
|
||||
y = act.pop("action.ee.y", None)
|
||||
z = act.pop("action.ee.z", None)
|
||||
wx = act.pop("action.ee.wx", None)
|
||||
wy = act.pop("action.ee.wy", None)
|
||||
wz = act.pop("action.ee.wz", None)
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
return act
|
||||
@@ -208,22 +191,35 @@ class EEBoundsAndSafety(ActionProcessor):
|
||||
self._last_pos = pos
|
||||
self._last_twist = twist
|
||||
|
||||
act[f"{ACTION}.ee.x"] = float(pos[0])
|
||||
act[f"{ACTION}.ee.y"] = float(pos[1])
|
||||
act[f"{ACTION}.ee.z"] = float(pos[2])
|
||||
act[f"{ACTION}.ee.wx"] = float(twist[0])
|
||||
act[f"{ACTION}.ee.wy"] = float(twist[1])
|
||||
act[f"{ACTION}.ee.wz"] = float(twist[2])
|
||||
act.update(
|
||||
{
|
||||
"action.ee.x": float(pos[0]),
|
||||
"action.ee.y": float(pos[1]),
|
||||
"action.ee.z": float(pos[2]),
|
||||
"action.ee.wx": float(twist[0]),
|
||||
"action.ee.wy": float(twist[1]),
|
||||
"action.ee.wz": float(twist[2]),
|
||||
}
|
||||
)
|
||||
return act
|
||||
|
||||
def reset(self):
|
||||
self._last_pos = None
|
||||
self._last_twist = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# Because this is last step we specify the dataset features of this step that we want to be stored in the dataset
|
||||
features["action.ee.x"] = float
|
||||
features["action.ee.y"] = float
|
||||
features["action.ee.z"] = float
|
||||
features["action.ee.wx"] = float
|
||||
features["action.ee.wy"] = float
|
||||
features["action.ee.wz"] = float
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints")
|
||||
@dataclass
|
||||
class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
class InverseKinematicsEEToJoints:
|
||||
"""
|
||||
Compute the desired joint positions from the desired end-effector pose.
|
||||
|
||||
@@ -251,14 +247,26 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
act = transition.get(TransitionKey.ACTION) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
x = act.get(f"{ACTION}.ee.x", None)
|
||||
y = act.get(f"{ACTION}.ee.y", None)
|
||||
z = act.get(f"{ACTION}.ee.z", None)
|
||||
wx = act.get(f"{ACTION}.ee.wx", None)
|
||||
wy = act.get(f"{ACTION}.ee.wy", None)
|
||||
wz = act.get(f"{ACTION}.ee.wz", None)
|
||||
x = act.get("action.ee.x", None)
|
||||
y = act.get("action.ee.y", None)
|
||||
z = act.get("action.ee.z", None)
|
||||
wx = act.get("action.ee.wx", None)
|
||||
wy = act.get("action.ee.wy", None)
|
||||
wz = act.get("action.ee.wz", None)
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
# Nothing to do; restore what we popped and return
|
||||
act.update(
|
||||
{
|
||||
"action.ee.x": x,
|
||||
"action.ee.y": y,
|
||||
"action.ee.z": z,
|
||||
"action.ee.wx": wx,
|
||||
"action.ee.wy": wy,
|
||||
"action.ee.wz": wz,
|
||||
}
|
||||
)
|
||||
transition[TransitionKey.ACTION] = act
|
||||
return transition
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
@@ -286,20 +294,25 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
new_act = dict(act)
|
||||
for i, name in enumerate(self.motor_names):
|
||||
if name == "gripper":
|
||||
new_act[f"{OBS_STATE}.gripper.pos"] = float(raw["gripper"])
|
||||
new_act["observation.state.gripper.pos"] = float(raw["gripper"])
|
||||
else:
|
||||
new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
|
||||
new_act[f"action.{name}.pos"] = float(q_target[i])
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
if not self.initial_guess_current_joints:
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[f"{OBS_STATE}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
for name in self.motor_names:
|
||||
features[f"{ACTION}.{name}.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
features["action.ee.x"] = float
|
||||
features["action.ee.y"] = float
|
||||
features["action.ee.z"] = float
|
||||
features["action.ee.wx"] = float
|
||||
features["action.ee.wy"] = float
|
||||
features["action.ee.wz"] = float
|
||||
|
||||
features["observation.state.gripper.pos"] = float
|
||||
features["action.gripper.pos"] = float
|
||||
return features
|
||||
|
||||
def reset(self):
|
||||
@@ -308,7 +321,7 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
|
||||
@ProcessorStepRegistry.register("gripper_velocity_to_joint")
|
||||
@dataclass
|
||||
class GripperVelocityToJoint(ProcessorStep):
|
||||
class GripperVelocityToJoint:
|
||||
"""
|
||||
Convert the gripper velocity to a joint velocity.
|
||||
|
||||
@@ -334,12 +347,12 @@ class GripperVelocityToJoint(ProcessorStep):
|
||||
act = transition.get(TransitionKey.ACTION) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
if f"{ACTION}.gripper" not in act:
|
||||
if "action.gripper" not in act:
|
||||
return transition
|
||||
|
||||
if "gripper" not in self.motor_names:
|
||||
new_act = dict(act)
|
||||
new_act.pop(f"{ACTION}.gripper", None)
|
||||
new_act.pop("action.gripper", None)
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
return transition
|
||||
|
||||
@@ -347,32 +360,33 @@ class GripperVelocityToJoint(ProcessorStep):
|
||||
# Discrete gripper actions are in [0, 1, 2]
|
||||
# 0: open, 1: close, 2: stay
|
||||
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
|
||||
gripper_action = act.get(f"{ACTION}.gripper", 1.0)
|
||||
gripper_action = act.get("action.gripper", 1.0)
|
||||
gripper_action = gripper_action - 1.0
|
||||
gripper_action *= self.clip_max
|
||||
act[f"{ACTION}.gripper"] = gripper_action
|
||||
act["action.gripper"] = gripper_action
|
||||
|
||||
# Get current gripper position from complementary data
|
||||
raw = comp.get("raw_joint_positions") or {}
|
||||
curr_pos = float(raw.get("gripper"))
|
||||
|
||||
# Compute desired gripper velocity
|
||||
u = float(act.get(f"{ACTION}.gripper", 0.0))
|
||||
u = float(act.get("action.gripper", 0.0))
|
||||
delta = u * float(self.speed_factor)
|
||||
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
|
||||
|
||||
new_act = dict(act)
|
||||
new_act[f"{ACTION}.gripper.pos"] = gripper_pos
|
||||
new_act.pop(f"{ACTION}.gripper", None)
|
||||
new_act["action.gripper.pos"] = gripper_pos
|
||||
new_act.pop("action.gripper", None)
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
|
||||
obs[f"{OBS_STATE}.gripper.pos"] = curr_pos
|
||||
obs.update({"observation.state.gripper.pos": curr_pos})
|
||||
transition[TransitionKey.OBSERVATION] = obs
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(f"{ACTION}.gripper", None)
|
||||
features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
features["observation.state.gripper.pos"] = float
|
||||
features["action.gripper.pos"] = float
|
||||
return features
|
||||
|
||||
|
||||
@@ -396,27 +410,31 @@ class ForwardKinematicsJointsToEE(ObservationProcessor):
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
|
||||
def observation(self, obs: dict) -> dict:
|
||||
if not all(f"{OBS_STATE}.{n}.pos" in obs for n in self.motor_names):
|
||||
def observation(self, obs: dict | None) -> dict:
|
||||
if not all(f"observation.state.{n}.pos" in obs for n in self.motor_names):
|
||||
return obs
|
||||
|
||||
q = np.array([obs[f"{OBS_STATE}.{n}.pos"] for n in self.motor_names], dtype=float)
|
||||
q = np.array([obs[f"observation.state.{n}.pos"] for n in self.motor_names], dtype=float)
|
||||
t = self.kinematics.forward_kinematics(q)
|
||||
pos = t[:3, 3]
|
||||
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
|
||||
|
||||
obs[f"{OBS_STATE}.ee.x"] = float(pos[0])
|
||||
obs[f"{OBS_STATE}.ee.y"] = float(pos[1])
|
||||
obs[f"{OBS_STATE}.ee.z"] = float(pos[2])
|
||||
obs[f"{OBS_STATE}.ee.wx"] = float(tw[0])
|
||||
obs[f"{OBS_STATE}.ee.wy"] = float(tw[1])
|
||||
obs[f"{OBS_STATE}.ee.wz"] = float(tw[2])
|
||||
obs.update(
|
||||
{
|
||||
"observation.state.ee.x": float(pos[0]),
|
||||
"observation.state.ee.y": float(pos[1]),
|
||||
"observation.state.ee.z": float(pos[2]),
|
||||
"observation.state.ee.wx": float(tw[0]),
|
||||
"observation.state.ee.wy": float(tw[1]),
|
||||
"observation.state.ee.wz": float(tw[2]),
|
||||
}
|
||||
)
|
||||
return obs
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz"]:
|
||||
features[f"{OBS_STATE}.ee.{k}"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"observation.state.ee.{k}"] = float
|
||||
return features
|
||||
|
||||
|
||||
@@ -433,14 +451,15 @@ class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessor):
|
||||
robot: Robot
|
||||
|
||||
def complementary_data(self, comp: dict | None) -> dict:
|
||||
new_comp = dict(comp)
|
||||
obs = (
|
||||
self.robot.get_observation()
|
||||
) # todo(steven): why not self.trtansition.get(TransitionKey.OBSERVATION)?
|
||||
comp = {} if comp is None else dict(comp)
|
||||
obs = self.robot.get_observation()
|
||||
|
||||
new_comp["raw_joint_positions"] = {
|
||||
comp["raw_joint_positions"] = {
|
||||
k.removesuffix(".pos"): float(v)
|
||||
for k, v in obs.items()
|
||||
if isinstance(k, str) and k.endswith(".pos")
|
||||
}
|
||||
return new_comp
|
||||
return comp
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -161,11 +161,6 @@ class SO100Follower(Robot):
|
||||
self.bus.write("I_Coefficient", motor, 0)
|
||||
self.bus.write("D_Coefficient", motor, 32)
|
||||
|
||||
if motor == "gripper":
|
||||
self.bus.write("Max_Torque_Limit", motor, 500) # 50% of max torque to avoid burnout
|
||||
self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout
|
||||
self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
for motor in reversed(self.bus.motors):
|
||||
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||
|
||||
@@ -157,13 +157,6 @@ class SO101Follower(Robot):
|
||||
self.bus.write("I_Coefficient", motor, 0)
|
||||
self.bus.write("D_Coefficient", motor, 32)
|
||||
|
||||
if motor == "gripper":
|
||||
self.bus.write(
|
||||
"Max_Torque_Limit", motor, 500
|
||||
) # 50% of the max torque limit to avoid burnout
|
||||
self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout
|
||||
self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
for motor in reversed(self.bus.motors):
|
||||
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||
|
||||
@@ -29,6 +29,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
from .so100_follower import SO100Follower
|
||||
|
||||
return SO100Follower(config)
|
||||
elif config.type == "so100_follower_end_effector":
|
||||
from .so100_follower import SO100FollowerEndEffector
|
||||
|
||||
return SO100FollowerEndEffector(config)
|
||||
elif config.type == "so101_follower":
|
||||
from .so101_follower import SO101Follower
|
||||
|
||||
|
||||
@@ -98,6 +98,7 @@ from lerobot.utils.utils import (
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
|
||||
#################################################
|
||||
# Main entry point #
|
||||
#################################################
|
||||
@@ -287,9 +288,7 @@ def act_with_policy(
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
observation = {
|
||||
k: v for k, v in transition[TransitionKey.OBSERVATION].items() if k in cfg.policy.input_features
|
||||
}
|
||||
observation = transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
@@ -309,16 +308,8 @@ def act_with_policy(
|
||||
)
|
||||
|
||||
# Extract values from processed transition
|
||||
next_observation = {
|
||||
k: v
|
||||
for k, v in new_transition[TransitionKey.OBSERVATION].items()
|
||||
if k in cfg.policy.input_features
|
||||
}
|
||||
|
||||
# Teleop action is the action that was executed in the environment
|
||||
# It is either the action from the teleop device or the action from the policy
|
||||
executed_action = new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"]
|
||||
|
||||
next_observation = new_transition[TransitionKey.OBSERVATION]
|
||||
executed_action = new_transition[TransitionKey.ACTION]
|
||||
reward = new_transition[TransitionKey.REWARD]
|
||||
done = new_transition.get(TransitionKey.DONE, False)
|
||||
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
|
||||
|
||||
@@ -37,7 +37,6 @@ from lerobot.processor import (
|
||||
InterventionActionProcessor,
|
||||
JointVelocityProcessor,
|
||||
MapDeltaActionToRobotAction,
|
||||
MapTensorToDeltaActionDict,
|
||||
MotorCurrentProcessor,
|
||||
Numpy2TorchActionProcessor,
|
||||
RewardClassifierProcessor,
|
||||
@@ -81,11 +80,11 @@ class DatasetConfig:
|
||||
"""Configuration for dataset creation and management."""
|
||||
|
||||
repo_id: str
|
||||
dataset_root: str
|
||||
task: str
|
||||
root: str | None = None
|
||||
num_episodes_to_record: int = 5
|
||||
replay_episode: int | None = None
|
||||
push_to_hub: bool = False
|
||||
num_episodes: int
|
||||
episode: int
|
||||
push_to_hub: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -402,11 +401,13 @@ def make_processors(
|
||||
joint_names=motor_names,
|
||||
)
|
||||
|
||||
env_pipeline_steps = [VanillaObservationProcessor()]
|
||||
env_pipeline_steps = [
|
||||
VanillaObservationProcessor(),
|
||||
]
|
||||
|
||||
if cfg.processor.observation is not None:
|
||||
if cfg.processor.observation.add_joint_velocity_to_observation:
|
||||
env_pipeline_steps.append(JointVelocityProcessor(dt=1.0 / cfg.fps))
|
||||
env_pipeline_steps.append(JointVelocityProcessor(dt=1.0 / cfg.fps, num_dof=len(motor_names)))
|
||||
if cfg.processor.observation.add_current_to_observation:
|
||||
env_pipeline_steps.append(MotorCurrentProcessor(robot=env.robot))
|
||||
|
||||
@@ -472,7 +473,6 @@ def make_processors(
|
||||
if cfg.processor.inverse_kinematics is not None and kinematics_solver is not None:
|
||||
# Add EE bounds and safety processor
|
||||
inverse_kinematics_steps = [
|
||||
MapTensorToDeltaActionDict(),
|
||||
MapDeltaActionToRobotAction(),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
@@ -625,7 +625,7 @@ def control_loop(
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.env.fps,
|
||||
root=cfg.dataset.root,
|
||||
root=cfg.dataset.dataset_root,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
image_writer_processes=0,
|
||||
@@ -636,7 +636,7 @@ def control_loop(
|
||||
episode_step = 0
|
||||
episode_start_time = time.perf_counter()
|
||||
|
||||
while episode_idx < cfg.dataset.num_episodes_to_record:
|
||||
while episode_idx < cfg.dataset.num_episodes:
|
||||
step_start_time = time.perf_counter()
|
||||
|
||||
# Create a neutral action (no movement)
|
||||
@@ -711,12 +711,10 @@ def control_loop(
|
||||
|
||||
def replay_trajectory(env: gym.Env, action_processor: RobotProcessor, cfg: GymManipulatorConfig) -> None:
|
||||
"""Replay recorded trajectory on robot environment."""
|
||||
assert cfg.dataset.replay_episode is not None, "Replay episode must be provided for replay"
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=[cfg.dataset.replay_episode],
|
||||
root=cfg.dataset.dataset_root,
|
||||
episodes=[cfg.dataset.episode],
|
||||
download_videos=False,
|
||||
)
|
||||
dataset_actions = dataset.hf_dataset.select_columns(["action"])
|
||||
|
||||
@@ -302,6 +302,11 @@ class RobotClient:
|
||||
|
||||
self.logger.debug(f"Current latest action: {latest_action}")
|
||||
|
||||
# Get queue state before changes
|
||||
old_size, old_timesteps = self._inspect_action_queue()
|
||||
if not old_timesteps:
|
||||
old_timesteps = [latest_action] # queue was empty
|
||||
|
||||
# Get queue state before changes
|
||||
old_size, old_timesteps = self._inspect_action_queue()
|
||||
if not old_timesteps:
|
||||
|
||||
@@ -32,7 +32,7 @@ from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import cycle
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.factory import make_policy, make_processor
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
@@ -141,7 +141,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
|
||||
)
|
||||
|
||||
|
||||
311
src/lerobot/scripts/train_accelerate.py
Normal file
311
src/lerobot/scripts/train_accelerate.py
Normal file
@@ -0,0 +1,311 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pprint import pformat
|
||||
from typing import Any, Callable
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from termcolor import colored
|
||||
from torch.amp import GradScaler
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
has_method,
|
||||
init_logging,
|
||||
is_launched_with_accelerate,
|
||||
)
|
||||
from lerobot.common.utils.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def update_policy(
|
||||
train_metrics: MetricsTracker,
|
||||
policy: PreTrainedPolicy,
|
||||
batch: Any,
|
||||
optimizer: Optimizer,
|
||||
grad_clip_norm: float,
|
||||
grad_scaler: GradScaler,
|
||||
lr_scheduler=None,
|
||||
use_amp: bool = False,
|
||||
lock=None,
|
||||
accelerator: Callable = None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
policy.train()
|
||||
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
accelerator.backward(loss)
|
||||
accelerator.unscale_gradients(optimizer=optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(),
|
||||
grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
optimizer.step()
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Step through pytorch scheduler at every batch instead of epoch
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
|
||||
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
|
||||
|
||||
train_metrics.loss = loss.item()
|
||||
train_metrics.grad_norm = grad_norm.item()
|
||||
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
||||
train_metrics.update_s = time.perf_counter() - start_time
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig, accelerator: Callable):
|
||||
cfg.validate()
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
if accelerator.is_main_process:
|
||||
# Disable logging on non-main processes.
|
||||
cfg.wandb.enable = False
|
||||
|
||||
if cfg.wandb.enable and cfg.wandb.project:
|
||||
wandb_logger = WandBLogger(cfg)
|
||||
else:
|
||||
wandb_logger = None
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
|
||||
if cfg.seed is not None:
|
||||
set_seed(cfg.seed, accelerator=accelerator)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.device, log=True, accelerator=accelerator)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
eval_env = None
|
||||
if cfg.eval_freq > 0 and cfg.env is not None:
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
|
||||
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
device=device,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
policy.to(device)
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
if accelerator.is_main_process:
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||
logging.info(f"{dataset.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.episode_data_index,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
|
||||
train_metrics = {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
"lr": AverageMeter("lr", ":0.1e"),
|
||||
"update_s": AverageMeter("updt_s", ":.3f"),
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||
}
|
||||
|
||||
train_tracker = MetricsTracker(
|
||||
cfg.batch_size,
|
||||
dataset.num_frames,
|
||||
dataset.num_episodes,
|
||||
train_metrics,
|
||||
initial_step=step,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
train_tracker,
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and accelerator.is_main_process
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps and accelerator.is_main_process
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 and accelerator.is_main_process
|
||||
|
||||
if is_log_step:
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
cfg,
|
||||
accelerator.unwrap_model(policy),
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if cfg.env and is_eval_step:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
|
||||
with torch.no_grad():
|
||||
eval_info = eval_policy(
|
||||
env=eval_env,
|
||||
policy=accelerator.unwrap_model(policy),
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
|
||||
eval_metrics = {
|
||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||
"pc_success": AverageMeter("success", ":.1f"),
|
||||
"eval_s": AverageMeter("eval_s", ":.3f"),
|
||||
}
|
||||
eval_tracker = MetricsTracker(
|
||||
cfg.batch_size,
|
||||
dataset.num_frames,
|
||||
dataset.num_episodes,
|
||||
eval_metrics,
|
||||
initial_step=step,
|
||||
accelerator=None,
|
||||
)
|
||||
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
|
||||
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
|
||||
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
|
||||
logging.info(eval_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
||||
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
||||
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
|
||||
if eval_env:
|
||||
eval_env.close()
|
||||
if not accelerator or accelerator.is_main_process:
|
||||
logging.info("End of training")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
# We set step_scheduler_with_optimizer False to prevent accelerate from
|
||||
# adjusting the lr_scheduler steps based on the num_processes
|
||||
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False)
|
||||
train(accelerator=accelerator)
|
||||
@@ -56,18 +56,11 @@ import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from pprint import pformat
|
||||
|
||||
import draccus
|
||||
import rerun as rr
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.processor import RobotProcessor
|
||||
from lerobot.processor.converters import (
|
||||
to_output_robot_action,
|
||||
to_transition_robot_observation,
|
||||
to_transition_teleop_action,
|
||||
)
|
||||
from lerobot.processor.pipeline import IdentityProcessor
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -104,82 +97,39 @@ class TeleoperateConfig:
|
||||
teleop_time_s: float | None = None
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
# Optional processors for data transformation
|
||||
teleop_action_processor: RobotProcessor | None = None # runs after teleop
|
||||
robot_action_processor: RobotProcessor | None = None # runs before robot
|
||||
robot_observation_processor: RobotProcessor | None = None # runs after robot
|
||||
|
||||
|
||||
def teleop_loop(
|
||||
teleop: Teleoperator,
|
||||
robot: Robot,
|
||||
fps: int,
|
||||
display_data: bool = False,
|
||||
duration: float | None = None,
|
||||
teleop_action_processor: RobotProcessor | None = None,
|
||||
robot_action_processor: RobotProcessor | None = None,
|
||||
robot_observation_processor: RobotProcessor | None = None,
|
||||
teleop: Teleoperator, robot: Robot, fps: int, display_data: bool = False, duration: float | None = None
|
||||
):
|
||||
# Initialize processors with defaults if not provided
|
||||
teleop_action_processor = teleop_action_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()], to_transition=to_transition_teleop_action, to_output=lambda tr: tr
|
||||
)
|
||||
robot_action_processor = robot_action_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=to_output_robot_action, # type: ignore[arg-type]
|
||||
)
|
||||
robot_observation_processor = robot_observation_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()], to_transition=to_transition_robot_observation, to_output=lambda tr: tr
|
||||
)
|
||||
|
||||
# Reset processors
|
||||
teleop_action_processor.reset()
|
||||
robot_action_processor.reset()
|
||||
robot_observation_processor.reset()
|
||||
|
||||
display_len = max(len(key) for key in robot.action_features)
|
||||
start = time.perf_counter()
|
||||
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get teleop action
|
||||
raw_action = teleop.get_action()
|
||||
|
||||
# Process teleop action through pipeline
|
||||
teleop_transition = teleop_action_processor(raw_action)
|
||||
|
||||
# Process action for robot through pipeline
|
||||
robot_action_to_send = robot_action_processor(teleop_transition)
|
||||
|
||||
# Send processed action to robot (robot_action_processor.to_output should return dict[str, Any])
|
||||
robot.send_action(robot_action_to_send) # type: ignore[arg-type]
|
||||
|
||||
action = teleop.get_action()
|
||||
if display_data:
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
# Process robot observation through pipeline
|
||||
obs_transition = robot_observation_processor(obs)
|
||||
log_rerun_data([obs_transition, teleop_transition])
|
||||
|
||||
print("\n" + "-" * (display_len + 10))
|
||||
print(f"{'NAME':<{display_len}} | {'NORM':>7}")
|
||||
# Display the final robot action that was sent
|
||||
for motor, value in robot_action_to_send.items():
|
||||
print(f"{motor:<{display_len}} | {value:>7.2f}")
|
||||
move_cursor_up(len(robot_action_to_send) + 5)
|
||||
observation = robot.get_observation()
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
|
||||
robot.send_action(action)
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
loop_s = time.perf_counter() - loop_start
|
||||
|
||||
print("\n" + "-" * (display_len + 10))
|
||||
print(f"{'NAME':<{display_len}} | {'NORM':>7}")
|
||||
for motor, value in action.items():
|
||||
print(f"{motor:<{display_len}} | {value:>7.2f}")
|
||||
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
|
||||
|
||||
if duration is not None and time.perf_counter() - start >= duration:
|
||||
return
|
||||
|
||||
move_cursor_up(len(action) + 5)
|
||||
|
||||
@parser.wrap()
|
||||
|
||||
@draccus.wrap()
|
||||
def teleoperate(cfg: TeleoperateConfig):
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
@@ -193,16 +143,7 @@ def teleoperate(cfg: TeleoperateConfig):
|
||||
robot.connect()
|
||||
|
||||
try:
|
||||
teleop_loop(
|
||||
teleop=teleop,
|
||||
robot=robot,
|
||||
fps=cfg.fps,
|
||||
display_data=cfg.display_data,
|
||||
duration=cfg.teleop_time_s,
|
||||
teleop_action_processor=cfg.teleop_action_processor,
|
||||
robot_action_processor=cfg.robot_action_processor,
|
||||
robot_observation_processor=cfg.robot_observation_processor,
|
||||
)
|
||||
teleop_loop(teleop, robot, cfg.fps, display_data=cfg.display_data, duration=cfg.teleop_time_s)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
|
||||
@@ -177,6 +177,16 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
"names": {"action.delta_x": 0, "action.delta_y": 1, "action.delta_z": 2},
|
||||
}
|
||||
|
||||
def _on_press(self, key):
|
||||
if hasattr(key, "char"):
|
||||
key = key.char
|
||||
self.event_queue.put((key, True))
|
||||
|
||||
def _on_release(self, key):
|
||||
if hasattr(key, "char"):
|
||||
key = key.char
|
||||
self.event_queue.put((key, False))
|
||||
|
||||
def get_action(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
|
||||
246
src/lerobot/teleoperators/phone/phone.py
Normal file
246
src/lerobot/teleoperators/phone/phone.py
Normal file
@@ -0,0 +1,246 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Docs:
|
||||
# hebi: https://docs.hebi.us/tools.html#mobile-io
|
||||
# teleop: https://github.com/SpesRobotics/teleop
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import hebi
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
from teleop import Teleop
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Phone(Teleoperator):
|
||||
"""
|
||||
Phone-based teleoperator using ARKit (iOS via HEBI Mobile I/O App) or the teleop Python package (Android via WebXR API).
|
||||
For HEBI Mobile I/O we also expose 8 analog (a1-a8) and 8 digital (b1-b8) inputs.
|
||||
|
||||
Press and hold **B1** to enable teleoperation. While enabled, the first B1 press
|
||||
captures a reference pose and rotation, when disabled and pressed again the position is reapplied.
|
||||
"""
|
||||
|
||||
config_class = PhoneConfig
|
||||
name = "phone"
|
||||
|
||||
def __init__(self, config: PhoneConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._group = None
|
||||
self._teleop = None
|
||||
self._teleop_thread = None
|
||||
self._latest_pose = None
|
||||
self._latest_message = None
|
||||
self._enabled: bool = False
|
||||
self._calib_pos: np.ndarray | None = None
|
||||
self._calib_rot_inv: Rotation | None = None
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return (self.config.phone_os == PhoneOS.IOS and self._group is not None) or (
|
||||
self.config.phone_os == PhoneOS.ANDROID and self._teleop is not None
|
||||
)
|
||||
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.")
|
||||
lookup = hebi.Lookup()
|
||||
time.sleep(2.0)
|
||||
group = lookup.get_group_from_names(["HEBI"], ["mobileIO"])
|
||||
if group is None:
|
||||
raise RuntimeError("Mobile I/O not found — check name/family settings in the app.")
|
||||
self._group = group
|
||||
logger.info(f"{self} connected to HEBI group with {group.size} module(s).")
|
||||
elif self.config.phone_os == PhoneOS.ANDROID:
|
||||
logger.info("Starting teleop stream for Android...")
|
||||
self._teleop = Teleop()
|
||||
self._teleop.subscribe(self._android_callback)
|
||||
self._teleop_thread = threading.Thread(target=self._teleop.run, daemon=True)
|
||||
self._teleop_thread.start()
|
||||
logger.info(f"{self} connected, teleop stream started.")
|
||||
else:
|
||||
raise ValueError(f"Invalid config phone_os: {self.config.phone_os}")
|
||||
|
||||
self.calibrate()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
print(
|
||||
"Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
|
||||
)
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n")
|
||||
else:
|
||||
print("Touch and move on the WebXR page to capture this pose...\n")
|
||||
|
||||
pos, rot = self._wait_for_capture_trigger()
|
||||
self._calib_pos = pos.copy()
|
||||
self._calib_rot_inv = rot.inv()
|
||||
self._enabled = False
|
||||
print("Calibration done\n")
|
||||
|
||||
def _reapply_position_calibration(self, pos: np.ndarray) -> None:
|
||||
self._calib_pos = pos.copy()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return (self._calib_pos is not None) and (self._calib_rot_inv is not None)
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {
|
||||
"phone.pos": np.ndarray, # shape (3,)
|
||||
"phone.rot": Rotation, # scipy.spatial.transform.Rotation
|
||||
"phone.raw_inputs": dict, # analogs/buttons or webXR meta
|
||||
"phone.enabled": bool,
|
||||
}
|
||||
|
||||
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
|
||||
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
|
||||
while True:
|
||||
ok, pos, rot, pose = self._read_current_pose()
|
||||
if not ok:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
io = getattr(pose, "io", None)
|
||||
b = getattr(io, "b", None) if io is not None else None
|
||||
b1 = False
|
||||
if b is not None:
|
||||
b1 = bool(b.get_int(1))
|
||||
if b1:
|
||||
return pos, rot
|
||||
else:
|
||||
msg = self._latest_message or {}
|
||||
if bool(msg.get("move", False)):
|
||||
return pos, rot
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
fbk = self._group.get_next_feedback()
|
||||
pose = fbk[0]
|
||||
ar_pos = getattr(pose, "ar_position", None)
|
||||
ar_quat = getattr(pose, "ar_orientation", None)
|
||||
if ar_pos is None or ar_quat is None:
|
||||
return False, None, None, None
|
||||
quat_xyzw = np.concatenate((ar_quat[1:], [ar_quat[0]])) # wxyz to xyzw
|
||||
rot = Rotation.from_quat(quat_xyzw)
|
||||
pos = ar_pos - rot.apply(self.config.camera_offset)
|
||||
return True, pos, rot, pose
|
||||
else:
|
||||
p = self._latest_pose
|
||||
if p is None:
|
||||
return False, None, None, None
|
||||
rot = Rotation.from_matrix(p[:3, :3])
|
||||
pos = p[:3, 3] - rot.apply(self.config.camera_offset)
|
||||
pose = self._latest_pose
|
||||
return True, pos, rot, pose
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
# No haptic or other feedback implemented yet
|
||||
pass
|
||||
|
||||
def configure(self) -> None:
|
||||
# No additional configuration required for phone teleop
|
||||
pass
|
||||
|
||||
def _android_callback(self, pose: np.ndarray, message: dict) -> None:
|
||||
self._latest_pose = pose
|
||||
self._latest_message = message
|
||||
time.sleep(0.001) # 1ms delay to avoid race condition
|
||||
|
||||
def get_action(self) -> dict:
|
||||
ok, raw_pos, raw_rot, pose = self._read_current_pose()
|
||||
if not ok or not self.is_calibrated:
|
||||
return {}
|
||||
|
||||
# Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
|
||||
raw_inputs: dict[str, float | int | bool] = {}
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
io = getattr(pose, "io", None)
|
||||
if io is not None:
|
||||
bank_a, bank_b = io.a, io.b
|
||||
if bank_a:
|
||||
for ch in range(1, 9):
|
||||
if bank_a.has_float(ch):
|
||||
raw_inputs[f"a{ch}"] = float(bank_a.get_float(ch))
|
||||
if bank_b:
|
||||
for ch in range(1, 9):
|
||||
if bank_b.has_int(ch):
|
||||
raw_inputs[f"b{ch}"] = int(bank_b.get_int(ch))
|
||||
elif hasattr(bank_b, "has_bool") and bank_b.has_bool(ch):
|
||||
raw_inputs[f"b{ch}"] = int(bank_b.get_bool(ch))
|
||||
else:
|
||||
msg = self._latest_message or {}
|
||||
raw_inputs["move"] = bool(msg.get("move", False))
|
||||
raw_inputs["scale"] = float(msg.get("scale", 1.0))
|
||||
raw_inputs["reservedButtonA"] = bool(msg.get("reservedButtonA", False))
|
||||
raw_inputs["reservedButtonB"] = bool(msg.get("reservedButtonB", False))
|
||||
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
enable = bool(raw_inputs.get("b1", 0))
|
||||
else:
|
||||
enable = bool(raw_inputs.get("move", False))
|
||||
|
||||
# Rising edge then re-capture calibration immediately from current raw pose
|
||||
if enable and not self._enabled:
|
||||
self._reapply_position_calibration(raw_pos)
|
||||
|
||||
# Apply calibration
|
||||
pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos)
|
||||
rot_cal = self._calib_rot_inv * raw_rot
|
||||
|
||||
self._enabled = enable
|
||||
|
||||
return {
|
||||
"phone.pos": pos_cal,
|
||||
"phone.rot": rot_cal,
|
||||
"phone.raw_inputs": raw_inputs,
|
||||
"phone.enabled": self._enabled,
|
||||
}
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# We could add haptic feedback (vibrations) here, but it's not implemented yet
|
||||
raise NotImplementedError
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
self._group = None
|
||||
else:
|
||||
self._teleop = None
|
||||
if self._teleop_thread and self._teleop_thread.is_alive():
|
||||
self._teleop_thread.join(timeout=1.0)
|
||||
self._teleop_thread = None
|
||||
self._latest_pose = None
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneOS
|
||||
|
||||
@@ -46,9 +46,9 @@ class MapPhoneActionToRobotAction(ActionProcessor):
|
||||
platform: PhoneOS
|
||||
_enabled_prev: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def action(self, act: dict) -> dict:
|
||||
def action(self, act: dict | None) -> dict:
|
||||
# Pop them from the action
|
||||
enabled = bool(act.pop("action.phone.enabled", 0))
|
||||
enabled = act.pop("action.phone.enabled", 0)
|
||||
pos = act.pop("action.phone.pos", None)
|
||||
rot = act.pop("action.phone.rot", None)
|
||||
inputs = act.pop("action.phone.raw_inputs", {})
|
||||
@@ -69,28 +69,19 @@ class MapPhoneActionToRobotAction(ActionProcessor):
|
||||
) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed
|
||||
|
||||
# For some actions we need to invert the axis
|
||||
act["action.enabled"] = enabled
|
||||
act["action.target_x"] = -pos[1] if enabled else 0.0
|
||||
act["action.target_y"] = pos[0] if enabled else 0.0
|
||||
act["action.target_z"] = pos[2] if enabled else 0.0
|
||||
act["action.target_wx"] = rotvec[1] if enabled else 0.0
|
||||
act["action.target_wy"] = rotvec[0] if enabled else 0.0
|
||||
act["action.target_wz"] = -rotvec[2] if enabled else 0.0
|
||||
act["action.gripper"] = gripper # Still send gripper action when disabled
|
||||
act.update(
|
||||
{
|
||||
"action.enabled": enabled,
|
||||
"action.target_x": -pos[1] if enabled else 0.0,
|
||||
"action.target_y": pos[0] if enabled else 0.0,
|
||||
"action.target_z": pos[2] if enabled else 0.0,
|
||||
"action.target_wx": rotvec[1] if enabled else 0.0,
|
||||
"action.target_wy": rotvec[0] if enabled else 0.0,
|
||||
"action.target_wz": -rotvec[2] if enabled else 0.0,
|
||||
"action.gripper": gripper, # Still send gripper action when disabled
|
||||
}
|
||||
)
|
||||
return act
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop("action.phone.enabled", None)
|
||||
features.pop("action.phone.pos", None)
|
||||
features.pop("action.phone.rot", None)
|
||||
features.pop("action.phone.raw_inputs", None)
|
||||
|
||||
features["action.enabled"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.gripper"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
return features
|
||||
|
||||
@@ -1,359 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Docs:
|
||||
# hebi: https://docs.hebi.us/tools.html#mobile-io
|
||||
# teleop: https://github.com/SpesRobotics/teleop
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import hebi
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
from teleop import Teleop
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasePhone:
|
||||
_enabled: bool = False
|
||||
_calib_pos: np.ndarray | None = None
|
||||
_calib_rot_inv: Rotation | None = None
|
||||
|
||||
def _reapply_position_calibration(self, pos: np.ndarray) -> None:
|
||||
self._calib_pos = pos.copy()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return (self._calib_pos is not None) and (self._calib_rot_inv is not None)
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {
|
||||
"phone.pos": np.ndarray, # shape (3,)
|
||||
"phone.rot": Rotation, # scipy.spatial.transform.Rotation
|
||||
"phone.raw_inputs": dict, # analogs/buttons or webXR meta
|
||||
"phone.enabled": bool,
|
||||
}
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
# No haptic or other feedback implemented yet
|
||||
pass
|
||||
|
||||
def configure(self) -> None:
|
||||
# No additional configuration required for phone teleop
|
||||
pass
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# We could add haptic feedback (vibrations) here, but it's not implemented yet
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IOSPhone(BasePhone, Teleoperator):
|
||||
name = "ios_phone"
|
||||
|
||||
def __init__(self, config: PhoneConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._group = None
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._group is not None
|
||||
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.")
|
||||
lookup = hebi.Lookup()
|
||||
time.sleep(2.0)
|
||||
group = lookup.get_group_from_names(["HEBI"], ["mobileIO"])
|
||||
if group is None:
|
||||
raise RuntimeError("Mobile I/O not found — check name/family settings in the app.")
|
||||
self._group = group
|
||||
logger.info(f"{self} connected to HEBI group with {group.size} module(s).")
|
||||
|
||||
self.calibrate()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
print(
|
||||
"Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
|
||||
)
|
||||
print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n")
|
||||
|
||||
pos, rot = self._wait_for_capture_trigger()
|
||||
self._calib_pos = pos.copy()
|
||||
self._calib_rot_inv = rot.inv()
|
||||
self._enabled = False
|
||||
print("Calibration done\n")
|
||||
|
||||
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
|
||||
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
|
||||
while True:
|
||||
ok, pos, rot, pose = self._read_current_pose()
|
||||
if not ok:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
io = getattr(pose, "io", None)
|
||||
b = getattr(io, "b", None) if io is not None else None
|
||||
b1 = False
|
||||
if b is not None:
|
||||
b1 = bool(b.get_int(1))
|
||||
if b1:
|
||||
return pos, rot
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
|
||||
fbk = self._group.get_next_feedback()
|
||||
pose = fbk[0]
|
||||
ar_pos = getattr(pose, "ar_position", None)
|
||||
ar_quat = getattr(pose, "ar_orientation", None)
|
||||
if ar_pos is None or ar_quat is None:
|
||||
return False, None, None, None
|
||||
# HEBI provides orientation in w, x, y, z format.
|
||||
# Scipy's Rotation expects x, y, z, w.
|
||||
quat_xyzw = np.concatenate((ar_quat[1:], [ar_quat[0]])) # wxyz to xyzw
|
||||
rot = Rotation.from_quat(quat_xyzw)
|
||||
pos = ar_pos - rot.apply(self.config.camera_offset)
|
||||
return True, pos, rot, pose
|
||||
|
||||
def get_action(self) -> dict:
|
||||
ok, raw_pos, raw_rot, pose = self._read_current_pose()
|
||||
if not ok or not self.is_calibrated:
|
||||
return {}
|
||||
|
||||
# Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
|
||||
raw_inputs: dict[str, float | int | bool] = {}
|
||||
io = getattr(pose, "io", None)
|
||||
if io is not None:
|
||||
bank_a, bank_b = io.a, io.b
|
||||
if bank_a:
|
||||
for ch in range(1, 9):
|
||||
if bank_a.has_float(ch):
|
||||
raw_inputs[f"a{ch}"] = float(bank_a.get_float(ch))
|
||||
if bank_b:
|
||||
for ch in range(1, 9):
|
||||
if bank_b.has_int(ch):
|
||||
raw_inputs[f"b{ch}"] = int(bank_b.get_int(ch))
|
||||
elif hasattr(bank_b, "has_bool") and bank_b.has_bool(ch):
|
||||
raw_inputs[f"b{ch}"] = int(bank_b.get_bool(ch))
|
||||
|
||||
enable = bool(raw_inputs.get("b1", 0))
|
||||
|
||||
# Rising edge then re-capture calibration immediately from current raw pose
|
||||
if enable and not self._enabled:
|
||||
self._reapply_position_calibration(raw_pos)
|
||||
|
||||
# Apply calibration
|
||||
pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos)
|
||||
rot_cal = self._calib_rot_inv * raw_rot
|
||||
|
||||
self._enabled = enable
|
||||
|
||||
return {
|
||||
"phone.pos": pos_cal,
|
||||
"phone.rot": rot_cal,
|
||||
"phone.raw_inputs": raw_inputs,
|
||||
"phone.enabled": self._enabled,
|
||||
}
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._group = None
|
||||
|
||||
|
||||
class AndroidPhone(BasePhone, Teleoperator):
|
||||
name = "android_phone"
|
||||
|
||||
def __init__(self, config: PhoneConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._teleop = None
|
||||
self._teleop_thread = None
|
||||
self._latest_pose = None
|
||||
self._latest_message = None
|
||||
self._android_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._teleop is not None
|
||||
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
logger.info("Starting teleop stream for Android...")
|
||||
self._teleop = Teleop()
|
||||
self._teleop.subscribe(self._android_callback)
|
||||
self._teleop_thread = threading.Thread(target=self._teleop.run, daemon=True)
|
||||
self._teleop_thread.start()
|
||||
logger.info(f"{self} connected, teleop stream started.")
|
||||
|
||||
self.calibrate()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
print(
|
||||
"Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
|
||||
)
|
||||
print("Touch and move on the WebXR page to capture this pose...\n")
|
||||
|
||||
pos, rot = self._wait_for_capture_trigger()
|
||||
self._calib_pos = pos.copy()
|
||||
self._calib_rot_inv = rot.inv()
|
||||
self._enabled = False
|
||||
print("Calibration done\n")
|
||||
|
||||
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
|
||||
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
|
||||
while True:
|
||||
with self._android_lock:
|
||||
msg = self._latest_message or {}
|
||||
|
||||
if bool(msg.get("move", False)):
|
||||
ok, pos, rot, _pose = self._read_current_pose()
|
||||
if ok:
|
||||
return pos, rot
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
|
||||
with self._android_lock:
|
||||
if self._latest_pose is None:
|
||||
return False, None, None, None
|
||||
p = self._latest_pose.copy()
|
||||
pose = self._latest_pose
|
||||
rot = Rotation.from_matrix(p[:3, :3])
|
||||
pos = p[:3, 3] - rot.apply(self.config.camera_offset)
|
||||
return True, pos, rot, pose
|
||||
|
||||
def _android_callback(self, pose: np.ndarray, message: dict) -> None:
|
||||
with self._android_lock:
|
||||
self._latest_pose = pose
|
||||
self._latest_message = message
|
||||
|
||||
def get_action(self) -> dict:
|
||||
ok, raw_pos, raw_rot, pose = self._read_current_pose()
|
||||
if not ok or not self.is_calibrated:
|
||||
return {}
|
||||
|
||||
# Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
|
||||
raw_inputs: dict[str, float | int | bool] = {}
|
||||
msg = self._latest_message or {}
|
||||
raw_inputs["move"] = bool(msg.get("move", False))
|
||||
raw_inputs["scale"] = float(msg.get("scale", 1.0))
|
||||
raw_inputs["reservedButtonA"] = bool(msg.get("reservedButtonA", False))
|
||||
raw_inputs["reservedButtonB"] = bool(msg.get("reservedButtonB", False))
|
||||
|
||||
enable = bool(raw_inputs.get("move", False))
|
||||
|
||||
# Rising edge then re-capture calibration immediately from current raw pose
|
||||
if enable and not self._enabled:
|
||||
self._reapply_position_calibration(raw_pos)
|
||||
|
||||
# Apply calibration
|
||||
pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos)
|
||||
rot_cal = self._calib_rot_inv * raw_rot
|
||||
|
||||
self._enabled = enable
|
||||
|
||||
return {
|
||||
"phone.pos": pos_cal,
|
||||
"phone.rot": rot_cal,
|
||||
"phone.raw_inputs": raw_inputs,
|
||||
"phone.enabled": self._enabled,
|
||||
}
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._teleop = None
|
||||
if self._teleop_thread and self._teleop_thread.is_alive():
|
||||
self._teleop_thread.join(timeout=1.0)
|
||||
self._teleop_thread = None
|
||||
self._latest_pose = None
|
||||
|
||||
|
||||
class Phone(Teleoperator):
|
||||
"""
|
||||
Phone-based teleoperator using ARKit (iOS via HEBI Mobile I/O App) or the teleop Python package (Android via WebXR API).
|
||||
For HEBI Mobile I/O we also expose 8 analog (a1-a8) and 8 digital (b1-b8) inputs.
|
||||
|
||||
Press and hold **B1** to enable teleoperation. While enabled, the first B1 press
|
||||
captures a reference pose and rotation, when disabled and pressed again the position is reapplied.
|
||||
"""
|
||||
|
||||
config_class = PhoneConfig
|
||||
name = "phone"
|
||||
|
||||
def __init__(self, config: PhoneConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self._phone_impl: Teleoperator
|
||||
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
self._phone_impl = IOSPhone(config)
|
||||
elif self.config.phone_os == PhoneOS.ANDROID:
|
||||
self._phone_impl = AndroidPhone(config)
|
||||
else:
|
||||
raise ValueError(f"Invalid config phone_os: {self.config.phone_os}")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._phone_impl.is_connected
|
||||
|
||||
def connect(self) -> None:
|
||||
return self._phone_impl.connect()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
return self._phone_impl.calibrate()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self._phone_impl.is_calibrated
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._phone_impl.action_features
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return self._phone_impl.feedback_features
|
||||
|
||||
def configure(self) -> None:
|
||||
return self._phone_impl.configure()
|
||||
|
||||
def get_action(self) -> dict:
|
||||
return self._phone_impl.get_action()
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
return self._phone_impl.send_feedback(feedback)
|
||||
|
||||
def disconnect(self) -> None:
|
||||
return self._phone_impl.disconnect()
|
||||
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from lerobot.utils.utils import format_big_number
|
||||
|
||||
@@ -84,6 +84,7 @@ class MetricsTracker:
|
||||
"samples",
|
||||
"episodes",
|
||||
"epochs",
|
||||
"accelerator",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@@ -93,12 +94,14 @@ class MetricsTracker:
|
||||
num_episodes: int,
|
||||
metrics: dict[str, AverageMeter],
|
||||
initial_step: int = 0,
|
||||
accelerator: Callable | None = None,
|
||||
):
|
||||
self.__dict__.update(dict.fromkeys(self.__keys__))
|
||||
self._batch_size = batch_size
|
||||
self._num_frames = num_frames
|
||||
self._avg_samples_per_ep = num_frames / num_episodes
|
||||
self.metrics = metrics
|
||||
self.accelerator = accelerator
|
||||
|
||||
self.steps = initial_step
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
@@ -128,7 +131,7 @@ class MetricsTracker:
|
||||
Updates metrics that depend on 'step' for one step.
|
||||
"""
|
||||
self.steps += 1
|
||||
self.samples += self._batch_size
|
||||
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import random
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Callable, Generator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -164,7 +164,7 @@ def set_rng_state(random_state_dict: dict[str, Any]):
|
||||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||
|
||||
|
||||
def set_seed(seed) -> None:
|
||||
def set_seed(seed: int, accelerator: Callable | None = None) -> None:
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
@@ -172,6 +172,11 @@ def set_seed(seed) -> None:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if accelerator:
|
||||
from accelerate.utils import set_seed as accelerate_set_seed
|
||||
|
||||
accelerate_set_seed(seed)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||
|
||||
@@ -24,6 +24,7 @@ import time
|
||||
from copy import copy, deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from statistics import mean
|
||||
|
||||
import numpy as np
|
||||
@@ -56,13 +57,15 @@ def auto_select_torch_device() -> torch.device:
|
||||
|
||||
|
||||
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
|
||||
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||
def get_safe_torch_device(
|
||||
try_device: str, log: bool = False, accelerator: Callable | None = None
|
||||
) -> torch.device:
|
||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||
try_device = str(try_device)
|
||||
match try_device:
|
||||
case "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
device = torch.device("cuda")
|
||||
device = accelerator.device if accelerator else torch.device("cuda")
|
||||
case "mps":
|
||||
assert torch.backends.mps.is_available()
|
||||
device = torch.device("mps")
|
||||
@@ -116,6 +119,7 @@ def init_logging(
|
||||
display_pid: bool = False,
|
||||
console_level: str = "INFO",
|
||||
file_level: str = "DEBUG",
|
||||
accelerator: Callable | None = None,
|
||||
):
|
||||
def custom_format(record: logging.LogRecord) -> str:
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -152,6 +156,11 @@ def init_logging(
|
||||
file_handler.setLevel(file_level.upper())
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
if accelerator is not None and not accelerator.is_main_process:
|
||||
# Disable duplicate logging on non-main processes
|
||||
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def format_big_number(num, precision=0):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
@@ -165,6 +174,10 @@ def format_big_number(num, precision=0):
|
||||
return num
|
||||
|
||||
|
||||
def is_launched_with_accelerate() -> bool:
|
||||
return "ACCELERATE_MIXED_PRECISION" in os.environ
|
||||
|
||||
|
||||
def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||
"""Returns path1 relative to path2."""
|
||||
path1 = path1.absolute()
|
||||
|
||||
@@ -23,7 +23,7 @@ from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors
|
||||
from lerobot.policies.factory import make_policy, make_policy_config, make_processor
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
|
||||
@@ -40,7 +40,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
dataset = make_dataset(train_cfg)
|
||||
dataset_stats = dataset.meta.stats
|
||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor, postprocessor = make_pre_post_processors(train_cfg.policy, dataset_stats=dataset_stats)
|
||||
preprocessor, postprocessor = make_processor(train_cfg.policy, dataset_stats=dataset_stats)
|
||||
policy.train()
|
||||
|
||||
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
|
||||
|
||||
@@ -39,7 +39,7 @@ from lerobot.policies.factory import (
|
||||
get_policy_class,
|
||||
make_policy,
|
||||
make_policy_config,
|
||||
make_pre_post_processors,
|
||||
make_processor,
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
@@ -151,7 +151,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(train_cfg)
|
||||
preprocessor, _ = make_pre_post_processors(train_cfg.policy, None)
|
||||
preprocessor, _ = make_processor(train_cfg.policy, None)
|
||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
|
||||
assert isinstance(policy, PreTrainedPolicy)
|
||||
|
||||
@@ -225,7 +225,7 @@ def test_act_backbone_lr():
|
||||
assert cfg.policy.optimizer_lr_backbone == 0.001
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
preprocessor, _ = make_pre_post_processors(cfg.policy, None)
|
||||
preprocessor, _ = make_processor(cfg.policy, None)
|
||||
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
assert len(optimizer.param_groups) == 2
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.processor_act import make_act_pre_post_processors
|
||||
from lerobot.policies.act.processor_act import make_act_processor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
@@ -78,7 +78,7 @@ def test_make_act_processor_basic():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
@@ -102,12 +102,7 @@ def test_act_processor_normalization():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
@@ -136,12 +131,7 @@ def test_act_processor_cuda():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
@@ -170,12 +160,7 @@ def test_act_processor_accelerate_scenario():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
@@ -198,7 +183,7 @@ def test_act_processor_multi_gpu():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU (like in multi-GPU training)
|
||||
device = torch.device("cuda:1")
|
||||
@@ -218,7 +203,7 @@ def test_act_processor_without_stats():
|
||||
"""Test ACT processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, dataset_stats=None)
|
||||
preprocessor, postprocessor = make_act_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors, but normalization won't have stats
|
||||
assert preprocessor is not None
|
||||
@@ -238,21 +223,14 @@ def test_act_processor_save_and_load():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
@@ -271,12 +249,7 @@ def test_act_processor_device_placement_preservation():
|
||||
|
||||
# Test with CPU config
|
||||
config.device = "cpu"
|
||||
preprocessor, _ = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, _ = make_act_processor(config, stats)
|
||||
|
||||
# Process CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
@@ -296,21 +269,12 @@ def test_act_processor_mixed_precision():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Modify the device processor to use float16
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
if isinstance(step, DeviceProcessor):
|
||||
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)}
|
||||
@@ -330,12 +294,7 @@ def test_act_processor_batch_consistency():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Test single sample (unbatched)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
|
||||
@@ -245,7 +245,7 @@ def test_mixed_observation():
|
||||
def test_integration_with_robot_processor():
|
||||
"""Test ToBatchProcessor integration with RobotProcessor."""
|
||||
to_batch_processor = ToBatchProcessor()
|
||||
pipeline = RobotProcessor([to_batch_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
pipeline = RobotProcessor([to_batch_processor])
|
||||
|
||||
# Create unbatched observation
|
||||
observation = {
|
||||
@@ -285,9 +285,7 @@ def test_serialization_methods():
|
||||
def test_save_and_load_pretrained():
|
||||
"""Test saving and loading ToBatchProcessor with RobotProcessor."""
|
||||
processor = ToBatchProcessor()
|
||||
pipeline = RobotProcessor(
|
||||
[processor], name="BatchPipeline", to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
pipeline = RobotProcessor([processor], name="BatchPipeline")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save pipeline
|
||||
@@ -298,9 +296,7 @@ def test_save_and_load_pretrained():
|
||||
assert config_path.exists()
|
||||
|
||||
# Load pipeline
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
|
||||
assert loaded_pipeline.name == "BatchPipeline"
|
||||
assert len(loaded_pipeline) == 1
|
||||
@@ -327,13 +323,11 @@ def test_registry_functionality():
|
||||
def test_registry_based_save_load():
|
||||
"""Test saving and loading using registry name."""
|
||||
processor = ToBatchProcessor()
|
||||
pipeline = RobotProcessor([processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
pipeline = RobotProcessor([processor])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
|
||||
# Verify the loaded processor works
|
||||
observation = {
|
||||
@@ -609,6 +603,24 @@ def test_action_dtype_preservation():
|
||||
assert result[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
|
||||
def test_action_in_place_mutation():
|
||||
"""Test that the processor mutates the transition in place for actions."""
|
||||
processor = ToBatchProcessor()
|
||||
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(action=action)
|
||||
|
||||
# Store reference to original transition
|
||||
original_transition = transition
|
||||
|
||||
# Process
|
||||
result = processor(transition)
|
||||
|
||||
# Should be the same object (in-place mutation)
|
||||
assert result is original_transition
|
||||
assert result[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
|
||||
def test_empty_action_tensor():
|
||||
"""Test handling of empty action tensors."""
|
||||
processor = ToBatchProcessor()
|
||||
@@ -839,6 +851,27 @@ def test_task_comprehensive_string_cases():
|
||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert processed_comp_data["task"] == task_list
|
||||
assert isinstance(processed_comp_data["task"], list)
|
||||
assert processed_comp_data["task"] is task_list # Should be same object (in-place)
|
||||
|
||||
|
||||
def test_task_in_place_mutation():
|
||||
"""Test that the processor mutates complementary_data in place for tasks."""
|
||||
processor = ToBatchProcessor()
|
||||
|
||||
complementary_data = {"task": "sort_objects"}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
|
||||
# Store reference to original transition and complementary_data
|
||||
original_transition = transition
|
||||
original_comp_data = complementary_data
|
||||
|
||||
# Process
|
||||
result = processor(transition)
|
||||
|
||||
# Should be the same objects (in-place mutation)
|
||||
assert result is original_transition
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA] is original_comp_data
|
||||
assert original_comp_data["task"] == ["sort_objects"]
|
||||
|
||||
|
||||
def test_task_preserves_other_keys():
|
||||
@@ -1094,49 +1127,3 @@ def test_empty_index_tensor():
|
||||
|
||||
# Should remain unchanged (already 1D)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["index"].shape == (0,)
|
||||
|
||||
|
||||
def test_action_processing_creates_new_transition():
|
||||
"""Test that the processor creates a new transition object with correctly processed action."""
|
||||
processor = ToBatchProcessor()
|
||||
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(action=action)
|
||||
|
||||
# Store reference to original transition
|
||||
original_transition = transition
|
||||
|
||||
# Process
|
||||
result = processor(transition)
|
||||
|
||||
# Should be a different object (functional design, not in-place mutation)
|
||||
assert result is not original_transition
|
||||
# Original transition should remain unchanged
|
||||
assert original_transition[TransitionKey.ACTION].shape == (4,)
|
||||
# Result should have correctly processed action with batch dimension
|
||||
assert result[TransitionKey.ACTION].shape == (1, 4)
|
||||
assert torch.equal(result[TransitionKey.ACTION][0], action)
|
||||
|
||||
|
||||
def test_task_processing_creates_new_transition():
|
||||
"""Test that the processor creates a new transition object with correctly processed task."""
|
||||
processor = ToBatchProcessor()
|
||||
|
||||
complementary_data = {"task": "sort_objects"}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
|
||||
# Store reference to original transition and complementary_data
|
||||
original_transition = transition
|
||||
original_comp_data = complementary_data
|
||||
|
||||
# Process
|
||||
result = processor(transition)
|
||||
|
||||
# Should be different transition object (functional design)
|
||||
assert result is not original_transition
|
||||
# But complementary_data is the same reference (current implementation behavior)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA] is original_comp_data
|
||||
# The task should be processed correctly (wrapped in list)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["sort_objects"]
|
||||
# Original complementary data is also modified (current behavior)
|
||||
assert original_comp_data["task"] == ["sort_objects"]
|
||||
|
||||
@@ -97,12 +97,7 @@ def test_classifier_processor_normalization():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
@@ -128,12 +123,7 @@ def test_classifier_processor_cuda():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
@@ -166,12 +156,7 @@ def test_classifier_processor_accelerate_scenario():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
@@ -245,22 +230,14 @@ def test_classifier_processor_save_and_load():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(
|
||||
factory_preprocessor.steps, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {
|
||||
@@ -283,19 +260,13 @@ def test_classifier_processor_mixed_precision():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_classifier_processor(config, stats)
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
modified_steps = []
|
||||
for step in factory_preprocessor.steps:
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
if isinstance(step, DeviceProcessor):
|
||||
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
|
||||
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
@@ -319,12 +290,7 @@ def test_classifier_processor_batch_data():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Test with batched data
|
||||
batch_size = 16
|
||||
@@ -349,12 +315,7 @@ def test_classifier_processor_postprocessor_identity():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Create test data for postprocessor
|
||||
reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions
|
||||
|
||||
@@ -5,7 +5,6 @@ import torch
|
||||
from lerobot.processor.converters import (
|
||||
to_dataset_frame,
|
||||
to_output_robot_action,
|
||||
to_tensor,
|
||||
to_transition_robot_observation,
|
||||
to_transition_teleop_action,
|
||||
)
|
||||
@@ -13,12 +12,12 @@ from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def test_to_transition_teleop_action_prefix_and_tensor_conversion():
|
||||
# Scalars, arrays, and uint8 arrays are all converted to tensors
|
||||
# Scalars, arrays, and "image-like" uint8 arrays are supported
|
||||
img = np.zeros((8, 12, 3), dtype=np.uint8)
|
||||
act = {
|
||||
"ee.x": 0.5, # scalar to torch tensor
|
||||
"delta": np.array([1.0, 2.0]), # ndarray to torch tensor
|
||||
"raw_img": img, # uint8 HWC to torch tensor
|
||||
"raw_img": img, # uint8 HWC to passthrough ndarray
|
||||
}
|
||||
|
||||
tr = to_transition_teleop_action(act)
|
||||
@@ -30,7 +29,7 @@ def test_to_transition_teleop_action_prefix_and_tensor_conversion():
|
||||
assert "action.delta" in tr[TransitionKey.ACTION]
|
||||
assert "action.raw_img" in tr[TransitionKey.ACTION]
|
||||
|
||||
# Types: all values -> torch tensor
|
||||
# Types: scalars/arrays -> torch tensor; images to np.ndarray
|
||||
assert isinstance(tr[TransitionKey.ACTION]["action.ee.x"], torch.Tensor)
|
||||
assert tr[TransitionKey.ACTION]["action.ee.x"].item() == pytest.approx(0.5)
|
||||
|
||||
@@ -38,8 +37,8 @@ def test_to_transition_teleop_action_prefix_and_tensor_conversion():
|
||||
assert tr[TransitionKey.ACTION]["action.delta"].shape == (2,)
|
||||
assert torch.allclose(tr[TransitionKey.ACTION]["action.delta"], torch.tensor([1.0, 2.0]))
|
||||
|
||||
assert isinstance(tr[TransitionKey.ACTION]["action.raw_img"], torch.Tensor)
|
||||
assert tr[TransitionKey.ACTION]["action.raw_img"].dtype == torch.float32 # converted from uint8
|
||||
assert isinstance(tr[TransitionKey.ACTION]["action.raw_img"], np.ndarray)
|
||||
assert tr[TransitionKey.ACTION]["action.raw_img"].dtype == np.uint8
|
||||
assert tr[TransitionKey.ACTION]["action.raw_img"].shape == (8, 12, 3)
|
||||
|
||||
# Observation is created as empty dict by make_transition
|
||||
@@ -195,185 +194,3 @@ def test_to_dataset_frame_merge_and_pack_vectors_and_metadata():
|
||||
# Complementary data
|
||||
assert batch["frame_is_pad"] is True
|
||||
assert batch["task"] == "Pick cube"
|
||||
|
||||
|
||||
# Tests for the unified to_tensor function
|
||||
def test_to_tensor_numpy_arrays():
|
||||
"""Test to_tensor with various numpy arrays."""
|
||||
# Regular numpy array
|
||||
arr = np.array([1.0, 2.0, 3.0])
|
||||
result = to_tensor(arr)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
# Different numpy dtypes should convert to float32 by default
|
||||
int_arr = np.array([1, 2, 3], dtype=np.int64)
|
||||
result = to_tensor(int_arr)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
# uint8 arrays (previously "preserved") should now convert
|
||||
uint8_arr = np.array([100, 150, 200], dtype=np.uint8)
|
||||
result = to_tensor(uint8_arr)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([100.0, 150.0, 200.0]))
|
||||
|
||||
|
||||
def test_to_tensor_numpy_scalars():
|
||||
"""Test to_tensor with numpy scalars (0-dimensional arrays)."""
|
||||
# numpy float32 scalar
|
||||
scalar = np.float32(3.14)
|
||||
result = to_tensor(scalar)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.ndim == 0 # Should be 0-dimensional tensor
|
||||
assert result.dtype == torch.float32
|
||||
assert result.item() == pytest.approx(3.14)
|
||||
|
||||
# numpy int32 scalar
|
||||
int_scalar = np.int32(42)
|
||||
result = to_tensor(int_scalar)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.ndim == 0
|
||||
assert result.dtype == torch.float32
|
||||
assert result.item() == pytest.approx(42.0)
|
||||
|
||||
|
||||
def test_to_tensor_python_scalars():
|
||||
"""Test to_tensor with Python scalars."""
|
||||
# Python int
|
||||
result = to_tensor(42)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert result.item() == pytest.approx(42.0)
|
||||
|
||||
# Python float
|
||||
result = to_tensor(3.14)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert result.item() == pytest.approx(3.14)
|
||||
|
||||
|
||||
def test_to_tensor_sequences():
|
||||
"""Test to_tensor with lists and tuples."""
|
||||
# List
|
||||
result = to_tensor([1, 2, 3])
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
# Tuple
|
||||
result = to_tensor((4.5, 5.5, 6.5))
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([4.5, 5.5, 6.5]))
|
||||
|
||||
|
||||
def test_to_tensor_existing_tensors():
|
||||
"""Test to_tensor with existing PyTorch tensors."""
|
||||
# Tensor with same dtype should pass through with potential device change
|
||||
tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
|
||||
result = to_tensor(tensor)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, tensor)
|
||||
|
||||
# Tensor with different dtype should convert
|
||||
int_tensor = torch.tensor([1, 2, 3], dtype=torch.int64)
|
||||
result = to_tensor(int_tensor)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
|
||||
def test_to_tensor_dictionaries():
|
||||
"""Test to_tensor with nested dictionaries."""
|
||||
# Simple dictionary
|
||||
data = {"mean": [0.1, 0.2], "std": np.array([1.0, 2.0]), "count": 42}
|
||||
result = to_tensor(data)
|
||||
assert isinstance(result, dict)
|
||||
assert isinstance(result["mean"], torch.Tensor)
|
||||
assert isinstance(result["std"], torch.Tensor)
|
||||
assert isinstance(result["count"], torch.Tensor)
|
||||
assert torch.allclose(result["mean"], torch.tensor([0.1, 0.2]))
|
||||
assert torch.allclose(result["std"], torch.tensor([1.0, 2.0]))
|
||||
assert result["count"].item() == pytest.approx(42.0)
|
||||
|
||||
# Nested dictionary
|
||||
nested = {
|
||||
"action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]},
|
||||
"observation": {"mean": np.array([0.5, 0.6]), "count": 10},
|
||||
}
|
||||
result = to_tensor(nested)
|
||||
assert isinstance(result, dict)
|
||||
assert isinstance(result["action"], dict)
|
||||
assert isinstance(result["observation"], dict)
|
||||
assert isinstance(result["action"]["mean"], torch.Tensor)
|
||||
assert isinstance(result["observation"]["mean"], torch.Tensor)
|
||||
assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2]))
|
||||
assert torch.allclose(result["observation"]["mean"], torch.tensor([0.5, 0.6]))
|
||||
|
||||
|
||||
def test_to_tensor_none_filtering():
|
||||
"""Test that None values are filtered out from dictionaries."""
|
||||
data = {"valid": [1, 2, 3], "none_value": None, "nested": {"valid": [4, 5], "also_none": None}}
|
||||
result = to_tensor(data)
|
||||
assert "none_value" not in result
|
||||
assert "also_none" not in result["nested"]
|
||||
assert "valid" in result
|
||||
assert "valid" in result["nested"]
|
||||
assert torch.allclose(result["valid"], torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
|
||||
def test_to_tensor_dtype_parameter():
|
||||
"""Test to_tensor with different dtype parameters."""
|
||||
arr = np.array([1, 2, 3])
|
||||
|
||||
# Default dtype (float32)
|
||||
result = to_tensor(arr)
|
||||
assert result.dtype == torch.float32
|
||||
|
||||
# Explicit float32
|
||||
result = to_tensor(arr, dtype=torch.float32)
|
||||
assert result.dtype == torch.float32
|
||||
|
||||
# Float64
|
||||
result = to_tensor(arr, dtype=torch.float64)
|
||||
assert result.dtype == torch.float64
|
||||
|
||||
# Preserve original dtype
|
||||
float64_arr = np.array([1.0, 2.0, 3.0], dtype=np.float64)
|
||||
result = to_tensor(float64_arr, dtype=None)
|
||||
assert result.dtype == torch.float64
|
||||
|
||||
|
||||
def test_to_tensor_device_parameter():
|
||||
"""Test to_tensor with device parameter."""
|
||||
arr = np.array([1.0, 2.0, 3.0])
|
||||
|
||||
# CPU device (default)
|
||||
result = to_tensor(arr, device="cpu")
|
||||
assert result.device.type == "cpu"
|
||||
|
||||
# CUDA device (if available)
|
||||
if torch.cuda.is_available():
|
||||
result = to_tensor(arr, device="cuda")
|
||||
assert result.device.type == "cuda"
|
||||
|
||||
|
||||
def test_to_tensor_empty_dict():
|
||||
"""Test to_tensor with empty dictionary."""
|
||||
result = to_tensor({})
|
||||
assert isinstance(result, dict)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_to_tensor_unsupported_type():
|
||||
"""Test to_tensor with unsupported types raises TypeError."""
|
||||
with pytest.raises(TypeError, match="Unsupported type for tensor conversion"):
|
||||
to_tensor("unsupported_string")
|
||||
|
||||
with pytest.raises(TypeError, match="Unsupported type for tensor conversion"):
|
||||
to_tensor(object())
|
||||
|
||||
@@ -311,12 +311,7 @@ def test_integration_with_robot_processor():
|
||||
device_processor = DeviceProcessor(device="cpu")
|
||||
batch_processor = ToBatchProcessor()
|
||||
|
||||
processor = RobotProcessor(
|
||||
steps=[batch_processor, device_processor],
|
||||
name="test_pipeline",
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
processor = RobotProcessor(steps=[batch_processor, device_processor], name="test_pipeline")
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
@@ -990,8 +985,6 @@ def test_policy_processor_integration():
|
||||
DeviceProcessor(device="cuda"),
|
||||
],
|
||||
name="test_preprocessor",
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
|
||||
# Create output processor (postprocessor) that moves to CPU
|
||||
@@ -1001,8 +994,6 @@ def test_policy_processor_integration():
|
||||
UnnormalizerProcessor(features={ACTION: features[ACTION]}, norm_map=norm_map, stats=stats),
|
||||
],
|
||||
name="test_postprocessor",
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
|
||||
# Test data on CPU
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors
|
||||
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
@@ -81,7 +81,7 @@ def test_make_diffusion_processor_basic():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
@@ -105,12 +105,7 @@ def test_diffusion_processor_with_images():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
# Create test data with images
|
||||
observation = {
|
||||
@@ -136,12 +131,7 @@ def test_diffusion_processor_cuda():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
@@ -174,12 +164,7 @@ def test_diffusion_processor_accelerate_scenario():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
@@ -206,7 +191,7 @@ def test_diffusion_processor_multi_gpu():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
@@ -230,7 +215,7 @@ def test_diffusion_processor_without_stats():
|
||||
"""Test Diffusion processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, dataset_stats=None)
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
@@ -253,22 +238,14 @@ def test_diffusion_processor_save_and_load():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(
|
||||
factory_preprocessor.steps, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {
|
||||
@@ -291,19 +268,13 @@ def test_diffusion_processor_mixed_precision():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
modified_steps = []
|
||||
for step in factory_preprocessor.steps:
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
if isinstance(step, DeviceProcessor):
|
||||
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
|
||||
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
@@ -327,12 +298,7 @@ def test_diffusion_processor_identity_normalization():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
# Create test data
|
||||
image_value = torch.rand(3, 224, 224) * 255 # Large values
|
||||
@@ -356,12 +322,7 @@ def test_diffusion_processor_batch_consistency():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
# Test with different batch sizes
|
||||
for batch_size in [1, 8, 32]:
|
||||
|
||||
@@ -20,11 +20,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.processor.normalize_processor import (
|
||||
NormalizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
_convert_stats_to_tensors,
|
||||
hotswap_stats,
|
||||
rename_stats,
|
||||
)
|
||||
from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey
|
||||
|
||||
@@ -51,7 +52,7 @@ def test_numpy_conversion():
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
}
|
||||
}
|
||||
tensor_stats = to_tensor(stats)
|
||||
tensor_stats = _convert_stats_to_tensors(stats)
|
||||
|
||||
assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor)
|
||||
assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor)
|
||||
@@ -66,7 +67,7 @@ def test_tensor_conversion():
|
||||
"std": torch.tensor([1.0, 1.0]),
|
||||
}
|
||||
}
|
||||
tensor_stats = to_tensor(stats)
|
||||
tensor_stats = _convert_stats_to_tensors(stats)
|
||||
|
||||
assert tensor_stats["action"]["mean"].dtype == torch.float32
|
||||
assert tensor_stats["action"]["std"].dtype == torch.float32
|
||||
@@ -79,7 +80,7 @@ def test_scalar_conversion():
|
||||
"std": 0.1,
|
||||
}
|
||||
}
|
||||
tensor_stats = to_tensor(stats)
|
||||
tensor_stats = _convert_stats_to_tensors(stats)
|
||||
|
||||
assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5))
|
||||
assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1))
|
||||
@@ -92,7 +93,7 @@ def test_list_conversion():
|
||||
"max": [1.0, 1.0, 2.0],
|
||||
}
|
||||
}
|
||||
tensor_stats = to_tensor(stats)
|
||||
tensor_stats = _convert_stats_to_tensors(stats)
|
||||
|
||||
assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0]))
|
||||
assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0]))
|
||||
@@ -105,7 +106,7 @@ def test_unsupported_type():
|
||||
}
|
||||
}
|
||||
with pytest.raises(TypeError, match="Unsupported type"):
|
||||
to_tensor(stats)
|
||||
_convert_stats_to_tensors(stats)
|
||||
|
||||
|
||||
# Helper functions to create feature maps and norm maps
|
||||
@@ -181,10 +182,7 @@ def test_selective_normalization(observation_stats):
|
||||
features = _create_observation_features()
|
||||
norm_map = _create_observation_norm_map()
|
||||
normalizer = NormalizerProcessor(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=observation_stats,
|
||||
normalize_observation_keys={"observation.image"},
|
||||
features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"}
|
||||
)
|
||||
|
||||
observation = {
|
||||
@@ -245,7 +243,6 @@ def test_from_lerobot_dataset():
|
||||
def test_state_dict_save_load(observation_normalizer):
|
||||
# Save state
|
||||
state_dict = observation_normalizer.state_dict()
|
||||
print("State dict:", state_dict)
|
||||
|
||||
# Create new normalizer and load state
|
||||
features = _create_observation_features()
|
||||
@@ -467,10 +464,10 @@ def test_processor_from_lerobot_dataset(full_stats):
|
||||
norm_map = _create_full_norm_map()
|
||||
|
||||
processor = NormalizerProcessor.from_lerobot_dataset(
|
||||
mock_dataset, features, norm_map, normalize_observation_keys={"observation.image"}
|
||||
mock_dataset, features, norm_map, normalize_keys={"observation.image"}
|
||||
)
|
||||
|
||||
assert processor.normalize_observation_keys == {"observation.image"}
|
||||
assert processor.normalize_keys == {"observation.image"}
|
||||
assert "observation.image" in processor._tensor_stats
|
||||
assert "action" in processor._tensor_stats
|
||||
|
||||
@@ -479,16 +476,12 @@ def test_get_config(full_stats):
|
||||
features = _create_full_features()
|
||||
norm_map = _create_full_norm_map()
|
||||
processor = NormalizerProcessor(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=full_stats,
|
||||
normalize_observation_keys={"observation.image"},
|
||||
eps=1e-6,
|
||||
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
|
||||
)
|
||||
|
||||
config = processor.get_config()
|
||||
expected_config = {
|
||||
"normalize_observation_keys": ["observation.image"],
|
||||
"normalize_keys": ["observation.image"],
|
||||
"eps": 1e-6,
|
||||
"features": {
|
||||
"observation.image": {"type": "VISUAL", "shape": (3, 96, 96)},
|
||||
@@ -506,7 +499,7 @@ def test_get_config(full_stats):
|
||||
|
||||
def test_integration_with_robot_processor(normalizer_processor):
|
||||
"""Test integration with RobotProcessor pipeline"""
|
||||
robot_processor = RobotProcessor([normalizer_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
robot_processor = RobotProcessor([normalizer_processor])
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
@@ -587,11 +580,7 @@ def test_serialization_roundtrip(full_stats):
|
||||
features = _create_full_features()
|
||||
norm_map = _create_full_norm_map()
|
||||
original_processor = NormalizerProcessor(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=full_stats,
|
||||
normalize_observation_keys={"observation.image"},
|
||||
eps=1e-6,
|
||||
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
|
||||
)
|
||||
|
||||
# Get config (serialization)
|
||||
@@ -602,7 +591,7 @@ def test_serialization_roundtrip(full_stats):
|
||||
features=config["features"],
|
||||
norm_map=config["norm_map"],
|
||||
stats=full_stats,
|
||||
normalize_observation_keys=set(config["normalize_observation_keys"]),
|
||||
normalize_keys=set(config["normalize_keys"]),
|
||||
eps=config["eps"],
|
||||
)
|
||||
|
||||
@@ -950,31 +939,31 @@ def test_identity_config_serialization():
|
||||
assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION])
|
||||
|
||||
|
||||
# def test_unsupported_normalization_mode_error():
|
||||
# """Test that unsupported normalization modes raise appropriate errors."""
|
||||
# features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))}
|
||||
def test_unsupported_normalization_mode_error():
|
||||
"""Test that unsupported normalization modes raise appropriate errors."""
|
||||
features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))}
|
||||
|
||||
# # Create an invalid norm_map (this would never happen in practice, but tests error handling)
|
||||
# from enum import Enum
|
||||
# Create an invalid norm_map (this would never happen in practice, but tests error handling)
|
||||
from enum import Enum
|
||||
|
||||
# class InvalidMode(str, Enum):
|
||||
# INVALID = "INVALID"
|
||||
class InvalidMode(str, Enum):
|
||||
INVALID = "INVALID"
|
||||
|
||||
# # We can't actually pass an invalid enum to the processor due to type checking,
|
||||
# # but we can test the error by manipulating the norm_map after creation
|
||||
# norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||
# stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}}
|
||||
# We can't actually pass an invalid enum to the processor due to type checking,
|
||||
# but we can test the error by manipulating the norm_map after creation
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||
stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}}
|
||||
|
||||
# normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
# # Manually inject an invalid mode to test error handling
|
||||
# normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE"
|
||||
# Manually inject an invalid mode to test error handling
|
||||
normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE"
|
||||
|
||||
# observation = {"observation.state": torch.tensor([1.0, -0.5])}
|
||||
# transition = create_transition(observation=observation)
|
||||
observation = {"observation.state": torch.tensor([1.0, -0.5])}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
# with pytest.raises(ValueError, match="Unsupported normalization mode"):
|
||||
# normalizer(transition)
|
||||
with pytest.raises(ValueError, match="Unsupported normalization mode"):
|
||||
normalizer(transition)
|
||||
|
||||
|
||||
def test_hotswap_stats_basic_functionality():
|
||||
@@ -1017,7 +1006,7 @@ def test_hotswap_stats_basic_functionality():
|
||||
assert new_processor.steps[1].stats == new_stats
|
||||
|
||||
# Check that tensor stats are updated correctly
|
||||
expected_tensor_stats = to_tensor(new_stats)
|
||||
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
|
||||
for key in expected_tensor_stats:
|
||||
for stat_name in expected_tensor_stats[key]:
|
||||
torch.testing.assert_close(
|
||||
@@ -1160,15 +1149,11 @@ def test_hotswap_stats_preserves_other_attributes():
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
normalize_observation_keys = {"observation.image"}
|
||||
normalize_keys = {"observation.image"}
|
||||
eps = 1e-6
|
||||
|
||||
normalizer = NormalizerProcessor(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=initial_stats,
|
||||
normalize_observation_keys=normalize_observation_keys,
|
||||
eps=eps,
|
||||
features=features, norm_map=norm_map, stats=initial_stats, normalize_keys=normalize_keys, eps=eps
|
||||
)
|
||||
robot_processor = RobotProcessor(steps=[normalizer])
|
||||
|
||||
@@ -1179,7 +1164,7 @@ def test_hotswap_stats_preserves_other_attributes():
|
||||
new_normalizer = new_processor.steps[0]
|
||||
assert new_normalizer.features == features
|
||||
assert new_normalizer.norm_map == norm_map
|
||||
assert new_normalizer.normalize_observation_keys == normalize_observation_keys
|
||||
assert new_normalizer.normalize_keys == normalize_keys
|
||||
assert new_normalizer.eps == eps
|
||||
|
||||
# But stats should be updated
|
||||
@@ -1223,7 +1208,7 @@ def test_hotswap_stats_multiple_normalizer_types():
|
||||
assert step.stats == new_stats
|
||||
|
||||
# Check tensor stats conversion
|
||||
expected_tensor_stats = to_tensor(new_stats)
|
||||
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
|
||||
for key in expected_tensor_stats:
|
||||
for stat_name in expected_tensor_stats[key]:
|
||||
torch.testing.assert_close(
|
||||
@@ -1285,6 +1270,273 @@ def test_hotswap_stats_with_different_data_types():
|
||||
torch.testing.assert_close(tensor_stats["observation.image"]["max"], torch.tensor(1.0))
|
||||
|
||||
|
||||
def test_normalization_info_tracking():
|
||||
"""Test that normalization info is tracked in complementary_data."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.STATE: NormalizationMode.MIN_MAX,
|
||||
FeatureType.ACTION: NormalizationMode.IDENTITY,
|
||||
}
|
||||
|
||||
stats = {
|
||||
"observation.image": {
|
||||
"mean": np.array([0.5, 0.5, 0.5]),
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
},
|
||||
"observation.state": {
|
||||
"min": np.array([0.0, -1.0]),
|
||||
"max": np.array([1.0, 1.0]),
|
||||
},
|
||||
"action": {
|
||||
"mean": np.array([0.0, 0.0]),
|
||||
"std": np.array([1.0, 1.0]),
|
||||
},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
|
||||
# Process the transition
|
||||
normalized_transition = normalizer(transition)
|
||||
|
||||
# Check that normalization info is added
|
||||
comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
assert comp_data is not None
|
||||
assert "normalized_keys" in comp_data
|
||||
|
||||
norm_info = comp_data["normalized_keys"]
|
||||
assert norm_info["observation.image"] == "MEAN_STD"
|
||||
assert norm_info["observation.state"] == "MIN_MAX"
|
||||
assert norm_info["action"] == "IDENTITY"
|
||||
|
||||
|
||||
def test_unnormalization_info_tracking():
|
||||
"""Test that unnormalization info is tracked in complementary_data."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
stats = {
|
||||
"observation.image": {
|
||||
"mean": np.array([0.5, 0.5, 0.5]),
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
},
|
||||
"action": {
|
||||
"min": np.array([-1.0, -1.0]),
|
||||
"max": np.array([1.0, 1.0]),
|
||||
},
|
||||
}
|
||||
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
|
||||
action = torch.tensor([0.0, -0.5])
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
|
||||
# Process the transition
|
||||
unnormalized_transition = unnormalizer(transition)
|
||||
|
||||
# Check that unnormalization info is added
|
||||
comp_data = unnormalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
assert comp_data is not None
|
||||
assert "unnormalized_keys" in comp_data
|
||||
|
||||
unnorm_info = comp_data["unnormalized_keys"]
|
||||
assert unnorm_info["observation.image"] == "MEAN_STD"
|
||||
assert unnorm_info["action"] == "MIN_MAX"
|
||||
|
||||
|
||||
def test_normalization_info_with_missing_stats():
|
||||
"""Test normalization info when stats are missing for some keys."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.STATE: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
# Only provide stats for image, not state
|
||||
stats = {
|
||||
"observation.image": {
|
||||
"mean": np.array([0.5, 0.5, 0.5]),
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
# Process the transition
|
||||
normalized_transition = normalizer(transition)
|
||||
|
||||
# Check that only keys with stats are in normalization info
|
||||
comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
assert comp_data is not None
|
||||
assert "normalized_keys" in comp_data
|
||||
|
||||
norm_info = comp_data["normalized_keys"]
|
||||
assert norm_info["observation.image"] == "MEAN_STD"
|
||||
# State should not be in the normalization info since it has no stats
|
||||
assert "observation.state" not in norm_info
|
||||
|
||||
|
||||
def test_normalization_info_with_selective_keys():
|
||||
"""Test normalization info with selective normalization."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.STATE: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
stats = {
|
||||
"observation.image": {
|
||||
"mean": np.array([0.5, 0.5, 0.5]),
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
},
|
||||
"observation.state": {
|
||||
"min": np.array([0.0, -1.0]),
|
||||
"max": np.array([1.0, 1.0]),
|
||||
},
|
||||
}
|
||||
|
||||
# Only normalize image
|
||||
normalizer = NormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats=stats, normalize_keys={"observation.image"}
|
||||
)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
# Process the transition
|
||||
normalized_transition = normalizer(transition)
|
||||
|
||||
# Check that only selected keys are in normalization info
|
||||
comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
assert comp_data is not None
|
||||
assert "normalized_keys" in comp_data
|
||||
|
||||
norm_info = comp_data["normalized_keys"]
|
||||
assert norm_info["observation.image"] == "MEAN_STD"
|
||||
# State should not be in the normalization info since it wasn't in normalize_keys
|
||||
assert "observation.state" not in norm_info
|
||||
|
||||
|
||||
def test_normalization_info_preserved_in_pipeline():
|
||||
"""Test that normalization info is preserved when using RobotProcessor pipeline."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
stats = {
|
||||
"observation.image": {
|
||||
"mean": np.array([0.5, 0.5, 0.5]),
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
},
|
||||
"action": {
|
||||
"min": np.array([-1.0, -1.0]),
|
||||
"max": np.array([1.0, 1.0]),
|
||||
},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
# Create pipeline
|
||||
pipeline = RobotProcessor([normalizer, unnormalizer])
|
||||
|
||||
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
|
||||
action = torch.tensor([0.5, -0.5])
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
|
||||
# Process through pipeline
|
||||
result = pipeline(transition)
|
||||
|
||||
# Check that both normalization and unnormalization info are present
|
||||
comp_data = result.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
assert comp_data is not None
|
||||
assert "normalized_keys" in comp_data
|
||||
assert "unnormalized_keys" in comp_data
|
||||
|
||||
# Check normalization info
|
||||
norm_info = comp_data["normalized_keys"]
|
||||
assert norm_info["observation.image"] == "MEAN_STD"
|
||||
assert norm_info["action"] == "MIN_MAX"
|
||||
|
||||
# Check unnormalization info
|
||||
unnorm_info = comp_data["unnormalized_keys"]
|
||||
assert unnorm_info["observation.image"] == "MEAN_STD"
|
||||
assert unnorm_info["action"] == "MIN_MAX"
|
||||
|
||||
|
||||
def test_normalization_info_empty_transition():
|
||||
"""Test that no normalization info is added for empty transitions."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
stats = {
|
||||
"observation.image": {"mean": [0.5], "std": [0.2]},
|
||||
"action": {"min": [-1.0], "max": [1.0]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
# Empty transition
|
||||
transition = create_transition()
|
||||
|
||||
# Process the transition
|
||||
normalized_transition = normalizer(transition)
|
||||
|
||||
# Check that no normalization info is added
|
||||
comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
assert comp_data is None or "normalized_keys" not in comp_data
|
||||
|
||||
|
||||
def test_hotswap_stats_functional_test():
|
||||
"""Test that hotswapped processor actually works functionally."""
|
||||
# Create test data
|
||||
@@ -1317,7 +1569,7 @@ def test_hotswap_stats_functional_test():
|
||||
|
||||
# Create original processor
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
original_processor = RobotProcessor(steps=[normalizer], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
original_processor = RobotProcessor(steps=[normalizer])
|
||||
|
||||
# Process with original stats
|
||||
original_result = original_processor(transition)
|
||||
@@ -1379,8 +1631,8 @@ def test_min_equals_max_maps_to_minus_one():
|
||||
assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([-1.0]))
|
||||
|
||||
|
||||
def test_action_normalized_despite_normalize_observation_keys():
|
||||
"""Action normalization is independent of normalize_observation_keys filter for observations."""
|
||||
def test_action_normalized_despite_normalize_keys():
|
||||
"""Action normalization is independent of normalize_keys filter for observations."""
|
||||
features = {
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (1,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
@@ -1388,7 +1640,7 @@ def test_action_normalized_despite_normalize_observation_keys():
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}}
|
||||
normalizer = NormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={"observation.state"}
|
||||
features=features, norm_map=norm_map, stats=stats, normalize_keys={"observation.state"}
|
||||
)
|
||||
|
||||
transition = create_transition(
|
||||
@@ -1428,6 +1680,19 @@ def test_unnormalize_observations_mean_std_and_min_max():
|
||||
assert torch.allclose(out_mm, torch.tensor([1.0, 0.0])) # mid of [0,2] and [-2,2]
|
||||
|
||||
|
||||
def test_rename_stats_basic():
|
||||
orig = {
|
||||
"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])},
|
||||
"action": {"mean": np.array([0.0])},
|
||||
}
|
||||
mapping = {"observation.state": "observation.robot_state"}
|
||||
renamed = rename_stats(orig, mapping)
|
||||
assert "observation.robot_state" in renamed and "observation.state" not in renamed
|
||||
# Ensure deep copy: mutate original and verify renamed unaffected
|
||||
orig["observation.state"]["mean"][0] = 42.0
|
||||
assert renamed["observation.robot_state"]["mean"][0] != 42.0
|
||||
|
||||
|
||||
def test_unknown_observation_keys_ignored():
|
||||
features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||
@@ -1440,6 +1705,8 @@ def test_unknown_observation_keys_ignored():
|
||||
|
||||
# Unknown key should pass through unchanged and not be tracked
|
||||
assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.unknown"], obs["observation.unknown"])
|
||||
comp = out.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
assert "normalized_keys" in comp and "observation.unknown" not in comp["normalized_keys"]
|
||||
|
||||
|
||||
def test_batched_action_normalization():
|
||||
@@ -1464,7 +1731,7 @@ def test_complementary_data_preservation():
|
||||
tr = create_transition(observation={"observation.state": torch.tensor([1.0])}, complementary_data=comp)
|
||||
out = normalizer(tr)
|
||||
new_comp = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert new_comp["existing"] == 123
|
||||
assert new_comp["existing"] == 123 and "normalized_keys" in new_comp
|
||||
|
||||
|
||||
def test_roundtrip_normalize_unnormalize_non_identity():
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre_post_processors
|
||||
from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_processor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
@@ -84,12 +84,7 @@ def test_make_pi0_processor_basic():
|
||||
stats = create_default_stats()
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor"):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_pi0_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
@@ -188,12 +183,7 @@ def test_pi0_processor_cuda():
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_pi0_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
@@ -243,12 +233,7 @@ def test_pi0_processor_accelerate_scenario():
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_pi0_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU and batched
|
||||
device = torch.device("cuda:0")
|
||||
@@ -299,12 +284,7 @@ def test_pi0_processor_multi_gpu():
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_pi0_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
@@ -330,12 +310,7 @@ def test_pi0_processor_without_stats():
|
||||
|
||||
# Mock the tokenizer processor
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor"):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
dataset_stats=None,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_pi0_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
|
||||
@@ -176,7 +176,7 @@ class MockStepWithTensorState:
|
||||
|
||||
def test_empty_pipeline():
|
||||
"""Test pipeline with no steps."""
|
||||
pipeline = RobotProcessor([], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
pipeline = RobotProcessor()
|
||||
|
||||
transition = create_transition()
|
||||
result = pipeline(transition)
|
||||
@@ -188,7 +188,7 @@ def test_empty_pipeline():
|
||||
def test_single_step_pipeline():
|
||||
"""Test pipeline with a single step."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotProcessor([step], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
transition = create_transition()
|
||||
result = pipeline(transition)
|
||||
@@ -205,7 +205,7 @@ def test_multiple_steps_pipeline():
|
||||
"""Test pipeline with multiple steps."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotProcessor([step1, step2], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
|
||||
transition = create_transition()
|
||||
result = pipeline(transition)
|
||||
@@ -557,9 +557,7 @@ def test_save_and_load_pretrained():
|
||||
def test_step_without_optional_methods():
|
||||
"""Test pipeline with steps that don't implement optional methods."""
|
||||
step = MockStepWithoutOptionalMethods(multiplier=3.0)
|
||||
pipeline = RobotProcessor(
|
||||
[step], to_transition=lambda x: x, to_output=lambda x: x
|
||||
) # Identity for EnvTransition input/output
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
transition = create_transition(reward=2.0)
|
||||
result = pipeline(transition)
|
||||
@@ -880,9 +878,7 @@ def test_from_pretrained_with_overrides():
|
||||
"registered_mock_step": {"device": "cuda", "value": 200},
|
||||
}
|
||||
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, overrides=overrides, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
|
||||
# Verify the pipeline was loaded correctly
|
||||
assert len(loaded_pipeline) == 2
|
||||
@@ -918,9 +914,7 @@ def test_from_pretrained_with_partial_overrides():
|
||||
|
||||
# The current implementation applies overrides to ALL steps with the same class name
|
||||
# Both steps will get the override
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, overrides=overrides, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
|
||||
transition = create_transition(reward=1.0)
|
||||
result = loaded_pipeline(transition)
|
||||
@@ -977,9 +971,7 @@ def test_from_pretrained_registered_step_override():
|
||||
# Override using registry name
|
||||
overrides = {"registered_mock_step": {"value": 999, "device": "cuda"}}
|
||||
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, overrides=overrides, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
|
||||
# Test that overrides were applied
|
||||
transition = create_transition()
|
||||
@@ -1007,9 +999,7 @@ def test_from_pretrained_mixed_registered_and_unregistered():
|
||||
"registered_mock_step": {"value": 777},
|
||||
}
|
||||
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, overrides=overrides, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
|
||||
# Test both steps
|
||||
transition = create_transition(reward=2.0)
|
||||
@@ -1030,9 +1020,7 @@ def test_from_pretrained_no_overrides():
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Load without overrides
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
|
||||
assert len(loaded_pipeline) == 1
|
||||
|
||||
@@ -1052,9 +1040,7 @@ def test_from_pretrained_empty_overrides():
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Load with empty overrides
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, overrides={}, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides={})
|
||||
|
||||
assert len(loaded_pipeline) == 1
|
||||
|
||||
@@ -1469,8 +1455,6 @@ def test_override_with_nested_config():
|
||||
loaded = RobotProcessor.from_pretrained(
|
||||
tmp_dir,
|
||||
overrides={"complex_config_step": {"nested_config": {"level1": {"level2": "overridden"}}}},
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
|
||||
# Test that override worked
|
||||
@@ -1569,10 +1553,7 @@ def test_override_with_callables():
|
||||
|
||||
# Load with callable override
|
||||
loaded = RobotProcessor.from_pretrained(
|
||||
tmp_dir,
|
||||
overrides={"callable_step": {"transform_fn": double_values}},
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
tmp_dir, overrides={"callable_step": {"transform_fn": double_values}}
|
||||
)
|
||||
|
||||
# Test it works
|
||||
@@ -1876,8 +1857,7 @@ def test_save_load_with_custom_converter_functions():
|
||||
|
||||
# Should work with standard format (wouldn't work with custom converter)
|
||||
result = loaded(batch)
|
||||
# With new behavior, default to_output is _default_transition_to_batch, so result is batch dict
|
||||
assert "observation.image" in result
|
||||
assert "observation.image" in result # Standard format preserved
|
||||
|
||||
|
||||
class NonCompliantStep:
|
||||
|
||||
@@ -21,7 +21,6 @@ import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
@@ -188,7 +187,7 @@ def test_integration_with_robot_processor():
|
||||
}
|
||||
rename_processor = RenameProcessor(rename_map=rename_map)
|
||||
|
||||
pipeline = RobotProcessor([rename_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
pipeline = RobotProcessor([rename_processor])
|
||||
|
||||
observation = {
|
||||
"agent_pos": np.array([1.0, 2.0, 3.0]),
|
||||
@@ -236,9 +235,7 @@ def test_save_and_load_pretrained():
|
||||
assert len(state_files) == 0
|
||||
|
||||
# Load pipeline
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
|
||||
assert loaded_pipeline.name == "TestRenameProcessor"
|
||||
assert len(loaded_pipeline) == 1
|
||||
@@ -279,7 +276,7 @@ def test_registry_functionality():
|
||||
def test_registry_based_save_load():
|
||||
"""Test save/load using registry name instead of module path."""
|
||||
processor = RenameProcessor(rename_map={"key1": "renamed_key1"})
|
||||
pipeline = RobotProcessor([processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
pipeline = RobotProcessor([processor])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save and load
|
||||
@@ -320,7 +317,7 @@ def test_chained_rename_processors():
|
||||
}
|
||||
)
|
||||
|
||||
pipeline = RobotProcessor([processor1, processor2], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
pipeline = RobotProcessor([processor1, processor2])
|
||||
|
||||
observation = {
|
||||
"pos": np.array([1.0, 2.0]),
|
||||
@@ -468,16 +465,3 @@ def test_features_chained_processors(policy_feature_factory):
|
||||
assert out["observation.image"] == spec["img"]
|
||||
assert out["extra"] == spec["extra"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_rename_stats_basic():
|
||||
orig = {
|
||||
"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])},
|
||||
"action": {"mean": np.array([0.0])},
|
||||
}
|
||||
mapping = {"observation.state": "observation.robot_state"}
|
||||
renamed = rename_stats(orig, mapping)
|
||||
assert "observation.robot_state" in renamed and "observation.state" not in renamed
|
||||
# Ensure deep copy: mutate original and verify renamed unaffected
|
||||
orig["observation.state"]["mean"][0] = 42.0
|
||||
assert renamed["observation.robot_state"]["mean"][0] != 42.0
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
from lerobot.policies.sac.processor_sac import make_sac_processor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
@@ -78,12 +78,7 @@ def test_make_sac_processor_basic():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
@@ -107,12 +102,7 @@ def test_sac_processor_normalization_modes():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10) * 2} # Larger values to test normalization
|
||||
@@ -143,12 +133,7 @@ def test_sac_processor_cuda():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
@@ -177,12 +162,7 @@ def test_sac_processor_accelerate_scenario():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
@@ -205,12 +185,7 @@ def test_sac_processor_multi_gpu():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
@@ -230,22 +205,7 @@ def test_sac_processor_without_stats():
|
||||
"""Test SAC processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(
|
||||
factory_preprocessor.steps,
|
||||
name=factory_preprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
postprocessor = RobotProcessor(
|
||||
factory_postprocessor.steps,
|
||||
name=factory_postprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
preprocessor, postprocessor = make_sac_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
@@ -265,21 +225,14 @@ def test_sac_processor_save_and_load():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
@@ -299,12 +252,7 @@ def test_sac_processor_mixed_precision():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
@@ -329,12 +277,7 @@ def test_sac_processor_batch_data():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
# Test with batched data
|
||||
batch_size = 32
|
||||
@@ -355,12 +298,7 @@ def test_sac_processor_edge_cases():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
# Test with empty observation
|
||||
transition = create_transition(observation={}, action=torch.randn(5))
|
||||
|
||||
@@ -23,10 +23,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.processor_smolvla import (
|
||||
SmolVLANewLineProcessor,
|
||||
make_smolvla_pre_post_processors,
|
||||
)
|
||||
from lerobot.policies.smolvla.processor_smolvla import SmolVLANewLineProcessor, make_smolvla_processor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
@@ -89,12 +86,7 @@ def test_make_smolvla_processor_basic():
|
||||
stats = create_default_stats()
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor"):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_smolvla_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
@@ -193,12 +185,7 @@ def test_smolvla_processor_cuda():
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_smolvla_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
@@ -248,12 +235,7 @@ def test_smolvla_processor_accelerate_scenario():
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_smolvla_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU and batched
|
||||
device = torch.device("cuda:0")
|
||||
@@ -304,12 +286,7 @@ def test_smolvla_processor_multi_gpu():
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_smolvla_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
@@ -335,12 +312,7 @@ def test_smolvla_processor_without_stats():
|
||||
|
||||
# Mock the tokenizer processor
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor"):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
dataset_stats=None,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_smolvla_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
|
||||
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
@@ -81,12 +81,7 @@ def test_make_tdmpc_processor_basic():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
@@ -110,12 +105,7 @@ def test_tdmpc_processor_normalization():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
@@ -148,12 +138,7 @@ def test_tdmpc_processor_cuda():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
@@ -186,12 +171,7 @@ def test_tdmpc_processor_accelerate_scenario():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
@@ -218,12 +198,7 @@ def test_tdmpc_processor_multi_gpu():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
@@ -247,22 +222,7 @@ def test_tdmpc_processor_without_stats():
|
||||
"""Test TDMPC processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_tdmpc_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(
|
||||
factory_preprocessor.steps,
|
||||
name=factory_preprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
postprocessor = RobotProcessor(
|
||||
factory_postprocessor.steps,
|
||||
name=factory_postprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
@@ -285,21 +245,14 @@ def test_tdmpc_processor_save_and_load():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {
|
||||
@@ -323,12 +276,7 @@ def test_tdmpc_processor_mixed_precision():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
@@ -357,12 +305,7 @@ def test_tdmpc_processor_batch_data():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
# Test with batched data
|
||||
batch_size = 64
|
||||
@@ -387,12 +330,7 @@ def test_tdmpc_processor_edge_cases():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
# Test with only state observation (no image)
|
||||
observation = {OBS_STATE: torch.randn(12)}
|
||||
|
||||
@@ -98,11 +98,7 @@ def test_basic_tokenization(mock_auto_tokenizer):
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "pick up the red cube"},
|
||||
)
|
||||
transition = create_transition(complementary_data={"task": "pick up the red cube"})
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -130,11 +126,7 @@ def test_basic_tokenization_with_tokenizer_object():
|
||||
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "pick up the red cube"},
|
||||
)
|
||||
transition = create_transition(complementary_data={"task": "pick up the red cube"})
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -164,11 +156,7 @@ def test_list_of_strings_tokenization(mock_auto_tokenizer):
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=8)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": ["pick up cube", "place on table"]},
|
||||
)
|
||||
transition = create_transition(complementary_data={"task": ["pick up cube", "place on table"]})
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -192,11 +180,7 @@ def test_custom_keys(mock_auto_tokenizer):
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", task_key="instruction", max_length=5)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"instruction": "move forward"},
|
||||
)
|
||||
transition = create_transition(complementary_data={"instruction": "move forward"})
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -389,7 +373,7 @@ def test_integration_with_robot_processor(mock_auto_tokenizer):
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6)
|
||||
robot_processor = RobotProcessor([tokenizer_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
robot_processor = RobotProcessor([tokenizer_processor])
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -427,23 +411,17 @@ def test_save_and_load_pretrained_with_tokenizer_name(mock_auto_tokenizer):
|
||||
tokenizer_name="test-tokenizer", max_length=32, task_key="instruction"
|
||||
)
|
||||
|
||||
robot_processor = RobotProcessor([original_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
robot_processor = RobotProcessor([original_processor])
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Save processor
|
||||
robot_processor.save_pretrained(temp_dir)
|
||||
|
||||
# Load processor - tokenizer will be recreated from saved config
|
||||
loaded_processor = RobotProcessor.from_pretrained(
|
||||
temp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
loaded_processor = RobotProcessor.from_pretrained(temp_dir)
|
||||
|
||||
# Test that loaded processor works
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"instruction": "test instruction"},
|
||||
)
|
||||
transition = create_transition(complementary_data={"instruction": "test instruction"})
|
||||
|
||||
result = loaded_processor(transition)
|
||||
assert TransitionKey.OBSERVATION in result
|
||||
@@ -458,7 +436,7 @@ def test_save_and_load_pretrained_with_tokenizer_object():
|
||||
|
||||
original_processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=32, task_key="instruction")
|
||||
|
||||
robot_processor = RobotProcessor([original_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
robot_processor = RobotProcessor([original_processor])
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Save processor
|
||||
@@ -466,18 +444,11 @@ def test_save_and_load_pretrained_with_tokenizer_object():
|
||||
|
||||
# Load processor with tokenizer override (since tokenizer object wasn't saved)
|
||||
loaded_processor = RobotProcessor.from_pretrained(
|
||||
temp_dir,
|
||||
overrides={"tokenizer_processor": {"tokenizer": mock_tokenizer}},
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
temp_dir, overrides={"tokenizer_processor": {"tokenizer": mock_tokenizer}}
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"instruction": "test instruction"},
|
||||
)
|
||||
transition = create_transition(complementary_data={"instruction": "test instruction"})
|
||||
|
||||
result = loaded_processor(transition)
|
||||
assert TransitionKey.OBSERVATION in result
|
||||
@@ -598,11 +569,7 @@ def test_tokenization_parameters(mock_auto_tokenizer):
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "test task"},
|
||||
)
|
||||
transition = create_transition(complementary_data={"task": "test task"})
|
||||
|
||||
processor(transition)
|
||||
|
||||
@@ -625,14 +592,12 @@ def test_preserves_other_complementary_data(mock_auto_tokenizer):
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={
|
||||
"task": "test task",
|
||||
"episode_id": 123,
|
||||
"timestamp": 456.789,
|
||||
"other_field": {"nested": "data"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
@@ -659,11 +624,7 @@ def test_deterministic_tokenization(mock_auto_tokenizer):
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "consistent test"},
|
||||
)
|
||||
transition = create_transition(complementary_data={"task": "consistent test"})
|
||||
|
||||
result1 = processor(transition)
|
||||
result2 = processor(transition)
|
||||
@@ -687,11 +648,7 @@ def test_empty_string_task(mock_auto_tokenizer):
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=8)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": ""},
|
||||
)
|
||||
transition = create_transition(complementary_data={"task": ""})
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -712,11 +669,7 @@ def test_very_long_task(mock_auto_tokenizer):
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=5, truncation=True)
|
||||
|
||||
long_task = " ".join(["word"] * 100) # Very long task
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": long_task},
|
||||
)
|
||||
transition = create_transition(complementary_data={"task": long_task})
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -761,11 +714,7 @@ def test_custom_padding_side(mock_auto_tokenizer):
|
||||
# Test left padding
|
||||
processor_left = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10, padding_side="left")
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "test task"},
|
||||
)
|
||||
transition = create_transition(complementary_data={"task": "test task"})
|
||||
processor_left(transition)
|
||||
|
||||
assert tracking_tokenizer.padding_side_calls[-1] == "left"
|
||||
@@ -924,6 +873,32 @@ def test_device_detection_from_action():
|
||||
assert attention_mask.device.type == "cuda"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
@require_package("transformers")
|
||||
def test_device_detection_from_complementary_data():
|
||||
"""Test that device is detected from tensors in complementary_data."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with tensor in complementary_data
|
||||
transition = create_transition(
|
||||
observation={"metadata": {"key": "value"}}, # No tensors
|
||||
complementary_data={
|
||||
"task": "comp data test",
|
||||
"index": torch.tensor([42]).cuda(), # Tensor in complementary_data
|
||||
},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that tokenized tensors match complementary_data tensor's device
|
||||
tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"]
|
||||
|
||||
assert tokens.device.type == "cuda"
|
||||
assert attention_mask.device.type == "cuda"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_device_detection_preserves_dtype():
|
||||
"""Test that device detection doesn't affect dtype of tokenized tensors."""
|
||||
@@ -957,9 +932,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
|
||||
# Create pipeline with TokenizerProcessor then DeviceProcessor
|
||||
tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6)
|
||||
device_processor = DeviceProcessor(device="cuda:0")
|
||||
robot_processor = RobotProcessor(
|
||||
[tokenizer_processor, device_processor], to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
robot_processor = RobotProcessor([tokenizer_processor, device_processor])
|
||||
|
||||
# Start with CPU tensors
|
||||
transition = create_transition(
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
@@ -81,12 +81,7 @@ def test_make_vqbet_processor_basic():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
@@ -110,12 +105,7 @@ def test_vqbet_processor_with_images():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
# Create test data with images and states
|
||||
observation = {
|
||||
@@ -141,12 +131,7 @@ def test_vqbet_processor_cuda():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
@@ -179,12 +164,7 @@ def test_vqbet_processor_accelerate_scenario():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU and batched
|
||||
device = torch.device("cuda:0")
|
||||
@@ -211,12 +191,7 @@ def test_vqbet_processor_multi_gpu():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
@@ -240,22 +215,7 @@ def test_vqbet_processor_without_stats():
|
||||
"""Test VQBeT processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_vqbet_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(
|
||||
factory_preprocessor.steps,
|
||||
name=factory_preprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
postprocessor = RobotProcessor(
|
||||
factory_postprocessor.steps,
|
||||
name=factory_postprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
@@ -278,21 +238,14 @@ def test_vqbet_processor_save_and_load():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {
|
||||
@@ -316,12 +269,7 @@ def test_vqbet_processor_mixed_precision():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
@@ -350,12 +298,7 @@ def test_vqbet_processor_large_batch():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
# Test with large batch
|
||||
batch_size = 128
|
||||
@@ -380,12 +323,7 @@ def test_vqbet_processor_sequential_processing():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
# Process multiple samples sequentially
|
||||
results = []
|
||||
|
||||
Reference in New Issue
Block a user