mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
refactor(pipeline): feature contract now categorizes between OBS or Action (#1867)
* refactor(processor): signature of transform_features * refactor(processor): remove prefixes + processor respect new transform_features signature + update test accordingly * refactor(processor): rename now is only for visual * refactor(processor): update normalize processor * refactor(processor): update vanilla processor features * refactor(processor): feature contract now uses its own enum * chore(processor): rename renameprocessor * chore(processor): minor changes * refactor(processor): add create & change aggregate * refactor(processor): update aggregate * refactor(processor): simplify to functions, fix features contracts and rename function * test(processor): remove to converter tests as now they are very simple * chore(docs): recover docs joint observations processor * fix(processor): update RKP * fix(tests): recv diff test_pipeline * chore(tests): add docs to test * chore(processor): leave obs language constant untouched * fix(processor): correct new shape of feature in crop image processor
This commit is contained in:
@@ -20,7 +20,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
@@ -128,7 +128,9 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Transforms feature keys from the Gym standard to the LeRobot standard.
|
||||
|
||||
@@ -148,6 +150,10 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
Returns:
|
||||
The policy features dictionary with standardized LeRobot keys.
|
||||
"""
|
||||
# Build a new features mapping keyed by the same FeatureType buckets
|
||||
# We assume callers already placed features in the correct FeatureType.
|
||||
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features.keys()}
|
||||
|
||||
exact_pairs = {
|
||||
"pixels": OBS_IMAGE,
|
||||
"environment_state": OBS_ENV_STATE,
|
||||
@@ -158,29 +164,43 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
"pixels.": f"{OBS_IMAGES}.",
|
||||
}
|
||||
|
||||
for key in list(features.keys()):
|
||||
matched_prefix = False
|
||||
for old_prefix, new_prefix in prefix_pairs.items():
|
||||
prefixed_old = f"observation.{old_prefix}"
|
||||
if key.startswith(prefixed_old):
|
||||
suffix = key[len(prefixed_old) :]
|
||||
features[f"{new_prefix}{suffix}"] = features.pop(key)
|
||||
matched_prefix = True
|
||||
break
|
||||
# Iterate over all incoming feature buckets and normalize/move each entry
|
||||
for src_ft, bucket in features.items():
|
||||
for key, feat in list(bucket.items()):
|
||||
handled = False
|
||||
|
||||
if key.startswith(old_prefix):
|
||||
suffix = key[len(old_prefix) :]
|
||||
features[f"{new_prefix}{suffix}"] = features.pop(key)
|
||||
matched_prefix = True
|
||||
break
|
||||
|
||||
if matched_prefix:
|
||||
continue
|
||||
|
||||
for old, new in exact_pairs.items():
|
||||
if key == old or key == f"observation.{old}":
|
||||
if key in features:
|
||||
features[new] = features.pop(key)
|
||||
# Prefix-based rules (e.g. pixels.cam1 -> OBS_IMAGES.cam1)
|
||||
for old_prefix, new_prefix in prefix_pairs.items():
|
||||
prefixed_old = f"observation.{old_prefix}"
|
||||
if key.startswith(prefixed_old):
|
||||
suffix = key[len(prefixed_old) :]
|
||||
new_key = f"{new_prefix}{suffix}"
|
||||
new_features[src_ft][new_key] = feat
|
||||
handled = True
|
||||
break
|
||||
|
||||
return features
|
||||
if key.startswith(old_prefix):
|
||||
suffix = key[len(old_prefix) :]
|
||||
new_key = f"{new_prefix}{suffix}"
|
||||
new_features[src_ft][new_key] = feat
|
||||
handled = True
|
||||
break
|
||||
|
||||
if handled:
|
||||
continue
|
||||
|
||||
# Exact-name rules (pixels, environment_state, agent_pos)
|
||||
for old, new in exact_pairs.items():
|
||||
if key == old or key == f"observation.{old}":
|
||||
new_key = new
|
||||
new_features[src_ft][new_key] = feat
|
||||
handled = True
|
||||
break
|
||||
|
||||
if handled:
|
||||
continue
|
||||
|
||||
# Default: keep key in the same source FeatureType bucket
|
||||
new_features[src_ft][key] = feat
|
||||
|
||||
return new_features
|
||||
|
||||
Reference in New Issue
Block a user