Feat/pipeline add feature contract (#1637)

* Add feature contract to pipelinestep and pipeline

* Add tests

* Add processor tests

* PR feedback

* encorperate pr feedback

* type in doc

* oops
This commit is contained in:
Pepijn
2025-07-31 16:29:48 +02:00
committed by Adil Zouitine
parent 5ced72e6b8
commit 2c4e888c7f
9 changed files with 472 additions and 0 deletions

View File

@@ -21,6 +21,7 @@ import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@@ -110,6 +111,27 @@ class ImageProcessor:
"""Reset processor state (no-op for this processor)."""
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms:
pixels -> OBS_IMAGE,
observation.pixels -> OBS_IMAGE,
pixels.<cam> -> OBS_IMAGES.<cam>,
observation.pixels.<cam> -> OBS_IMAGES.<cam>
"""
if "pixels" in features:
features[OBS_IMAGE] = features.pop("pixels")
if "observation.pixels" in features:
features[OBS_IMAGE] = features.pop("observation.pixels")
prefixes = ("pixels.", "observation.pixels.")
for key in list(features.keys()):
for p in prefixes:
if key.startswith(p):
suffix = key[len(p) :]
features[f"{OBS_IMAGES}.{suffix}"] = features.pop(key)
break
return features
@dataclass
class StateProcessor:
@@ -169,6 +191,25 @@ class StateProcessor:
"""Reset processor state (no-op for this processor)."""
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms:
environment_state -> OBS_ENV_STATE,
agent_pos -> OBS_STATE,
observation.environment_state -> OBS_ENV_STATE,
observation.agent_pos -> OBS_STATE
"""
pairs = (
("environment_state", OBS_ENV_STATE),
("agent_pos", OBS_STATE),
)
for old, new in pairs:
if old in features:
features[new] = features.pop(old)
prefixed = f"observation.{old}"
if prefixed in features:
features[new] = features.pop(prefixed)
return features
@dataclass
@ProcessorStepRegistry.register(name="observation_processor")
@@ -219,3 +260,8 @@ class VanillaObservationProcessor:
"""Reset processor state."""
self.image_processor.reset()
self.state_processor.reset()
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features = self.image_processor.feature_contract(features)
features = self.state_processor.feature_contract(features)
return features