refactor(pipeline): minor improvements (#1684)

* chore(pipeline): remove unused features + device torch + envtransition keys

* refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor

* refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code

* test(pipeline): fix broken test after refactors

* docs(pipeline): update docstrings VanillaObservationProcessor

* chore(pipeline): move None check to base pipeline classes
This commit is contained in:
Steven Palma
2025-08-06 14:00:13 +02:00
committed by GitHub
parent 7beb040e8e
commit fd4ae3466b
8 changed files with 165 additions and 421 deletions

View File

@@ -16,24 +16,21 @@
from dataclasses import dataclass, field
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
from lerobot.processor.pipeline import (
ObservationProcessor,
ProcessorStepRegistry,
)
@dataclass
@ProcessorStepRegistry.register(name="rename_processor")
class RenameProcessor:
class RenameProcessor(ObservationProcessor):
"""Rename processor that renames keys in the observation."""
rename_map: dict[str, str] = field(default_factory=dict)
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
def observation(self, observation):
processed_obs = {}
for key, value in observation.items():
if key in self.rename_map:
@@ -41,20 +38,11 @@ class RenameProcessor:
else:
processed_obs[key] = value
# Create a new transition with the renamed observation
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
return processed_obs
def get_config(self) -> dict[str, Any]:
return {"rename_map": self.rename_map}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms:
- Each key in the observation that appears in `rename_map` is renamed to its value.