mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 04:11:24 +00:00
129 lines
4.3 KiB
Python
129 lines
4.3 KiB
Python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Inference strategy configs and factory.
|
|
|
|
Selection is explicit via ``--inference.type=sync|rtc``. Adding a new
|
|
backend requires registering its config subclass and dispatching it in
|
|
:func:`create_inference_strategy`.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import abc
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from threading import Event
|
|
|
|
import draccus
|
|
|
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
|
from lerobot.processor import PolicyProcessorPipeline
|
|
|
|
from ..robot_wrapper import ThreadSafeRobot
|
|
from .base import InferenceStrategy
|
|
from .rtc import RTCInferenceStrategy
|
|
from .sync import SyncInferenceStrategy
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Configs
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class InferenceStrategyConfig(draccus.ChoiceRegistry, abc.ABC):
|
|
"""Abstract base for inference backend configuration.
|
|
|
|
Use ``--inference.type=<name>`` on the CLI to select a backend.
|
|
"""
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
return self.get_choice_name(self.__class__)
|
|
|
|
|
|
@InferenceStrategyConfig.register_subclass("sync")
|
|
@dataclass
|
|
class SyncInferenceConfig(InferenceStrategyConfig):
|
|
"""Inline synchronous inference (one policy call per control tick)."""
|
|
|
|
|
|
@InferenceStrategyConfig.register_subclass("rtc")
|
|
@dataclass
|
|
class RTCInferenceConfig(InferenceStrategyConfig):
|
|
"""Real-Time Chunking: async policy inference in a background thread."""
|
|
|
|
# ``RTCConfig`` is a small dataclass with default-only fields, so eagerly
|
|
# constructing one here costs nothing and keeps draccus' CLI surface flat
|
|
# (``--inference.rtc.execution_horizon=...`` etc.). No need to lazy-init.
|
|
rtc: RTCConfig = field(default_factory=RTCConfig)
|
|
queue_threshold: int = 30
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Factory
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def create_inference_strategy(
|
|
config: InferenceStrategyConfig,
|
|
*,
|
|
policy: PreTrainedPolicy,
|
|
preprocessor: PolicyProcessorPipeline,
|
|
postprocessor: PolicyProcessorPipeline,
|
|
robot_wrapper: ThreadSafeRobot,
|
|
hw_features: dict,
|
|
dataset_features: dict,
|
|
ordered_action_keys: list[str],
|
|
task: str,
|
|
fps: float,
|
|
device: str | None,
|
|
use_torch_compile: bool = False,
|
|
compile_warmup_inferences: int = 2,
|
|
shutdown_event: Event | None = None,
|
|
) -> InferenceStrategy:
|
|
"""Instantiate the appropriate inference strategy from a config object."""
|
|
if isinstance(config, SyncInferenceConfig):
|
|
return SyncInferenceStrategy(
|
|
policy=policy,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
dataset_features=dataset_features,
|
|
ordered_action_keys=ordered_action_keys,
|
|
task=task,
|
|
device=device,
|
|
robot_type=robot_wrapper.robot_type,
|
|
)
|
|
if isinstance(config, RTCInferenceConfig):
|
|
return RTCInferenceStrategy(
|
|
policy=policy,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
robot_wrapper=robot_wrapper,
|
|
rtc_config=config.rtc,
|
|
hw_features=hw_features,
|
|
task=task,
|
|
fps=fps,
|
|
device=device,
|
|
use_torch_compile=use_torch_compile,
|
|
compile_warmup_inferences=compile_warmup_inferences,
|
|
rtc_queue_threshold=config.queue_threshold,
|
|
shutdown_event=shutdown_event,
|
|
)
|
|
raise ValueError(f"Unknown inference strategy type: {type(config).__name__}")
|