chore(docs): update doctrines pipeline files (#1872)

* docs(processor): update docstrings batch_processor

* docs(processor): update docstrings device_processor

* docs(processor): update docstrings tokenizer_processor

* update docstrings processor_act

* update docstrings for pipeline_features

* update docstrings for utils

* update docstring for processor_diffusion

* update docstrings factory

* add docstrings to pi0 processor

* add docstring to pi0fast processor

* add docstring classifier processor

* add docstring to sac processor

* add docstring smolvla processor

* add docstring to tdmpc processor

* add docstring to vqbet processor

* add docstrings to converters

* add docstrings for delta_action_processor

* add docstring to gym action processor

* update hil processor

* add docstring to joint obs processor

* add docstring to migrate_normalize_processor

* update docstrings normalize processor

* update docstring normalize processor

* update docstrings observation processor

* update docstrings rename_processor

* add docstrings robot_kinematic_processor

* cleanup rl comments

* add docstring to train.py

* add docstring to teleoperate.py

* add docstrings to phone_processor.py

* add docstrings to teleop_phone.py

* add docstrings to control_utils.py

* add docstrings to visualization_utils.py

---------

Co-authored-by: Pepijn <pepijn@huggingface.co>
This commit is contained in:
Steven Palma
2025-09-08 18:44:15 +02:00
committed by GitHub
parent d32006440c
commit af9ddcf9a2
33 changed files with 2325 additions and 519 deletions

View File

@@ -1,3 +1,20 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import time
from dataclasses import dataclass
@@ -29,21 +46,25 @@ TELEOP_ACTION_KEY = "teleop_action"
@runtime_checkable
class HasTeleopEvents(Protocol):
"""Minimal protocol for objects that provide teleoperation events.
"""
Minimal protocol for objects that provide teleoperation events.
This protocol only defines the additional get_teleop_events() method,
avoiding duplication of the entire Teleoperator interface.
This protocol defines the `get_teleop_events()` method, allowing processor
steps to interact with teleoperators that support event-based controls
(like episode termination or success flagging) without needing to know the
teleoperator's specific class.
"""
def get_teleop_events(self) -> dict[str, Any]:
"""Get extra control events from the teleoperator.
"""
Get extra control events from the teleoperator.
Returns:
Dictionary containing control events such as:
- is_intervention: bool - Whether human is currently intervening
- terminate_episode: bool - Whether to terminate the current episode
- success: bool - Whether the episode was successful
- rerecord_episode: bool - Whether to rerecord the episode
A dictionary containing control events such as:
- `is_intervention`: bool - Whether the human is currently intervening.
- `terminate_episode`: bool - Whether to terminate the current episode.
- `success`: bool - Whether the episode was successful.
- `rerecord_episode`: bool - Whether to rerecord the episode.
"""
...
@@ -53,7 +74,15 @@ TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
def _check_teleop_with_events(teleop: Teleoperator) -> None:
"""Runtime check that a teleoperator implements get_teleop_events."""
"""
Runtime check that a teleoperator implements the `HasTeleopEvents` protocol.
Args:
teleop: The teleoperator instance to check.
Raises:
TypeError: If the teleoperator does not have a `get_teleop_events` method.
"""
if not isinstance(teleop, HasTeleopEvents):
raise TypeError(
f"Teleoperator {type(teleop).__name__} must implement get_teleop_events() method. "
@@ -64,11 +93,30 @@ def _check_teleop_with_events(teleop: Teleoperator) -> None:
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
@dataclass
class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
"""Add teleoperator action to transition complementary data."""
"""
Adds the raw action from a teleoperator to the transition's complementary data.
This is useful for human-in-the-loop scenarios where the human's input needs to
be available to downstream processors, for example, to override a policy's action
during an intervention.
Attributes:
teleop_device: The teleoperator instance to get the action from.
"""
teleop_device: Teleoperator
def complementary_data(self, complementary_data: dict) -> dict:
"""
Retrieves the teleoperator's action and adds it to the complementary data.
Args:
complementary_data: The incoming complementary data dictionary.
Returns:
A new dictionary with the teleoperator action added under the
`teleop_action` key.
"""
new_complementary_data = dict(complementary_data)
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
return new_complementary_data
@@ -80,26 +128,33 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
@ProcessorStepRegistry.register("add_teleop_action_as_info")
@dataclass
class AddTeleopEventsAsInfoStep(InfoProcessorStep):
"""Add teleoperator control events to transition info.
"""
Adds teleoperator control events (e.g., terminate, success) to the transition's info.
This processor step extracts control events from teleoperators that support
event-based interaction (intervention detection, episode termination, etc.).
This step extracts control events from teleoperators that support event-based
interaction, making these signals available to other parts of the system.
Works with any teleoperator that inherits from Teleoperator and implements the
get_teleop_events() method, including custom user-defined teleoperators.
Built-in compatible teleoperators:
- GamepadTeleop: Uses gamepad buttons for control events
- KeyboardEndEffectorTeleop: Uses keyboard keys for control events
Attributes:
teleop_device: An instance of a teleoperator that implements the
`HasTeleopEvents` protocol.
"""
teleop_device: TeleopWithEvents
def __post_init__(self):
"""Validate that the teleoperator supports events."""
"""Validates that the provided teleoperator supports events after initialization."""
_check_teleop_with_events(self.teleop_device)
def info(self, info: dict) -> dict:
"""
Retrieves teleoperator events and updates the info dictionary.
Args:
info: The incoming info dictionary.
Returns:
A new dictionary including the teleoperator events.
"""
new_info = dict(info)
teleop_events = self.teleop_device.get_teleop_events()
@@ -113,12 +168,32 @@ class AddTeleopEventsAsInfoStep(InfoProcessorStep):
@ProcessorStepRegistry.register("image_crop_resize_processor")
@dataclass
class ImageCropResizeProcessorStep(ObservationProcessorStep):
"""Crop and resize image observations."""
"""
Crops and/or resizes image observations.
This step iterates through all image keys in an observation dictionary and applies
the specified transformations. It handles device placement, moving tensors to the
CPU if necessary for operations not supported on certain accelerators like MPS.
Attributes:
crop_params_dict: A dictionary mapping image keys to cropping parameters
(top, left, height, width).
resize_size: A tuple (height, width) to resize all images to.
"""
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
resize_size: tuple[int, int] | None = None
def observation(self, observation: dict) -> dict:
"""
Applies cropping and resizing to all images in the observation dictionary.
Args:
observation: The observation dictionary, potentially containing image tensors.
Returns:
A new observation dictionary with transformed images.
"""
if self.resize_size is None and not self.crop_params_dict:
return observation
@@ -146,12 +221,27 @@ class ImageCropResizeProcessorStep(ObservationProcessorStep):
return new_observation
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary with the crop parameters and resize dimensions.
"""
return {
"crop_params_dict": self.crop_params_dict,
"resize_size": self.resize_size,
}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""
Updates the image feature shapes in the policy features dictionary if resizing is applied.
Args:
features: The policy features dictionary.
Returns:
The updated policy features dictionary with new image shapes.
"""
if self.resize_size is None:
return features
for key in features:
@@ -163,12 +253,27 @@ class ImageCropResizeProcessorStep(ObservationProcessorStep):
@dataclass
@ProcessorStepRegistry.register("time_limit_processor")
class TimeLimitProcessorStep(TruncatedProcessorStep):
"""Track episode steps and enforce time limits."""
"""
Tracks episode steps and enforces a time limit by truncating the episode.
Attributes:
max_episode_steps: The maximum number of steps allowed per episode.
current_step: The current step count for the active episode.
"""
max_episode_steps: int
current_step: int = 0
def truncated(self, truncated):
def truncated(self, truncated: bool) -> bool:
"""
Increments the step counter and sets the truncated flag if the time limit is reached.
Args:
truncated: The incoming truncated flag.
Returns:
True if the episode step limit is reached, otherwise the incoming value.
"""
self.current_step += 1
if self.current_step >= self.max_episode_steps:
truncated = True
@@ -176,11 +281,18 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
return truncated
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the `max_episode_steps`.
"""
return {
"max_episode_steps": self.max_episode_steps,
}
def reset(self) -> None:
"""Resets the step counter, typically called at the start of a new episode."""
self.current_step = 0
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
@@ -190,13 +302,31 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
"""Apply penalty for inappropriate gripper usage."""
"""
Applies a penalty for inefficient gripper usage.
This step penalizes actions that attempt to close an already closed gripper or
open an already open one, based on position thresholds.
Attributes:
penalty: The negative reward value to apply.
max_gripper_pos: The maximum position value for the gripper, used for normalization.
"""
penalty: float = -0.01
max_gripper_pos: float = 30.0
def complementary_data(self, complementary_data):
"""Calculate gripper penalty and add to complementary data."""
def complementary_data(self, complementary_data: dict) -> dict:
"""
Calculates the gripper penalty and adds it to the complementary data.
Args:
complementary_data: The incoming complementary data, which should contain
raw joint positions.
Returns:
A new complementary data dictionary with the `discrete_penalty` key added.
"""
action = self.transition.get(TransitionKey.ACTION)
current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None)
@@ -223,14 +353,20 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
return new_complementary_data
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the penalty value and max gripper position.
"""
return {
"penalty": self.penalty,
"max_gripper_pos": self.max_gripper_pos,
}
def reset(self) -> None:
"""Reset the processor state."""
self.last_gripper_state = None
"""Resets the processor's internal state."""
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -239,12 +375,33 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
@dataclass
@ProcessorStepRegistry.register("intervention_action_processor")
class InterventionActionProcessorStep(ProcessorStep):
"""Handle human intervention actions and episode termination."""
"""
Handles human intervention, overriding policy actions and managing episode termination.
When an intervention is detected (via teleoperator events in the `info` dict),
this step replaces the policy's action with the human's teleoperated action.
It also processes signals to terminate the episode or flag success.
Attributes:
use_gripper: Whether to include the gripper in the teleoperated action.
terminate_on_success: If True, automatically sets the `done` flag when a
`success` event is received.
"""
use_gripper: bool = False
terminate_on_success: bool = True
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Processes the transition to handle interventions.
Args:
transition: The incoming environment transition.
Returns:
The modified transition, potentially with an overridden action, updated
reward, and termination status.
"""
action = transition.get(TransitionKey.ACTION)
if action is None:
return transition
@@ -300,6 +457,12 @@ class InterventionActionProcessorStep(ProcessorStep):
return new_transition
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the step's configuration attributes.
"""
return {
"use_gripper": self.use_gripper,
"terminate_on_success": self.terminate_on_success,
@@ -312,7 +475,20 @@ class InterventionActionProcessorStep(ProcessorStep):
@dataclass
@ProcessorStepRegistry.register("reward_classifier_processor")
class RewardClassifierProcessorStep(ProcessorStep):
"""Apply reward classification to image observations."""
"""
Applies a pretrained reward classifier to image observations to predict success.
This step uses a model to determine if the current state is successful, updating
the reward and potentially terminating the episode.
Attributes:
pretrained_path: Path to the pretrained reward classifier model.
device: The device to run the classifier on.
success_threshold: The probability threshold to consider a prediction as successful.
success_reward: The reward value to assign on success.
terminate_on_success: If True, terminates the episode upon successful classification.
reward_classifier: The loaded classifier model instance.
"""
pretrained_path: str | None = None
device: str = "cpu"
@@ -323,7 +499,7 @@ class RewardClassifierProcessorStep(ProcessorStep):
reward_classifier: Any = None
def __post_init__(self):
"""Initialize the reward classifier after dataclass initialization."""
"""Initializes the reward classifier model after the dataclass is created."""
if self.pretrained_path is not None:
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
@@ -332,6 +508,16 @@ class RewardClassifierProcessorStep(ProcessorStep):
self.reward_classifier.eval()
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Processes a transition, applying the reward classifier to its image observations.
Args:
transition: The incoming environment transition.
Returns:
The modified transition with an updated reward and done flag based on the
classifier's prediction.
"""
new_transition = transition.copy()
observation = new_transition.get(TransitionKey.OBSERVATION)
if observation is None or self.reward_classifier is None:
@@ -371,6 +557,12 @@ class RewardClassifierProcessorStep(ProcessorStep):
return new_transition
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the step's configuration attributes.
"""
return {
"device": self.device,
"success_threshold": self.success_threshold,