mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 04:11:24 +00:00
refactor(processors): add transform_features method to various processors (#1843)
This commit is contained in:
@@ -15,6 +15,7 @@ from dataclasses import dataclass, field
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor.pipeline import (
|
||||
ActionProcessor,
|
||||
@@ -37,6 +38,9 @@ class ToBatchProcessorAction(ActionProcessor):
|
||||
|
||||
return action.unsqueeze(0)
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_observation")
|
||||
@@ -63,6 +67,9 @@ class ToBatchProcessorObservation(ObservationProcessor):
|
||||
observation[key] = value.unsqueeze(0)
|
||||
return observation
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
|
||||
@@ -89,6 +96,9 @@ class ToBatchProcessorComplementaryData(ComplementaryDataProcessor):
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
return complementary_data
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor")
|
||||
@@ -140,3 +150,6 @@ class ToBatchProcessor(ProcessorStep):
|
||||
transition = self.to_batch_observation_processor(transition)
|
||||
transition = self.to_batch_complementary_data_processor(transition)
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -295,7 +295,8 @@ def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> E
|
||||
Returns:
|
||||
Merged EnvTransition.
|
||||
"""
|
||||
if isinstance(transitions, EnvTransition): # Single transition
|
||||
|
||||
if not isinstance(transitions, Sequence): # Single transition
|
||||
return transitions
|
||||
|
||||
items = list(transitions)
|
||||
|
||||
@@ -45,6 +45,9 @@ class MapTensorToDeltaActionDict(ActionProcessor):
|
||||
delta_action["action.gripper"] = action[3]
|
||||
return delta_action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
|
||||
@dataclass
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
from lerobot.utils.utils import get_safe_torch_device
|
||||
@@ -127,3 +128,6 @@ class DeviceProcessor(ProcessorStep):
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {"device": self.device, "float_dtype": self.float_dtype}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -16,6 +16,7 @@ from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
|
||||
|
||||
@@ -48,6 +49,9 @@ class Torch2NumpyActionProcessor(ActionProcessor):
|
||||
|
||||
return numpy_action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("numpy2torch_action_processor")
|
||||
@dataclass
|
||||
@@ -62,3 +66,6 @@ class Numpy2TorchActionProcessor(ActionProcessor):
|
||||
)
|
||||
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
||||
return torch_action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -39,6 +39,9 @@ class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor):
|
||||
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_info")
|
||||
@dataclass
|
||||
@@ -53,6 +56,9 @@ class AddTeleopEventsAsInfo(InfoProcessor):
|
||||
new_info.update(teleop_events)
|
||||
return new_info
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("image_crop_resize_processor")
|
||||
@dataclass
|
||||
@@ -127,6 +133,9 @@ class TimeLimitProcessor(TruncatedProcessor):
|
||||
def reset(self) -> None:
|
||||
self.current_step = 0
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
@@ -173,6 +182,9 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
|
||||
"""Reset the processor state."""
|
||||
self.last_gripper_state = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("intervention_action_processor")
|
||||
@@ -243,6 +255,9 @@ class InterventionActionProcessor(ProcessorStep):
|
||||
"terminate_on_success": self.terminate_on_success,
|
||||
}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("reward_classifier_processor")
|
||||
@@ -312,3 +327,6 @@ class RewardClassifierProcessor(ProcessorStep):
|
||||
"success_reward": self.success_reward,
|
||||
"terminate_on_success": self.terminate_on_success,
|
||||
}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -211,6 +211,9 @@ class NormalizerProcessor(_NormalizationMixin, ProcessorStep):
|
||||
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="unnormalizer_processor")
|
||||
@@ -249,6 +252,9 @@ class UnnormalizerProcessor(_NormalizationMixin, ProcessorStep):
|
||||
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor:
|
||||
"""
|
||||
|
||||
@@ -169,7 +169,7 @@ class ProcessorStep(ABC):
|
||||
def reset(self) -> None:
|
||||
return None
|
||||
|
||||
# TODO(Steven): Consider making this abstract so it is more explicit
|
||||
@abstractmethod
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -1091,3 +1091,6 @@ class IdentityProcessor(ProcessorStep):
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
Reference in New Issue
Block a user