mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
209 lines
8.2 KiB
Python
209 lines
8.2 KiB
Python
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||
|
|
#
|
||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
# you may not use this file except in compliance with the License.
|
||
|
|
# You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
# See the License for the specific language governing permissions and
|
||
|
|
# limitations under the License.
|
||
|
|
|
||
|
|
from collections.abc import Sequence
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from torch import Tensor
|
||
|
|
|
||
|
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||
|
|
from lerobot.types import EnvTransition, TransitionKey
|
||
|
|
from lerobot.utils.constants import OBS_STATE
|
||
|
|
|
||
|
|
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||
|
|
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||
|
|
|
||
|
|
# Re-export for backward compatibility
|
||
|
|
__all__ = [
|
||
|
|
"MapDeltaActionToRobotActionStep",
|
||
|
|
"MapTensorToDeltaActionDictStep",
|
||
|
|
"RelativeActionsProcessorStep",
|
||
|
|
"AbsoluteActionsProcessorStep",
|
||
|
|
"to_relative_actions",
|
||
|
|
"to_absolute_actions",
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
def to_relative_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||
|
|
"""Convert absolute actions to relative: relative = action - state (for masked dims).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
actions: (B, T, action_dim) or (B, action_dim).
|
||
|
|
state: (B, state_dim). Broadcast across time dimension.
|
||
|
|
mask: Which dims to convert. Can be shorter than action_dim.
|
||
|
|
"""
|
||
|
|
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||
|
|
dims = mask_t.shape[0]
|
||
|
|
# Align state to the same device/dtype as actions. _last_state is cached before
|
||
|
|
# DeviceProcessorStep moves the transition, so it can be on CPU while actions are on CUDA.
|
||
|
|
if state.device != actions.device or state.dtype != actions.dtype:
|
||
|
|
state = state.to(device=actions.device, dtype=actions.dtype)
|
||
|
|
state_offset = state[..., :dims] * mask_t
|
||
|
|
if actions.ndim == 3:
|
||
|
|
state_offset = state_offset.unsqueeze(-2)
|
||
|
|
actions = actions.clone()
|
||
|
|
actions[..., :dims] -= state_offset
|
||
|
|
return actions
|
||
|
|
|
||
|
|
|
||
|
|
def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||
|
|
"""Convert relative actions back to absolute: absolute = relative + state (for masked dims).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
actions: (B, T, action_dim) or (B, action_dim).
|
||
|
|
state: (B, state_dim). Broadcast across time dimension.
|
||
|
|
mask: Which dims to convert. Can be shorter than action_dim.
|
||
|
|
"""
|
||
|
|
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||
|
|
dims = mask_t.shape[0]
|
||
|
|
# Align state to the same device/dtype as actions. _last_state is cached before
|
||
|
|
# DeviceProcessorStep moves the transition, so it can be on CPU while actions are on CUDA.
|
||
|
|
if state.device != actions.device or state.dtype != actions.dtype:
|
||
|
|
state = state.to(device=actions.device, dtype=actions.dtype)
|
||
|
|
state_offset = state[..., :dims] * mask_t
|
||
|
|
if actions.ndim == 3:
|
||
|
|
state_offset = state_offset.unsqueeze(-2)
|
||
|
|
actions = actions.clone()
|
||
|
|
actions[..., :dims] += state_offset
|
||
|
|
return actions
|
||
|
|
|
||
|
|
|
||
|
|
@ProcessorStepRegistry.register("delta_actions_processor")
|
||
|
|
@dataclass
|
||
|
|
class RelativeActionsProcessorStep(ProcessorStep):
|
||
|
|
"""Converts absolute actions to relative actions (action -= state) for masked dimensions.
|
||
|
|
|
||
|
|
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
|
||
|
|
trains on relative offsets instead of absolute positions.
|
||
|
|
Caches the last seen state so a paired AbsoluteActionsProcessorStep can reverse
|
||
|
|
the conversion during postprocessing.
|
||
|
|
|
||
|
|
Attributes:
|
||
|
|
enabled: Whether to apply the relative conversion.
|
||
|
|
exclude_joints: Joint names to keep absolute (not converted to relative).
|
||
|
|
action_names: Action dimension names from dataset metadata, used to build
|
||
|
|
the mask from exclude_joints. If None, all dims are converted.
|
||
|
|
"""
|
||
|
|
|
||
|
|
enabled: bool = False
|
||
|
|
exclude_joints: list[str] = field(default_factory=list)
|
||
|
|
action_names: list[str] | None = None
|
||
|
|
_last_state: torch.Tensor | None = field(default=None, init=False, repr=False)
|
||
|
|
|
||
|
|
def _build_mask(self, action_dim: int) -> list[bool]:
|
||
|
|
if not self.exclude_joints or self.action_names is None:
|
||
|
|
return [True] * action_dim
|
||
|
|
|
||
|
|
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
|
||
|
|
if not exclude_tokens:
|
||
|
|
return [True] * action_dim
|
||
|
|
|
||
|
|
mask = []
|
||
|
|
for name in self.action_names[:action_dim]:
|
||
|
|
action_name = str(name).lower()
|
||
|
|
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
|
||
|
|
mask.append(not is_excluded)
|
||
|
|
|
||
|
|
if len(mask) < action_dim:
|
||
|
|
mask.extend([True] * (action_dim - len(mask)))
|
||
|
|
|
||
|
|
return mask
|
||
|
|
|
||
|
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
|
|
observation = transition.get(TransitionKey.OBSERVATION, {})
|
||
|
|
state = observation.get(OBS_STATE) if observation else None
|
||
|
|
|
||
|
|
# Always cache state for the paired AbsoluteActionsProcessorStep
|
||
|
|
if state is not None:
|
||
|
|
self._last_state = state
|
||
|
|
|
||
|
|
if not self.enabled:
|
||
|
|
return transition
|
||
|
|
|
||
|
|
new_transition = transition.copy()
|
||
|
|
action = new_transition.get(TransitionKey.ACTION)
|
||
|
|
if action is None or state is None:
|
||
|
|
return new_transition
|
||
|
|
|
||
|
|
mask = self._build_mask(action.shape[-1])
|
||
|
|
new_transition[TransitionKey.ACTION] = to_relative_actions(action, state, mask)
|
||
|
|
return new_transition
|
||
|
|
|
||
|
|
def get_config(self) -> dict[str, Any]:
|
||
|
|
return {
|
||
|
|
"enabled": self.enabled,
|
||
|
|
"exclude_joints": self.exclude_joints,
|
||
|
|
"action_names": self.action_names,
|
||
|
|
}
|
||
|
|
|
||
|
|
def transform_features(
|
||
|
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||
|
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||
|
|
return features
|
||
|
|
|
||
|
|
|
||
|
|
@ProcessorStepRegistry.register("absolute_actions_processor")
|
||
|
|
@dataclass
|
||
|
|
class AbsoluteActionsProcessorStep(ProcessorStep):
|
||
|
|
"""Converts relative actions back to absolute actions (action += state) for all dimensions.
|
||
|
|
|
||
|
|
Mirrors OpenPI's AbsoluteActions transform. Applied during postprocessing so
|
||
|
|
predicted relative offsets are converted back to absolute positions for execution.
|
||
|
|
Reads the cached state from its paired RelativeActionsProcessorStep.
|
||
|
|
|
||
|
|
Attributes:
|
||
|
|
enabled: Whether to apply the absolute conversion.
|
||
|
|
relative_step: Reference to the paired RelativeActionsProcessorStep that caches state.
|
||
|
|
"""
|
||
|
|
|
||
|
|
enabled: bool = False
|
||
|
|
relative_step: RelativeActionsProcessorStep | None = field(default=None, repr=False)
|
||
|
|
|
||
|
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
|
|
if not self.enabled:
|
||
|
|
return transition
|
||
|
|
|
||
|
|
if self.relative_step is None:
|
||
|
|
raise RuntimeError(
|
||
|
|
"AbsoluteActionsProcessorStep requires a paired RelativeActionsProcessorStep "
|
||
|
|
"but relative_step is None. Ensure relative_step is set when constructing the postprocessor."
|
||
|
|
)
|
||
|
|
|
||
|
|
if self.relative_step._last_state is None:
|
||
|
|
raise RuntimeError(
|
||
|
|
"AbsoluteActionsProcessorStep requires state from RelativeActionsProcessorStep "
|
||
|
|
"but no state has been cached. Ensure the preprocessor runs before the postprocessor."
|
||
|
|
)
|
||
|
|
|
||
|
|
new_transition = transition.copy()
|
||
|
|
action = new_transition.get(TransitionKey.ACTION)
|
||
|
|
if action is None:
|
||
|
|
return new_transition
|
||
|
|
|
||
|
|
mask = self.relative_step._build_mask(action.shape[-1])
|
||
|
|
new_transition[TransitionKey.ACTION] = to_absolute_actions(
|
||
|
|
action, self.relative_step._last_state, mask
|
||
|
|
)
|
||
|
|
return new_transition
|
||
|
|
|
||
|
|
def get_config(self) -> dict[str, Any]:
|
||
|
|
return {"enabled": self.enabled}
|
||
|
|
|
||
|
|
def transform_features(
|
||
|
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||
|
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||
|
|
return features
|