mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
refactor(pipeline): enforce ProcessorStep inheritance for pipeline steps (#1862)
- Updated the DataProcessorPipeline to require that all steps inherit from ProcessorStep, enhancing type safety and clarity. - Adjusted tests to utilize a MockTokenizerProcessorStep that adheres to the ProcessorStep interface, ensuring consistent behavior across tests. - Refactored various mock step classes in tests to inherit from ProcessorStep for improved consistency and maintainability.
This commit is contained in:
@@ -27,7 +27,13 @@ import torch.nn as nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.processor import DataProcessorPipeline, EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.processor import (
|
||||
DataProcessorPipeline,
|
||||
EnvTransition,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
@@ -47,7 +53,7 @@ def create_transition(
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockStep:
|
||||
class MockStep(ProcessorStep):
|
||||
"""Mock pipeline step for testing - demonstrates best practices.
|
||||
|
||||
This example shows the proper separation:
|
||||
@@ -96,7 +102,7 @@ class MockStep:
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockStepWithoutOptionalMethods:
|
||||
class MockStepWithoutOptionalMethods(ProcessorStep):
|
||||
"""Mock step that only implements the required __call__ method."""
|
||||
|
||||
multiplier: float = 2.0
|
||||
@@ -118,7 +124,7 @@ class MockStepWithoutOptionalMethods:
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockStepWithTensorState:
|
||||
class MockStepWithTensorState(ProcessorStep):
|
||||
"""Mock step demonstrating mixed JSON attributes and tensor state."""
|
||||
|
||||
name: str = "tensor_step"
|
||||
@@ -613,7 +619,7 @@ def test_mixed_json_and_tensor_state():
|
||||
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
||||
|
||||
|
||||
class MockModuleStep(nn.Module):
|
||||
class MockModuleStep(ProcessorStep, nn.Module):
|
||||
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
||||
|
||||
def __init__(self, input_dim: int = 10, hidden_dim: int = 5):
|
||||
@@ -653,12 +659,12 @@ class MockModuleStep(nn.Module):
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Override to return all module parameters and buffers."""
|
||||
# Get the module's state dict (includes all parameters and buffers)
|
||||
return super().state_dict()
|
||||
return nn.Module.state_dict(self)
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Override to load all module parameters and buffers."""
|
||||
# Use the module's load_state_dict
|
||||
super().load_state_dict(state)
|
||||
nn.Module.load_state_dict(self, state)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.running_mean.zero_()
|
||||
@@ -669,7 +675,7 @@ class MockModuleStep(nn.Module):
|
||||
return features
|
||||
|
||||
|
||||
class MockNonModuleStepWithState:
|
||||
class MockNonModuleStepWithState(ProcessorStep):
|
||||
"""Mock step that explicitly does NOT inherit from nn.Module but has tensor state.
|
||||
|
||||
This tests the state_dict/load_state_dict path for regular classes.
|
||||
@@ -753,7 +759,7 @@ class MockNonModuleStepWithState:
|
||||
|
||||
# Tests for overrides functionality
|
||||
@dataclass
|
||||
class MockStepWithNonSerializableParam:
|
||||
class MockStepWithNonSerializableParam(ProcessorStep):
|
||||
"""Mock step that requires a non-serializable parameter."""
|
||||
|
||||
def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None):
|
||||
@@ -808,7 +814,7 @@ class MockStepWithNonSerializableParam:
|
||||
|
||||
@ProcessorStepRegistry.register("registered_mock_step")
|
||||
@dataclass
|
||||
class RegisteredMockStep:
|
||||
class RegisteredMockStep(ProcessorStep):
|
||||
"""Mock step registered in the registry."""
|
||||
|
||||
value: int = 42
|
||||
@@ -1381,7 +1387,7 @@ def test_state_file_naming_with_registry():
|
||||
# Register a test step
|
||||
@ProcessorStepRegistry.register("test_stateful_step")
|
||||
@dataclass
|
||||
class TestStatefulStep:
|
||||
class TestStatefulStep(ProcessorStep):
|
||||
value: int = 0
|
||||
|
||||
def __init__(self, value: int = 0):
|
||||
@@ -1436,7 +1442,7 @@ def test_override_with_nested_config():
|
||||
|
||||
@ProcessorStepRegistry.register("complex_config_step")
|
||||
@dataclass
|
||||
class ComplexConfigStep:
|
||||
class ComplexConfigStep(ProcessorStep):
|
||||
name: str = "complex"
|
||||
simple_param: int = 42
|
||||
nested_config: dict = None
|
||||
@@ -1532,7 +1538,7 @@ def test_override_with_callables():
|
||||
|
||||
@ProcessorStepRegistry.register("callable_step")
|
||||
@dataclass
|
||||
class CallableStep:
|
||||
class CallableStep(ProcessorStep):
|
||||
name: str = "callable_step"
|
||||
transform_fn: Any = None
|
||||
|
||||
@@ -1667,7 +1673,7 @@ def test_override_with_device_strings():
|
||||
|
||||
@ProcessorStepRegistry.register("device_aware_step")
|
||||
@dataclass
|
||||
class DeviceAwareStep:
|
||||
class DeviceAwareStep(ProcessorStep):
|
||||
device: str = "cpu"
|
||||
|
||||
def __init__(self, device: str = "cpu"):
|
||||
@@ -1806,13 +1812,17 @@ class NonCallableStep:
|
||||
return features
|
||||
|
||||
|
||||
def test_construction_rejects_step_without_call():
|
||||
with pytest.raises(TypeError, match=r"must define __call__"):
|
||||
def test_construction_rejects_step_without_processorstep():
|
||||
"""Test that DataProcessorPipeline rejects steps that don't inherit from ProcessorStep."""
|
||||
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
|
||||
DataProcessorPipeline([NonCallableStep()])
|
||||
|
||||
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
|
||||
DataProcessorPipeline([NonCompliantStep()])
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureContractAddStep:
|
||||
class FeatureContractAddStep(ProcessorStep):
|
||||
"""Adds a PolicyFeature"""
|
||||
|
||||
key: str = "a"
|
||||
@@ -1827,7 +1837,7 @@ class FeatureContractAddStep:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureContractMutateStep:
|
||||
class FeatureContractMutateStep(ProcessorStep):
|
||||
"""Mutates a PolicyFeature"""
|
||||
|
||||
key: str = "a"
|
||||
@@ -1842,7 +1852,7 @@ class FeatureContractMutateStep:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureContractBadReturnStep:
|
||||
class FeatureContractBadReturnStep(ProcessorStep):
|
||||
"""Returns a non-dict"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -1853,7 +1863,7 @@ class FeatureContractBadReturnStep:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureContractRemoveStep:
|
||||
class FeatureContractRemoveStep(ProcessorStep):
|
||||
"""Removes a PolicyFeature"""
|
||||
|
||||
key: str
|
||||
@@ -1906,7 +1916,7 @@ def test_features_respects_initial_without_mutation(policy_feature_factory):
|
||||
|
||||
|
||||
def test_features_execution_order_tracking():
|
||||
class Track:
|
||||
class Track(ProcessorStep):
|
||||
def __init__(self, label):
|
||||
self.label = label
|
||||
|
||||
@@ -1945,7 +1955,7 @@ def test_features_remove_from_initial(policy_feature_factory):
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddActionEEAndJointFeatures:
|
||||
class AddActionEEAndJointFeatures(ProcessorStep):
|
||||
"""Adds both EE and JOINT action features."""
|
||||
|
||||
def __call__(self, tr):
|
||||
@@ -1962,7 +1972,7 @@ class AddActionEEAndJointFeatures:
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddObservationStateFeatures:
|
||||
class AddObservationStateFeatures(ProcessorStep):
|
||||
"""Adds state features (and optionally an image spec to test precedence)."""
|
||||
|
||||
add_front_image: bool = False
|
||||
|
||||
Reference in New Issue
Block a user