refactor(pipeline): Transition from tuple to dictionary format for EnvTransition

- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability.
- Replaced instances of TransitionIndex with TransitionKey for accessing transition components.
- Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
This commit is contained in:
Adil Zouitine
2025-07-21 14:54:31 +02:00
parent 14c2ece004
commit f2b79656eb
16 changed files with 828 additions and 650 deletions

View File

@@ -18,7 +18,7 @@ from typing import Any
import torch
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@dataclass
@@ -29,7 +29,7 @@ class RenameProcessor:
rename_map: dict[str, str] = field(default_factory=dict)
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
@@ -39,15 +39,11 @@ class RenameProcessor:
processed_obs[self.rename_map[key]] = value
else:
processed_obs[key] = value
return (
processed_obs,
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
# Create a new transition with the renamed observation
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
def get_config(self) -> dict[str, Any]:
return {"rename_map": self.rename_map}