From f2c29d78cfbfd99a1e53ce4d857f5adc0672fc1b Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 14 Apr 2026 17:51:03 +0200 Subject: [PATCH] more improvements and fixes --- src/lerobot/policies/rtc/action_queue.py | 20 +- src/lerobot/rollout/configs.py | 11 +- src/lerobot/rollout/inference.py | 18 +- src/lerobot/rollout/ring_buffer.py | 2 +- src/lerobot/rollout/strategies/__init__.py | 212 +------------ src/lerobot/rollout/strategies/core.py | 187 +++++++++++ src/lerobot/rollout/strategies/dagger.py | 342 +++++++++++++-------- src/lerobot/rollout/strategies/factory.py | 54 ++++ src/lerobot/rollout/strategies/sentry.py | 28 +- 9 files changed, 525 insertions(+), 349 deletions(-) create mode 100644 src/lerobot/rollout/strategies/core.py create mode 100644 src/lerobot/rollout/strategies/factory.py diff --git a/src/lerobot/policies/rtc/action_queue.py b/src/lerobot/policies/rtc/action_queue.py index dbbdc41df..199257b12 100644 --- a/src/lerobot/policies/rtc/action_queue.py +++ b/src/lerobot/policies/rtc/action_queue.py @@ -92,10 +92,10 @@ class ActionQueue: Returns: int: Number of unconsumed actions. """ - if self.queue is None: - return 0 - length = len(self.queue) - return length - self.last_index + with self.lock: + if self.queue is None: + return 0 + return len(self.queue) - self.last_index def empty(self) -> bool: """Check if the queue is empty. @@ -103,11 +103,10 @@ class ActionQueue: Returns: bool: True if no actions remain, False otherwise. """ - if self.queue is None: - return True - - length = len(self.queue) - return length - self.last_index <= 0 + with self.lock: + if self.queue is None: + return True + return len(self.queue) - self.last_index <= 0 def get_action_index(self) -> int: """Get the current action consumption index. @@ -115,7 +114,8 @@ class ActionQueue: Returns: int: Index of the next action to be consumed. """ - return self.last_index + with self.lock: + return self.last_index def get_left_over(self) -> Tensor | None: """Get leftover original actions for RTC prev_chunk_left_over. diff --git a/src/lerobot/rollout/configs.py b/src/lerobot/rollout/configs.py index 7e37fa8f2..994c49289 100644 --- a/src/lerobot/rollout/configs.py +++ b/src/lerobot/rollout/configs.py @@ -203,10 +203,13 @@ class RolloutConfig: ) # Sentry MUST use streaming encoding to avoid disk I/O blocking the control loop - if isinstance(self.strategy, SentryStrategyConfig) and self.dataset is not None: - if not self.dataset.streaming_encoding: - logger.warning("Sentry mode forces streaming_encoding=True") - self.dataset.streaming_encoding = True + if ( + isinstance(self.strategy, SentryStrategyConfig) + and self.dataset is not None + and not self.dataset.streaming_encoding + ): + logger.warning("Sentry mode forces streaming_encoding=True") + self.dataset.streaming_encoding = True @classmethod def __get_path_fields__(cls) -> list[str]: diff --git a/src/lerobot/rollout/inference.py b/src/lerobot/rollout/inference.py index c5e94bf3f..01f9ca81f 100644 --- a/src/lerobot/rollout/inference.py +++ b/src/lerobot/rollout/inference.py @@ -147,6 +147,7 @@ class InferenceEngine: use_torch_compile: bool = False, compile_warmup_inferences: int = 2, rtc_queue_threshold: int = 30, + shutdown_event: Event | None = None, ) -> None: self._policy = policy self._preprocessor = preprocessor @@ -170,6 +171,8 @@ class InferenceEngine: self._policy_active = Event() self._compile_warmup_done = Event() self._shutdown_event = Event() + self._rtc_error = Event() + self._global_shutdown_event = shutdown_event self._rtc_thread: Thread | None = None if not self._use_torch_compile: @@ -211,6 +214,11 @@ class InferenceEngine: def compile_warmup_done(self) -> Event: return self._compile_warmup_done + @property + def rtc_failed(self) -> bool: + """True if the RTC background thread exited due to an unrecoverable error.""" + return self._rtc_error.is_set() + def start(self) -> None: """Start the inference engine. Launches the RTC background thread if enabled.""" if self._use_rtc: @@ -249,8 +257,8 @@ class InferenceEngine: self._policy.reset() self._preprocessor.reset() self._postprocessor.reset() - if self._use_rtc: - self._action_queue = ActionQueue(self._rtc_config) + if self._use_rtc and self._action_queue is not None: + self._action_queue.clear() # ------------------------------------------------------------------ # Sync inference @@ -401,3 +409,9 @@ class InferenceEngine: except Exception as e: logger.error("Fatal error in RTC thread: %s", e) logger.error(traceback.format_exc()) + self._rtc_error.set() + # Unblock any warmup waiters so the main loop doesn't spin forever + self._compile_warmup_done.set() + # Signal the top-level shutdown so strategies exit their control loops + if self._global_shutdown_event is not None: + self._global_shutdown_event.set() diff --git a/src/lerobot/rollout/ring_buffer.py b/src/lerobot/rollout/ring_buffer.py index f2aa88d14..9041b5ce9 100644 --- a/src/lerobot/rollout/ring_buffer.py +++ b/src/lerobot/rollout/ring_buffer.py @@ -95,6 +95,6 @@ def _estimate_frame_bytes(frame: dict) -> int: total += v.nbytes elif isinstance(v, (int, float)): total += 8 - elif isinstance(v, str) or isinstance(v, bytes): + elif isinstance(v, (str, bytes)): total += len(v) return max(total, 1) # avoid zero-size frames diff --git a/src/lerobot/rollout/strategies/__init__.py b/src/lerobot/rollout/strategies/__init__.py index 4674d852e..446bc7155 100644 --- a/src/lerobot/rollout/strategies/__init__.py +++ b/src/lerobot/rollout/strategies/__init__.py @@ -12,209 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Rollout strategy ABC, factory, and shared inference helper.""" +"""Rollout strategies — public API re-exports.""" -from __future__ import annotations +from .core import RolloutStrategy, infer_action +from .factory import create_strategy -import abc -import time -from typing import TYPE_CHECKING - -import torch - -from lerobot.policies.rtc import ActionInterpolator -from lerobot.policies.utils import make_robot_action -from lerobot.utils.constants import OBS_STR -from lerobot.utils.feature_utils import build_dataset_frame -from lerobot.utils.robot_utils import precise_sleep - -if TYPE_CHECKING: - from lerobot.rollout.configs import RolloutStrategyConfig - from lerobot.rollout.context import RolloutContext - from lerobot.rollout.inference import InferenceEngine - - -class RolloutStrategy(abc.ABC): - """Abstract base for rollout execution strategies. - - Each concrete strategy implements a self-contained control loop with - its own recording/interaction semantics. Strategies are mutually - exclusive — only one runs per session. - """ - - def __init__(self, config: RolloutStrategyConfig) -> None: - self.config = config - self._engine: InferenceEngine | None = None - self._interpolator: ActionInterpolator | None = None - self._warmup_flushed: bool = False - - def _init_engine(self, ctx: RolloutContext) -> None: - """Create and start the inference engine and action interpolator. - - Call this from ``setup()`` to avoid duplicating the engine - construction across every strategy. - """ - from lerobot.rollout.inference import InferenceEngine - - self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier) - self._engine = InferenceEngine( - policy=ctx.policy, - preprocessor=ctx.preprocessor, - postprocessor=ctx.postprocessor, - robot_wrapper=ctx.robot_wrapper, - rtc_config=ctx.cfg.rtc, - hw_features=ctx.hw_features, - action_keys=ctx.action_keys, - task=ctx.cfg.task, - fps=ctx.cfg.fps, - device=ctx.cfg.device, - use_torch_compile=ctx.cfg.use_torch_compile, - compile_warmup_inferences=ctx.cfg.compile_warmup_inferences, - ) - self._engine.start() - self._warmup_flushed = False - - def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool: - """Handle torch.compile warmup phase. - - Returns ``True`` if the caller should ``continue`` (still warming - up). On the first post-warmup iteration the engine and - interpolator are reset so stale warmup state is discarded. - """ - engine = self._engine - interpolator = self._interpolator - if not use_torch_compile: - return False - if not engine.compile_warmup_done.is_set(): - dt = time.perf_counter() - loop_start - if (sleep_t := control_interval - dt) > 0: - precise_sleep(sleep_t) - return True - if not self._warmup_flushed: - engine.reset() - interpolator.reset() - self._warmup_flushed = True - if engine.is_rtc: - engine.resume() - return False - - def _teardown_hardware(self, ctx: RolloutContext) -> None: - """Stop the inference engine and disconnect hardware.""" - if self._engine is not None: - self._engine.stop() - if ctx.robot.is_connected: - ctx.robot.disconnect() - if ctx.teleop is not None and ctx.teleop.is_connected: - ctx.teleop.disconnect() - - @abc.abstractmethod - def setup(self, ctx: RolloutContext) -> None: - """Strategy-specific initialisation (keyboard listeners, buffers, etc.).""" - - @abc.abstractmethod - def run(self, ctx: RolloutContext) -> None: - """Main rollout loop. Returns when shutdown is requested or duration expires.""" - - @abc.abstractmethod - def teardown(self, ctx: RolloutContext) -> None: - """Cleanup: save dataset, stop threads, disconnect hardware.""" - - -# --------------------------------------------------------------------------- -# Shared inference helper -# --------------------------------------------------------------------------- - - -def infer_action( - engine: InferenceEngine, - obs_processed: dict, - obs_raw: dict, - ctx: RolloutContext, - interpolator: ActionInterpolator, - ordered_keys: list[str], - features: dict, -) -> dict | None: - """Run one policy inference step and send the resulting action to the robot. - - Handles both sync and RTC backends. Uses the interpolator for smooth - control at higher-than-inference rates (works with any multiplier, - including 1 where it acts as a pass-through). - - Parameters - ---------- - engine: - The inference engine (sync or RTC). - obs_processed: - Observation dict after ``robot_observation_processor``. - obs_raw: - Raw observation dict (needed by ``robot_action_processor``). - ctx: - Rollout context. - interpolator: - Action interpolator for Nx control rate. - ordered_keys: - Ordered action feature names (policy-to-robot mapping). - features: - Feature specification dict for ``build_dataset_frame`` / - ``make_robot_action``. Use ``dataset.features`` when recording, - ``ctx.dataset_features`` otherwise. - - Returns - ------- - Action dict sent to the robot, or ``None`` if no action was - available (empty RTC queue, interpolator buffer not ready). - """ - if engine.is_rtc: - if interpolator.needs_new_action(): - action_tensor = engine.consume_rtc_action() - if action_tensor is not None: - interpolator.add(action_tensor.cpu()) - else: - if interpolator.needs_new_action(): - obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) - action_tensor = engine.get_action_sync(obs_frame) - action_dict = make_robot_action(action_tensor, features) - action_t = torch.tensor([action_dict[k] for k in ordered_keys]) - interpolator.add(action_t) - - interp = interpolator.get() - if interp is not None: - action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)} - processed = ctx.robot_action_processor((action_dict, obs_raw)) - ctx.robot_wrapper.send_action(processed) - return action_dict - return None - - -# --------------------------------------------------------------------------- -# Strategy factory -# --------------------------------------------------------------------------- - - -def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy: - """Instantiate the appropriate strategy from a config object.""" - from lerobot.rollout.configs import ( - BaseStrategyConfig, - DAggerStrategyConfig, - HighlightStrategyConfig, - SentryStrategyConfig, - ) - - if isinstance(config, BaseStrategyConfig): - from .base import BaseStrategy - - return BaseStrategy(config) - if isinstance(config, SentryStrategyConfig): - from .sentry import SentryStrategy - - return SentryStrategy(config) - if isinstance(config, HighlightStrategyConfig): - from .highlight import HighlightStrategy - - return HighlightStrategy(config) - if isinstance(config, DAggerStrategyConfig): - from .dagger import DAggerStrategy - - return DAggerStrategy(config) - - raise ValueError(f"Unknown strategy config type: {type(config).__name__}") +__all__ = [ + "RolloutStrategy", + "create_strategy", + "infer_action", +] diff --git a/src/lerobot/rollout/strategies/core.py b/src/lerobot/rollout/strategies/core.py new file mode 100644 index 000000000..565194259 --- /dev/null +++ b/src/lerobot/rollout/strategies/core.py @@ -0,0 +1,187 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rollout strategy ABC and shared inference helper.""" + +from __future__ import annotations + +import abc +import time +from typing import TYPE_CHECKING + +import torch + +from lerobot.policies.rtc import ActionInterpolator +from lerobot.policies.utils import make_robot_action +from lerobot.utils.constants import OBS_STR +from lerobot.utils.feature_utils import build_dataset_frame +from lerobot.utils.robot_utils import precise_sleep + +if TYPE_CHECKING: + from lerobot.rollout.configs import RolloutStrategyConfig + from lerobot.rollout.context import RolloutContext + from lerobot.rollout.inference import InferenceEngine + + +class RolloutStrategy(abc.ABC): + """Abstract base for rollout execution strategies. + + Each concrete strategy implements a self-contained control loop with + its own recording/interaction semantics. Strategies are mutually + exclusive — only one runs per session. + """ + + def __init__(self, config: RolloutStrategyConfig) -> None: + self.config = config + self._engine: InferenceEngine | None = None + self._interpolator: ActionInterpolator | None = None + self._warmup_flushed: bool = False + + def _init_engine(self, ctx: RolloutContext) -> None: + """Create and start the inference engine and action interpolator. + + Call this from ``setup()`` to avoid duplicating the engine + construction across every strategy. + """ + from lerobot.rollout.inference import InferenceEngine + + self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier) + self._engine = InferenceEngine( + policy=ctx.policy, + preprocessor=ctx.preprocessor, + postprocessor=ctx.postprocessor, + robot_wrapper=ctx.robot_wrapper, + rtc_config=ctx.cfg.rtc, + hw_features=ctx.hw_features, + action_keys=ctx.action_keys, + task=ctx.cfg.task, + fps=ctx.cfg.fps, + device=ctx.cfg.device, + use_torch_compile=ctx.cfg.use_torch_compile, + compile_warmup_inferences=ctx.cfg.compile_warmup_inferences, + shutdown_event=ctx.shutdown_event, + ) + self._engine.start() + self._warmup_flushed = False + + def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool: + """Handle torch.compile warmup phase. + + Returns ``True`` if the caller should ``continue`` (still warming + up). On the first post-warmup iteration the engine and + interpolator are reset so stale warmup state is discarded. + """ + engine = self._engine + interpolator = self._interpolator + if not use_torch_compile: + return False + if not engine.compile_warmup_done.is_set(): + dt = time.perf_counter() - loop_start + if (sleep_t := control_interval - dt) > 0: + precise_sleep(sleep_t) + return True + if not self._warmup_flushed: + engine.reset() + interpolator.reset() + self._warmup_flushed = True + if engine.is_rtc: + engine.resume() + return False + + def _teardown_hardware(self, ctx: RolloutContext) -> None: + """Stop the inference engine and disconnect hardware.""" + if self._engine is not None: + self._engine.stop() + if ctx.robot.is_connected: + ctx.robot.disconnect() + if ctx.teleop is not None and ctx.teleop.is_connected: + ctx.teleop.disconnect() + + @abc.abstractmethod + def setup(self, ctx: RolloutContext) -> None: + """Strategy-specific initialisation (keyboard listeners, buffers, etc.).""" + + @abc.abstractmethod + def run(self, ctx: RolloutContext) -> None: + """Main rollout loop. Returns when shutdown is requested or duration expires.""" + + @abc.abstractmethod + def teardown(self, ctx: RolloutContext) -> None: + """Cleanup: save dataset, stop threads, disconnect hardware.""" + + +# --------------------------------------------------------------------------- +# Shared inference helper +# --------------------------------------------------------------------------- + + +def infer_action( + engine: InferenceEngine, + obs_processed: dict, + obs_raw: dict, + ctx: RolloutContext, + interpolator: ActionInterpolator, + ordered_keys: list[str], + features: dict, +) -> dict | None: + """Run one policy inference step and send the resulting action to the robot. + + Handles both sync and RTC backends. Uses the interpolator for smooth + control at higher-than-inference rates (works with any multiplier, + including 1 where it acts as a pass-through). + + Parameters + ---------- + engine: + The inference engine (sync or RTC). + obs_processed: + Observation dict after ``robot_observation_processor``. + obs_raw: + Raw observation dict (needed by ``robot_action_processor``). + ctx: + Rollout context. + interpolator: + Action interpolator for Nx control rate. + ordered_keys: + Ordered action feature names (policy-to-robot mapping). + features: + Feature specification dict for ``build_dataset_frame`` / + ``make_robot_action``. Use ``dataset.features`` when recording, + ``ctx.dataset_features`` otherwise. + + Returns + ------- + Action dict sent to the robot, or ``None`` if no action was + available (empty RTC queue, interpolator buffer not ready). + """ + if engine.is_rtc: + if interpolator.needs_new_action(): + action_tensor = engine.consume_rtc_action() + if action_tensor is not None: + interpolator.add(action_tensor.cpu()) + else: + if interpolator.needs_new_action(): + obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) + action_tensor = engine.get_action_sync(obs_frame) + action_dict = make_robot_action(action_tensor, features) + action_t = torch.tensor([action_dict[k] for k in ordered_keys]) + interpolator.add(action_t) + + interp = interpolator.get() + if interp is not None: + action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)} + processed = ctx.robot_action_processor((action_dict, obs_raw)) + ctx.robot_wrapper.send_action(processed) + return action_dict + return None diff --git a/src/lerobot/rollout/strategies/dagger.py b/src/lerobot/rollout/strategies/dagger.py index dae4cadad..95fdaf6a2 100644 --- a/src/lerobot/rollout/strategies/dagger.py +++ b/src/lerobot/rollout/strategies/dagger.py @@ -30,8 +30,10 @@ Keyboard Controls: from __future__ import annotations import contextlib +import enum import logging import time +from threading import Lock from typing import Any import numpy as np @@ -53,6 +55,99 @@ from . import RolloutStrategy, infer_action logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# DAgger state machine +# --------------------------------------------------------------------------- + + +class DAggerPhase(enum.Enum): + """Observable phases of a DAgger episode.""" + + AUTONOMOUS = "autonomous" # Policy driving, recording autonomous frames + PAUSED = "paused" # Engine paused, teleop aligned, awaiting takeover/resume + CORRECTING = "correcting" # Human driving via teleop, recording interventions + + +# Valid (current_phase, event) → next_phase +_DAGGER_TRANSITIONS: dict[tuple[DAggerPhase, str], DAggerPhase] = { + (DAggerPhase.AUTONOMOUS, "pause"): DAggerPhase.PAUSED, + (DAggerPhase.PAUSED, "takeover"): DAggerPhase.CORRECTING, + (DAggerPhase.PAUSED, "resume"): DAggerPhase.AUTONOMOUS, + (DAggerPhase.CORRECTING, "resume"): DAggerPhase.AUTONOMOUS, +} + + +class DAggerEvents: + """Thread-safe container for DAgger keyboard/pedal events. + + Replaces the previous plain dict with a lock-protected phase enum + and edge-triggered transition requests. The keyboard/pedal threads + write transition requests; the main loop consumes them. + """ + + def __init__(self) -> None: + self._lock = Lock() + self._phase = DAggerPhase.AUTONOMOUS + self._pending_transition: str | None = None + + # Episode-level flags (written by keyboard, consumed by main loop) + self.exit_early: bool = False + self.rerecord_episode: bool = False + self.stop_recording: bool = False + + # Reset-phase flags (simpler lifecycle, shared between threads) + self.in_reset: bool = False + self.start_next_episode: bool = False + + # -- Thread-safe phase access ------------------------------------------ + + @property + def phase(self) -> DAggerPhase: + with self._lock: + return self._phase + + @phase.setter + def phase(self, value: DAggerPhase) -> None: + with self._lock: + self._phase = value + + def request_transition(self, event: str) -> None: + """Request a phase transition (called from keyboard/pedal threads). + + Only enqueues the request if it corresponds to a valid transition + from the current phase, preventing impossible state changes. + """ + with self._lock: + if (self._phase, event) in _DAGGER_TRANSITIONS: + self._pending_transition = event + + def consume_transition(self) -> tuple[DAggerPhase, DAggerPhase] | None: + """Consume a pending transition (called from main loop). + + Returns ``(old_phase, new_phase)`` if a valid transition was + pending, or ``None`` if there is nothing to process. + """ + with self._lock: + if self._pending_transition is None: + return None + key = (self._phase, self._pending_transition) + self._pending_transition = None + new_phase = _DAGGER_TRANSITIONS.get(key) + if new_phase is None: + return None + old_phase = self._phase + self._phase = new_phase + return old_phase, new_phase + + def reset_for_episode(self) -> None: + """Reset all transient state at the start of an episode.""" + with self._lock: + self._phase = DAggerPhase.AUTONOMOUS + self._pending_transition = None + self.exit_early = False + self.rerecord_episode = False + + # --------------------------------------------------------------------------- # Teleoperator helpers (extracted from examples/hil/hil_utils.py) # --------------------------------------------------------------------------- @@ -99,7 +194,7 @@ def _teleop_smooth_move_to( def _reset_loop( robot: ThreadSafeRobot, teleop: Teleoperator, - events: dict, + events: DAggerEvents, fps: int, teleop_action_processor: RobotProcessorPipeline, robot_action_processor: RobotProcessorPipeline, @@ -111,24 +206,24 @@ def _reset_loop( """ logger.info("RESET — press any key to enable teleoperation") - events["in_reset"] = True - events["start_next_episode"] = False + events.in_reset = True + events.start_next_episode = False obs = robot.get_observation() robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50) - while not events["start_next_episode"] and not events["stop_recording"]: + while not events.start_next_episode and not events.stop_recording: precise_sleep(0.05) - if events["stop_recording"]: + if events.stop_recording: return - events["start_next_episode"] = False + events.start_next_episode = False _teleop_disable_torque(teleop) logger.info("Teleop enabled — press any key to start episode") - while not events["start_next_episode"] and not events["stop_recording"]: + while not events.start_next_episode and not events.stop_recording: loop_start = time.perf_counter() obs = robot.get_observation() action = teleop.get_action() @@ -137,78 +232,78 @@ def _reset_loop( robot.send_action(robot_action_to_send) precise_sleep(1 / fps - (time.perf_counter() - loop_start)) - events["in_reset"] = False - events["start_next_episode"] = False - events["exit_early"] = False - events["policy_paused"] = False - events["correction_active"] = False - events["resume_policy"] = False + events.in_reset = False + events.start_next_episode = False + events.reset_for_episode() -def _init_dagger_keyboard(): - """Initialise keyboard listener with DAgger/HIL controls.""" - events = { - "exit_early": False, - "rerecord_episode": False, - "stop_recording": False, - "policy_paused": False, - "correction_active": False, - "resume_policy": False, - "in_reset": False, - "start_next_episode": False, - } +def _init_dagger_keyboard(events: DAggerEvents): + """Initialise keyboard listener with DAgger/HIL controls. + Returns the pynput Listener (or ``None`` in headless mode). + """ if is_headless(): logger.warning("Headless environment — keyboard controls unavailable") - return None, events + return None from pynput import keyboard def on_press(key): try: - if events["in_reset"]: + # During the reset phase, only accept episode-start or stop + if events.in_reset: if ( key in [keyboard.Key.space, keyboard.Key.right] or hasattr(key, "char") and key.char == "c" ): - events["start_next_episode"] = True + events.start_next_episode = True elif key == keyboard.Key.esc: - events["stop_recording"] = True - events["start_next_episode"] = True - else: - if key == keyboard.Key.space: - if not events["policy_paused"] and not events["correction_active"]: - logger.info("PAUSED — press 'c' to take control or 'p' to resume policy") - events["policy_paused"] = True - elif hasattr(key, "char") and key.char == "c": - if events["policy_paused"] and not events["correction_active"]: - logger.info("Taking control...") - events["start_next_episode"] = True - elif hasattr(key, "char") and key.char == "p": - if events["policy_paused"] or events["correction_active"]: - logger.info("Resuming policy...") - events["resume_policy"] = True - elif key == keyboard.Key.right: - logger.info("End episode") - events["exit_early"] = True - elif key == keyboard.Key.left: - logger.info("Re-record episode") - events["rerecord_episode"] = True - events["exit_early"] = True - elif key == keyboard.Key.esc: - logger.info("Stop recording...") - events["stop_recording"] = True - events["exit_early"] = True + events.stop_recording = True + events.start_next_episode = True + return + + # Phase-aware transition requests + phase = events.phase + if key == keyboard.Key.space and phase == DAggerPhase.AUTONOMOUS: + logger.info("PAUSED — press 'c' to take control or 'p' to resume policy") + events.request_transition("pause") + elif hasattr(key, "char") and key.char == "c" and phase == DAggerPhase.PAUSED: + logger.info("Taking control...") + events.request_transition("takeover") + elif ( + hasattr(key, "char") + and key.char == "p" + and phase + in ( + DAggerPhase.PAUSED, + DAggerPhase.CORRECTING, + ) + ): + logger.info("Resuming policy...") + events.request_transition("resume") + + # Episode-level controls (valid in any phase) + elif key == keyboard.Key.right: + logger.info("End episode") + events.exit_early = True + elif key == keyboard.Key.left: + logger.info("Re-record episode") + events.rerecord_episode = True + events.exit_early = True + elif key == keyboard.Key.esc: + logger.info("Stop recording...") + events.stop_recording = True + events.exit_early = True except Exception as e: logger.debug("Key error: %s", e) listener = keyboard.Listener(on_press=on_press) listener.start() - return listener, events + return listener -def _start_pedal_listener(events: dict) -> None: +def _start_pedal_listener(events: DAggerEvents) -> None: """Start foot pedal listener thread if evdev is available.""" import threading @@ -232,18 +327,19 @@ def _start_pedal_listener(events: dict) -> None: code = code[0] if key.keystate != 1: continue - if events["in_reset"]: + if events.in_reset: if code in ["KEY_A", "KEY_C"]: - events["start_next_episode"] = True + events.start_next_episode = True else: if code not in ["KEY_A", "KEY_C"]: continue - if events["correction_active"]: - events["resume_policy"] = True - elif events["policy_paused"]: - events["start_next_episode"] = True - else: - events["policy_paused"] = True + phase = events.phase + if phase == DAggerPhase.CORRECTING: + events.request_transition("resume") + elif phase == DAggerPhase.PAUSED: + events.request_transition("takeover") + elif phase == DAggerPhase.AUTONOMOUS: + events.request_transition("pause") except (FileNotFoundError, PermissionError): pass except Exception as e: @@ -260,9 +356,11 @@ def _start_pedal_listener(events: dict) -> None: class DAggerStrategy(RolloutStrategy): """Human-in-the-Loop data collection with intervention tagging. - State machine: - AUTONOMOUS -> (SPACE) -> PAUSED -> ('c') -> TAKEOVER -> ('p') -> AUTONOMOUS - -> (->) -> save episode + Uses a formal state machine (see :class:`DAggerPhase`) for phase + transitions, eliminating impossible states:: + + AUTONOMOUS --(SPACE)--> PAUSED --(c)--> CORRECTING --(p)--> AUTONOMOUS + --(p)--> AUTONOMOUS Supports both synchronous and RTC inference backends. All actions (policy and teleop) flow through the appropriate @@ -278,12 +376,12 @@ class DAggerStrategy(RolloutStrategy): def __init__(self, config: DAggerStrategyConfig): super().__init__(config) self._listener = None - self._events: dict[str, Any] = {} + self._events = DAggerEvents() def setup(self, ctx: RolloutContext) -> None: self._init_engine(ctx) - self._listener, self._events = _init_dagger_keyboard() + self._listener = _init_dagger_keyboard(self._events) _start_pedal_listener(self._events) logger.info( @@ -302,22 +400,22 @@ class DAggerStrategy(RolloutStrategy): with VideoEncodingManager(dataset): try: recorded = 0 - while recorded < self.config.num_episodes and not events["stop_recording"]: + while recorded < self.config.num_episodes and not events.stop_recording: log_say(f"Episode {dataset.num_episodes}", self.config.play_sounds) self._run_episode(ctx) - if events["rerecord_episode"]: + if events.rerecord_episode: log_say("Re-recording", self.config.play_sounds) - events["rerecord_episode"] = False - events["exit_early"] = False + events.rerecord_episode = False + events.exit_early = False dataset.clear_episode_buffer() continue dataset.save_episode() recorded += 1 - if recorded < self.config.num_episodes and not events["stop_recording"]: + if recorded < self.config.num_episodes and not events.stop_recording: _reset_loop( ctx.robot_wrapper, teleop, @@ -371,10 +469,9 @@ class DAggerStrategy(RolloutStrategy): engine.reset() interpolator.reset() + events.reset_for_episode() _teleop_disable_torque(teleop) - was_paused = False - waiting_for_takeover = False last_action: dict[str, Any] | None = None frame_buffer: list[dict] = [] task_str = cfg.dataset.single_task if cfg.dataset else cfg.task @@ -389,59 +486,26 @@ class DAggerStrategy(RolloutStrategy): while timestamp < self.config.episode_time_s: loop_start = time.perf_counter() - if events["exit_early"]: - events["exit_early"] = False - events["policy_paused"] = False - events["correction_active"] = False - events["resume_policy"] = False + if events.exit_early: + events.exit_early = False break - # --- Resume from pause/correction --- - if events["resume_policy"] and ( - events["policy_paused"] or events["correction_active"] or waiting_for_takeover - ): - events["resume_policy"] = False - events["start_next_episode"] = False - events["policy_paused"] = False - events["correction_active"] = False - waiting_for_takeover = False - was_paused = False + # --- Process pending phase transition --- + transition = events.consume_transition() + if transition is not None: + old_phase, new_phase = transition + self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop) last_action = None - interpolator.reset() - engine.reset() - if engine.is_rtc: - engine.resume() - # --- Pause: align teleop to robot position --- - if events["policy_paused"] and not was_paused: - if engine.is_rtc: - engine.pause() - obs = robot.get_observation() - robot_pos = { - k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features - } - _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50) - events["start_next_episode"] = False - waiting_for_takeover = True - was_paused = True - interpolator.reset() - - # --- Takeover: enable teleop control --- - if waiting_for_takeover and events["start_next_episode"]: - _teleop_disable_torque(teleop) - events["start_next_episode"] = False - events["correction_active"] = True - waiting_for_takeover = False - if engine.is_rtc: - engine.reset() + phase = events.phase # --- Get observation --- obs = robot.get_observation() obs_processed = ctx.robot_observation_processor(obs) obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) - # --- CORRECTION: human teleop control --- - if events["correction_active"]: + # --- CORRECTING: human teleop control --- + if phase == DAggerPhase.CORRECTING: teleop_action = teleop.get_action() processed_teleop = ctx.teleop_action_processor((teleop_action, obs)) robot_action_to_send = ctx.robot_action_processor((processed_teleop, obs)) @@ -461,7 +525,7 @@ class DAggerStrategy(RolloutStrategy): record_tick += 1 # --- PAUSED: hold position --- - elif waiting_for_takeover or events["policy_paused"]: + elif phase == DAggerPhase.PAUSED: if last_action: robot.send_action(last_action) @@ -507,3 +571,41 @@ class DAggerStrategy(RolloutStrategy): if not stream_online: for frame in frame_buffer: dataset.add_frame(frame) + + # ------------------------------------------------------------------ + # State-machine transition side-effects + # ------------------------------------------------------------------ + + @staticmethod + def _apply_transition( + old_phase: DAggerPhase, + new_phase: DAggerPhase, + engine, + interpolator, + robot: ThreadSafeRobot, + teleop: Teleoperator, + ) -> None: + """Execute side-effects for a validated phase transition.""" + if old_phase == DAggerPhase.AUTONOMOUS and new_phase == DAggerPhase.PAUSED: + # Pause engine + align teleop to robot position + if engine.is_rtc: + engine.pause() + obs = robot.get_observation() + robot_pos = { + k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features + } + _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50) + interpolator.reset() + + elif new_phase == DAggerPhase.CORRECTING: + # Enable human teleop control + _teleop_disable_torque(teleop) + if engine.is_rtc: + engine.reset() + + elif new_phase == DAggerPhase.AUTONOMOUS: + # Resume policy from pause or correction + interpolator.reset() + engine.reset() + if engine.is_rtc: + engine.resume() diff --git a/src/lerobot/rollout/strategies/factory.py b/src/lerobot/rollout/strategies/factory.py new file mode 100644 index 000000000..9c43ea2af --- /dev/null +++ b/src/lerobot/rollout/strategies/factory.py @@ -0,0 +1,54 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Strategy factory: config type-name → strategy class dispatch.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .core import RolloutStrategy + +if TYPE_CHECKING: + from lerobot.rollout.configs import RolloutStrategyConfig + + +def _lazy_strategy_map() -> dict[str, type[RolloutStrategy]]: + """Build the strategy type-name → class mapping with lazy imports.""" + from .base import BaseStrategy + from .dagger import DAggerStrategy + from .highlight import HighlightStrategy + from .sentry import SentryStrategy + + return { + "base": BaseStrategy, + "sentry": SentryStrategy, + "highlight": HighlightStrategy, + "dagger": DAggerStrategy, + } + + +def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy: + """Instantiate the appropriate strategy from a config object. + + Uses ``config.type`` (the name registered via ``draccus.ChoiceRegistry``) + to look up the strategy class, so adding a new strategy only requires + registering its config subclass and adding one entry to + ``_lazy_strategy_map``. + """ + strategy_map = _lazy_strategy_map() + strategy_cls = strategy_map.get(config.type) + if strategy_cls is None: + raise ValueError(f"Unknown strategy type '{config.type}'. Available: {sorted(strategy_map.keys())}") + return strategy_cls(config) diff --git a/src/lerobot/rollout/strategies/sentry.py b/src/lerobot/rollout/strategies/sentry.py index a0d85747a..024584a9e 100644 --- a/src/lerobot/rollout/strategies/sentry.py +++ b/src/lerobot/rollout/strategies/sentry.py @@ -19,7 +19,7 @@ from __future__ import annotations import contextlib import logging import time -from threading import Event, Thread +from threading import Event, Lock, Thread from lerobot.datasets import VideoEncodingManager from lerobot.utils.constants import ACTION, OBS_STR @@ -46,6 +46,10 @@ class SentryStrategy(RolloutStrategy): All actions flow through ``robot_observation_processor`` (observations) and ``robot_action_processor`` (actions) before reaching the robot, supporting EE-space recording with joint-space robots. + + **Thread safety:** A lock (``_episode_lock``) serialises + ``save_episode`` and ``push_to_hub`` calls so the background push + thread never reads an episode that is still being finalised. """ config: SentryStrategyConfig @@ -54,6 +58,7 @@ class SentryStrategy(RolloutStrategy): super().__init__(config) self._push_thread: Thread | None = None self._needs_push = Event() + self._episode_lock = Lock() def setup(self, ctx: RolloutContext) -> None: self._init_engine(ctx) @@ -113,7 +118,8 @@ class SentryStrategy(RolloutStrategy): # Auto-rotate episodes elapsed = time.perf_counter() - episode_start if elapsed >= self.config.episode_duration_s: - dataset.save_episode() + with self._episode_lock: + dataset.save_episode() episodes_since_push += 1 self._needs_push.set() logger.info("Episode saved (total: %d)", dataset.num_episodes) @@ -134,7 +140,8 @@ class SentryStrategy(RolloutStrategy): finally: with contextlib.suppress(Exception): - dataset.save_episode() + with self._episode_lock: + dataset.save_episode() self._needs_push.set() def teardown(self, ctx: RolloutContext) -> None: @@ -155,17 +162,22 @@ class SentryStrategy(RolloutStrategy): logger.info("Sentry strategy teardown complete") def _background_push(self, dataset, cfg) -> None: - """Push dataset to hub in a background thread (non-blocking).""" + """Push dataset to hub in a background thread (non-blocking). + + Acquires ``_episode_lock`` during the push to prevent + ``save_episode`` from finalising a new episode mid-upload. + """ if self._push_thread is not None and self._push_thread.is_alive(): logger.info("Previous push still in progress, skipping") return def _push(): try: - dataset.push_to_hub( - tags=cfg.dataset.tags if cfg.dataset else None, - private=cfg.dataset.private if cfg.dataset else False, - ) + with self._episode_lock: + dataset.push_to_hub( + tags=cfg.dataset.tags if cfg.dataset else None, + private=cfg.dataset.private if cfg.dataset else False, + ) self._needs_push.clear() logger.info("Background push to hub complete") except Exception as e: