more improvements and fixes

This commit is contained in:
Steven Palma
2026-04-14 17:51:03 +02:00
parent 8bc47e4318
commit f2c29d78cf
9 changed files with 525 additions and 349 deletions

View File

@@ -92,10 +92,10 @@ class ActionQueue:
Returns:
int: Number of unconsumed actions.
"""
if self.queue is None:
return 0
length = len(self.queue)
return length - self.last_index
with self.lock:
if self.queue is None:
return 0
return len(self.queue) - self.last_index
def empty(self) -> bool:
"""Check if the queue is empty.
@@ -103,11 +103,10 @@ class ActionQueue:
Returns:
bool: True if no actions remain, False otherwise.
"""
if self.queue is None:
return True
length = len(self.queue)
return length - self.last_index <= 0
with self.lock:
if self.queue is None:
return True
return len(self.queue) - self.last_index <= 0
def get_action_index(self) -> int:
"""Get the current action consumption index.
@@ -115,7 +114,8 @@ class ActionQueue:
Returns:
int: Index of the next action to be consumed.
"""
return self.last_index
with self.lock:
return self.last_index
def get_left_over(self) -> Tensor | None:
"""Get leftover original actions for RTC prev_chunk_left_over.

View File

@@ -203,10 +203,13 @@ class RolloutConfig:
)
# Sentry MUST use streaming encoding to avoid disk I/O blocking the control loop
if isinstance(self.strategy, SentryStrategyConfig) and self.dataset is not None:
if not self.dataset.streaming_encoding:
logger.warning("Sentry mode forces streaming_encoding=True")
self.dataset.streaming_encoding = True
if (
isinstance(self.strategy, SentryStrategyConfig)
and self.dataset is not None
and not self.dataset.streaming_encoding
):
logger.warning("Sentry mode forces streaming_encoding=True")
self.dataset.streaming_encoding = True
@classmethod
def __get_path_fields__(cls) -> list[str]:

View File

@@ -147,6 +147,7 @@ class InferenceEngine:
use_torch_compile: bool = False,
compile_warmup_inferences: int = 2,
rtc_queue_threshold: int = 30,
shutdown_event: Event | None = None,
) -> None:
self._policy = policy
self._preprocessor = preprocessor
@@ -170,6 +171,8 @@ class InferenceEngine:
self._policy_active = Event()
self._compile_warmup_done = Event()
self._shutdown_event = Event()
self._rtc_error = Event()
self._global_shutdown_event = shutdown_event
self._rtc_thread: Thread | None = None
if not self._use_torch_compile:
@@ -211,6 +214,11 @@ class InferenceEngine:
def compile_warmup_done(self) -> Event:
return self._compile_warmup_done
@property
def rtc_failed(self) -> bool:
"""True if the RTC background thread exited due to an unrecoverable error."""
return self._rtc_error.is_set()
def start(self) -> None:
"""Start the inference engine. Launches the RTC background thread if enabled."""
if self._use_rtc:
@@ -249,8 +257,8 @@ class InferenceEngine:
self._policy.reset()
self._preprocessor.reset()
self._postprocessor.reset()
if self._use_rtc:
self._action_queue = ActionQueue(self._rtc_config)
if self._use_rtc and self._action_queue is not None:
self._action_queue.clear()
# ------------------------------------------------------------------
# Sync inference
@@ -401,3 +409,9 @@ class InferenceEngine:
except Exception as e:
logger.error("Fatal error in RTC thread: %s", e)
logger.error(traceback.format_exc())
self._rtc_error.set()
# Unblock any warmup waiters so the main loop doesn't spin forever
self._compile_warmup_done.set()
# Signal the top-level shutdown so strategies exit their control loops
if self._global_shutdown_event is not None:
self._global_shutdown_event.set()

View File

@@ -95,6 +95,6 @@ def _estimate_frame_bytes(frame: dict) -> int:
total += v.nbytes
elif isinstance(v, (int, float)):
total += 8
elif isinstance(v, str) or isinstance(v, bytes):
elif isinstance(v, (str, bytes)):
total += len(v)
return max(total, 1) # avoid zero-size frames

View File

