mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
Feat/pipeline add feature contract (#1637)
* Add feature contract to pipelinestep and pipeline * Add tests * Add processor tests * PR feedback * encorperate pr feedback * type in doc * oops
This commit is contained in:
@@ -18,6 +18,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
@@ -74,3 +75,6 @@ class DeviceProcessor:
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {"device": self.device}
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -204,6 +204,9 @@ class NormalizerProcessor:
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="unnormalizer_processor")
|
||||
@@ -327,3 +330,6 @@ class UnnormalizerProcessor:
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -21,6 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
@@ -110,6 +111,27 @@ class ImageProcessor:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms:
|
||||
pixels -> OBS_IMAGE,
|
||||
observation.pixels -> OBS_IMAGE,
|
||||
pixels.<cam> -> OBS_IMAGES.<cam>,
|
||||
observation.pixels.<cam> -> OBS_IMAGES.<cam>
|
||||
"""
|
||||
if "pixels" in features:
|
||||
features[OBS_IMAGE] = features.pop("pixels")
|
||||
if "observation.pixels" in features:
|
||||
features[OBS_IMAGE] = features.pop("observation.pixels")
|
||||
|
||||
prefixes = ("pixels.", "observation.pixels.")
|
||||
for key in list(features.keys()):
|
||||
for p in prefixes:
|
||||
if key.startswith(p):
|
||||
suffix = key[len(p) :]
|
||||
features[f"{OBS_IMAGES}.{suffix}"] = features.pop(key)
|
||||
break
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateProcessor:
|
||||
@@ -169,6 +191,25 @@ class StateProcessor:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms:
|
||||
environment_state -> OBS_ENV_STATE,
|
||||
agent_pos -> OBS_STATE,
|
||||
observation.environment_state -> OBS_ENV_STATE,
|
||||
observation.agent_pos -> OBS_STATE
|
||||
"""
|
||||
pairs = (
|
||||
("environment_state", OBS_ENV_STATE),
|
||||
("agent_pos", OBS_STATE),
|
||||
)
|
||||
for old, new in pairs:
|
||||
if old in features:
|
||||
features[new] = features.pop(old)
|
||||
prefixed = f"observation.{old}"
|
||||
if prefixed in features:
|
||||
features[new] = features.pop(prefixed)
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="observation_processor")
|
||||
@@ -219,3 +260,8 @@ class VanillaObservationProcessor:
|
||||
"""Reset processor state."""
|
||||
self.image_processor.reset()
|
||||
self.state_processor.reset()
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features = self.image_processor.feature_contract(features)
|
||||
features = self.state_processor.feature_contract(features)
|
||||
return features
|
||||
|
||||
@@ -19,6 +19,7 @@ import importlib
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
@@ -29,6 +30,7 @@ from huggingface_hub import ModelHubMixin, hf_hub_download
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.utils.utils import get_safe_torch_device
|
||||
|
||||
|
||||
@@ -141,6 +143,11 @@ class ProcessorStep(Protocol):
|
||||
automatically serialise the step's configuration and learnable state using
|
||||
a safe-to-share JSON + SafeTensors format.
|
||||
|
||||
|
||||
**Required**:
|
||||
- ``__call__(transition: EnvTransition) -> EnvTransition``
|
||||
- ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]``
|
||||
|
||||
Optional helper protocol:
|
||||
* ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable
|
||||
configuration and state. YOU decide what to save here. This is where all
|
||||
@@ -168,6 +175,8 @@ class ProcessorStep(Protocol):
|
||||
|
||||
def reset(self) -> None: ...
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ...
|
||||
|
||||
|
||||
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
|
||||
"""Convert a *batch* dict coming from Learobot replay/dataset code into an
|
||||
@@ -840,6 +849,33 @@ class RobotProcessor(ModelHubMixin):
|
||||
|
||||
return f"RobotProcessor({', '.join(parts)})"
|
||||
|
||||
def __post_init__(self):
|
||||
for i, step in enumerate(self.steps):
|
||||
if not callable(step):
|
||||
raise TypeError(
|
||||
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
|
||||
)
|
||||
|
||||
fc = getattr(step, "feature_contract", None)
|
||||
if not callable(fc):
|
||||
raise TypeError(
|
||||
f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]"
|
||||
)
|
||||
|
||||
def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
Apply ALL steps in order. Each step must implement
|
||||
feature_contract(features) and return a dict (full or incremental schema).
|
||||
"""
|
||||
features: dict[str, PolicyFeature] = deepcopy(initial_features)
|
||||
|
||||
for _, step in enumerate(self.steps):
|
||||
out = step.feature_contract(features)
|
||||
if not isinstance(out, dict):
|
||||
raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]")
|
||||
features = out
|
||||
return features
|
||||
|
||||
|
||||
class ObservationProcessor:
|
||||
"""Base class for processors that modify only the observation component of a transition.
|
||||
@@ -1145,3 +1181,6 @@ class IdentityProcessor:
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
@@ -53,3 +54,10 @@ class RenameProcessor:
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms:
|
||||
- Each key in the observation that appears in `rename_map` is renamed to its value.
|
||||
- Keys not in `rename_map` remain unchanged.
|
||||
"""
|
||||
return {self.rename_map.get(k, k): v for k, v in features.items()}
|
||||
|
||||
Reference in New Issue
Block a user