mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 21:01:26 +00:00
Merge branch 'main' into feat/audio_dataset
This commit is contained in:
@@ -14,14 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .audio_processor import AudioProcessorStep
|
||||
from .batch_processor import AddBatchDimensionProcessorStep
|
||||
from .converters import (
|
||||
batch_to_transition,
|
||||
create_transition,
|
||||
transition_to_batch,
|
||||
)
|
||||
from .core import (
|
||||
from lerobot.types import (
|
||||
EnvAction,
|
||||
EnvTransition,
|
||||
PolicyAction,
|
||||
@@ -29,6 +22,14 @@ from .core import (
|
||||
RobotObservation,
|
||||
TransitionKey,
|
||||
)
|
||||
|
||||
from .audio_processor import AudioProcessorStep
|
||||
from .batch_processor import AddBatchDimensionProcessorStep
|
||||
from .converters import (
|
||||
batch_to_transition,
|
||||
create_transition,
|
||||
transition_to_batch,
|
||||
)
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .factory import (
|
||||
@@ -45,6 +46,7 @@ from .hil_processor import (
|
||||
AddTeleopActionAsComplimentaryDataStep,
|
||||
AddTeleopEventsAsInfoStep,
|
||||
GripperPenaltyProcessorStep,
|
||||
GymHILAdapterProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
RewardClassifierProcessorStep,
|
||||
@@ -74,6 +76,12 @@ from .policy_robot_bridge import (
|
||||
PolicyActionToRobotActionProcessorStep,
|
||||
RobotActionToPolicyActionProcessorStep,
|
||||
)
|
||||
from .relative_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
RelativeActionsProcessorStep,
|
||||
to_absolute_actions,
|
||||
to_relative_actions,
|
||||
)
|
||||
from .rename_processor import RenameObservationsProcessorStep
|
||||
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
|
||||
|
||||
@@ -89,6 +97,7 @@ __all__ = [
|
||||
"DoneProcessorStep",
|
||||
"EnvAction",
|
||||
"EnvTransition",
|
||||
"GymHILAdapterProcessorStep",
|
||||
"GripperPenaltyProcessorStep",
|
||||
"hotswap_stats",
|
||||
"IdentityProcessorStep",
|
||||
@@ -99,6 +108,8 @@ __all__ = [
|
||||
"make_default_teleop_action_processor",
|
||||
"make_default_robot_action_processor",
|
||||
"make_default_robot_observation_processor",
|
||||
"AbsoluteActionsProcessorStep",
|
||||
"RelativeActionsProcessorStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"NormalizerProcessorStep",
|
||||
@@ -128,6 +139,8 @@ __all__ = [
|
||||
"transition_to_batch",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessorStep",
|
||||
"to_absolute_actions",
|
||||
"to_relative_actions",
|
||||
"UnnormalizerProcessorStep",
|
||||
"VanillaObservationProcessorStep",
|
||||
]
|
||||
|
||||
@@ -27,7 +27,6 @@ from torch import Tensor
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_AUDIO, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from .core import EnvTransition, PolicyAction
|
||||
from .pipeline import (
|
||||
ComplementaryDataProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
|
||||
@@ -23,10 +23,9 @@ from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.types import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||
from lerobot.utils.constants import ACTION, DONE, INFO, OBS_PREFIX, REWARD, TRUNCATED
|
||||
|
||||
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||
|
||||
|
||||
@singledispatch
|
||||
def to_tensor(
|
||||
@@ -168,11 +167,12 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
|
||||
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, TypeAlias, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class TransitionKey(str, Enum):
|
||||
"""Keys for accessing EnvTransition dictionary components."""
|
||||
|
||||
# TODO(Steven): Use consts
|
||||
OBSERVATION = "observation"
|
||||
ACTION = "action"
|
||||
REWARD = "reward"
|
||||
DONE = "done"
|
||||
TRUNCATED = "truncated"
|
||||
INFO = "info"
|
||||
COMPLEMENTARY_DATA = "complementary_data"
|
||||
|
||||
|
||||
PolicyAction: TypeAlias = torch.Tensor
|
||||
RobotAction: TypeAlias = dict[str, Any]
|
||||
EnvAction: TypeAlias = np.ndarray
|
||||
RobotObservation: TypeAlias = dict[str, Any]
|
||||
|
||||
|
||||
EnvTransition = TypedDict(
|
||||
"EnvTransition",
|
||||
{
|
||||
TransitionKey.OBSERVATION.value: RobotObservation | None,
|
||||
TransitionKey.ACTION.value: PolicyAction | RobotAction | EnvAction | None,
|
||||
TransitionKey.REWARD.value: float | torch.Tensor | None,
|
||||
TransitionKey.DONE.value: bool | torch.Tensor | None,
|
||||
TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
|
||||
TransitionKey.INFO.value: dict[str, Any] | None,
|
||||
TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None,
|
||||
},
|
||||
)
|
||||
@@ -17,8 +17,8 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.types import PolicyAction, RobotAction
|
||||
|
||||
from .core import PolicyAction, RobotAction
|
||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
|
||||
|
||||
|
||||
|
||||
@@ -25,9 +25,9 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.utils import get_safe_torch_device
|
||||
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
@@ -92,7 +92,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
|
||||
# copy over non-STATE features
|
||||
for ft, feats in features.items():
|
||||
if ft != PipelineFeatureType.STATE:
|
||||
if ft != FeatureType.STATE:
|
||||
new_features[ft] = feats.copy()
|
||||
|
||||
# rebuild STATE features
|
||||
@@ -100,13 +100,11 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
|
||||
# add our new flattened state
|
||||
state_feats[OBS_STATE] = PolicyFeature(
|
||||
key=OBS_STATE,
|
||||
type=FeatureType.STATE,
|
||||
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
|
||||
dtype="float32",
|
||||
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
|
||||
)
|
||||
|
||||
new_features[PipelineFeatureType.STATE] = state_feats
|
||||
new_features[FeatureType.STATE] = state_feats
|
||||
|
||||
return new_features
|
||||
|
||||
|
||||
@@ -14,13 +14,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
|
||||
from .converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from .core import RobotAction, RobotObservation
|
||||
from .pipeline import IdentityProcessorStep, RobotProcessorPipeline
|
||||
|
||||
|
||||
|
||||
@@ -17,9 +17,10 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.types import EnvAction, EnvTransition, PolicyAction
|
||||
|
||||
from .converters import to_tensor
|
||||
from .core import EnvAction, EnvTransition, PolicyAction
|
||||
from .hil_processor import TELEOP_ACTION_KEY
|
||||
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@@ -74,7 +75,7 @@ class Numpy2TorchActionProcessorStep(ProcessorStep):
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Converts numpy action to torch tensor if action exists, otherwise passes through."""
|
||||
from .core import TransitionKey
|
||||
from lerobot.types import TransitionKey
|
||||
|
||||
self._current_transition = transition.copy()
|
||||
new_transition = self._current_transition
|
||||
@@ -89,6 +90,13 @@ class Numpy2TorchActionProcessorStep(ProcessorStep):
|
||||
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
||||
new_transition[TransitionKey.ACTION] = torch_action
|
||||
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if TELEOP_ACTION_KEY in complementary_data:
|
||||
teleop_action = complementary_data[TELEOP_ACTION_KEY]
|
||||
if isinstance(teleop_action, EnvAction):
|
||||
complementary_data[TELEOP_ACTION_KEY] = to_tensor(teleop_action)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -18,17 +18,20 @@
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol, TypeVar, runtime_checkable
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
|
||||
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
|
||||
|
||||
from .pipeline import (
|
||||
ComplementaryDataProcessorStep,
|
||||
InfoProcessorStep,
|
||||
@@ -69,10 +72,10 @@ class HasTeleopEvents(Protocol):
|
||||
|
||||
|
||||
# Type variable constrained to Teleoperator subclasses that also implement events
|
||||
TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
|
||||
TeleopWithEvents = TypeVar("TeleopWithEvents", bound="Teleoperator")
|
||||
|
||||
|
||||
def _check_teleop_with_events(teleop: Teleoperator) -> None:
|
||||
def _check_teleop_with_events(teleop: "Teleoperator") -> None:
|
||||
"""
|
||||
Runtime check that a teleoperator implements the `HasTeleopEvents` protocol.
|
||||
|
||||
@@ -103,7 +106,7 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
|
||||
teleop_device: The teleoperator instance to get the action from.
|
||||
"""
|
||||
|
||||
teleop_device: Teleoperator
|
||||
teleop_device: "Teleoperator"
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
"""
|
||||
@@ -310,9 +313,40 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("gym_hil_adapter_processor")
|
||||
class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Adapts the output of the `gym-hil` environment to the format expected by `lerobot` processors.
|
||||
|
||||
This step normalizes the `transition` object by:
|
||||
1. Copying `teleop_action` from `info` to `complementary_data`.
|
||||
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
if TELEOP_ACTION_KEY in info:
|
||||
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
||||
|
||||
if "is_intervention" in info:
|
||||
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
||||
|
||||
transition[TransitionKey.INFO] = info
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Applies a penalty for inefficient gripper usage.
|
||||
|
||||
@@ -327,26 +361,27 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Calculates the gripper penalty and adds it to the complementary data.
|
||||
|
||||
Args:
|
||||
complementary_data: The incoming complementary data, which should contain
|
||||
raw joint positions.
|
||||
transition: The incoming environment transition.
|
||||
|
||||
Returns:
|
||||
A new complementary data dictionary with the `discrete_penalty` key added.
|
||||
The modified transition with the penalty added to complementary data.
|
||||
"""
|
||||
action = self.transition.get(TransitionKey.ACTION)
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
raw_joint_positions = complementary_data.get("raw_joint_positions")
|
||||
if raw_joint_positions is None:
|
||||
return complementary_data
|
||||
return new_transition
|
||||
|
||||
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
|
||||
if current_gripper_pos is None:
|
||||
return complementary_data
|
||||
return new_transition
|
||||
|
||||
# Gripper action is a PolicyAction at this stage
|
||||
gripper_action = action[-1].item()
|
||||
@@ -362,11 +397,12 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
# Create new complementary data with penalty info
|
||||
# Update complementary data with penalty info
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
|
||||
return new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -26,10 +26,10 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
from .converters import from_tensor_to_numpy, to_tensor
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry, RobotObservation
|
||||
|
||||
|
||||
|
||||
@@ -39,17 +39,17 @@ from collections.abc import Callable, Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast
|
||||
from typing import Any, TypedDict, TypeVar, cast
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.types import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
from .converters import batch_to_transition, create_transition, transition_to_batch
|
||||
from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||
|
||||
# Generic type variables for pipeline input and output.
|
||||
TInput = TypeVar("TInput")
|
||||
@@ -251,7 +251,7 @@ class ProcessorMigrationError(Exception):
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
||||
class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
"""A sequential pipeline for processing data, integrated with the Hugging Face Hub.
|
||||
|
||||
This class chains together multiple `ProcessorStep` instances to form a complete
|
||||
@@ -413,7 +413,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
||||
Args:
|
||||
save_directory: The directory where the pipeline will be saved. If None, saves to
|
||||
HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}.
|
||||
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`.
|
||||
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=true`.
|
||||
push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it.
|
||||
card_kwargs: Additional arguments passed to the card template to customize the card.
|
||||
config_filename: The name of the JSON configuration file. If None, a name is
|
||||
@@ -1432,8 +1432,8 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
||||
|
||||
|
||||
# Type aliases for semantic clarity.
|
||||
RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
|
||||
PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
|
||||
RobotProcessorPipeline = DataProcessorPipeline[TInput, TOutput]
|
||||
PolicyProcessorPipeline = DataProcessorPipeline[TInput, TOutput]
|
||||
|
||||
|
||||
class ObservationProcessorStep(ProcessorStep, ABC):
|
||||
|
||||
208
src/lerobot/processor/relative_action_processor.py
Normal file
208
src/lerobot/processor/relative_action_processor.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
# Re-export for backward compatibility
|
||||
__all__ = [
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"RelativeActionsProcessorStep",
|
||||
"AbsoluteActionsProcessorStep",
|
||||
"to_relative_actions",
|
||||
"to_absolute_actions",
|
||||
]
|
||||
|
||||
|
||||
def to_relative_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||
"""Convert absolute actions to relative: relative = action - state (for masked dims).
|
||||
|
||||
Args:
|
||||
actions: (B, T, action_dim) or (B, action_dim).
|
||||
state: (B, state_dim). Broadcast across time dimension.
|
||||
mask: Which dims to convert. Can be shorter than action_dim.
|
||||
"""
|
||||
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||||
dims = mask_t.shape[0]
|
||||
# Align state to the same device/dtype as actions. _last_state is cached before
|
||||
# DeviceProcessorStep moves the transition, so it can be on CPU while actions are on CUDA.
|
||||
if state.device != actions.device or state.dtype != actions.dtype:
|
||||
state = state.to(device=actions.device, dtype=actions.dtype)
|
||||
state_offset = state[..., :dims] * mask_t
|
||||
if actions.ndim == 3:
|
||||
state_offset = state_offset.unsqueeze(-2)
|
||||
actions = actions.clone()
|
||||
actions[..., :dims] -= state_offset
|
||||
return actions
|
||||
|
||||
|
||||
def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||
"""Convert relative actions back to absolute: absolute = relative + state (for masked dims).
|
||||
|
||||
Args:
|
||||
actions: (B, T, action_dim) or (B, action_dim).
|
||||
state: (B, state_dim). Broadcast across time dimension.
|
||||
mask: Which dims to convert. Can be shorter than action_dim.
|
||||
"""
|
||||
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||||
dims = mask_t.shape[0]
|
||||
# Align state to the same device/dtype as actions. _last_state is cached before
|
||||
# DeviceProcessorStep moves the transition, so it can be on CPU while actions are on CUDA.
|
||||
if state.device != actions.device or state.dtype != actions.dtype:
|
||||
state = state.to(device=actions.device, dtype=actions.dtype)
|
||||
state_offset = state[..., :dims] * mask_t
|
||||
if actions.ndim == 3:
|
||||
state_offset = state_offset.unsqueeze(-2)
|
||||
actions = actions.clone()
|
||||
actions[..., :dims] += state_offset
|
||||
return actions
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("delta_actions_processor")
|
||||
@dataclass
|
||||
class RelativeActionsProcessorStep(ProcessorStep):
|
||||
"""Converts absolute actions to relative actions (action -= state) for masked dimensions.
|
||||
|
||||
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
|
||||
trains on relative offsets instead of absolute positions.
|
||||
Caches the last seen state so a paired AbsoluteActionsProcessorStep can reverse
|
||||
the conversion during postprocessing.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether to apply the relative conversion.
|
||||
exclude_joints: Joint names to keep absolute (not converted to relative).
|
||||
action_names: Action dimension names from dataset metadata, used to build
|
||||
the mask from exclude_joints. If None, all dims are converted.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
exclude_joints: list[str] = field(default_factory=list)
|
||||
action_names: list[str] | None = None
|
||||
_last_state: torch.Tensor | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def _build_mask(self, action_dim: int) -> list[bool]:
|
||||
if not self.exclude_joints or self.action_names is None:
|
||||
return [True] * action_dim
|
||||
|
||||
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
|
||||
if not exclude_tokens:
|
||||
return [True] * action_dim
|
||||
|
||||
mask = []
|
||||
for name in self.action_names[:action_dim]:
|
||||
action_name = str(name).lower()
|
||||
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
|
||||
mask.append(not is_excluded)
|
||||
|
||||
if len(mask) < action_dim:
|
||||
mask.extend([True] * (action_dim - len(mask)))
|
||||
|
||||
return mask
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION, {})
|
||||
state = observation.get(OBS_STATE) if observation else None
|
||||
|
||||
# Always cache state for the paired AbsoluteActionsProcessorStep
|
||||
if state is not None:
|
||||
self._last_state = state
|
||||
|
||||
if not self.enabled:
|
||||
return transition
|
||||
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is None or state is None:
|
||||
return new_transition
|
||||
|
||||
mask = self._build_mask(action.shape[-1])
|
||||
new_transition[TransitionKey.ACTION] = to_relative_actions(action, state, mask)
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
"exclude_joints": self.exclude_joints,
|
||||
"action_names": self.action_names,
|
||||
}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("absolute_actions_processor")
|
||||
@dataclass
|
||||
class AbsoluteActionsProcessorStep(ProcessorStep):
|
||||
"""Converts relative actions back to absolute actions (action += state) for all dimensions.
|
||||
|
||||
Mirrors OpenPI's AbsoluteActions transform. Applied during postprocessing so
|
||||
predicted relative offsets are converted back to absolute positions for execution.
|
||||
Reads the cached state from its paired RelativeActionsProcessorStep.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether to apply the absolute conversion.
|
||||
relative_step: Reference to the paired RelativeActionsProcessorStep that caches state.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
relative_step: RelativeActionsProcessorStep | None = field(default=None, repr=False)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
if not self.enabled:
|
||||
return transition
|
||||
|
||||
if self.relative_step is None:
|
||||
raise RuntimeError(
|
||||
"AbsoluteActionsProcessorStep requires a paired RelativeActionsProcessorStep "
|
||||
"but relative_step is None. Ensure relative_step is set when constructing the postprocessor."
|
||||
)
|
||||
|
||||
if self.relative_step._last_state is None:
|
||||
raise RuntimeError(
|
||||
"AbsoluteActionsProcessorStep requires state from RelativeActionsProcessorStep "
|
||||
"but no state has been cached. Ensure the preprocessor runs before the postprocessor."
|
||||
)
|
||||
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return new_transition
|
||||
|
||||
mask = self.relative_step._build_mask(action.shape[-1])
|
||||
new_transition[TransitionKey.ACTION] = to_absolute_actions(
|
||||
action, self.relative_step._last_state, mask
|
||||
)
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"enabled": self.enabled}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
@@ -30,15 +30,17 @@ from typing import TYPE_CHECKING, Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.types import EnvTransition, RobotObservation, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_TOKENS,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
from .core import EnvTransition, RobotObservation, TransitionKey
|
||||
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
@@ -139,6 +141,32 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
|
||||
return None
|
||||
|
||||
def get_subtask(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""
|
||||
Extracts the subtask from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
A list of subtask strings, or None if the subtask key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
subtask = complementary_data.get("subtask")
|
||||
if subtask is None:
|
||||
return None
|
||||
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(subtask, str):
|
||||
return [subtask]
|
||||
elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask):
|
||||
return subtask
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation: RobotObservation) -> RobotObservation:
|
||||
"""
|
||||
Tokenizes the task description and adds it to the observation dictionary.
|
||||
@@ -176,6 +204,24 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Tokenize subtask if available
|
||||
subtask = self.get_subtask(self.transition)
|
||||
if subtask is not None:
|
||||
tokenized_subtask = self._tokenize_text(subtask)
|
||||
|
||||
# Move new tokenized tensors to the detected device
|
||||
if target_device is not None:
|
||||
tokenized_subtask = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_subtask.items()
|
||||
}
|
||||
|
||||
# Add tokenized subtask to the observation
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
|
||||
return new_observation
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
@@ -290,7 +336,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Attributes:
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "lerobot/fast-action-tokenizer").
|
||||
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
||||
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
||||
|
||||
Reference in New Issue
Block a user