mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 19:31:25 +00:00
refactor(pipeline): Transition from tuple to dictionary format for EnvTransition
- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
This commit is contained in:
@@ -21,7 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -36,7 +36,7 @@ class ImageProcessor:
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition[TransitionIndex.OBSERVATION]
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
if observation is None:
|
||||
return transition
|
||||
@@ -60,15 +60,9 @@ class ImageProcessor:
|
||||
processed_obs[key] = value
|
||||
|
||||
# Return new transition with processed observation
|
||||
return (
|
||||
processed_obs,
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
|
||||
def _process_single_image(self, img: np.ndarray) -> Tensor:
|
||||
"""Process a single image array."""
|
||||
@@ -124,7 +118,7 @@ class StateProcessor:
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition[TransitionIndex.OBSERVATION]
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
if observation is None:
|
||||
return transition
|
||||
@@ -150,15 +144,9 @@ class StateProcessor:
|
||||
del processed_obs["agent_pos"]
|
||||
|
||||
# Return new transition with processed observation
|
||||
return (
|
||||
processed_obs,
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
|
||||
Reference in New Issue
Block a user