refactor(processors): add transform_features method to various processors (#1843)

This commit is contained in:
Steven Palma
2025-09-02 17:15:01 +02:00
committed by GitHub
parent 645c87e3a9
commit 2914ae2a96
11 changed files with 71 additions and 2 deletions

View File

@@ -15,6 +15,7 @@ from dataclasses import dataclass, field
from torch import Tensor
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor.pipeline import (
ActionProcessor,
@@ -37,6 +38,9 @@ class ToBatchProcessorAction(ActionProcessor):
return action.unsqueeze(0)
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_observation")
@@ -63,6 +67,9 @@ class ToBatchProcessorObservation(ObservationProcessor):
observation[key] = value.unsqueeze(0)
return observation
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
@@ -89,6 +96,9 @@ class ToBatchProcessorComplementaryData(ComplementaryDataProcessor):
complementary_data["task_index"] = task_index_value.unsqueeze(0)
return complementary_data
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor")
@@ -140,3 +150,6 @@ class ToBatchProcessor(ProcessorStep):
transition = self.to_batch_observation_processor(transition)
transition = self.to_batch_complementary_data_processor(transition)
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features