refactor(pipeline): Transition from tuple to dictionary format for EnvTransition

- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability.
- Replaced instances of TransitionIndex with TransitionKey for accessing transition components.
- Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
This commit is contained in:
Adil Zouitine
2025-07-21 14:54:31 +02:00
parent 14c2ece004
commit f2b79656eb
16 changed files with 828 additions and 650 deletions

View File

@@ -32,7 +32,7 @@ from .pipeline import (
ProcessorStepRegistry,
RewardProcessor,
RobotProcessor,
TransitionIndex,
TransitionKey,
TruncatedProcessor,
)
from .rename_processor import RenameProcessor
@@ -54,7 +54,7 @@ __all__ = [
"RewardProcessor",
"RobotProcessor",
"StateProcessor",
"TransitionIndex",
"TransitionKey",
"TruncatedProcessor",
"VanillaObservationProcessor",
]

View File

@@ -18,7 +18,7 @@ from typing import Any
import torch
from lerobot.processor.pipeline import EnvTransition, TransitionIndex
from lerobot.processor.pipeline import EnvTransition, TransitionKey
@dataclass
@@ -35,30 +35,41 @@ class DeviceProcessor:
self.non_blocking = "cuda" in self.device
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation: dict[str, torch.Tensor] = transition[TransitionIndex.OBSERVATION]
action = transition[TransitionIndex.ACTION]
reward = transition[TransitionIndex.REWARD]
done = transition[TransitionIndex.DONE]
truncated = transition[TransitionIndex.TRUNCATED]
info = transition[TransitionIndex.INFO]
complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
# Create a copy of the transition
new_transition = transition.copy()
# Process observation tensors
observation = transition.get(TransitionKey.OBSERVATION)
if observation is not None:
observation = {
k: v.to(self.device, non_blocking=self.non_blocking) for k, v in observation.items()
new_observation = {
k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v
for k, v in observation.items()
}
if action is not None:
action = action.to(self.device)
new_transition[TransitionKey.OBSERVATION] = new_observation
return (
observation,
action,
reward,
done,
truncated,
info,
complementary_data,
)
# Process action tensor
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, torch.Tensor):
new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking)
# Process reward tensor
reward = transition.get(TransitionKey.REWARD)
if reward is not None and isinstance(reward, torch.Tensor):
new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking)
# Process done tensor
done = transition.get(TransitionKey.DONE)
if done is not None and isinstance(done, torch.Tensor):
new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking)
# Process truncated tensor
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is not None and isinstance(truncated, torch.Tensor):
new_transition[TransitionKey.TRUNCATED] = truncated.to(
self.device, non_blocking=self.non_blocking
)
return new_transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""

View File

