Compare commits

..

1 Commits

Author SHA1 Message Date
AdilZouitine
37103baa07 feat(accelerate): port multi gpu training
Co-authored-by: mshukor <mustafa.shukor97@gmail.com>
2025-08-12 16:10:20 +02:00
74 changed files with 2615 additions and 2778 deletions

View File

@@ -29,7 +29,7 @@ ENV DEBIAN_FRONTEND=noninteractive \
# Install system dependencies and uv (as root)
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential git curl libglib2.0-0 libegl1-mesa-dev ffmpeg \
build-essential git curl libglib2.0-0 libegl1-mesa ffmpeg \
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
&& mv /root/.local/bin/uv /usr/local/bin/uv \

View File

@@ -39,8 +39,6 @@
- sections:
- local: notebooks
title: Notebooks
- local: feetech
title: Updating Feetech Firmware
title: "Resources"
- sections:
- local: contributing

View File

@@ -1,71 +0,0 @@
# Feetech Motor Firmware Update
This tutorial guides you through updating the firmware of Feetech motors using the official Feetech software.
## Prerequisites
- Windows computer (Feetech software is only available for Windows)
- Feetech motor control board
- USB cable to connect the control board to your computer
- Feetech motors connected to the control board
## Step 1: Download Feetech Software
1. Visit the official Feetech software download page: [https://www.feetechrc.com/software.html](https://www.feetechrc.com/software.html)
2. Download the latest version of the Feetech debugging software (FD)
3. Install the software on your Windows computer
## Step 2: Hardware Setup
1. Connect your Feetech motors to the motor control board
2. Connect the motor control board to your Windows computer via USB cable
3. Ensure power is supplied to the motors
## Step 3: Configure Connection
1. Launch the Feetech debugging software
2. Select the correct COM port from the port dropdown menu
- If unsure which port to use, check Windows Device Manager under "Ports (COM & LPT)"
3. Set the appropriate baud rate (typically 1000000 for most Feetech motors)
4. Click "Open" to establish communication with the control board
## Step 4: Scan for Motors
1. Once connected, click the "Search" button to detect all connected motors
2. The software will automatically discover and list all motors on the bus
3. Each motor will appear with its ID number
## Step 5: Update Firmware
For each motor you want to update:
1. **Select the motor** from the list by clicking on it
2. **Click on Upgrade tab**:
3. **Click on Online button**:
- If an potential firmware update is found, it will be displayed in the box
4. **Click on Upgrade button**:
- The update progress will be displayed
## Step 6: Verify Update
1. After the update completes, the software should automatically refresh the motor information
2. Verify that the firmware version has been updated to the expected version
## Important Notes
⚠️ **Warning**: Do not disconnect power or USB during firmware updates, it will potentially brick the motor.
## Bonus: Motor Debugging on Linux/macOS
For debugging purposes only, you can use the open-source Feetech Debug Tool:
- **Repository**: [FT_SCServo_Debug_Qt](https://github.com/CarolinePascal/FT_SCServo_Debug_Qt/tree/fix/port-search-timer)
### Installation Instructions
Follow the instructions in the repository to install the tool, for Ubuntu you can directly install it, for MacOS you need to build it from source.
**Limitations:**
- This tool is for debugging and parameter adjustment only
- Firmware updates must still be done on Windows with official Feetech software

View File

@@ -127,11 +127,11 @@ class RewardClassifierConfig:
# Dataset configuration
class DatasetConfig:
repo_id: str # LeRobot dataset repository ID
dataset_root: str # Local dataset root directory
task: str # Task identifier
root: str | None = None # Local dataset root directory
num_episodes_to_record: int = 5 # Number of episodes for recording
replay_episode: int | None = None # Episode index for replay
push_to_hub: bool = False # Whether to push datasets to Hub
num_episodes: int # Number of episodes for recording
episode: int # Episode index for replay
push_to_hub: bool # Whether to push datasets to Hub
```
<!-- prettier-ignore-end -->
@@ -351,7 +351,7 @@ Create a configuration file for recording demonstrations (or edit an existing on
1. Set `mode` to `"record"` at the root level
2. Specify a unique `repo_id` for your dataset in the `dataset` section (e.g., "username/task_name")
3. Set `num_episodes_to_record` in the `dataset` section to the number of demonstrations you want to collect
3. Set `num_episodes` in the `dataset` section to the number of demonstrations you want to collect
4. Set `env.processor.image_preprocessing.crop_params_dict` to `{}` initially (we'll determine crops later)
5. Configure `env.robot`, `env.teleop`, and other hardware settings in the `env` section
@@ -390,10 +390,10 @@ Example configuration section:
},
"dataset": {
"repo_id": "username/pick_lift_cube",
"root": null,
"dataset_root": null,
"task": "pick_and_lift",
"num_episodes_to_record": 15,
"replay_episode": 0,
"num_episodes": 15,
"episode": 0,
"push_to_hub": true
},
"mode": "record",
@@ -626,7 +626,7 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/r
- **mode**: set it to `"record"` to collect a dataset (at root level)
- **dataset.repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
- **dataset.num_episodes_to_record**: Number of episodes to record
- **dataset.num_episodes**: Number of episodes to record
- **env.processor.reset.terminate_on_success**: Whether to automatically terminate episodes when success is detected (default: `true`)
- **env.fps**: Number of frames per second to record
- **dataset.push_to_hub**: Whether to push the dataset to the hub
@@ -664,8 +664,8 @@ Example configuration section for data collection:
"repo_id": "hf_username/dataset_name",
"dataset_root": "data/your_dataset",
"task": "reward_classifier_task",
"num_episodes_to_record": 20,
"replay_episode": null,
"num_episodes": 20,
"episode": 0,
"push_to_hub": true
},
"mode": "record",

View File

@@ -107,10 +107,10 @@ To collect a dataset, set the mode to `record` whilst defining the repo_id and n
},
"dataset": {
"repo_id": "username/sim_dataset",
"root": null,
"dataset_root": null,
"task": "pick_cube",
"num_episodes_to_record": 10,
"replay_episode": null,
"num_episodes": 10,
"episode": 0,
"push_to_hub": true
},
"mode": "record"

View File

@@ -36,10 +36,10 @@ To teleoperate and collect a dataset, we need to modify this config file. Here's
},
"dataset": {
"repo_id": "your_username/il_gym",
"root": null,
"dataset_root": null,
"task": "pick_cube",
"num_episodes_to_record": 30,
"replay_episode": null,
"num_episodes": 30,
"episode": 0,
"push_to_hub": true
},
"mode": "record",
@@ -50,7 +50,7 @@ To teleoperate and collect a dataset, we need to modify this config file. Here's
Key configuration points:
- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"`
- Set `num_episodes_to_record: 30` to collect 30 demonstration episodes
- Set `num_episodes: 30` to collect 30 demonstration episodes
- Ensure `mode` is set to `"record"`
- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"`
- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"`

View File

@@ -1,7 +1,7 @@
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.factory import make_processor
from lerobot.record import record_loop
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.utils.control_utils import init_keyboard_listener
@@ -46,7 +46,7 @@ listener, events = init_keyboard_listener()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
preprocessor, postprocessor = make_pre_post_processors(
preprocessor, postprocessor = make_processor(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,

View File

@@ -20,7 +20,7 @@ from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_featur
from lerobot.datasets.utils import merge_features
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.factory import make_processor
from lerobot.processor.converters import (
to_output_robot_action,
to_transition_robot_observation,
@@ -127,7 +127,7 @@ robot.connect()
episode_idx = 0
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
preprocessor, postprocessor = make_pre_post_processors(
preprocessor, postprocessor = make_processor(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,

View File

@@ -38,8 +38,8 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone import Phone
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.teleoperators.phone.teleop_phone import Phone
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun

View File

@@ -19,7 +19,7 @@ import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action
from lerobot.processor.converters import to_output_robot_action
from lerobot.processor.pipeline import RobotProcessor
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
@@ -49,6 +49,31 @@ kinematics_solver = RobotKinematics(
joint_names=list(robot.bus.motors.keys()),
)
# This method converts the action from the dataset to a transition for pipeline
def action_to_transition(action: dict):
act = {}
# EE pose
for k in ("ee.x", "ee.y", "ee.z", "ee.wx", "ee.wy", "ee.wz"):
if k in action:
act[f"action.{k}"] = float(action[k])
# Gripper: your dataset has absolute position
if "gripper.pos" in action:
act["action.gripper.pos"] = float(action["gripper.pos"])
return {
"observation": None,
"action": act,
"reward": None,
"done": False,
"truncated": False,
"info": {},
"complementary_data": {},
}
# Build pipeline to convert ee pose action to joint action
robot_ee_to_joints = RobotProcessor(
steps=[
@@ -59,7 +84,7 @@ robot_ee_to_joints = RobotProcessor(
initial_guess_current_joints=False, # Because replay is open loop
),
],
to_transition=to_transition_teleop_action,
to_transition=action_to_transition,
to_output=to_output_robot_action,
)

View File

@@ -28,8 +28,8 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone import Phone
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.teleoperators.phone.teleop_phone import Phone
# Initialize the robot and teleoperator
robot_config = SO100FollowerConfig(
@@ -48,8 +48,8 @@ kinematics_solver = RobotKinematics(
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert phone action to ee pose action to joint action
phone_to_robot_joints = RobotProcessor(
# Build pipeline to convert phone action to ee pose action
phone_to_robot_ee_pose = RobotProcessor(
steps=[
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
AddRobotObservationAsComplimentaryData(robot=robot),
@@ -63,6 +63,14 @@ phone_to_robot_joints = RobotProcessor(
max_ee_step_m=0.10,
max_ee_twist_step_rad=0.50,
),
],
to_transition=to_transition_teleop_action,
to_output=lambda tr: tr,
)
# Build pipeline to convert ee pose action to joint action
robot_ee_to_joints = RobotProcessor(
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
@@ -72,7 +80,7 @@ phone_to_robot_joints = RobotProcessor(
speed_factor=20.0,
),
],
to_transition=to_transition_teleop_action,
to_transition=lambda tr: tr,
to_output=to_output_robot_action,
)
@@ -81,11 +89,19 @@ teleop_device.connect()
print("Starting teleop loop. Move your phone to teleoperate the robot.")
while True:
phone_obs = teleop_device.get_action()
if not phone_obs:
time.sleep(0.01)
continue
# Get teleop observation
phone_obs = teleop_device.get_action()
# Phone -> EE pose -> Joints transition
joint_action = phone_to_robot_joints(phone_obs)
# Phone to EE pose transition
ee_transition = phone_to_robot_ee_pose(phone_obs)
# EE pose to Joints transition
joint_action = robot_ee_to_joints(ee_transition)
if joint_action:
robot.send_action(joint_action)

View File

@@ -24,11 +24,6 @@ OBS_IMAGES = "observation.images"
OBS_LANGUAGE = "observation.language"
ACTION = "action"
REWARD = "next.reward"
TRUNCATED = "next.truncated"
DONE = "next.done"
OBS_LANGUAGE_TOKENS = "observation.language.tokens"
OBS_LANGUAGE_ATTENTION_MASK = "observation.language.attention_mask"
ROBOTS = "robots"
ROBOT_TYPE = "robot_type"

View File

@@ -825,8 +825,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
if not episode_data:
episode_buffer = self.episode_buffer
else:
episode_buffer = episode_data
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)

View File

@@ -15,7 +15,6 @@
from collections.abc import Sequence
from typing import Any
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.processor.pipeline import RobotProcessor
@@ -60,26 +59,26 @@ def aggregate_pipeline_dataset_features(
# Go over every feature from the pipeline and merge:
for full_key, ty in all_features.items():
if full_key.startswith(f"{ACTION}."):
if full_key.startswith("action."):
# action.<feat>
if not keep(full_key):
continue
name = full_key[len(f"{ACTION}.") :]
hw.setdefault(ACTION, {})[name] = ty
name = full_key[len("action.") :]
hw.setdefault("action", {})[name] = ty
elif full_key.startswith(f"{OBS_STATE}."):
elif full_key.startswith("observation.state."):
# observation.state.<feat>
if not keep(full_key):
continue
name = full_key[len(f"{OBS_STATE}.") :]
name = full_key[len("observation.state.") :]
hw.setdefault("observation", {})[name] = ty
elif full_key.startswith(f"{OBS_IMAGES}."):
elif full_key.startswith("observation.images."):
# observation.images.<cam>
# images obey ONLY the use_videos flag, not patterns
if not use_videos:
continue
name = full_key[len(f"{OBS_IMAGES}.") :]
name = full_key[len("observation.images.") :]
hw.setdefault("observation", {})[name] = ty
else:
@@ -87,8 +86,8 @@ def aggregate_pipeline_dataset_features(
continue
out: dict[str, dict] = {}
if ACTION in hw:
out.update(hw_to_dataset_features(hw[ACTION], ACTION, use_videos))
if "action" in hw:
out.update(hw_to_dataset_features(hw["action"], "action", use_videos))
if "observation" in hw:
out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos))

View File

@@ -107,8 +107,6 @@ X_SERIES_ENCODINGS_TABLE = {
"Goal_PWM": X_SERIES_CONTROL_TABLE["Goal_PWM"][1],
"Goal_Current": X_SERIES_CONTROL_TABLE["Goal_Current"][1],
"Goal_Velocity": X_SERIES_CONTROL_TABLE["Goal_Velocity"][1],
"Goal_Position": X_SERIES_CONTROL_TABLE["Goal_Position"][1],
"Present_Position": X_SERIES_CONTROL_TABLE["Present_Position"][1],
"Present_PWM": X_SERIES_CONTROL_TABLE["Present_PWM"][1],
"Present_Current": X_SERIES_CONTROL_TABLE["Present_Current"][1],
"Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1],

View File

@@ -287,7 +287,7 @@ class ACT(nn.Module):
└───────────────────────┘
"""
def __init__(self, config: ACTConfig):
def __init__(self, config: ACTConfig, dataset_stats=None):
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
super().__init__()

View File

@@ -20,7 +20,6 @@ from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
@@ -28,17 +27,9 @@ from lerobot.processor import (
)
def make_act_pre_post_processors(
config: ACTConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
def make_act_processor(
config: ACTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
@@ -55,16 +46,6 @@ def make_act_pre_post_processors(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
RobotProcessor(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -21,7 +21,6 @@ from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
@@ -29,17 +28,9 @@ from lerobot.processor import (
)
def make_diffusion_pre_post_processors(
config: DiffusionConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
def make_diffusion_processor(
config: DiffusionConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
@@ -56,15 +47,6 @@ def make_diffusion_pre_post_processors(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
RobotProcessor(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -17,9 +17,10 @@
from __future__ import annotations
import logging
from typing import Any, TypedDict
from typing import Any, TypedDict, cast
import torch
from torch import nn
from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
@@ -38,7 +39,7 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor.pipeline import ProcessorKwargs, RobotProcessor
from lerobot.processor.pipeline import RobotProcessor
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
@@ -114,11 +115,9 @@ class ProcessorConfigKwargs(TypedDict, total=False):
preprocessor_overrides: dict[str, Any] | None
postprocessor_overrides: dict[str, Any] | None
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
preprocessor_kwargs: ProcessorKwargs | None
postprocessor_kwargs: ProcessorKwargs | None
def make_pre_post_processors(
def make_processor(
policy_cfg: PreTrainedConfig,
pretrained_path: str | None = None,
**kwargs: Unpack[ProcessorConfigKwargs],
@@ -141,116 +140,82 @@ def make_pre_post_processors(
NotImplementedError: If the policy type doesn't have a processor implemented.
"""
if pretrained_path:
# Extract preprocessor and postprocessor kwargs
preprocessor_kwargs = kwargs.get("preprocessor_kwargs", {})
postprocessor_kwargs = kwargs.get("postprocessor_kwargs", {})
return (
RobotProcessor.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get("preprocessor_config_filename", "robot_preprocessor.json"),
overrides=kwargs.get("preprocessor_overrides", {}),
to_transition=preprocessor_kwargs.get("to_transition"),
to_output=preprocessor_kwargs.get("to_output"),
),
RobotProcessor.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get("postprocessor_config_filename", "robot_postprocessor.json"),
overrides=kwargs.get("postprocessor_overrides", {}),
to_transition=postprocessor_kwargs.get("to_transition"),
to_output=postprocessor_kwargs.get("to_output"),
),
)
# Create a new processor based on policy type
if isinstance(policy_cfg, TDMPCConfig):
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
if policy_cfg.type == "tdmpc":
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor
processors = make_tdmpc_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
processors = make_tdmpc_processor(
config=cast(TDMPCConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif isinstance(policy_cfg, DiffusionConfig):
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors
elif policy_cfg.type == "diffusion":
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor
processors = make_diffusion_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
processors = make_diffusion_processor(
cast(DiffusionConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif isinstance(policy_cfg, ACTConfig):
from lerobot.policies.act.processor_act import make_act_pre_post_processors
elif policy_cfg.type == "act":
from lerobot.policies.act.processor_act import make_act_processor
processors = make_act_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
processors = make_act_processor(
config=cast(ACTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif isinstance(policy_cfg, VQBeTConfig):
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
elif policy_cfg.type == "vqbet":
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor
processors = make_vqbet_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
processors = make_vqbet_processor(
config=cast(VQBeTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif isinstance(policy_cfg, PI0Config):
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
elif policy_cfg.type == "pi0":
from lerobot.policies.pi0.processor_pi0 import make_pi0_processor
processors = make_pi0_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
processors = make_pi0_processor(
config=cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif isinstance(policy_cfg, PI0FASTConfig):
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
elif policy_cfg.type == "pi0fast":
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_processor
processors = make_pi0fast_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
processors = make_pi0fast_processor(
cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif isinstance(policy_cfg, SACConfig):
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
elif policy_cfg.type == "sac":
from lerobot.policies.sac.processor_sac import make_sac_processor
processors = make_sac_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
processors = make_sac_processor(
cast(SACConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif isinstance(policy_cfg, RewardClassifierConfig):
elif policy_cfg.type == "reward_classifier":
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
processors = make_classifier_processor(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
cast(RewardClassifierConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif isinstance(policy_cfg, SmolVLAConfig):
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors
elif policy_cfg.type == "smolvla":
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_processor
processors = make_smolvla_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
processors = make_smolvla_processor(
cast(SmolVLAConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
else:
@@ -330,7 +295,7 @@ def make_policy(
policy = policy_cls(**kwargs)
policy.to(cfg.device)
assert isinstance(policy, torch.nn.Module)
assert isinstance(policy, nn.Module)
# policy = torch.compile(policy, mode="reduce-overhead")

View File

@@ -14,68 +14,82 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
ProcessorKwargs,
RobotProcessor,
ToBatchProcessor,
TokenizerProcessor,
UnnormalizerProcessor,
)
from lerobot.processor.pipeline import (
ComplementaryDataProcessor,
EnvTransition,
ProcessorStep,
ProcessorStepRegistry,
TransitionKey,
)
from lerobot.processor.rename_processor import RenameProcessor
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
class Pi0NewLineProcessor(ComplementaryDataProcessor):
class Pi0NewLineProcessor(ProcessorStep):
"""Add a new line to the end of the task if it doesn't have one.
This is required for the PaliGemma tokenizer.
"""
def complementary_data(self, complementary_data):
if "task" not in complementary_data:
return complementary_data
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Check if complementary_data exists
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None or "task" not in complementary_data:
return transition
task = complementary_data["task"]
if task is None:
return complementary_data
new_complementary_data = dict(complementary_data)
return transition
# Handle both string and list of strings
if isinstance(task, str):
# Single string: add newline if not present
if not task.endswith("\n"):
new_complementary_data["task"] = f"{task}\n"
complementary_data["task"] = f"{task}\n"
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: add newline to each if not present
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
# If task is neither string nor list of strings, leave unchanged
return new_complementary_data
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Add tokenized task features to the features."""
return features
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}
def make_pi0_pre_post_processors(
config: PI0Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
def make_pi0_processor(
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
@@ -102,15 +116,6 @@ def make_pi0_pre_post_processors(
),
]
return (
RobotProcessor(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -21,7 +21,6 @@ from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
@@ -29,17 +28,9 @@ from lerobot.processor import (
)
def make_pi0fast_pre_post_processors(
config: PI0Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
def make_pi0fast_processor(
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessor(
@@ -56,15 +47,6 @@ def make_pi0fast_pre_post_processors(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
RobotProcessor(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -22,7 +22,6 @@ from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
@@ -30,17 +29,9 @@ from lerobot.processor import (
)
def make_sac_pre_post_processors(
config: SACConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
def make_sac_processor(
config: SACConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
@@ -57,15 +48,6 @@ def make_sac_pre_post_processors(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
RobotProcessor(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -20,22 +20,13 @@ from lerobot.processor import (
DeviceProcessor,
IdentityProcessor,
NormalizerProcessor,
ProcessorKwargs,
RobotProcessor,
)
def make_classifier_processor(
config: RewardClassifierConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
config: RewardClassifierConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
NormalizerProcessor(
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
@@ -46,16 +37,6 @@ def make_classifier_processor(
DeviceProcessor(device=config.device),
]
output_steps = [DeviceProcessor(device="cpu"), IdentityProcessor()]
return (
RobotProcessor(
steps=input_steps,
name="classifier_preprocessor",
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name="classifier_postprocessor",
**postprocessor_kwargs,
),
return RobotProcessor(steps=input_steps, name="classifier_preprocessor"), RobotProcessor(
steps=output_steps, name="classifier_postprocessor"
)

View File

@@ -13,38 +13,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
TokenizerProcessor,
UnnormalizerProcessor,
)
from lerobot.processor.pipeline import (
ComplementaryDataProcessor,
ProcessorStepRegistry,
)
from lerobot.processor.pipeline import EnvTransition, ProcessorStep, ProcessorStepRegistry, TransitionKey
def make_smolvla_pre_post_processors(
config: SmolVLAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
def make_smolvla_processor(
config: SmolVLAConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessor(
@@ -68,42 +58,53 @@ def make_smolvla_pre_post_processors(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
RobotProcessor(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
class SmolVLANewLineProcessor(ComplementaryDataProcessor):
class SmolVLANewLineProcessor(ProcessorStep):
"""Add a new line to the end of the task if it doesn't have one."""
def complementary_data(self, complementary_data):
if "task" not in complementary_data:
return complementary_data
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Check if complementary_data exists
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None or "task" not in complementary_data:
return transition
task = complementary_data["task"]
if task is None:
return complementary_data
new_complementary_data = dict(complementary_data)
return transition
# Handle both string and list of strings
if isinstance(task, str):
# Single string: add newline if not present
if not task.endswith("\n"):
new_complementary_data["task"] = f"{task}\n"
complementary_data["task"] = f"{task}\n"
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: add newline to each if not present
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
# If task is neither string nor list of strings, leave unchanged
return new_complementary_data
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Adds nothing to the features."""
return features
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}

View File

@@ -21,7 +21,6 @@ from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
@@ -29,17 +28,9 @@ from lerobot.processor import (
)
def make_tdmpc_pre_post_processors(
config: TDMPCConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
def make_tdmpc_processor(
config: TDMPCConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
@@ -56,15 +47,6 @@ def make_tdmpc_pre_post_processors(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
RobotProcessor(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -22,7 +22,6 @@ from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
@@ -30,17 +29,9 @@ from lerobot.processor import (
)
def make_vqbet_pre_post_processors(
config: VQBeTConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
def make_vqbet_processor(
config: VQBeTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}), # Let the possibility to the user to rename the keys
NormalizerProcessor(
@@ -57,15 +48,6 @@ def make_vqbet_pre_post_processors(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
RobotProcessor(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
)

View File

@@ -15,17 +15,18 @@
# limitations under the License.
from .batch_processor import ToBatchProcessor
from .delta_action_processor import MapDeltaActionToRobotAction, MapTensorToDeltaActionDict
from .delta_action_processor import MapDeltaActionToRobotAction
from .device_processor import DeviceProcessor
from .gym_action_processor import Numpy2TorchActionProcessor, Torch2NumpyActionProcessor
from .hil_processor import (
AddTeleopActionAsComplimentaryData,
AddTeleopEventsAsInfo,
GripperPenaltyProcessor,
ImageCropResizeProcessor,
InterventionActionProcessor,
Numpy2TorchActionProcessor,
RewardClassifierProcessor,
TimeLimitProcessor,
Torch2NumpyActionProcessor,
)
from .joint_observations_processor import JointVelocityProcessor, MotorCurrentProcessor
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor, hotswap_stats
@@ -37,7 +38,6 @@ from .pipeline import (
IdentityProcessor,
InfoProcessor,
ObservationProcessor,
ProcessorKwargs,
ProcessorStep,
ProcessorStepRegistry,
RewardProcessor,
@@ -55,7 +55,6 @@ __all__ = [
"DeviceProcessor",
"DoneProcessor",
"MapDeltaActionToRobotAction",
"MapTensorToDeltaActionDict",
"EnvTransition",
"GripperPenaltyProcessor",
"IdentityProcessor",
@@ -69,7 +68,6 @@ __all__ = [
"UnnormalizerProcessor",
"hotswap_stats",
"ObservationProcessor",
"ProcessorKwargs",
"ProcessorStep",
"ProcessorStepRegistry",
"RenameProcessor",

View File

@@ -11,88 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Any
import torch
from torch import Tensor
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor.pipeline import (
ActionProcessor,
ComplementaryDataProcessor,
EnvTransition,
ObservationProcessor,
ProcessorStep,
ProcessorStepRegistry,
)
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_action")
class ToBatchProcessorAction(ActionProcessor):
"""Process action component in-place, adding batch dimension if needed."""
def action(self, action):
if not isinstance(action, Tensor) or action.dim() != 1:
return action
return action.unsqueeze(0)
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_observation")
class ToBatchProcessorObservation(ObservationProcessor):
"""Process observation component in-place, adding batch dimensions where needed."""
def observation(self, observation):
# Process state observations - add batch dim if 1D
for state_key in [OBS_STATE, OBS_ENV_STATE]:
if state_key in observation:
state_value = observation[state_key]
if isinstance(state_value, Tensor) and state_value.dim() == 1:
observation[state_key] = state_value.unsqueeze(0)
# Process single image observation - add batch dim if 3D
if OBS_IMAGE in observation:
image_value = observation[OBS_IMAGE]
if isinstance(image_value, Tensor) and image_value.dim() == 3:
observation[OBS_IMAGE] = image_value.unsqueeze(0)
# Process multiple image observations - add batch dim if 3D
for key, value in observation.items():
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
observation[key] = value.unsqueeze(0)
return observation
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
class ToBatchProcessorComplementaryData(ComplementaryDataProcessor):
"""Process complementary data in-place, handling task field batching."""
def complementary_data(self, complementary_data):
# Process task field - wrap string in list to add batch dimension
if "task" in complementary_data:
task_value = complementary_data["task"]
if isinstance(task_value, str):
complementary_data["task"] = [task_value]
# Process index field - add batch dim if 0D
if "index" in complementary_data:
index_value = complementary_data["index"]
if isinstance(index_value, Tensor) and index_value.dim() == 0:
complementary_data["index"] = index_value.unsqueeze(0)
# Process task_index field - add batch dim if 0D
if "task_index" in complementary_data:
task_index_value = complementary_data["task_index"]
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
return complementary_data
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor")
class ToBatchProcessor(ProcessorStep):
class ToBatchProcessor:
"""Processor that adds batch dimensions to observations and actions when needed.
This processor ensures that observations and actions have proper batch dimensions for model processing:
@@ -127,16 +59,81 @@ class ToBatchProcessor(ProcessorStep):
```
"""
to_batch_action_processor: ToBatchProcessorAction = field(default_factory=ToBatchProcessorAction)
to_batch_observation_processor: ToBatchProcessorObservation = field(
default_factory=ToBatchProcessorObservation
)
to_batch_complementary_data_processor: ToBatchProcessorComplementaryData = field(
default_factory=ToBatchProcessorComplementaryData
)
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = self.to_batch_action_processor(transition)
transition = self.to_batch_observation_processor(transition)
transition = self.to_batch_complementary_data_processor(transition)
self._process_observation(transition)
self._process_action(transition)
self._process_complementary_data(transition)
return transition
def _process_observation(self, transition: EnvTransition) -> None:
"""Process observation component in-place, adding batch dimensions where needed."""
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return
# Process state observations - add batch dim if 1D
for state_key in [OBS_STATE, OBS_ENV_STATE]:
if state_key in observation:
state_value = observation[state_key]
if isinstance(state_value, Tensor) and state_value.dim() == 1:
observation[state_key] = state_value.unsqueeze(0)
# Process single image observation - add batch dim if 3D
if OBS_IMAGE in observation:
image_value = observation[OBS_IMAGE]
if isinstance(image_value, Tensor) and image_value.dim() == 3:
observation[OBS_IMAGE] = image_value.unsqueeze(0)
# Process multiple image observations - add batch dim if 3D
for key, value in observation.items():
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
observation[key] = value.unsqueeze(0)
def _process_action(self, transition: EnvTransition) -> None:
"""Process action component in-place, adding batch dimension if needed."""
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, Tensor) and action.dim() == 1:
transition[TransitionKey.ACTION] = action.unsqueeze(0)
def _process_complementary_data(self, transition: EnvTransition) -> None:
"""Process complementary data in-place, handling task field batching."""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return
# Process task field - wrap string in list to add batch dimension
if "task" in complementary_data:
task_value = complementary_data["task"]
if isinstance(task_value, str):
complementary_data["task"] = [task_value]
# Process index field - add batch dim if 0D
if "index" in complementary_data:
index_value = complementary_data["index"]
if isinstance(index_value, Tensor) and index_value.dim() == 0:
complementary_data["index"] = index_value.unsqueeze(0)
# Process task_index field - add batch dim if 0D
if "task_index" in complementary_data:
task_index_value = complementary_data["task_index"]
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -18,125 +18,26 @@ from __future__ import annotations
from collections.abc import Iterable, Sequence
from copy import deepcopy
from functools import singledispatch
from typing import Any
import numpy as np
import torch
from scipy.spatial.transform import Rotation
from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
from .pipeline import EnvTransition, TransitionKey
@singledispatch
def to_tensor(
value: Any,
*,
dtype: torch.dtype | None = torch.float32,
device: torch.device | str | None = None,
) -> torch.Tensor:
"""
Convert various data types to PyTorch tensors with configurable options.
This is a unified tensor conversion function using single dispatch to handle
different input types appropriately.
Args:
value: Input value to convert (tensor, array, scalar, sequence, etc.)
dtype: Target tensor dtype. If None, preserves original dtype.
device: Target device for the tensor.
Returns:
PyTorch tensor.
Raises:
TypeError: If the input type is not supported.
"""
raise TypeError(f"Unsupported type for tensor conversion: {type(value)}")
@to_tensor.register(torch.Tensor)
def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle existing PyTorch tensors."""
if dtype is not None:
value = value.to(dtype=dtype)
if device is not None:
value = value.to(device=device)
return value
@to_tensor.register(np.ndarray)
def _(
value: np.ndarray,
*,
dtype=torch.float32,
device=None,
**kwargs,
) -> torch.Tensor:
"""Handle numpy arrays."""
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars
if value.ndim == 0:
# Numpy scalars should be converted to 0-dimensional tensors
scalar_value = value.item()
return torch.tensor(scalar_value, dtype=dtype, device=device)
# Create tensor from numpy array (torch.from_numpy handles contiguity automatically)
tensor = torch.from_numpy(value)
# Apply dtype conversion if specified
if dtype is not None:
tensor = tensor.to(dtype=dtype)
if device is not None:
tensor = tensor.to(device=device)
return tensor
@to_tensor.register(int)
@to_tensor.register(float)
@to_tensor.register(np.integer)
@to_tensor.register(np.floating)
def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle scalar values including numpy scalars."""
return torch.tensor(value, dtype=dtype, device=device)
@to_tensor.register(list)
@to_tensor.register(tuple)
def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle sequences (lists, tuples)."""
return torch.tensor(value, dtype=dtype, device=device)
@to_tensor.register(dict)
def _(value: dict, *, device=None, **kwargs) -> dict:
"""Handle dictionaries by recursively converting values to tensors."""
if not value:
return {}
result = {}
for key, sub_value in value.items():
if sub_value is None:
continue
if isinstance(sub_value, dict):
# Recursively process nested dictionaries
result[key] = to_tensor(
sub_value,
device=device,
**kwargs,
)
continue
# Convert individual values to tensors
result[key] = to_tensor(
sub_value,
device=device,
**kwargs,
)
return result
def _to_tensor(x: torch.Tensor | np.ndarray | Sequence[int | float]):
if isinstance(x, torch.Tensor):
return x
if isinstance(x, np.ndarray):
# Keep images (uint8 HWC) and python objects as-is
if x.dtype == np.uint8 or x.dtype == np.object_:
return x
# Scalars/arrays to float32 tensor
return torch.as_tensor(x, dtype=torch.float32)
# Anything else to float32 tensor
return torch.as_tensor(x, dtype=torch.float32)
def _from_tensor(x: Any):
@@ -152,7 +53,7 @@ def _is_image(arr: Any) -> bool:
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
state, images = {}, {}
for k, v in obs.items():
if "image" in k.lower() or _is_image(v):
if _is_image(v):
images[k] = v
else:
state[k] = v
@@ -181,11 +82,11 @@ def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition:
for k, v in action.items():
# Check if the value is a type that should not be converted to a tensor.
if isinstance(v, (Rotation, dict)):
act_dict[f"{ACTION}.{k}"] = v
act_dict[f"action.{k}"] = v
continue
arr = np.array(v) if np.isscalar(v) else v
act_dict[f"{ACTION}.{k}"] = to_tensor(arr)
act_dict[f"action.{k}"] = _to_tensor(arr)
return make_obs_act_transition(act=act_dict)
@@ -200,10 +101,10 @@ def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransitio
obs_dict: dict[str, Any] = {}
for k, v in state.items():
arr = np.array(v) if np.isscalar(v) else v
obs_dict[f"{OBS_STATE}.{k}"] = to_tensor(arr)
obs_dict[f"observation.state.{k}"] = _to_tensor(arr)
for cam, img in images.items():
obs_dict[f"{OBS_IMAGES}.{cam}"] = img
obs_dict[f"observation.images.{cam}"] = img
return make_obs_act_transition(obs=obs_dict)
@@ -215,12 +116,9 @@ def to_output_robot_action(transition: EnvTransition) -> dict[str, Any]:
out: dict[str, Any] = {}
action_dict = transition.get(TransitionKey.ACTION) or {}
if action_dict is None:
return out
for k, v in action_dict.items():
if isinstance(k, str) and k.startswith(f"{ACTION}.") and k.endswith((".pos", ".vel")):
out_key = k[len(f"{ACTION}.") :] # Strip the 'action.' prefix.
if isinstance(k, str) and k.startswith("action.") and k.endswith((".pos", ".vel")):
out_key = k[len("action.") :] # Strip the 'action.' prefix.
out[out_key] = float(v)
return out
@@ -251,9 +149,9 @@ def to_dataset_frame(
- info dict
- *_is_pad flags and task from complementary_data
"""
action_names = features.get(ACTION, {}).get("names", [])
obs_state_names = features.get(OBS_STATE, {}).get("names", [])
image_keys = [k for k in features if k.startswith(OBS_IMAGES)]
action_names = features.get("action", {}).get("names", [])
obs_state_names = features.get("observation.state", {}).get("names", [])
image_keys = [k for k in features if k.startswith("observation.images.")]
def _merge(base: EnvTransition, other: EnvTransition) -> EnvTransition:
out = deepcopy(base)
@@ -297,20 +195,21 @@ def to_dataset_frame(
# Observation.state vector
if obs_state_names:
vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
batch[OBS_STATE] = np.asarray(vals, dtype=np.float32)
vals = [_from_tensor(obs.get(f"observation.state.{n}", 0.0)) for n in obs_state_names]
batch["observation.state"] = np.asarray(vals, dtype=np.float32)
# Action vector
if action_names:
vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
batch[ACTION] = np.asarray(vals, dtype=np.float32)
vals = [_from_tensor(act.get(f"action.{n}", 0.0)) for n in action_names]
batch["action"] = np.asarray(vals, dtype=np.float32)
# Next.* fields
if tr.get(TransitionKey.REWARD) is not None:
batch[REWARD] = _from_tensor(tr[TransitionKey.REWARD])
batch["next.reward"] = _from_tensor(tr[TransitionKey.REWARD])
if tr.get(TransitionKey.DONE) is not None:
batch[DONE] = _from_tensor(tr[TransitionKey.DONE])
batch["next.done"] = _from_tensor(tr[TransitionKey.DONE])
if tr.get(TransitionKey.TRUNCATED) is not None:
batch[TRUNCATED] = _from_tensor(tr[TransitionKey.TRUNCATED])
batch["next.truncated"] = _from_tensor(tr[TransitionKey.TRUNCATED])
# Complementary data flags and task
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}

View File

@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from dataclasses import dataclass, field
from torch import Tensor
@@ -22,30 +22,6 @@ from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
@dataclass
class MapTensorToDeltaActionDict(ActionProcessor):
"""
Map a tensor to a delta action dictionary.
"""
def action(self, action: Tensor) -> dict:
if isinstance(action, dict):
return action
if action.dim() > 1:
action = action.squeeze(0)
# TODO (maractingi): add rotation
delta_action = {
"action.delta_x": action[0],
"action.delta_y": action[1],
"action.delta_z": action[2],
}
if action.shape[0] > 3:
delta_action["action.gripper"] = action[3]
return delta_action
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
@dataclass
class MapDeltaActionToRobotAction(ActionProcessor):
@@ -77,25 +53,35 @@ class MapDeltaActionToRobotAction(ActionProcessor):
# Scale factors for delta movements
position_scale: float = 1.0
rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard
noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise
gripper_deadzone: float = 0.1 # Threshold for gripper activation
_prev_enabled: bool = field(default=False, init=False, repr=False)
def action(self, action: dict | Tensor | None) -> dict:
if action is None:
return {}
def action(self, action: dict) -> dict:
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
delta_x = action.pop("action.delta_x", 0.0)
delta_y = action.pop("action.delta_y", 0.0)
delta_z = action.pop("action.delta_z", 0.0)
gripper = action.pop("action.gripper", 1.0) # Default to "stay" (1.0)
if isinstance(action, dict):
delta_x = action.pop("action.delta_x", 0.0)
delta_y = action.pop("action.delta_y", 0.0)
delta_z = action.pop("action.delta_z", 0.0)
gripper = action.pop("action.gripper", 1.0) # Default to "stay" (1.0)
else:
delta_x = action[0].item()
delta_y = action[1].item()
delta_z = action[2].item()
gripper = action[3].item()
# Determine if the teleoperator is actively providing input
# Consider enabled if any significant movement delta is detected
position_magnitude = (delta_x**2 + delta_y**2 + delta_z**2) ** 0.5 # Use Euclidean norm for position
enabled = position_magnitude > self.noise_threshold # Small threshold to avoid noise
position_magnitude = abs(delta_x) + abs(delta_y) + abs(delta_z)
enabled = position_magnitude > 1e-6 # Small threshold to avoid noise
# Scale the deltas appropriately
scaled_delta_x = delta_x * self.position_scale
scaled_delta_y = delta_y * self.position_scale
scaled_delta_z = delta_z * self.position_scale
scaled_delta_x = float(delta_x) * self.position_scale
scaled_delta_y = float(delta_y) * self.position_scale
scaled_delta_z = float(delta_z) * self.position_scale
# For gamepad/keyboard, we don't have rotation input, so set to 0
# These could be extended in the future for more sophisticated teleoperators
@@ -115,6 +101,7 @@ class MapDeltaActionToRobotAction(ActionProcessor):
"action.gripper": float(gripper),
}
self._prev_enabled = enabled
return action
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
@@ -133,3 +120,6 @@ class MapDeltaActionToRobotAction(ActionProcessor):
}
)
return features
def reset(self):
self._prev_enabled = False

View File

@@ -18,13 +18,14 @@ from typing import Any
import torch
from lerobot.processor.pipeline import EnvTransition, ProcessorStep, ProcessorStepRegistry, TransitionKey
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
from lerobot.utils.utils import get_safe_torch_device
@ProcessorStepRegistry.register("device_processor")
@dataclass
class DeviceProcessor(ProcessorStep):
class DeviceProcessor:
"""Processes transitions by moving tensors to the specified device and optionally converting float dtypes.
This processor ensures that all tensors in the transition are moved to the
@@ -35,30 +36,32 @@ class DeviceProcessor(ProcessorStep):
device: str = "cpu"
float_dtype: str | None = None
DTYPE_MAPPING = {
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
"bfloat16": torch.bfloat16,
"half": torch.float16,
"float": torch.float32,
"double": torch.float64,
}
_device: torch.device | None = None
def __post_init__(self):
self._device: torch.device = get_safe_torch_device(self.device)
self.device = self._device.type # cuda might have changed to cuda:1
self._device = get_safe_torch_device(self.device)
self.device = self._device.type
self.non_blocking = "cuda" in str(self.device)
# Validate and convert float_dtype string to torch dtype
if self.float_dtype is not None:
if self.float_dtype not in self.DTYPE_MAPPING:
dtype_mapping = {
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
"bfloat16": torch.bfloat16,
"half": torch.float16,
"float": torch.float32,
"double": torch.float64,
}
if self.float_dtype not in dtype_mapping:
available_dtypes = list(dtype_mapping.keys())
raise ValueError(
f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}"
f"Invalid float_dtype '{self.float_dtype}'. Available options: {available_dtypes}"
)
self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype]
self._target_float_dtype = dtype_mapping[self.float_dtype]
else:
self._target_float_dtype = None
@@ -91,38 +94,69 @@ class DeviceProcessor(ProcessorStep):
return tensor
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Create a copy of the transition
new_transition = transition.copy()
simple_tensor_keys = [
TransitionKey.ACTION,
TransitionKey.REWARD,
TransitionKey.DONE,
TransitionKey.TRUNCATED,
]
# Process observation tensors
observation = transition.get(TransitionKey.OBSERVATION)
if observation is not None:
new_observation = {
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
for k, v in observation.items()
}
new_transition[TransitionKey.OBSERVATION] = new_observation
dict_tensor_keys = [
TransitionKey.OBSERVATION,
TransitionKey.COMPLEMENTARY_DATA,
]
# Process action tensor
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, torch.Tensor):
new_transition[TransitionKey.ACTION] = self._process_tensor(action)
# Process simple tensors
for key in simple_tensor_keys:
value = transition.get(key)
if isinstance(value, torch.Tensor):
new_transition[key] = self._process_tensor(value)
# Process reward tensor
reward = transition.get(TransitionKey.REWARD)
if reward is not None and isinstance(reward, torch.Tensor):
new_transition[TransitionKey.REWARD] = self._process_tensor(reward)
# Process dictionary-like tensors
for key in dict_tensor_keys:
data_dict = transition.get(key)
if data_dict is not None:
new_data_dict = {
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
for k, v in data_dict.items()
}
new_transition[key] = new_data_dict
# Process done tensor
done = transition.get(TransitionKey.DONE)
if done is not None and isinstance(done, torch.Tensor):
new_transition[TransitionKey.DONE] = self._process_tensor(done)
# Process truncated tensor
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is not None and isinstance(truncated, torch.Tensor):
new_transition[TransitionKey.TRUNCATED] = self._process_tensor(truncated)
# Process complementary data tensors
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is not None:
new_complementary_data = {}
# Process all items in complementary_data
for key, value in complementary_data.items():
if isinstance(value, torch.Tensor):
new_complementary_data[key] = self._process_tensor(value)
else:
new_complementary_data[key] = value
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {"device": self.device, "float_dtype": self.float_dtype}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -1,64 +0,0 @@
#! /usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
from dataclasses import dataclass
import numpy as np
import torch
from lerobot.processor.converters import to_tensor
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
@ProcessorStepRegistry.register("torch2numpy_action_processor")
@dataclass
class Torch2NumpyActionProcessor(ActionProcessor):
"""Convert PyTorch tensor actions to NumPy arrays."""
squeeze_batch_dim: bool = True
def action(self, action: torch.Tensor) -> np.ndarray:
if not isinstance(action, torch.Tensor):
raise TypeError(
f"Expected torch.Tensor or None, got {type(action).__name__}. "
"Use appropriate processor for non-tensor actions."
)
numpy_action = action.detach().cpu().numpy()
# Remove batch dimensions but preserve action dimensions
# Only squeeze if there's a batch dimension (first dim == 1)
if (
self.squeeze_batch_dim
and numpy_action.shape
and len(numpy_action.shape) > 1
and numpy_action.shape[0] == 1
):
numpy_action = numpy_action.squeeze(0)
return numpy_action
@ProcessorStepRegistry.register("numpy2torch_action_processor")
@dataclass
class Numpy2TorchActionProcessor(ActionProcessor):
"""Convert NumPy array action to PyTorch tensor."""
def action(self, action: np.ndarray) -> torch.Tensor:
if not isinstance(action, np.ndarray):
raise TypeError(
f"Expected np.ndarray or None, got {type(action).__name__}. "
"Use appropriate processor for non-tensor actions."
)
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
return torch_action

View File

@@ -1,4 +1,3 @@
import math
import time
from dataclasses import dataclass
from typing import Any
@@ -8,23 +7,19 @@ import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.configs.types import PolicyFeature
from lerobot.constants import ACTION
from lerobot.processor.pipeline import (
ActionProcessor,
ComplementaryDataProcessor,
EnvTransition,
InfoProcessor,
ObservationProcessor,
ProcessorStep,
ProcessorStepRegistry,
TransitionKey,
TruncatedProcessor,
)
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
GRIPPER_KEY = "gripper"
DISCRETE_PENALTY_KEY = "discrete_penalty"
TELEOP_ACTION_KEY = "teleop_action"
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
@@ -34,10 +29,10 @@ class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor):
teleop_device: Teleoperator
def complementary_data(self, complementary_data: dict) -> dict:
new_complementary_data = dict(complementary_data)
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
return new_complementary_data
def complementary_data(self, complementary_data: dict | None) -> dict:
complementary_data = {} if complementary_data is None else dict(complementary_data)
complementary_data["teleop_action"] = self.teleop_device.get_action()
return complementary_data
@ProcessorStepRegistry.register("add_teleop_action_as_info")
@@ -47,11 +42,60 @@ class AddTeleopEventsAsInfo(InfoProcessor):
teleop_device: Teleoperator
def info(self, info: dict) -> dict:
new_info = dict(info)
def info(self, info: dict | None) -> dict:
info = {} if info is None else dict(info)
teleop_events = getattr(self.teleop_device, "get_teleop_events", lambda: {})()
new_info.update(teleop_events)
return new_info
info.update(teleop_events)
return info
@ProcessorStepRegistry.register("torch2numpy_action_processor")
@dataclass
class Torch2NumpyActionProcessor(ActionProcessor):
"""Convert PyTorch tensor actions to NumPy arrays."""
squeeze_batch_dim: bool = True
def action(self, action: torch.Tensor | None) -> np.ndarray | None:
if action is None:
return None
if not isinstance(action, torch.Tensor):
raise TypeError(
f"Expected torch.Tensor or None, got {type(action).__name__}. "
"Use appropriate processor for non-tensor actions."
)
numpy_action = action.detach().cpu().numpy()
# Remove batch dimensions but preserve action dimensions
# Only squeeze if there's a batch dimension (first dim == 1)
if (
self.squeeze_batch_dim
and numpy_action.shape
and len(numpy_action.shape) > 1
and numpy_action.shape[0] == 1
):
numpy_action = numpy_action.squeeze(0)
return numpy_action
@ProcessorStepRegistry.register("numpy2torch_action_processor")
@dataclass
class Numpy2TorchActionProcessor(ActionProcessor):
"""Convert NumPy array action to PyTorch tensor."""
def action(self, action: np.ndarray | None) -> torch.Tensor | None:
if action is None:
return None
if not isinstance(action, np.ndarray):
raise TypeError(
f"Expected np.ndarray or None, got {type(action).__name__}. "
"Use appropriate processor for non-tensor actions."
)
torch_action = torch.from_numpy(action)
return torch_action
@ProcessorStepRegistry.register("image_crop_resize_processor")
@@ -62,7 +106,10 @@ class ImageCropResizeProcessor(ObservationProcessor):
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
resize_size: tuple[int, int] | None = None
def observation(self, observation: dict) -> dict:
def observation(self, observation: dict | None) -> dict | None:
if observation is None:
return None
if self.resize_size is None and not self.crop_params_dict:
return observation
@@ -106,45 +153,63 @@ class ImageCropResizeProcessor(ObservationProcessor):
@dataclass
@ProcessorStepRegistry.register("time_limit_processor")
class TimeLimitProcessor(TruncatedProcessor):
class TimeLimitProcessor:
"""Track episode steps and enforce time limits."""
max_episode_steps: int
current_step: int = 0
def truncated(self, truncated):
def __call__(self, transition: EnvTransition) -> EnvTransition:
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is None:
return transition
self.current_step += 1
if self.current_step >= self.max_episode_steps:
truncated = True
# TODO (steven): missing an else truncated = False?
return truncated
new_transition = transition.copy()
new_transition[TransitionKey.TRUNCATED] = truncated
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"max_episode_steps": self.max_episode_steps,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
self.current_step = 0
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessor(ComplementaryDataProcessor):
class GripperPenaltyProcessor:
"""Apply penalty for inappropriate gripper usage."""
penalty: float = -0.01
max_gripper_pos: float = 30.0
def complementary_data(self, complementary_data):
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Calculate gripper penalty and add to complementary data."""
action = self.transition.get(TransitionKey.ACTION)
action = transition.get(TransitionKey.ACTION)
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None or action is None:
return transition
current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None)
if current_gripper_pos is None:
return complementary_data
return transition
gripper_action = action[f"{ACTION}.{GRIPPER_KEY}.pos"]
gripper_action = action[f"action.{GRIPPER_KEY}.pos"]
gripper_action_normalized = gripper_action / self.max_gripper_pos
# Normalize gripper state and action
@@ -157,11 +222,19 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
gripper_penalty = self.penalty * int(gripper_penalty_bool)
# Add penalty information to complementary data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
# Create new complementary data with penalty info
new_complementary_data = dict(complementary_data)
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
new_complementary_data["discrete_penalty"] = gripper_penalty
return new_complementary_data
# Create new transition with updated complementary data
new_transition = transition.copy()
existing_comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
existing_comp_data.update(new_complementary_data)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = existing_comp_data # type: ignore[misc]
return new_transition
def get_config(self) -> dict[str, Any]:
return {
@@ -169,14 +242,23 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
"max_gripper_pos": self.max_gripper_pos,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
"""Reset the processor state."""
self.last_gripper_state = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("intervention_action_processor")
class InterventionActionProcessor(ProcessorStep):
class InterventionActionProcessor:
"""Handle human intervention actions and episode termination."""
use_gripper: bool = False
@@ -189,8 +271,7 @@ class InterventionActionProcessor(ProcessorStep):
# Get intervention signals from complementary data
info = transition.get(TransitionKey.INFO, {})
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
teleop_action = complementary_data.get(TELEOP_ACTION_KEY, {})
teleop_action = info.get("teleop_action", {})
is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False)
terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False)
success = info.get(TeleopEvents.SUCCESS, False)
@@ -203,12 +284,12 @@ class InterventionActionProcessor(ProcessorStep):
if isinstance(teleop_action, dict):
# Convert teleop_action dict to tensor format
action_list = [
teleop_action.get(f"{ACTION}.delta_x", 0.0),
teleop_action.get(f"{ACTION}.delta_y", 0.0),
teleop_action.get(f"{ACTION}.delta_z", 0.0),
teleop_action.get("action.delta_x", 0.0),
teleop_action.get("action.delta_y", 0.0),
teleop_action.get("action.delta_z", 0.0),
]
if self.use_gripper:
action_list.append(teleop_action.get(GRIPPER_KEY, 1.0))
action_list.append(teleop_action.get("gripper", 1.0))
elif isinstance(teleop_action, np.ndarray):
action_list = teleop_action.tolist()
else:
@@ -232,7 +313,7 @@ class InterventionActionProcessor(ProcessorStep):
# Update complementary data with teleop action
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
complementary_data[TELEOP_ACTION_KEY] = new_transition.get(TransitionKey.ACTION)
complementary_data["teleop_action"] = new_transition.get(TransitionKey.ACTION)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
@@ -240,13 +321,24 @@ class InterventionActionProcessor(ProcessorStep):
def get_config(self) -> dict[str, Any]:
return {
"use_gripper": self.use_gripper,
"terminate_on_success": self.terminate_on_success,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("reward_classifier_processor")
class RewardClassifierProcessor(ProcessorStep):
class RewardClassifierProcessor:
"""Apply reward classification to image observations."""
pretrained_path: str | None = None
@@ -288,7 +380,7 @@ class RewardClassifierProcessor(ProcessorStep):
reward = transition.get(TransitionKey.REWARD, 0.0)
terminated = transition.get(TransitionKey.DONE, False)
if math.isclose(success, 1, abs_tol=1e-2):
if success == 1.0:
reward = self.success_reward
if self.terminate_on_success:
terminated = True
@@ -312,3 +404,15 @@ class RewardClassifierProcessor(ProcessorStep):
"success_reward": self.success_reward,
"terminate_on_success": self.terminate_on_success,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -13,28 +13,30 @@ from lerobot.robots import Robot
@dataclass
@ProcessorStepRegistry.register("joint_velocity_processor")
class JointVelocityProcessor(ObservationProcessor):
class JointVelocityProcessor:
"""Add joint velocity information to observations."""
dt: float = 0.1
joint_velocity_limits: float = 100.0
dt: float = 1.0 / 10
num_dof: int | None = None
last_joint_positions: torch.Tensor | None = None
def observation(self, observation: dict) -> dict:
def observation(self, observation: dict | None) -> dict | None:
if observation is None:
return None
# Get current joint positions (assuming they're in observation.state)
current_positions = observation.get("observation.state")
if current_positions is None:
# TODO(steven): if we get here, then the transform_features method will not hold
return observation
# Initialize last joint positions if not already set
if self.last_joint_positions is None:
self.last_joint_positions = current_positions.clone()
joint_velocities = torch.zeros_like(current_positions)
else:
# Compute velocities
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
# Compute velocities
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
self.last_joint_positions = current_positions.clone()
# Extend observation with velocities
@@ -48,6 +50,7 @@ class JointVelocityProcessor(ObservationProcessor):
def get_config(self) -> dict[str, Any]:
return {
"joint_velocity_limits": self.joint_velocity_limits,
"dt": self.dt,
}
@@ -55,11 +58,12 @@ class JointVelocityProcessor(ObservationProcessor):
self.last_joint_positions = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
if "observation.state" in features:
if "observation.state" in features and self.num_dof is not None:
from lerobot.configs.types import PolicyFeature
original_feature = features["observation.state"]
# Double the shape to account for positions + velocities
new_shape = (original_feature.shape[0] * 2,) + original_feature.shape[1:]
new_shape = (original_feature.shape[0] + self.num_dof,) + original_feature.shape[1:]
features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape)
return features
@@ -71,7 +75,10 @@ class MotorCurrentProcessor(ObservationProcessor):
robot: Robot | None = None
def observation(self, observation: dict) -> dict:
def observation(self, observation: dict | None) -> dict | None:
if observation is None:
return None
# Get current values from robot state
if self.robot is None:
return observation

View File

@@ -1,88 +1,232 @@
from __future__ import annotations
from collections.abc import Mapping
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any
import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.converters import to_tensor
from lerobot.processor.pipeline import (
EnvTransition,
ProcessorStep,
ProcessorStepRegistry,
RobotProcessor,
TransitionKey,
)
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, RobotProcessor, TransitionKey
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
"""Convert numpy arrays and other types to torch tensors."""
tensor_stats: dict[str, dict[str, Tensor]] = {}
for key, sub in stats.items():
tensor_stats[key] = {}
for stat_name, value in sub.items():
if isinstance(value, np.ndarray):
tensor_val = torch.from_numpy(value.astype(np.float32))
elif isinstance(value, torch.Tensor):
tensor_val = value.to(dtype=torch.float32)
elif isinstance(value, (int, float, list, tuple)):
tensor_val = torch.tensor(value, dtype=torch.float32)
else:
raise TypeError(f"Unsupported type for stats['{key}']['{stat_name}']: {type(value)}")
tensor_stats[key][stat_name] = tensor_val
return tensor_stats
@dataclass
class _NormalizationMixin:
"""
A mixin class providing core functionality for normalization and unnormalization.
@ProcessorStepRegistry.register(name="normalizer_processor")
class NormalizerProcessor:
"""Normalizes observations and actions in a single processor step.
This class manages normalization statistics, their conversion to tensors, device placement,
and the application of normalization transformations. It is designed to be inherited by
concrete ProcessorStep implementations.
This processor handles normalization of both observation and action tensors
using either mean/std normalization or min/max scaling to a [-1, 1] range.
For each tensor key in the stats dictionary, the processor will:
- Use mean/std normalization if those statistics are provided: (x - mean) / std
- Use min/max scaling if those statistics are provided: 2 * (x - min) / (max - min) - 1
The processor can be configured to normalize only specific keys by setting
the normalize_keys parameter.
"""
# Features and normalisation map are mandatory to match the design of normalize.py
features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode]
# Pre-computed statistics coming from dataset.meta.stats for instance.
stats: dict[str, dict[str, Any]] | None = None
device: torch.device | str | None = None
# Explicit subset of keys to normalise. If ``None`` every key (except
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
# membership checks O(1).
normalize_keys: set[str] | None = None
eps: float = 1e-8
normalize_observation_keys: set[str] | None = None
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@classmethod
def from_lerobot_dataset(
cls,
dataset: LeRobotDataset,
features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode],
*,
normalize_keys: set[str] | None = None,
eps: float = 1e-8,
) -> NormalizerProcessor:
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`.
The features and norm_map parameters are mandatory to match the design
pattern used in normalize.py.
"""
return cls(
features=features,
norm_map=norm_map,
stats=dataset.meta.stats,
normalize_keys=normalize_keys,
eps=eps,
)
def __post_init__(self):
# Robust JSON deserialization handling (guard empty maps)
if self.features:
first_val = next(iter(self.features.values()))
if isinstance(first_val, dict):
reconstructed = {}
for key, ft_dict in self.features.items():
reconstructed[key] = PolicyFeature(
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
)
self.features = reconstructed
# Handle deserialization from JSON config
if self.features and isinstance(list(self.features.values())[0], dict):
# Features came from JSON - need to reconstruct PolicyFeature objects
reconstructed_features = {}
for key, ft_dict in self.features.items():
reconstructed_features[key] = PolicyFeature(
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
)
self.features = reconstructed_features
if self.norm_map:
# if keys are strings (JSON), rebuild enum map
if all(isinstance(k, str) for k in self.norm_map.keys()):
reconstructed = {}
for ft_type_str, norm_mode_str in self.norm_map.items():
reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
self.norm_map = reconstructed
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
# norm_map came from JSON - need to reconstruct enum keys and values
reconstructed_norm_map = {}
for ft_type_str, norm_mode_str in self.norm_map.items():
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
self.norm_map = reconstructed_norm_map
# Convert stats to tensors and move to the target device once during initialization.
# Convert statistics once so we avoid repeated numpy→Tensor conversions
# during runtime.
self.stats = self.stats or {}
self._tensor_stats = to_tensor(self.stats, device=self.device)
self._tensor_stats = _convert_stats_to_tensors(self.stats)
def to(self, device: torch.device | str) -> _NormalizationMixin:
"""Moves the processor's normalization stats to the specified device and returns self."""
self.device = device
self._tensor_stats = to_tensor(self.stats, device=self.device)
return self
# Ensure *normalize_keys* is a set for fast look-ups and compare by
# value later when returning the configuration.
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
self.normalize_keys = set(self.normalize_keys)
def state_dict(self) -> dict[str, Tensor]:
flat: dict[str, Tensor] = {}
for key, sub in self._tensor_stats.items():
for stat_name, tensor in sub.items():
flat[f"{key}.{stat_name}"] = tensor.cpu() # Always save to CPU
return flat
def _normalize_obs(self, observation, normalized_info):
if observation is None:
return None
def load_state_dict(self, state: dict[str, Tensor]) -> None:
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.rsplit(".", 1)
# Load to the processor's configured device.
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
dtype=torch.float32, device=self.device
# Decide which keys should be normalised for this call.
if self.normalize_keys is not None:
keys_to_norm = self.normalize_keys
else:
# Use feature map to skip action keys.
keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION}
processed = dict(observation)
for key in keys_to_norm:
if key not in processed or key not in self.features:
continue
# Check the normalization mode for this feature type
feature = self.features[key]
norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY)
# Skip normalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
normalized_info[key] = "IDENTITY"
continue
# Skip if no stats available for this key
if key not in self._tensor_stats:
continue
orig_val = processed[key]
tensor = (
orig_val.to(dtype=torch.float32)
if isinstance(orig_val, torch.Tensor)
else torch.as_tensor(orig_val, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = (tensor - mean) / (std + self.eps)
normalized_info[key] = "MEAN_STD"
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
normalized_info[key] = "MIN_MAX"
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
return processed
def _normalize_action(self, action, normalized_info):
if action is None:
return action
# Check the normalization mode for actions
norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY)
# Skip normalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
normalized_info["action"] = "IDENTITY"
return action
# Skip if no stats available for actions
if "action" not in self._tensor_stats:
return action
tensor = (
action.to(dtype=torch.float32)
if isinstance(action, torch.Tensor)
else torch.as_tensor(action, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
normalized_info["action"] = "MEAN_STD"
return (tensor - mean) / (std + self.eps)
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
normalized_info["action"] = "MIN_MAX"
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
# If we reach here, the required stats for the normalization mode are not available
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Track what was normalized
normalized_info = {}
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION), normalized_info)
action = self._normalize_action(transition.get(TransitionKey.ACTION), normalized_info)
# Create a new transition with normalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
# Add normalization info to complementary data
if normalized_info:
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = {} if comp_data is None else dict(comp_data)
comp_data["normalized_keys"] = normalized_info
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
config = {
@@ -92,177 +236,240 @@ class _NormalizationMixin:
},
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
}
if self.normalize_observation_keys is not None:
config["normalize_observation_keys"] = sorted(self.normalize_observation_keys)
if self.normalize_keys is not None:
# Serialise as a list for YAML / JSON friendliness
config["normalize_keys"] = sorted(self.normalize_keys)
return config
def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]:
new_observation = dict(observation)
for key, feature in self.features.items():
if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys:
continue
if feature.type != FeatureType.ACTION and key in new_observation:
tensor = torch.as_tensor(new_observation[key], dtype=torch.float32)
new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse)
return new_observation
def state_dict(self) -> dict[str, Tensor]:
flat = {}
for key, sub in self._tensor_stats.items():
for stat_name, tensor in sub.items():
flat[f"{key}.{stat_name}"] = tensor
return flat
def _normalize_action(self, action: Any, inverse: bool) -> Tensor:
tensor = torch.as_tensor(action, dtype=torch.float32)
processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse)
return processed_action
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.rsplit(".", 1)
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
def _apply_transform(
self, tensor: Tensor, key: str, feature_type: FeatureType, *, inverse: bool = False
) -> Tensor:
"""Core logic to apply normalization or unnormalization."""
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
return tensor
def reset(self):
pass
if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX):
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
# Ensure input tensor is on the same device as the stats.
if self.device and tensor.device != self.device:
tensor = tensor.to(self.device)
# For Accelerate compatibility: move stats to match input tensor device
input_device = tensor.device
stats = self._tensor_stats[key]
tensor = tensor.to(dtype=torch.float32)
# Move stats to input device if needed
stats_device = next(iter(stats.values())).device
if stats_device != input_device:
stats = to_tensor({key: self._tensor_stats[key]}, device=input_device)[key]
if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
# Avoid division by zero by adding a small epsilon.
denom = std + self.eps
if inverse:
return tensor * std + mean
return (tensor - mean) / denom
if norm_mode == NormalizationMode.MIN_MAX and "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
denom = max_val - min_val
# When min_val == max_val, substitute the denominator with a small epsilon
# to prevent division by zero. This consistently maps an input equal to
# min_val to -1, ensuring a stable transformation.
denom = torch.where(
denom == 0, torch.tensor(self.eps, device=input_device, dtype=torch.float32), denom
)
if inverse:
# Map from [-1, 1] back to [min, max]
return (tensor + 1) / 2 * denom + min_val
# Map from [min, max] to [-1, 1]
return 2 * (tensor - min_val) / denom - 1
# If necessary stats are missing, return input unchanged.
return tensor
@dataclass
@ProcessorStepRegistry.register(name="normalizer_processor")
class NormalizerProcessor(_NormalizationMixin, ProcessorStep):
"""
A processor that applies normalization to observations and actions in a transition.
This class directly implements the normalization logic for both observation and action
components of an `EnvTransition`, using statistics (mean/std or min/max) provided at
initialization.
"""
@classmethod
def from_lerobot_dataset(
cls,
dataset: LeRobotDataset,
features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode],
*,
normalize_observation_keys: set[str] | None = None,
eps: float = 1e-8,
device: torch.device | str | None = None,
) -> NormalizerProcessor:
return cls(
features=features,
norm_map=norm_map,
stats=dataset.meta.stats,
normalize_observation_keys=normalize_observation_keys,
eps=eps,
device=device,
)
def __call__(self, transition: EnvTransition) -> EnvTransition:
new_transition = transition.copy()
# Handle observation normalization.
observation = new_transition.get(TransitionKey.OBSERVATION)
if observation is not None:
new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(
observation, inverse=False
)
# Handle action normalization.
action = new_transition.get(TransitionKey.ACTION)
if action is not None:
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False)
return new_transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register(name="unnormalizer_processor")
class UnnormalizerProcessor(_NormalizationMixin, ProcessorStep):
"""
A processor that applies unnormalization (the inverse of normalization) to
observations and actions in a transition.
class UnnormalizerProcessor:
"""Inverse normalisation for observations and actions.
This is typically used to transform actions from a normalized policy output back into
the original scale for execution in an environment.
Exactly mirrors :class:`NormalizerProcessor` but applies the inverse
transform.
"""
features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode]
stats: dict[str, dict[str, Any]] | None = None
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@classmethod
def from_lerobot_dataset(
cls,
dataset: LeRobotDataset,
features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode],
*,
device: torch.device | str | None = None,
) -> UnnormalizerProcessor:
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, device=device)
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats)
def __post_init__(self):
# Handle deserialization from JSON config
if self.features and isinstance(list(self.features.values())[0], dict):
# Features came from JSON - need to reconstruct PolicyFeature objects
reconstructed_features = {}
for key, ft_dict in self.features.items():
reconstructed_features[key] = PolicyFeature(
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
)
self.features = reconstructed_features
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
# norm_map came from JSON - need to reconstruct enum keys and values
reconstructed_norm_map = {}
for ft_type_str, norm_mode_str in self.norm_map.items():
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
self.norm_map = reconstructed_norm_map
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats)
def _unnormalize_obs(self, observation, unnormalized_info):
if observation is None:
return None
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
processed = dict(observation)
for key in keys:
if key not in processed or key not in self.features:
continue
# Check the normalization mode for this feature type
feature = self.features[key]
norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY)
# Skip unnormalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
unnormalized_info[key] = "IDENTITY"
continue
# Skip if no stats available for this key
if key not in self._tensor_stats:
continue
orig_val = processed[key]
tensor = (
orig_val.to(dtype=torch.float32)
if isinstance(orig_val, torch.Tensor)
else torch.as_tensor(orig_val, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = tensor * std + mean
unnormalized_info[key] = "MEAN_STD"
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
unnormalized_info[key] = "MIN_MAX"
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
return processed
def _unnormalize_action(self, action, unnormalized_info):
if action is None:
return action
# Check the normalization mode for actions
norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY)
# Skip unnormalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
unnormalized_info["action"] = "IDENTITY"
return action
# Skip if no stats available for actions
if "action" not in self._tensor_stats:
return action
tensor = (
action.to(dtype=torch.float32)
if isinstance(action, torch.Tensor)
else torch.as_tensor(action, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
unnormalized_info["action"] = "MEAN_STD"
return tensor * std + mean
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
unnormalized_info["action"] = "MIN_MAX"
return (tensor + 1) / 2 * (max_val - min_val) + min_val
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
# If we reach here, the required stats for the normalization mode are not available
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Track what was unnormalized
unnormalized_info = {}
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION), unnormalized_info)
action = self._unnormalize_action(transition.get(TransitionKey.ACTION), unnormalized_info)
# Create a new transition with unnormalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
# Handle observation unnormalization.
observation = new_transition.get(TransitionKey.OBSERVATION)
if observation is not None:
new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(observation, inverse=True)
# Handle action unnormalization.
action = new_transition.get(TransitionKey.ACTION)
if action is not None:
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True)
# Add unnormalization info to complementary data
if unnormalized_info:
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = {} if comp_data is None else dict(comp_data)
comp_data["unnormalized_keys"] = unnormalized_info
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"features": {
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
},
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
}
def state_dict(self) -> dict[str, Tensor]:
flat = {}
for key, sub in self._tensor_stats.items():
for stat_name, tensor in sub.items():
flat[f"{key}.{stat_name}"] = tensor
return flat
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.rsplit(".", 1)
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
def reset(self):
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor:
"""
Replaces normalization statistics in a RobotProcessor pipeline.
This function creates a deep copy of the provided `RobotProcessor` and updates the
statistics of any `NormalizerProcessor` or `UnnormalizerProcessor` steps within it.
It's useful for adapting a trained policy to a new environment or dataset with
different data distributions.
"""
rp = deepcopy(robot_processor)
for step in rp.steps:
if isinstance(step, _NormalizationMixin):
robot_processor = deepcopy(robot_processor)
for step in robot_processor.steps:
if isinstance(step, NormalizerProcessor) or isinstance(step, UnnormalizerProcessor):
step: NormalizerProcessor | UnnormalizerProcessor
step.stats = stats
# Re-initialize tensor_stats on the correct device.
step._tensor_stats = to_tensor(stats, device=step.device)
return rp
step._tensor_stats = _convert_stats_to_tensors(stats)
return robot_processor
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
"""Rename keys in the stats dictionary according to the provided mapping.
Args:
stats: The statistics dictionary with structure {feature_key: {stat_name: value}}
rename_map: Dictionary mapping old key names to new key names
Returns:
A new stats dictionary with renamed keys
Example:
>>> stats = {"observation.state": {"mean": 0.0, "std": 1.0}, "action": {"mean": 0.5, "std": 0.5}}
>>> rename_map = {"observation.state": "observation.robot_state"}
>>> new_stats = rename_stats(stats, rename_map)
>>> # new_stats will have "observation.robot_state" instead of "observation.state"
"""
renamed_stats = {}
for old_key, sub_stats in stats.items():
# Use the new key if it exists in the rename map, otherwise keep the old key
new_key = rename_map.get(old_key, old_key)
renamed_stats[new_key] = deepcopy(sub_stats)
return renamed_stats

View File

@@ -18,13 +18,12 @@ from __future__ import annotations
import importlib
import json
import os
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Generic, TypedDict, TypeVar, cast
from typing import Any, Protocol, TypedDict, runtime_checkable
import torch
from huggingface_hub import ModelHubMixin, hf_hub_download
@@ -33,9 +32,6 @@ from safetensors.torch import load_file, save_file
from lerobot.configs.types import PolicyFeature
# Type variable for generic processor output type
TOutput = TypeVar("TOutput")
class TransitionKey(str, Enum):
"""Keys for accessing EnvTransition dictionary components."""
@@ -136,7 +132,8 @@ class ProcessorStepRegistry:
cls._registry.clear()
class ProcessorStep(ABC):
@runtime_checkable
class ProcessorStep(Protocol):
"""Structural typing interface for a single processor step.
A step is any callable accepting a full `EnvTransition` dict and
@@ -169,34 +166,17 @@ class ProcessorStep(ABC):
- state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)}
"""
_current_transition: EnvTransition | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition: ...
@property
def transition(self) -> EnvTransition:
"""The current transition being processed by this step."""
if self._current_transition is None:
raise ValueError("Transition is not set. Make sure to call the step with a transition first.")
return self._current_transition
def get_config(self) -> dict[str, Any]: ...
@abstractmethod
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def state_dict(self) -> dict[str, torch.Tensor]: ...
def get_config(self) -> dict[str, Any]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ...
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def reset(self) -> None: ...
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
return None
def reset(self) -> None:
return None
# TODO(Steven): Consider making this abstract so it is more explicit
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ...
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
@@ -219,10 +199,6 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq
metadata without breaking the processor.
"""
# Validate input type
if not isinstance(batch, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
# Extract observation keys
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
observation = observation_keys if observation_keys else None
@@ -286,15 +262,8 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
return batch
class ProcessorKwargs(TypedDict, total=False):
"""Keyword arguments for RobotProcessor constructor."""
to_transition: Callable[[dict[str, Any]], EnvTransition] | None
to_output: Callable[[EnvTransition], Any] | None
@dataclass
class RobotProcessor(ModelHubMixin, Generic[TOutput]):
class RobotProcessor(ModelHubMixin):
"""
Composable, debuggable post-processing processor for robot transitions.
@@ -302,43 +271,20 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` dicts
and batch dictionaries, automatically converting between formats as needed.
The processor is generic over its output type TOutput, which provides better type safety
and clarity about what the processor returns.
Args:
steps: Ordered list of processing steps executed on every call. Defaults to empty list.
name: Human-readable identifier that is persisted inside the JSON config.
Defaults to "RobotProcessor".
to_transition: Function to convert batch dict to EnvTransition dict.
Defaults to _default_batch_to_transition.
to_output: Function to convert EnvTransition dict to the desired output format of type TOutput.
Defaults to _default_transition_to_batch (returns batch dict).
Use identity function (lambda x: x) for EnvTransition output.
to_output: Function to convert EnvTransition dict to the desired output format.
Usually it is a batch dict or EnvTransition dict.
Defaults to _default_transition_to_batch.
before_step_hooks: List of hooks called before each step. Each hook receives the step
index and transition, and can optionally return a modified transition.
after_step_hooks: List of hooks called after each step. Each hook receives the step
index and transition, and can optionally return a modified transition.
Type Safety Examples:
```python
# Default behavior - returns batch dict
processor: RobotProcessor[dict[str, Any]] = RobotProcessor(steps=[some_step1, some_step2])
result: dict[str, Any] = processor(batch_data) # Type checker knows this is a dict
# For EnvTransition output, explicitly specify identity function
transition_processor: RobotProcessor[EnvTransition] = RobotProcessor(
steps=[some_step1, some_step2],
to_output=lambda x: x, # Identity function
)
result: EnvTransition = transition_processor(batch_data) # Type checker knows this is EnvTransition
# For custom output types
processor: RobotProcessor[str] = RobotProcessor(
steps=[custom_step], to_output=lambda t: f"Processed {len(t)} keys"
)
result: str = processor(batch_data) # Type checker knows this is str
```
Hook Semantics:
- Hooks are executed sequentially in the order they were registered. There is no way to
reorder hooks after registration without creating a new pipeline.
@@ -360,13 +306,8 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(
default_factory=lambda: _default_batch_to_transition, repr=False
)
to_output: Callable[[EnvTransition], TOutput] = field(
# Cast is necessary here: Working around Python type-checker limitation.
# _default_transition_to_batch returns dict[str, Any], but we need it to be TOutput
# for the generic to work. When no explicit type is given, TOutput defaults to dict[str, Any],
# making this cast safe.
default_factory=lambda: cast(Callable[[EnvTransition], TOutput], _default_transition_to_batch),
repr=False,
to_output: Callable[[EnvTransition], dict[str, Any] | EnvTransition] = field(
default_factory=lambda: _default_transition_to_batch, repr=False
)
# Processor-level hooks for observation/monitoring
@@ -374,57 +315,98 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
def __call__(self, data: dict[str, Any]) -> TOutput:
def __call__(self, data: EnvTransition | dict[str, Any]):
"""Process data through all steps.
The method accepts a batch dictionary (like the ones returned by ReplayBuffer or
LeRobotDataset). It is first converted to EnvTransition format using to_transition,
then processed through all steps, and finally converted to the output format using to_output.
The method accepts either the classic EnvTransition dict or a batch dictionary
(like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied
it is first converted to the internal dict format using to_transition; after all
steps are executed the dict is transformed back into a batch dict with to_batch and the
result is returned thereby preserving the caller's original data type.
Args:
data: A batch dictionary to process.
data: Either an EnvTransition dict or a batch dictionary to process.
Returns:
The processed data in the format specified by to_output.
The processed data in the same format as the input (EnvTransition or batch dict).
Raises:
ValueError: If the transition is not a valid EnvTransition format.
"""
# Always convert input through to_transition
transition = self.to_transition(data)
# Check if we need to convert back to batch format at the end
_, called_with_batch = self._prepare_transition(data)
# Process through all steps
for idx, processor_step in enumerate(self.steps):
# Apply before hooks
# Use step_through to get the iterator
step_iterator = self.step_through(data)
# Get initial state (before any steps)
current_transition = next(step_iterator)
# Process each step with hooks
for idx, next_transition in enumerate(step_iterator):
# Apply before hooks with current state (before step execution)
for hook in self.before_step_hooks:
hook(idx, transition)
hook(idx, current_transition)
# Execute step
transition = processor_step(transition)
# Move to next state (after step execution)
current_transition = next_transition
# Apply after hooks
# Apply after hooks with updated state
for hook in self.after_step_hooks:
hook(idx, transition)
hook(idx, current_transition)
# Always use to_output for consistent typing
return self.to_output(transition)
# Convert back to original format if needed
if called_with_batch or self.to_output is not _default_transition_to_batch:
return self.to_output(current_transition)
else:
return current_transition
def step_through(self, data: dict[str, Any]) -> Iterable[EnvTransition]:
def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
"""Prepare and validate transition data for processing.
Args:
data: Either an EnvTransition dict or a batch dictionary to process.
Returns:
A tuple of (prepared_transition, called_with_batch_flag)
Raises:
ValueError: If the transition is not a valid EnvTransition format.
"""
# Check if data is already an EnvTransition or needs conversion
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
# It's a batch dict, convert it
called_with_batch = True
transition = self.to_transition(data)
else:
# It's already an EnvTransition
called_with_batch = False
transition = data
# Basic validation
if not isinstance(transition, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
return transition, called_with_batch
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition]:
"""Yield the intermediate results after each processor step.
This is a low-level method that does NOT apply hooks. It simply executes each step
and yields the intermediate results. This allows users to debug the pipeline or
apply custom logic between steps if needed.
Note: This method always yields EnvTransition objects regardless of output format.
If you need the results in the output format, you'll need to convert them
Note: This method always yields EnvTransition objects regardless of input format.
If you need the results in the original input format, you'll need to convert them
using `to_output()`.
Args:
data: A batch dictionary to process.
data: Either an EnvTransition dict or a batch dictionary to process.
Yields:
The intermediate EnvTransition results after each step.
"""
# Always convert input through to_transition
transition = self.to_transition(data)
transition, _ = self._prepare_transition(data)
# Yield initial state
yield transition
@@ -526,10 +508,8 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
revision: str | None = None,
config_filename: str | None = None,
overrides: dict[str, Any] | None = None,
to_transition: Callable[[dict[str, Any]], EnvTransition] | None = None,
to_output: Callable[[EnvTransition], TOutput] | None = None,
**kwargs,
) -> RobotProcessor[TOutput]:
) -> RobotProcessor:
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
Args:
@@ -543,14 +523,9 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
(for registered steps). Values are dictionaries containing parameter overrides
that will be merged with the saved configuration. This is useful for providing
non-serializable objects like environment instances.
to_transition: Function to convert batch dict to EnvTransition dict.
Defaults to _default_batch_to_transition.
to_output: Function to convert EnvTransition dict to the desired output format of type T.
Defaults to _default_transition_to_batch (returns batch dict).
Use identity function (lambda x: x) for EnvTransition output.
Returns:
A RobotProcessor[TOutput] instance loaded from the saved configuration.
A RobotProcessor instance loaded from the saved configuration.
Raises:
ImportError: If a processor step class cannot be loaded or imported.
@@ -764,34 +739,19 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
f"Make sure override keys match exact step class names or registry names."
)
return cls(
steps=steps,
name=loaded_config.get("name", "RobotProcessor"),
to_transition=to_transition or _default_batch_to_transition,
# Cast is necessary here: Same type-checker limitation as above.
# When to_output is None, we use the default which returns dict[str, Any].
# The cast ensures type consistency with the generic TOutput parameter.
to_output=to_output or cast(Callable[[EnvTransition], TOutput], _default_transition_to_batch),
)
return cls(steps, loaded_config.get("name", "RobotProcessor"))
def __len__(self) -> int:
"""Return the number of steps in the processor."""
return len(self.steps)
def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor[TOutput]:
def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor:
"""Indexing helper exposing underlying steps.
* ``int`` returns the idx-th ProcessorStep.
* ``slice`` returns a new RobotProcessor with the sliced steps.
"""
if isinstance(idx, slice):
return RobotProcessor(
steps=self.steps[idx],
name=self.name,
to_transition=self.to_transition,
to_output=self.to_output,
before_step_hooks=self.before_step_hooks.copy(),
after_step_hooks=self.after_step_hooks.copy(),
)
return RobotProcessor(self.steps[idx], self.name)
return self.steps[idx]
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
@@ -860,7 +820,6 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
def __post_init__(self):
for i, step in enumerate(self.steps):
if not callable(step):
# TODO(steven): This should instead check isinstance(step, ProcessorStep), test need to be updated
raise TypeError(
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
)
@@ -878,7 +837,7 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
return features
class ObservationProcessor(ProcessorStep, ABC):
class ObservationProcessor:
"""Base class for processors that modify only the observation component of a transition.
Subclasses should override the `observation` method to implement custom observation processing.
@@ -899,8 +858,7 @@ class ObservationProcessor(ProcessorStep, ABC):
manipulation, focusing only on the specific observation processing logic.
"""
@abstractmethod
def observation(self, observation) -> dict[str, Any]:
def observation(self, observation):
"""Process the observation component.
Args:
@@ -909,22 +867,36 @@ class ObservationProcessor(ProcessorStep, ABC):
Returns:
The processed observation
"""
...
return observation
def __call__(self, transition: EnvTransition) -> EnvTransition:
self._current_transition = transition.copy()
new_transition = self._current_transition
observation = new_transition.get(TransitionKey.OBSERVATION)
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return new_transition
return transition
processed_observation = self.observation(observation)
# Create a new transition dict with the processed observation
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_observation
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
class ActionProcessor(ProcessorStep, ABC):
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
class ActionProcessor:
"""Base class for processors that modify only the action component of a transition.
Subclasses should override the `action` method to implement custom action processing.
@@ -946,8 +918,7 @@ class ActionProcessor(ProcessorStep, ABC):
manipulation, focusing only on the specific action processing logic.
"""
@abstractmethod
def action(self, action) -> Any | torch.Tensor:
def action(self, action):
"""Process the action component.
Args:
@@ -956,22 +927,36 @@ class ActionProcessor(ProcessorStep, ABC):
Returns:
The processed action
"""
...
return action
def __call__(self, transition: EnvTransition) -> EnvTransition:
self._current_transition = transition.copy()
new_transition = self._current_transition
action = new_transition.get(TransitionKey.ACTION)
action = transition.get(TransitionKey.ACTION)
if action is None:
return new_transition
return transition
processed_action = self.action(action)
# Create a new transition dict with the processed action
new_transition = transition.copy()
new_transition[TransitionKey.ACTION] = processed_action
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
class RewardProcessor(ProcessorStep, ABC):
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
class RewardProcessor:
"""Base class for processors that modify only the reward component of a transition.
Subclasses should override the `reward` method to implement custom reward processing.
@@ -992,8 +977,7 @@ class RewardProcessor(ProcessorStep, ABC):
manipulation, focusing only on the specific reward processing logic.
"""
@abstractmethod
def reward(self, reward) -> float | torch.Tensor:
def reward(self, reward):
"""Process the reward component.
Args:
@@ -1002,22 +986,36 @@ class RewardProcessor(ProcessorStep, ABC):
Returns:
The processed reward
"""
...
return reward
def __call__(self, transition: EnvTransition) -> EnvTransition:
self._current_transition = transition.copy()
new_transition = self._current_transition
reward = new_transition.get(TransitionKey.REWARD)
reward = transition.get(TransitionKey.REWARD)
if reward is None:
return new_transition
return transition
processed_reward = self.reward(reward)
# Create a new transition dict with the processed reward
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = processed_reward
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
class DoneProcessor(ProcessorStep, ABC):
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
class DoneProcessor:
"""Base class for processors that modify only the done flag of a transition.
Subclasses should override the `done` method to implement custom done flag processing.
@@ -1043,8 +1041,7 @@ class DoneProcessor(ProcessorStep, ABC):
manipulation, focusing only on the specific done flag processing logic.
"""
@abstractmethod
def done(self, done) -> bool | torch.Tensor:
def done(self, done):
"""Process the done flag.
Args:
@@ -1053,22 +1050,36 @@ class DoneProcessor(ProcessorStep, ABC):
Returns:
The processed done flag
"""
...
return done
def __call__(self, transition: EnvTransition) -> EnvTransition:
self._current_transition = transition.copy()
new_transition = self._current_transition
done = new_transition.get(TransitionKey.DONE)
done = transition.get(TransitionKey.DONE)
if done is None:
return new_transition
return transition
processed_done = self.done(done)
# Create a new transition dict with the processed done flag
new_transition = transition.copy()
new_transition[TransitionKey.DONE] = processed_done
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
class TruncatedProcessor(ProcessorStep, ABC):
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
class TruncatedProcessor:
"""Base class for processors that modify only the truncated flag of a transition.
Subclasses should override the `truncated` method to implement custom truncated flag processing.
@@ -1090,8 +1101,7 @@ class TruncatedProcessor(ProcessorStep, ABC):
manipulation, focusing only on the specific truncated flag processing logic.
"""
@abstractmethod
def truncated(self, truncated) -> bool | torch.Tensor:
def truncated(self, truncated):
"""Process the truncated flag.
Args:
@@ -1100,22 +1110,36 @@ class TruncatedProcessor(ProcessorStep, ABC):
Returns:
The processed truncated flag
"""
...
return truncated
def __call__(self, transition: EnvTransition) -> EnvTransition:
self._current_transition = transition.copy()
new_transition = self._current_transition
truncated = new_transition.get(TransitionKey.TRUNCATED)
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is None:
return new_transition
return transition
processed_truncated = self.truncated(truncated)
# Create a new transition dict with the processed truncated flag
new_transition = transition.copy()
new_transition[TransitionKey.TRUNCATED] = processed_truncated
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
class InfoProcessor(ProcessorStep, ABC):
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
class InfoProcessor:
"""Base class for processors that modify only the info dictionary of a transition.
Subclasses should override the `info` method to implement custom info processing.
@@ -1142,8 +1166,7 @@ class InfoProcessor(ProcessorStep, ABC):
manipulation, focusing only on the specific info dictionary processing logic.
"""
@abstractmethod
def info(self, info) -> dict[str, Any]:
def info(self, info):
"""Process the info dictionary.
Args:
@@ -1152,22 +1175,36 @@ class InfoProcessor(ProcessorStep, ABC):
Returns:
The processed info dictionary
"""
...
return info
def __call__(self, transition: EnvTransition) -> EnvTransition:
self._current_transition = transition.copy()
new_transition = self._current_transition
info = new_transition.get(TransitionKey.INFO)
info = transition.get(TransitionKey.INFO)
if info is None:
return new_transition
return transition
processed_info = self.info(info)
# Create a new transition dict with the processed info
new_transition = transition.copy()
new_transition[TransitionKey.INFO] = processed_info
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
class ComplementaryDataProcessor(ProcessorStep, ABC):
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
class ComplementaryDataProcessor:
"""Base class for processors that modify only the complementary data of a transition.
Subclasses should override the `complementary_data` method to implement custom complementary data processing.
@@ -1175,8 +1212,7 @@ class ComplementaryDataProcessor(ProcessorStep, ABC):
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
"""
@abstractmethod
def complementary_data(self, complementary_data) -> dict[str, Any]:
def complementary_data(self, complementary_data):
"""Process the complementary data.
Args:
@@ -1185,23 +1221,52 @@ class ComplementaryDataProcessor(ProcessorStep, ABC):
Returns:
The processed complementary data
"""
...
return complementary_data
def __call__(self, transition: EnvTransition) -> EnvTransition:
self._current_transition = transition.copy()
new_transition = self._current_transition
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA)
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return new_transition
return transition
processed_complementary_data = self.complementary_data(complementary_data)
# Create a new transition dict with the processed complementary data
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
class IdentityProcessor(ProcessorStep):
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
class IdentityProcessor:
"""Identity processor that does nothing."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def get_config(self) -> dict[str, Any]:
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any
@@ -50,14 +49,3 @@ class RenameProcessor(ObservationProcessor):
- Keys not in `rename_map` remain unchanged.
"""
return {self.rename_map.get(k, k): v for k, v in features.items()}
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
"""Rename keys in the stats dictionary according to rename_map (defensive copy)."""
if not stats:
return {}
renamed: dict[str, dict[str, Any]] = {}
for old_key, sub_stats in stats.items():
new_key = rename_map.get(old_key, old_key)
renamed[new_key] = deepcopy(sub_stats) if sub_stats is not None else {}
return renamed

View File

@@ -10,13 +10,8 @@ from typing import TYPE_CHECKING, Any
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
from lerobot.processor.pipeline import (
EnvTransition,
ObservationProcessor,
ProcessorStepRegistry,
TransitionKey,
)
from lerobot.constants import OBS_LANGUAGE
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
@@ -27,7 +22,7 @@ else:
@dataclass
@ProcessorStepRegistry.register(name="tokenizer_processor")
class TokenizerProcessor(ObservationProcessor):
class TokenizerProcessor:
"""Tokenizes text tasks in complementary data using a huggingface tokenizer.
This processor handles tokenization of task strings found in the complementary_data
@@ -123,7 +118,7 @@ class TokenizerProcessor(ObservationProcessor):
return None
def observation(self, observation):
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Process the transition by tokenizing the task text.
Args:
@@ -135,15 +130,15 @@ class TokenizerProcessor(ObservationProcessor):
Raises:
ValueError: If tokenizer initialization failed.
"""
task = self.get_task(self.transition)
task = self.get_task(transition)
if task is None:
return observation
return transition
# Tokenize the task (creates CPU tensors)
tokenized_prompt = self._tokenize_text(task)
# Detect device from existing tensors in the transition
target_device = self._detect_device(self.transition)
target_device = self._detect_device(transition)
# Move tokenized tensors to match the device of other data
if target_device is not None:
@@ -153,13 +148,20 @@ class TokenizerProcessor(ObservationProcessor):
}
# Get or create observation dict
new_observation = dict(observation)
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
observation = {}
else:
observation = dict(observation) # Make a copy
# Add tokenized data to observation
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"]
observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to(
dtype=torch.bool
)
return new_observation
transition[TransitionKey.OBSERVATION.value] = observation # type: ignore[misc]
return transition
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
"""Detect device from existing tensors in the transition.
@@ -185,6 +187,19 @@ class TokenizerProcessor(ObservationProcessor):
if isinstance(action, torch.Tensor):
return action.device
# Check other tensor fields
for key in [TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED]:
value = transition.get(key)
if isinstance(value, torch.Tensor):
return value.device
# Check complementary data for tensors
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data:
for value in complementary_data.values():
if isinstance(value, torch.Tensor):
return value.device
return None # No tensors found, keep on CPU
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
@@ -220,12 +235,23 @@ class TokenizerProcessor(ObservationProcessor):
}
# Only include tokenizer_name if it was used (not when tokenizer object was provided)
# TODO(steven): Consider saving the name of the _tokenizer if it was loaded
if self.tokenizer_name is not None and self.tokenizer is None:
if self.tokenizer_name is not None:
config["tokenizer_name"] = self.tokenizer_name
return config
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Add tokenized task features to the feature contract.
@@ -237,13 +263,13 @@ class TokenizerProcessor(ObservationProcessor):
"""
# Add features for tokenized output if they don't exist
# Standard tokenizer output includes tokens and attention_mask
tokens_key = f"{OBS_LANGUAGE}.tokens"
attention_mask_key = f"{OBS_LANGUAGE}.attention_mask"
if OBS_LANGUAGE_TOKENS not in features:
features[OBS_LANGUAGE_TOKENS] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
if tokens_key not in features:
features[tokens_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
if OBS_LANGUAGE_ATTENTION_MASK not in features:
features[OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
type=FeatureType.LANGUAGE, shape=(self.max_length,)
)
if attention_mask_key not in features:
features[attention_mask_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
return features

View File

@@ -74,7 +74,7 @@ from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.factory import make_policy, make_processor
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import RobotProcessor
from lerobot.processor.converters import (
@@ -83,8 +83,8 @@ from lerobot.processor.converters import (
to_transition_robot_observation,
to_transition_teleop_action,
)
from lerobot.processor.normalize_processor import rename_stats
from lerobot.processor.pipeline import IdentityProcessor, TransitionKey
from lerobot.processor.rename_processor import rename_stats
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -346,14 +346,13 @@ def record_loop(
else:
logging.info(
"No policy or teleoperator provided, skipping action generation. "
"This is likely to happen when resetting the environment without a teleop device."
"The robot won't be at its rest position at the start of the next episode."
"This is likely to happen during environment reset."
)
continue
# Still continue to next loop to respect timing
# Applies a pipeline to the action, default is IdentityProcessor
# IMPORTANT: action_pipeline.to_output must return a dict suitable for robot.send_action()
if policy is not None and policy_transition is not None:
if policy_transition is not None:
robot_action_to_send = robot_action_processor(policy_transition)
else:
robot_action_to_send = robot_action_processor(teleop_transition)
@@ -435,7 +434,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
preprocessor = None
postprocessor = None
if cfg.policy is not None:
preprocessor, postprocessor = make_pre_post_processors(
preprocessor, postprocessor = make_processor(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
@@ -511,9 +510,5 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
return dataset
def main():
record()
if __name__ == "__main__":
main()
record()

View File

@@ -45,11 +45,9 @@ from dataclasses import asdict, dataclass
from pathlib import Path
from pprint import pformat
from lerobot.configs import parser
import draccus
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor import RobotProcessor
from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action
from lerobot.processor.pipeline import IdentityProcessor
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -85,25 +83,13 @@ class ReplayConfig:
dataset: DatasetReplayConfig
# Use vocal synthesis to read events.
play_sounds: bool = True
# Optional processor for actions before sending to robot
robot_action_processor: RobotProcessor | None = None
@parser.wrap()
@draccus.wrap()
def replay(cfg: ReplayConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
# Initialize robot action processor with default if not provided
robot_action_processor = cfg.robot_action_processor or RobotProcessor(
steps=[IdentityProcessor()],
to_transition=to_transition_teleop_action,
to_output=to_output_robot_action, # type: ignore[arg-type]
)
# Reset processor
robot_action_processor.reset()
robot = make_robot_from_config(cfg.robot)
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
actions = dataset.hf_dataset.select_columns("action")
@@ -118,10 +104,7 @@ def replay(cfg: ReplayConfig):
for i, name in enumerate(dataset.features["action"]["names"]):
action[name] = action_array[i]
# Process action through robot action processor
processed_action = robot_action_processor(action)
robot.send_action(processed_action) # type: ignore[arg-type]
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / dataset.fps - dt_s)

View File

@@ -19,15 +19,13 @@ from dataclasses import dataclass, field
import numpy as np
from scipy.spatial.transform import Rotation
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.configs.types import PolicyFeature
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor.pipeline import (
ActionProcessor,
ComplementaryDataProcessor,
EnvTransition,
ObservationProcessor,
ProcessorStep,
ProcessorStepRegistry,
TransitionKey,
)
@@ -36,7 +34,7 @@ from lerobot.robots.robot import Robot
@ProcessorStepRegistry.register("ee_reference_and_delta")
@dataclass
class EEReferenceAndDelta(ActionProcessor):
class EEReferenceAndDelta:
"""
Compute the desired end-effector pose from the target pose and the current pose.
@@ -63,9 +61,9 @@ class EEReferenceAndDelta(ActionProcessor):
_prev_enabled: bool = field(default=False, init=False, repr=False)
_command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False)
def action(self, action):
new_action = action.copy()
comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
def __call__(self, transition: EnvTransition) -> EnvTransition:
act = transition.get(TransitionKey.ACTION) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
# Get joint positions from complimentary data
raw = comp.get("raw_joint_positions", None)
@@ -82,13 +80,13 @@ class EEReferenceAndDelta(ActionProcessor):
# Current pose from FK on measured joints
t_curr = self.kinematics.forward_kinematics(q)
enabled = bool(new_action.pop(f"{ACTION}.enabled", 0))
tx = float(new_action.pop(f"{ACTION}.target_x", 0.0))
ty = float(new_action.pop(f"{ACTION}.target_y", 0.0))
tz = float(new_action.pop(f"{ACTION}.target_z", 0.0))
wx = float(new_action.pop(f"{ACTION}.target_wx", 0.0))
wy = float(new_action.pop(f"{ACTION}.target_wy", 0.0))
wz = float(new_action.pop(f"{ACTION}.target_wz", 0.0))
enabled = bool(act.pop("action.enabled", 0))
tx = float(act.pop("action.target_x", 0.0))
ty = float(act.pop("action.target_y", 0.0))
tz = float(act.pop("action.target_z", 0.0))
wx = float(act.pop("action.target_wx", 0.0))
wy = float(act.pop("action.target_wy", 0.0))
wz = float(act.pop("action.target_wz", 0.0))
desired = None
@@ -124,36 +122,22 @@ class EEReferenceAndDelta(ActionProcessor):
# Write action fields
pos = desired[:3, 3]
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
new_action[f"{ACTION}.ee.x"] = float(pos[0])
new_action[f"{ACTION}.ee.y"] = float(pos[1])
new_action[f"{ACTION}.ee.z"] = float(pos[2])
new_action[f"{ACTION}.ee.wx"] = float(tw[0])
new_action[f"{ACTION}.ee.wy"] = float(tw[1])
new_action[f"{ACTION}.ee.wz"] = float(tw[2])
act.update(
{
"action.ee.x": float(pos[0]),
"action.ee.y": float(pos[1]),
"action.ee.z": float(pos[2]),
"action.ee.wx": float(tw[0]),
"action.ee.wy": float(tw[1]),
"action.ee.wz": float(tw[2]),
}
)
self._prev_enabled = enabled
return new_action
def reset(self):
self._prev_enabled = False
self.reference_ee_pose = None
self._command_when_disabled = None
transition[TransitionKey.ACTION] = act
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop(f"{ACTION}.enabled", None)
features.pop(f"{ACTION}.target_x", None)
features.pop(f"{ACTION}.target_y", None)
features.pop(f"{ACTION}.target_z", None)
features.pop(f"{ACTION}.target_wx", None)
features.pop(f"{ACTION}.target_wy", None)
features.pop(f"{ACTION}.target_wz", None)
features[f"{ACTION}.ee.x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.ee.y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.ee.z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.ee.wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.ee.wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.ee.wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
return features
@@ -178,15 +162,14 @@ class EEBoundsAndSafety(ActionProcessor):
max_ee_step_m: float = 0.05
max_ee_twist_step_rad: float = 0.20
_last_pos: np.ndarray | None = field(default=None, init=False, repr=False)
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
def action(self, act: dict) -> dict:
x = act.get(f"{ACTION}.ee.x", None)
y = act.get(f"{ACTION}.ee.y", None)
z = act.get(f"{ACTION}.ee.z", None)
wx = act.get(f"{ACTION}.ee.wx", None)
wy = act.get(f"{ACTION}.ee.wy", None)
wz = act.get(f"{ACTION}.ee.wz", None)
def action(self, act: dict | None) -> dict:
x = act.pop("action.ee.x", None)
y = act.pop("action.ee.y", None)
z = act.pop("action.ee.z", None)
wx = act.pop("action.ee.wx", None)
wy = act.pop("action.ee.wy", None)
wz = act.pop("action.ee.wz", None)
if None in (x, y, z, wx, wy, wz):
return act
@@ -208,22 +191,35 @@ class EEBoundsAndSafety(ActionProcessor):
self._last_pos = pos
self._last_twist = twist
act[f"{ACTION}.ee.x"] = float(pos[0])
act[f"{ACTION}.ee.y"] = float(pos[1])
act[f"{ACTION}.ee.z"] = float(pos[2])
act[f"{ACTION}.ee.wx"] = float(twist[0])
act[f"{ACTION}.ee.wy"] = float(twist[1])
act[f"{ACTION}.ee.wz"] = float(twist[2])
act.update(
{
"action.ee.x": float(pos[0]),
"action.ee.y": float(pos[1]),
"action.ee.z": float(pos[2]),
"action.ee.wx": float(twist[0]),
"action.ee.wy": float(twist[1]),
"action.ee.wz": float(twist[2]),
}
)
return act
def reset(self):
self._last_pos = None
self._last_twist = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# Because this is last step we specify the dataset features of this step that we want to be stored in the dataset
features["action.ee.x"] = float
features["action.ee.y"] = float
features["action.ee.z"] = float
features["action.ee.wx"] = float
features["action.ee.wy"] = float
features["action.ee.wz"] = float
return features
@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints")
@dataclass
class InverseKinematicsEEToJoints(ProcessorStep):
class InverseKinematicsEEToJoints:
"""
Compute the desired joint positions from the desired end-effector pose.
@@ -251,14 +247,26 @@ class InverseKinematicsEEToJoints(ProcessorStep):
act = transition.get(TransitionKey.ACTION) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
x = act.get(f"{ACTION}.ee.x", None)
y = act.get(f"{ACTION}.ee.y", None)
z = act.get(f"{ACTION}.ee.z", None)
wx = act.get(f"{ACTION}.ee.wx", None)
wy = act.get(f"{ACTION}.ee.wy", None)
wz = act.get(f"{ACTION}.ee.wz", None)
x = act.get("action.ee.x", None)
y = act.get("action.ee.y", None)
z = act.get("action.ee.z", None)
wx = act.get("action.ee.wx", None)
wy = act.get("action.ee.wy", None)
wz = act.get("action.ee.wz", None)
if None in (x, y, z, wx, wy, wz):
# Nothing to do; restore what we popped and return
act.update(
{
"action.ee.x": x,
"action.ee.y": y,
"action.ee.z": z,
"action.ee.wx": wx,
"action.ee.wy": wy,
"action.ee.wz": wz,
}
)
transition[TransitionKey.ACTION] = act
return transition
# Get joint positions from complimentary data
@@ -286,20 +294,25 @@ class InverseKinematicsEEToJoints(ProcessorStep):
new_act = dict(act)
for i, name in enumerate(self.motor_names):
if name == "gripper":
new_act[f"{OBS_STATE}.gripper.pos"] = float(raw["gripper"])
new_act["observation.state.gripper.pos"] = float(raw["gripper"])
else:
new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
new_act[f"action.{name}.pos"] = float(q_target[i])
transition[TransitionKey.ACTION] = new_act
if not self.initial_guess_current_joints:
transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features[f"{OBS_STATE}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
for name in self.motor_names:
features[f"{ACTION}.{name}.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
# We specify the dataset features of this step that we want to be stored in the dataset
features["action.ee.x"] = float
features["action.ee.y"] = float
features["action.ee.z"] = float
features["action.ee.wx"] = float
features["action.ee.wy"] = float
features["action.ee.wz"] = float
features["observation.state.gripper.pos"] = float
features["action.gripper.pos"] = float
return features
def reset(self):
@@ -308,7 +321,7 @@ class InverseKinematicsEEToJoints(ProcessorStep):
@ProcessorStepRegistry.register("gripper_velocity_to_joint")
@dataclass
class GripperVelocityToJoint(ProcessorStep):
class GripperVelocityToJoint:
"""
Convert the gripper velocity to a joint velocity.
@@ -334,12 +347,12 @@ class GripperVelocityToJoint(ProcessorStep):
act = transition.get(TransitionKey.ACTION) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if f"{ACTION}.gripper" not in act:
if "action.gripper" not in act:
return transition
if "gripper" not in self.motor_names:
new_act = dict(act)
new_act.pop(f"{ACTION}.gripper", None)
new_act.pop("action.gripper", None)
transition[TransitionKey.ACTION] = new_act
return transition
@@ -347,32 +360,33 @@ class GripperVelocityToJoint(ProcessorStep):
# Discrete gripper actions are in [0, 1, 2]
# 0: open, 1: close, 2: stay
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
gripper_action = act.get(f"{ACTION}.gripper", 1.0)
gripper_action = act.get("action.gripper", 1.0)
gripper_action = gripper_action - 1.0
gripper_action *= self.clip_max
act[f"{ACTION}.gripper"] = gripper_action
act["action.gripper"] = gripper_action
# Get current gripper position from complementary data
raw = comp.get("raw_joint_positions") or {}
curr_pos = float(raw.get("gripper"))
# Compute desired gripper velocity
u = float(act.get(f"{ACTION}.gripper", 0.0))
u = float(act.get("action.gripper", 0.0))
delta = u * float(self.speed_factor)
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
new_act = dict(act)
new_act[f"{ACTION}.gripper.pos"] = gripper_pos
new_act.pop(f"{ACTION}.gripper", None)
new_act["action.gripper.pos"] = gripper_pos
new_act.pop("action.gripper", None)
transition[TransitionKey.ACTION] = new_act
obs[f"{OBS_STATE}.gripper.pos"] = curr_pos
obs.update({"observation.state.gripper.pos": curr_pos})
transition[TransitionKey.OBSERVATION] = obs
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop(f"{ACTION}.gripper", None)
features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
# We specify the dataset features of this step that we want to be stored in the dataset
features["observation.state.gripper.pos"] = float
features["action.gripper.pos"] = float
return features
@@ -396,27 +410,31 @@ class ForwardKinematicsJointsToEE(ObservationProcessor):
kinematics: RobotKinematics
motor_names: list[str]
def observation(self, obs: dict) -> dict:
if not all(f"{OBS_STATE}.{n}.pos" in obs for n in self.motor_names):
def observation(self, obs: dict | None) -> dict:
if not all(f"observation.state.{n}.pos" in obs for n in self.motor_names):
return obs
q = np.array([obs[f"{OBS_STATE}.{n}.pos"] for n in self.motor_names], dtype=float)
q = np.array([obs[f"observation.state.{n}.pos"] for n in self.motor_names], dtype=float)
t = self.kinematics.forward_kinematics(q)
pos = t[:3, 3]
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
obs[f"{OBS_STATE}.ee.x"] = float(pos[0])
obs[f"{OBS_STATE}.ee.y"] = float(pos[1])
obs[f"{OBS_STATE}.ee.z"] = float(pos[2])
obs[f"{OBS_STATE}.ee.wx"] = float(tw[0])
obs[f"{OBS_STATE}.ee.wy"] = float(tw[1])
obs[f"{OBS_STATE}.ee.wz"] = float(tw[2])
obs.update(
{
"observation.state.ee.x": float(pos[0]),
"observation.state.ee.y": float(pos[1]),
"observation.state.ee.z": float(pos[2]),
"observation.state.ee.wx": float(tw[0]),
"observation.state.ee.wy": float(tw[1]),
"observation.state.ee.wz": float(tw[2]),
}
)
return obs
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We specify the dataset features of this step that we want to be stored in the dataset
for k in ["x", "y", "z", "wx", "wy", "wz"]:
features[f"{OBS_STATE}.ee.{k}"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"observation.state.ee.{k}"] = float
return features
@@ -433,14 +451,15 @@ class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessor):
robot: Robot
def complementary_data(self, comp: dict | None) -> dict:
new_comp = dict(comp)
obs = (
self.robot.get_observation()
) # todo(steven): why not self.trtansition.get(TransitionKey.OBSERVATION)?
comp = {} if comp is None else dict(comp)
obs = self.robot.get_observation()
new_comp["raw_joint_positions"] = {
comp["raw_joint_positions"] = {
k.removesuffix(".pos"): float(v)
for k, v in obs.items()
if isinstance(k, str) and k.endswith(".pos")
}
return new_comp
return comp
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -161,11 +161,6 @@ class SO100Follower(Robot):
self.bus.write("I_Coefficient", motor, 0)
self.bus.write("D_Coefficient", motor, 32)
if motor == "gripper":
self.bus.write("Max_Torque_Limit", motor, 500) # 50% of max torque to avoid burnout
self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout
self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")

View File

@@ -157,13 +157,6 @@ class SO101Follower(Robot):
self.bus.write("I_Coefficient", motor, 0)
self.bus.write("D_Coefficient", motor, 32)
if motor == "gripper":
self.bus.write(
"Max_Torque_Limit", motor, 500
) # 50% of the max torque limit to avoid burnout
self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout
self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")

View File

@@ -29,6 +29,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
from .so100_follower import SO100Follower
return SO100Follower(config)
elif config.type == "so100_follower_end_effector":
from .so100_follower import SO100FollowerEndEffector
return SO100FollowerEndEffector(config)
elif config.type == "so101_follower":
from .so101_follower import SO101Follower

View File

@@ -98,6 +98,7 @@ from lerobot.utils.utils import (
ACTOR_SHUTDOWN_TIMEOUT = 30
#################################################
# Main entry point #
#################################################
@@ -287,9 +288,7 @@ def act_with_policy(
logging.info("[ACTOR] Shutting down act_with_policy")
return
observation = {
k: v for k, v in transition[TransitionKey.OBSERVATION].items() if k in cfg.policy.input_features
}
observation = transition[TransitionKey.OBSERVATION]
# Time policy inference and check if it meets FPS requirement
with policy_timer:
@@ -309,16 +308,8 @@ def act_with_policy(
)
# Extract values from processed transition
next_observation = {
k: v
for k, v in new_transition[TransitionKey.OBSERVATION].items()
if k in cfg.policy.input_features
}
# Teleop action is the action that was executed in the environment
# It is either the action from the teleop device or the action from the policy
executed_action = new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"]
next_observation = new_transition[TransitionKey.OBSERVATION]
executed_action = new_transition[TransitionKey.ACTION]
reward = new_transition[TransitionKey.REWARD]
done = new_transition.get(TransitionKey.DONE, False)
truncated = new_transition.get(TransitionKey.TRUNCATED, False)

View File

@@ -37,7 +37,6 @@ from lerobot.processor import (
InterventionActionProcessor,
JointVelocityProcessor,
MapDeltaActionToRobotAction,
MapTensorToDeltaActionDict,
MotorCurrentProcessor,
Numpy2TorchActionProcessor,
RewardClassifierProcessor,
@@ -81,11 +80,11 @@ class DatasetConfig:
"""Configuration for dataset creation and management."""
repo_id: str
dataset_root: str
task: str
root: str | None = None
num_episodes_to_record: int = 5
replay_episode: int | None = None
push_to_hub: bool = False
num_episodes: int
episode: int
push_to_hub: bool
@dataclass
@@ -402,11 +401,13 @@ def make_processors(
joint_names=motor_names,
)
env_pipeline_steps = [VanillaObservationProcessor()]
env_pipeline_steps = [
VanillaObservationProcessor(),
]
if cfg.processor.observation is not None:
if cfg.processor.observation.add_joint_velocity_to_observation:
env_pipeline_steps.append(JointVelocityProcessor(dt=1.0 / cfg.fps))
env_pipeline_steps.append(JointVelocityProcessor(dt=1.0 / cfg.fps, num_dof=len(motor_names)))
if cfg.processor.observation.add_current_to_observation:
env_pipeline_steps.append(MotorCurrentProcessor(robot=env.robot))
@@ -472,7 +473,6 @@ def make_processors(
if cfg.processor.inverse_kinematics is not None and kinematics_solver is not None:
# Add EE bounds and safety processor
inverse_kinematics_steps = [
MapTensorToDeltaActionDict(),
MapDeltaActionToRobotAction(),
EEReferenceAndDelta(
kinematics=kinematics_solver,
@@ -625,7 +625,7 @@ def control_loop(
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.env.fps,
root=cfg.dataset.root,
root=cfg.dataset.dataset_root,
use_videos=True,
image_writer_threads=4,
image_writer_processes=0,
@@ -636,7 +636,7 @@ def control_loop(
episode_step = 0
episode_start_time = time.perf_counter()
while episode_idx < cfg.dataset.num_episodes_to_record:
while episode_idx < cfg.dataset.num_episodes:
step_start_time = time.perf_counter()
# Create a neutral action (no movement)
@@ -711,12 +711,10 @@ def control_loop(
def replay_trajectory(env: gym.Env, action_processor: RobotProcessor, cfg: GymManipulatorConfig) -> None:
"""Replay recorded trajectory on robot environment."""
assert cfg.dataset.replay_episode is not None, "Replay episode must be provided for replay"
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=[cfg.dataset.replay_episode],
root=cfg.dataset.dataset_root,
episodes=[cfg.dataset.episode],
download_videos=False,
)
dataset_actions = dataset.hf_dataset.select_columns(["action"])

View File

@@ -302,6 +302,11 @@ class RobotClient:
self.logger.debug(f"Current latest action: {latest_action}")
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:
old_timesteps = [latest_action] # queue was empty
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:

View File

@@ -32,7 +32,7 @@ from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.envs.factory import make_env
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.factory import make_policy, make_processor
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.scripts.eval import eval_policy
@@ -141,7 +141,7 @@ def train(cfg: TrainPipelineConfig):
cfg=cfg.policy,
ds_meta=dataset.meta,
)
preprocessor, postprocessor = make_pre_post_processors(
preprocessor, postprocessor = make_processor(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
)

View File

@@ -0,0 +1,311 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from contextlib import nullcontext
from pprint import pformat
from typing import Any, Callable
import accelerate
import torch
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
has_method,
init_logging,
is_launched_with_accelerate,
)
from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.eval import eval_policy
def update_policy(
train_metrics: MetricsTracker,
policy: PreTrainedPolicy,
batch: Any,
optimizer: Optimizer,
grad_clip_norm: float,
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
lock=None,
accelerator: Callable = None,
) -> tuple[MetricsTracker, dict]:
start_time = time.perf_counter()
policy.train()
loss, output_dict = policy.forward(batch)
accelerator.backward(loss)
accelerator.unscale_gradients(optimizer=optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
optimizer.step()
optimizer.zero_grad()
# Step through pytorch scheduler at every batch instead of epoch
if lr_scheduler is not None:
lr_scheduler.step()
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
train_metrics.lr = optimizer.param_groups[0]["lr"]
train_metrics.update_s = time.perf_counter() - start_time
return train_metrics, output_dict
@parser.wrap()
def train(cfg: TrainPipelineConfig, accelerator: Callable):
cfg.validate()
logging.info(pformat(cfg.to_dict()))
if accelerator.is_main_process:
# Disable logging on non-main processes.
cfg.wandb.enable = False
if cfg.wandb.enable and cfg.wandb.project:
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if cfg.seed is not None:
set_seed(cfg.seed, accelerator=accelerator)
# Check device is available
device = get_safe_torch_device(cfg.device, log=True, accelerator=accelerator)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("Creating dataset")
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None
if cfg.eval_freq > 0 and cfg.env is not None:
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
logging.info("Creating policy")
policy = make_policy(
cfg=cfg.policy,
device=device,
ds_meta=dataset.meta,
)
policy.to(device)
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
if accelerator.is_main_process:
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
dataset.episode_data_index,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
else:
shuffle = True
sampler = None
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle,
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=False,
)
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
dl_iter = cycle(dataloader)
policy.train()
train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
"grad_norm": AverageMeter("grdn", ":.3f"),
"lr": AverageMeter("lr", ":0.1e"),
"update_s": AverageMeter("updt_s", ":.3f"),
"dataloading_s": AverageMeter("data_s", ":.3f"),
}
train_tracker = MetricsTracker(
cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
train_metrics,
initial_step=step,
accelerator=accelerator,
)
if accelerator.is_main_process:
logging.info("Start offline training on a fixed dataset")
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
train_tracker.dataloading_s = time.perf_counter() - start_time
train_tracker, output_dict = update_policy(
train_tracker,
policy,
batch,
optimizer,
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
accelerator=accelerator,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
step += 1
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and accelerator.is_main_process
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps and accelerator.is_main_process
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 and accelerator.is_main_process
if is_log_step:
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(
checkpoint_dir,
step,
cfg,
accelerator.unwrap_model(policy),
optimizer,
lr_scheduler,
)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
accelerator.wait_for_everyone()
if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
with torch.no_grad():
eval_info = eval_policy(
env=eval_env,
policy=accelerator.unwrap_model(policy),
n_episodes=cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
"pc_success": AverageMeter("success", ":.1f"),
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
eval_metrics,
initial_step=step,
accelerator=None,
)
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
logging.info(eval_tracker)
if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
if eval_env:
eval_env.close()
if not accelerator or accelerator.is_main_process:
logging.info("End of training")
if __name__ == "__main__":
init_logging()
# We set step_scheduler_with_optimizer False to prevent accelerate from
# adjusting the lr_scheduler steps based on the num_processes
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False)
train(accelerator=accelerator)

View File

@@ -56,18 +56,11 @@ import time
from dataclasses import asdict, dataclass
from pprint import pformat
import draccus
import rerun as rr
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.processor import RobotProcessor
from lerobot.processor.converters import (
to_output_robot_action,
to_transition_robot_observation,
to_transition_teleop_action,
)
from lerobot.processor.pipeline import IdentityProcessor
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -104,82 +97,39 @@ class TeleoperateConfig:
teleop_time_s: float | None = None
# Display all cameras on screen
display_data: bool = False
# Optional processors for data transformation
teleop_action_processor: RobotProcessor | None = None # runs after teleop
robot_action_processor: RobotProcessor | None = None # runs before robot
robot_observation_processor: RobotProcessor | None = None # runs after robot
def teleop_loop(
teleop: Teleoperator,
robot: Robot,
fps: int,
display_data: bool = False,
duration: float | None = None,
teleop_action_processor: RobotProcessor | None = None,
robot_action_processor: RobotProcessor | None = None,
robot_observation_processor: RobotProcessor | None = None,
teleop: Teleoperator, robot: Robot, fps: int, display_data: bool = False, duration: float | None = None
):
# Initialize processors with defaults if not provided
teleop_action_processor = teleop_action_processor or RobotProcessor(
steps=[IdentityProcessor()], to_transition=to_transition_teleop_action, to_output=lambda tr: tr
)
robot_action_processor = robot_action_processor or RobotProcessor(
steps=[IdentityProcessor()],
to_transition=lambda tr: tr,
to_output=to_output_robot_action, # type: ignore[arg-type]
)
robot_observation_processor = robot_observation_processor or RobotProcessor(
steps=[IdentityProcessor()], to_transition=to_transition_robot_observation, to_output=lambda tr: tr
)
# Reset processors
teleop_action_processor.reset()
robot_action_processor.reset()
robot_observation_processor.reset()
display_len = max(len(key) for key in robot.action_features)
start = time.perf_counter()
while True:
loop_start = time.perf_counter()
# Get teleop action
raw_action = teleop.get_action()
# Process teleop action through pipeline
teleop_transition = teleop_action_processor(raw_action)
# Process action for robot through pipeline
robot_action_to_send = robot_action_processor(teleop_transition)
# Send processed action to robot (robot_action_processor.to_output should return dict[str, Any])
robot.send_action(robot_action_to_send) # type: ignore[arg-type]
action = teleop.get_action()
if display_data:
# Get robot observation
obs = robot.get_observation()
# Process robot observation through pipeline
obs_transition = robot_observation_processor(obs)
log_rerun_data([obs_transition, teleop_transition])
print("\n" + "-" * (display_len + 10))
print(f"{'NAME':<{display_len}} | {'NORM':>7}")
# Display the final robot action that was sent
for motor, value in robot_action_to_send.items():
print(f"{motor:<{display_len}} | {value:>7.2f}")
move_cursor_up(len(robot_action_to_send) + 5)
observation = robot.get_observation()
log_rerun_data(observation=observation, action=action)
robot.send_action(action)
dt_s = time.perf_counter() - loop_start
busy_wait(1 / fps - dt_s)
loop_s = time.perf_counter() - loop_start
print("\n" + "-" * (display_len + 10))
print(f"{'NAME':<{display_len}} | {'NORM':>7}")
for motor, value in action.items():
print(f"{motor:<{display_len}} | {value:>7.2f}")
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
if duration is not None and time.perf_counter() - start >= duration:
return
move_cursor_up(len(action) + 5)
@parser.wrap()
@draccus.wrap()
def teleoperate(cfg: TeleoperateConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
@@ -193,16 +143,7 @@ def teleoperate(cfg: TeleoperateConfig):
robot.connect()
try:
teleop_loop(
teleop=teleop,
robot=robot,
fps=cfg.fps,
display_data=cfg.display_data,
duration=cfg.teleop_time_s,
teleop_action_processor=cfg.teleop_action_processor,
robot_action_processor=cfg.robot_action_processor,
robot_observation_processor=cfg.robot_observation_processor,
)
teleop_loop(teleop, robot, cfg.fps, display_data=cfg.display_data, duration=cfg.teleop_time_s)
except KeyboardInterrupt:
pass
finally:

View File

@@ -177,6 +177,16 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
"names": {"action.delta_x": 0, "action.delta_y": 1, "action.delta_z": 2},
}
def _on_press(self, key):
if hasattr(key, "char"):
key = key.char
self.event_queue.put((key, True))
def _on_release(self, key):
if hasattr(key, "char"):
key = key.char
self.event_queue.put((key, False))
def get_action(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(

View File

@@ -0,0 +1,246 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Docs:
# hebi: https://docs.hebi.us/tools.html#mobile-io
# teleop: https://github.com/SpesRobotics/teleop
import logging
import threading
import time
import hebi
import numpy as np
from scipy.spatial.transform import Rotation
from teleop import Teleop
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.teleoperator import Teleoperator
logger = logging.getLogger(__name__)
class Phone(Teleoperator):
"""
Phone-based teleoperator using ARKit (iOS via HEBI Mobile I/O App) or the teleop Python package (Android via WebXR API).
For HEBI Mobile I/O we also expose 8 analog (a1-a8) and 8 digital (b1-b8) inputs.
Press and hold **B1** to enable teleoperation. While enabled, the first B1 press
captures a reference pose and rotation, when disabled and pressed again the position is reapplied.
"""
config_class = PhoneConfig
name = "phone"
def __init__(self, config: PhoneConfig):
super().__init__(config)
self.config = config
self._group = None
self._teleop = None
self._teleop_thread = None
self._latest_pose = None
self._latest_message = None
self._enabled: bool = False
self._calib_pos: np.ndarray | None = None
self._calib_rot_inv: Rotation | None = None
@property
def is_connected(self) -> bool:
return (self.config.phone_os == PhoneOS.IOS and self._group is not None) or (
self.config.phone_os == PhoneOS.ANDROID and self._teleop is not None
)
def connect(self) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
if self.config.phone_os == PhoneOS.IOS:
logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.")
lookup = hebi.Lookup()
time.sleep(2.0)
group = lookup.get_group_from_names(["HEBI"], ["mobileIO"])
if group is None:
raise RuntimeError("Mobile I/O not found — check name/family settings in the app.")
self._group = group
logger.info(f"{self} connected to HEBI group with {group.size} module(s).")
elif self.config.phone_os == PhoneOS.ANDROID:
logger.info("Starting teleop stream for Android...")
self._teleop = Teleop()
self._teleop.subscribe(self._android_callback)
self._teleop_thread = threading.Thread(target=self._teleop.run, daemon=True)
self._teleop_thread.start()
logger.info(f"{self} connected, teleop stream started.")
else:
raise ValueError(f"Invalid config phone_os: {self.config.phone_os}")
self.calibrate()
def calibrate(self) -> None:
print(
"Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
)
if self.config.phone_os == PhoneOS.IOS:
print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n")
else:
print("Touch and move on the WebXR page to capture this pose...\n")
pos, rot = self._wait_for_capture_trigger()
self._calib_pos = pos.copy()
self._calib_rot_inv = rot.inv()
self._enabled = False
print("Calibration done\n")
def _reapply_position_calibration(self, pos: np.ndarray) -> None:
self._calib_pos = pos.copy()
@property
def is_calibrated(self) -> bool:
return (self._calib_pos is not None) and (self._calib_rot_inv is not None)
@property
def action_features(self) -> dict[str, type]:
return {
"phone.pos": np.ndarray, # shape (3,)
"phone.rot": Rotation, # scipy.spatial.transform.Rotation
"phone.raw_inputs": dict, # analogs/buttons or webXR meta
"phone.enabled": bool,
}
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
while True:
ok, pos, rot, pose = self._read_current_pose()
if not ok:
time.sleep(0.01)
continue
if self.config.phone_os == PhoneOS.IOS:
io = getattr(pose, "io", None)
b = getattr(io, "b", None) if io is not None else None
b1 = False
if b is not None:
b1 = bool(b.get_int(1))
if b1:
return pos, rot
else:
msg = self._latest_message or {}
if bool(msg.get("move", False)):
return pos, rot
time.sleep(0.01)
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
if self.config.phone_os == PhoneOS.IOS:
fbk = self._group.get_next_feedback()
pose = fbk[0]
ar_pos = getattr(pose, "ar_position", None)
ar_quat = getattr(pose, "ar_orientation", None)
if ar_pos is None or ar_quat is None:
return False, None, None, None
quat_xyzw = np.concatenate((ar_quat[1:], [ar_quat[0]])) # wxyz to xyzw
rot = Rotation.from_quat(quat_xyzw)
pos = ar_pos - rot.apply(self.config.camera_offset)
return True, pos, rot, pose
else:
p = self._latest_pose
if p is None:
return False, None, None, None
rot = Rotation.from_matrix(p[:3, :3])
pos = p[:3, 3] - rot.apply(self.config.camera_offset)
pose = self._latest_pose
return True, pos, rot, pose
@property
def feedback_features(self) -> dict[str, type]:
# No haptic or other feedback implemented yet
pass
def configure(self) -> None:
# No additional configuration required for phone teleop
pass
def _android_callback(self, pose: np.ndarray, message: dict) -> None:
self._latest_pose = pose
self._latest_message = message
time.sleep(0.001) # 1ms delay to avoid race condition
def get_action(self) -> dict:
ok, raw_pos, raw_rot, pose = self._read_current_pose()
if not ok or not self.is_calibrated:
return {}
# Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
raw_inputs: dict[str, float | int | bool] = {}
if self.config.phone_os == PhoneOS.IOS:
io = getattr(pose, "io", None)
if io is not None:
bank_a, bank_b = io.a, io.b
if bank_a:
for ch in range(1, 9):
if bank_a.has_float(ch):
raw_inputs[f"a{ch}"] = float(bank_a.get_float(ch))
if bank_b:
for ch in range(1, 9):
if bank_b.has_int(ch):
raw_inputs[f"b{ch}"] = int(bank_b.get_int(ch))
elif hasattr(bank_b, "has_bool") and bank_b.has_bool(ch):
raw_inputs[f"b{ch}"] = int(bank_b.get_bool(ch))
else:
msg = self._latest_message or {}
raw_inputs["move"] = bool(msg.get("move", False))
raw_inputs["scale"] = float(msg.get("scale", 1.0))
raw_inputs["reservedButtonA"] = bool(msg.get("reservedButtonA", False))
raw_inputs["reservedButtonB"] = bool(msg.get("reservedButtonB", False))
if self.config.phone_os == PhoneOS.IOS:
enable = bool(raw_inputs.get("b1", 0))
else:
enable = bool(raw_inputs.get("move", False))
# Rising edge then re-capture calibration immediately from current raw pose
if enable and not self._enabled:
self._reapply_position_calibration(raw_pos)
# Apply calibration
pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos)
rot_cal = self._calib_rot_inv * raw_rot
self._enabled = enable
return {
"phone.pos": pos_cal,
"phone.rot": rot_cal,
"phone.raw_inputs": raw_inputs,
"phone.enabled": self._enabled,
}
def send_feedback(self, feedback: dict[str, float]) -> None:
# We could add haptic feedback (vibrations) here, but it's not implemented yet
raise NotImplementedError
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.config.phone_os == PhoneOS.IOS:
self._group = None
else:
self._teleop = None
if self._teleop_thread and self._teleop_thread.is_alive():
self._teleop_thread.join(timeout=1.0)
self._teleop_thread = None
self._latest_pose = None

View File

@@ -16,7 +16,7 @@
from dataclasses import dataclass, field
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
from lerobot.teleoperators.phone.config_phone import PhoneOS
@@ -46,9 +46,9 @@ class MapPhoneActionToRobotAction(ActionProcessor):
platform: PhoneOS
_enabled_prev: bool = field(default=False, init=False, repr=False)
def action(self, act: dict) -> dict:
def action(self, act: dict | None) -> dict:
# Pop them from the action
enabled = bool(act.pop("action.phone.enabled", 0))
enabled = act.pop("action.phone.enabled", 0)
pos = act.pop("action.phone.pos", None)
rot = act.pop("action.phone.rot", None)
inputs = act.pop("action.phone.raw_inputs", {})
@@ -69,28 +69,19 @@ class MapPhoneActionToRobotAction(ActionProcessor):
) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed
# For some actions we need to invert the axis
act["action.enabled"] = enabled
act["action.target_x"] = -pos[1] if enabled else 0.0
act["action.target_y"] = pos[0] if enabled else 0.0
act["action.target_z"] = pos[2] if enabled else 0.0
act["action.target_wx"] = rotvec[1] if enabled else 0.0
act["action.target_wy"] = rotvec[0] if enabled else 0.0
act["action.target_wz"] = -rotvec[2] if enabled else 0.0
act["action.gripper"] = gripper # Still send gripper action when disabled
act.update(
{
"action.enabled": enabled,
"action.target_x": -pos[1] if enabled else 0.0,
"action.target_y": pos[0] if enabled else 0.0,
"action.target_z": pos[2] if enabled else 0.0,
"action.target_wx": rotvec[1] if enabled else 0.0,
"action.target_wy": rotvec[0] if enabled else 0.0,
"action.target_wz": -rotvec[2] if enabled else 0.0,
"action.gripper": gripper, # Still send gripper action when disabled
}
)
return act
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop("action.phone.enabled", None)
features.pop("action.phone.pos", None)
features.pop("action.phone.rot", None)
features.pop("action.phone.raw_inputs", None)
features["action.enabled"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.target_x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.target_y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.target_z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.target_wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.target_wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.target_wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.gripper"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
return features

View File

@@ -1,359 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Docs:
# hebi: https://docs.hebi.us/tools.html#mobile-io
# teleop: https://github.com/SpesRobotics/teleop
import logging
import threading
import time
import hebi
import numpy as np
from scipy.spatial.transform import Rotation
from teleop import Teleop
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.teleoperator import Teleoperator
logger = logging.getLogger(__name__)
class BasePhone:
_enabled: bool = False
_calib_pos: np.ndarray | None = None
_calib_rot_inv: Rotation | None = None
def _reapply_position_calibration(self, pos: np.ndarray) -> None:
self._calib_pos = pos.copy()
@property
def is_calibrated(self) -> bool:
return (self._calib_pos is not None) and (self._calib_rot_inv is not None)
@property
def action_features(self) -> dict[str, type]:
return {
"phone.pos": np.ndarray, # shape (3,)
"phone.rot": Rotation, # scipy.spatial.transform.Rotation
"phone.raw_inputs": dict, # analogs/buttons or webXR meta
"phone.enabled": bool,
}
@property
def feedback_features(self) -> dict[str, type]:
# No haptic or other feedback implemented yet
pass
def configure(self) -> None:
# No additional configuration required for phone teleop
pass
def send_feedback(self, feedback: dict[str, float]) -> None:
# We could add haptic feedback (vibrations) here, but it's not implemented yet
raise NotImplementedError
class IOSPhone(BasePhone, Teleoperator):
name = "ios_phone"
def __init__(self, config: PhoneConfig):
super().__init__(config)
self.config = config
self._group = None
@property
def is_connected(self) -> bool:
return self._group is not None
def connect(self) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.")
lookup = hebi.Lookup()
time.sleep(2.0)
group = lookup.get_group_from_names(["HEBI"], ["mobileIO"])
if group is None:
raise RuntimeError("Mobile I/O not found — check name/family settings in the app.")
self._group = group
logger.info(f"{self} connected to HEBI group with {group.size} module(s).")
self.calibrate()
def calibrate(self) -> None:
print(
"Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
)
print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n")
pos, rot = self._wait_for_capture_trigger()
self._calib_pos = pos.copy()
self._calib_rot_inv = rot.inv()
self._enabled = False
print("Calibration done\n")
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
while True:
ok, pos, rot, pose = self._read_current_pose()
if not ok:
time.sleep(0.01)
continue
io = getattr(pose, "io", None)
b = getattr(io, "b", None) if io is not None else None
b1 = False
if b is not None:
b1 = bool(b.get_int(1))
if b1:
return pos, rot
time.sleep(0.01)
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
fbk = self._group.get_next_feedback()
pose = fbk[0]
ar_pos = getattr(pose, "ar_position", None)
ar_quat = getattr(pose, "ar_orientation", None)
if ar_pos is None or ar_quat is None:
return False, None, None, None
# HEBI provides orientation in w, x, y, z format.
# Scipy's Rotation expects x, y, z, w.
quat_xyzw = np.concatenate((ar_quat[1:], [ar_quat[0]])) # wxyz to xyzw
rot = Rotation.from_quat(quat_xyzw)
pos = ar_pos - rot.apply(self.config.camera_offset)
return True, pos, rot, pose
def get_action(self) -> dict:
ok, raw_pos, raw_rot, pose = self._read_current_pose()
if not ok or not self.is_calibrated:
return {}
# Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
raw_inputs: dict[str, float | int | bool] = {}
io = getattr(pose, "io", None)
if io is not None:
bank_a, bank_b = io.a, io.b
if bank_a:
for ch in range(1, 9):
if bank_a.has_float(ch):
raw_inputs[f"a{ch}"] = float(bank_a.get_float(ch))
if bank_b:
for ch in range(1, 9):
if bank_b.has_int(ch):
raw_inputs[f"b{ch}"] = int(bank_b.get_int(ch))
elif hasattr(bank_b, "has_bool") and bank_b.has_bool(ch):
raw_inputs[f"b{ch}"] = int(bank_b.get_bool(ch))
enable = bool(raw_inputs.get("b1", 0))
# Rising edge then re-capture calibration immediately from current raw pose
if enable and not self._enabled:
self._reapply_position_calibration(raw_pos)
# Apply calibration
pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos)
rot_cal = self._calib_rot_inv * raw_rot
self._enabled = enable
return {
"phone.pos": pos_cal,
"phone.rot": rot_cal,
"phone.raw_inputs": raw_inputs,
"phone.enabled": self._enabled,
}
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self._group = None
class AndroidPhone(BasePhone, Teleoperator):
name = "android_phone"
def __init__(self, config: PhoneConfig):
super().__init__(config)
self.config = config
self._teleop = None
self._teleop_thread = None
self._latest_pose = None
self._latest_message = None
self._android_lock = threading.Lock()
@property
def is_connected(self) -> bool:
return self._teleop is not None
def connect(self) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
logger.info("Starting teleop stream for Android...")
self._teleop = Teleop()
self._teleop.subscribe(self._android_callback)
self._teleop_thread = threading.Thread(target=self._teleop.run, daemon=True)
self._teleop_thread.start()
logger.info(f"{self} connected, teleop stream started.")
self.calibrate()
def calibrate(self) -> None:
print(
"Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
)
print("Touch and move on the WebXR page to capture this pose...\n")
pos, rot = self._wait_for_capture_trigger()
self._calib_pos = pos.copy()
self._calib_rot_inv = rot.inv()
self._enabled = False
print("Calibration done\n")
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
while True:
with self._android_lock:
msg = self._latest_message or {}
if bool(msg.get("move", False)):
ok, pos, rot, _pose = self._read_current_pose()
if ok:
return pos, rot
time.sleep(0.01)
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
with self._android_lock:
if self._latest_pose is None:
return False, None, None, None
p = self._latest_pose.copy()
pose = self._latest_pose
rot = Rotation.from_matrix(p[:3, :3])
pos = p[:3, 3] - rot.apply(self.config.camera_offset)
return True, pos, rot, pose
def _android_callback(self, pose: np.ndarray, message: dict) -> None:
with self._android_lock:
self._latest_pose = pose
self._latest_message = message
def get_action(self) -> dict:
ok, raw_pos, raw_rot, pose = self._read_current_pose()
if not ok or not self.is_calibrated:
return {}
# Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
raw_inputs: dict[str, float | int | bool] = {}
msg = self._latest_message or {}
raw_inputs["move"] = bool(msg.get("move", False))
raw_inputs["scale"] = float(msg.get("scale", 1.0))
raw_inputs["reservedButtonA"] = bool(msg.get("reservedButtonA", False))
raw_inputs["reservedButtonB"] = bool(msg.get("reservedButtonB", False))
enable = bool(raw_inputs.get("move", False))
# Rising edge then re-capture calibration immediately from current raw pose
if enable and not self._enabled:
self._reapply_position_calibration(raw_pos)
# Apply calibration
pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos)
rot_cal = self._calib_rot_inv * raw_rot
self._enabled = enable
return {
"phone.pos": pos_cal,
"phone.rot": rot_cal,
"phone.raw_inputs": raw_inputs,
"phone.enabled": self._enabled,
}
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self._teleop = None
if self._teleop_thread and self._teleop_thread.is_alive():
self._teleop_thread.join(timeout=1.0)
self._teleop_thread = None
self._latest_pose = None
class Phone(Teleoperator):
"""
Phone-based teleoperator using ARKit (iOS via HEBI Mobile I/O App) or the teleop Python package (Android via WebXR API).
For HEBI Mobile I/O we also expose 8 analog (a1-a8) and 8 digital (b1-b8) inputs.
Press and hold **B1** to enable teleoperation. While enabled, the first B1 press
captures a reference pose and rotation, when disabled and pressed again the position is reapplied.
"""
config_class = PhoneConfig
name = "phone"
def __init__(self, config: PhoneConfig):
super().__init__(config)
self.config = config
self._phone_impl: Teleoperator
if self.config.phone_os == PhoneOS.IOS:
self._phone_impl = IOSPhone(config)
elif self.config.phone_os == PhoneOS.ANDROID:
self._phone_impl = AndroidPhone(config)
else:
raise ValueError(f"Invalid config phone_os: {self.config.phone_os}")
@property
def is_connected(self) -> bool:
return self._phone_impl.is_connected
def connect(self) -> None:
return self._phone_impl.connect()
def calibrate(self) -> None:
return self._phone_impl.calibrate()
@property
def is_calibrated(self) -> bool:
return self._phone_impl.is_calibrated
@property
def action_features(self) -> dict[str, type]:
return self._phone_impl.action_features
@property
def feedback_features(self) -> dict[str, type]:
return self._phone_impl.feedback_features
def configure(self) -> None:
return self._phone_impl.configure()
def get_action(self) -> dict:
return self._phone_impl.get_action()
def send_feedback(self, feedback: dict[str, float]) -> None:
return self._phone_impl.send_feedback(feedback)
def disconnect(self) -> None:
return self._phone_impl.disconnect()

View File

@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, Callable
from lerobot.utils.utils import format_big_number
@@ -84,6 +84,7 @@ class MetricsTracker:
"samples",
"episodes",
"epochs",
"accelerator",
]
def __init__(
@@ -93,12 +94,14 @@ class MetricsTracker:
num_episodes: int,
metrics: dict[str, AverageMeter],
initial_step: int = 0,
accelerator: Callable | None = None,
):
self.__dict__.update(dict.fromkeys(self.__keys__))
self._batch_size = batch_size
self._num_frames = num_frames
self._avg_samples_per_ep = num_frames / num_episodes
self.metrics = metrics
self.accelerator = accelerator
self.steps = initial_step
# A sample is an (observation,action) pair, where observation and action
@@ -128,7 +131,7 @@ class MetricsTracker:
Updates metrics that depend on 'step' for one step.
"""
self.steps += 1
self.samples += self._batch_size
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames

View File

@@ -17,7 +17,7 @@ import random
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path
from typing import Any
from typing import Any, Callable, Generator
import numpy as np
import torch
@@ -164,7 +164,7 @@ def set_rng_state(random_state_dict: dict[str, Any]):
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
def set_seed(seed) -> None:
def set_seed(seed: int, accelerator: Callable | None = None) -> None:
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
@@ -172,6 +172,11 @@ def set_seed(seed) -> None:
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
if accelerator:
from accelerate.utils import set_seed as accelerate_set_seed
accelerate_set_seed(seed)
@contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]:

View File

@@ -24,6 +24,7 @@ import time
from copy import copy, deepcopy
from datetime import datetime, timezone
from pathlib import Path
from typing import Callable
from statistics import mean
import numpy as np
@@ -56,13 +57,15 @@ def auto_select_torch_device() -> torch.device:
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
def get_safe_torch_device(
try_device: str, log: bool = False, accelerator: Callable | None = None
) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available."""
try_device = str(try_device)
match try_device:
case "cuda":
assert torch.cuda.is_available()
device = torch.device("cuda")
device = accelerator.device if accelerator else torch.device("cuda")
case "mps":
assert torch.backends.mps.is_available()
device = torch.device("mps")
@@ -116,6 +119,7 @@ def init_logging(
display_pid: bool = False,
console_level: str = "INFO",
file_level: str = "DEBUG",
accelerator: Callable | None = None,
):
def custom_format(record: logging.LogRecord) -> str:
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@@ -152,6 +156,11 @@ def init_logging(
file_handler.setLevel(file_level.upper())
logger.addHandler(file_handler)
if accelerator is not None and not accelerator.is_main_process:
# Disable duplicate logging on non-main processes
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
logging.getLogger().setLevel(logging.WARNING)
def format_big_number(num, precision=0):
suffixes = ["", "K", "M", "B", "T", "Q"]
@@ -165,6 +174,10 @@ def format_big_number(num, precision=0):
return num
def is_launched_with_accelerate() -> bool:
return "ACCELERATE_MIXED_PRECISION" in os.environ
def _relative_path_between(path1: Path, path2: Path) -> Path:
"""Returns path1 relative to path2."""
path1 = path1.absolute()

View File

@@ -23,7 +23,7 @@ from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors
from lerobot.policies.factory import make_policy, make_policy_config, make_processor
from lerobot.processor import TransitionKey
from lerobot.utils.random_utils import set_seed
@@ -40,7 +40,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
dataset = make_dataset(train_cfg)
dataset_stats = dataset.meta.stats
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
preprocessor, postprocessor = make_pre_post_processors(train_cfg.policy, dataset_stats=dataset_stats)
preprocessor, postprocessor = make_processor(train_cfg.policy, dataset_stats=dataset_stats)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)

View File

@@ -39,7 +39,7 @@ from lerobot.policies.factory import (
get_policy_class,
make_policy,
make_policy_config,
make_pre_post_processors,
make_processor,
)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.random_utils import seeded_context
@@ -151,7 +151,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
# Check that we can make the policy object.
dataset = make_dataset(train_cfg)
preprocessor, _ = make_pre_post_processors(train_cfg.policy, None)
preprocessor, _ = make_processor(train_cfg.policy, None)
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
assert isinstance(policy, PreTrainedPolicy)
@@ -225,7 +225,7 @@ def test_act_backbone_lr():
assert cfg.policy.optimizer_lr_backbone == 0.001
dataset = make_dataset(cfg)
preprocessor, _ = make_pre_post_processors(cfg.policy, None)
preprocessor, _ = make_processor(cfg.policy, None)
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
assert len(optimizer.param_groups) == 2

View File

@@ -23,7 +23,7 @@ import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.processor_act import make_act_pre_post_processors
from lerobot.policies.act.processor_act import make_act_processor
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
@@ -78,7 +78,7 @@ def test_make_act_processor_basic():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
preprocessor, postprocessor = make_act_processor(config, stats)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
@@ -102,12 +102,7 @@ def test_act_processor_normalization():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_act_processor(config, stats)
# Create test data
observation = {OBS_STATE: torch.randn(7)}
@@ -136,12 +131,7 @@ def test_act_processor_cuda():
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_act_processor(config, stats)
# Create CPU data
observation = {OBS_STATE: torch.randn(7)}
@@ -170,12 +160,7 @@ def test_act_processor_accelerate_scenario():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_act_processor(config, stats)
# Simulate Accelerate: data already on GPU
device = torch.device("cuda:0")
@@ -198,7 +183,7 @@ def test_act_processor_multi_gpu():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
preprocessor, postprocessor = make_act_processor(config, stats)
# Simulate data on different GPU (like in multi-GPU training)
device = torch.device("cuda:1")
@@ -218,7 +203,7 @@ def test_act_processor_without_stats():
"""Test ACT processor creation without dataset statistics."""
config = create_default_config()
preprocessor, postprocessor = make_act_pre_post_processors(config, dataset_stats=None)
preprocessor, postprocessor = make_act_processor(config, dataset_stats=None)
# Should still create processors, but normalization won't have stats
assert preprocessor is not None
@@ -238,21 +223,14 @@ def test_act_processor_save_and_load():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_act_processor(config, stats)
with tempfile.TemporaryDirectory() as tmpdir:
# Save preprocessor
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
# Test that loaded processor works
observation = {OBS_STATE: torch.randn(7)}
@@ -271,12 +249,7 @@ def test_act_processor_device_placement_preservation():
# Test with CPU config
config.device = "cpu"
preprocessor, _ = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, _ = make_act_processor(config, stats)
# Process CPU data
observation = {OBS_STATE: torch.randn(7)}
@@ -296,21 +269,12 @@ def test_act_processor_mixed_precision():
stats = create_default_stats()
# Modify the device processor to use float16
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_act_processor(config, stats)
# Replace DeviceProcessor with one that uses float16
modified_steps = []
for step in preprocessor.steps:
for i, step in enumerate(preprocessor.steps):
if isinstance(step, DeviceProcessor):
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
# Create test data
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)}
@@ -330,12 +294,7 @@ def test_act_processor_batch_consistency():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_act_processor(config, stats)
# Test single sample (unbatched)
observation = {OBS_STATE: torch.randn(7)}

View File

@@ -245,7 +245,7 @@ def test_mixed_observation():
def test_integration_with_robot_processor():
"""Test ToBatchProcessor integration with RobotProcessor."""
to_batch_processor = ToBatchProcessor()
pipeline = RobotProcessor([to_batch_processor], to_transition=lambda x: x, to_output=lambda x: x)
pipeline = RobotProcessor([to_batch_processor])
# Create unbatched observation
observation = {
@@ -285,9 +285,7 @@ def test_serialization_methods():
def test_save_and_load_pretrained():
"""Test saving and loading ToBatchProcessor with RobotProcessor."""
processor = ToBatchProcessor()
pipeline = RobotProcessor(
[processor], name="BatchPipeline", to_transition=lambda x: x, to_output=lambda x: x
)
pipeline = RobotProcessor([processor], name="BatchPipeline")
with tempfile.TemporaryDirectory() as tmp_dir:
# Save pipeline
@@ -298,9 +296,7 @@ def test_save_and_load_pretrained():
assert config_path.exists()
# Load pipeline
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
)
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
assert loaded_pipeline.name == "BatchPipeline"
assert len(loaded_pipeline) == 1
@@ -327,13 +323,11 @@ def test_registry_functionality():
def test_registry_based_save_load():
"""Test saving and loading using registry name."""
processor = ToBatchProcessor()
pipeline = RobotProcessor([processor], to_transition=lambda x: x, to_output=lambda x: x)
pipeline = RobotProcessor([processor])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
)
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
# Verify the loaded processor works
observation = {
@@ -609,6 +603,24 @@ def test_action_dtype_preservation():
assert result[TransitionKey.ACTION].shape == (1, 4)
def test_action_in_place_mutation():
"""Test that the processor mutates the transition in place for actions."""
processor = ToBatchProcessor()
action = torch.randn(4)
transition = create_transition(action=action)
# Store reference to original transition
original_transition = transition
# Process
result = processor(transition)
# Should be the same object (in-place mutation)
assert result is original_transition
assert result[TransitionKey.ACTION].shape == (1, 4)
def test_empty_action_tensor():
"""Test handling of empty action tensors."""
processor = ToBatchProcessor()
@@ -839,6 +851,27 @@ def test_task_comprehensive_string_cases():
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == task_list
assert isinstance(processed_comp_data["task"], list)
assert processed_comp_data["task"] is task_list # Should be same object (in-place)
def test_task_in_place_mutation():
"""Test that the processor mutates complementary_data in place for tasks."""
processor = ToBatchProcessor()
complementary_data = {"task": "sort_objects"}
transition = create_transition(complementary_data=complementary_data)
# Store reference to original transition and complementary_data
original_transition = transition
original_comp_data = complementary_data
# Process
result = processor(transition)
# Should be the same objects (in-place mutation)
assert result is original_transition
assert result[TransitionKey.COMPLEMENTARY_DATA] is original_comp_data
assert original_comp_data["task"] == ["sort_objects"]
def test_task_preserves_other_keys():
@@ -1094,49 +1127,3 @@ def test_empty_index_tensor():
# Should remain unchanged (already 1D)
assert result[TransitionKey.COMPLEMENTARY_DATA]["index"].shape == (0,)
def test_action_processing_creates_new_transition():
"""Test that the processor creates a new transition object with correctly processed action."""
processor = ToBatchProcessor()
action = torch.randn(4)
transition = create_transition(action=action)
# Store reference to original transition
original_transition = transition
# Process
result = processor(transition)
# Should be a different object (functional design, not in-place mutation)
assert result is not original_transition
# Original transition should remain unchanged
assert original_transition[TransitionKey.ACTION].shape == (4,)
# Result should have correctly processed action with batch dimension
assert result[TransitionKey.ACTION].shape == (1, 4)
assert torch.equal(result[TransitionKey.ACTION][0], action)
def test_task_processing_creates_new_transition():
"""Test that the processor creates a new transition object with correctly processed task."""
processor = ToBatchProcessor()
complementary_data = {"task": "sort_objects"}
transition = create_transition(complementary_data=complementary_data)
# Store reference to original transition and complementary_data
original_transition = transition
original_comp_data = complementary_data
# Process
result = processor(transition)
# Should be different transition object (functional design)
assert result is not original_transition
# But complementary_data is the same reference (current implementation behavior)
assert result[TransitionKey.COMPLEMENTARY_DATA] is original_comp_data
# The task should be processed correctly (wrapped in list)
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["sort_objects"]
# Original complementary data is also modified (current behavior)
assert original_comp_data["task"] == ["sort_objects"]

View File

@@ -97,12 +97,7 @@ def test_classifier_processor_normalization():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Create test data
observation = {
@@ -128,12 +123,7 @@ def test_classifier_processor_cuda():
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Create CPU data
observation = {
@@ -166,12 +156,7 @@ def test_classifier_processor_accelerate_scenario():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Simulate Accelerate: data already on GPU
device = torch.device("cuda:0")
@@ -245,22 +230,14 @@ def test_classifier_processor_save_and_load():
config = create_default_config()
stats = create_default_stats()
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_classifier_processor(config, stats)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(
factory_preprocessor.steps, to_transition=lambda x: x, to_output=lambda x: x
)
preprocessor, postprocessor = make_classifier_processor(config, stats)
with tempfile.TemporaryDirectory() as tmpdir:
# Save preprocessor
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
# Test that loaded processor works
observation = {
@@ -283,19 +260,13 @@ def test_classifier_processor_mixed_precision():
config.device = "cuda"
stats = create_default_stats()
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_classifier_processor(config, stats)
# Create processor
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Replace DeviceProcessor with one that uses float16
modified_steps = []
for step in factory_preprocessor.steps:
for i, step in enumerate(preprocessor.steps):
if isinstance(step, DeviceProcessor):
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
else:
modified_steps.append(step)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
# Create test data
observation = {
@@ -319,12 +290,7 @@ def test_classifier_processor_batch_data():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Test with batched data
batch_size = 16
@@ -349,12 +315,7 @@ def test_classifier_processor_postprocessor_identity():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Create test data for postprocessor
reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions

View File

@@ -5,7 +5,6 @@ import torch
from lerobot.processor.converters import (
to_dataset_frame,
to_output_robot_action,
to_tensor,
to_transition_robot_observation,
to_transition_teleop_action,
)
@@ -13,12 +12,12 @@ from lerobot.processor.pipeline import TransitionKey
def test_to_transition_teleop_action_prefix_and_tensor_conversion():
# Scalars, arrays, and uint8 arrays are all converted to tensors
# Scalars, arrays, and "image-like" uint8 arrays are supported
img = np.zeros((8, 12, 3), dtype=np.uint8)
act = {
"ee.x": 0.5, # scalar to torch tensor
"delta": np.array([1.0, 2.0]), # ndarray to torch tensor
"raw_img": img, # uint8 HWC to torch tensor
"raw_img": img, # uint8 HWC to passthrough ndarray
}
tr = to_transition_teleop_action(act)
@@ -30,7 +29,7 @@ def test_to_transition_teleop_action_prefix_and_tensor_conversion():
assert "action.delta" in tr[TransitionKey.ACTION]
assert "action.raw_img" in tr[TransitionKey.ACTION]
# Types: all values -> torch tensor
# Types: scalars/arrays -> torch tensor; images to np.ndarray
assert isinstance(tr[TransitionKey.ACTION]["action.ee.x"], torch.Tensor)
assert tr[TransitionKey.ACTION]["action.ee.x"].item() == pytest.approx(0.5)
@@ -38,8 +37,8 @@ def test_to_transition_teleop_action_prefix_and_tensor_conversion():
assert tr[TransitionKey.ACTION]["action.delta"].shape == (2,)
assert torch.allclose(tr[TransitionKey.ACTION]["action.delta"], torch.tensor([1.0, 2.0]))
assert isinstance(tr[TransitionKey.ACTION]["action.raw_img"], torch.Tensor)
assert tr[TransitionKey.ACTION]["action.raw_img"].dtype == torch.float32 # converted from uint8
assert isinstance(tr[TransitionKey.ACTION]["action.raw_img"], np.ndarray)
assert tr[TransitionKey.ACTION]["action.raw_img"].dtype == np.uint8
assert tr[TransitionKey.ACTION]["action.raw_img"].shape == (8, 12, 3)
# Observation is created as empty dict by make_transition
@@ -195,185 +194,3 @@ def test_to_dataset_frame_merge_and_pack_vectors_and_metadata():
# Complementary data
assert batch["frame_is_pad"] is True
assert batch["task"] == "Pick cube"
# Tests for the unified to_tensor function
def test_to_tensor_numpy_arrays():
"""Test to_tensor with various numpy arrays."""
# Regular numpy array
arr = np.array([1.0, 2.0, 3.0])
result = to_tensor(arr)
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
# Different numpy dtypes should convert to float32 by default
int_arr = np.array([1, 2, 3], dtype=np.int64)
result = to_tensor(int_arr)
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
# uint8 arrays (previously "preserved") should now convert
uint8_arr = np.array([100, 150, 200], dtype=np.uint8)
result = to_tensor(uint8_arr)
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert torch.allclose(result, torch.tensor([100.0, 150.0, 200.0]))
def test_to_tensor_numpy_scalars():
"""Test to_tensor with numpy scalars (0-dimensional arrays)."""
# numpy float32 scalar
scalar = np.float32(3.14)
result = to_tensor(scalar)
assert isinstance(result, torch.Tensor)
assert result.ndim == 0 # Should be 0-dimensional tensor
assert result.dtype == torch.float32
assert result.item() == pytest.approx(3.14)
# numpy int32 scalar
int_scalar = np.int32(42)
result = to_tensor(int_scalar)
assert isinstance(result, torch.Tensor)
assert result.ndim == 0
assert result.dtype == torch.float32
assert result.item() == pytest.approx(42.0)
def test_to_tensor_python_scalars():
"""Test to_tensor with Python scalars."""
# Python int
result = to_tensor(42)
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert result.item() == pytest.approx(42.0)
# Python float
result = to_tensor(3.14)
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert result.item() == pytest.approx(3.14)
def test_to_tensor_sequences():
"""Test to_tensor with lists and tuples."""
# List
result = to_tensor([1, 2, 3])
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
# Tuple
result = to_tensor((4.5, 5.5, 6.5))
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert torch.allclose(result, torch.tensor([4.5, 5.5, 6.5]))
def test_to_tensor_existing_tensors():
"""Test to_tensor with existing PyTorch tensors."""
# Tensor with same dtype should pass through with potential device change
tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
result = to_tensor(tensor)
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert torch.allclose(result, tensor)
# Tensor with different dtype should convert
int_tensor = torch.tensor([1, 2, 3], dtype=torch.int64)
result = to_tensor(int_tensor)
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
def test_to_tensor_dictionaries():
"""Test to_tensor with nested dictionaries."""
# Simple dictionary
data = {"mean": [0.1, 0.2], "std": np.array([1.0, 2.0]), "count": 42}
result = to_tensor(data)
assert isinstance(result, dict)
assert isinstance(result["mean"], torch.Tensor)
assert isinstance(result["std"], torch.Tensor)
assert isinstance(result["count"], torch.Tensor)
assert torch.allclose(result["mean"], torch.tensor([0.1, 0.2]))
assert torch.allclose(result["std"], torch.tensor([1.0, 2.0]))
assert result["count"].item() == pytest.approx(42.0)
# Nested dictionary
nested = {
"action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]},
"observation": {"mean": np.array([0.5, 0.6]), "count": 10},
}
result = to_tensor(nested)
assert isinstance(result, dict)
assert isinstance(result["action"], dict)
assert isinstance(result["observation"], dict)
assert isinstance(result["action"]["mean"], torch.Tensor)
assert isinstance(result["observation"]["mean"], torch.Tensor)
assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2]))
assert torch.allclose(result["observation"]["mean"], torch.tensor([0.5, 0.6]))
def test_to_tensor_none_filtering():
"""Test that None values are filtered out from dictionaries."""
data = {"valid": [1, 2, 3], "none_value": None, "nested": {"valid": [4, 5], "also_none": None}}
result = to_tensor(data)
assert "none_value" not in result
assert "also_none" not in result["nested"]
assert "valid" in result
assert "valid" in result["nested"]
assert torch.allclose(result["valid"], torch.tensor([1.0, 2.0, 3.0]))
def test_to_tensor_dtype_parameter():
"""Test to_tensor with different dtype parameters."""
arr = np.array([1, 2, 3])
# Default dtype (float32)
result = to_tensor(arr)
assert result.dtype == torch.float32
# Explicit float32
result = to_tensor(arr, dtype=torch.float32)
assert result.dtype == torch.float32
# Float64
result = to_tensor(arr, dtype=torch.float64)
assert result.dtype == torch.float64
# Preserve original dtype
float64_arr = np.array([1.0, 2.0, 3.0], dtype=np.float64)
result = to_tensor(float64_arr, dtype=None)
assert result.dtype == torch.float64
def test_to_tensor_device_parameter():
"""Test to_tensor with device parameter."""
arr = np.array([1.0, 2.0, 3.0])
# CPU device (default)
result = to_tensor(arr, device="cpu")
assert result.device.type == "cpu"
# CUDA device (if available)
if torch.cuda.is_available():
result = to_tensor(arr, device="cuda")
assert result.device.type == "cuda"
def test_to_tensor_empty_dict():
"""Test to_tensor with empty dictionary."""
result = to_tensor({})
assert isinstance(result, dict)
assert len(result) == 0
def test_to_tensor_unsupported_type():
"""Test to_tensor with unsupported types raises TypeError."""
with pytest.raises(TypeError, match="Unsupported type for tensor conversion"):
to_tensor("unsupported_string")
with pytest.raises(TypeError, match="Unsupported type for tensor conversion"):
to_tensor(object())

View File

@@ -311,12 +311,7 @@ def test_integration_with_robot_processor():
device_processor = DeviceProcessor(device="cpu")
batch_processor = ToBatchProcessor()
processor = RobotProcessor(
steps=[batch_processor, device_processor],
name="test_pipeline",
to_transition=lambda x: x,
to_output=lambda x: x,
)
processor = RobotProcessor(steps=[batch_processor, device_processor], name="test_pipeline")
# Create test data
observation = {OBS_STATE: torch.randn(10)}
@@ -990,8 +985,6 @@ def test_policy_processor_integration():
DeviceProcessor(device="cuda"),
],
name="test_preprocessor",
to_transition=lambda x: x,
to_output=lambda x: x,
)
# Create output processor (postprocessor) that moves to CPU
@@ -1001,8 +994,6 @@ def test_policy_processor_integration():
UnnormalizerProcessor(features={ACTION: features[ACTION]}, norm_map=norm_map, stats=stats),
],
name="test_postprocessor",
to_transition=lambda x: x,
to_output=lambda x: x,
)
# Test data on CPU

View File

@@ -23,7 +23,7 @@ import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
@@ -81,7 +81,7 @@ def test_make_diffusion_processor_basic():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
preprocessor, postprocessor = make_diffusion_processor(config, stats)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
@@ -105,12 +105,7 @@ def test_diffusion_processor_with_images():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_diffusion_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_diffusion_processor(config, stats)
# Create test data with images
observation = {
@@ -136,12 +131,7 @@ def test_diffusion_processor_cuda():
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_diffusion_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_diffusion_processor(config, stats)
# Create CPU data
observation = {
@@ -174,12 +164,7 @@ def test_diffusion_processor_accelerate_scenario():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_diffusion_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_diffusion_processor(config, stats)
# Simulate Accelerate: data already on GPU
device = torch.device("cuda:0")
@@ -206,7 +191,7 @@ def test_diffusion_processor_multi_gpu():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
preprocessor, postprocessor = make_diffusion_processor(config, stats)
# Simulate data on different GPU
device = torch.device("cuda:1")
@@ -230,7 +215,7 @@ def test_diffusion_processor_without_stats():
"""Test Diffusion processor creation without dataset statistics."""
config = create_default_config()
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, dataset_stats=None)
preprocessor, postprocessor = make_diffusion_processor(config, dataset_stats=None)
# Should still create processors
assert preprocessor is not None
@@ -253,22 +238,14 @@ def test_diffusion_processor_save_and_load():
config = create_default_config()
stats = create_default_stats()
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_diffusion_pre_post_processors(config, stats)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(
factory_preprocessor.steps, to_transition=lambda x: x, to_output=lambda x: x
)
preprocessor, postprocessor = make_diffusion_processor(config, stats)
with tempfile.TemporaryDirectory() as tmpdir:
# Save preprocessor
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
# Test that loaded processor works
observation = {
@@ -291,19 +268,13 @@ def test_diffusion_processor_mixed_precision():
config.device = "cuda"
stats = create_default_stats()
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_diffusion_pre_post_processors(config, stats)
# Create processor
preprocessor, postprocessor = make_diffusion_processor(config, stats)
# Replace DeviceProcessor with one that uses float16
modified_steps = []
for step in factory_preprocessor.steps:
for i, step in enumerate(preprocessor.steps):
if isinstance(step, DeviceProcessor):
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
else:
modified_steps.append(step)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
# Create test data
observation = {
@@ -327,12 +298,7 @@ def test_diffusion_processor_identity_normalization():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_diffusion_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_diffusion_processor(config, stats)
# Create test data
image_value = torch.rand(3, 224, 224) * 255 # Large values
@@ -356,12 +322,7 @@ def test_diffusion_processor_batch_consistency():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_diffusion_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_diffusion_processor(config, stats)
# Test with different batch sizes
for batch_size in [1, 8, 32]:

View File

@@ -20,11 +20,12 @@ import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.converters import to_tensor
from lerobot.processor.normalize_processor import (
NormalizerProcessor,
UnnormalizerProcessor,
_convert_stats_to_tensors,
hotswap_stats,
rename_stats,
)
from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey
@@ -51,7 +52,7 @@ def test_numpy_conversion():
"std": np.array([0.2, 0.2, 0.2]),
}
}
tensor_stats = to_tensor(stats)
tensor_stats = _convert_stats_to_tensors(stats)
assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor)
assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor)
@@ -66,7 +67,7 @@ def test_tensor_conversion():
"std": torch.tensor([1.0, 1.0]),
}
}
tensor_stats = to_tensor(stats)
tensor_stats = _convert_stats_to_tensors(stats)
assert tensor_stats["action"]["mean"].dtype == torch.float32
assert tensor_stats["action"]["std"].dtype == torch.float32
@@ -79,7 +80,7 @@ def test_scalar_conversion():
"std": 0.1,
}
}
tensor_stats = to_tensor(stats)
tensor_stats = _convert_stats_to_tensors(stats)
assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5))
assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1))
@@ -92,7 +93,7 @@ def test_list_conversion():
"max": [1.0, 1.0, 2.0],
}
}
tensor_stats = to_tensor(stats)
tensor_stats = _convert_stats_to_tensors(stats)
assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0]))
assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0]))
@@ -105,7 +106,7 @@ def test_unsupported_type():
}
}
with pytest.raises(TypeError, match="Unsupported type"):
to_tensor(stats)
_convert_stats_to_tensors(stats)
# Helper functions to create feature maps and norm maps
@@ -181,10 +182,7 @@ def test_selective_normalization(observation_stats):
features = _create_observation_features()
norm_map = _create_observation_norm_map()
normalizer = NormalizerProcessor(
features=features,
norm_map=norm_map,
stats=observation_stats,
normalize_observation_keys={"observation.image"},
features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"}
)
observation = {
@@ -245,7 +243,6 @@ def test_from_lerobot_dataset():
def test_state_dict_save_load(observation_normalizer):
# Save state
state_dict = observation_normalizer.state_dict()
print("State dict:", state_dict)
# Create new normalizer and load state
features = _create_observation_features()
@@ -467,10 +464,10 @@ def test_processor_from_lerobot_dataset(full_stats):
norm_map = _create_full_norm_map()
processor = NormalizerProcessor.from_lerobot_dataset(
mock_dataset, features, norm_map, normalize_observation_keys={"observation.image"}
mock_dataset, features, norm_map, normalize_keys={"observation.image"}
)
assert processor.normalize_observation_keys == {"observation.image"}
assert processor.normalize_keys == {"observation.image"}
assert "observation.image" in processor._tensor_stats
assert "action" in processor._tensor_stats
@@ -479,16 +476,12 @@ def test_get_config(full_stats):
features = _create_full_features()
norm_map = _create_full_norm_map()
processor = NormalizerProcessor(
features=features,
norm_map=norm_map,
stats=full_stats,
normalize_observation_keys={"observation.image"},
eps=1e-6,
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
)
config = processor.get_config()
expected_config = {
"normalize_observation_keys": ["observation.image"],
"normalize_keys": ["observation.image"],
"eps": 1e-6,
"features": {
"observation.image": {"type": "VISUAL", "shape": (3, 96, 96)},
@@ -506,7 +499,7 @@ def test_get_config(full_stats):
def test_integration_with_robot_processor(normalizer_processor):
"""Test integration with RobotProcessor pipeline"""
robot_processor = RobotProcessor([normalizer_processor], to_transition=lambda x: x, to_output=lambda x: x)
robot_processor = RobotProcessor([normalizer_processor])
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
@@ -587,11 +580,7 @@ def test_serialization_roundtrip(full_stats):
features = _create_full_features()
norm_map = _create_full_norm_map()
original_processor = NormalizerProcessor(
features=features,
norm_map=norm_map,
stats=full_stats,
normalize_observation_keys={"observation.image"},
eps=1e-6,
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
)
# Get config (serialization)
@@ -602,7 +591,7 @@ def test_serialization_roundtrip(full_stats):
features=config["features"],
norm_map=config["norm_map"],
stats=full_stats,
normalize_observation_keys=set(config["normalize_observation_keys"]),
normalize_keys=set(config["normalize_keys"]),
eps=config["eps"],
)
@@ -950,31 +939,31 @@ def test_identity_config_serialization():
assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION])
# def test_unsupported_normalization_mode_error():
# """Test that unsupported normalization modes raise appropriate errors."""
# features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))}
def test_unsupported_normalization_mode_error():
"""Test that unsupported normalization modes raise appropriate errors."""
features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))}
# # Create an invalid norm_map (this would never happen in practice, but tests error handling)
# from enum import Enum
# Create an invalid norm_map (this would never happen in practice, but tests error handling)
from enum import Enum
# class InvalidMode(str, Enum):
# INVALID = "INVALID"
class InvalidMode(str, Enum):
INVALID = "INVALID"
# # We can't actually pass an invalid enum to the processor due to type checking,
# # but we can test the error by manipulating the norm_map after creation
# norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
# stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}}
# We can't actually pass an invalid enum to the processor due to type checking,
# but we can test the error by manipulating the norm_map after creation
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}}
# normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
# # Manually inject an invalid mode to test error handling
# normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE"
# Manually inject an invalid mode to test error handling
normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE"
# observation = {"observation.state": torch.tensor([1.0, -0.5])}
# transition = create_transition(observation=observation)
observation = {"observation.state": torch.tensor([1.0, -0.5])}
transition = create_transition(observation=observation)
# with pytest.raises(ValueError, match="Unsupported normalization mode"):
# normalizer(transition)
with pytest.raises(ValueError, match="Unsupported normalization mode"):
normalizer(transition)
def test_hotswap_stats_basic_functionality():
@@ -1017,7 +1006,7 @@ def test_hotswap_stats_basic_functionality():
assert new_processor.steps[1].stats == new_stats
# Check that tensor stats are updated correctly
expected_tensor_stats = to_tensor(new_stats)
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
for key in expected_tensor_stats:
for stat_name in expected_tensor_stats[key]:
torch.testing.assert_close(
@@ -1160,15 +1149,11 @@ def test_hotswap_stats_preserves_other_attributes():
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalize_observation_keys = {"observation.image"}
normalize_keys = {"observation.image"}
eps = 1e-6
normalizer = NormalizerProcessor(
features=features,
norm_map=norm_map,
stats=initial_stats,
normalize_observation_keys=normalize_observation_keys,
eps=eps,
features=features, norm_map=norm_map, stats=initial_stats, normalize_keys=normalize_keys, eps=eps
)
robot_processor = RobotProcessor(steps=[normalizer])
@@ -1179,7 +1164,7 @@ def test_hotswap_stats_preserves_other_attributes():
new_normalizer = new_processor.steps[0]
assert new_normalizer.features == features
assert new_normalizer.norm_map == norm_map
assert new_normalizer.normalize_observation_keys == normalize_observation_keys
assert new_normalizer.normalize_keys == normalize_keys
assert new_normalizer.eps == eps
# But stats should be updated
@@ -1223,7 +1208,7 @@ def test_hotswap_stats_multiple_normalizer_types():
assert step.stats == new_stats
# Check tensor stats conversion
expected_tensor_stats = to_tensor(new_stats)
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
for key in expected_tensor_stats:
for stat_name in expected_tensor_stats[key]:
torch.testing.assert_close(
@@ -1285,6 +1270,273 @@ def test_hotswap_stats_with_different_data_types():
torch.testing.assert_close(tensor_stats["observation.image"]["max"], torch.tensor(1.0))
def test_normalization_info_tracking():
"""Test that normalization info is tracked in complementary_data."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.STATE: NormalizationMode.MIN_MAX,
FeatureType.ACTION: NormalizationMode.IDENTITY,
}
stats = {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
},
"observation.state": {
"min": np.array([0.0, -1.0]),
"max": np.array([1.0, 1.0]),
},
"action": {
"mean": np.array([0.0, 0.0]),
"std": np.array([1.0, 1.0]),
},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = create_transition(observation=observation, action=action)
# Process the transition
normalized_transition = normalizer(transition)
# Check that normalization info is added
comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
assert comp_data is not None
assert "normalized_keys" in comp_data
norm_info = comp_data["normalized_keys"]
assert norm_info["observation.image"] == "MEAN_STD"
assert norm_info["observation.state"] == "MIN_MAX"
assert norm_info["action"] == "IDENTITY"
def test_unnormalization_info_tracking():
"""Test that unnormalization info is tracked in complementary_data."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MIN_MAX,
}
stats = {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
},
"action": {
"min": np.array([-1.0, -1.0]),
"max": np.array([1.0, 1.0]),
},
}
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
action = torch.tensor([0.0, -0.5])
transition = create_transition(observation=observation, action=action)
# Process the transition
unnormalized_transition = unnormalizer(transition)
# Check that unnormalization info is added
comp_data = unnormalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
assert comp_data is not None
assert "unnormalized_keys" in comp_data
unnorm_info = comp_data["unnormalized_keys"]
assert unnorm_info["observation.image"] == "MEAN_STD"
assert unnorm_info["action"] == "MIN_MAX"
def test_normalization_info_with_missing_stats():
"""Test normalization info when stats are missing for some keys."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.STATE: NormalizationMode.MIN_MAX,
}
# Only provide stats for image, not state
stats = {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
# Process the transition
normalized_transition = normalizer(transition)
# Check that only keys with stats are in normalization info
comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
assert comp_data is not None
assert "normalized_keys" in comp_data
norm_info = comp_data["normalized_keys"]
assert norm_info["observation.image"] == "MEAN_STD"
# State should not be in the normalization info since it has no stats
assert "observation.state" not in norm_info
def test_normalization_info_with_selective_keys():
"""Test normalization info with selective normalization."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.STATE: NormalizationMode.MIN_MAX,
}
stats = {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
},
"observation.state": {
"min": np.array([0.0, -1.0]),
"max": np.array([1.0, 1.0]),
},
}
# Only normalize image
normalizer = NormalizerProcessor(
features=features, norm_map=norm_map, stats=stats, normalize_keys={"observation.image"}
)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
# Process the transition
normalized_transition = normalizer(transition)
# Check that only selected keys are in normalization info
comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
assert comp_data is not None
assert "normalized_keys" in comp_data
norm_info = comp_data["normalized_keys"]
assert norm_info["observation.image"] == "MEAN_STD"
# State should not be in the normalization info since it wasn't in normalize_keys
assert "observation.state" not in norm_info
def test_normalization_info_preserved_in_pipeline():
"""Test that normalization info is preserved when using RobotProcessor pipeline."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MIN_MAX,
}
stats = {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
},
"action": {
"min": np.array([-1.0, -1.0]),
"max": np.array([1.0, 1.0]),
},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
# Create pipeline
pipeline = RobotProcessor([normalizer, unnormalizer])
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
action = torch.tensor([0.5, -0.5])
transition = create_transition(observation=observation, action=action)
# Process through pipeline
result = pipeline(transition)
# Check that both normalization and unnormalization info are present
comp_data = result.get(TransitionKey.COMPLEMENTARY_DATA)
assert comp_data is not None
assert "normalized_keys" in comp_data
assert "unnormalized_keys" in comp_data
# Check normalization info
norm_info = comp_data["normalized_keys"]
assert norm_info["observation.image"] == "MEAN_STD"
assert norm_info["action"] == "MIN_MAX"
# Check unnormalization info
unnorm_info = comp_data["unnormalized_keys"]
assert unnorm_info["observation.image"] == "MEAN_STD"
assert unnorm_info["action"] == "MIN_MAX"
def test_normalization_info_empty_transition():
"""Test that no normalization info is added for empty transitions."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MIN_MAX,
}
stats = {
"observation.image": {"mean": [0.5], "std": [0.2]},
"action": {"min": [-1.0], "max": [1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
# Empty transition
transition = create_transition()
# Process the transition
normalized_transition = normalizer(transition)
# Check that no normalization info is added
comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
assert comp_data is None or "normalized_keys" not in comp_data
def test_hotswap_stats_functional_test():
"""Test that hotswapped processor actually works functionally."""
# Create test data
@@ -1317,7 +1569,7 @@ def test_hotswap_stats_functional_test():
# Create original processor
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
original_processor = RobotProcessor(steps=[normalizer], to_transition=lambda x: x, to_output=lambda x: x)
original_processor = RobotProcessor(steps=[normalizer])
# Process with original stats
original_result = original_processor(transition)
@@ -1379,8 +1631,8 @@ def test_min_equals_max_maps_to_minus_one():
assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([-1.0]))
def test_action_normalized_despite_normalize_observation_keys():
"""Action normalization is independent of normalize_observation_keys filter for observations."""
def test_action_normalized_despite_normalize_keys():
"""Action normalization is independent of normalize_keys filter for observations."""
features = {
"observation.state": PolicyFeature(FeatureType.STATE, (1,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
@@ -1388,7 +1640,7 @@ def test_action_normalized_despite_normalize_observation_keys():
norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD}
stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}}
normalizer = NormalizerProcessor(
features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={"observation.state"}
features=features, norm_map=norm_map, stats=stats, normalize_keys={"observation.state"}
)
transition = create_transition(
@@ -1428,6 +1680,19 @@ def test_unnormalize_observations_mean_std_and_min_max():
assert torch.allclose(out_mm, torch.tensor([1.0, 0.0])) # mid of [0,2] and [-2,2]
def test_rename_stats_basic():
orig = {
"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])},
"action": {"mean": np.array([0.0])},
}
mapping = {"observation.state": "observation.robot_state"}
renamed = rename_stats(orig, mapping)
assert "observation.robot_state" in renamed and "observation.state" not in renamed
# Ensure deep copy: mutate original and verify renamed unaffected
orig["observation.state"]["mean"][0] = 42.0
assert renamed["observation.robot_state"]["mean"][0] != 42.0
def test_unknown_observation_keys_ignored():
features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))}
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
@@ -1440,6 +1705,8 @@ def test_unknown_observation_keys_ignored():
# Unknown key should pass through unchanged and not be tracked
assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.unknown"], obs["observation.unknown"])
comp = out.get(TransitionKey.COMPLEMENTARY_DATA) or {}
assert "normalized_keys" in comp and "observation.unknown" not in comp["normalized_keys"]
def test_batched_action_normalization():
@@ -1464,7 +1731,7 @@ def test_complementary_data_preservation():
tr = create_transition(observation={"observation.state": torch.tensor([1.0])}, complementary_data=comp)
out = normalizer(tr)
new_comp = out[TransitionKey.COMPLEMENTARY_DATA]
assert new_comp["existing"] == 123
assert new_comp["existing"] == 123 and "normalized_keys" in new_comp
def test_roundtrip_normalize_unnormalize_non_identity():

