mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
chore(docs): update doctrines pipeline files (#1872)
* docs(processor): update docstrings batch_processor * docs(processor): update docstrings device_processor * docs(processor): update docstrings tokenizer_processor * update docstrings processor_act * update docstrings for pipeline_features * update docstrings for utils * update docstring for processor_diffusion * update docstrings factory * add docstrings to pi0 processor * add docstring to pi0fast processor * add docstring classifier processor * add docstring to sac processor * add docstring smolvla processor * add docstring to tdmpc processor * add docstring to vqbet processor * add docstrings to converters * add docstrings for delta_action_processor * add docstring to gym action processor * update hil processor * add docstring to joint obs processor * add docstring to migrate_normalize_processor * update docstrings normalize processor * update docstring normalize processor * update docstrings observation processor * update docstrings rename_processor * add docstrings robot_kinematic_processor * cleanup rl comments * add docstring to train.py * add docstring to teleoperate.py * add docstrings to phone_processor.py * add docstrings to teleop_phone.py * add docstrings to control_utils.py * add docstrings to visualization_utils.py --------- Co-authored-by: Pepijn <pepijn@huggingface.co>
This commit is contained in:
@@ -1,3 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
@@ -29,21 +46,25 @@ TELEOP_ACTION_KEY = "teleop_action"
|
||||
|
||||
@runtime_checkable
|
||||
class HasTeleopEvents(Protocol):
|
||||
"""Minimal protocol for objects that provide teleoperation events.
|
||||
"""
|
||||
Minimal protocol for objects that provide teleoperation events.
|
||||
|
||||
This protocol only defines the additional get_teleop_events() method,
|
||||
avoiding duplication of the entire Teleoperator interface.
|
||||
This protocol defines the `get_teleop_events()` method, allowing processor
|
||||
steps to interact with teleoperators that support event-based controls
|
||||
(like episode termination or success flagging) without needing to know the
|
||||
teleoperator's specific class.
|
||||
"""
|
||||
|
||||
def get_teleop_events(self) -> dict[str, Any]:
|
||||
"""Get extra control events from the teleoperator.
|
||||
"""
|
||||
Get extra control events from the teleoperator.
|
||||
|
||||
Returns:
|
||||
Dictionary containing control events such as:
|
||||
- is_intervention: bool - Whether human is currently intervening
|
||||
- terminate_episode: bool - Whether to terminate the current episode
|
||||
- success: bool - Whether the episode was successful
|
||||
- rerecord_episode: bool - Whether to rerecord the episode
|
||||
A dictionary containing control events such as:
|
||||
- `is_intervention`: bool - Whether the human is currently intervening.
|
||||
- `terminate_episode`: bool - Whether to terminate the current episode.
|
||||
- `success`: bool - Whether the episode was successful.
|
||||
- `rerecord_episode`: bool - Whether to rerecord the episode.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -53,7 +74,15 @@ TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
|
||||
|
||||
|
||||
def _check_teleop_with_events(teleop: Teleoperator) -> None:
|
||||
"""Runtime check that a teleoperator implements get_teleop_events."""
|
||||
"""
|
||||
Runtime check that a teleoperator implements the `HasTeleopEvents` protocol.
|
||||
|
||||
Args:
|
||||
teleop: The teleoperator instance to check.
|
||||
|
||||
Raises:
|
||||
TypeError: If the teleoperator does not have a `get_teleop_events` method.
|
||||
"""
|
||||
if not isinstance(teleop, HasTeleopEvents):
|
||||
raise TypeError(
|
||||
f"Teleoperator {type(teleop).__name__} must implement get_teleop_events() method. "
|
||||
@@ -64,11 +93,30 @@ def _check_teleop_with_events(teleop: Teleoperator) -> None:
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
|
||||
@dataclass
|
||||
class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
|
||||
"""Add teleoperator action to transition complementary data."""
|
||||
"""
|
||||
Adds the raw action from a teleoperator to the transition's complementary data.
|
||||
|
||||
This is useful for human-in-the-loop scenarios where the human's input needs to
|
||||
be available to downstream processors, for example, to override a policy's action
|
||||
during an intervention.
|
||||
|
||||
Attributes:
|
||||
teleop_device: The teleoperator instance to get the action from.
|
||||
"""
|
||||
|
||||
teleop_device: Teleoperator
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
"""
|
||||
Retrieves the teleoperator's action and adds it to the complementary data.
|
||||
|
||||
Args:
|
||||
complementary_data: The incoming complementary data dictionary.
|
||||
|
||||
Returns:
|
||||
A new dictionary with the teleoperator action added under the
|
||||
`teleop_action` key.
|
||||
"""
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
|
||||
return new_complementary_data
|
||||
@@ -80,26 +128,33 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_info")
|
||||
@dataclass
|
||||
class AddTeleopEventsAsInfoStep(InfoProcessorStep):
|
||||
"""Add teleoperator control events to transition info.
|
||||
"""
|
||||
Adds teleoperator control events (e.g., terminate, success) to the transition's info.
|
||||
|
||||
This processor step extracts control events from teleoperators that support
|
||||
event-based interaction (intervention detection, episode termination, etc.).
|
||||
This step extracts control events from teleoperators that support event-based
|
||||
interaction, making these signals available to other parts of the system.
|
||||
|
||||
Works with any teleoperator that inherits from Teleoperator and implements the
|
||||
get_teleop_events() method, including custom user-defined teleoperators.
|
||||
|
||||
Built-in compatible teleoperators:
|
||||
- GamepadTeleop: Uses gamepad buttons for control events
|
||||
- KeyboardEndEffectorTeleop: Uses keyboard keys for control events
|
||||
Attributes:
|
||||
teleop_device: An instance of a teleoperator that implements the
|
||||
`HasTeleopEvents` protocol.
|
||||
"""
|
||||
|
||||
teleop_device: TeleopWithEvents
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that the teleoperator supports events."""
|
||||
"""Validates that the provided teleoperator supports events after initialization."""
|
||||
_check_teleop_with_events(self.teleop_device)
|
||||
|
||||
def info(self, info: dict) -> dict:
|
||||
"""
|
||||
Retrieves teleoperator events and updates the info dictionary.
|
||||
|
||||
Args:
|
||||
info: The incoming info dictionary.
|
||||
|
||||
Returns:
|
||||
A new dictionary including the teleoperator events.
|
||||
"""
|
||||
new_info = dict(info)
|
||||
|
||||
teleop_events = self.teleop_device.get_teleop_events()
|
||||
@@ -113,12 +168,32 @@ class AddTeleopEventsAsInfoStep(InfoProcessorStep):
|
||||
@ProcessorStepRegistry.register("image_crop_resize_processor")
|
||||
@dataclass
|
||||
class ImageCropResizeProcessorStep(ObservationProcessorStep):
|
||||
"""Crop and resize image observations."""
|
||||
"""
|
||||
Crops and/or resizes image observations.
|
||||
|
||||
This step iterates through all image keys in an observation dictionary and applies
|
||||
the specified transformations. It handles device placement, moving tensors to the
|
||||
CPU if necessary for operations not supported on certain accelerators like MPS.
|
||||
|
||||
Attributes:
|
||||
crop_params_dict: A dictionary mapping image keys to cropping parameters
|
||||
(top, left, height, width).
|
||||
resize_size: A tuple (height, width) to resize all images to.
|
||||
"""
|
||||
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
"""
|
||||
Applies cropping and resizing to all images in the observation dictionary.
|
||||
|
||||
Args:
|
||||
observation: The observation dictionary, potentially containing image tensors.
|
||||
|
||||
Returns:
|
||||
A new observation dictionary with transformed images.
|
||||
"""
|
||||
if self.resize_size is None and not self.crop_params_dict:
|
||||
return observation
|
||||
|
||||
@@ -146,12 +221,27 @@ class ImageCropResizeProcessorStep(ObservationProcessorStep):
|
||||
return new_observation
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary with the crop parameters and resize dimensions.
|
||||
"""
|
||||
return {
|
||||
"crop_params_dict": self.crop_params_dict,
|
||||
"resize_size": self.resize_size,
|
||||
}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
Updates the image feature shapes in the policy features dictionary if resizing is applied.
|
||||
|
||||
Args:
|
||||
features: The policy features dictionary.
|
||||
|
||||
Returns:
|
||||
The updated policy features dictionary with new image shapes.
|
||||
"""
|
||||
if self.resize_size is None:
|
||||
return features
|
||||
for key in features:
|
||||
@@ -163,12 +253,27 @@ class ImageCropResizeProcessorStep(ObservationProcessorStep):
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("time_limit_processor")
|
||||
class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
"""Track episode steps and enforce time limits."""
|
||||
"""
|
||||
Tracks episode steps and enforces a time limit by truncating the episode.
|
||||
|
||||
Attributes:
|
||||
max_episode_steps: The maximum number of steps allowed per episode.
|
||||
current_step: The current step count for the active episode.
|
||||
"""
|
||||
|
||||
max_episode_steps: int
|
||||
current_step: int = 0
|
||||
|
||||
def truncated(self, truncated):
|
||||
def truncated(self, truncated: bool) -> bool:
|
||||
"""
|
||||
Increments the step counter and sets the truncated flag if the time limit is reached.
|
||||
|
||||
Args:
|
||||
truncated: The incoming truncated flag.
|
||||
|
||||
Returns:
|
||||
True if the episode step limit is reached, otherwise the incoming value.
|
||||
"""
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.max_episode_steps:
|
||||
truncated = True
|
||||
@@ -176,11 +281,18 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
return truncated
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the `max_episode_steps`.
|
||||
"""
|
||||
return {
|
||||
"max_episode_steps": self.max_episode_steps,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the step counter, typically called at the start of a new episode."""
|
||||
self.current_step = 0
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
@@ -190,13 +302,31 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
"""Apply penalty for inappropriate gripper usage."""
|
||||
"""
|
||||
Applies a penalty for inefficient gripper usage.
|
||||
|
||||
This step penalizes actions that attempt to close an already closed gripper or
|
||||
open an already open one, based on position thresholds.
|
||||
|
||||
Attributes:
|
||||
penalty: The negative reward value to apply.
|
||||
max_gripper_pos: The maximum position value for the gripper, used for normalization.
|
||||
"""
|
||||
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
"""Calculate gripper penalty and add to complementary data."""
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
"""
|
||||
Calculates the gripper penalty and adds it to the complementary data.
|
||||
|
||||
Args:
|
||||
complementary_data: The incoming complementary data, which should contain
|
||||
raw joint positions.
|
||||
|
||||
Returns:
|
||||
A new complementary data dictionary with the `discrete_penalty` key added.
|
||||
"""
|
||||
action = self.transition.get(TransitionKey.ACTION)
|
||||
|
||||
current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None)
|
||||
@@ -223,14 +353,20 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
return new_complementary_data
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the penalty value and max gripper position.
|
||||
"""
|
||||
return {
|
||||
"penalty": self.penalty,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the processor state."""
|
||||
self.last_gripper_state = None
|
||||
"""Resets the processor's internal state."""
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -239,12 +375,33 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("intervention_action_processor")
|
||||
class InterventionActionProcessorStep(ProcessorStep):
|
||||
"""Handle human intervention actions and episode termination."""
|
||||
"""
|
||||
Handles human intervention, overriding policy actions and managing episode termination.
|
||||
|
||||
When an intervention is detected (via teleoperator events in the `info` dict),
|
||||
this step replaces the policy's action with the human's teleoperated action.
|
||||
It also processes signals to terminate the episode or flag success.
|
||||
|
||||
Attributes:
|
||||
use_gripper: Whether to include the gripper in the teleoperated action.
|
||||
terminate_on_success: If True, automatically sets the `done` flag when a
|
||||
`success` event is received.
|
||||
"""
|
||||
|
||||
use_gripper: bool = False
|
||||
terminate_on_success: bool = True
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Processes the transition to handle interventions.
|
||||
|
||||
Args:
|
||||
transition: The incoming environment transition.
|
||||
|
||||
Returns:
|
||||
The modified transition, potentially with an overridden action, updated
|
||||
reward, and termination status.
|
||||
"""
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return transition
|
||||
@@ -300,6 +457,12 @@ class InterventionActionProcessorStep(ProcessorStep):
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the step's configuration attributes.
|
||||
"""
|
||||
return {
|
||||
"use_gripper": self.use_gripper,
|
||||
"terminate_on_success": self.terminate_on_success,
|
||||
@@ -312,7 +475,20 @@ class InterventionActionProcessorStep(ProcessorStep):
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("reward_classifier_processor")
|
||||
class RewardClassifierProcessorStep(ProcessorStep):
|
||||
"""Apply reward classification to image observations."""
|
||||
"""
|
||||
Applies a pretrained reward classifier to image observations to predict success.
|
||||
|
||||
This step uses a model to determine if the current state is successful, updating
|
||||
the reward and potentially terminating the episode.
|
||||
|
||||
Attributes:
|
||||
pretrained_path: Path to the pretrained reward classifier model.
|
||||
device: The device to run the classifier on.
|
||||
success_threshold: The probability threshold to consider a prediction as successful.
|
||||
success_reward: The reward value to assign on success.
|
||||
terminate_on_success: If True, terminates the episode upon successful classification.
|
||||
reward_classifier: The loaded classifier model instance.
|
||||
"""
|
||||
|
||||
pretrained_path: str | None = None
|
||||
device: str = "cpu"
|
||||
@@ -323,7 +499,7 @@ class RewardClassifierProcessorStep(ProcessorStep):
|
||||
reward_classifier: Any = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the reward classifier after dataclass initialization."""
|
||||
"""Initializes the reward classifier model after the dataclass is created."""
|
||||
if self.pretrained_path is not None:
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
@@ -332,6 +508,16 @@ class RewardClassifierProcessorStep(ProcessorStep):
|
||||
self.reward_classifier.eval()
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Processes a transition, applying the reward classifier to its image observations.
|
||||
|
||||
Args:
|
||||
transition: The incoming environment transition.
|
||||
|
||||
Returns:
|
||||
The modified transition with an updated reward and done flag based on the
|
||||
classifier's prediction.
|
||||
"""
|
||||
new_transition = transition.copy()
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None or self.reward_classifier is None:
|
||||
@@ -371,6 +557,12 @@ class RewardClassifierProcessorStep(ProcessorStep):
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the step's configuration attributes.
|
||||
"""
|
||||
return {
|
||||
"device": self.device,
|
||||
"success_threshold": self.success_threshold,
|
||||
|
||||
Reference in New Issue
Block a user