@@ -1,8 +1,8 @@
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any
from collections.abc import Mapping
import numpy as np
import torch
@@ -10,7 +10,7 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
@@ -166,17 +166,14 @@ class NormalizerProcessor:
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = self._normalize_obs(transition[TransitionIndex.OBSERVATION])
action = self._normalize_action(transition[TransitionIndex.ACTION])
return (
observation,
action,
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._normalize_action(transition.get(TransitionKey.ACTION))
# Create a new transition with normalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
return new_transition
def get_config(self) -> dict[str, Any]:
config = {
@@ -297,17 +294,14 @@ class UnnormalizerProcessor:
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = self._unnormalize_obs(transition[TransitionIndex.OBSERVATION])
action = self._unnormalize_action(transition[TransitionIndex.ACTION])
return (
observation,
action,
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._unnormalize_action(transition.get(TransitionKey.ACTION))
# Create a new transition with unnormalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
return new_transition
def get_config(self) -> dict[str, Any]:
return {

View File

@@ -21,7 +21,7 @@ import numpy as np
import torch
from torch import Tensor
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@dataclass
@@ -36,7 +36,7 @@ class ImageProcessor:
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
@@ -60,15 +60,9 @@ class ImageProcessor:
processed_obs[key] = value
# Return new transition with processed observation
return (
processed_obs,
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
def _process_single_image(self, img: np.ndarray) -> Tensor:
"""Process a single image array."""
@@ -124,7 +118,7 @@ class StateProcessor:
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
@@ -150,15 +144,9 @@ class StateProcessor:
del processed_obs["agent_pos"]
# Return new transition with processed observation
return (
processed_obs,
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""

View File

@@ -18,39 +18,42 @@ from __future__ import annotations
import importlib
import json
import os
from dataclasses import dataclass, field
from enum import IntEnum
from pathlib import Path
from typing import Any, Protocol
from collections.abc import Callable, Iterable, Sequence
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Protocol, TypedDict
import torch
from huggingface_hub import ModelHubMixin, hf_hub_download
from safetensors.torch import load_file, save_file
class TransitionIndex(IntEnum):
"""Explicit indices for EnvTransition tuple components."""
class TransitionKey(str, Enum):
"""Keys for accessing EnvTransition dictionary components."""
OBSERVATION = 0
ACTION = 1
REWARD = 2
DONE = 3
TRUNCATED = 4
INFO = 5
COMPLEMENTARY_DATA = 6
OBSERVATION = "observation"
ACTION = "action"
REWARD = "reward"
DONE = "done"
TRUNCATED = "truncated"
INFO = "info"
COMPLEMENTARY_DATA = "complementary_data"
# (observation, action, reward, done, truncated, info, complementary_data)
EnvTransition = tuple[
dict[str, Any] | None, # observation
Any | torch.Tensor | None, # action
float | torch.Tensor | None, # reward
bool | torch.Tensor | None, # done
bool | torch.Tensor | None, # truncated
dict[str, Any] | None, # info
dict[str, Any] | None, # complementary_data
]
class EnvTransition(TypedDict, total=False):
"""Environment transition data structure.
All fields are optional (total=False) to allow flexible usage.
"""
observation: dict[str, Any] | None
action: Any | torch.Tensor | None
reward: float | torch.Tensor | None
done: bool | torch.Tensor | None
truncated: bool | torch.Tensor | None
info: dict[str, Any] | None
complementary_data: dict[str, Any] | None
class ProcessorStepRegistry:
@@ -165,10 +168,9 @@ class ProcessorStep(Protocol):
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
"""Convert a *batch* dict coming from Learobot replay/dataset code into an
``EnvTransition`` tuple.
``EnvTransition`` dictionary.
The function is intentionally **strictly positional** it maps well known
keys to the fixed slot order used inside the pipeline. Missing keys are
The function maps well known keys to the EnvTransition structure. Missing keys are
filled with sane defaults (``None`` or ``0.0``/``False``).
Keys recognised (case-sensitive):
@@ -193,15 +195,16 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq
task_key = {"task": batch["task"]} if "task" in batch else {}
complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {}
return (
observation,
batch.get("action"),
batch.get("next.reward", 0.0),
batch.get("next.done", False),
batch.get("next.truncated", False),
batch.get("info", {}),
complementary_data,
)
transition: EnvTransition = {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: batch.get("action"),
TransitionKey.REWARD: batch.get("next.reward", 0.0),
TransitionKey.DONE: batch.get("next.done", False),
TransitionKey.TRUNCATED: batch.get("next.truncated", False),
TransitionKey.INFO: batch.get("info", {}),
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
return transition
def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401
@@ -209,25 +212,16 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
the canonical field names used throughout *LeRobot*.
"""
(
observation,
action,
reward,
done,
truncated,
info,
complementary_data,
) = transition
batch = {
"action": action,
"next.reward": reward,
"next.done": done,
"next.truncated": truncated,
"info": info,
"action": transition.get(TransitionKey.ACTION),
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
"next.done": transition.get(TransitionKey.DONE, False),
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
"info": transition.get(TransitionKey.INFO, {}),
}
# Add padding and task data from complementary_data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data:
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
batch.update(pad_data)
@@ -236,6 +230,7 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
batch["task"] = complementary_data["task"]
# Handle observation - flatten dict to observation.* keys if it's a dict
observation = transition.get(TransitionKey.OBSERVATION)
if isinstance(observation, dict):
batch.update(observation)
@@ -293,33 +288,35 @@ class RobotProcessor(ModelHubMixin):
def __call__(self, data: EnvTransition | dict[str, Any]):
"""Process data through all steps.
The method accepts either the classic EnvTransition tuple or a batch dictionary
The method accepts either the classic EnvTransition dict or a batch dictionary
(like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied
it is first converted to the internal tuple format using to_transition; after all
steps are executed the tuple is transformed back into a dict with to_batch and the
it is first converted to the internal dict format using to_transition; after all
steps are executed the dict is transformed back into a batch dict with to_batch and the
result is returned thereby preserving the caller's original data type.
Args:
data: Either an EnvTransition tuple or a batch dictionary to process.
data: Either an EnvTransition dict or a batch dictionary to process.
Returns:
The processed data in the same format as the input (tuple or dict).
The processed data in the same format as the input (EnvTransition or batch dict).
Raises:
ValueError: If the transition is not a valid 7-tuple format.
ValueError: If the transition is not a valid EnvTransition format.
"""
called_with_batch = isinstance(data, dict)
# Check if data is already an EnvTransition or needs conversion
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
# It's a batch dict, convert it
called_with_batch = True
transition = self.to_transition(data)
else:
# It's already an EnvTransition
called_with_batch = False
transition = data
transition = self.to_transition(data) if called_with_batch else data
# Basic validation with helpful error message for tuple input
if not isinstance(transition, tuple) or len(transition) != 7:
raise ValueError(
"EnvTransition must be a 7-tuple of (observation, action, reward, done, "
"truncated, info, complementary_data). "
f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}."
)
# Basic validation
if not isinstance(transition, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
for idx, processor_step in enumerate(self.steps):
for hook in self.before_step_hooks:
@@ -339,25 +336,28 @@ class RobotProcessor(ModelHubMixin):
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
"""Yield the intermediate results after each processor step.
Like __call__, this method accepts either EnvTransition tuples or batch dictionaries
Like __call__, this method accepts either EnvTransition dicts or batch dictionaries
and preserves the input format in the yielded results.
Args:
data: Either an EnvTransition tuple or a batch dictionary to process.
data: Either an EnvTransition dict or a batch dictionary to process.
Yields:
The intermediate results after each step, in the same format as the input.
"""
called_with_batch = isinstance(data, dict)
transition = self.to_transition(data) if called_with_batch else data
# Check if data is already an EnvTransition or needs conversion
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
# It's a batch dict, convert it
called_with_batch = True
transition = self.to_transition(data)
else:
# It's already an EnvTransition
called_with_batch = False
transition = data
# Basic validation with helpful error message for tuple input
if not isinstance(transition, tuple) or len(transition) != 7:
raise ValueError(
"EnvTransition must be a 7-tuple of (observation, action, reward, done, "
"truncated, info, complementary_data). "
f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}."
)
# Basic validation
if not isinstance(transition, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
# Yield initial state
yield self.to_output(transition) if called_with_batch else transition
@@ -684,7 +684,7 @@ class ObservationProcessor:
Subclasses should override the `observation` method to implement custom observation processing.
This class handles the boilerplate of extracting and reinserting the processed observation
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
@@ -696,7 +696,7 @@ class ObservationProcessor:
return observation * self.scale_factor
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
By inheriting from this class, you avoid writing repetitive code to handle transition dict
manipulation, focusing only on the specific observation processing logic.
"""
@@ -712,10 +712,12 @@ class ObservationProcessor:
return observation
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
observation = self.observation(observation)
transition = (observation, *transition[TransitionIndex.ACTION :])
return transition
observation = transition.get(TransitionKey.OBSERVATION)
processed_observation = self.observation(observation)
# Create a new transition dict with the processed observation
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_observation
return new_transition
class ActionProcessor:
@@ -723,7 +725,7 @@ class ActionProcessor:
Subclasses should override the `action` method to implement custom action processing.
This class handles the boilerplate of extracting and reinserting the processed action
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
@@ -736,7 +738,7 @@ class ActionProcessor:
return np.clip(action, self.min_val, self.max_val)
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
By inheriting from this class, you avoid writing repetitive code to handle transition dict
manipulation, focusing only on the specific action processing logic.
"""
@@ -752,10 +754,12 @@ class ActionProcessor:
return action
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition[TransitionIndex.ACTION]
action = self.action(action)
transition = (transition[TransitionIndex.OBSERVATION], action, *transition[TransitionIndex.REWARD :])
return transition
action = transition.get(TransitionKey.ACTION)
processed_action = self.action(action)
# Create a new transition dict with the processed action
new_transition = transition.copy()
new_transition[TransitionKey.ACTION] = processed_action
return new_transition
class RewardProcessor:
@@ -763,7 +767,7 @@ class RewardProcessor:
Subclasses should override the `reward` method to implement custom reward processing.
This class handles the boilerplate of extracting and reinserting the processed reward
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
@@ -775,7 +779,7 @@ class RewardProcessor:
return reward * self.scale_factor
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
By inheriting from this class, you avoid writing repetitive code to handle transition dict
manipulation, focusing only on the specific reward processing logic.
"""
@@ -791,15 +795,12 @@ class RewardProcessor:
return reward
def __call__(self, transition: EnvTransition) -> EnvTransition:
reward = transition[TransitionIndex.REWARD]
reward = self.reward(reward)
transition = (
transition[TransitionIndex.OBSERVATION],
transition[TransitionIndex.ACTION],
reward,
*transition[TransitionIndex.DONE :],
)
return transition
reward = transition.get(TransitionKey.REWARD)
processed_reward = self.reward(reward)
# Create a new transition dict with the processed reward
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = processed_reward
return new_transition
class DoneProcessor:
@@ -807,7 +808,7 @@ class DoneProcessor:
Subclasses should override the `done` method to implement custom done flag processing.
This class handles the boilerplate of extracting and reinserting the processed done flag
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
@@ -824,7 +825,7 @@ class DoneProcessor:
self.steps = 0
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
By inheriting from this class, you avoid writing repetitive code to handle transition dict
manipulation, focusing only on the specific done flag processing logic.
"""
@@ -840,16 +841,12 @@ class DoneProcessor:
return done
def __call__(self, transition: EnvTransition) -> EnvTransition:
done = transition[TransitionIndex.DONE]
done = self.done(done)
transition = (
transition[TransitionIndex.OBSERVATION],
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
done,
*transition[TransitionIndex.TRUNCATED :],
)
return transition
done = transition.get(TransitionKey.DONE)
processed_done = self.done(done)
# Create a new transition dict with the processed done flag
new_transition = transition.copy()
new_transition[TransitionKey.DONE] = processed_done
return new_transition
class TruncatedProcessor:
@@ -857,7 +854,7 @@ class TruncatedProcessor:
Subclasses should override the `truncated` method to implement custom truncated flag processing.
This class handles the boilerplate of extracting and reinserting the processed truncated flag
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
@@ -870,7 +867,7 @@ class TruncatedProcessor:
return truncated or some_condition > self.threshold
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
By inheriting from this class, you avoid writing repetitive code to handle transition dict
manipulation, focusing only on the specific truncated flag processing logic.
"""
@@ -886,17 +883,12 @@ class TruncatedProcessor:
return truncated
def __call__(self, transition: EnvTransition) -> EnvTransition:
truncated = transition[TransitionIndex.TRUNCATED]
truncated = self.truncated(truncated)
transition = (
transition[TransitionIndex.OBSERVATION],
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
truncated,
*transition[TransitionIndex.INFO :],
)
return transition
truncated = transition.get(TransitionKey.TRUNCATED)
processed_truncated = self.truncated(truncated)
# Create a new transition dict with the processed truncated flag
new_transition = transition.copy()
new_transition[TransitionKey.TRUNCATED] = processed_truncated
return new_transition
class InfoProcessor:
@@ -904,7 +896,7 @@ class InfoProcessor:
Subclasses should override the `info` method to implement custom info processing.
This class handles the boilerplate of extracting and reinserting the processed info
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
@@ -922,7 +914,7 @@ class InfoProcessor:
self.step_count = 0
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
By inheriting from this class, you avoid writing repetitive code to handle transition dict
manipulation, focusing only on the specific info dictionary processing logic.
"""
@@ -938,18 +930,12 @@ class InfoProcessor:
return info
def __call__(self, transition: EnvTransition) -> EnvTransition:
info = transition[TransitionIndex.INFO]
info = self.info(info)
transition = (
transition[TransitionIndex.OBSERVATION],
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
info,
*transition[TransitionIndex.COMPLEMENTARY_DATA :],
)
return transition
info = transition.get(TransitionKey.INFO)
processed_info = self.info(info)
# Create a new transition dict with the processed info
new_transition = transition.copy()
new_transition[TransitionKey.INFO] = processed_info
return new_transition
class ComplementaryDataProcessor:
@@ -957,7 +943,7 @@ class ComplementaryDataProcessor:
Subclasses should override the `complementary_data` method to implement custom complementary data processing.
This class handles the boilerplate of extracting and reinserting the processed complementary data
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
"""
def complementary_data(self, complementary_data):
@@ -972,18 +958,12 @@ class ComplementaryDataProcessor:
return complementary_data
def __call__(self, transition: EnvTransition) -> EnvTransition:
complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
complementary_data = self.complementary_data(complementary_data)
transition = (
transition[TransitionIndex.OBSERVATION],
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
complementary_data,
)
return transition
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
processed_complementary_data = self.complementary_data(complementary_data)
# Create a new transition dict with the processed complementary data
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data
return new_transition
class IdentityProcessor:

View File

@@ -18,7 +18,7 @@ from typing import Any
import torch
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@dataclass
@@ -29,7 +29,7 @@ class RenameProcessor:
rename_map: dict[str, str] = field(default_factory=dict)
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
@@ -39,15 +39,11 @@ class RenameProcessor:
processed_obs[self.rename_map[key]] = value
else:
processed_obs[key] = value
return (
processed_obs,
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
# Create a new transition with the renamed observation
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
def get_config(self) -> dict[str, Any]:
return {"rename_map": self.rename_map}