mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 04:11:24 +00:00
refactor(processor): improve processor pipeline typing with generic type (#1810)
* refactor(processor): introduce generic type for to_output - Always return `TOutput` - Remove `_prepare_transition`, so `__call__` now always returns `TOutput` - Update tests accordingly - This refactor paves the way for adding settings for `to_transition` and `to_output` in `make_processor` and the post-processor * refactor(processor): consolidate ProcessorKwargs usage across policies - Removed the ProcessorTypes module and integrated ProcessorKwargs directly into the processor pipeline. - Updated multiple policy files to utilize the new ProcessorKwargs structure for preprocessor and postprocessor arguments. - Simplified the handling of processor kwargs by initializing them to empty dictionaries when not provided.
This commit is contained in:
@@ -21,6 +21,7 @@ from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
@@ -29,8 +30,16 @@ from lerobot.processor import (
|
||||
|
||||
|
||||
def make_diffusion_pre_post_processors(
|
||||
config: DiffusionConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
config: DiffusionConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
@@ -47,6 +56,15 @@ def make_diffusion_pre_post_processors(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user