mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 21:01:26 +00:00
- Updated DataProcessorPipeline to use HubMixin instead of ModelHubMixin for improved functionality. - Refactored save_pretrained method to handle saving
1169 lines
46 KiB
Python
1169 lines
46 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
This module defines a generic, sequential data processing pipeline framework, primarily designed for
|
|
transforming robotics data (observations, actions, rewards, etc.).
|
|
|
|
The core components are:
|
|
- ProcessorStep: An abstract base class for a single data transformation operation.
|
|
- ProcessorStepRegistry: A mechanism to register and retrieve ProcessorStep classes by name.
|
|
- DataProcessorPipeline: A class that chains multiple ProcessorStep instances together to form a complete
|
|
data processing workflow. It integrates with the Hugging Face Hub for easy sharing and versioning of
|
|
pipelines, including their configuration and state.
|
|
- Specialized abstract ProcessorStep subclasses (e.g., ObservationProcessorStep, ActionProcessorStep)
|
|
to simplify the creation of steps that target specific parts of a data transition.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import importlib
|
|
import json
|
|
import os
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Callable, Iterable, Sequence
|
|
from copy import deepcopy
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast
|
|
|
|
import torch
|
|
from huggingface_hub import hf_hub_download
|
|
from safetensors.torch import load_file, save_file
|
|
|
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
|
from lerobot.utils.hub import HubMixin
|
|
|
|
from .converters import batch_to_transition, create_transition, transition_to_batch
|
|
from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey
|
|
|
|
# Generic type variables for pipeline input and output.
|
|
TInput = TypeVar("TInput")
|
|
TOutput = TypeVar("TOutput")
|
|
|
|
|
|
class ProcessorStepRegistry:
|
|
"""A registry for ProcessorStep classes to allow instantiation from a string name.
|
|
|
|
This class provides a way to map string identifiers to `ProcessorStep` classes,
|
|
which is useful for deserializing pipelines from configuration files without
|
|
|
|
hardcoding class imports.
|
|
"""
|
|
|
|
_registry: dict[str, type] = {}
|
|
|
|
@classmethod
|
|
def register(cls, name: str | None = None):
|
|
"""A class decorator to register a ProcessorStep.
|
|
|
|
Args:
|
|
name: The name to register the class under. If None, the class's `__name__` is used.
|
|
|
|
Returns:
|
|
A decorator function that registers the class and returns it.
|
|
|
|
Raises:
|
|
ValueError: If a step with the same name is already registered.
|
|
"""
|
|
|
|
def decorator(step_class: type) -> type:
|
|
"""The actual decorator that performs the registration."""
|
|
registration_name = name if name is not None else step_class.__name__
|
|
|
|
if registration_name in cls._registry:
|
|
raise ValueError(
|
|
f"Processor step '{registration_name}' is already registered. "
|
|
f"Use a different name or unregister the existing one first."
|
|
)
|
|
|
|
cls._registry[registration_name] = step_class
|
|
# Store the registration name on the class for easy lookup during serialization.
|
|
step_class._registry_name = registration_name
|
|
return step_class
|
|
|
|
return decorator
|
|
|
|
@classmethod
|
|
def get(cls, name: str) -> type:
|
|
"""Retrieves a processor step class from the registry by its name.
|
|
|
|
Args:
|
|
name: The name of the step to retrieve.
|
|
|
|
Returns:
|
|
The processor step class corresponding to the given name.
|
|
|
|
Raises:
|
|
KeyError: If the name is not found in the registry.
|
|
"""
|
|
if name not in cls._registry:
|
|
available = list(cls._registry.keys())
|
|
raise KeyError(
|
|
f"Processor step '{name}' not found in registry. "
|
|
f"Available steps: {available}. "
|
|
f"Make sure the step is registered using @ProcessorStepRegistry.register()"
|
|
)
|
|
return cls._registry[name]
|
|
|
|
@classmethod
|
|
def unregister(cls, name: str) -> None:
|
|
"""Removes a processor step from the registry.
|
|
|
|
Args:
|
|
name: The name of the step to unregister.
|
|
"""
|
|
cls._registry.pop(name, None)
|
|
|
|
@classmethod
|
|
def list(cls) -> list[str]:
|
|
"""Returns a list of all registered processor step names."""
|
|
return list(cls._registry.keys())
|
|
|
|
@classmethod
|
|
def clear(cls) -> None:
|
|
"""Clears all processor steps from the registry."""
|
|
cls._registry.clear()
|
|
|
|
|
|
class ProcessorStep(ABC):
|
|
"""Abstract base class for a single step in a data processing pipeline.
|
|
|
|
Each step must implement the `__call__` method to perform its transformation
|
|
on a data transition and the `transform_features` method to describe how it
|
|
alters the shape or type of data features.
|
|
|
|
Subclasses can optionally be stateful by implementing `state_dict` and `load_state_dict`.
|
|
"""
|
|
|
|
_current_transition: EnvTransition | None = None
|
|
|
|
@property
|
|
def transition(self) -> EnvTransition:
|
|
"""Provides access to the most recent transition being processed.
|
|
|
|
This is useful for steps that need to access other parts of the transition
|
|
data beyond their primary target (e.g., an action processing step that
|
|
needs to look at the observation).
|
|
|
|
Raises:
|
|
ValueError: If accessed before the step has been called with a transition.
|
|
"""
|
|
if self._current_transition is None:
|
|
raise ValueError("Transition is not set. Make sure to call the step with a transition first.")
|
|
return self._current_transition
|
|
|
|
@abstractmethod
|
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
"""Processes an environment transition.
|
|
|
|
This method should contain the core logic of the processing step.
|
|
|
|
Args:
|
|
transition: The input data transition to be processed.
|
|
|
|
Returns:
|
|
The processed transition.
|
|
"""
|
|
return transition
|
|
|
|
def get_config(self) -> dict[str, Any]:
|
|
"""Returns the configuration of the step for serialization.
|
|
|
|
Returns:
|
|
A JSON-serializable dictionary of configuration parameters.
|
|
"""
|
|
return {}
|
|
|
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
"""Returns the state of the step (e.g., learned parameters, running means).
|
|
|
|
Returns:
|
|
A dictionary mapping state names to tensors.
|
|
"""
|
|
return {}
|
|
|
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
|
"""Loads the step's state from a state dictionary.
|
|
|
|
Args:
|
|
state: A dictionary of state tensors.
|
|
"""
|
|
return None
|
|
|
|
def reset(self) -> None:
|
|
"""Resets the internal state of the processor step, if any."""
|
|
return None
|
|
|
|
@abstractmethod
|
|
def transform_features(
|
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
|
"""Defines how this step modifies the description of pipeline features.
|
|
|
|
This method is used to track changes in data shapes, dtypes, or modalities
|
|
as data flows through the pipeline, without needing to process actual data.
|
|
|
|
Args:
|
|
features: A dictionary describing the input features for observations, actions, etc.
|
|
|
|
Returns:
|
|
A dictionary describing the output features after this step's transformation.
|
|
"""
|
|
return features
|
|
|
|
|
|
class ProcessorKwargs(TypedDict, total=False):
|
|
"""A TypedDict for optional keyword arguments used in pipeline construction."""
|
|
|
|
to_transition: Callable[[dict[str, Any]], EnvTransition] | None
|
|
to_output: Callable[[EnvTransition], Any] | None
|
|
name: str | None
|
|
before_step_hooks: list[Callable[[int, EnvTransition], None]] | None
|
|
after_step_hooks: list[Callable[[int, EnvTransition], None]] | None
|
|
|
|
|
|
@dataclass
|
|
class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
|
"""A sequential pipeline for processing data, integrated with the Hugging Face Hub.
|
|
|
|
This class chains together multiple `ProcessorStep` instances to form a complete
|
|
data processing workflow. It's generic, allowing for custom input and output types,
|
|
which are handled by the `to_transition` and `to_output` converters.
|
|
|
|
Attributes:
|
|
steps: A sequence of `ProcessorStep` objects that make up the pipeline.
|
|
name: A descriptive name for the pipeline.
|
|
to_transition: A function to convert raw input data into the standardized `EnvTransition` format.
|
|
to_output: A function to convert the final `EnvTransition` into the desired output format.
|
|
before_step_hooks: A list of functions to be called before each step is executed.
|
|
after_step_hooks: A list of functions to be called after each step is executed.
|
|
"""
|
|
|
|
steps: Sequence[ProcessorStep] = field(default_factory=list)
|
|
name: str = "DataProcessorPipeline"
|
|
|
|
to_transition: Callable[[TInput], EnvTransition] = field(
|
|
default_factory=lambda: cast(Callable[[TInput], EnvTransition], batch_to_transition), repr=False
|
|
)
|
|
to_output: Callable[[EnvTransition], TOutput] = field(
|
|
default_factory=lambda: cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
|
repr=False,
|
|
)
|
|
|
|
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
|
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
|
|
|
def __call__(self, data: TInput) -> TOutput:
|
|
"""Processes input data through the full pipeline.
|
|
|
|
Args:
|
|
data: The input data to process.
|
|
|
|
Returns:
|
|
The processed data in the specified output format.
|
|
"""
|
|
transition = self.to_transition(data)
|
|
transformed_transition = self._forward(transition)
|
|
return self.to_output(transformed_transition)
|
|
|
|
def _forward(self, transition: EnvTransition) -> EnvTransition:
|
|
"""Executes all processing steps and hooks in sequence.
|
|
|
|
Args:
|
|
transition: The initial `EnvTransition` object.
|
|
|
|
Returns:
|
|
The final `EnvTransition` after all steps have been applied.
|
|
"""
|
|
for idx, processor_step in enumerate(self.steps):
|
|
# Execute pre-hooks
|
|
for hook in self.before_step_hooks:
|
|
hook(idx, transition)
|
|
|
|
transition = processor_step(transition)
|
|
|
|
# Execute post-hooks
|
|
for hook in self.after_step_hooks:
|
|
hook(idx, transition)
|
|
return transition
|
|
|
|
def step_through(self, data: TInput) -> Iterable[EnvTransition]:
|
|
"""Processes data step-by-step, yielding the transition at each stage.
|
|
|
|
This is a generator method useful for debugging and inspecting the intermediate
|
|
state of the data as it passes through the pipeline.
|
|
|
|
Args:
|
|
data: The input data.
|
|
|
|
Yields:
|
|
The `EnvTransition` object, starting with the initial state and then after
|
|
each processing step.
|
|
"""
|
|
transition = self.to_transition(data)
|
|
|
|
# Yield the initial state before any processing.
|
|
yield transition
|
|
|
|
for processor_step in self.steps:
|
|
transition = processor_step(transition)
|
|
yield transition
|
|
|
|
def _save_pretrained(self, save_directory: Path, **kwargs):
|
|
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
|
|
|
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
|
"""
|
|
config_filename = kwargs.pop("config_filename", None)
|
|
|
|
# Sanitize the pipeline name to create a valid filename prefix.
|
|
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
|
|
|
if config_filename is None:
|
|
config_filename = f"{sanitized_name}.json"
|
|
|
|
config: dict[str, Any] = {
|
|
"name": self.name,
|
|
"steps": [],
|
|
}
|
|
|
|
# Iterate through each step to build its configuration entry.
|
|
for step_index, processor_step in enumerate(self.steps):
|
|
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
|
|
|
step_entry: dict[str, Any] = {}
|
|
# Prefer registry name for portability, otherwise fall back to full class path.
|
|
if registry_name:
|
|
step_entry["registry_name"] = registry_name
|
|
else:
|
|
step_entry["class"] = (
|
|
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
|
)
|
|
|
|
# Save step configuration if `get_config` is implemented.
|
|
if hasattr(processor_step, "get_config"):
|
|
step_entry["config"] = processor_step.get_config()
|
|
|
|
# Save step state if `state_dict` is implemented and returns a non-empty dict.
|
|
if hasattr(processor_step, "state_dict"):
|
|
state = processor_step.state_dict()
|
|
if state:
|
|
# Clone tensors to avoid modifying the original state.
|
|
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
|
|
|
|
# Create a unique filename for the state file.
|
|
if registry_name:
|
|
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
|
else:
|
|
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
|
|
|
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
|
step_entry["state_file"] = state_filename
|
|
|
|
config["steps"].append(step_entry)
|
|
|
|
# Write the main configuration JSON file.
|
|
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
|
json.dump(config, file_pointer, indent=2)
|
|
|
|
def save_pretrained(
|
|
self,
|
|
save_directory: str | Path | None = None,
|
|
*,
|
|
repo_id: str | None = None,
|
|
push_to_hub: bool = False,
|
|
card_kwargs: dict[str, Any] | None = None,
|
|
config_filename: str | None = None,
|
|
**push_to_hub_kwargs,
|
|
):
|
|
"""Saves the pipeline's configuration and state to a directory.
|
|
|
|
This method creates a JSON configuration file that defines the pipeline's structure
|
|
(name and steps). For each stateful step, it also saves a `.safetensors` file
|
|
containing its state dictionary.
|
|
|
|
Args:
|
|
save_directory: The directory where the pipeline will be saved. If None, saves to
|
|
HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}.
|
|
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`.
|
|
push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it.
|
|
card_kwargs: Additional arguments passed to the card template to customize the card.
|
|
config_filename: The name of the JSON configuration file. If None, a name is
|
|
generated from the pipeline's `name` attribute.
|
|
**push_to_hub_kwargs: Additional key word arguments passed along to the push_to_hub method.
|
|
"""
|
|
if save_directory is None:
|
|
# Use default directory in HF_LEROBOT_HOME
|
|
from lerobot.constants import HF_LEROBOT_HOME
|
|
|
|
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
|
save_directory = HF_LEROBOT_HOME / "processors" / sanitized_name
|
|
|
|
# For direct saves (not through hub), handle config_filename
|
|
if not push_to_hub and config_filename is not None:
|
|
# Call _save_pretrained directly with config_filename
|
|
save_directory = Path(save_directory)
|
|
save_directory.mkdir(parents=True, exist_ok=True)
|
|
self._save_pretrained(save_directory, config_filename=config_filename)
|
|
return None
|
|
|
|
# Pass config_filename through kwargs for _save_pretrained when using hub
|
|
if config_filename is not None:
|
|
push_to_hub_kwargs["config_filename"] = config_filename
|
|
|
|
# Call parent's save_pretrained which will call our _save_pretrained
|
|
return super().save_pretrained(
|
|
save_directory=save_directory,
|
|
repo_id=repo_id,
|
|
push_to_hub=push_to_hub,
|
|
card_kwargs=card_kwargs,
|
|
**push_to_hub_kwargs,
|
|
)
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls,
|
|
pretrained_model_name_or_path: str | Path,
|
|
*,
|
|
force_download: bool = False,
|
|
resume_download: bool | None = None,
|
|
proxies: dict[str, str] | None = None,
|
|
token: str | bool | None = None,
|
|
cache_dir: str | Path | None = None,
|
|
local_files_only: bool = False,
|
|
revision: str | None = None,
|
|
config_filename: str | None = None,
|
|
overrides: dict[str, Any] | None = None,
|
|
to_transition: Callable[[TInput], EnvTransition] | None = None,
|
|
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
|
**kwargs,
|
|
) -> DataProcessorPipeline[TInput, TOutput]:
|
|
"""Loads a pipeline from a local directory or a Hugging Face Hub repository.
|
|
|
|
This method reconstructs a `DataProcessorPipeline` by:
|
|
1. Loading the main JSON configuration file.
|
|
2. Iterating through the steps defined in the config.
|
|
3. Dynamically importing or looking up each step's class.
|
|
4. Instantiating each step with its saved configuration, potentially with overrides.
|
|
5. Loading the step's state from its `.safetensors` file, if it exists.
|
|
|
|
Args:
|
|
pretrained_model_name_or_path: The identifier of the repository on the Hugging Face Hub
|
|
or a path to a local directory.
|
|
force_download: Whether to force (re)downloading the files.
|
|
resume_download: Whether to resume a previously interrupted download.
|
|
proxies: A dictionary of proxy servers to use.
|
|
token: The token to use as HTTP bearer authorization for private Hub repositories.
|
|
cache_dir: The path to a specific cache folder to store downloaded files.
|
|
local_files_only: If True, avoid downloading files from the Hub.
|
|
revision: The specific model version to use (e.g., a branch name, tag name, or commit id).
|
|
config_filename: The name of the pipeline's JSON configuration file. If not provided,
|
|
it's auto-detected in local directories (if only one .json file exists). This parameter
|
|
is mandatory when loading from Hugging Face Hub repositories.
|
|
overrides: A dictionary to override the configuration of specific steps. Keys should
|
|
match the step's class name or registry name.
|
|
to_transition: A custom function to convert input data to `EnvTransition`.
|
|
to_output: A custom function to convert the final `EnvTransition` to the output format.
|
|
**kwargs: Additional arguments (not used).
|
|
|
|
Returns:
|
|
An instance of `DataProcessorPipeline` loaded with the specified configuration and state.
|
|
|
|
Raises:
|
|
FileNotFoundError: If the config file cannot be found.
|
|
ValueError: If configuration is ambiguous or instantiation fails.
|
|
ImportError: If a step's class cannot be imported.
|
|
KeyError: If an override key doesn't match any step in the pipeline.
|
|
"""
|
|
model_id = str(pretrained_model_name_or_path)
|
|
loaded_config: dict[str, Any] | None = None
|
|
base_path: Path | None = None
|
|
|
|
# Standard pattern: try local directory first
|
|
if Path(model_id).is_dir():
|
|
base_path = Path(model_id)
|
|
|
|
# Handle config filename
|
|
if config_filename is None:
|
|
json_files = list(base_path.glob("*.json"))
|
|
if len(json_files) == 0:
|
|
# No config files found locally, will try Hub next
|
|
pass
|
|
elif len(json_files) == 1:
|
|
config_filename = json_files[0].name
|
|
else:
|
|
raise ValueError(
|
|
f"Multiple .json files found in {model_id}: {[f.name for f in json_files]}. "
|
|
f"Please specify which one to load using the config_filename parameter."
|
|
)
|
|
|
|
# Try to load config from local directory
|
|
if config_filename and (base_path / config_filename).exists():
|
|
with open(base_path / config_filename) as f:
|
|
loaded_config = json.load(f)
|
|
|
|
# If not found locally, try Hub
|
|
if loaded_config is None:
|
|
# Check if this looks like a local path that doesn't exist
|
|
# Hub repo IDs have format "user/repo" with exactly one slash
|
|
# Local paths typically have multiple slashes, backslashes, or start with ./ or ../
|
|
looks_like_local_path = (
|
|
model_id.count("/") > 1 # Multiple slashes suggest local path
|
|
or "\\" in model_id # Backslashes are only in local paths
|
|
or Path(model_id).is_absolute() # Absolute paths are local
|
|
or model_id.startswith("./")
|
|
or model_id.startswith("../") # Relative path indicators
|
|
)
|
|
|
|
if looks_like_local_path:
|
|
# This appears to be a local path that doesn't exist
|
|
raise FileNotFoundError(f"Local path '{model_id}' does not exist")
|
|
# For Hub repositories, config_filename is mandatory
|
|
if config_filename is None:
|
|
raise ValueError(
|
|
f"When loading from Hugging Face Hub, 'config_filename' must be specified. "
|
|
f"Example: DataProcessorPipeline.from_pretrained('{model_id}', config_filename='processor.json')"
|
|
)
|
|
|
|
try:
|
|
# Download the configuration file from the Hub
|
|
config_path = hf_hub_download(
|
|
repo_id=model_id,
|
|
filename=config_filename,
|
|
repo_type="model",
|
|
force_download=force_download,
|
|
resume_download=resume_download,
|
|
proxies=proxies,
|
|
token=token,
|
|
cache_dir=cache_dir,
|
|
local_files_only=local_files_only,
|
|
revision=revision,
|
|
)
|
|
|
|
with open(config_path) as f:
|
|
loaded_config = json.load(f)
|
|
|
|
# The base path for other files (like state tensors) is the directory of the config file
|
|
base_path = Path(config_path).parent
|
|
|
|
except Exception as e:
|
|
raise FileNotFoundError(
|
|
f"Could not find {config_filename} on the HuggingFace Hub at {model_id}"
|
|
) from e
|
|
|
|
# At this point, loaded_config must be loaded successfully
|
|
if loaded_config is None:
|
|
raise RuntimeError("Failed to load configuration from local directory or Hub")
|
|
|
|
if overrides is None:
|
|
overrides = {}
|
|
|
|
override_keys = set(overrides.keys())
|
|
|
|
steps: list[ProcessorStep] = []
|
|
for step_entry in loaded_config["steps"]:
|
|
# Determine the step class, prioritizing the registry.
|
|
if "registry_name" in step_entry:
|
|
try:
|
|
step_class = ProcessorStepRegistry.get(step_entry["registry_name"])
|
|
step_key = step_entry["registry_name"]
|
|
except KeyError as e:
|
|
raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e
|
|
else:
|
|
# Fallback to dynamic import using the full class path.
|
|
full_class_path = step_entry["class"]
|
|
module_path, class_name = full_class_path.rsplit(".", 1)
|
|
|
|
try:
|
|
module = importlib.import_module(module_path)
|
|
step_class = getattr(module, class_name)
|
|
step_key = class_name
|
|
except (ImportError, AttributeError) as e:
|
|
raise ImportError(
|
|
f"Failed to load processor step '{full_class_path}'. "
|
|
f"Make sure the module '{module_path}' is installed and contains class '{class_name}'. "
|
|
f"Consider registering the step using @ProcessorStepRegistry.register() for better portability. "
|
|
f"Error: {str(e)}"
|
|
) from e
|
|
|
|
# Instantiate the step, merging saved config with user-provided overrides.
|
|
try:
|
|
saved_cfg = step_entry.get("config", {})
|
|
step_overrides = overrides.get(step_key, {})
|
|
merged_cfg = {**saved_cfg, **step_overrides}
|
|
step_instance: ProcessorStep = step_class(**merged_cfg)
|
|
|
|
if step_key in override_keys:
|
|
override_keys.discard(step_key)
|
|
|
|
except Exception as e:
|
|
step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown"))
|
|
raise ValueError(
|
|
f"Failed to instantiate processor step '{step_name}' with config: {step_entry.get('config', {})}. "
|
|
f"Error: {str(e)}"
|
|
) from e
|
|
|
|
# Load the step's state if a state file is specified.
|
|
if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"):
|
|
# Check if state file exists locally first
|
|
if base_path and (base_path / step_entry["state_file"]).exists():
|
|
state_path = str(base_path / step_entry["state_file"])
|
|
else:
|
|
# Download the state file from the Hub.
|
|
state_path = hf_hub_download(
|
|
repo_id=model_id,
|
|
filename=step_entry["state_file"],
|
|
repo_type="model",
|
|
force_download=force_download,
|
|
resume_download=resume_download,
|
|
proxies=proxies,
|
|
token=token,
|
|
cache_dir=cache_dir,
|
|
local_files_only=local_files_only,
|
|
revision=revision,
|
|
)
|
|
|
|
step_instance.load_state_dict(load_file(state_path))
|
|
|
|
steps.append(step_instance)
|
|
|
|
# Check for any unused override keys, which likely indicates a typo by the user.
|
|
if override_keys:
|
|
available_keys = [
|
|
step.get("registry_name") or step["class"].rsplit(".", 1)[1]
|
|
for step in loaded_config["steps"]
|
|
]
|
|
|
|
raise KeyError(
|
|
f"Override keys {list(override_keys)} do not match any step in the saved configuration. "
|
|
f"Available step keys: {available_keys}. "
|
|
f"Make sure override keys match exact step class names or registry names."
|
|
)
|
|
|
|
# Construct and return the final pipeline instance.
|
|
return cls(
|
|
steps=steps,
|
|
name=loaded_config.get("name", "DataProcessorPipeline"),
|
|
to_transition=to_transition or batch_to_transition,
|
|
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
|
)
|
|
|
|
def __len__(self) -> int:
|
|
"""Returns the number of steps in the pipeline."""
|
|
return len(self.steps)
|
|
|
|
def __getitem__(self, idx: int | slice) -> ProcessorStep | DataProcessorPipeline[TInput, TOutput]:
|
|
"""Retrieves a step or a sub-pipeline by index or slice.
|
|
|
|
Args:
|
|
idx: An integer index or a slice object.
|
|
|
|
Returns:
|
|
A `ProcessorStep` if `idx` is an integer, or a new `DataProcessorPipeline`
|
|
containing the sliced steps.
|
|
"""
|
|
if isinstance(idx, slice):
|
|
# Return a new pipeline instance with the sliced steps.
|
|
return DataProcessorPipeline(
|
|
steps=self.steps[idx],
|
|
name=self.name,
|
|
to_transition=self.to_transition,
|
|
to_output=self.to_output,
|
|
before_step_hooks=self.before_step_hooks.copy(),
|
|
after_step_hooks=self.after_step_hooks.copy(),
|
|
)
|
|
return self.steps[idx]
|
|
|
|
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
|
"""Registers a function to be called before each step.
|
|
|
|
Args:
|
|
fn: A callable that accepts the step index and the current transition.
|
|
"""
|
|
self.before_step_hooks.append(fn)
|
|
|
|
def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
|
"""Unregisters a 'before_step' hook.
|
|
|
|
Args:
|
|
fn: The exact function object that was previously registered.
|
|
|
|
Raises:
|
|
ValueError: If the hook is not found in the list.
|
|
"""
|
|
try:
|
|
self.before_step_hooks.remove(fn)
|
|
except ValueError:
|
|
raise ValueError(
|
|
f"Hook {fn} not found in before_step_hooks. Make sure to pass the exact same function reference."
|
|
) from None
|
|
|
|
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
|
"""Registers a function to be called after each step.
|
|
|
|
Args:
|
|
fn: A callable that accepts the step index and the current transition.
|
|
"""
|
|
self.after_step_hooks.append(fn)
|
|
|
|
def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
|
"""Unregisters an 'after_step' hook.
|
|
|
|
Args:
|
|
fn: The exact function object that was previously registered.
|
|
|
|
Raises:
|
|
ValueError: If the hook is not found in the list.
|
|
"""
|
|
try:
|
|
self.after_step_hooks.remove(fn)
|
|
except ValueError:
|
|
raise ValueError(
|
|
f"Hook {fn} not found in after_step_hooks. Make sure to pass the exact same function reference."
|
|
) from None
|
|
|
|
def reset(self):
|
|
"""Resets the state of all stateful steps in the pipeline."""
|
|
for step in self.steps:
|
|
if hasattr(step, "reset"):
|
|
step.reset()
|
|
|
|
def __repr__(self) -> str:
|
|
"""Provides a concise string representation of the pipeline."""
|
|
step_names = [step.__class__.__name__ for step in self.steps]
|
|
|
|
if not step_names:
|
|
steps_repr = "steps=0: []"
|
|
elif len(step_names) <= 3:
|
|
steps_repr = f"steps={len(step_names)}: [{', '.join(step_names)}]"
|
|
else:
|
|
# For long pipelines, show the first, second, and last steps.
|
|
displayed = f"{step_names[0]}, {step_names[1]}, ..., {step_names[-1]}"
|
|
steps_repr = f"steps={len(step_names)}: [{displayed}]"
|
|
|
|
parts = [f"name='{self.name}'", steps_repr]
|
|
|
|
return f"DataProcessorPipeline({', '.join(parts)})"
|
|
|
|
def __post_init__(self):
|
|
"""Validates that all provided steps are instances of `ProcessorStep`."""
|
|
for i, step in enumerate(self.steps):
|
|
if not isinstance(step, ProcessorStep):
|
|
raise TypeError(f"Step {i} ({type(step).__name__}) must inherit from ProcessorStep")
|
|
|
|
def transform_features(
|
|
self, initial_features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
|
"""Applies feature transformations from all steps sequentially.
|
|
|
|
This method propagates a feature description dictionary through each step's
|
|
`transform_features` method, allowing the pipeline to statically determine
|
|
the output feature specification without processing any real data.
|
|
|
|
Args:
|
|
initial_features: A dictionary describing the initial features.
|
|
|
|
Returns:
|
|
The final feature description after all transformations.
|
|
"""
|
|
features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = deepcopy(initial_features)
|
|
|
|
for _, step in enumerate(self.steps):
|
|
out = step.transform_features(features)
|
|
features = out
|
|
return features
|
|
|
|
# Convenience methods for processing individual parts of a transition.
|
|
def process_observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
|
"""Processes only the observation part of a transition through the pipeline.
|
|
|
|
Args:
|
|
observation: The observation dictionary.
|
|
|
|
Returns:
|
|
The processed observation dictionary.
|
|
"""
|
|
transition: EnvTransition = create_transition(observation=observation)
|
|
transformed_transition = self._forward(transition)
|
|
return transformed_transition[TransitionKey.OBSERVATION]
|
|
|
|
def process_action(
|
|
self, action: PolicyAction | RobotAction | EnvAction
|
|
) -> PolicyAction | RobotAction | EnvAction:
|
|
"""Processes only the action part of a transition through the pipeline.
|
|
|
|
Args:
|
|
action: The action data.
|
|
|
|
Returns:
|
|
The processed action.
|
|
"""
|
|
transition: EnvTransition = create_transition(action=action)
|
|
transformed_transition = self._forward(transition)
|
|
return transformed_transition[TransitionKey.ACTION]
|
|
|
|
def process_reward(self, reward: float | torch.Tensor) -> float | torch.Tensor:
|
|
"""Processes only the reward part of a transition through the pipeline.
|
|
|
|
Args:
|
|
reward: The reward value.
|
|
|
|
Returns:
|
|
The processed reward.
|
|
"""
|
|
transition: EnvTransition = create_transition(reward=reward)
|
|
transformed_transition = self._forward(transition)
|
|
return transformed_transition[TransitionKey.REWARD]
|
|
|
|
def process_done(self, done: bool | torch.Tensor) -> bool | torch.Tensor:
|
|
"""Processes only the done flag of a transition through the pipeline.
|
|
|
|
Args:
|
|
done: The done flag.
|
|
|
|
Returns:
|
|
The processed done flag.
|
|
"""
|
|
transition: EnvTransition = create_transition(done=done)
|
|
transformed_transition = self._forward(transition)
|
|
return transformed_transition[TransitionKey.DONE]
|
|
|
|
def process_truncated(self, truncated: bool | torch.Tensor) -> bool | torch.Tensor:
|
|
"""Processes only the truncated flag of a transition through the pipeline.
|
|
|
|
Args:
|
|
truncated: The truncated flag.
|
|
|
|
Returns:
|
|
The processed truncated flag.
|
|
"""
|
|
transition: EnvTransition = create_transition(truncated=truncated)
|
|
transformed_transition = self._forward(transition)
|
|
return transformed_transition[TransitionKey.TRUNCATED]
|
|
|
|
def process_info(self, info: dict[str, Any]) -> dict[str, Any]:
|
|
"""Processes only the info dictionary of a transition through the pipeline.
|
|
|
|
Args:
|
|
info: The info dictionary.
|
|
|
|
Returns:
|
|
The processed info dictionary.
|
|
"""
|
|
transition: EnvTransition = create_transition(info=info)
|
|
transformed_transition = self._forward(transition)
|
|
return transformed_transition[TransitionKey.INFO]
|
|
|
|
def process_complementary_data(self, complementary_data: dict[str, Any]) -> dict[str, Any]:
|
|
"""Processes only the complementary data part of a transition through the pipeline.
|
|
|
|
Args:
|
|
complementary_data: The complementary data dictionary.
|
|
|
|
Returns:
|
|
The processed complementary data dictionary.
|
|
"""
|
|
transition: EnvTransition = create_transition(complementary_data=complementary_data)
|
|
transformed_transition = self._forward(transition)
|
|
return transformed_transition[TransitionKey.COMPLEMENTARY_DATA]
|
|
|
|
|
|
# Type aliases for semantic clarity.
|
|
RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
|
|
PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
|
|
|
|
|
|
class ObservationProcessorStep(ProcessorStep, ABC):
|
|
"""An abstract `ProcessorStep` that specifically targets the observation in a transition."""
|
|
|
|
@abstractmethod
|
|
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
|
"""Processes an observation dictionary. Subclasses must implement this method.
|
|
|
|
Args:
|
|
observation: The input observation dictionary from the transition.
|
|
|
|
Returns:
|
|
The processed observation dictionary.
|
|
"""
|
|
...
|
|
|
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
"""Applies the `observation` method to the transition's observation."""
|
|
self._current_transition = transition.copy()
|
|
new_transition = self._current_transition
|
|
|
|
observation = new_transition.get(TransitionKey.OBSERVATION)
|
|
if observation is None or not isinstance(observation, dict):
|
|
raise ValueError("ObservationProcessorStep requires an observation in the transition.")
|
|
|
|
processed_observation = self.observation(observation.copy())
|
|
new_transition[TransitionKey.OBSERVATION] = processed_observation
|
|
return new_transition
|
|
|
|
|
|
class ActionProcessorStep(ProcessorStep, ABC):
|
|
"""An abstract `ProcessorStep` that specifically targets the action in a transition."""
|
|
|
|
@abstractmethod
|
|
def action(
|
|
self, action: PolicyAction | RobotAction | EnvAction
|
|
) -> PolicyAction | RobotAction | EnvAction:
|
|
"""Processes an action. Subclasses must implement this method.
|
|
|
|
Args:
|
|
action: The input action from the transition.
|
|
|
|
Returns:
|
|
The processed action.
|
|
"""
|
|
...
|
|
|
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
"""Applies the `action` method to the transition's action."""
|
|
self._current_transition = transition.copy()
|
|
new_transition = self._current_transition
|
|
|
|
action = new_transition.get(TransitionKey.ACTION)
|
|
if action is None:
|
|
raise ValueError("ActionProcessorStep requires an action in the transition.")
|
|
|
|
processed_action = self.action(action)
|
|
new_transition[TransitionKey.ACTION] = processed_action
|
|
return new_transition
|
|
|
|
|
|
class RobotActionProcessorStep(ProcessorStep, ABC):
|
|
"""An abstract `ProcessorStep` for processing a `RobotAction` (a dictionary)."""
|
|
|
|
@abstractmethod
|
|
def action(self, action: RobotAction) -> RobotAction:
|
|
"""Processes a `RobotAction`. Subclasses must implement this method.
|
|
|
|
Args:
|
|
action: The input `RobotAction` dictionary.
|
|
|
|
Returns:
|
|
The processed `RobotAction`.
|
|
"""
|
|
...
|
|
|
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
"""Applies the `action` method to the transition's action, ensuring it's a `RobotAction`."""
|
|
self._current_transition = transition.copy()
|
|
new_transition = self._current_transition
|
|
|
|
action = new_transition.get(TransitionKey.ACTION)
|
|
if action is None or not isinstance(action, dict):
|
|
raise ValueError(f"Action should be a RobotAction type (dict), but got {type(action)}")
|
|
|
|
processed_action = self.action(action.copy())
|
|
new_transition[TransitionKey.ACTION] = processed_action
|
|
return new_transition
|
|
|
|
|
|
class PolicyActionProcessorStep(ProcessorStep, ABC):
|
|
"""An abstract `ProcessorStep` for processing a `PolicyAction` (a tensor or dict of tensors)."""
|
|
|
|
@abstractmethod
|
|
def action(self, action: PolicyAction) -> PolicyAction:
|
|
"""Processes a `PolicyAction`. Subclasses must implement this method.
|
|
|
|
Args:
|
|
action: The input `PolicyAction`.
|
|
|
|
Returns:
|
|
The processed `PolicyAction`.
|
|
"""
|
|
...
|
|
|
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
"""Applies the `action` method to the transition's action, ensuring it's a `PolicyAction`."""
|
|
self._current_transition = transition.copy()
|
|
new_transition = self._current_transition
|
|
|
|
action = new_transition.get(TransitionKey.ACTION)
|
|
if not isinstance(action, PolicyAction):
|
|
raise ValueError(f"Action should be a PolicyAction type (tensor), but got {type(action)}")
|
|
|
|
processed_action = self.action(action)
|
|
new_transition[TransitionKey.ACTION] = processed_action
|
|
return new_transition
|
|
|
|
|
|
class RewardProcessorStep(ProcessorStep, ABC):
|
|
"""An abstract `ProcessorStep` that specifically targets the reward in a transition."""
|
|
|
|
@abstractmethod
|
|
def reward(self, reward) -> float | torch.Tensor:
|
|
"""Processes a reward. Subclasses must implement this method.
|
|
|
|
Args:
|
|
reward: The input reward from the transition.
|
|
|
|
Returns:
|
|
The processed reward.
|
|
"""
|
|
...
|
|
|
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
"""Applies the `reward` method to the transition's reward."""
|
|
self._current_transition = transition.copy()
|
|
new_transition = self._current_transition
|
|
|
|
reward = new_transition.get(TransitionKey.REWARD)
|
|
if reward is None:
|
|
raise ValueError("RewardProcessorStep requires a reward in the transition.")
|
|
|
|
processed_reward = self.reward(reward)
|
|
new_transition[TransitionKey.REWARD] = processed_reward
|
|
return new_transition
|
|
|
|
|
|
class DoneProcessorStep(ProcessorStep, ABC):
|
|
"""An abstract `ProcessorStep` that specifically targets the 'done' flag in a transition."""
|
|
|
|
@abstractmethod
|
|
def done(self, done) -> bool | torch.Tensor:
|
|
"""Processes a 'done' flag. Subclasses must implement this method.
|
|
|
|
Args:
|
|
done: The input 'done' flag from the transition.
|
|
|
|
Returns:
|
|
The processed 'done' flag.
|
|
"""
|
|
...
|
|
|
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
"""Applies the `done` method to the transition's 'done' flag."""
|
|
self._current_transition = transition.copy()
|
|
new_transition = self._current_transition
|
|
|
|
done = new_transition.get(TransitionKey.DONE)
|
|
if done is None:
|
|
raise ValueError("DoneProcessorStep requires a done flag in the transition.")
|
|
|
|
processed_done = self.done(done)
|
|
new_transition[TransitionKey.DONE] = processed_done
|
|
return new_transition
|
|
|
|
|
|
class TruncatedProcessorStep(ProcessorStep, ABC):
|
|
"""An abstract `ProcessorStep` that specifically targets the 'truncated' flag in a transition."""
|
|
|
|
@abstractmethod
|
|
def truncated(self, truncated) -> bool | torch.Tensor:
|
|
"""Processes a 'truncated' flag. Subclasses must implement this method.
|
|
|
|
Args:
|
|
truncated: The input 'truncated' flag from the transition.
|
|
|
|
Returns:
|
|
The processed 'truncated' flag.
|
|
"""
|
|
...
|
|
|
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
"""Applies the `truncated` method to the transition's 'truncated' flag."""
|
|
self._current_transition = transition.copy()
|
|
new_transition = self._current_transition
|
|
|
|
truncated = new_transition.get(TransitionKey.TRUNCATED)
|
|
if truncated is None:
|
|
raise ValueError("TruncatedProcessorStep requires a truncated flag in the transition.")
|
|
|
|
processed_truncated = self.truncated(truncated)
|
|
new_transition[TransitionKey.TRUNCATED] = processed_truncated
|
|
return new_transition
|
|
|
|
|
|
class InfoProcessorStep(ProcessorStep, ABC):
|
|
"""An abstract `ProcessorStep` that specifically targets the 'info' dictionary in a transition."""
|
|
|
|
@abstractmethod
|
|
def info(self, info) -> dict[str, Any]:
|
|
"""Processes an 'info' dictionary. Subclasses must implement this method.
|
|
|
|
Args:
|
|
info: The input 'info' dictionary from the transition.
|
|
|
|
Returns:
|
|
The processed 'info' dictionary.
|
|
"""
|
|
...
|
|
|
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
"""Applies the `info` method to the transition's 'info' dictionary."""
|
|
self._current_transition = transition.copy()
|
|
new_transition = self._current_transition
|
|
|
|
info = new_transition.get(TransitionKey.INFO)
|
|
if info is None or not isinstance(info, dict):
|
|
raise ValueError("InfoProcessorStep requires an info dictionary in the transition.")
|
|
|
|
processed_info = self.info(info.copy())
|
|
new_transition[TransitionKey.INFO] = processed_info
|
|
return new_transition
|
|
|
|
|
|
class ComplementaryDataProcessorStep(ProcessorStep, ABC):
|
|
"""An abstract `ProcessorStep` that targets the 'complementary_data' in a transition."""
|
|
|
|
@abstractmethod
|
|
def complementary_data(self, complementary_data) -> dict[str, Any]:
|
|
"""Processes a 'complementary_data' dictionary. Subclasses must implement this method.
|
|
|
|
Args:
|
|
complementary_data: The input 'complementary_data' from the transition.
|
|
|
|
Returns:
|
|
The processed 'complementary_data' dictionary.
|
|
"""
|
|
...
|
|
|
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
"""Applies the `complementary_data` method to the transition's data."""
|
|
self._current_transition = transition.copy()
|
|
new_transition = self._current_transition
|
|
|
|
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
|
if complementary_data is None or not isinstance(complementary_data, dict):
|
|
raise ValueError("ComplementaryDataProcessorStep requires complementary data in the transition.")
|
|
|
|
processed_complementary_data = self.complementary_data(complementary_data.copy())
|
|
new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data
|
|
return new_transition
|
|
|
|
|
|
class IdentityProcessorStep(ProcessorStep):
|
|
"""A no-op processor step that returns the input transition and features unchanged.
|
|
|
|
This can be useful as a placeholder or for debugging purposes.
|
|
"""
|
|
|
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
"""Returns the transition without modification."""
|
|
return transition
|
|
|
|
def transform_features(
|
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
|
"""Returns the features without modification."""
|
|
return features
|