View File

@@ -23,7 +23,7 @@ import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre_post_processors
from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_processor
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
@@ -84,12 +84,7 @@ def test_make_pi0_processor_basic():
stats = create_default_stats()
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor"):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_pi0_processor(config, stats)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
@@ -188,12 +183,7 @@ def test_pi0_processor_cuda():
return features
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_pi0_processor(config, stats)
# Create CPU data
observation = {
@@ -243,12 +233,7 @@ def test_pi0_processor_accelerate_scenario():
return features
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_pi0_processor(config, stats)
# Simulate Accelerate: data already on GPU and batched
device = torch.device("cuda:0")
@@ -299,12 +284,7 @@ def test_pi0_processor_multi_gpu():
return features
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_pi0_processor(config, stats)
# Simulate data on different GPU
device = torch.device("cuda:1")
@@ -330,12 +310,7 @@ def test_pi0_processor_without_stats():
# Mock the tokenizer processor
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor"):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
dataset_stats=None,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_pi0_processor(config, dataset_stats=None)
# Should still create processors
assert preprocessor is not None

View File

@@ -176,7 +176,7 @@ class MockStepWithTensorState:
def test_empty_pipeline():
"""Test pipeline with no steps."""
pipeline = RobotProcessor([], to_transition=lambda x: x, to_output=lambda x: x)
pipeline = RobotProcessor()
transition = create_transition()
result = pipeline(transition)
@@ -188,7 +188,7 @@ def test_empty_pipeline():
def test_single_step_pipeline():
"""Test pipeline with a single step."""
step = MockStep("test_step")
pipeline = RobotProcessor([step], to_transition=lambda x: x, to_output=lambda x: x)
pipeline = RobotProcessor([step])
transition = create_transition()
result = pipeline(transition)
@@ -205,7 +205,7 @@ def test_multiple_steps_pipeline():
"""Test pipeline with multiple steps."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotProcessor([step1, step2], to_transition=lambda x: x, to_output=lambda x: x)
pipeline = RobotProcessor([step1, step2])
transition = create_transition()
result = pipeline(transition)
@@ -557,9 +557,7 @@ def test_save_and_load_pretrained():
def test_step_without_optional_methods():
"""Test pipeline with steps that don't implement optional methods."""
step = MockStepWithoutOptionalMethods(multiplier=3.0)
pipeline = RobotProcessor(
[step], to_transition=lambda x: x, to_output=lambda x: x
) # Identity for EnvTransition input/output
pipeline = RobotProcessor([step])
transition = create_transition(reward=2.0)
result = pipeline(transition)
@@ -880,9 +878,7 @@ def test_from_pretrained_with_overrides():
"registered_mock_step": {"device": "cuda", "value": 200},
}
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, overrides=overrides, to_transition=lambda x: x, to_output=lambda x: x
)
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
# Verify the pipeline was loaded correctly
assert len(loaded_pipeline) == 2
@@ -918,9 +914,7 @@ def test_from_pretrained_with_partial_overrides():
# The current implementation applies overrides to ALL steps with the same class name
# Both steps will get the override
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, overrides=overrides, to_transition=lambda x: x, to_output=lambda x: x
)
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
transition = create_transition(reward=1.0)
result = loaded_pipeline(transition)
@@ -977,9 +971,7 @@ def test_from_pretrained_registered_step_override():
# Override using registry name
overrides = {"registered_mock_step": {"value": 999, "device": "cuda"}}
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, overrides=overrides, to_transition=lambda x: x, to_output=lambda x: x
)
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
# Test that overrides were applied
transition = create_transition()
@@ -1007,9 +999,7 @@ def test_from_pretrained_mixed_registered_and_unregistered():
"registered_mock_step": {"value": 777},
}
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, overrides=overrides, to_transition=lambda x: x, to_output=lambda x: x
)
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
# Test both steps
transition = create_transition(reward=2.0)
@@ -1030,9 +1020,7 @@ def test_from_pretrained_no_overrides():
pipeline.save_pretrained(tmp_dir)
# Load without overrides
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
)
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
assert len(loaded_pipeline) == 1
@@ -1052,9 +1040,7 @@ def test_from_pretrained_empty_overrides():
pipeline.save_pretrained(tmp_dir)
# Load with empty overrides
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, overrides={}, to_transition=lambda x: x, to_output=lambda x: x
)
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides={})
assert len(loaded_pipeline) == 1
@@ -1469,8 +1455,6 @@ def test_override_with_nested_config():
loaded = RobotProcessor.from_pretrained(
tmp_dir,
overrides={"complex_config_step": {"nested_config": {"level1": {"level2": "overridden"}}}},
to_transition=lambda x: x,
to_output=lambda x: x,
)
# Test that override worked
@@ -1569,10 +1553,7 @@ def test_override_with_callables():
# Load with callable override
loaded = RobotProcessor.from_pretrained(
tmp_dir,
overrides={"callable_step": {"transform_fn": double_values}},
to_transition=lambda x: x,
to_output=lambda x: x,
tmp_dir, overrides={"callable_step": {"transform_fn": double_values}}
)
# Test it works
@@ -1876,8 +1857,7 @@ def test_save_load_with_custom_converter_functions():
# Should work with standard format (wouldn't work with custom converter)
result = loaded(batch)
# With new behavior, default to_output is _default_transition_to_batch, so result is batch dict
assert "observation.image" in result
assert "observation.image" in result # Standard format preserved
class NonCompliantStep:

