refactor(pipeline): feature contract now categorizes between OBS or Action (#1867)

* refactor(processor): signature of transform_features

* refactor(processor): remove prefixes + processor respect new transform_features signature + update test accordingly

* refactor(processor): rename now is only for visual

* refactor(processor): update normalize processor

* refactor(processor): update vanilla processor features

* refactor(processor): feature contract now uses its own enum

* chore(processor): rename renameprocessor

* chore(processor): minor changes

* refactor(processor): add create & change aggregate

* refactor(processor): update aggregate

* refactor(processor): simplify to functions, fix features contracts and rename function

* test(processor): remove to converter tests as now they are very simple

* chore(docs): recover docs joint observations processor

* fix(processor): update RKP

* fix(tests): recv diff test_pipeline

* chore(tests): add docs to test

* chore(processor): leave obs language constant untouched

* fix(processor): correct new shape of feature in crop image processor
This commit is contained in:
Steven Palma
2025-09-09 18:27:30 +02:00
committed by GitHub
parent acf0ba7fb3
commit e881fb6678
47 changed files with 781 additions and 616 deletions

View File

@@ -20,7 +20,7 @@ import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import PolicyFeature
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@@ -128,7 +128,9 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
def observation(self, observation):
return self._process_observation(observation)
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Transforms feature keys from the Gym standard to the LeRobot standard.
@@ -148,6 +150,10 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
Returns:
The policy features dictionary with standardized LeRobot keys.
"""
# Build a new features mapping keyed by the same FeatureType buckets
# We assume callers already placed features in the correct FeatureType.
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features.keys()}
exact_pairs = {
"pixels": OBS_IMAGE,
"environment_state": OBS_ENV_STATE,
@@ -158,29 +164,43 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
"pixels.": f"{OBS_IMAGES}.",
}
for key in list(features.keys()):
matched_prefix = False
for old_prefix, new_prefix in prefix_pairs.items():
prefixed_old = f"observation.{old_prefix}"
if key.startswith(prefixed_old):
suffix = key[len(prefixed_old) :]
features[f"{new_prefix}{suffix}"] = features.pop(key)
matched_prefix = True
break
# Iterate over all incoming feature buckets and normalize/move each entry
for src_ft, bucket in features.items():
for key, feat in list(bucket.items()):
handled = False
if key.startswith(old_prefix):
suffix = key[len(old_prefix) :]
features[f"{new_prefix}{suffix}"] = features.pop(key)
matched_prefix = True
break
if matched_prefix:
continue
for old, new in exact_pairs.items():
if key == old or key == f"observation.{old}":
if key in features:
features[new] = features.pop(key)
# Prefix-based rules (e.g. pixels.cam1 -> OBS_IMAGES.cam1)
for old_prefix, new_prefix in prefix_pairs.items():
prefixed_old = f"observation.{old_prefix}"
if key.startswith(prefixed_old):
suffix = key[len(prefixed_old) :]
new_key = f"{new_prefix}{suffix}"
new_features[src_ft][new_key] = feat
handled = True
break
return features
if key.startswith(old_prefix):
suffix = key[len(old_prefix) :]
new_key = f"{new_prefix}{suffix}"
new_features[src_ft][new_key] = feat
handled = True
break
if handled:
continue
# Exact-name rules (pixels, environment_state, agent_pos)
for old, new in exact_pairs.items():
if key == old or key == f"observation.{old}":
new_key = new
new_features[src_ft][new_key] = feat
handled = True
break
if handled:
continue
# Default: keep key in the same source FeatureType bucket
new_features[src_ft][key] = feat
return new_features