refactor(processors): add transform_features method to various processors (#1843)

This commit is contained in:
Steven Palma
2025-09-02 17:15:01 +02:00
committed by GitHub
parent 645c87e3a9
commit 2914ae2a96
11 changed files with 71 additions and 2 deletions

View File

@@ -39,6 +39,9 @@ class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor):
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
return new_complementary_data
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@ProcessorStepRegistry.register("add_teleop_action_as_info")
@dataclass
@@ -53,6 +56,9 @@ class AddTeleopEventsAsInfo(InfoProcessor):
new_info.update(teleop_events)
return new_info
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@ProcessorStepRegistry.register("image_crop_resize_processor")
@dataclass
@@ -127,6 +133,9 @@ class TimeLimitProcessor(TruncatedProcessor):
def reset(self) -> None:
self.current_step = 0
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
@@ -173,6 +182,9 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
"""Reset the processor state."""
self.last_gripper_state = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("intervention_action_processor")
@@ -243,6 +255,9 @@ class InterventionActionProcessor(ProcessorStep):
"terminate_on_success": self.terminate_on_success,
}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("reward_classifier_processor")
@@ -312,3 +327,6 @@ class RewardClassifierProcessor(ProcessorStep):
"success_reward": self.success_reward,
"terminate_on_success": self.terminate_on_success,
}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features