@@ -12,209 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rollout strategy ABC, factory, and shared inference helper."""
"""Rollout strategies — public API re-exports."""
from __future__ import annotations
from .core import RolloutStrategy, infer_action
from .factory import create_strategy
import abc
import time
from typing import TYPE_CHECKING
import torch
from lerobot.policies.rtc import ActionInterpolator
from lerobot.policies.utils import make_robot_action
from lerobot.utils.constants import OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
if TYPE_CHECKING:
from lerobot.rollout.configs import RolloutStrategyConfig
from lerobot.rollout.context import RolloutContext
from lerobot.rollout.inference import InferenceEngine
class RolloutStrategy(abc.ABC):
"""Abstract base for rollout execution strategies.
Each concrete strategy implements a self-contained control loop with
its own recording/interaction semantics. Strategies are mutually
exclusive — only one runs per session.
"""
def __init__(self, config: RolloutStrategyConfig) -> None:
self.config = config
self._engine: InferenceEngine | None = None
self._interpolator: ActionInterpolator | None = None
self._warmup_flushed: bool = False
def _init_engine(self, ctx: RolloutContext) -> None:
"""Create and start the inference engine and action interpolator.
Call this from ``setup()`` to avoid duplicating the engine
construction across every strategy.
"""
from lerobot.rollout.inference import InferenceEngine
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
self._engine = InferenceEngine(
policy=ctx.policy,
preprocessor=ctx.preprocessor,
postprocessor=ctx.postprocessor,
robot_wrapper=ctx.robot_wrapper,
rtc_config=ctx.cfg.rtc,
hw_features=ctx.hw_features,
action_keys=ctx.action_keys,
task=ctx.cfg.task,
fps=ctx.cfg.fps,
device=ctx.cfg.device,
use_torch_compile=ctx.cfg.use_torch_compile,
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
)
self._engine.start()
self._warmup_flushed = False
def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool:
"""Handle torch.compile warmup phase.
Returns ``True`` if the caller should ``continue`` (still warming
up). On the first post-warmup iteration the engine and
interpolator are reset so stale warmup state is discarded.
"""
engine = self._engine
interpolator = self._interpolator
if not use_torch_compile:
return False
if not engine.compile_warmup_done.is_set():
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
return True
if not self._warmup_flushed:
engine.reset()
interpolator.reset()
self._warmup_flushed = True
if engine.is_rtc:
engine.resume()
return False
def _teardown_hardware(self, ctx: RolloutContext) -> None:
"""Stop the inference engine and disconnect hardware."""
if self._engine is not None:
self._engine.stop()
if ctx.robot.is_connected:
ctx.robot.disconnect()
if ctx.teleop is not None and ctx.teleop.is_connected:
ctx.teleop.disconnect()
@abc.abstractmethod
def setup(self, ctx: RolloutContext) -> None:
"""Strategy-specific initialisation (keyboard listeners, buffers, etc.)."""
@abc.abstractmethod
def run(self, ctx: RolloutContext) -> None:
"""Main rollout loop. Returns when shutdown is requested or duration expires."""
@abc.abstractmethod
def teardown(self, ctx: RolloutContext) -> None:
"""Cleanup: save dataset, stop threads, disconnect hardware."""
# ---------------------------------------------------------------------------
# Shared inference helper
# ---------------------------------------------------------------------------
def infer_action(
engine: InferenceEngine,
obs_processed: dict,
obs_raw: dict,
ctx: RolloutContext,
interpolator: ActionInterpolator,
ordered_keys: list[str],
features: dict,
) -> dict | None:
"""Run one policy inference step and send the resulting action to the robot.
Handles both sync and RTC backends. Uses the interpolator for smooth
control at higher-than-inference rates (works with any multiplier,
including 1 where it acts as a pass-through).
Parameters
----------
engine:
The inference engine (sync or RTC).
obs_processed:
Observation dict after ``robot_observation_processor``.
obs_raw:
Raw observation dict (needed by ``robot_action_processor``).
ctx:
Rollout context.
interpolator:
Action interpolator for Nx control rate.
ordered_keys:
Ordered action feature names (policy-to-robot mapping).
features:
Feature specification dict for ``build_dataset_frame`` /
``make_robot_action``. Use ``dataset.features`` when recording,
``ctx.dataset_features`` otherwise.
Returns
-------
Action dict sent to the robot, or ``None`` if no action was
available (empty RTC queue, interpolator buffer not ready).
"""
if engine.is_rtc:
if interpolator.needs_new_action():
action_tensor = engine.consume_rtc_action()
if action_tensor is not None:
interpolator.add(action_tensor.cpu())
else:
if interpolator.needs_new_action():
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_tensor = engine.get_action_sync(obs_frame)
action_dict = make_robot_action(action_tensor, features)
action_t = torch.tensor([action_dict[k] for k in ordered_keys])
interpolator.add(action_t)
interp = interpolator.get()
if interp is not None:
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
processed = ctx.robot_action_processor((action_dict, obs_raw))
ctx.robot_wrapper.send_action(processed)
return action_dict
return None
# ---------------------------------------------------------------------------
# Strategy factory
# ---------------------------------------------------------------------------
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
"""Instantiate the appropriate strategy from a config object."""
from lerobot.rollout.configs import (
BaseStrategyConfig,
DAggerStrategyConfig,
HighlightStrategyConfig,
SentryStrategyConfig,
)
if isinstance(config, BaseStrategyConfig):
from .base import BaseStrategy
return BaseStrategy(config)
if isinstance(config, SentryStrategyConfig):
from .sentry import SentryStrategy
return SentryStrategy(config)
if isinstance(config, HighlightStrategyConfig):
from .highlight import HighlightStrategy
return HighlightStrategy(config)
if isinstance(config, DAggerStrategyConfig):
from .dagger import DAggerStrategy
return DAggerStrategy(config)
raise ValueError(f"Unknown strategy config type: {type(config).__name__}")
__all__ = [
"RolloutStrategy",
"create_strategy",
"infer_action",
]

View File

@@ -0,0 +1,187 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rollout strategy ABC and shared inference helper."""
from __future__ import annotations
import abc
import time
from typing import TYPE_CHECKING
import torch
from lerobot.policies.rtc import ActionInterpolator
from lerobot.policies.utils import make_robot_action
from lerobot.utils.constants import OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
if TYPE_CHECKING:
from lerobot.rollout.configs import RolloutStrategyConfig
from lerobot.rollout.context import RolloutContext
from lerobot.rollout.inference import InferenceEngine
class RolloutStrategy(abc.ABC):
"""Abstract base for rollout execution strategies.
Each concrete strategy implements a self-contained control loop with
its own recording/interaction semantics. Strategies are mutually
exclusive — only one runs per session.
"""
def __init__(self, config: RolloutStrategyConfig) -> None:
self.config = config
self._engine: InferenceEngine | None = None
self._interpolator: ActionInterpolator | None = None
self._warmup_flushed: bool = False
def _init_engine(self, ctx: RolloutContext) -> None:
"""Create and start the inference engine and action interpolator.
Call this from ``setup()`` to avoid duplicating the engine
construction across every strategy.
"""
from lerobot.rollout.inference import InferenceEngine
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
self._engine = InferenceEngine(
policy=ctx.policy,
preprocessor=ctx.preprocessor,
postprocessor=ctx.postprocessor,
robot_wrapper=ctx.robot_wrapper,
rtc_config=ctx.cfg.rtc,
hw_features=ctx.hw_features,
action_keys=ctx.action_keys,
task=ctx.cfg.task,
fps=ctx.cfg.fps,
device=ctx.cfg.device,
use_torch_compile=ctx.cfg.use_torch_compile,
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
shutdown_event=ctx.shutdown_event,
)
self._engine.start()
self._warmup_flushed = False
def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool:
"""Handle torch.compile warmup phase.
Returns ``True`` if the caller should ``continue`` (still warming
up). On the first post-warmup iteration the engine and
interpolator are reset so stale warmup state is discarded.
"""
engine = self._engine
interpolator = self._interpolator
if not use_torch_compile:
return False
if not engine.compile_warmup_done.is_set():
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
return True
if not self._warmup_flushed:
engine.reset()
interpolator.reset()
self._warmup_flushed = True
if engine.is_rtc:
engine.resume()
return False
def _teardown_hardware(self, ctx: RolloutContext) -> None:
"""Stop the inference engine and disconnect hardware."""
if self._engine is not None:
self._engine.stop()
if ctx.robot.is_connected:
ctx.robot.disconnect()
if ctx.teleop is not None and ctx.teleop.is_connected:
ctx.teleop.disconnect()
@abc.abstractmethod
def setup(self, ctx: RolloutContext) -> None:
"""Strategy-specific initialisation (keyboard listeners, buffers, etc.)."""
@abc.abstractmethod
def run(self, ctx: RolloutContext) -> None:
"""Main rollout loop. Returns when shutdown is requested or duration expires."""
@abc.abstractmethod
def teardown(self, ctx: RolloutContext) -> None:
"""Cleanup: save dataset, stop threads, disconnect hardware."""
# ---------------------------------------------------------------------------
# Shared inference helper
# ---------------------------------------------------------------------------
def infer_action(
engine: InferenceEngine,
obs_processed: dict,
obs_raw: dict,
ctx: RolloutContext,
interpolator: ActionInterpolator,
ordered_keys: list[str],
features: dict,
) -> dict | None:
"""Run one policy inference step and send the resulting action to the robot.
Handles both sync and RTC backends. Uses the interpolator for smooth
control at higher-than-inference rates (works with any multiplier,
including 1 where it acts as a pass-through).
Parameters
----------
engine:
The inference engine (sync or RTC).
obs_processed:
Observation dict after ``robot_observation_processor``.
obs_raw:
Raw observation dict (needed by ``robot_action_processor``).
ctx:
Rollout context.
interpolator:
Action interpolator for Nx control rate.
ordered_keys:
Ordered action feature names (policy-to-robot mapping).
features:
Feature specification dict for ``build_dataset_frame`` /
``make_robot_action``. Use ``dataset.features`` when recording,
``ctx.dataset_features`` otherwise.
Returns
-------
Action dict sent to the robot, or ``None`` if no action was
available (empty RTC queue, interpolator buffer not ready).
"""
if engine.is_rtc:
if interpolator.needs_new_action():
action_tensor = engine.consume_rtc_action()
if action_tensor is not None:
interpolator.add(action_tensor.cpu())
else:
if interpolator.needs_new_action():
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_tensor = engine.get_action_sync(obs_frame)
action_dict = make_robot_action(action_tensor, features)
action_t = torch.tensor([action_dict[k] for k in ordered_keys])
interpolator.add(action_t)
interp = interpolator.get()
if interp is not None:
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
processed = ctx.robot_action_processor((action_dict, obs_raw))
ctx.robot_wrapper.send_action(processed)
return action_dict
return None

