mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-05 05:11:25 +00:00
simplify dagger
This commit is contained in:
@@ -16,6 +16,8 @@
|
||||
|
||||
from .configs import (
|
||||
BaseStrategyConfig,
|
||||
DAggerKeyboardConfig,
|
||||
DAggerPedalConfig,
|
||||
DAggerStrategyConfig,
|
||||
DatasetRecordConfig,
|
||||
HighlightStrategyConfig,
|
||||
@@ -39,6 +41,8 @@ from .strategies import RolloutStrategy, create_strategy
|
||||
|
||||
__all__ = [
|
||||
"BaseStrategyConfig",
|
||||
"DAggerKeyboardConfig",
|
||||
"DAggerPedalConfig",
|
||||
"DAggerStrategyConfig",
|
||||
"HighlightStrategyConfig",
|
||||
"InferenceEngine",
|
||||
|
||||
@@ -66,7 +66,7 @@ class SentryStrategyConfig(RolloutStrategyConfig):
|
||||
uploaded in the background every ``upload_every_n_episodes`` episodes.
|
||||
"""
|
||||
|
||||
episode_duration_s: float = 120.0
|
||||
episode_duration_s: float = 20.0
|
||||
upload_every_n_episodes: int = 5
|
||||
|
||||
|
||||
@@ -87,6 +87,32 @@ class HighlightStrategyConfig(RolloutStrategyConfig):
|
||||
push_key: str = "h"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DAggerKeyboardConfig:
|
||||
"""Keyboard key bindings for DAgger controls.
|
||||
|
||||
Keys are specified as single characters (e.g. ``"c"``, ``"h"``) or
|
||||
special key names (``"space"``).
|
||||
"""
|
||||
|
||||
pause_resume: str = "space"
|
||||
correction: str = "c"
|
||||
upload: str = "h"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DAggerPedalConfig:
|
||||
"""Foot pedal configuration for DAgger controls.
|
||||
|
||||
Pedal codes are evdev key code strings (e.g. ``"KEY_A"``).
|
||||
"""
|
||||
|
||||
device_path: str = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
||||
pause_resume: str = "KEY_A"
|
||||
correction: str = "KEY_B"
|
||||
upload: str = "KEY_C"
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("dagger")
|
||||
@dataclass
|
||||
class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||
@@ -95,19 +121,30 @@ class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||
Alternates between autonomous policy execution and human intervention.
|
||||
Intervention frames are tagged with ``intervention=True``.
|
||||
|
||||
Input is controlled via either a keyboard or foot pedal, selected by
|
||||
``input_device``. Each device exposes three actions:
|
||||
|
||||
1. **pause_resume** — toggle policy execution on/off.
|
||||
2. **correction** — toggle human correction recording.
|
||||
3. **upload** — push dataset to hub on demand (corrections-only mode).
|
||||
|
||||
When ``record_autonomous=True`` (default) both autonomous and correction
|
||||
frames are recorded — this requires streaming encoding so the policy
|
||||
loop never blocks on disk I/O. Set to ``False`` to record only the
|
||||
human-correction windows; encoding can then happen between phases.
|
||||
frames are recorded with sentry-like time-based episode rotation and
|
||||
background uploading. Set to ``False`` to record only the human-correction
|
||||
windows, where each correction becomes its own episode.
|
||||
"""
|
||||
|
||||
episode_time_s: float = 120.0
|
||||
num_episodes: int = 50
|
||||
play_sounds: bool = True
|
||||
calibrate: bool = False
|
||||
log_hz: bool = True
|
||||
hz_log_interval_s: float = 2.0
|
||||
record_autonomous: bool = True
|
||||
episode_time_s: float = 20.0
|
||||
num_episodes: int = 10
|
||||
record_autonomous: bool = False
|
||||
upload_every_n_episodes: int = 5
|
||||
input_device: str = "keyboard"
|
||||
keyboard: DAggerKeyboardConfig = field(default_factory=DAggerKeyboardConfig)
|
||||
pedal: DAggerPedalConfig = field(default_factory=DAggerPedalConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.input_device not in ("keyboard", "pedal"):
|
||||
raise ValueError(f"DAgger input_device must be 'keyboard' or 'pedal', got '{self.input_device}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -160,9 +197,7 @@ class RolloutConfig:
|
||||
if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None:
|
||||
raise ValueError("DAgger strategy requires --teleop.type to be set")
|
||||
|
||||
needs_dataset = isinstance(
|
||||
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
|
||||
)
|
||||
needs_dataset = isinstance(self.strategy, (SentryStrategyConfig, HighlightStrategyConfig))
|
||||
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
||||
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
||||
|
||||
|
||||
@@ -234,6 +234,18 @@ def build_rollout_context(
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
teleop.connect()
|
||||
|
||||
# DAgger requires teleop with motor control capabilities (enable_torque,
|
||||
# disable_torque, write_goal_positions).
|
||||
if isinstance(cfg.strategy, DAggerStrategyConfig) and teleop is not None:
|
||||
required_teleop_methods = ("enable_torque", "disable_torque", "write_goal_positions")
|
||||
missing = [m for m in required_teleop_methods if not callable(getattr(teleop, m, None))]
|
||||
if missing:
|
||||
teleop.disconnect()
|
||||
raise ValueError(
|
||||
f"DAgger strategy requires a teleoperator with motor control methods "
|
||||
f"{required_teleop_methods}. '{type(teleop).__name__}' is missing: {missing}"
|
||||
)
|
||||
|
||||
# --- 4. Features + action-key reconciliation ---------------------
|
||||
all_obs_features = robot.observation_features
|
||||
observation_features_hw = {
|
||||
|
||||
@@ -18,13 +18,20 @@ Implements the RaC paradigm (Recovery and Correction) for interactive
|
||||
imitation learning. Alternates between autonomous policy execution and
|
||||
human intervention via teleoperator.
|
||||
|
||||
Keyboard Controls:
|
||||
SPACE - Pause policy (robot holds position, no recording)
|
||||
c - Take control (start correction, recording resumes)
|
||||
p - Resume policy after pause/correction
|
||||
-> - End episode (save and continue)
|
||||
<- - Re-record episode
|
||||
ESC - Stop recording and push to hub
|
||||
Input is controlled via either a keyboard or foot pedal, selected by
|
||||
the ``input_device`` config field. Each device exposes three actions:
|
||||
|
||||
1. **pause_resume** — Toggle policy execution (AUTONOMOUS <-> PAUSED).
|
||||
2. **correction** — Toggle correction recording (PAUSED <-> CORRECTING).
|
||||
3. **upload** — Push dataset to hub on demand (corrections-only mode).
|
||||
ESC (keyboard only) — Stop session.
|
||||
|
||||
Recording Modes:
|
||||
``record_autonomous=True``: Sentry-like continuous recording with
|
||||
time-based episode rotation. Both autonomous and correction
|
||||
frames are recorded; corrections tagged ``intervention=True``.
|
||||
``record_autonomous=False``: Only correction windows are recorded.
|
||||
Each correction (start to stop) becomes one episode.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -32,7 +39,10 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import Event, Lock
|
||||
from typing import Any
|
||||
|
||||
@@ -40,19 +50,31 @@ import numpy as np
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available
|
||||
from lerobot.utils.pedal import start_pedal_listener
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
from ..configs import DAggerStrategyConfig
|
||||
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .core import RolloutStrategy, send_next_action
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
keyboard = None
|
||||
if PYNPUT_AVAILABLE:
|
||||
try:
|
||||
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
||||
logging.info("No DISPLAY set. Skipping pynput import.")
|
||||
PYNPUT_AVAILABLE = False
|
||||
else:
|
||||
from pynput import keyboard
|
||||
except Exception as e:
|
||||
PYNPUT_AVAILABLE = False
|
||||
logging.info(f"Could not import pynput: {e}")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -64,22 +86,22 @@ logger = logging.getLogger(__name__)
|
||||
class DAggerPhase(enum.Enum):
|
||||
"""Observable phases of a DAgger episode."""
|
||||
|
||||
AUTONOMOUS = "autonomous" # Policy driving, recording autonomous frames
|
||||
PAUSED = "paused" # Engine paused, teleop aligned, awaiting takeover/resume
|
||||
AUTONOMOUS = "autonomous" # Policy driving
|
||||
PAUSED = "paused" # Engine paused, teleop aligned, awaiting input
|
||||
CORRECTING = "correcting" # Human driving via teleop, recording interventions
|
||||
|
||||
|
||||
# Valid (current_phase, event) → next_phase
|
||||
# Valid (current_phase, event) -> next_phase
|
||||
_DAGGER_TRANSITIONS: dict[tuple[DAggerPhase, str], DAggerPhase] = {
|
||||
(DAggerPhase.AUTONOMOUS, "pause"): DAggerPhase.PAUSED,
|
||||
(DAggerPhase.PAUSED, "takeover"): DAggerPhase.CORRECTING,
|
||||
(DAggerPhase.PAUSED, "resume"): DAggerPhase.AUTONOMOUS,
|
||||
(DAggerPhase.CORRECTING, "resume"): DAggerPhase.AUTONOMOUS,
|
||||
(DAggerPhase.AUTONOMOUS, "pause_resume"): DAggerPhase.PAUSED,
|
||||
(DAggerPhase.PAUSED, "pause_resume"): DAggerPhase.AUTONOMOUS,
|
||||
(DAggerPhase.PAUSED, "correction"): DAggerPhase.CORRECTING,
|
||||
(DAggerPhase.CORRECTING, "correction"): DAggerPhase.PAUSED,
|
||||
}
|
||||
|
||||
|
||||
class DAggerEvents:
|
||||
"""Thread-safe container for DAgger keyboard/pedal events.
|
||||
"""Thread-safe container for DAgger input device events.
|
||||
|
||||
The keyboard/pedal threads write transition requests; the main loop
|
||||
consumes them.
|
||||
@@ -90,16 +112,9 @@ class DAggerEvents:
|
||||
self._phase = DAggerPhase.AUTONOMOUS
|
||||
self._pending_transition: str | None = None
|
||||
|
||||
# Episode-level flags written by keyboard/pedal threads, consumed by
|
||||
# the main loop. ``threading.Event`` gives us atomic set/clear/check
|
||||
# semantics without taking ``self._lock``.
|
||||
self.exit_early = Event()
|
||||
self.rerecord_episode = Event()
|
||||
# Session-level flags
|
||||
self.stop_recording = Event()
|
||||
|
||||
# Reset-phase flags (simpler lifecycle, shared between threads).
|
||||
self.in_reset = Event()
|
||||
self.start_next_episode = Event()
|
||||
self.upload_requested = Event()
|
||||
|
||||
# -- Thread-safe phase access ------------------------------------------
|
||||
|
||||
@@ -138,13 +153,12 @@ class DAggerEvents:
|
||||
self._phase = new_phase
|
||||
return old_phase, new_phase
|
||||
|
||||
def reset_for_episode(self) -> None:
|
||||
"""Reset all transient state at the start of an episode."""
|
||||
def reset(self) -> None:
|
||||
"""Reset all transient state for a fresh session."""
|
||||
with self._lock:
|
||||
self._phase = DAggerPhase.AUTONOMOUS
|
||||
self._pending_transition = None
|
||||
self.exit_early.clear()
|
||||
self.rerecord_episode.clear()
|
||||
self.upload_requested.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -152,29 +166,15 @@ class DAggerEvents:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _teleop_has_motor_control(teleop: Teleoperator) -> bool:
|
||||
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
|
||||
|
||||
|
||||
def _teleop_disable_torque(teleop: Teleoperator) -> None:
|
||||
if hasattr(teleop, "disable_torque"):
|
||||
teleop.disable_torque()
|
||||
|
||||
|
||||
def _teleop_enable_torque(teleop: Teleoperator) -> None:
|
||||
if hasattr(teleop, "enable_torque"):
|
||||
teleop.enable_torque()
|
||||
|
||||
|
||||
def _teleop_smooth_move_to(
|
||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50
|
||||
) -> None:
|
||||
"""Smoothly move teleop to target position if motor control is available."""
|
||||
if not _teleop_has_motor_control(teleop):
|
||||
logger.warning("Teleop does not support motor control — cannot mirror robot position")
|
||||
return
|
||||
"""Smoothly move teleop to target position via linear interpolation.
|
||||
|
||||
_teleop_enable_torque(teleop)
|
||||
The teleoperator is guaranteed to have motor control methods
|
||||
(validated at context build time).
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
@@ -190,103 +190,58 @@ def _teleop_smooth_move_to(
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def _reset_loop(
|
||||
robot: ThreadSafeRobot,
|
||||
teleop: Teleoperator,
|
||||
events: DAggerEvents,
|
||||
fps: int,
|
||||
teleop_action_processor: RobotProcessorPipeline,
|
||||
robot_action_processor: RobotProcessorPipeline,
|
||||
) -> None:
|
||||
"""Reset period where the human repositions the environment."""
|
||||
logger.info("RESET — press any key to enable teleoperation")
|
||||
|
||||
events.in_reset.set()
|
||||
events.start_next_episode.clear()
|
||||
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
_teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
while not events.start_next_episode.is_set() and not events.stop_recording.is_set():
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events.stop_recording.is_set():
|
||||
return
|
||||
|
||||
events.start_next_episode.clear()
|
||||
_teleop_disable_torque(teleop)
|
||||
logger.info("Teleop enabled — press any key to start episode")
|
||||
|
||||
while not events.start_next_episode.is_set() and not events.stop_recording.is_set():
|
||||
loop_start = time.perf_counter()
|
||||
obs = robot.get_observation()
|
||||
action = teleop.get_action()
|
||||
processed_teleop = teleop_action_processor((action, obs))
|
||||
robot_action_to_send = robot_action_processor((processed_teleop, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
|
||||
|
||||
events.in_reset.clear()
|
||||
events.start_next_episode.clear()
|
||||
events.reset_for_episode()
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input device handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _init_dagger_keyboard(events: DAggerEvents):
|
||||
"""Initialise keyboard listener with DAgger/HIL controls.
|
||||
def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
|
||||
"""Initialise keyboard listener with DAgger 3-key controls.
|
||||
|
||||
Returns the pynput Listener (or ``None`` in headless mode).
|
||||
Returns the pynput Listener (or ``None`` in headless mode or when
|
||||
pynput is unavailable).
|
||||
"""
|
||||
if is_headless():
|
||||
logger.warning("Headless environment — keyboard controls unavailable")
|
||||
if not PYNPUT_AVAILABLE or is_headless():
|
||||
logger.warning("Headless environment or pynput unavailable — keyboard controls disabled")
|
||||
return None
|
||||
|
||||
from pynput import keyboard
|
||||
# Map config key names to pynput Key objects for special keys
|
||||
special_keys = {
|
||||
"space": keyboard.Key.space,
|
||||
"tab": keyboard.Key.tab,
|
||||
"enter": keyboard.Key.enter,
|
||||
}
|
||||
|
||||
def _resolve_key(key) -> str | None:
|
||||
"""Resolve a pynput key event to a config-comparable string."""
|
||||
if key == keyboard.Key.esc:
|
||||
return "esc"
|
||||
for name, pynput_key in special_keys.items():
|
||||
if key == pynput_key:
|
||||
return name
|
||||
if hasattr(key, "char") and key.char:
|
||||
return key.char
|
||||
return None
|
||||
|
||||
# Build mapping: resolved key string -> DAgger event name
|
||||
key_to_event = {
|
||||
cfg.pause_resume: "pause_resume",
|
||||
cfg.correction: "correction",
|
||||
}
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events.in_reset.is_set():
|
||||
if (
|
||||
key in [keyboard.Key.space, keyboard.Key.right]
|
||||
or hasattr(key, "char")
|
||||
and key.char == "c"
|
||||
):
|
||||
events.start_next_episode.set()
|
||||
elif key == keyboard.Key.esc:
|
||||
events.stop_recording.set()
|
||||
events.start_next_episode.set()
|
||||
resolved = _resolve_key(key)
|
||||
if resolved is None:
|
||||
return
|
||||
|
||||
phase = events.phase
|
||||
if key == keyboard.Key.space and phase == DAggerPhase.AUTONOMOUS:
|
||||
logger.info("PAUSED — press 'c' to take control or 'p' to resume policy")
|
||||
events.request_transition("pause")
|
||||
elif hasattr(key, "char") and key.char == "c" and phase == DAggerPhase.PAUSED:
|
||||
logger.info("Taking control...")
|
||||
events.request_transition("takeover")
|
||||
elif (
|
||||
hasattr(key, "char")
|
||||
and key.char == "p"
|
||||
and phase
|
||||
in (
|
||||
DAggerPhase.PAUSED,
|
||||
DAggerPhase.CORRECTING,
|
||||
)
|
||||
):
|
||||
logger.info("Resuming policy...")
|
||||
events.request_transition("resume")
|
||||
|
||||
elif key == keyboard.Key.right:
|
||||
logger.info("End episode")
|
||||
events.exit_early.set()
|
||||
elif key == keyboard.Key.left:
|
||||
logger.info("Re-record episode")
|
||||
events.rerecord_episode.set()
|
||||
events.exit_early.set()
|
||||
elif key == keyboard.Key.esc:
|
||||
if resolved == "esc":
|
||||
logger.info("Stop recording...")
|
||||
events.stop_recording.set()
|
||||
events.exit_early.set()
|
||||
return
|
||||
if resolved in key_to_event:
|
||||
events.request_transition(key_to_event[resolved])
|
||||
if resolved == cfg.upload:
|
||||
events.upload_requested.set()
|
||||
except Exception as e:
|
||||
logger.debug("Key error: %s", e)
|
||||
|
||||
@@ -295,27 +250,23 @@ def _init_dagger_keyboard(events: DAggerEvents):
|
||||
return listener
|
||||
|
||||
|
||||
_DAGGER_PEDAL_KEYS = ("KEY_A", "KEY_C")
|
||||
def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig):
|
||||
"""Initialise foot pedal listener with DAgger 3-pedal controls.
|
||||
|
||||
|
||||
def _dagger_pedal_callback(events: DAggerEvents):
|
||||
"""Build the pedal key-press handler for DAgger's state machine."""
|
||||
Returns the pedal listener thread (or ``None`` if evdev is unavailable).
|
||||
"""
|
||||
code_to_event = {
|
||||
cfg.pause_resume: "pause_resume",
|
||||
cfg.correction: "correction",
|
||||
}
|
||||
|
||||
def on_press(code: str) -> None:
|
||||
if code not in _DAGGER_PEDAL_KEYS:
|
||||
return
|
||||
if events.in_reset.is_set():
|
||||
events.start_next_episode.set()
|
||||
return
|
||||
phase = events.phase
|
||||
if phase == DAggerPhase.CORRECTING:
|
||||
events.request_transition("resume")
|
||||
elif phase == DAggerPhase.PAUSED:
|
||||
events.request_transition("takeover")
|
||||
elif phase == DAggerPhase.AUTONOMOUS:
|
||||
events.request_transition("pause")
|
||||
if code in code_to_event:
|
||||
events.request_transition(code_to_event[code])
|
||||
if code == cfg.upload:
|
||||
events.upload_requested.set()
|
||||
|
||||
return on_press
|
||||
return start_pedal_listener(on_press, device_path=cfg.device_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -328,12 +279,14 @@ class DAggerStrategy(RolloutStrategy):
|
||||
|
||||
State machine::
|
||||
|
||||
AUTONOMOUS --(SPACE)--> PAUSED --(c)--> CORRECTING --(p)--> AUTONOMOUS
|
||||
--(p)--> AUTONOMOUS
|
||||
AUTONOMOUS --(key1)--> PAUSED --(key2)--> CORRECTING --(key2)--> PAUSED
|
||||
--(key1)--> AUTONOMOUS
|
||||
|
||||
Intervention frames are tagged with ``intervention=True`` (bool) in
|
||||
the dataset; autonomous frames with ``intervention=False``. When
|
||||
``record_autonomous=False`` only corrections are recorded.
|
||||
Recording modes:
|
||||
``record_autonomous=True``: Sentry-like continuous recording with
|
||||
time-based episode rotation. Intervention frames tagged True.
|
||||
``record_autonomous=False``: Only correction windows recorded.
|
||||
Each correction = one episode. Upload on demand via key3.
|
||||
"""
|
||||
|
||||
config: DAggerStrategyConfig
|
||||
@@ -341,71 +294,51 @@ class DAggerStrategy(RolloutStrategy):
|
||||
def __init__(self, config: DAggerStrategyConfig):
|
||||
super().__init__(config)
|
||||
self._listener = None
|
||||
self._pedal_thread = None
|
||||
self._events = DAggerEvents()
|
||||
self._push_executor: ThreadPoolExecutor | None = None
|
||||
self._pending_push: Future | None = None
|
||||
self._needs_push = Event()
|
||||
self._episode_lock = Lock()
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Initialise the inference engine, keyboard listener, and pedal handler."""
|
||||
"""Initialise the inference engine and input device listener."""
|
||||
self._init_engine(ctx)
|
||||
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="dagger-push")
|
||||
|
||||
self._listener = _init_dagger_keyboard(self._events)
|
||||
start_pedal_listener(_dagger_pedal_callback(self._events))
|
||||
if self.config.input_device == "keyboard":
|
||||
self._listener = _init_dagger_keyboard(self._events, self.config.keyboard)
|
||||
else:
|
||||
self._pedal_thread = _init_dagger_pedal(self._events, self.config.pedal)
|
||||
|
||||
record_mode = "all frames (sentry-like)" if self.config.record_autonomous else "corrections only"
|
||||
logger.info(
|
||||
"DAgger strategy ready (episodes=%d, episode_time=%.0fs, record_autonomous=%s)",
|
||||
"DAgger strategy ready (input=%s, episodes=%d, record=%s)",
|
||||
self.config.input_device,
|
||||
self.config.num_episodes,
|
||||
self.config.episode_time_s,
|
||||
self.config.record_autonomous,
|
||||
record_mode,
|
||||
)
|
||||
logger.info("Controls: SPACE=pause, c=take control, p=resume, ->=end, <-=redo, ESC=stop")
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Run DAgger episodes with human-in-the-loop intervention."""
|
||||
dataset = ctx.data.dataset
|
||||
events = self._events
|
||||
teleop = ctx.hardware.teleop
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
recorded = 0
|
||||
while recorded < self.config.num_episodes and not events.stop_recording.is_set():
|
||||
log_say(f"Episode {dataset.num_episodes}", self.config.play_sounds)
|
||||
|
||||
self._run_episode(ctx)
|
||||
|
||||
if events.rerecord_episode.is_set():
|
||||
log_say("Re-recording", self.config.play_sounds)
|
||||
events.rerecord_episode.clear()
|
||||
events.exit_early.clear()
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
|
||||
if recorded < self.config.num_episodes and not events.stop_recording.is_set():
|
||||
_reset_loop(
|
||||
ctx.hardware.robot_wrapper,
|
||||
teleop,
|
||||
events,
|
||||
int(ctx.runtime.cfg.fps),
|
||||
ctx.processors.teleop_action_processor,
|
||||
ctx.processors.robot_action_processor,
|
||||
)
|
||||
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
dataset.save_episode()
|
||||
if self.config.record_autonomous:
|
||||
self._run_continuous(ctx)
|
||||
else:
|
||||
self._run_corrections_only(ctx)
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Stop listeners, finalise the dataset, and disconnect hardware."""
|
||||
log_say("Stop recording", self.config.play_sounds, blocking=True)
|
||||
|
||||
if self._listener is not None and not is_headless():
|
||||
self._listener.stop()
|
||||
|
||||
# Flush any queued/running push cleanly
|
||||
if self._push_executor is not None:
|
||||
self._push_executor.shutdown(wait=True)
|
||||
self._push_executor = None
|
||||
|
||||
if ctx.data.dataset is not None:
|
||||
ctx.data.dataset.finalize()
|
||||
if ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
|
||||
if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
|
||||
ctx.data.dataset.push_to_hub(
|
||||
tags=ctx.runtime.cfg.dataset.tags,
|
||||
private=ctx.runtime.cfg.dataset.private,
|
||||
@@ -415,11 +348,17 @@ class DAggerStrategy(RolloutStrategy):
|
||||
logger.info("DAgger strategy teardown complete")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Episode rollout (state machine)
|
||||
# Continuous recording mode (record_autonomous=True)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_episode(self, ctx: RolloutContext) -> None:
|
||||
"""Run a single DAgger episode with the HIL state machine."""
|
||||
def _run_continuous(self, ctx: RolloutContext) -> None:
|
||||
"""Sentry-like continuous recording with intervention tagging.
|
||||
|
||||
Episodes are auto-rotated every ``episode_time_s`` seconds and
|
||||
uploaded in the background every ``upload_every_n_episodes`` episodes.
|
||||
Both autonomous and correction frames are recorded; corrections are
|
||||
tagged with ``intervention=True``.
|
||||
"""
|
||||
engine = self._engine
|
||||
cfg = ctx.runtime.cfg
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
@@ -427,111 +366,231 @@ class DAggerStrategy(RolloutStrategy):
|
||||
dataset = ctx.data.dataset
|
||||
events = self._events
|
||||
interpolator = self._interpolator
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
stream_online = bool(cfg.dataset.streaming_encoding) if cfg.dataset else False
|
||||
record_stride = max(1, cfg.interpolation_multiplier)
|
||||
record_autonomous = self.config.record_autonomous
|
||||
|
||||
features = ctx.data.dataset_features
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
events.reset_for_episode()
|
||||
_teleop_disable_torque(teleop)
|
||||
|
||||
last_action: dict[str, Any] | None = None
|
||||
frame_buffer: list[dict] = []
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
|
||||
timestamp = 0.0
|
||||
record_tick = 0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
events.reset()
|
||||
teleop.disable_torque()
|
||||
engine.resume()
|
||||
|
||||
while timestamp < self.config.episode_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
last_action: dict[str, Any] | None = None
|
||||
record_tick = 0
|
||||
episode_start = time.perf_counter()
|
||||
start_time = time.perf_counter()
|
||||
episodes_since_push = 0
|
||||
|
||||
if events.exit_early.is_set():
|
||||
events.exit_early.clear()
|
||||
break
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
while not events.stop_recording.is_set() and not ctx.runtime.shutdown_event.is_set():
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
transition = events.consume_transition()
|
||||
if transition is not None:
|
||||
old_phase, new_phase = transition
|
||||
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
|
||||
last_action = None
|
||||
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
|
||||
break
|
||||
|
||||
phase = events.phase
|
||||
# Process transitions
|
||||
transition = events.consume_transition()
|
||||
if transition is not None:
|
||||
old_phase, new_phase = transition
|
||||
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
|
||||
last_action = None
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
phase = events.phase
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# --- CORRECTING: human teleop control ---
|
||||
if phase == DAggerPhase.CORRECTING:
|
||||
teleop_action = teleop.get_action()
|
||||
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
frame = {
|
||||
**obs_frame,
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([True], dtype=bool),
|
||||
}
|
||||
if stream_online:
|
||||
dataset.add_frame(frame)
|
||||
else:
|
||||
frame_buffer.append(frame)
|
||||
record_tick += 1
|
||||
|
||||
# --- PAUSED: hold position ---
|
||||
elif phase == DAggerPhase.PAUSED:
|
||||
if last_action:
|
||||
robot.send_action(last_action)
|
||||
|
||||
# --- AUTONOMOUS: policy control ---
|
||||
else:
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
timestamp = time.perf_counter() - start_t
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
|
||||
if action_dict is not None:
|
||||
last_action = ctx.processors.robot_action_processor((action_dict, obs))
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
if record_autonomous and record_tick % record_stride == 0:
|
||||
frame = {
|
||||
**obs_frame,
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([False], dtype=bool),
|
||||
}
|
||||
if stream_online:
|
||||
# --- CORRECTING: human teleop control ---
|
||||
if phase == DAggerPhase.CORRECTING:
|
||||
teleop_action = teleop.get_action()
|
||||
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
last_action = robot_action_to_send
|
||||
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
frame = {
|
||||
**obs_frame,
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([True], dtype=bool),
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
else:
|
||||
frame_buffer.append(frame)
|
||||
record_tick += 1
|
||||
record_tick += 1
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
timestamp = time.perf_counter() - start_t
|
||||
# --- PAUSED: hold position ---
|
||||
elif phase == DAggerPhase.PAUSED:
|
||||
if last_action:
|
||||
robot.send_action(last_action)
|
||||
|
||||
# End of episode: pause engine, disable teleop, flush buffer
|
||||
engine.pause()
|
||||
_teleop_disable_torque(teleop)
|
||||
# --- AUTONOMOUS: policy control ---
|
||||
else:
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if not stream_online:
|
||||
for frame in frame_buffer:
|
||||
dataset.add_frame(frame)
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
if action_dict is not None:
|
||||
last_action = ctx.processors.robot_action_processor((action_dict, obs))
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
frame = {
|
||||
**obs_frame,
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([False], dtype=bool),
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
record_tick += 1
|
||||
|
||||
# Sentry-like episode rotation
|
||||
elapsed = time.perf_counter() - episode_start
|
||||
if elapsed >= self.config.episode_time_s:
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
episodes_since_push += 1
|
||||
self._needs_push.set()
|
||||
logger.info("Episode saved (total: %d)", dataset.num_episodes)
|
||||
|
||||
if episodes_since_push >= self.config.upload_every_n_episodes:
|
||||
self._background_push(dataset, cfg)
|
||||
episodes_since_push = 0
|
||||
|
||||
episode_start = time.perf_counter()
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
|
||||
finally:
|
||||
engine.pause()
|
||||
teleop.disable_torque()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
self._needs_push.set()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Corrections-only mode (record_autonomous=False)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_corrections_only(self, ctx: RolloutContext) -> None:
|
||||
"""Record only human correction windows. Each correction = one episode.
|
||||
|
||||
The policy runs autonomously without recording. When the user
|
||||
pauses and starts a correction, frames are recorded with
|
||||
``intervention=True``. Stopping the correction saves the episode.
|
||||
The dataset can be uploaded on demand via the upload key/pedal.
|
||||
"""
|
||||
engine = self._engine
|
||||
cfg = ctx.runtime.cfg
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
teleop = ctx.hardware.teleop
|
||||
dataset = ctx.data.dataset
|
||||
events = self._events
|
||||
interpolator = self._interpolator
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
record_stride = max(1, cfg.interpolation_multiplier)
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
events.reset()
|
||||
teleop.disable_torque()
|
||||
engine.resume()
|
||||
|
||||
last_action: dict[str, Any] | None = None
|
||||
record_tick = 0
|
||||
recorded = 0
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
while (
|
||||
recorded < self.config.num_episodes
|
||||
and not events.stop_recording.is_set()
|
||||
and not ctx.runtime.shutdown_event.is_set()
|
||||
):
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Process transitions
|
||||
transition = events.consume_transition()
|
||||
if transition is not None:
|
||||
old_phase, new_phase = transition
|
||||
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
|
||||
last_action = None
|
||||
|
||||
# Correction ended -> save episode (blocking if not streaming)
|
||||
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
self._needs_push.set()
|
||||
logger.info("Episode %d saved", recorded)
|
||||
|
||||
# On-demand upload
|
||||
if events.upload_requested.is_set():
|
||||
events.upload_requested.clear()
|
||||
self._background_push(dataset, cfg)
|
||||
|
||||
phase = events.phase
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
|
||||
# --- CORRECTING: human teleop control + recording ---
|
||||
if phase == DAggerPhase.CORRECTING:
|
||||
teleop_action = teleop.get_action()
|
||||
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
last_action = robot_action_to_send
|
||||
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
dataset.add_frame(
|
||||
{
|
||||
**obs_frame,
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([True], dtype=bool),
|
||||
}
|
||||
)
|
||||
record_tick += 1
|
||||
|
||||
# --- PAUSED: hold position ---
|
||||
elif phase == DAggerPhase.PAUSED:
|
||||
if last_action:
|
||||
robot.send_action(last_action)
|
||||
|
||||
# --- AUTONOMOUS: policy control (no recording) ---
|
||||
else:
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
if action_dict is not None:
|
||||
last_action = ctx.processors.robot_action_processor((action_dict, obs))
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
|
||||
finally:
|
||||
engine.pause()
|
||||
teleop.disable_torque()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
self._needs_push.set()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# State-machine transition side-effects
|
||||
@@ -554,13 +613,41 @@ class DAggerStrategy(RolloutStrategy):
|
||||
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||
}
|
||||
_teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
interpolator.reset()
|
||||
|
||||
elif new_phase == DAggerPhase.CORRECTING:
|
||||
_teleop_disable_torque(teleop)
|
||||
engine.reset()
|
||||
teleop.disable_torque()
|
||||
|
||||
elif new_phase == DAggerPhase.AUTONOMOUS:
|
||||
interpolator.reset()
|
||||
engine.reset()
|
||||
engine.resume()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Background push (shared by both modes)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _background_push(self, dataset, cfg) -> None:
|
||||
"""Queue a Hub push on the single-worker executor.
|
||||
|
||||
The executor's max_workers=1 guarantees at most one push runs at
|
||||
a time; submitted tasks are queued rather than dropped.
|
||||
"""
|
||||
if self._push_executor is None:
|
||||
return
|
||||
|
||||
if self._pending_push is not None and not self._pending_push.done():
|
||||
logger.info("Previous push still in progress; queueing next")
|
||||
|
||||
def _push():
|
||||
try:
|
||||
with self._episode_lock:
|
||||
dataset.push_to_hub(
|
||||
tags=cfg.dataset.tags if cfg.dataset else None,
|
||||
private=cfg.dataset.private if cfg.dataset else False,
|
||||
)
|
||||
self._needs_push.clear()
|
||||
logger.info("Background push to hub complete")
|
||||
except Exception as e:
|
||||
logger.error("Background push failed: %s", e)
|
||||
|
||||
self._pending_push = self._push_executor.submit(_push)
|
||||
|
||||
Reference in New Issue
Block a user