feat(batch_processor): Enhance ToBatchProcessor to handle action batching

- Updated ToBatchProcessor to add batch dimensions to actions in addition to observations.
- Implemented separate methods for processing observations and actions, improving code readability.
- Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types.
This commit is contained in:
Adil Zouitine
2025-07-24 17:20:57 +02:00
committed by Steven Palma
parent 21baa8fa02
commit 99de7567e6
3 changed files with 245 additions and 10 deletions

View File

@@ -46,6 +46,7 @@ from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_safetensors
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.batch_processor import ToBatchProcessor
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from lerobot.processor.pipeline import RobotProcessor
@@ -403,14 +404,16 @@ def main():
preprocessor_steps = [
NormalizerProcessor(features=input_features, norm_map=norm_map, stats=stats),
NormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
ToBatchProcessor(),
]
preprocessor = RobotProcessor(preprocessor_steps, name=f"{policy_type}_preprocessor")
preprocessor = RobotProcessor(preprocessor_steps, name="preprocessor")
# Create postprocessor with unnormalizer for outputs only
postprocessor_steps = [
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
ToBatchProcessor(),
]
postprocessor = RobotProcessor(postprocessor_steps, name=f"{policy_type}_postprocessor")
postprocessor = RobotProcessor(postprocessor_steps, name="postprocessor")
# Determine hub repo ID if pushing to hub
if args.push_to_hub: