mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
740 lines
28 KiB
Python
740 lines
28 KiB
Python
#!/usr/bin/env python
|
||
|
||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
from __future__ import annotations
|
||
|
||
import importlib
|
||
import json
|
||
import os
|
||
from dataclasses import dataclass, field
|
||
from enum import IntEnum
|
||
from pathlib import Path
|
||
from typing import Any, Callable, Dict, Iterable, Protocol, Sequence, Tuple
|
||
|
||
import torch
|
||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||
from safetensors.torch import load_file, save_file
|
||
|
||
|
||
class TransitionIndex(IntEnum):
|
||
"""Explicit indices for EnvTransition tuple components."""
|
||
|
||
OBSERVATION = 0
|
||
ACTION = 1
|
||
REWARD = 2
|
||
DONE = 3
|
||
TRUNCATED = 4
|
||
INFO = 5
|
||
COMPLEMENTARY_DATA = 6
|
||
|
||
|
||
# (observation, action, reward, done, truncated, info, complementary_data)
|
||
EnvTransition = Tuple[
|
||
Any | None, # observation
|
||
Any | None, # action
|
||
float | None, # reward
|
||
bool | None, # done
|
||
bool | None, # truncated
|
||
Dict[str, Any] | None, # info
|
||
Dict[str, Any] | None, # complementary_data
|
||
]
|
||
|
||
|
||
class ProcessorStepRegistry:
|
||
"""Registry for processor steps that enables saving/loading by name instead of module path."""
|
||
|
||
_registry: dict[str, type] = {}
|
||
|
||
@classmethod
|
||
def register(cls, name: str = None):
|
||
"""Decorator to register a processor step class.
|
||
|
||
Args:
|
||
name: Optional registration name. If not provided, uses class name.
|
||
|
||
Example:
|
||
@ProcessorStepRegistry.register("adaptive_normalizer")
|
||
class AdaptiveObservationNormalizer:
|
||
...
|
||
"""
|
||
|
||
def decorator(step_class: type) -> type:
|
||
registration_name = name if name is not None else step_class.__name__
|
||
|
||
if registration_name in cls._registry:
|
||
raise ValueError(
|
||
f"Processor step '{registration_name}' is already registered. "
|
||
f"Use a different name or unregister the existing one first."
|
||
)
|
||
|
||
cls._registry[registration_name] = step_class
|
||
# Store the registration name on the class for later reference
|
||
step_class._registry_name = registration_name
|
||
return step_class
|
||
|
||
return decorator
|
||
|
||
@classmethod
|
||
def get(cls, name: str) -> type:
|
||
"""Get a registered processor step class by name.
|
||
|
||
Args:
|
||
name: The registration name of the step.
|
||
|
||
Returns:
|
||
The registered step class.
|
||
|
||
Raises:
|
||
KeyError: If the step is not registered.
|
||
"""
|
||
if name not in cls._registry:
|
||
available = list(cls._registry.keys())
|
||
raise KeyError(
|
||
f"Processor step '{name}' not found in registry. "
|
||
f"Available steps: {available}. "
|
||
f"Make sure the step is registered using @ProcessorStepRegistry.register()"
|
||
)
|
||
return cls._registry[name]
|
||
|
||
@classmethod
|
||
def unregister(cls, name: str) -> None:
|
||
"""Remove a step from the registry."""
|
||
cls._registry.pop(name, None)
|
||
|
||
@classmethod
|
||
def list(cls) -> list[str]:
|
||
"""List all registered step names."""
|
||
return list(cls._registry.keys())
|
||
|
||
@classmethod
|
||
def clear(cls) -> None:
|
||
"""Clear all registrations."""
|
||
cls._registry.clear()
|
||
|
||
|
||
class ProcessorStep(Protocol):
|
||
"""Structural typing interface for a single processor step.
|
||
|
||
A step is any callable accepting a full `EnvTransition` tuple and
|
||
returning a (possibly modified) tuple of the same structure. Implementers
|
||
are encouraged—but not required—to expose the optional helper methods
|
||
listed below. When present, these hooks let `RobotProcessor`
|
||
automatically serialise the step's configuration and learnable state using
|
||
a safe-to-share JSON + SafeTensors format.
|
||
|
||
Optional helper protocol:
|
||
* ``get_config() -> Dict[str, Any]`` – User-defined JSON-serializable
|
||
configuration and state. YOU decide what to save here. This is where all
|
||
non-tensor state goes (e.g., name, counter, threshold, window_size).
|
||
The config dict will be passed to your class constructor when loading.
|
||
* ``state_dict() -> Dict[str, torch.Tensor]`` – PyTorch tensor state ONLY.
|
||
This is exclusively for torch.Tensor objects (e.g., learned weights,
|
||
running statistics as tensors). Never put simple Python types here.
|
||
* ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict
|
||
containing torch tensors only.
|
||
* ``reset()`` – Clear internal buffers at episode boundaries.
|
||
|
||
Example separation:
|
||
- get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10}
|
||
- state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)}
|
||
"""
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition: ...
|
||
|
||
def get_config(self) -> dict[str, Any]: ...
|
||
|
||
def state_dict(self) -> dict[str, torch.Tensor]: ...
|
||
|
||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ...
|
||
|
||
def reset(self) -> None: ...
|
||
|
||
|
||
@dataclass
|
||
class RobotProcessor(ModelHubMixin):
|
||
"""
|
||
Composable, debuggable post-processing processor for robot transitions.
|
||
The class orchestrates an ordered collection of small, functional
|
||
transforms—steps—executed left-to-right on each incoming
|
||
`EnvTransition`.
|
||
Parameters:
|
||
steps : Sequence[ProcessorStep], optional
|
||
Ordered list executed on every call
|
||
name : str, default="RobotProcessor"
|
||
Human-readable identifier that is persisted inside the JSON config.
|
||
seed : int | None, optional
|
||
Global seed forwarded to steps that choose to consume it.
|
||
Examples:
|
||
Basic usage::
|
||
env = gym.make("CartPole-v1")
|
||
proc = RobotProcessor([
|
||
ObservationNormalizer(),
|
||
IntrinsicVelocity(),
|
||
VelocityBonus(0.02),
|
||
])
|
||
obs, info = env.reset(seed=0)
|
||
tr = (obs, None, 0.0, False, False, info, {})
|
||
obs, *_ = proc(tr) # agent sees a normalised observation
|
||
Inspecting intermediate results::
|
||
for idx, step_tr in enumerate(proc.step_through(tr)):
|
||
print(idx, step_tr)
|
||
Serialization to the Hugging Face Hub::
|
||
proc.save_pretrained("chkpt")
|
||
proc.push_to_hub("my-org/cartpole_proc")
|
||
loaded = RobotProcessor.from_pretrained("my-org/cartpole_proc")
|
||
"""
|
||
|
||
steps: Sequence[ProcessorStep] = field(default_factory=list)
|
||
name: str = "RobotProcessor"
|
||
seed: int | None = None
|
||
|
||
# Processor-level hooks
|
||
# A hook can optionally return a modified transition. If it returns
|
||
# ``None`` the current value is left untouched.
|
||
before_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
|
||
default_factory=list, repr=False
|
||
)
|
||
after_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
|
||
default_factory=list, repr=False
|
||
)
|
||
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
"""Run *transition* through every step, firing hooks on the way."""
|
||
|
||
# Basic validation with helpful error message
|
||
if not isinstance(transition, tuple) or len(transition) != 7:
|
||
raise ValueError(
|
||
f"EnvTransition must be a 7-tuple of (observation, action, reward, done, truncated, info, complementary_data), "
|
||
f"got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}"
|
||
)
|
||
|
||
for idx, processor_step in enumerate(self.steps):
|
||
for hook in self.before_step_hooks:
|
||
updated = hook(idx, transition)
|
||
if updated is not None:
|
||
transition = updated
|
||
|
||
transition = processor_step(transition)
|
||
|
||
for hook in self.after_step_hooks:
|
||
updated = hook(idx, transition)
|
||
if updated is not None:
|
||
transition = updated
|
||
|
||
return transition
|
||
|
||
def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]:
|
||
"""Yield the intermediate Transition instances after each processor step."""
|
||
yield transition
|
||
for processor_step in self.steps:
|
||
transition = processor_step(transition)
|
||
yield transition
|
||
|
||
_CFG_NAME = "processor.json"
|
||
|
||
def _save_pretrained(self, destination_path: str, **kwargs):
|
||
"""Internal save method for ModelHubMixin compatibility."""
|
||
self.save_pretrained(destination_path)
|
||
|
||
def _generate_model_card(self, destination_path: str) -> None:
|
||
"""Generate README.md from the RobotProcessor model card template."""
|
||
# Read the template
|
||
template_path = Path(__file__).parent.parent / "templates" / "robotprocessor_modelcard_template.md"
|
||
|
||
if not template_path.exists():
|
||
# Fallback: if template doesn't exist, skip model card generation
|
||
return
|
||
|
||
with open(template_path) as f:
|
||
model_card_content = f.read()
|
||
|
||
# Write the README.md
|
||
readme_path = os.path.join(destination_path, "README.md")
|
||
with open(readme_path, "w") as f:
|
||
f.write(model_card_content)
|
||
|
||
def save_pretrained(self, destination_path: str, **kwargs):
|
||
"""Serialize the processor definition and parameters to *destination_path*."""
|
||
os.makedirs(destination_path, exist_ok=True)
|
||
|
||
config: dict[str, Any] = {
|
||
"name": self.name,
|
||
"seed": self.seed,
|
||
"steps": [],
|
||
}
|
||
|
||
for step_index, processor_step in enumerate(self.steps):
|
||
# Check if step was registered
|
||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||
|
||
if registry_name:
|
||
# Use registry name for registered steps
|
||
step_entry: dict[str, Any] = {
|
||
"registry_name": registry_name,
|
||
}
|
||
else:
|
||
# Fall back to full module path for unregistered steps
|
||
step_entry: dict[str, Any] = {
|
||
"class": f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}",
|
||
}
|
||
|
||
if hasattr(processor_step, "get_config"):
|
||
step_entry["config"] = processor_step.get_config()
|
||
|
||
if hasattr(processor_step, "state_dict"):
|
||
state = processor_step.state_dict()
|
||
if state:
|
||
# Clone tensors to avoid shared memory issues
|
||
# This ensures each tensor has its own memory allocation
|
||
# The reason is to avoid the following error:
|
||
# RuntimeError: Some tensors share memory, this will lead to duplicate memory on disk
|
||
# and potential differences when loading them again
|
||
# ------------------------------------------------------------------------------
|
||
# Since the state_dict of processor will be light, we can just clone the tensors
|
||
# and save them to the disk.
|
||
cloned_state = {}
|
||
for key, tensor in state.items():
|
||
cloned_state[key] = tensor.clone()
|
||
|
||
state_filename = f"step_{step_index}.safetensors"
|
||
save_file(cloned_state, os.path.join(destination_path, state_filename))
|
||
step_entry["state_file"] = state_filename
|
||
|
||
config["steps"].append(step_entry)
|
||
|
||
with open(os.path.join(destination_path, self._CFG_NAME), "w") as file_pointer:
|
||
json.dump(config, file_pointer, indent=2)
|
||
|
||
# Generate README.md from template
|
||
self._generate_model_card(destination_path)
|
||
|
||
def to(self, device: str | torch.device):
|
||
"""Move all tensor states inside each step to device and return self.
|
||
|
||
Uses a generic mechanism: fetch each step's state dict, move every tensor
|
||
to the target device, and reload it. Only works for steps that implement
|
||
both state_dict() and load_state_dict() methods.
|
||
"""
|
||
device = torch.device(device)
|
||
|
||
for step in self.steps:
|
||
if hasattr(step, "state_dict") and hasattr(step, "load_state_dict"):
|
||
state = step.state_dict()
|
||
if state: # Only process if there's actual state
|
||
moved_state = {k: v.to(device) for k, v in state.items()}
|
||
step.load_state_dict(moved_state)
|
||
|
||
return self
|
||
|
||
@classmethod
|
||
def from_pretrained(cls, source: str) -> RobotProcessor:
|
||
"""Load a serialized processor from source (local path or Hugging Face Hub identifier)."""
|
||
if Path(source).is_dir():
|
||
# Local path - use it directly
|
||
base_path = Path(source)
|
||
with open(base_path / cls._CFG_NAME) as file_pointer:
|
||
config: dict[str, Any] = json.load(file_pointer)
|
||
else:
|
||
# Hugging Face Hub - download all required files
|
||
# First download the config file
|
||
config_path = hf_hub_download(source, cls._CFG_NAME, repo_type="model")
|
||
with open(config_path) as file_pointer:
|
||
config: dict[str, Any] = json.load(file_pointer)
|
||
|
||
# Store downloaded files in the same directory as the config
|
||
base_path = Path(config_path).parent
|
||
|
||
steps: list[ProcessorStep] = []
|
||
for step_entry in config["steps"]:
|
||
# Check if step uses registry name or module path
|
||
if "registry_name" in step_entry:
|
||
# Load from registry
|
||
try:
|
||
step_class = ProcessorStepRegistry.get(step_entry["registry_name"])
|
||
except KeyError as e:
|
||
raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e
|
||
else:
|
||
# Fall back to module path loading for backward compatibility
|
||
full_class_path = step_entry["class"]
|
||
module_path, class_name = full_class_path.rsplit(".", 1)
|
||
|
||
# Import the module containing the step class
|
||
try:
|
||
module = importlib.import_module(module_path)
|
||
step_class = getattr(module, class_name)
|
||
except (ImportError, AttributeError) as e:
|
||
raise ImportError(
|
||
f"Failed to load processor step '{full_class_path}'. "
|
||
f"Make sure the module '{module_path}' is installed and contains class '{class_name}'. "
|
||
f"Consider registering the step using @ProcessorStepRegistry.register() for better portability. "
|
||
f"Error: {str(e)}"
|
||
) from e
|
||
|
||
# Instantiate the step with its config
|
||
try:
|
||
step_instance: ProcessorStep = step_class(**step_entry.get("config", {}))
|
||
except Exception as e:
|
||
step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown"))
|
||
raise ValueError(
|
||
f"Failed to instantiate processor step '{step_name}' with config: {step_entry.get('config', {})}. "
|
||
f"Error: {str(e)}"
|
||
) from e
|
||
|
||
# Load state if available
|
||
if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"):
|
||
if Path(source).is_dir():
|
||
# Local path - read directly
|
||
state_path = str(base_path / step_entry["state_file"])
|
||
else:
|
||
# Hugging Face Hub - download the state file
|
||
state_path = hf_hub_download(source, step_entry["state_file"], repo_type="model")
|
||
|
||
step_instance.load_state_dict(load_file(state_path))
|
||
|
||
steps.append(step_instance)
|
||
|
||
return cls(steps, config.get("name", "RobotProcessor"), config.get("seed"))
|
||
|
||
def __len__(self) -> int:
|
||
"""Return the number of steps in the processor."""
|
||
return len(self.steps)
|
||
|
||
def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor:
|
||
"""Indexing helper exposing underlying steps.
|
||
* ``int`` – returns the idx-th ProcessorStep.
|
||
* ``slice`` – returns a new RobotProcessor with the sliced steps.
|
||
"""
|
||
if isinstance(idx, slice):
|
||
return RobotProcessor(self.steps[idx], self.name, self.seed)
|
||
return self.steps[idx]
|
||
|
||
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||
"""Attach fn to be executed before every processor step."""
|
||
self.before_step_hooks.append(fn)
|
||
|
||
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||
"""Attach fn to be executed after every processor step."""
|
||
self.after_step_hooks.append(fn)
|
||
|
||
def register_reset_hook(self, fn: Callable[[], None]):
|
||
"""Attach fn to be executed when reset is called."""
|
||
self.reset_hooks.append(fn)
|
||
|
||
def reset(self):
|
||
"""Clear state in every step that implements ``reset()`` and fire registered hooks."""
|
||
for step in self.steps:
|
||
if hasattr(step, "reset"):
|
||
step.reset() # type: ignore[attr-defined]
|
||
for fn in self.reset_hooks:
|
||
fn()
|
||
|
||
def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> dict[str, float]:
|
||
"""Profile the execution time of each step for performance optimization."""
|
||
import time
|
||
|
||
profile_results = {}
|
||
|
||
for idx, processor_step in enumerate(self.steps):
|
||
step_name = f"step_{idx}_{processor_step.__class__.__name__}"
|
||
|
||
# Warm up
|
||
for _ in range(5):
|
||
_ = processor_step(transition)
|
||
|
||
# Time the step
|
||
start_time = time.perf_counter()
|
||
for _ in range(num_runs):
|
||
transition = processor_step(transition)
|
||
end_time = time.perf_counter()
|
||
|
||
avg_time = (end_time - start_time) / num_runs * 1000 # Convert to milliseconds
|
||
profile_results[step_name] = avg_time
|
||
|
||
return profile_results
|
||
|
||
|
||
class ObservationProcessor:
|
||
"""Base class for processors that modify only the observation component of a transition.
|
||
|
||
Subclasses should override the `observation` method to implement custom observation processing.
|
||
This class handles the boilerplate of extracting and reinserting the processed observation
|
||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||
|
||
Example:
|
||
```python
|
||
class MyObservationScaler(ObservationProcessor):
|
||
def __init__(self, scale_factor):
|
||
self.scale_factor = scale_factor
|
||
|
||
def observation(self, observation):
|
||
return observation * self.scale_factor
|
||
```
|
||
|
||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||
manipulation, focusing only on the specific observation processing logic.
|
||
"""
|
||
|
||
def observation(self, observation):
|
||
"""Process the observation component.
|
||
|
||
Args:
|
||
observation: The observation to process
|
||
|
||
Returns:
|
||
The processed observation
|
||
"""
|
||
return observation
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
observation = transition[TransitionIndex.OBSERVATION]
|
||
observation = self.observation(observation)
|
||
transition = (observation, *transition[TransitionIndex.ACTION :])
|
||
return transition
|
||
|
||
|
||
class ActionProcessor:
|
||
"""Base class for processors that modify only the action component of a transition.
|
||
|
||
Subclasses should override the `action` method to implement custom action processing.
|
||
This class handles the boilerplate of extracting and reinserting the processed action
|
||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||
|
||
Example:
|
||
```python
|
||
class ActionClipping(ActionProcessor):
|
||
def __init__(self, min_val, max_val):
|
||
self.min_val = min_val
|
||
self.max_val = max_val
|
||
|
||
def action(self, action):
|
||
return np.clip(action, self.min_val, self.max_val)
|
||
```
|
||
|
||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||
manipulation, focusing only on the specific action processing logic.
|
||
"""
|
||
|
||
def action(self, action):
|
||
"""Process the action component.
|
||
|
||
Args:
|
||
action: The action to process
|
||
|
||
Returns:
|
||
The processed action
|
||
"""
|
||
return action
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
action = transition[TransitionIndex.ACTION]
|
||
action = self.action(action)
|
||
transition = (transition[TransitionIndex.OBSERVATION], action, *transition[TransitionIndex.REWARD :])
|
||
return transition
|
||
|
||
|
||
class RewardProcessor:
|
||
"""Base class for processors that modify only the reward component of a transition.
|
||
|
||
Subclasses should override the `reward` method to implement custom reward processing.
|
||
This class handles the boilerplate of extracting and reinserting the processed reward
|
||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||
|
||
Example:
|
||
```python
|
||
class RewardScaler(RewardProcessor):
|
||
def __init__(self, scale_factor):
|
||
self.scale_factor = scale_factor
|
||
|
||
def reward(self, reward):
|
||
return reward * self.scale_factor
|
||
```
|
||
|
||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||
manipulation, focusing only on the specific reward processing logic.
|
||
"""
|
||
|
||
def reward(self, reward):
|
||
"""Process the reward component.
|
||
|
||
Args:
|
||
reward: The reward to process
|
||
|
||
Returns:
|
||
The processed reward
|
||
"""
|
||
return reward
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
reward = transition[TransitionIndex.REWARD]
|
||
reward = self.reward(reward)
|
||
transition = (
|
||
transition[TransitionIndex.OBSERVATION],
|
||
transition[TransitionIndex.ACTION],
|
||
reward,
|
||
*transition[TransitionIndex.DONE :],
|
||
)
|
||
return transition
|
||
|
||
|
||
class DoneProcessor:
|
||
"""Base class for processors that modify only the done flag of a transition.
|
||
|
||
Subclasses should override the `done` method to implement custom done flag processing.
|
||
This class handles the boilerplate of extracting and reinserting the processed done flag
|
||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||
|
||
Example:
|
||
```python
|
||
class TimeoutDone(DoneProcessor):
|
||
def __init__(self, max_steps):
|
||
self.steps = 0
|
||
self.max_steps = max_steps
|
||
|
||
def done(self, done):
|
||
self.steps += 1
|
||
return done or self.steps >= self.max_steps
|
||
|
||
def reset(self):
|
||
self.steps = 0
|
||
```
|
||
|
||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||
manipulation, focusing only on the specific done flag processing logic.
|
||
"""
|
||
|
||
def done(self, done):
|
||
"""Process the done flag.
|
||
|
||
Args:
|
||
done: The done flag to process
|
||
|
||
Returns:
|
||
The processed done flag
|
||
"""
|
||
return done
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
done = transition[TransitionIndex.DONE]
|
||
done = self.done(done)
|
||
transition = (
|
||
transition[TransitionIndex.OBSERVATION],
|
||
transition[TransitionIndex.ACTION],
|
||
transition[TransitionIndex.REWARD],
|
||
done,
|
||
*transition[TransitionIndex.TRUNCATED :],
|
||
)
|
||
return transition
|
||
|
||
|
||
class TruncatedProcessor:
|
||
"""Base class for processors that modify only the truncated flag of a transition.
|
||
|
||
Subclasses should override the `truncated` method to implement custom truncated flag processing.
|
||
This class handles the boilerplate of extracting and reinserting the processed truncated flag
|
||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||
|
||
Example:
|
||
```python
|
||
class EarlyTruncation(TruncatedProcessor):
|
||
def __init__(self, threshold):
|
||
self.threshold = threshold
|
||
|
||
def truncated(self, truncated):
|
||
# Additional truncation condition
|
||
return truncated or some_condition > self.threshold
|
||
```
|
||
|
||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||
manipulation, focusing only on the specific truncated flag processing logic.
|
||
"""
|
||
|
||
def truncated(self, truncated):
|
||
"""Process the truncated flag.
|
||
|
||
Args:
|
||
truncated: The truncated flag to process
|
||
|
||
Returns:
|
||
The processed truncated flag
|
||
"""
|
||
return truncated
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
truncated = transition[TransitionIndex.TRUNCATED]
|
||
truncated = self.truncated(truncated)
|
||
transition = (
|
||
transition[TransitionIndex.OBSERVATION],
|
||
transition[TransitionIndex.ACTION],
|
||
transition[TransitionIndex.REWARD],
|
||
transition[TransitionIndex.DONE],
|
||
truncated,
|
||
*transition[TransitionIndex.INFO :],
|
||
)
|
||
return transition
|
||
|
||
|
||
class InfoProcessor:
|
||
"""Base class for processors that modify only the info dictionary of a transition.
|
||
|
||
Subclasses should override the `info` method to implement custom info processing.
|
||
This class handles the boilerplate of extracting and reinserting the processed info
|
||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||
|
||
Example:
|
||
```python
|
||
class InfoAugmenter(InfoProcessor):
|
||
def __init__(self):
|
||
self.step_count = 0
|
||
|
||
def info(self, info):
|
||
info = info.copy() # Create a copy to avoid modifying the original
|
||
info['steps'] = self.step_count
|
||
self.step_count += 1
|
||
return info
|
||
|
||
def reset(self):
|
||
self.step_count = 0
|
||
```
|
||
|
||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||
manipulation, focusing only on the specific info dictionary processing logic.
|
||
"""
|
||
|
||
def info(self, info):
|
||
"""Process the info dictionary.
|
||
|
||
Args:
|
||
info: The info dictionary to process
|
||
|
||
Returns:
|
||
The processed info dictionary
|
||
"""
|
||
return info
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
info = transition[TransitionIndex.INFO]
|
||
info = self.info(info)
|
||
transition = (
|
||
transition[TransitionIndex.OBSERVATION],
|
||
transition[TransitionIndex.ACTION],
|
||
transition[TransitionIndex.REWARD],
|
||
transition[TransitionIndex.DONE],
|
||
transition[TransitionIndex.TRUNCATED],
|
||
info,
|
||
*transition[TransitionIndex.COMPLEMENTARY_DATA :],
|
||
)
|
||
return transition
|