mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 21:01:26 +00:00
feat(scripts): lerobot-rollout
This commit is contained in:
committed by
Steven Palma
parent
5c43fa1cce
commit
bc06cb44ca
391
src/lerobot/rollout/inference/rtc.py
Normal file
391
src/lerobot/rollout/inference/rtc.py
Normal file
@@ -0,0 +1,391 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Real-Time Chunking inference engine.
|
||||
|
||||
A background thread produces action chunks asynchronously via
|
||||
:meth:`policy.predict_action_chunk`. The main control loop polls
|
||||
``get_action`` for the next ready action; observations flow the other
|
||||
way via ``notify_observation``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
import traceback
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc import ActionQueue, LatencyTracker
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.utils import prepare_observation_for_inference
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
RelativeActionsProcessorStep,
|
||||
TransitionKey,
|
||||
create_transition,
|
||||
to_relative_actions,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .base import InferenceEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# How long the RTC loop sleeps when paused, idle, or backpressured by a full queue.
|
||||
_RTC_IDLE_SLEEP_S: float = 0.01
|
||||
# Backoff between transient inference errors (per consecutive failure).
|
||||
_RTC_ERROR_RETRY_DELAY_S: float = 0.5
|
||||
# Consecutive transient errors tolerated before giving up and propagating shutdown.
|
||||
_RTC_MAX_CONSECUTIVE_ERRORS: int = 10
|
||||
# Hard timeout for joining the RTC thread on stop().
|
||||
_RTC_JOIN_TIMEOUT_S: float = 3.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RTC helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute: torch.Tensor,
|
||||
current_state: torch.Tensor,
|
||||
relative_step: RelativeActionsProcessorStep,
|
||||
normalizer_step: NormalizerProcessorStep | None,
|
||||
policy_device: torch.device | str,
|
||||
) -> torch.Tensor:
|
||||
"""Convert absolute leftover actions into model-space for relative-action RTC policies.
|
||||
|
||||
When using relative actions, the RTC prefix (previous chunk's unexecuted tail)
|
||||
is stored in absolute coordinates. Before feeding it back to the policy, this
|
||||
helper re-expresses those actions relative to the robot's current joint state
|
||||
and optionally normalizes them so the policy receives correctly scaled inputs.
|
||||
"""
|
||||
state = current_state.detach().cpu()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
action_cpu = prev_actions_absolute.detach().cpu()
|
||||
mask = relative_step._build_mask(action_cpu.shape[-1])
|
||||
relative_actions = to_relative_actions(action_cpu, state, mask)
|
||||
|
||||
transition = create_transition(action=relative_actions)
|
||||
if normalizer_step is not None:
|
||||
transition = normalizer_step(transition)
|
||||
|
||||
return transition[TransitionKey.ACTION].to(policy_device)
|
||||
|
||||
|
||||
def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int) -> torch.Tensor:
|
||||
"""Pad or truncate RTC prefix actions to a fixed length for stable compiled inference."""
|
||||
if prev_actions.ndim != 2:
|
||||
raise ValueError(f"Expected 2D [T, A] tensor, got shape={tuple(prev_actions.shape)}")
|
||||
steps, action_dim = prev_actions.shape
|
||||
if steps == target_steps:
|
||||
return prev_actions
|
||||
if steps > target_steps:
|
||||
return prev_actions[:target_steps]
|
||||
padded = torch.zeros((target_steps, action_dim), dtype=prev_actions.dtype, device=prev_actions.device)
|
||||
padded[:steps] = prev_actions
|
||||
return padded
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RTCInferenceEngine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RTCInferenceEngine(InferenceEngine):
|
||||
"""Async RTC inference: a background thread produces action chunks.
|
||||
|
||||
``get_action`` pops the next action from the shared queue (or
|
||||
returns ``None`` if the queue is empty). The main loop should call
|
||||
``notify_observation`` every tick and ``pause``/``resume`` around
|
||||
human-intervention phases.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
robot_wrapper: ThreadSafeRobot,
|
||||
rtc_config: RTCConfig,
|
||||
hw_features: dict,
|
||||
task: str,
|
||||
fps: float,
|
||||
device: str | None,
|
||||
use_torch_compile: bool = False,
|
||||
compile_warmup_inferences: int = 2,
|
||||
rtc_queue_threshold: int = 30,
|
||||
shutdown_event: Event | None = None,
|
||||
) -> None:
|
||||
self._policy = policy
|
||||
self._preprocessor = preprocessor
|
||||
self._postprocessor = postprocessor
|
||||
self._robot = robot_wrapper
|
||||
self._rtc_config = rtc_config
|
||||
self._hw_features = hw_features
|
||||
self._task = task
|
||||
self._fps = fps
|
||||
self._device = device or "cpu"
|
||||
self._use_torch_compile = use_torch_compile
|
||||
self._compile_warmup_inferences = compile_warmup_inferences
|
||||
self._rtc_queue_threshold = rtc_queue_threshold
|
||||
|
||||
self._action_queue: ActionQueue | None = None
|
||||
self._obs_holder: dict[str, Any] = {}
|
||||
self._obs_lock = Lock()
|
||||
self._policy_active = Event()
|
||||
self._compile_warmup_done = Event()
|
||||
self._shutdown_event = Event()
|
||||
self._rtc_error = Event()
|
||||
self._global_shutdown_event = shutdown_event
|
||||
self._rtc_thread: Thread | None = None
|
||||
|
||||
if not self._use_torch_compile:
|
||||
self._compile_warmup_done.set()
|
||||
logger.info("RTCInferenceEngine initialized (torch.compile disabled, no warmup needed)")
|
||||
else:
|
||||
logger.info(
|
||||
"RTCInferenceEngine initialized (torch.compile enabled, %d warmup inferences)",
|
||||
compile_warmup_inferences,
|
||||
)
|
||||
|
||||
# Processor introspection for relative-action re-anchoring.
|
||||
self._relative_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
||||
None,
|
||||
)
|
||||
self._normalizer_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
|
||||
None,
|
||||
)
|
||||
if self._relative_step is not None:
|
||||
if self._relative_step.action_names is None:
|
||||
cfg_names = getattr(policy.config, "action_feature_names", None)
|
||||
if cfg_names:
|
||||
self._relative_step.action_names = list(cfg_names)
|
||||
else:
|
||||
self._relative_step.action_names = [
|
||||
k for k in robot_wrapper.action_features if k.endswith(".pos")
|
||||
]
|
||||
logger.info("Relative actions enabled: RTC prefix will be re-anchored")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
"""True once torch.compile warmup is complete (or immediately if compile is disabled)."""
|
||||
return self._compile_warmup_done.is_set()
|
||||
|
||||
@property
|
||||
def failed(self) -> bool:
|
||||
"""True if the RTC background thread exited due to an unrecoverable error."""
|
||||
return self._rtc_error.is_set()
|
||||
|
||||
@property
|
||||
def action_queue(self) -> ActionQueue | None:
|
||||
"""The shared action queue between the RTC thread and the main loop."""
|
||||
return self._action_queue
|
||||
|
||||
def start(self) -> None:
|
||||
"""Launch the RTC background thread."""
|
||||
self._action_queue = ActionQueue(self._rtc_config)
|
||||
self._obs_holder = {
|
||||
"obs": None,
|
||||
"robot_type": self._robot.robot_type,
|
||||
}
|
||||
self._shutdown_event.clear()
|
||||
self._rtc_thread = Thread(
|
||||
target=self._rtc_loop,
|
||||
daemon=True,
|
||||
name="RTCInference",
|
||||
)
|
||||
self._rtc_thread.start()
|
||||
logger.info("RTC inference thread started")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the RTC thread to stop and wait for it."""
|
||||
logger.info("Stopping RTC inference thread...")
|
||||
self._shutdown_event.set()
|
||||
self._policy_active.clear()
|
||||
if self._rtc_thread is not None and self._rtc_thread.is_alive():
|
||||
self._rtc_thread.join(timeout=_RTC_JOIN_TIMEOUT_S)
|
||||
if self._rtc_thread.is_alive():
|
||||
logger.warning("RTC thread did not join within %.1fs", _RTC_JOIN_TIMEOUT_S)
|
||||
else:
|
||||
logger.info("RTC inference thread stopped")
|
||||
self._rtc_thread = None
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pause the RTC background thread."""
|
||||
logger.info("Pausing RTC inference thread")
|
||||
self._policy_active.clear()
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resume the RTC background thread."""
|
||||
logger.info("Resuming RTC inference thread")
|
||||
self._policy_active.set()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the policy, processors, and action queue."""
|
||||
logger.info("Resetting RTC inference state (policy + processors + queue)")
|
||||
self._policy.reset()
|
||||
self._preprocessor.reset()
|
||||
self._postprocessor.reset()
|
||||
if self._action_queue is not None:
|
||||
self._action_queue.clear()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Action production (called from main thread)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||
"""Pop the next action from the RTC queue (ignores ``obs_frame``)."""
|
||||
if self._action_queue is None:
|
||||
return None
|
||||
return self._action_queue.get()
|
||||
|
||||
def notify_observation(self, obs: dict) -> None:
|
||||
"""Publish the latest observation for the RTC thread to consume."""
|
||||
with self._obs_lock:
|
||||
self._obs_holder["obs"] = obs
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# RTC: background inference thread
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _rtc_loop(self) -> None:
|
||||
"""Background thread that generates action chunks via RTC."""
|
||||
try:
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / self._fps
|
||||
policy_device = torch.device(self._device)
|
||||
|
||||
warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0
|
||||
inference_count = 0
|
||||
consecutive_errors = 0
|
||||
|
||||
while not self._shutdown_event.is_set():
|
||||
if not self._policy_active.is_set():
|
||||
time.sleep(_RTC_IDLE_SLEEP_S)
|
||||
continue
|
||||
|
||||
queue = self._action_queue
|
||||
with self._obs_lock:
|
||||
obs = self._obs_holder.get("obs")
|
||||
if queue is None or obs is None:
|
||||
time.sleep(_RTC_IDLE_SLEEP_S)
|
||||
continue
|
||||
|
||||
if queue.qsize() <= self._rtc_queue_threshold:
|
||||
try:
|
||||
current_time = time.perf_counter()
|
||||
idx_before = queue.get_action_index()
|
||||
prev_actions = queue.get_left_over()
|
||||
|
||||
latency = latency_tracker.max()
|
||||
delay = math.ceil(latency / time_per_chunk) if latency else 0
|
||||
|
||||
obs_batch = build_dataset_frame(self._hw_features, obs, prefix="observation")
|
||||
obs_batch = prepare_observation_for_inference(
|
||||
obs_batch, policy_device, self._task, self._robot.robot_type
|
||||
)
|
||||
obs_batch["task"] = [self._task]
|
||||
|
||||
preprocessed = self._preprocessor(obs_batch)
|
||||
|
||||
if prev_actions is not None and self._relative_step is not None:
|
||||
state_tensor = preprocessed.get(OBS_STATE)
|
||||
if state_tensor is not None:
|
||||
prev_abs = queue.get_processed_left_over()
|
||||
if prev_abs is not None and prev_abs.numel() > 0:
|
||||
prev_actions = _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_abs,
|
||||
current_state=state_tensor,
|
||||
relative_step=self._relative_step,
|
||||
normalizer_step=self._normalizer_step,
|
||||
policy_device=policy_device,
|
||||
)
|
||||
|
||||
if prev_actions is not None:
|
||||
prev_actions = _normalize_prev_actions_length(
|
||||
prev_actions, target_steps=self._rtc_config.execution_horizon
|
||||
)
|
||||
|
||||
actions = self._policy.predict_action_chunk(
|
||||
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
|
||||
original = actions.squeeze(0).clone()
|
||||
processed = self._postprocessor(actions).squeeze(0)
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
|
||||
inference_count += 1
|
||||
consecutive_errors = 0
|
||||
is_warmup = self._use_torch_compile and inference_count <= warmup_required
|
||||
if is_warmup:
|
||||
latency_tracker.reset()
|
||||
else:
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
queue.merge(original, processed, new_delay, idx_before)
|
||||
|
||||
if (
|
||||
is_warmup
|
||||
and inference_count >= warmup_required
|
||||
and not self._compile_warmup_done.is_set()
|
||||
):
|
||||
self._compile_warmup_done.set()
|
||||
logger.info("Compile warmup complete (%d inferences)", inference_count)
|
||||
|
||||
logger.debug("RTC inference latency=%.2fs, queue=%d", new_latency, queue.qsize())
|
||||
|
||||
except Exception as e:
|
||||
consecutive_errors += 1
|
||||
logger.error(
|
||||
"RTC inference error (%d/%d): %s",
|
||||
consecutive_errors,
|
||||
_RTC_MAX_CONSECUTIVE_ERRORS,
|
||||
e,
|
||||
)
|
||||
logger.debug(traceback.format_exc())
|
||||
if consecutive_errors >= _RTC_MAX_CONSECUTIVE_ERRORS:
|
||||
# Persistent failure: stop retrying and propagate shutdown.
|
||||
raise
|
||||
time.sleep(_RTC_ERROR_RETRY_DELAY_S)
|
||||
else:
|
||||
time.sleep(_RTC_IDLE_SLEEP_S)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Fatal error in RTC thread: %s", e)
|
||||
logger.error(traceback.format_exc())
|
||||
self._rtc_error.set()
|
||||
# Unblock any warmup waiters so the main loop doesn't spin forever
|
||||
self._compile_warmup_done.set()
|
||||
# Signal the top-level shutdown so strategies exit their control loops
|
||||
if self._global_shutdown_event is not None:
|
||||
self._global_shutdown_event.set()
|
||||
Reference in New Issue
Block a user