View File

@@ -30,8 +30,10 @@ Keyboard Controls:
from __future__ import annotations
import contextlib
import enum
import logging
import time
from threading import Lock
from typing import Any
import numpy as np
@@ -53,6 +55,99 @@ from . import RolloutStrategy, infer_action
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# DAgger state machine
# ---------------------------------------------------------------------------
class DAggerPhase(enum.Enum):
"""Observable phases of a DAgger episode."""
AUTONOMOUS = "autonomous" # Policy driving, recording autonomous frames
PAUSED = "paused" # Engine paused, teleop aligned, awaiting takeover/resume
CORRECTING = "correcting" # Human driving via teleop, recording interventions
# Valid (current_phase, event) → next_phase
_DAGGER_TRANSITIONS: dict[tuple[DAggerPhase, str], DAggerPhase] = {
(DAggerPhase.AUTONOMOUS, "pause"): DAggerPhase.PAUSED,
(DAggerPhase.PAUSED, "takeover"): DAggerPhase.CORRECTING,
(DAggerPhase.PAUSED, "resume"): DAggerPhase.AUTONOMOUS,
(DAggerPhase.CORRECTING, "resume"): DAggerPhase.AUTONOMOUS,
}
class DAggerEvents:
"""Thread-safe container for DAgger keyboard/pedal events.
Replaces the previous plain dict with a lock-protected phase enum
and edge-triggered transition requests. The keyboard/pedal threads
write transition requests; the main loop consumes them.
"""
def __init__(self) -> None:
self._lock = Lock()
self._phase = DAggerPhase.AUTONOMOUS
self._pending_transition: str | None = None
# Episode-level flags (written by keyboard, consumed by main loop)
self.exit_early: bool = False
self.rerecord_episode: bool = False
self.stop_recording: bool = False
# Reset-phase flags (simpler lifecycle, shared between threads)
self.in_reset: bool = False
self.start_next_episode: bool = False
# -- Thread-safe phase access ------------------------------------------
@property
def phase(self) -> DAggerPhase:
with self._lock:
return self._phase
@phase.setter
def phase(self, value: DAggerPhase) -> None:
with self._lock:
self._phase = value
def request_transition(self, event: str) -> None:
"""Request a phase transition (called from keyboard/pedal threads).
Only enqueues the request if it corresponds to a valid transition
from the current phase, preventing impossible state changes.
"""
with self._lock:
if (self._phase, event) in _DAGGER_TRANSITIONS:
self._pending_transition = event
def consume_transition(self) -> tuple[DAggerPhase, DAggerPhase] | None:
"""Consume a pending transition (called from main loop).
Returns ``(old_phase, new_phase)`` if a valid transition was
pending, or ``None`` if there is nothing to process.
"""
with self._lock:
if self._pending_transition is None:
return None
key = (self._phase, self._pending_transition)
self._pending_transition = None
new_phase = _DAGGER_TRANSITIONS.get(key)
if new_phase is None:
return None
old_phase = self._phase
self._phase = new_phase
return old_phase, new_phase
def reset_for_episode(self) -> None:
"""Reset all transient state at the start of an episode."""
with self._lock:
self._phase = DAggerPhase.AUTONOMOUS
self._pending_transition = None
self.exit_early = False
self.rerecord_episode = False
# ---------------------------------------------------------------------------
# Teleoperator helpers (extracted from examples/hil/hil_utils.py)
# ---------------------------------------------------------------------------
@@ -99,7 +194,7 @@ def _teleop_smooth_move_to(
def _reset_loop(
robot: ThreadSafeRobot,
teleop: Teleoperator,
events: dict,
events: DAggerEvents,
fps: int,
teleop_action_processor: RobotProcessorPipeline,
robot_action_processor: RobotProcessorPipeline,
@@ -111,24 +206,24 @@ def _reset_loop(
"""
logger.info("RESET — press any key to enable teleoperation")
events["in_reset"] = True
events["start_next_episode"] = False
events.in_reset = True
events.start_next_episode = False
obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
_teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
while not events["start_next_episode"] and not events["stop_recording"]:
while not events.start_next_episode and not events.stop_recording:
precise_sleep(0.05)
if events["stop_recording"]:
if events.stop_recording:
return
events["start_next_episode"] = False
events.start_next_episode = False
_teleop_disable_torque(teleop)
logger.info("Teleop enabled — press any key to start episode")
while not events["start_next_episode"] and not events["stop_recording"]:
while not events.start_next_episode and not events.stop_recording:
loop_start = time.perf_counter()
obs = robot.get_observation()
action = teleop.get_action()
@@ -137,78 +232,78 @@ def _reset_loop(
robot.send_action(robot_action_to_send)
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
events["in_reset"] = False
events["start_next_episode"] = False
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
events["resume_policy"] = False
events.in_reset = False
events.start_next_episode = False
events.reset_for_episode()
def _init_dagger_keyboard():
"""Initialise keyboard listener with DAgger/HIL controls."""
events = {
"exit_early": False,
"rerecord_episode": False,
"stop_recording": False,
"policy_paused": False,
"correction_active": False,
"resume_policy": False,
"in_reset": False,
"start_next_episode": False,
}
def _init_dagger_keyboard(events: DAggerEvents):
"""Initialise keyboard listener with DAgger/HIL controls.
Returns the pynput Listener (or ``None`` in headless mode).
"""
if is_headless():
logger.warning("Headless environment — keyboard controls unavailable")
return None, events
return None
from pynput import keyboard
def on_press(key):
try:
if events["in_reset"]:
# During the reset phase, only accept episode-start or stop
if events.in_reset:
if (
key in [keyboard.Key.space, keyboard.Key.right]
or hasattr(key, "char")
and key.char == "c"
):
events["start_next_episode"] = True
events.start_next_episode = True
elif key == keyboard.Key.esc:
events["stop_recording"] = True
events["start_next_episode"] = True
else:
if key == keyboard.Key.space:
if not events["policy_paused"] and not events["correction_active"]:
logger.info("PAUSED — press 'c' to take control or 'p' to resume policy")
events["policy_paused"] = True
elif hasattr(key, "char") and key.char == "c":
if events["policy_paused"] and not events["correction_active"]:
logger.info("Taking control...")
events["start_next_episode"] = True
elif hasattr(key, "char") and key.char == "p":
if events["policy_paused"] or events["correction_active"]:
logger.info("Resuming policy...")
events["resume_policy"] = True
elif key == keyboard.Key.right:
logger.info("End episode")
events["exit_early"] = True
elif key == keyboard.Key.left:
logger.info("Re-record episode")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
logger.info("Stop recording...")
events["stop_recording"] = True
events["exit_early"] = True
events.stop_recording = True
events.start_next_episode = True
return
# Phase-aware transition requests
phase = events.phase
if key == keyboard.Key.space and phase == DAggerPhase.AUTONOMOUS:
logger.info("PAUSED — press 'c' to take control or 'p' to resume policy")
events.request_transition("pause")
elif hasattr(key, "char") and key.char == "c" and phase == DAggerPhase.PAUSED:
logger.info("Taking control...")
events.request_transition("takeover")
elif (
hasattr(key, "char")
and key.char == "p"
and phase
in (
DAggerPhase.PAUSED,
DAggerPhase.CORRECTING,
)
):
logger.info("Resuming policy...")
events.request_transition("resume")
# Episode-level controls (valid in any phase)
elif key == keyboard.Key.right:
logger.info("End episode")
events.exit_early = True
elif key == keyboard.Key.left:
logger.info("Re-record episode")
events.rerecord_episode = True
events.exit_early = True
elif key == keyboard.Key.esc:
logger.info("Stop recording...")
events.stop_recording = True
events.exit_early = True
except Exception as e:
logger.debug("Key error: %s", e)
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
return listener
def _start_pedal_listener(events: dict) -> None:
def _start_pedal_listener(events: DAggerEvents) -> None:
"""Start foot pedal listener thread if evdev is available."""
import threading
@@ -232,18 +327,19 @@ def _start_pedal_listener(events: dict) -> None:
code = code[0]
if key.keystate != 1:
continue
if events["in_reset"]:
if events.in_reset:
if code in ["KEY_A", "KEY_C"]:
events["start_next_episode"] = True
events.start_next_episode = True
else:
if code not in ["KEY_A", "KEY_C"]:
continue
if events["correction_active"]:
events["resume_policy"] = True
elif events["policy_paused"]:
events["start_next_episode"] = True
else:
events["policy_paused"] = True
phase = events.phase
if phase == DAggerPhase.CORRECTING:
events.request_transition("resume")
elif phase == DAggerPhase.PAUSED:
events.request_transition("takeover")
elif phase == DAggerPhase.AUTONOMOUS:
events.request_transition("pause")
except (FileNotFoundError, PermissionError):
pass
except Exception as e:
@@ -260,9 +356,11 @@ def _start_pedal_listener(events: dict) -> None:
class DAggerStrategy(RolloutStrategy):
"""Human-in-the-Loop data collection with intervention tagging.
State machine:
AUTONOMOUS -> (SPACE) -> PAUSED -> ('c') -> TAKEOVER -> ('p') -> AUTONOMOUS
-> (->) -> save episode
Uses a formal state machine (see :class:`DAggerPhase`) for phase
transitions, eliminating impossible states::
AUTONOMOUS --(SPACE)--> PAUSED --(c)--> CORRECTING --(p)--> AUTONOMOUS
--(p)--> AUTONOMOUS
Supports both synchronous and RTC inference backends.
All actions (policy and teleop) flow through the appropriate
@@ -278,12 +376,12 @@ class DAggerStrategy(RolloutStrategy):
def __init__(self, config: DAggerStrategyConfig):
super().__init__(config)
self._listener = None
self._events: dict[str, Any] = {}
self._events = DAggerEvents()
def setup(self, ctx: RolloutContext) -> None:
self._init_engine(ctx)
self._listener, self._events = _init_dagger_keyboard()
self._listener = _init_dagger_keyboard(self._events)
_start_pedal_listener(self._events)
logger.info(
@@ -302,22 +400,22 @@ class DAggerStrategy(RolloutStrategy):
with VideoEncodingManager(dataset):
try:
recorded = 0
while recorded < self.config.num_episodes and not events["stop_recording"]:
while recorded < self.config.num_episodes and not events.stop_recording:
log_say(f"Episode {dataset.num_episodes}", self.config.play_sounds)
self._run_episode(ctx)
if events["rerecord_episode"]:
if events.rerecord_episode:
log_say("Re-recording", self.config.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
events.rerecord_episode = False
events.exit_early = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded += 1
if recorded < self.config.num_episodes and not events["stop_recording"]:
if recorded < self.config.num_episodes and not events.stop_recording:
_reset_loop(
ctx.robot_wrapper,
teleop,
@@ -371,10 +469,9 @@ class DAggerStrategy(RolloutStrategy):
engine.reset()
interpolator.reset()
events.reset_for_episode()
_teleop_disable_torque(teleop)
was_paused = False
waiting_for_takeover = False
last_action: dict[str, Any] | None = None
frame_buffer: list[dict] = []
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
@@ -389,59 +486,26 @@ class DAggerStrategy(RolloutStrategy):
while timestamp < self.config.episode_time_s:
loop_start = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
events["resume_policy"] = False
if events.exit_early:
events.exit_early = False
break
# --- Resume from pause/correction ---
if events["resume_policy"] and (
events["policy_paused"] or events["correction_active"] or waiting_for_takeover
):
events["resume_policy"] = False
events["start_next_episode"] = False
events["policy_paused"] = False
events["correction_active"] = False
waiting_for_takeover = False
was_paused = False
# --- Process pending phase transition ---
transition = events.consume_transition()
if transition is not None:
old_phase, new_phase = transition
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
last_action = None
interpolator.reset()
engine.reset()
if engine.is_rtc:
engine.resume()
# --- Pause: align teleop to robot position ---
if events["policy_paused"] and not was_paused:
if engine.is_rtc:
engine.pause()
obs = robot.get_observation()
robot_pos = {
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
}
_teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
events["start_next_episode"] = False
waiting_for_takeover = True
was_paused = True
interpolator.reset()
# --- Takeover: enable teleop control ---
if waiting_for_takeover and events["start_next_episode"]:
_teleop_disable_torque(teleop)
events["start_next_episode"] = False
events["correction_active"] = True
waiting_for_takeover = False
if engine.is_rtc:
engine.reset()
phase = events.phase
# --- Get observation ---
obs = robot.get_observation()
obs_processed = ctx.robot_observation_processor(obs)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
# --- CORRECTION: human teleop control ---
if events["correction_active"]:
# --- CORRECTING: human teleop control ---
if phase == DAggerPhase.CORRECTING:
teleop_action = teleop.get_action()
processed_teleop = ctx.teleop_action_processor((teleop_action, obs))
robot_action_to_send = ctx.robot_action_processor((processed_teleop, obs))
@@ -461,7 +525,7 @@ class DAggerStrategy(RolloutStrategy):
record_tick += 1
# --- PAUSED: hold position ---
elif waiting_for_takeover or events["policy_paused"]:
elif phase == DAggerPhase.PAUSED:
if last_action:
robot.send_action(last_action)
@@ -507,3 +571,41 @@ class DAggerStrategy(RolloutStrategy):
if not stream_online:
for frame in frame_buffer:
dataset.add_frame(frame)
# ------------------------------------------------------------------
# State-machine transition side-effects
# ------------------------------------------------------------------
@staticmethod
def _apply_transition(
old_phase: DAggerPhase,
new_phase: DAggerPhase,
engine,
interpolator,
robot: ThreadSafeRobot,
teleop: Teleoperator,
) -> None:
"""Execute side-effects for a validated phase transition."""
if old_phase == DAggerPhase.AUTONOMOUS and new_phase == DAggerPhase.PAUSED:
# Pause engine + align teleop to robot position
if engine.is_rtc:
engine.pause()
obs = robot.get_observation()
robot_pos = {
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
}
_teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
interpolator.reset()
elif new_phase == DAggerPhase.CORRECTING:
# Enable human teleop control
_teleop_disable_torque(teleop)
if engine.is_rtc:
engine.reset()
elif new_phase == DAggerPhase.AUTONOMOUS:
# Resume policy from pause or correction
interpolator.reset()
engine.reset()
if engine.is_rtc:
engine.resume()

View File

@@ -0,0 +1,54 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Strategy factory: config type-name → strategy class dispatch."""
from __future__ import annotations
from typing import TYPE_CHECKING
from .core import RolloutStrategy
if TYPE_CHECKING:
from lerobot.rollout.configs import RolloutStrategyConfig
def _lazy_strategy_map() -> dict[str, type[RolloutStrategy]]:
"""Build the strategy type-name → class mapping with lazy imports."""
from .base import BaseStrategy
from .dagger import DAggerStrategy
from .highlight import HighlightStrategy
from .sentry import SentryStrategy
return {
"base": BaseStrategy,
"sentry": SentryStrategy,
"highlight": HighlightStrategy,
"dagger": DAggerStrategy,
}
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
"""Instantiate the appropriate strategy from a config object.
Uses ``config.type`` (the name registered via ``draccus.ChoiceRegistry``)
to look up the strategy class, so adding a new strategy only requires
registering its config subclass and adding one entry to
``_lazy_strategy_map``.
"""
strategy_map = _lazy_strategy_map()
strategy_cls = strategy_map.get(config.type)
if strategy_cls is None:
raise ValueError(f"Unknown strategy type '{config.type}'. Available: {sorted(strategy_map.keys())}")
return strategy_cls(config)

View File

@@ -19,7 +19,7 @@ from __future__ import annotations
import contextlib
import logging
import time
from threading import Event, Thread
from threading import Event, Lock, Thread
from lerobot.datasets import VideoEncodingManager
from lerobot.utils.constants import ACTION, OBS_STR
@@ -46,6 +46,10 @@ class SentryStrategy(RolloutStrategy):
All actions flow through ``robot_observation_processor`` (observations)
and ``robot_action_processor`` (actions) before reaching the robot,
supporting EE-space recording with joint-space robots.
**Thread safety:** A lock (``_episode_lock``) serialises
``save_episode`` and ``push_to_hub`` calls so the background push
thread never reads an episode that is still being finalised.
"""
config: SentryStrategyConfig
@@ -54,6 +58,7 @@ class SentryStrategy(RolloutStrategy):
super().__init__(config)
self._push_thread: Thread | None = None
self._needs_push = Event()
self._episode_lock = Lock()
def setup(self, ctx: RolloutContext) -> None:
self._init_engine(ctx)
@@ -113,7 +118,8 @@ class SentryStrategy(RolloutStrategy):
# Auto-rotate episodes
elapsed = time.perf_counter() - episode_start
if elapsed >= self.config.episode_duration_s:
dataset.save_episode()
with self._episode_lock:
dataset.save_episode()
episodes_since_push += 1
self._needs_push.set()
logger.info("Episode saved (total: %d)", dataset.num_episodes)
@@ -134,7 +140,8 @@ class SentryStrategy(RolloutStrategy):
finally:
with contextlib.suppress(Exception):
dataset.save_episode()
with self._episode_lock:
dataset.save_episode()
self._needs_push.set()
def teardown(self, ctx: RolloutContext) -> None:
@@ -155,17 +162,22 @@ class SentryStrategy(RolloutStrategy):
logger.info("Sentry strategy teardown complete")
def _background_push(self, dataset, cfg) -> None:
"""Push dataset to hub in a background thread (non-blocking)."""
"""Push dataset to hub in a background thread (non-blocking).
Acquires ``_episode_lock`` during the push to prevent
``save_episode`` from finalising a new episode mid-upload.
"""
if self._push_thread is not None and self._push_thread.is_alive():
logger.info("Previous push still in progress, skipping")
return
def _push():
try:
dataset.push_to_hub(
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
)
with self._episode_lock:
dataset.push_to_hub(
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
)
self._needs_push.clear()
logger.info("Background push to hub complete")
except Exception as e: