Feat/pipeline add feature contract (#1637)

* Add feature contract to pipelinestep and pipeline

* Add tests

* Add processor tests

* PR feedback

* encorperate pr feedback

* type in doc

* oops
This commit is contained in:
Pepijn
2025-07-31 16:29:48 +02:00
committed by Adil Zouitine
parent 5ced72e6b8
commit 2c4e888c7f
9 changed files with 472 additions and 0 deletions

View File

@@ -19,6 +19,7 @@ import importlib
import json
import os
from collections.abc import Callable, Iterable, Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
@@ -29,6 +30,7 @@ from huggingface_hub import ModelHubMixin, hf_hub_download
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_file, save_file
from lerobot.configs.types import PolicyFeature
from lerobot.utils.utils import get_safe_torch_device
@@ -141,6 +143,11 @@ class ProcessorStep(Protocol):
automatically serialise the step's configuration and learnable state using
a safe-to-share JSON + SafeTensors format.
**Required**:
- ``__call__(transition: EnvTransition) -> EnvTransition``
- ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]``
Optional helper protocol:
* ``get_config() -> dict[str, Any]`` User-defined JSON-serializable
configuration and state. YOU decide what to save here. This is where all
@@ -168,6 +175,8 @@ class ProcessorStep(Protocol):
def reset(self) -> None: ...
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ...
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
"""Convert a *batch* dict coming from Learobot replay/dataset code into an
@@ -840,6 +849,33 @@ class RobotProcessor(ModelHubMixin):
return f"RobotProcessor({', '.join(parts)})"
def __post_init__(self):
for i, step in enumerate(self.steps):
if not callable(step):
raise TypeError(
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
)
fc = getattr(step, "feature_contract", None)
if not callable(fc):
raise TypeError(
f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]"
)
def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""
Apply ALL steps in order. Each step must implement
feature_contract(features) and return a dict (full or incremental schema).
"""
features: dict[str, PolicyFeature] = deepcopy(initial_features)
for _, step in enumerate(self.steps):
out = step.feature_contract(features)
if not isinstance(out, dict):
raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]")
features = out
return features
class ObservationProcessor:
"""Base class for processors that modify only the observation component of a transition.
@@ -1145,3 +1181,6 @@ class IdentityProcessor:
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features