refactor(pipeline): Transition from tuple to dictionary format for EnvTransition

- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability.
- Replaced instances of TransitionIndex with TransitionKey for accessing transition components.
- Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
This commit is contained in:
Adil Zouitine
2025-07-21 14:54:31 +02:00
parent 14c2ece004
commit f2b79656eb
16 changed files with 828 additions and 650 deletions

View File

@@ -21,7 +21,7 @@ import numpy as np
import torch
from torch import Tensor
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@dataclass
@@ -36,7 +36,7 @@ class ImageProcessor:
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
@@ -60,15 +60,9 @@ class ImageProcessor:
processed_obs[key] = value
# Return new transition with processed observation
return (
processed_obs,
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
def _process_single_image(self, img: np.ndarray) -> Tensor:
"""Process a single image array."""
@@ -124,7 +118,7 @@ class StateProcessor:
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
@@ -150,15 +144,9 @@ class StateProcessor:
del processed_obs["agent_pos"]
# Return new transition with processed observation
return (
processed_obs,
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""