mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
* feat(scripts): lerobot-rollout * fix(rollout) require dataset in dagger + use duration too * fix(docs): dagger num_episodes * test(rollout): fix expectations * fix(rollout): features check * fix(rollout): device and task propagation + feature pos + warn fps + move rename_map config * docs(rollout): edit rename_map instructions * chore(rollout): multiple minor improvements * chore(rollout): address coments + minor improvements * fix(rollout): enable default * fix(tests): default value RTCConfig * fix(rollout): robot_observation_processor and notify_observation at policy frequency instead of interpolator rate Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> * fix(rollout): prevent relativeactions with sync inference engine Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> * fix(rollout): rtc reanchor to non normalized state Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> * fix(rollout): fixing the episode length to use hwc (#3469) also reducing default length to 5 minutes * feat(rollout): go back to initial position is now a config * fix(rollout): properly propagating video_files_size_in_mb to lerobot_dataset (#3470) * chore(rollout): note about dagger correction stage * chore(docs): update comments and docstring * fix(test): move rtc relative out of rollout module * fix(rollout): address the review comments --------- Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Maxime Ellerbach <maxime.ellerbach@huggingface.co>
361 lines
15 KiB
Python
361 lines
15 KiB
Python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Real-Time Chunking inference engine.
|
|
|
|
A background thread produces action chunks asynchronously via
|
|
:meth:`policy.predict_action_chunk`. The main control loop polls
|
|
``get_action`` for the next ready action; observations flow the other
|
|
way via ``notify_observation``.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import math
|
|
import time
|
|
import traceback
|
|
from threading import Event, Lock, Thread
|
|
from typing import Any
|
|
|
|
import torch
|
|
|
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
|
from lerobot.policies.rtc import ActionQueue, LatencyTracker, reanchor_relative_rtc_prefix
|
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
|
from lerobot.policies.utils import prepare_observation_for_inference
|
|
from lerobot.processor import (
|
|
NormalizerProcessorStep,
|
|
PolicyProcessorPipeline,
|
|
RelativeActionsProcessorStep,
|
|
)
|
|
from lerobot.utils.feature_utils import build_dataset_frame
|
|
|
|
from ..robot_wrapper import ThreadSafeRobot
|
|
from .base import InferenceEngine
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# How long the RTC loop sleeps when paused, idle, or backpressured by a full queue.
|
|
_RTC_IDLE_SLEEP_S: float = 0.01
|
|
# Backoff between transient inference errors (per consecutive failure).
|
|
_RTC_ERROR_RETRY_DELAY_S: float = 0.5
|
|
# Consecutive transient errors tolerated before giving up and propagating shutdown.
|
|
_RTC_MAX_CONSECUTIVE_ERRORS: int = 10
|
|
# Hard timeout for joining the RTC thread on stop().
|
|
_RTC_JOIN_TIMEOUT_S: float = 3.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RTC helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int) -> torch.Tensor:
|
|
"""Pad or truncate RTC prefix actions to a fixed length for stable compiled inference."""
|
|
if prev_actions.ndim != 2:
|
|
raise ValueError(f"Expected 2D [T, A] tensor, got shape={tuple(prev_actions.shape)}")
|
|
steps, action_dim = prev_actions.shape
|
|
if steps == target_steps:
|
|
return prev_actions
|
|
if steps > target_steps:
|
|
return prev_actions[:target_steps]
|
|
padded = torch.zeros((target_steps, action_dim), dtype=prev_actions.dtype, device=prev_actions.device)
|
|
padded[:steps] = prev_actions
|
|
return padded
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RTCInferenceEngine
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class RTCInferenceEngine(InferenceEngine):
|
|
"""Async RTC inference: a background thread produces action chunks.
|
|
|
|
``get_action`` pops the next action from the shared queue (or
|
|
returns ``None`` if the queue is empty). The main loop should call
|
|
``notify_observation`` every tick and ``pause``/``resume`` around
|
|
human-intervention phases.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
policy: PreTrainedPolicy,
|
|
preprocessor: PolicyProcessorPipeline,
|
|
postprocessor: PolicyProcessorPipeline,
|
|
robot_wrapper: ThreadSafeRobot,
|
|
rtc_config: RTCConfig,
|
|
hw_features: dict,
|
|
task: str,
|
|
fps: float,
|
|
device: str | None,
|
|
use_torch_compile: bool = False,
|
|
compile_warmup_inferences: int = 2,
|
|
rtc_queue_threshold: int = 30,
|
|
shutdown_event: Event | None = None,
|
|
) -> None:
|
|
self._policy = policy
|
|
self._preprocessor = preprocessor
|
|
self._postprocessor = postprocessor
|
|
self._robot = robot_wrapper
|
|
self._rtc_config = rtc_config
|
|
self._hw_features = hw_features
|
|
self._task = task
|
|
self._fps = fps
|
|
self._device = device or "cpu"
|
|
self._use_torch_compile = use_torch_compile
|
|
self._compile_warmup_inferences = compile_warmup_inferences
|
|
self._rtc_queue_threshold = rtc_queue_threshold
|
|
|
|
self._action_queue: ActionQueue | None = None
|
|
self._obs_holder: dict[str, Any] = {}
|
|
self._obs_lock = Lock()
|
|
self._policy_active = Event()
|
|
self._compile_warmup_done = Event()
|
|
self._shutdown_event = Event()
|
|
self._rtc_error = Event()
|
|
self._global_shutdown_event = shutdown_event
|
|
self._rtc_thread: Thread | None = None
|
|
|
|
if not self._use_torch_compile:
|
|
self._compile_warmup_done.set()
|
|
logger.info("RTCInferenceEngine initialized (torch.compile disabled, no warmup needed)")
|
|
else:
|
|
logger.info(
|
|
"RTCInferenceEngine initialized (torch.compile enabled, %d warmup inferences)",
|
|
compile_warmup_inferences,
|
|
)
|
|
|
|
# Processor introspection for relative-action re-anchoring.
|
|
self._relative_step = next(
|
|
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
|
None,
|
|
)
|
|
self._normalizer_step = next(
|
|
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
|
|
None,
|
|
)
|
|
if self._relative_step is not None:
|
|
if self._relative_step.action_names is None:
|
|
cfg_names = getattr(policy.config, "action_feature_names", None)
|
|
if cfg_names:
|
|
self._relative_step.action_names = list(cfg_names)
|
|
else:
|
|
self._relative_step.action_names = [
|
|
k for k in robot_wrapper.action_features if k.endswith(".pos")
|
|
]
|
|
logger.info("Relative actions enabled: RTC prefix will be re-anchored")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Lifecycle
|
|
# ------------------------------------------------------------------
|
|
|
|
@property
|
|
def ready(self) -> bool:
|
|
"""True once torch.compile warmup is complete (or immediately if compile is disabled)."""
|
|
return self._compile_warmup_done.is_set()
|
|
|
|
@property
|
|
def failed(self) -> bool:
|
|
"""True if the RTC background thread exited due to an unrecoverable error."""
|
|
return self._rtc_error.is_set()
|
|
|
|
@property
|
|
def action_queue(self) -> ActionQueue | None:
|
|
"""The shared action queue between the RTC thread and the main loop."""
|
|
return self._action_queue
|
|
|
|
def start(self) -> None:
|
|
"""Launch the RTC background thread."""
|
|
self._action_queue = ActionQueue(self._rtc_config)
|
|
self._obs_holder = {
|
|
"obs": None,
|
|
"robot_type": self._robot.robot_type,
|
|
}
|
|
self._shutdown_event.clear()
|
|
self._rtc_thread = Thread(
|
|
target=self._rtc_loop,
|
|
daemon=True,
|
|
name="RTCInference",
|
|
)
|
|
self._rtc_thread.start()
|
|
logger.info("RTC inference thread started")
|
|
|
|
def stop(self) -> None:
|
|
"""Signal the RTC thread to stop and wait for it."""
|
|
logger.info("Stopping RTC inference thread...")
|
|
self._shutdown_event.set()
|
|
self._policy_active.clear()
|
|
if self._rtc_thread is not None and self._rtc_thread.is_alive():
|
|
self._rtc_thread.join(timeout=_RTC_JOIN_TIMEOUT_S)
|
|
if self._rtc_thread.is_alive():
|
|
logger.warning("RTC thread did not join within %.1fs", _RTC_JOIN_TIMEOUT_S)
|
|
else:
|
|
logger.info("RTC inference thread stopped")
|
|
self._rtc_thread = None
|
|
|
|
def pause(self) -> None:
|
|
"""Pause the RTC background thread."""
|
|
logger.info("Pausing RTC inference thread")
|
|
self._policy_active.clear()
|
|
|
|
def resume(self) -> None:
|
|
"""Resume the RTC background thread."""
|
|
logger.info("Resuming RTC inference thread")
|
|
self._policy_active.set()
|
|
|
|
def reset(self) -> None:
|
|
"""Reset the policy, processors, and action queue."""
|
|
logger.info("Resetting RTC inference state (policy + processors + queue)")
|
|
self._policy.reset()
|
|
self._preprocessor.reset()
|
|
self._postprocessor.reset()
|
|
if self._action_queue is not None:
|
|
self._action_queue.clear()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Action production (called from main thread)
|
|
# ------------------------------------------------------------------
|
|
|
|
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
|
"""Pop the next action from the RTC queue (ignores ``obs_frame``)."""
|
|
if self._action_queue is None:
|
|
return None
|
|
return self._action_queue.get()
|
|
|
|
def notify_observation(self, obs: dict) -> None:
|
|
"""Publish the latest observation for the RTC thread to consume."""
|
|
with self._obs_lock:
|
|
self._obs_holder["obs"] = obs
|
|
|
|
# ------------------------------------------------------------------
|
|
# RTC: background inference thread
|
|
# ------------------------------------------------------------------
|
|
|
|
def _rtc_loop(self) -> None:
|
|
"""Background thread that generates action chunks via RTC."""
|
|
try:
|
|
latency_tracker = LatencyTracker()
|
|
time_per_chunk = 1.0 / self._fps
|
|
policy_device = torch.device(self._device)
|
|
|
|
warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0
|
|
inference_count = 0
|
|
consecutive_errors = 0
|
|
|
|
while not self._shutdown_event.is_set():
|
|
if not self._policy_active.is_set():
|
|
time.sleep(_RTC_IDLE_SLEEP_S)
|
|
continue
|
|
|
|
queue = self._action_queue
|
|
with self._obs_lock:
|
|
obs = self._obs_holder.get("obs")
|
|
if queue is None or obs is None:
|
|
time.sleep(_RTC_IDLE_SLEEP_S)
|
|
continue
|
|
|
|
if queue.qsize() <= self._rtc_queue_threshold:
|
|
try:
|
|
current_time = time.perf_counter()
|
|
idx_before = queue.get_action_index()
|
|
prev_actions = queue.get_left_over()
|
|
|
|
latency = latency_tracker.max()
|
|
delay = math.ceil(latency / time_per_chunk) if latency else 0
|
|
|
|
obs_batch = build_dataset_frame(self._hw_features, obs, prefix="observation")
|
|
obs_batch = prepare_observation_for_inference(
|
|
obs_batch, policy_device, self._task, self._robot.robot_type
|
|
)
|
|
obs_batch["task"] = [self._task]
|
|
|
|
preprocessed = self._preprocessor(obs_batch)
|
|
|
|
if prev_actions is not None and self._relative_step is not None:
|
|
# Rebase against the raw cached state so the leftover tail stays in
|
|
# the training-time coordinate frame.
|
|
raw_state = self._relative_step.get_cached_state()
|
|
if raw_state is not None:
|
|
prev_abs = queue.get_processed_left_over()
|
|
if prev_abs is not None and prev_abs.numel() > 0:
|
|
prev_actions = reanchor_relative_rtc_prefix(
|
|
prev_actions_absolute=prev_abs,
|
|
current_state=raw_state,
|
|
relative_step=self._relative_step,
|
|
normalizer_step=self._normalizer_step,
|
|
policy_device=policy_device,
|
|
)
|
|
|
|
if prev_actions is not None:
|
|
prev_actions = _normalize_prev_actions_length(
|
|
prev_actions, target_steps=self._rtc_config.execution_horizon
|
|
)
|
|
|
|
actions = self._policy.predict_action_chunk(
|
|
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
|
|
)
|
|
|
|
original = actions.squeeze(0).clone()
|
|
processed = self._postprocessor(actions).squeeze(0)
|
|
new_latency = time.perf_counter() - current_time
|
|
new_delay = math.ceil(new_latency / time_per_chunk)
|
|
|
|
inference_count += 1
|
|
consecutive_errors = 0
|
|
is_warmup = self._use_torch_compile and inference_count <= warmup_required
|
|
if is_warmup:
|
|
latency_tracker.reset()
|
|
else:
|
|
latency_tracker.add(new_latency)
|
|
|
|
queue.merge(original, processed, new_delay, idx_before)
|
|
|
|
if (
|
|
is_warmup
|
|
and inference_count >= warmup_required
|
|
and not self._compile_warmup_done.is_set()
|
|
):
|
|
self._compile_warmup_done.set()
|
|
logger.info("Compile warmup complete (%d inferences)", inference_count)
|
|
|
|
logger.debug("RTC inference latency=%.2fs, queue=%d", new_latency, queue.qsize())
|
|
|
|
except Exception as e:
|
|
consecutive_errors += 1
|
|
logger.error(
|
|
"RTC inference error (%d/%d): %s",
|
|
consecutive_errors,
|
|
_RTC_MAX_CONSECUTIVE_ERRORS,
|
|
e,
|
|
)
|
|
logger.debug(traceback.format_exc())
|
|
if consecutive_errors >= _RTC_MAX_CONSECUTIVE_ERRORS:
|
|
# Persistent failure: stop retrying and propagate shutdown.
|
|
raise
|
|
time.sleep(_RTC_ERROR_RETRY_DELAY_S)
|
|
else:
|
|
time.sleep(_RTC_IDLE_SLEEP_S)
|
|
|
|
except Exception as e:
|
|
logger.error("Fatal error in RTC thread: %s", e)
|
|
logger.error(traceback.format_exc())
|
|
self._rtc_error.set()
|
|
# Unblock any warmup waiters so the main loop doesn't spin forever
|
|
self._compile_warmup_done.set()
|
|
# Signal the top-level shutdown so strategies exit their control loops
|
|
if self._global_shutdown_event is not None:
|
|
self._global_shutdown_event.set()
|