simplify dagger

This commit is contained in:
Steven Palma
2026-04-17 15:55:03 +02:00
parent 051f6c6803
commit 35bb2c7459
4 changed files with 448 additions and 310 deletions

View File

@@ -16,6 +16,8 @@
from .configs import (
BaseStrategyConfig,
DAggerKeyboardConfig,
DAggerPedalConfig,
DAggerStrategyConfig,
DatasetRecordConfig,
HighlightStrategyConfig,
@@ -39,6 +41,8 @@ from .strategies import RolloutStrategy, create_strategy
__all__ = [
"BaseStrategyConfig",
"DAggerKeyboardConfig",
"DAggerPedalConfig",
"DAggerStrategyConfig",
"HighlightStrategyConfig",
"InferenceEngine",

View File

@@ -66,7 +66,7 @@ class SentryStrategyConfig(RolloutStrategyConfig):
uploaded in the background every ``upload_every_n_episodes`` episodes.
"""
episode_duration_s: float = 120.0
episode_duration_s: float = 20.0
upload_every_n_episodes: int = 5
@@ -87,6 +87,32 @@ class HighlightStrategyConfig(RolloutStrategyConfig):
push_key: str = "h"
@dataclass
class DAggerKeyboardConfig:
"""Keyboard key bindings for DAgger controls.
Keys are specified as single characters (e.g. ``"c"``, ``"h"``) or
special key names (``"space"``).
"""
pause_resume: str = "space"
correction: str = "c"
upload: str = "h"
@dataclass
class DAggerPedalConfig:
"""Foot pedal configuration for DAgger controls.
Pedal codes are evdev key code strings (e.g. ``"KEY_A"``).
"""
device_path: str = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
pause_resume: str = "KEY_A"
correction: str = "KEY_B"
upload: str = "KEY_C"
@RolloutStrategyConfig.register_subclass("dagger")
@dataclass
class DAggerStrategyConfig(RolloutStrategyConfig):
@@ -95,19 +121,30 @@ class DAggerStrategyConfig(RolloutStrategyConfig):
Alternates between autonomous policy execution and human intervention.
Intervention frames are tagged with ``intervention=True``.
Input is controlled via either a keyboard or foot pedal, selected by
``input_device``. Each device exposes three actions:
1. **pause_resume** — toggle policy execution on/off.
2. **correction** — toggle human correction recording.
3. **upload** — push dataset to hub on demand (corrections-only mode).
When ``record_autonomous=True`` (default) both autonomous and correction
frames are recorded — this requires streaming encoding so the policy
loop never blocks on disk I/O. Set to ``False`` to record only the
human-correction windows; encoding can then happen between phases.
frames are recorded with sentry-like time-based episode rotation and
background uploading. Set to ``False`` to record only the human-correction
windows, where each correction becomes its own episode.
"""
episode_time_s: float = 120.0
num_episodes: int = 50
play_sounds: bool = True
calibrate: bool = False
log_hz: bool = True
hz_log_interval_s: float = 2.0
record_autonomous: bool = True
episode_time_s: float = 20.0
num_episodes: int = 10
record_autonomous: bool = False
upload_every_n_episodes: int = 5
input_device: str = "keyboard"
keyboard: DAggerKeyboardConfig = field(default_factory=DAggerKeyboardConfig)
pedal: DAggerPedalConfig = field(default_factory=DAggerPedalConfig)
def __post_init__(self):
if self.input_device not in ("keyboard", "pedal"):
raise ValueError(f"DAgger input_device must be 'keyboard' or 'pedal', got '{self.input_device}'")
# ---------------------------------------------------------------------------
@@ -160,9 +197,7 @@ class RolloutConfig:
if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None:
raise ValueError("DAgger strategy requires --teleop.type to be set")
needs_dataset = isinstance(
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
)
needs_dataset = isinstance(self.strategy, (SentryStrategyConfig, HighlightStrategyConfig))
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")

View File

@@ -234,6 +234,18 @@ def build_rollout_context(
teleop = make_teleoperator_from_config(cfg.teleop)
teleop.connect()
# DAgger requires teleop with motor control capabilities (enable_torque,
# disable_torque, write_goal_positions).
if isinstance(cfg.strategy, DAggerStrategyConfig) and teleop is not None:
required_teleop_methods = ("enable_torque", "disable_torque", "write_goal_positions")
missing = [m for m in required_teleop_methods if not callable(getattr(teleop, m, None))]
if missing:
teleop.disconnect()
raise ValueError(
f"DAgger strategy requires a teleoperator with motor control methods "
f"{required_teleop_methods}. '{type(teleop).__name__}' is missing: {missing}"
)
# --- 4. Features + action-key reconciliation ---------------------
all_obs_features = robot.observation_features
observation_features_hw = {

View File

@@ -18,13 +18,20 @@ Implements the RaC paradigm (Recovery and Correction) for interactive
imitation learning. Alternates between autonomous policy execution and
human intervention via teleoperator.
Keyboard Controls:
SPACE - Pause policy (robot holds position, no recording)
c - Take control (start correction, recording resumes)
p - Resume policy after pause/correction
-> - End episode (save and continue)
<- - Re-record episode
ESC - Stop recording and push to hub
Input is controlled via either a keyboard or foot pedal, selected by
the ``input_device`` config field. Each device exposes three actions:
1. **pause_resume** — Toggle policy execution (AUTONOMOUS <-> PAUSED).
2. **correction** — Toggle correction recording (PAUSED <-> CORRECTING).
3. **upload** — Push dataset to hub on demand (corrections-only mode).
ESC (keyboard only) Stop session.
Recording Modes:
``record_autonomous=True``: Sentry-like continuous recording with
time-based episode rotation. Both autonomous and correction
frames are recorded; corrections tagged ``intervention=True``.
``record_autonomous=False``: Only correction windows are recorded.
Each correction (start to stop) becomes one episode.
"""
from __future__ import annotations
@@ -32,7 +39,10 @@ from __future__ import annotations
import contextlib
import enum
import logging
import os
import sys
import time
from concurrent.futures import Future, ThreadPoolExecutor
from threading import Event, Lock
from typing import Any
@@ -40,19 +50,31 @@ import numpy as np
from lerobot.common.control_utils import is_headless
from lerobot.datasets import VideoEncodingManager
from lerobot.processor import RobotProcessorPipeline
from lerobot.teleoperators import Teleoperator
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.import_utils import _pynput_available
from lerobot.utils.pedal import start_pedal_listener
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from ..configs import DAggerStrategyConfig
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
from ..context import RolloutContext
from ..robot_wrapper import ThreadSafeRobot
from .core import RolloutStrategy, send_next_action
PYNPUT_AVAILABLE = _pynput_available
keyboard = None
if PYNPUT_AVAILABLE:
try:
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
logging.info("No DISPLAY set. Skipping pynput import.")
PYNPUT_AVAILABLE = False
else:
from pynput import keyboard
except Exception as e:
PYNPUT_AVAILABLE = False
logging.info(f"Could not import pynput: {e}")
logger = logging.getLogger(__name__)
@@ -64,22 +86,22 @@ logger = logging.getLogger(__name__)
class DAggerPhase(enum.Enum):
"""Observable phases of a DAgger episode."""
AUTONOMOUS = "autonomous" # Policy driving, recording autonomous frames
PAUSED = "paused" # Engine paused, teleop aligned, awaiting takeover/resume
AUTONOMOUS = "autonomous" # Policy driving
PAUSED = "paused" # Engine paused, teleop aligned, awaiting input
CORRECTING = "correcting" # Human driving via teleop, recording interventions
# Valid (current_phase, event) next_phase
# Valid (current_phase, event) -> next_phase
_DAGGER_TRANSITIONS: dict[tuple[DAggerPhase, str], DAggerPhase] = {
(DAggerPhase.AUTONOMOUS, "pause"): DAggerPhase.PAUSED,
(DAggerPhase.PAUSED, "takeover"): DAggerPhase.CORRECTING,
(DAggerPhase.PAUSED, "resume"): DAggerPhase.AUTONOMOUS,
(DAggerPhase.CORRECTING, "resume"): DAggerPhase.AUTONOMOUS,
(DAggerPhase.AUTONOMOUS, "pause_resume"): DAggerPhase.PAUSED,
(DAggerPhase.PAUSED, "pause_resume"): DAggerPhase.AUTONOMOUS,
(DAggerPhase.PAUSED, "correction"): DAggerPhase.CORRECTING,
(DAggerPhase.CORRECTING, "correction"): DAggerPhase.PAUSED,
}
class DAggerEvents:
"""Thread-safe container for DAgger keyboard/pedal events.
"""Thread-safe container for DAgger input device events.
The keyboard/pedal threads write transition requests; the main loop
consumes them.
@@ -90,16 +112,9 @@ class DAggerEvents:
self._phase = DAggerPhase.AUTONOMOUS
self._pending_transition: str | None = None
# Episode-level flags written by keyboard/pedal threads, consumed by
# the main loop. ``threading.Event`` gives us atomic set/clear/check
# semantics without taking ``self._lock``.
self.exit_early = Event()
self.rerecord_episode = Event()
# Session-level flags
self.stop_recording = Event()
# Reset-phase flags (simpler lifecycle, shared between threads).
self.in_reset = Event()
self.start_next_episode = Event()
self.upload_requested = Event()
# -- Thread-safe phase access ------------------------------------------
@@ -138,13 +153,12 @@ class DAggerEvents:
self._phase = new_phase
return old_phase, new_phase
def reset_for_episode(self) -> None:
"""Reset all transient state at the start of an episode."""
def reset(self) -> None:
"""Reset all transient state for a fresh session."""
with self._lock:
self._phase = DAggerPhase.AUTONOMOUS
self._pending_transition = None
self.exit_early.clear()
self.rerecord_episode.clear()
self.upload_requested.clear()
# ---------------------------------------------------------------------------
@@ -152,29 +166,15 @@ class DAggerEvents:
# ---------------------------------------------------------------------------
def _teleop_has_motor_control(teleop: Teleoperator) -> bool:
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
def _teleop_disable_torque(teleop: Teleoperator) -> None:
if hasattr(teleop, "disable_torque"):
teleop.disable_torque()
def _teleop_enable_torque(teleop: Teleoperator) -> None:
if hasattr(teleop, "enable_torque"):
teleop.enable_torque()
def _teleop_smooth_move_to(
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50
) -> None:
"""Smoothly move teleop to target position if motor control is available."""
if not _teleop_has_motor_control(teleop):
logger.warning("Teleop does not support motor control — cannot mirror robot position")
return
"""Smoothly move teleop to target position via linear interpolation.
_teleop_enable_torque(teleop)
The teleoperator is guaranteed to have motor control methods
(validated at context build time).
"""
teleop.enable_torque()
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
@@ -190,103 +190,58 @@ def _teleop_smooth_move_to(
time.sleep(1 / fps)
def _reset_loop(
robot: ThreadSafeRobot,
teleop: Teleoperator,
events: DAggerEvents,
fps: int,
teleop_action_processor: RobotProcessorPipeline,
robot_action_processor: RobotProcessorPipeline,
) -> None:
"""Reset period where the human repositions the environment."""
logger.info("RESET — press any key to enable teleoperation")
events.in_reset.set()
events.start_next_episode.clear()
obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
_teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
while not events.start_next_episode.is_set() and not events.stop_recording.is_set():
precise_sleep(0.05)
if events.stop_recording.is_set():
return
events.start_next_episode.clear()
_teleop_disable_torque(teleop)
logger.info("Teleop enabled — press any key to start episode")
while not events.start_next_episode.is_set() and not events.stop_recording.is_set():
loop_start = time.perf_counter()
obs = robot.get_observation()
action = teleop.get_action()
processed_teleop = teleop_action_processor((action, obs))
robot_action_to_send = robot_action_processor((processed_teleop, obs))
robot.send_action(robot_action_to_send)
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
events.in_reset.clear()
events.start_next_episode.clear()
events.reset_for_episode()
# ---------------------------------------------------------------------------
# Input device handlers
# ---------------------------------------------------------------------------
def _init_dagger_keyboard(events: DAggerEvents):
"""Initialise keyboard listener with DAgger/HIL controls.
def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
"""Initialise keyboard listener with DAgger 3-key controls.
Returns the pynput Listener (or ``None`` in headless mode).
Returns the pynput Listener (or ``None`` in headless mode or when
pynput is unavailable).
"""
if is_headless():
logger.warning("Headless environment — keyboard controls unavailable")
if not PYNPUT_AVAILABLE or is_headless():
logger.warning("Headless environment or pynput unavailable — keyboard controls disabled")
return None
from pynput import keyboard
# Map config key names to pynput Key objects for special keys
special_keys = {
"space": keyboard.Key.space,
"tab": keyboard.Key.tab,
"enter": keyboard.Key.enter,
}
def _resolve_key(key) -> str | None:
"""Resolve a pynput key event to a config-comparable string."""
if key == keyboard.Key.esc:
return "esc"
for name, pynput_key in special_keys.items():
if key == pynput_key:
return name
if hasattr(key, "char") and key.char:
return key.char
return None
# Build mapping: resolved key string -> DAgger event name
key_to_event = {
cfg.pause_resume: "pause_resume",
cfg.correction: "correction",
}
def on_press(key):
try:
if events.in_reset.is_set():
if (
key in [keyboard.Key.space, keyboard.Key.right]
or hasattr(key, "char")
and key.char == "c"
):
events.start_next_episode.set()
elif key == keyboard.Key.esc:
events.stop_recording.set()
events.start_next_episode.set()
resolved = _resolve_key(key)
if resolved is None:
return
phase = events.phase
if key == keyboard.Key.space and phase == DAggerPhase.AUTONOMOUS:
logger.info("PAUSED — press 'c' to take control or 'p' to resume policy")
events.request_transition("pause")
elif hasattr(key, "char") and key.char == "c" and phase == DAggerPhase.PAUSED:
logger.info("Taking control...")
events.request_transition("takeover")
elif (
hasattr(key, "char")
and key.char == "p"
and phase
in (
DAggerPhase.PAUSED,
DAggerPhase.CORRECTING,
)
):
logger.info("Resuming policy...")
events.request_transition("resume")
elif key == keyboard.Key.right:
logger.info("End episode")
events.exit_early.set()
elif key == keyboard.Key.left:
logger.info("Re-record episode")
events.rerecord_episode.set()
events.exit_early.set()
elif key == keyboard.Key.esc:
if resolved == "esc":
logger.info("Stop recording...")
events.stop_recording.set()
events.exit_early.set()
return
if resolved in key_to_event:
events.request_transition(key_to_event[resolved])
if resolved == cfg.upload:
events.upload_requested.set()
except Exception as e:
logger.debug("Key error: %s", e)
@@ -295,27 +250,23 @@ def _init_dagger_keyboard(events: DAggerEvents):
return listener
_DAGGER_PEDAL_KEYS = ("KEY_A", "KEY_C")
def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig):
"""Initialise foot pedal listener with DAgger 3-pedal controls.
def _dagger_pedal_callback(events: DAggerEvents):
"""Build the pedal key-press handler for DAgger's state machine."""
Returns the pedal listener thread (or ``None`` if evdev is unavailable).
"""
code_to_event = {
cfg.pause_resume: "pause_resume",
cfg.correction: "correction",
}
def on_press(code: str) -> None:
if code not in _DAGGER_PEDAL_KEYS:
return
if events.in_reset.is_set():
events.start_next_episode.set()
return
phase = events.phase
if phase == DAggerPhase.CORRECTING:
events.request_transition("resume")
elif phase == DAggerPhase.PAUSED:
events.request_transition("takeover")
elif phase == DAggerPhase.AUTONOMOUS:
events.request_transition("pause")
if code in code_to_event:
events.request_transition(code_to_event[code])
if code == cfg.upload:
events.upload_requested.set()
return on_press
return start_pedal_listener(on_press, device_path=cfg.device_path)
# ---------------------------------------------------------------------------
@@ -328,12 +279,14 @@ class DAggerStrategy(RolloutStrategy):
State machine::
AUTONOMOUS --(SPACE)--> PAUSED --(c)--> CORRECTING --(p)--> AUTONOMOUS
--(p)--> AUTONOMOUS
AUTONOMOUS --(key1)--> PAUSED --(key2)--> CORRECTING --(key2)--> PAUSED
--(key1)--> AUTONOMOUS
Intervention frames are tagged with ``intervention=True`` (bool) in
the dataset; autonomous frames with ``intervention=False``. When
``record_autonomous=False`` only corrections are recorded.
Recording modes:
``record_autonomous=True``: Sentry-like continuous recording with
time-based episode rotation. Intervention frames tagged True.
``record_autonomous=False``: Only correction windows recorded.
Each correction = one episode. Upload on demand via key3.
"""
config: DAggerStrategyConfig
@@ -341,71 +294,51 @@ class DAggerStrategy(RolloutStrategy):
def __init__(self, config: DAggerStrategyConfig):
super().__init__(config)
self._listener = None
self._pedal_thread = None
self._events = DAggerEvents()
self._push_executor: ThreadPoolExecutor | None = None
self._pending_push: Future | None = None
self._needs_push = Event()
self._episode_lock = Lock()
def setup(self, ctx: RolloutContext) -> None:
"""Initialise the inference engine, keyboard listener, and pedal handler."""
"""Initialise the inference engine and input device listener."""
self._init_engine(ctx)
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="dagger-push")
self._listener = _init_dagger_keyboard(self._events)
start_pedal_listener(_dagger_pedal_callback(self._events))
if self.config.input_device == "keyboard":
self._listener = _init_dagger_keyboard(self._events, self.config.keyboard)
else:
self._pedal_thread = _init_dagger_pedal(self._events, self.config.pedal)
record_mode = "all frames (sentry-like)" if self.config.record_autonomous else "corrections only"
logger.info(
"DAgger strategy ready (episodes=%d, episode_time=%.0fs, record_autonomous=%s)",
"DAgger strategy ready (input=%s, episodes=%d, record=%s)",
self.config.input_device,
self.config.num_episodes,
self.config.episode_time_s,
self.config.record_autonomous,
record_mode,
)
logger.info("Controls: SPACE=pause, c=take control, p=resume, ->=end, <-=redo, ESC=stop")
def run(self, ctx: RolloutContext) -> None:
"""Run DAgger episodes with human-in-the-loop intervention."""
dataset = ctx.data.dataset
events = self._events
teleop = ctx.hardware.teleop
with VideoEncodingManager(dataset):
try:
recorded = 0
while recorded < self.config.num_episodes and not events.stop_recording.is_set():
log_say(f"Episode {dataset.num_episodes}", self.config.play_sounds)
self._run_episode(ctx)
if events.rerecord_episode.is_set():
log_say("Re-recording", self.config.play_sounds)
events.rerecord_episode.clear()
events.exit_early.clear()
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded += 1
if recorded < self.config.num_episodes and not events.stop_recording.is_set():
_reset_loop(
ctx.hardware.robot_wrapper,
teleop,
events,
int(ctx.runtime.cfg.fps),
ctx.processors.teleop_action_processor,
ctx.processors.robot_action_processor,
)
finally:
with contextlib.suppress(Exception):
dataset.save_episode()
if self.config.record_autonomous:
self._run_continuous(ctx)
else:
self._run_corrections_only(ctx)
def teardown(self, ctx: RolloutContext) -> None:
"""Stop listeners, finalise the dataset, and disconnect hardware."""
log_say("Stop recording", self.config.play_sounds, blocking=True)
if self._listener is not None and not is_headless():
self._listener.stop()
# Flush any queued/running push cleanly
if self._push_executor is not None:
self._push_executor.shutdown(wait=True)
self._push_executor = None
if ctx.data.dataset is not None:
ctx.data.dataset.finalize()
if ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
ctx.data.dataset.push_to_hub(
tags=ctx.runtime.cfg.dataset.tags,
private=ctx.runtime.cfg.dataset.private,
@@ -415,11 +348,17 @@ class DAggerStrategy(RolloutStrategy):
logger.info("DAgger strategy teardown complete")
# ------------------------------------------------------------------
# Episode rollout (state machine)
# Continuous recording mode (record_autonomous=True)
# ------------------------------------------------------------------
def _run_episode(self, ctx: RolloutContext) -> None:
"""Run a single DAgger episode with the HIL state machine."""
def _run_continuous(self, ctx: RolloutContext) -> None:
"""Sentry-like continuous recording with intervention tagging.
Episodes are auto-rotated every ``episode_time_s`` seconds and
uploaded in the background every ``upload_every_n_episodes`` episodes.
Both autonomous and correction frames are recorded; corrections are
tagged with ``intervention=True``.
"""
engine = self._engine
cfg = ctx.runtime.cfg
robot = ctx.hardware.robot_wrapper
@@ -427,111 +366,231 @@ class DAggerStrategy(RolloutStrategy):
dataset = ctx.data.dataset
events = self._events
interpolator = self._interpolator
features = ctx.data.dataset_features
control_interval = interpolator.get_control_interval(cfg.fps)
stream_online = bool(cfg.dataset.streaming_encoding) if cfg.dataset else False
record_stride = max(1, cfg.interpolation_multiplier)
record_autonomous = self.config.record_autonomous
features = ctx.data.dataset_features
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
engine.reset()
interpolator.reset()
events.reset_for_episode()
_teleop_disable_torque(teleop)
last_action: dict[str, Any] | None = None
frame_buffer: list[dict] = []
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
timestamp = 0.0
record_tick = 0
start_t = time.perf_counter()
events.reset()
teleop.disable_torque()
engine.resume()
while timestamp < self.config.episode_time_s:
loop_start = time.perf_counter()
last_action: dict[str, Any] | None = None
record_tick = 0
episode_start = time.perf_counter()
start_time = time.perf_counter()
episodes_since_push = 0
if events.exit_early.is_set():
events.exit_early.clear()
break
with VideoEncodingManager(dataset):
try:
while not events.stop_recording.is_set() and not ctx.runtime.shutdown_event.is_set():
loop_start = time.perf_counter()
transition = events.consume_transition()
if transition is not None:
old_phase, new_phase = transition
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
last_action = None
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
break
phase = events.phase
# Process transitions
transition = events.consume_transition()
if transition is not None:
old_phase, new_phase = transition
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
last_action = None
obs = robot.get_observation()
obs_processed = ctx.processors.robot_observation_processor(obs)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
phase = events.phase
obs = robot.get_observation()
obs_processed = ctx.processors.robot_observation_processor(obs)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
# --- CORRECTING: human teleop control ---
if phase == DAggerPhase.CORRECTING:
teleop_action = teleop.get_action()
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
robot.send_action(robot_action_to_send)
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
if record_tick % record_stride == 0:
frame = {
**obs_frame,
**action_frame,
"task": task_str,
"intervention": np.array([True], dtype=bool),
}
if stream_online:
dataset.add_frame(frame)
else:
frame_buffer.append(frame)
record_tick += 1
# --- PAUSED: hold position ---
elif phase == DAggerPhase.PAUSED:
if last_action:
robot.send_action(last_action)
# --- AUTONOMOUS: policy control ---
else:
engine.notify_observation(obs_processed)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
timestamp = time.perf_counter() - start_t
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
last_action = ctx.processors.robot_action_processor((action_dict, obs))
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
if record_autonomous and record_tick % record_stride == 0:
frame = {
**obs_frame,
**action_frame,
"task": task_str,
"intervention": np.array([False], dtype=bool),
}
if stream_online:
# --- CORRECTING: human teleop control ---
if phase == DAggerPhase.CORRECTING:
teleop_action = teleop.get_action()
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
robot.send_action(robot_action_to_send)
last_action = robot_action_to_send
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
if record_tick % record_stride == 0:
frame = {
**obs_frame,
**action_frame,
"task": task_str,
"intervention": np.array([True], dtype=bool),
}
dataset.add_frame(frame)
else:
frame_buffer.append(frame)
record_tick += 1
record_tick += 1
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
timestamp = time.perf_counter() - start_t
# --- PAUSED: hold position ---
elif phase == DAggerPhase.PAUSED:
if last_action:
robot.send_action(last_action)
# End of episode: pause engine, disable teleop, flush buffer
engine.pause()
_teleop_disable_torque(teleop)
# --- AUTONOMOUS: policy control ---
else:
engine.notify_observation(obs_processed)
if not stream_online:
for frame in frame_buffer:
dataset.add_frame(frame)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
last_action = ctx.processors.robot_action_processor((action_dict, obs))
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
if record_tick % record_stride == 0:
frame = {
**obs_frame,
**action_frame,
"task": task_str,
"intervention": np.array([False], dtype=bool),
}
dataset.add_frame(frame)
record_tick += 1
# Sentry-like episode rotation
elapsed = time.perf_counter() - episode_start
if elapsed >= self.config.episode_time_s:
with self._episode_lock:
dataset.save_episode()
episodes_since_push += 1
self._needs_push.set()
logger.info("Episode saved (total: %d)", dataset.num_episodes)
if episodes_since_push >= self.config.upload_every_n_episodes:
self._background_push(dataset, cfg)
episodes_since_push = 0
episode_start = time.perf_counter()
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
finally:
engine.pause()
teleop.disable_torque()
with contextlib.suppress(Exception):
with self._episode_lock:
dataset.save_episode()
self._needs_push.set()
# ------------------------------------------------------------------
# Corrections-only mode (record_autonomous=False)
# ------------------------------------------------------------------
def _run_corrections_only(self, ctx: RolloutContext) -> None:
"""Record only human correction windows. Each correction = one episode.
The policy runs autonomously without recording. When the user
pauses and starts a correction, frames are recorded with
``intervention=True``. Stopping the correction saves the episode.
The dataset can be uploaded on demand via the upload key/pedal.
"""
engine = self._engine
cfg = ctx.runtime.cfg
robot = ctx.hardware.robot_wrapper
teleop = ctx.hardware.teleop
dataset = ctx.data.dataset
events = self._events
interpolator = self._interpolator
features = ctx.data.dataset_features
control_interval = interpolator.get_control_interval(cfg.fps)
record_stride = max(1, cfg.interpolation_multiplier)
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
engine.reset()
interpolator.reset()
events.reset()
teleop.disable_torque()
engine.resume()
last_action: dict[str, Any] | None = None
record_tick = 0
recorded = 0
with VideoEncodingManager(dataset):
try:
while (
recorded < self.config.num_episodes
and not events.stop_recording.is_set()
and not ctx.runtime.shutdown_event.is_set()
):
loop_start = time.perf_counter()
# Process transitions
transition = events.consume_transition()
if transition is not None:
old_phase, new_phase = transition
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
last_action = None
# Correction ended -> save episode (blocking if not streaming)
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
with self._episode_lock:
dataset.save_episode()
recorded += 1
self._needs_push.set()
logger.info("Episode %d saved", recorded)
# On-demand upload
if events.upload_requested.is_set():
events.upload_requested.clear()
self._background_push(dataset, cfg)
phase = events.phase
obs = robot.get_observation()
obs_processed = ctx.processors.robot_observation_processor(obs)
# --- CORRECTING: human teleop control + recording ---
if phase == DAggerPhase.CORRECTING:
teleop_action = teleop.get_action()
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
robot.send_action(robot_action_to_send)
last_action = robot_action_to_send
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
if record_tick % record_stride == 0:
dataset.add_frame(
{
**obs_frame,
**action_frame,
"task": task_str,
"intervention": np.array([True], dtype=bool),
}
)
record_tick += 1
# --- PAUSED: hold position ---
elif phase == DAggerPhase.PAUSED:
if last_action:
robot.send_action(last_action)
# --- AUTONOMOUS: policy control (no recording) ---
else:
engine.notify_observation(obs_processed)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
last_action = ctx.processors.robot_action_processor((action_dict, obs))
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
finally:
engine.pause()
teleop.disable_torque()
with contextlib.suppress(Exception):
with self._episode_lock:
dataset.save_episode()
self._needs_push.set()
# ------------------------------------------------------------------
# State-machine transition side-effects
@@ -554,13 +613,41 @@ class DAggerStrategy(RolloutStrategy):
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
}
_teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
interpolator.reset()
elif new_phase == DAggerPhase.CORRECTING:
_teleop_disable_torque(teleop)
engine.reset()
teleop.disable_torque()
elif new_phase == DAggerPhase.AUTONOMOUS:
interpolator.reset()
engine.reset()
engine.resume()
# ------------------------------------------------------------------
# Background push (shared by both modes)
# ------------------------------------------------------------------
def _background_push(self, dataset, cfg) -> None:
"""Queue a Hub push on the single-worker executor.
The executor's max_workers=1 guarantees at most one push runs at
a time; submitted tasks are queued rather than dropped.
"""
if self._push_executor is None:
return
if self._pending_push is not None and not self._pending_push.done():
logger.info("Previous push still in progress; queueing next")
def _push():
try:
with self._episode_lock:
dataset.push_to_hub(
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
)
self._needs_push.clear()
logger.info("Background push to hub complete")
except Exception as e:
logger.error("Background push failed: %s", e)
self._pending_push = self._push_executor.submit(_push)