mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 21:01:26 +00:00
refactor(pipeline): Transition from tuple to dictionary format for EnvTransition
- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
This commit is contained in:
@@ -32,7 +32,7 @@ from .pipeline import (
|
||||
ProcessorStepRegistry,
|
||||
RewardProcessor,
|
||||
RobotProcessor,
|
||||
TransitionIndex,
|
||||
TransitionKey,
|
||||
TruncatedProcessor,
|
||||
)
|
||||
from .rename_processor import RenameProcessor
|
||||
@@ -54,7 +54,7 @@ __all__ = [
|
||||
"RewardProcessor",
|
||||
"RobotProcessor",
|
||||
"StateProcessor",
|
||||
"TransitionIndex",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessor",
|
||||
"VanillaObservationProcessor",
|
||||
]
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionIndex
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -35,30 +35,41 @@ class DeviceProcessor:
|
||||
self.non_blocking = "cuda" in self.device
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation: dict[str, torch.Tensor] = transition[TransitionIndex.OBSERVATION]
|
||||
action = transition[TransitionIndex.ACTION]
|
||||
reward = transition[TransitionIndex.REWARD]
|
||||
done = transition[TransitionIndex.DONE]
|
||||
truncated = transition[TransitionIndex.TRUNCATED]
|
||||
info = transition[TransitionIndex.INFO]
|
||||
complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
|
||||
# Create a copy of the transition
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Process observation tensors
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is not None:
|
||||
observation = {
|
||||
k: v.to(self.device, non_blocking=self.non_blocking) for k, v in observation.items()
|
||||
new_observation = {
|
||||
k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in observation.items()
|
||||
}
|
||||
if action is not None:
|
||||
action = action.to(self.device)
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
|
||||
return (
|
||||
observation,
|
||||
action,
|
||||
reward,
|
||||
done,
|
||||
truncated,
|
||||
info,
|
||||
complementary_data,
|
||||
)
|
||||
# Process action tensor
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and isinstance(action, torch.Tensor):
|
||||
new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking)
|
||||
|
||||
# Process reward tensor
|
||||
reward = transition.get(TransitionKey.REWARD)
|
||||
if reward is not None and isinstance(reward, torch.Tensor):
|
||||
new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking)
|
||||
|
||||
# Process done tensor
|
||||
done = transition.get(TransitionKey.DONE)
|
||||
if done is not None and isinstance(done, torch.Tensor):
|
||||
new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking)
|
||||
|
||||
# Process truncated tensor
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
if truncated is not None and isinstance(truncated, torch.Tensor):
|
||||
new_transition[TransitionKey.TRUNCATED] = truncated.to(
|
||||
self.device, non_blocking=self.non_blocking
|
||||
)
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from collections.abc import Mapping
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -10,7 +10,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
|
||||
@@ -166,17 +166,14 @@ class NormalizerProcessor:
|
||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = self._normalize_obs(transition[TransitionIndex.OBSERVATION])
|
||||
action = self._normalize_action(transition[TransitionIndex.ACTION])
|
||||
return (
|
||||
observation,
|
||||
action,
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||
action = self._normalize_action(transition.get(TransitionKey.ACTION))
|
||||
|
||||
# Create a new transition with normalized values
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
config = {
|
||||
@@ -297,17 +294,14 @@ class UnnormalizerProcessor:
|
||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = self._unnormalize_obs(transition[TransitionIndex.OBSERVATION])
|
||||
action = self._unnormalize_action(transition[TransitionIndex.ACTION])
|
||||
return (
|
||||
observation,
|
||||
action,
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||
action = self._unnormalize_action(transition.get(TransitionKey.ACTION))
|
||||
|
||||
# Create a new transition with unnormalized values
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
||||
@@ -21,7 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -36,7 +36,7 @@ class ImageProcessor:
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition[TransitionIndex.OBSERVATION]
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
if observation is None:
|
||||
return transition
|
||||
@@ -60,15 +60,9 @@ class ImageProcessor:
|
||||
processed_obs[key] = value
|
||||
|
||||
# Return new transition with processed observation
|
||||
return (
|
||||
processed_obs,
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
|
||||
def _process_single_image(self, img: np.ndarray) -> Tensor:
|
||||
"""Process a single image array."""
|
||||
@@ -124,7 +118,7 @@ class StateProcessor:
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition[TransitionIndex.OBSERVATION]
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
if observation is None:
|
||||
return transition
|
||||
@@ -150,15 +144,9 @@ class StateProcessor:
|
||||
del processed_obs["agent_pos"]
|
||||
|
||||
# Return new transition with processed observation
|
||||
return (
|
||||
processed_obs,
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
|
||||
@@ -18,39 +18,42 @@ from __future__ import annotations
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol, TypedDict
|
||||
|
||||
import torch
|
||||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
class TransitionIndex(IntEnum):
|
||||
"""Explicit indices for EnvTransition tuple components."""
|
||||
class TransitionKey(str, Enum):
|
||||
"""Keys for accessing EnvTransition dictionary components."""
|
||||
|
||||
OBSERVATION = 0
|
||||
ACTION = 1
|
||||
REWARD = 2
|
||||
DONE = 3
|
||||
TRUNCATED = 4
|
||||
INFO = 5
|
||||
COMPLEMENTARY_DATA = 6
|
||||
OBSERVATION = "observation"
|
||||
ACTION = "action"
|
||||
REWARD = "reward"
|
||||
DONE = "done"
|
||||
TRUNCATED = "truncated"
|
||||
INFO = "info"
|
||||
COMPLEMENTARY_DATA = "complementary_data"
|
||||
|
||||
|
||||
# (observation, action, reward, done, truncated, info, complementary_data)
|
||||
EnvTransition = tuple[
|
||||
dict[str, Any] | None, # observation
|
||||
Any | torch.Tensor | None, # action
|
||||
float | torch.Tensor | None, # reward
|
||||
bool | torch.Tensor | None, # done
|
||||
bool | torch.Tensor | None, # truncated
|
||||
dict[str, Any] | None, # info
|
||||
dict[str, Any] | None, # complementary_data
|
||||
]
|
||||
class EnvTransition(TypedDict, total=False):
|
||||
"""Environment transition data structure.
|
||||
|
||||
All fields are optional (total=False) to allow flexible usage.
|
||||
"""
|
||||
|
||||
observation: dict[str, Any] | None
|
||||
action: Any | torch.Tensor | None
|
||||
reward: float | torch.Tensor | None
|
||||
done: bool | torch.Tensor | None
|
||||
truncated: bool | torch.Tensor | None
|
||||
info: dict[str, Any] | None
|
||||
complementary_data: dict[str, Any] | None
|
||||
|
||||
|
||||
class ProcessorStepRegistry:
|
||||
@@ -165,10 +168,9 @@ class ProcessorStep(Protocol):
|
||||
|
||||
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
|
||||
"""Convert a *batch* dict coming from Learobot replay/dataset code into an
|
||||
``EnvTransition`` tuple.
|
||||
``EnvTransition`` dictionary.
|
||||
|
||||
The function is intentionally **strictly positional** – it maps well known
|
||||
keys to the fixed slot order used inside the pipeline. Missing keys are
|
||||
The function maps well known keys to the EnvTransition structure. Missing keys are
|
||||
filled with sane defaults (``None`` or ``0.0``/``False``).
|
||||
|
||||
Keys recognised (case-sensitive):
|
||||
@@ -193,15 +195,16 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {}
|
||||
|
||||
return (
|
||||
observation,
|
||||
batch.get("action"),
|
||||
batch.get("next.reward", 0.0),
|
||||
batch.get("next.done", False),
|
||||
batch.get("next.truncated", False),
|
||||
batch.get("info", {}),
|
||||
complementary_data,
|
||||
)
|
||||
transition: EnvTransition = {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: batch.get("action"),
|
||||
TransitionKey.REWARD: batch.get("next.reward", 0.0),
|
||||
TransitionKey.DONE: batch.get("next.done", False),
|
||||
TransitionKey.TRUNCATED: batch.get("next.truncated", False),
|
||||
TransitionKey.INFO: batch.get("info", {}),
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
return transition
|
||||
|
||||
|
||||
def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401
|
||||
@@ -209,25 +212,16 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
the canonical field names used throughout *LeRobot*.
|
||||
"""
|
||||
|
||||
(
|
||||
observation,
|
||||
action,
|
||||
reward,
|
||||
done,
|
||||
truncated,
|
||||
info,
|
||||
complementary_data,
|
||||
) = transition
|
||||
|
||||
batch = {
|
||||
"action": action,
|
||||
"next.reward": reward,
|
||||
"next.done": done,
|
||||
"next.truncated": truncated,
|
||||
"info": info,
|
||||
"action": transition.get(TransitionKey.ACTION),
|
||||
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
|
||||
"next.done": transition.get(TransitionKey.DONE, False),
|
||||
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
|
||||
"info": transition.get(TransitionKey.INFO, {}),
|
||||
}
|
||||
|
||||
# Add padding and task data from complementary_data
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data:
|
||||
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
|
||||
batch.update(pad_data)
|
||||
@@ -236,6 +230,7 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
batch["task"] = complementary_data["task"]
|
||||
|
||||
# Handle observation - flatten dict to observation.* keys if it's a dict
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if isinstance(observation, dict):
|
||||
batch.update(observation)
|
||||
|
||||
@@ -293,33 +288,35 @@ class RobotProcessor(ModelHubMixin):
|
||||
def __call__(self, data: EnvTransition | dict[str, Any]):
|
||||
"""Process data through all steps.
|
||||
|
||||
The method accepts either the classic EnvTransition tuple or a batch dictionary
|
||||
The method accepts either the classic EnvTransition dict or a batch dictionary
|
||||
(like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied
|
||||
it is first converted to the internal tuple format using to_transition; after all
|
||||
steps are executed the tuple is transformed back into a dict with to_batch and the
|
||||
it is first converted to the internal dict format using to_transition; after all
|
||||
steps are executed the dict is transformed back into a batch dict with to_batch and the
|
||||
result is returned – thereby preserving the caller's original data type.
|
||||
|
||||
Args:
|
||||
data: Either an EnvTransition tuple or a batch dictionary to process.
|
||||
data: Either an EnvTransition dict or a batch dictionary to process.
|
||||
|
||||
Returns:
|
||||
The processed data in the same format as the input (tuple or dict).
|
||||
The processed data in the same format as the input (EnvTransition or batch dict).
|
||||
|
||||
Raises:
|
||||
ValueError: If the transition is not a valid 7-tuple format.
|
||||
ValueError: If the transition is not a valid EnvTransition format.
|
||||
"""
|
||||
|
||||
called_with_batch = isinstance(data, dict)
|
||||
# Check if data is already an EnvTransition or needs conversion
|
||||
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
|
||||
# It's a batch dict, convert it
|
||||
called_with_batch = True
|
||||
transition = self.to_transition(data)
|
||||
else:
|
||||
# It's already an EnvTransition
|
||||
called_with_batch = False
|
||||
transition = data
|
||||
|
||||
transition = self.to_transition(data) if called_with_batch else data
|
||||
|
||||
# Basic validation with helpful error message for tuple input
|
||||
if not isinstance(transition, tuple) or len(transition) != 7:
|
||||
raise ValueError(
|
||||
"EnvTransition must be a 7-tuple of (observation, action, reward, done, "
|
||||
"truncated, info, complementary_data). "
|
||||
f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}."
|
||||
)
|
||||
# Basic validation
|
||||
if not isinstance(transition, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
|
||||
|
||||
for idx, processor_step in enumerate(self.steps):
|
||||
for hook in self.before_step_hooks:
|
||||
@@ -339,25 +336,28 @@ class RobotProcessor(ModelHubMixin):
|
||||
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
|
||||
"""Yield the intermediate results after each processor step.
|
||||
|
||||
Like __call__, this method accepts either EnvTransition tuples or batch dictionaries
|
||||
Like __call__, this method accepts either EnvTransition dicts or batch dictionaries
|
||||
and preserves the input format in the yielded results.
|
||||
|
||||
Args:
|
||||
data: Either an EnvTransition tuple or a batch dictionary to process.
|
||||
data: Either an EnvTransition dict or a batch dictionary to process.
|
||||
|
||||
Yields:
|
||||
The intermediate results after each step, in the same format as the input.
|
||||
"""
|
||||
called_with_batch = isinstance(data, dict)
|
||||
transition = self.to_transition(data) if called_with_batch else data
|
||||
# Check if data is already an EnvTransition or needs conversion
|
||||
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
|
||||
# It's a batch dict, convert it
|
||||
called_with_batch = True
|
||||
transition = self.to_transition(data)
|
||||
else:
|
||||
# It's already an EnvTransition
|
||||
called_with_batch = False
|
||||
transition = data
|
||||
|
||||
# Basic validation with helpful error message for tuple input
|
||||
if not isinstance(transition, tuple) or len(transition) != 7:
|
||||
raise ValueError(
|
||||
"EnvTransition must be a 7-tuple of (observation, action, reward, done, "
|
||||
"truncated, info, complementary_data). "
|
||||
f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}."
|
||||
)
|
||||
# Basic validation
|
||||
if not isinstance(transition, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
|
||||
|
||||
# Yield initial state
|
||||
yield self.to_output(transition) if called_with_batch else transition
|
||||
@@ -684,7 +684,7 @@ class ObservationProcessor:
|
||||
|
||||
Subclasses should override the `observation` method to implement custom observation processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed observation
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -696,7 +696,7 @@ class ObservationProcessor:
|
||||
return observation * self.scale_factor
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||||
manipulation, focusing only on the specific observation processing logic.
|
||||
"""
|
||||
|
||||
@@ -712,10 +712,12 @@ class ObservationProcessor:
|
||||
return observation
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition[TransitionIndex.OBSERVATION]
|
||||
observation = self.observation(observation)
|
||||
transition = (observation, *transition[TransitionIndex.ACTION :])
|
||||
return transition
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
processed_observation = self.observation(observation)
|
||||
# Create a new transition dict with the processed observation
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_observation
|
||||
return new_transition
|
||||
|
||||
|
||||
class ActionProcessor:
|
||||
@@ -723,7 +725,7 @@ class ActionProcessor:
|
||||
|
||||
Subclasses should override the `action` method to implement custom action processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed action
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -736,7 +738,7 @@ class ActionProcessor:
|
||||
return np.clip(action, self.min_val, self.max_val)
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||||
manipulation, focusing only on the specific action processing logic.
|
||||
"""
|
||||
|
||||
@@ -752,10 +754,12 @@ class ActionProcessor:
|
||||
return action
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition[TransitionIndex.ACTION]
|
||||
action = self.action(action)
|
||||
transition = (transition[TransitionIndex.OBSERVATION], action, *transition[TransitionIndex.REWARD :])
|
||||
return transition
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
processed_action = self.action(action)
|
||||
# Create a new transition dict with the processed action
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.ACTION] = processed_action
|
||||
return new_transition
|
||||
|
||||
|
||||
class RewardProcessor:
|
||||
@@ -763,7 +767,7 @@ class RewardProcessor:
|
||||
|
||||
Subclasses should override the `reward` method to implement custom reward processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed reward
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -775,7 +779,7 @@ class RewardProcessor:
|
||||
return reward * self.scale_factor
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||||
manipulation, focusing only on the specific reward processing logic.
|
||||
"""
|
||||
|
||||
@@ -791,15 +795,12 @@ class RewardProcessor:
|
||||
return reward
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
reward = transition[TransitionIndex.REWARD]
|
||||
reward = self.reward(reward)
|
||||
transition = (
|
||||
transition[TransitionIndex.OBSERVATION],
|
||||
transition[TransitionIndex.ACTION],
|
||||
reward,
|
||||
*transition[TransitionIndex.DONE :],
|
||||
)
|
||||
return transition
|
||||
reward = transition.get(TransitionKey.REWARD)
|
||||
processed_reward = self.reward(reward)
|
||||
# Create a new transition dict with the processed reward
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.REWARD] = processed_reward
|
||||
return new_transition
|
||||
|
||||
|
||||
class DoneProcessor:
|
||||
@@ -807,7 +808,7 @@ class DoneProcessor:
|
||||
|
||||
Subclasses should override the `done` method to implement custom done flag processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed done flag
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -824,7 +825,7 @@ class DoneProcessor:
|
||||
self.steps = 0
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||||
manipulation, focusing only on the specific done flag processing logic.
|
||||
"""
|
||||
|
||||
@@ -840,16 +841,12 @@ class DoneProcessor:
|
||||
return done
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
done = transition[TransitionIndex.DONE]
|
||||
done = self.done(done)
|
||||
transition = (
|
||||
transition[TransitionIndex.OBSERVATION],
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
done,
|
||||
*transition[TransitionIndex.TRUNCATED :],
|
||||
)
|
||||
return transition
|
||||
done = transition.get(TransitionKey.DONE)
|
||||
processed_done = self.done(done)
|
||||
# Create a new transition dict with the processed done flag
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.DONE] = processed_done
|
||||
return new_transition
|
||||
|
||||
|
||||
class TruncatedProcessor:
|
||||
@@ -857,7 +854,7 @@ class TruncatedProcessor:
|
||||
|
||||
Subclasses should override the `truncated` method to implement custom truncated flag processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed truncated flag
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -870,7 +867,7 @@ class TruncatedProcessor:
|
||||
return truncated or some_condition > self.threshold
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||||
manipulation, focusing only on the specific truncated flag processing logic.
|
||||
"""
|
||||
|
||||
@@ -886,17 +883,12 @@ class TruncatedProcessor:
|
||||
return truncated
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
truncated = transition[TransitionIndex.TRUNCATED]
|
||||
truncated = self.truncated(truncated)
|
||||
transition = (
|
||||
transition[TransitionIndex.OBSERVATION],
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
truncated,
|
||||
*transition[TransitionIndex.INFO :],
|
||||
)
|
||||
return transition
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
processed_truncated = self.truncated(truncated)
|
||||
# Create a new transition dict with the processed truncated flag
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.TRUNCATED] = processed_truncated
|
||||
return new_transition
|
||||
|
||||
|
||||
class InfoProcessor:
|
||||
@@ -904,7 +896,7 @@ class InfoProcessor:
|
||||
|
||||
Subclasses should override the `info` method to implement custom info processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed info
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -922,7 +914,7 @@ class InfoProcessor:
|
||||
self.step_count = 0
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||||
manipulation, focusing only on the specific info dictionary processing logic.
|
||||
"""
|
||||
|
||||
@@ -938,18 +930,12 @@ class InfoProcessor:
|
||||
return info
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
info = transition[TransitionIndex.INFO]
|
||||
info = self.info(info)
|
||||
transition = (
|
||||
transition[TransitionIndex.OBSERVATION],
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
info,
|
||||
*transition[TransitionIndex.COMPLEMENTARY_DATA :],
|
||||
)
|
||||
return transition
|
||||
info = transition.get(TransitionKey.INFO)
|
||||
processed_info = self.info(info)
|
||||
# Create a new transition dict with the processed info
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.INFO] = processed_info
|
||||
return new_transition
|
||||
|
||||
|
||||
class ComplementaryDataProcessor:
|
||||
@@ -957,7 +943,7 @@ class ComplementaryDataProcessor:
|
||||
|
||||
Subclasses should override the `complementary_data` method to implement custom complementary data processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed complementary data
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
@@ -972,18 +958,12 @@ class ComplementaryDataProcessor:
|
||||
return complementary_data
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
|
||||
complementary_data = self.complementary_data(complementary_data)
|
||||
transition = (
|
||||
transition[TransitionIndex.OBSERVATION],
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
complementary_data,
|
||||
)
|
||||
return transition
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
processed_complementary_data = self.complementary_data(complementary_data)
|
||||
# Create a new transition dict with the processed complementary data
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data
|
||||
return new_transition
|
||||
|
||||
|
||||
class IdentityProcessor:
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -29,7 +29,7 @@ class RenameProcessor:
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition[TransitionIndex.OBSERVATION]
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
return transition
|
||||
|
||||
@@ -39,15 +39,11 @@ class RenameProcessor:
|
||||
processed_obs[self.rename_map[key]] = value
|
||||
else:
|
||||
processed_obs[key] = value
|
||||
return (
|
||||
processed_obs,
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
|
||||
# Create a new transition with the renamed observation
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"rename_map": self.rename_map}
|
||||
|
||||
Reference in New Issue
Block a user