mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
Feat/pipeline add feature contract (#1637)
* Add feature contract to pipelinestep and pipeline * Add tests * Add processor tests * PR feedback * encorperate pr feedback * type in doc * oops
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -25,8 +26,10 @@ import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
def create_transition(
|
||||
@@ -88,6 +91,10 @@ class MockStep:
|
||||
def reset(self) -> None:
|
||||
self.counter = 0
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockStepWithoutOptionalMethods:
|
||||
@@ -106,6 +113,10 @@ class MockStepWithoutOptionalMethods:
|
||||
|
||||
return transition
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockStepWithTensorState:
|
||||
@@ -158,6 +169,10 @@ class MockStepWithTensorState:
|
||||
self.running_mean.zero_()
|
||||
self.running_count.zero_()
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
return features
|
||||
|
||||
|
||||
def test_empty_pipeline():
|
||||
"""Test pipeline with no steps."""
|
||||
@@ -699,6 +714,10 @@ class MockModuleStep(nn.Module):
|
||||
self.running_mean.zero_()
|
||||
self.counter = 0
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
return features
|
||||
|
||||
|
||||
def test_to_device_with_state_dict():
|
||||
"""Test moving pipeline to device for steps with state_dict."""
|
||||
@@ -953,6 +972,10 @@ class MockNonModuleStepWithState:
|
||||
self.step_count.zero_()
|
||||
self.history.clear()
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
return features
|
||||
|
||||
|
||||
def test_to_device_non_module_class():
|
||||
"""Test moving pipeline to device for regular classes (non nn.Module) with tensor state.
|
||||
@@ -1127,6 +1150,10 @@ class MockStepWithNonSerializableParam:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("registered_mock_step")
|
||||
@dataclass
|
||||
@@ -1162,6 +1189,10 @@ class RegisteredMockStep:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
return features
|
||||
|
||||
|
||||
class MockEnvironment:
|
||||
"""Mock environment for testing non-serializable parameters."""
|
||||
@@ -1483,6 +1514,10 @@ class MockStepWithMixedState:
|
||||
"list_value": self.list_value,
|
||||
}
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
return features
|
||||
|
||||
|
||||
def test_to_device_with_mixed_state_types():
|
||||
"""Test that to() only moves tensor state, while non-tensor state remains in config."""
|
||||
@@ -1790,6 +1825,10 @@ def test_state_file_naming_with_registry():
|
||||
def load_state_dict(self, state):
|
||||
self.state_tensor = state["state_tensor"]
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
return features
|
||||
|
||||
try:
|
||||
# Create pipeline with registered steps
|
||||
step1 = TestStatefulStep(1)
|
||||
@@ -1843,6 +1882,10 @@ def test_override_with_nested_config():
|
||||
def get_config(self):
|
||||
return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config}
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
return features
|
||||
|
||||
try:
|
||||
step = ComplexConfigStep()
|
||||
pipeline = RobotProcessor([step])
|
||||
@@ -1931,6 +1974,10 @@ def test_override_with_callables():
|
||||
def get_config(self):
|
||||
return {"name": self.name}
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
return features
|
||||
|
||||
try:
|
||||
step = CallableStep()
|
||||
pipeline = RobotProcessor([step])
|
||||
@@ -2059,6 +2106,10 @@ def test_override_with_device_strings():
|
||||
def load_state_dict(self, state):
|
||||
self.buffer = state["buffer"]
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
return features
|
||||
|
||||
try:
|
||||
step = DeviceAwareStep(device="cpu")
|
||||
pipeline = RobotProcessor([step])
|
||||
@@ -2146,3 +2197,166 @@ def test_save_load_with_custom_converter_functions():
|
||||
# Should work with standard format (wouldn't work with custom converter)
|
||||
result = loaded(batch)
|
||||
assert "observation.image" in result # Standard format preserved
|
||||
|
||||
|
||||
class NonCompliantStep:
|
||||
"""Intentionally non-compliant: missing feature_contract."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
|
||||
def test_construction_rejects_step_without_feature_contract():
|
||||
with pytest.raises(TypeError, match=r"must define feature_contract\(features\) -> dict\[str, Any\]"):
|
||||
RobotProcessor([NonCompliantStep()])
|
||||
|
||||
|
||||
class NonCallableStep:
|
||||
"""Intentionally non-compliant: missing __call__."""
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
def test_construction_rejects_step_without_call():
|
||||
with pytest.raises(TypeError, match=r"must define __call__"):
|
||||
RobotProcessor([NonCallableStep()])
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureContractAddStep:
|
||||
"""Adds a PolicyFeature"""
|
||||
|
||||
key: str = "a"
|
||||
value: PolicyFeature = PolicyFeature(type=FeatureType.STATE, shape=(1,))
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[self.key] = self.value
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureContractMutateStep:
|
||||
"""Mutates a PolicyFeature"""
|
||||
|
||||
key: str = "a"
|
||||
fn: Callable[[PolicyFeature | None], PolicyFeature] = lambda x: x # noqa: E731
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[self.key] = self.fn(features.get(self.key))
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureContractBadReturnStep:
|
||||
"""Returns a non-dict"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return ["not-a-dict"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureContractRemoveStep:
|
||||
"""Removes a PolicyFeature"""
|
||||
|
||||
key: str
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(self.key, None)
|
||||
return features
|
||||
|
||||
|
||||
def test_feature_contract_orders_and_merges(policy_feature_factory):
|
||||
p = RobotProcessor(
|
||||
[
|
||||
FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))),
|
||||
FeatureContractMutateStep("a", lambda v: PolicyFeature(type=v.type, shape=(3,))),
|
||||
FeatureContractAddStep("b", policy_feature_factory(FeatureType.ENV, (2,))),
|
||||
]
|
||||
)
|
||||
out = p.feature_contract({})
|
||||
|
||||
assert out["a"].type == FeatureType.STATE and out["a"].shape == (3,)
|
||||
assert out["b"].type == FeatureType.ENV and out["b"].shape == (2,)
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_feature_contract_respects_initial_without_mutation(policy_feature_factory):
|
||||
initial = {
|
||||
"seed": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
"nested": policy_feature_factory(FeatureType.ENV, (0,)),
|
||||
}
|
||||
p = RobotProcessor(
|
||||
[
|
||||
FeatureContractMutateStep("seed", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 1,))),
|
||||
FeatureContractMutateStep(
|
||||
"nested", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 5,))
|
||||
),
|
||||
]
|
||||
)
|
||||
out = p.feature_contract(initial_features=initial)
|
||||
|
||||
assert out["seed"].shape == (8,)
|
||||
assert out["nested"].shape == (5,)
|
||||
# Initial dict must be preserved
|
||||
assert initial["seed"].shape == (7,)
|
||||
assert initial["nested"].shape == (0,)
|
||||
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_feature_contract_type_error_on_bad_step():
|
||||
p = RobotProcessor([FeatureContractAddStep(), FeatureContractBadReturnStep()])
|
||||
with pytest.raises(TypeError, match=r"\w+\.feature_contract must return dict\[str, Any\]"):
|
||||
_ = p.feature_contract({})
|
||||
|
||||
|
||||
def test_feature_contract_execution_order_tracking():
|
||||
class Track:
|
||||
def __init__(self, label):
|
||||
self.label = label
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
code = {"A": 1, "B": 2, "C": 3}[self.label]
|
||||
pf = features.get("order", PolicyFeature(type=FeatureType.ENV, shape=()))
|
||||
features["order"] = PolicyFeature(type=pf.type, shape=pf.shape + (code,))
|
||||
return features
|
||||
|
||||
out = RobotProcessor([Track("A"), Track("B"), Track("C")]).feature_contract({})
|
||||
assert out["order"].shape == (1, 2, 3)
|
||||
|
||||
|
||||
def test_feature_contract_remove_key(policy_feature_factory):
|
||||
p = RobotProcessor(
|
||||
[
|
||||
FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))),
|
||||
FeatureContractRemoveStep("a"),
|
||||
]
|
||||
)
|
||||
out = p.feature_contract({})
|
||||
assert "a" not in out
|
||||
|
||||
|
||||
def test_feature_contract_remove_from_initial(policy_feature_factory):
|
||||
initial = {
|
||||
"keep": policy_feature_factory(FeatureType.STATE, (1,)),
|
||||
"drop": policy_feature_factory(FeatureType.STATE, (1,)),
|
||||
}
|
||||
p = RobotProcessor([FeatureContractRemoveStep("drop")])
|
||||
out = p.feature_contract(initial_features=initial)
|
||||
assert "drop" not in out and out["keep"] == initial["keep"]
|
||||
|
||||
Reference in New Issue
Block a user