View File

@@ -21,7 +21,6 @@ import torch
from lerobot.configs.types import FeatureType
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
from lerobot.processor.rename_processor import rename_stats
from tests.conftest import assert_contract_is_typed
@@ -188,7 +187,7 @@ def test_integration_with_robot_processor():
}
rename_processor = RenameProcessor(rename_map=rename_map)
pipeline = RobotProcessor([rename_processor], to_transition=lambda x: x, to_output=lambda x: x)
pipeline = RobotProcessor([rename_processor])
observation = {
"agent_pos": np.array([1.0, 2.0, 3.0]),
@@ -236,9 +235,7 @@ def test_save_and_load_pretrained():
assert len(state_files) == 0
# Load pipeline
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
)
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
assert loaded_pipeline.name == "TestRenameProcessor"
assert len(loaded_pipeline) == 1
@@ -279,7 +276,7 @@ def test_registry_functionality():
def test_registry_based_save_load():
"""Test save/load using registry name instead of module path."""
processor = RenameProcessor(rename_map={"key1": "renamed_key1"})
pipeline = RobotProcessor([processor], to_transition=lambda x: x, to_output=lambda x: x)
pipeline = RobotProcessor([processor])
with tempfile.TemporaryDirectory() as tmp_dir:
# Save and load
@@ -320,7 +317,7 @@ def test_chained_rename_processors():
}
)
pipeline = RobotProcessor([processor1, processor2], to_transition=lambda x: x, to_output=lambda x: x)
pipeline = RobotProcessor([processor1, processor2])
observation = {
"pos": np.array([1.0, 2.0]),
@@ -468,16 +465,3 @@ def test_features_chained_processors(policy_feature_factory):
assert out["observation.image"] == spec["img"]
assert out["extra"] == spec["extra"]
assert_contract_is_typed(out)
def test_rename_stats_basic():
orig = {
"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])},
"action": {"mean": np.array([0.0])},
}
mapping = {"observation.state": "observation.robot_state"}
renamed = rename_stats(orig, mapping)
assert "observation.robot_state" in renamed and "observation.state" not in renamed
# Ensure deep copy: mutate original and verify renamed unaffected
orig["observation.state"]["mean"][0] = 42.0
assert renamed["observation.robot_state"]["mean"][0] != 42.0

View File

@@ -23,7 +23,7 @@ import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
from lerobot.policies.sac.processor_sac import make_sac_processor
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
@@ -78,12 +78,7 @@ def test_make_sac_processor_basic():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_sac_processor(config, stats)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
@@ -107,12 +102,7 @@ def test_sac_processor_normalization_modes():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_sac_processor(config, stats)
# Create test data
observation = {OBS_STATE: torch.randn(10) * 2} # Larger values to test normalization
@@ -143,12 +133,7 @@ def test_sac_processor_cuda():
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_sac_processor(config, stats)
# Create CPU data
observation = {OBS_STATE: torch.randn(10)}
@@ -177,12 +162,7 @@ def test_sac_processor_accelerate_scenario():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_sac_processor(config, stats)
# Simulate Accelerate: data already on GPU
device = torch.device("cuda:0")
@@ -205,12 +185,7 @@ def test_sac_processor_multi_gpu():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_sac_processor(config, stats)
# Simulate data on different GPU
device = torch.device("cuda:1")
@@ -230,22 +205,7 @@ def test_sac_processor_without_stats():
"""Test SAC processor creation without dataset statistics."""
config = create_default_config()
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(
factory_preprocessor.steps,
name=factory_preprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
postprocessor = RobotProcessor(
factory_postprocessor.steps,
name=factory_postprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
preprocessor, postprocessor = make_sac_processor(config, dataset_stats=None)
# Should still create processors
assert preprocessor is not None
@@ -265,21 +225,14 @@ def test_sac_processor_save_and_load():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_sac_processor(config, stats)
with tempfile.TemporaryDirectory() as tmpdir:
# Save preprocessor
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
# Test that loaded processor works
observation = {OBS_STATE: torch.randn(10)}
@@ -299,12 +252,7 @@ def test_sac_processor_mixed_precision():
stats = create_default_stats()
# Create processor
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_sac_processor(config, stats)
# Replace DeviceProcessor with one that uses float16
for i, step in enumerate(preprocessor.steps):
@@ -329,12 +277,7 @@ def test_sac_processor_batch_data():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_sac_processor(config, stats)
# Test with batched data
batch_size = 32
@@ -355,12 +298,7 @@ def test_sac_processor_edge_cases():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_sac_processor(config, stats)
# Test with empty observation
transition = create_transition(observation={}, action=torch.randn(5))

View File

@@ -23,10 +23,7 @@ import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.processor_smolvla import (
SmolVLANewLineProcessor,
make_smolvla_pre_post_processors,
)
from lerobot.policies.smolvla.processor_smolvla import SmolVLANewLineProcessor, make_smolvla_processor
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
@@ -89,12 +86,7 @@ def test_make_smolvla_processor_basic():
stats = create_default_stats()
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor"):
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_smolvla_processor(config, stats)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
@@ -193,12 +185,7 @@ def test_smolvla_processor_cuda():
return features
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_smolvla_processor(config, stats)
# Create CPU data
observation = {
@@ -248,12 +235,7 @@ def test_smolvla_processor_accelerate_scenario():
return features
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_smolvla_processor(config, stats)
# Simulate Accelerate: data already on GPU and batched
device = torch.device("cuda:0")
@@ -304,12 +286,7 @@ def test_smolvla_processor_multi_gpu():
return features
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_smolvla_processor(config, stats)
# Simulate data on different GPU
device = torch.device("cuda:1")
@@ -335,12 +312,7 @@ def test_smolvla_processor_without_stats():
# Mock the tokenizer processor
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor"):
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
dataset_stats=None,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_smolvla_processor(config, dataset_stats=None)
# Should still create processors
assert preprocessor is not None

View File

@@ -23,7 +23,7 @@ import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
@@ -81,12 +81,7 @@ def test_make_tdmpc_processor_basic():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
@@ -110,12 +105,7 @@ def test_tdmpc_processor_normalization():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
# Create test data
observation = {
@@ -148,12 +138,7 @@ def test_tdmpc_processor_cuda():
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
# Create CPU data
observation = {
@@ -186,12 +171,7 @@ def test_tdmpc_processor_accelerate_scenario():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
# Simulate Accelerate: data already on GPU
device = torch.device("cuda:0")
@@ -218,12 +198,7 @@ def test_tdmpc_processor_multi_gpu():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
# Simulate data on different GPU
device = torch.device("cuda:1")
@@ -247,22 +222,7 @@ def test_tdmpc_processor_without_stats():
"""Test TDMPC processor creation without dataset statistics."""
config = create_default_config()
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_tdmpc_pre_post_processors(config, dataset_stats=None)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(
factory_preprocessor.steps,
name=factory_preprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
postprocessor = RobotProcessor(
factory_postprocessor.steps,
name=factory_postprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
preprocessor, postprocessor = make_tdmpc_processor(config, dataset_stats=None)
# Should still create processors
assert preprocessor is not None
@@ -285,21 +245,14 @@ def test_tdmpc_processor_save_and_load():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
with tempfile.TemporaryDirectory() as tmpdir:
# Save preprocessor
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
# Test that loaded processor works
observation = {
@@ -323,12 +276,7 @@ def test_tdmpc_processor_mixed_precision():
stats = create_default_stats()
# Create processor
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
# Replace DeviceProcessor with one that uses float16
for i, step in enumerate(preprocessor.steps):
@@ -357,12 +305,7 @@ def test_tdmpc_processor_batch_data():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
# Test with batched data
batch_size = 64
@@ -387,12 +330,7 @@ def test_tdmpc_processor_edge_cases():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
# Test with only state observation (no image)
observation = {OBS_STATE: torch.randn(12)}

View File

@@ -98,11 +98,7 @@ def test_basic_tokenization(mock_auto_tokenizer):
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "pick up the red cube"},
)
transition = create_transition(complementary_data={"task": "pick up the red cube"})
result = processor(transition)
@@ -130,11 +126,7 @@ def test_basic_tokenization_with_tokenizer_object():
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "pick up the red cube"},
)
transition = create_transition(complementary_data={"task": "pick up the red cube"})
result = processor(transition)
@@ -164,11 +156,7 @@ def test_list_of_strings_tokenization(mock_auto_tokenizer):
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": ["pick up cube", "place on table"]},
)
transition = create_transition(complementary_data={"task": ["pick up cube", "place on table"]})
result = processor(transition)
@@ -192,11 +180,7 @@ def test_custom_keys(mock_auto_tokenizer):
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", task_key="instruction", max_length=5)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"instruction": "move forward"},
)
transition = create_transition(complementary_data={"instruction": "move forward"})
result = processor(transition)
@@ -389,7 +373,7 @@ def test_integration_with_robot_processor(mock_auto_tokenizer):
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6)
robot_processor = RobotProcessor([tokenizer_processor], to_transition=lambda x: x, to_output=lambda x: x)
robot_processor = RobotProcessor([tokenizer_processor])
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
@@ -427,23 +411,17 @@ def test_save_and_load_pretrained_with_tokenizer_name(mock_auto_tokenizer):
tokenizer_name="test-tokenizer", max_length=32, task_key="instruction"
)
robot_processor = RobotProcessor([original_processor], to_transition=lambda x: x, to_output=lambda x: x)
robot_processor = RobotProcessor([original_processor])
with tempfile.TemporaryDirectory() as temp_dir:
# Save processor
robot_processor.save_pretrained(temp_dir)
# Load processor - tokenizer will be recreated from saved config
loaded_processor = RobotProcessor.from_pretrained(
temp_dir, to_transition=lambda x: x, to_output=lambda x: x
)
loaded_processor = RobotProcessor.from_pretrained(temp_dir)
# Test that loaded processor works
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"instruction": "test instruction"},
)
transition = create_transition(complementary_data={"instruction": "test instruction"})
result = loaded_processor(transition)
assert TransitionKey.OBSERVATION in result
@@ -458,7 +436,7 @@ def test_save_and_load_pretrained_with_tokenizer_object():
original_processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=32, task_key="instruction")
robot_processor = RobotProcessor([original_processor], to_transition=lambda x: x, to_output=lambda x: x)
robot_processor = RobotProcessor([original_processor])
with tempfile.TemporaryDirectory() as temp_dir:
# Save processor
@@ -466,18 +444,11 @@ def test_save_and_load_pretrained_with_tokenizer_object():
# Load processor with tokenizer override (since tokenizer object wasn't saved)
loaded_processor = RobotProcessor.from_pretrained(
temp_dir,
overrides={"tokenizer_processor": {"tokenizer": mock_tokenizer}},
to_transition=lambda x: x,
to_output=lambda x: x,
temp_dir, overrides={"tokenizer_processor": {"tokenizer": mock_tokenizer}}
)
# Test that loaded processor works
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"instruction": "test instruction"},
)
transition = create_transition(complementary_data={"instruction": "test instruction"})
result = loaded_processor(transition)
assert TransitionKey.OBSERVATION in result
@@ -598,11 +569,7 @@ def test_tokenization_parameters(mock_auto_tokenizer):
padding_side="left",
)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "test task"},
)
transition = create_transition(complementary_data={"task": "test task"})
processor(transition)
@@ -625,14 +592,12 @@ def test_preserves_other_complementary_data(mock_auto_tokenizer):
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={
"task": "test task",
"episode_id": 123,
"timestamp": 456.789,
"other_field": {"nested": "data"},
},
}
)
result = processor(transition)
@@ -659,11 +624,7 @@ def test_deterministic_tokenization(mock_auto_tokenizer):
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "consistent test"},
)
transition = create_transition(complementary_data={"task": "consistent test"})
result1 = processor(transition)
result2 = processor(transition)
@@ -687,11 +648,7 @@ def test_empty_string_task(mock_auto_tokenizer):
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": ""},
)
transition = create_transition(complementary_data={"task": ""})
result = processor(transition)
@@ -712,11 +669,7 @@ def test_very_long_task(mock_auto_tokenizer):
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=5, truncation=True)
long_task = " ".join(["word"] * 100) # Very long task
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": long_task},
)
transition = create_transition(complementary_data={"task": long_task})
result = processor(transition)
@@ -761,11 +714,7 @@ def test_custom_padding_side(mock_auto_tokenizer):
# Test left padding
processor_left = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10, padding_side="left")
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "test task"},
)
transition = create_transition(complementary_data={"task": "test task"})
processor_left(transition)
assert tracking_tokenizer.padding_side_calls[-1] == "left"
@@ -924,6 +873,32 @@ def test_device_detection_from_action():
assert attention_mask.device.type == "cuda"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@require_package("transformers")
def test_device_detection_from_complementary_data():
"""Test that device is detected from tensors in complementary_data."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
# Create transition with tensor in complementary_data
transition = create_transition(
observation={"metadata": {"key": "value"}}, # No tensors
complementary_data={
"task": "comp data test",
"index": torch.tensor([42]).cuda(), # Tensor in complementary_data
},
)
result = processor(transition)
# Check that tokenized tensors match complementary_data tensor's device
tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"]
assert tokens.device.type == "cuda"
assert attention_mask.device.type == "cuda"
@require_package("transformers")
def test_device_detection_preserves_dtype():
"""Test that device detection doesn't affect dtype of tokenized tensors."""
@@ -957,9 +932,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
# Create pipeline with TokenizerProcessor then DeviceProcessor
tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6)
device_processor = DeviceProcessor(device="cuda:0")
robot_processor = RobotProcessor(
[tokenizer_processor, device_processor], to_transition=lambda x: x, to_output=lambda x: x
)
robot_processor = RobotProcessor([tokenizer_processor, device_processor])
# Start with CPU tensors
transition = create_transition(

View File

@@ -23,7 +23,7 @@ import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
@@ -81,12 +81,7 @@ def test_make_vqbet_processor_basic():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_vqbet_processor(config, stats)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
@@ -110,12 +105,7 @@ def test_vqbet_processor_with_images():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_vqbet_processor(config, stats)
# Create test data with images and states
observation = {
@@ -141,12 +131,7 @@ def test_vqbet_processor_cuda():
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_vqbet_processor(config, stats)
# Create CPU data
observation = {
@@ -179,12 +164,7 @@ def test_vqbet_processor_accelerate_scenario():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_vqbet_processor(config, stats)
# Simulate Accelerate: data already on GPU and batched
device = torch.device("cuda:0")
@@ -211,12 +191,7 @@ def test_vqbet_processor_multi_gpu():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_vqbet_processor(config, stats)
# Simulate data on different GPU
device = torch.device("cuda:1")
@@ -240,22 +215,7 @@ def test_vqbet_processor_without_stats():
"""Test VQBeT processor creation without dataset statistics."""
config = create_default_config()
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_vqbet_pre_post_processors(config, dataset_stats=None)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(
factory_preprocessor.steps,
name=factory_preprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
postprocessor = RobotProcessor(
factory_postprocessor.steps,
name=factory_postprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
preprocessor, postprocessor = make_vqbet_processor(config, dataset_stats=None)
# Should still create processors
assert preprocessor is not None
@@ -278,21 +238,14 @@ def test_vqbet_processor_save_and_load():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_vqbet_processor(config, stats)
with tempfile.TemporaryDirectory() as tmpdir:
# Save preprocessor
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
# Test that loaded processor works
observation = {
@@ -316,12 +269,7 @@ def test_vqbet_processor_mixed_precision():
stats = create_default_stats()
# Create processor
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_vqbet_processor(config, stats)
# Replace DeviceProcessor with one that uses float16
for i, step in enumerate(preprocessor.steps):
@@ -350,12 +298,7 @@ def test_vqbet_processor_large_batch():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_vqbet_processor(config, stats)
# Test with large batch
batch_size = 128
@@ -380,12 +323,7 @@ def test_vqbet_processor_sequential_processing():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
preprocessor, postprocessor = make_vqbet_processor(config, stats)
# Process multiple samples sequentially
results = []