mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 04:41:24 +00:00
* docs(processor): update docstrings batch_processor * docs(processor): update docstrings device_processor * docs(processor): update docstrings tokenizer_processor * update docstrings processor_act * update docstrings for pipeline_features * update docstrings for utils * update docstring for processor_diffusion * update docstrings factory * add docstrings to pi0 processor * add docstring to pi0fast processor * add docstring classifier processor * add docstring to sac processor * add docstring smolvla processor * add docstring to tdmpc processor * add docstring to vqbet processor * add docstrings to converters * add docstrings for delta_action_processor * add docstring to gym action processor * update hil processor * add docstring to joint obs processor * add docstring to migrate_normalize_processor * update docstrings normalize processor * update docstring normalize processor * update docstrings observation processor * update docstrings rename_processor * add docstrings robot_kinematic_processor * cleanup rl comments * add docstring to train.py * add docstring to teleoperate.py * add docstrings to phone_processor.py * add docstrings to teleop_phone.py * add docstrings to control_utils.py * add docstrings to visualization_utils.py --------- Co-authored-by: Pepijn <pepijn@huggingface.co>
82 lines
3.0 KiB
Python
82 lines
3.0 KiB
Python
# !/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import torch
|
|
|
|
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
|
from lerobot.processor import (
|
|
DeviceProcessorStep,
|
|
IdentityProcessorStep,
|
|
NormalizerProcessorStep,
|
|
PolicyProcessorPipeline,
|
|
ProcessorKwargs,
|
|
)
|
|
|
|
|
|
def make_classifier_processor(
|
|
config: RewardClassifierConfig,
|
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
|
preprocessor_kwargs: ProcessorKwargs | None = None,
|
|
postprocessor_kwargs: ProcessorKwargs | None = None,
|
|
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
|
"""
|
|
Constructs pre-processor and post-processor pipelines for the reward classifier.
|
|
|
|
The pre-processing pipeline prepares input data for the classifier by:
|
|
1. Normalizing both input and output features based on dataset statistics.
|
|
2. Moving the data to the specified device.
|
|
|
|
The post-processing pipeline handles the classifier's output by:
|
|
1. Moving the data to the CPU.
|
|
2. Applying an identity step, as no unnormalization is needed for the output logits.
|
|
|
|
Args:
|
|
config: The configuration object for the RewardClassifier.
|
|
dataset_stats: A dictionary of statistics for normalization.
|
|
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
|
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
|
|
|
Returns:
|
|
A tuple containing the configured pre-processor and post-processor pipelines.
|
|
"""
|
|
if preprocessor_kwargs is None:
|
|
preprocessor_kwargs = {}
|
|
if postprocessor_kwargs is None:
|
|
postprocessor_kwargs = {}
|
|
|
|
input_steps = [
|
|
NormalizerProcessorStep(
|
|
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
|
),
|
|
NormalizerProcessorStep(
|
|
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
|
),
|
|
DeviceProcessorStep(device=config.device),
|
|
]
|
|
output_steps = [DeviceProcessorStep(device="cpu"), IdentityProcessorStep()]
|
|
|
|
return (
|
|
PolicyProcessorPipeline(
|
|
steps=input_steps,
|
|
name="classifier_preprocessor",
|
|
**preprocessor_kwargs,
|
|
),
|
|
PolicyProcessorPipeline(
|
|
steps=output_steps,
|
|
name="classifier_postprocessor",
|
|
**postprocessor_kwargs,
|
|
),
|
|
)
|