diff --git a/src/lerobot/rollout/context.py b/src/lerobot/rollout/context.py index 48c60b7fd..f54d4248b 100644 --- a/src/lerobot/rollout/context.py +++ b/src/lerobot/rollout/context.py @@ -177,7 +177,7 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC k: v for k, v in all_obs_features.items() if v is float or isinstance(v, tuple) } - action_features_hw = {k: v for k, v in robot.action_features.items() if k.endswith(".pos")} + action_features_hw = robot.action_features # Build dataset features dataset_features = combine_feature_dicts( @@ -196,7 +196,7 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC hw_features = hw_to_dataset_features(observation_features_hw, "observation") # Action keys - action_keys = [k for k in robot.action_features if k.endswith(".pos")] + action_keys = list(robot.action_features.keys()) # Ordered action keys (reconcile policy vs dataset ordering) policy_action_names = getattr(policy_config, "action_feature_names", None) diff --git a/src/lerobot/rollout/strategies/__init__.py b/src/lerobot/rollout/strategies/__init__.py index 3c5eee83a..4674d852e 100644 --- a/src/lerobot/rollout/strategies/__init__.py +++ b/src/lerobot/rollout/strategies/__init__.py @@ -17,6 +17,7 @@ from __future__ import annotations import abc +import time from typing import TYPE_CHECKING import torch @@ -25,6 +26,7 @@ from lerobot.policies.rtc import ActionInterpolator from lerobot.policies.utils import make_robot_action from lerobot.utils.constants import OBS_STR from lerobot.utils.feature_utils import build_dataset_frame +from lerobot.utils.robot_utils import precise_sleep if TYPE_CHECKING: from lerobot.rollout.configs import RolloutStrategyConfig @@ -42,6 +44,68 @@ class RolloutStrategy(abc.ABC): def __init__(self, config: RolloutStrategyConfig) -> None: self.config = config + self._engine: InferenceEngine | None = None + self._interpolator: ActionInterpolator | None = None + self._warmup_flushed: bool = False + + def _init_engine(self, ctx: RolloutContext) -> None: + """Create and start the inference engine and action interpolator. + + Call this from ``setup()`` to avoid duplicating the engine + construction across every strategy. + """ + from lerobot.rollout.inference import InferenceEngine + + self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier) + self._engine = InferenceEngine( + policy=ctx.policy, + preprocessor=ctx.preprocessor, + postprocessor=ctx.postprocessor, + robot_wrapper=ctx.robot_wrapper, + rtc_config=ctx.cfg.rtc, + hw_features=ctx.hw_features, + action_keys=ctx.action_keys, + task=ctx.cfg.task, + fps=ctx.cfg.fps, + device=ctx.cfg.device, + use_torch_compile=ctx.cfg.use_torch_compile, + compile_warmup_inferences=ctx.cfg.compile_warmup_inferences, + ) + self._engine.start() + self._warmup_flushed = False + + def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool: + """Handle torch.compile warmup phase. + + Returns ``True`` if the caller should ``continue`` (still warming + up). On the first post-warmup iteration the engine and + interpolator are reset so stale warmup state is discarded. + """ + engine = self._engine + interpolator = self._interpolator + if not use_torch_compile: + return False + if not engine.compile_warmup_done.is_set(): + dt = time.perf_counter() - loop_start + if (sleep_t := control_interval - dt) > 0: + precise_sleep(sleep_t) + return True + if not self._warmup_flushed: + engine.reset() + interpolator.reset() + self._warmup_flushed = True + if engine.is_rtc: + engine.resume() + return False + + def _teardown_hardware(self, ctx: RolloutContext) -> None: + """Stop the inference engine and disconnect hardware.""" + if self._engine is not None: + self._engine.stop() + if ctx.robot.is_connected: + ctx.robot.disconnect() + if ctx.teleop is not None and ctx.teleop.is_connected: + ctx.teleop.disconnect() @abc.abstractmethod def setup(self, ctx: RolloutContext) -> None: diff --git a/src/lerobot/rollout/strategies/base.py b/src/lerobot/rollout/strategies/base.py index 8f5d5c7d1..30bef0376 100644 --- a/src/lerobot/rollout/strategies/base.py +++ b/src/lerobot/rollout/strategies/base.py @@ -19,11 +19,9 @@ from __future__ import annotations import logging import time -from lerobot.policies.rtc import ActionInterpolator from lerobot.utils.robot_utils import precise_sleep from ..context import RolloutContext -from ..inference import InferenceEngine from . import RolloutStrategy, infer_action logger = logging.getLogger(__name__) @@ -37,29 +35,8 @@ class BaseStrategy(RolloutStrategy): ``robot_action_processor`` pipeline before reaching the robot. """ - def __init__(self, config): - super().__init__(config) - self._engine: InferenceEngine | None = None - self._interpolator: ActionInterpolator | None = None - def setup(self, ctx: RolloutContext) -> None: - self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier) - - self._engine = InferenceEngine( - policy=ctx.policy, - preprocessor=ctx.preprocessor, - postprocessor=ctx.postprocessor, - robot_wrapper=ctx.robot_wrapper, - rtc_config=ctx.cfg.rtc, - hw_features=ctx.hw_features, - action_keys=ctx.action_keys, - task=ctx.cfg.task, - fps=ctx.cfg.fps, - device=ctx.cfg.device, - use_torch_compile=ctx.cfg.use_torch_compile, - compile_warmup_inferences=ctx.cfg.compile_warmup_inferences, - ) - self._engine.start() + self._init_engine(ctx) logger.info("Base strategy ready (rtc=%s)", self._engine.is_rtc) def run(self, ctx: RolloutContext) -> None: @@ -72,7 +49,6 @@ class BaseStrategy(RolloutStrategy): ordered_keys = ctx.ordered_action_keys start_time = time.perf_counter() - warmup_flushed = False if engine.is_rtc: engine.resume() @@ -89,20 +65,9 @@ class BaseStrategy(RolloutStrategy): if engine.is_rtc: engine.update_observation(obs_processed) - # Wait for torch.compile warmup before running live inference - if cfg.use_torch_compile and not engine.compile_warmup_done.is_set(): - dt = time.perf_counter() - loop_start - if (sleep_t := control_interval - dt) > 0: - precise_sleep(sleep_t) + if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): continue - if cfg.use_torch_compile and not warmup_flushed: - engine.reset() - interpolator.reset() - warmup_flushed = True - if engine.is_rtc: - engine.resume() - infer_action(engine, obs_processed, obs, ctx, interpolator, ordered_keys, ctx.dataset_features) dt = time.perf_counter() - loop_start @@ -110,10 +75,5 @@ class BaseStrategy(RolloutStrategy): precise_sleep(sleep_t) def teardown(self, ctx: RolloutContext) -> None: - if self._engine is not None: - self._engine.stop() - if ctx.robot.is_connected: - ctx.robot.disconnect() - if ctx.teleop is not None and ctx.teleop.is_connected: - ctx.teleop.disconnect() + self._teardown_hardware(ctx) logger.info("Base strategy teardown complete") diff --git a/src/lerobot/rollout/strategies/dagger.py b/src/lerobot/rollout/strategies/dagger.py index a6d82ec3f..dae4cadad 100644 --- a/src/lerobot/rollout/strategies/dagger.py +++ b/src/lerobot/rollout/strategies/dagger.py @@ -38,8 +38,8 @@ import numpy as np from lerobot.common.control_utils import is_headless from lerobot.datasets import VideoEncodingManager -from lerobot.policies.rtc import ActionInterpolator from lerobot.processor import RobotProcessorPipeline +from lerobot.teleoperators import Teleoperator from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame from lerobot.utils.robot_utils import precise_sleep @@ -47,7 +47,7 @@ from lerobot.utils.utils import log_say from ..configs import DAggerStrategyConfig from ..context import RolloutContext -from ..inference import InferenceEngine +from ..robot_wrapper import ThreadSafeRobot from . import RolloutStrategy, infer_action logger = logging.getLogger(__name__) @@ -58,21 +58,23 @@ logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- -def _teleop_has_motor_control(teleop) -> bool: +def _teleop_has_motor_control(teleop: Teleoperator) -> bool: return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions")) -def _teleop_disable_torque(teleop) -> None: +def _teleop_disable_torque(teleop: Teleoperator) -> None: if hasattr(teleop, "disable_torque"): teleop.disable_torque() -def _teleop_enable_torque(teleop) -> None: +def _teleop_enable_torque(teleop: Teleoperator) -> None: if hasattr(teleop, "enable_torque"): teleop.enable_torque() -def _teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 50) -> None: +def _teleop_smooth_move_to( + teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50 +) -> None: """Smoothly move teleop to target position if motor control is available.""" if not _teleop_has_motor_control(teleop): logger.warning("Teleop does not support motor control — cannot mirror robot position") @@ -95,8 +97,8 @@ def _teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fp def _reset_loop( - robot, - teleop, + robot: ThreadSafeRobot, + teleop: Teleoperator, events: dict, fps: int, teleop_action_processor: RobotProcessorPipeline, @@ -275,29 +277,11 @@ class DAggerStrategy(RolloutStrategy): def __init__(self, config: DAggerStrategyConfig): super().__init__(config) - self._engine: InferenceEngine | None = None - self._interpolator: ActionInterpolator | None = None self._listener = None self._events: dict[str, Any] = {} def setup(self, ctx: RolloutContext) -> None: - self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier) - - self._engine = InferenceEngine( - policy=ctx.policy, - preprocessor=ctx.preprocessor, - postprocessor=ctx.postprocessor, - robot_wrapper=ctx.robot_wrapper, - rtc_config=ctx.cfg.rtc, - hw_features=ctx.hw_features, - action_keys=ctx.action_keys, - task=ctx.cfg.task, - fps=ctx.cfg.fps, - device=ctx.cfg.device, - use_torch_compile=ctx.cfg.use_torch_compile, - compile_warmup_inferences=ctx.cfg.compile_warmup_inferences, - ) - self._engine.start() + self._init_engine(ctx) self._listener, self._events = _init_dagger_keyboard() _start_pedal_listener(self._events) @@ -350,9 +334,6 @@ class DAggerStrategy(RolloutStrategy): def teardown(self, ctx: RolloutContext) -> None: log_say("Stop recording", self.config.play_sounds, blocking=True) - if self._engine is not None: - self._engine.stop() - if self._listener is not None and not is_headless(): self._listener.stop() @@ -364,10 +345,7 @@ class DAggerStrategy(RolloutStrategy): private=ctx.cfg.dataset.private, ) - if ctx.robot.is_connected: - ctx.robot.disconnect() - if ctx.teleop is not None and ctx.teleop.is_connected: - ctx.teleop.disconnect() + self._teardown_hardware(ctx) logger.info("DAgger strategy teardown complete") # ------------------------------------------------------------------ @@ -404,7 +382,6 @@ class DAggerStrategy(RolloutStrategy): timestamp = 0.0 record_tick = 0 start_t = time.perf_counter() - warmup_flushed = False if engine.is_rtc: engine.resume() @@ -493,21 +470,10 @@ class DAggerStrategy(RolloutStrategy): if engine.is_rtc: engine.update_observation(obs_processed) - # Wait for torch.compile warmup - if cfg.use_torch_compile and not engine.compile_warmup_done.is_set(): - dt = time.perf_counter() - loop_start - if (sleep_t := control_interval - dt) > 0: - precise_sleep(sleep_t) + if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): timestamp = time.perf_counter() - start_t continue - if cfg.use_torch_compile and not warmup_flushed: - engine.reset() - interpolator.reset() - warmup_flushed = True - if engine.is_rtc: - engine.resume() - action_dict = infer_action( engine, obs_processed, obs, ctx, interpolator, ordered_keys, features ) diff --git a/src/lerobot/rollout/strategies/highlight.py b/src/lerobot/rollout/strategies/highlight.py index 3e772470d..f9a0dd32d 100644 --- a/src/lerobot/rollout/strategies/highlight.py +++ b/src/lerobot/rollout/strategies/highlight.py @@ -22,14 +22,12 @@ import time from threading import Event as ThreadingEvent from lerobot.datasets import VideoEncodingManager -from lerobot.policies.rtc import ActionInterpolator from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame from lerobot.utils.robot_utils import precise_sleep from ..configs import HighlightStrategyConfig from ..context import RolloutContext -from ..inference import InferenceEngine from ..ring_buffer import RolloutRingBuffer from . import RolloutStrategy, infer_action @@ -54,31 +52,14 @@ class HighlightStrategy(RolloutStrategy): def __init__(self, config: HighlightStrategyConfig): super().__init__(config) - self._engine: InferenceEngine | None = None - self._interpolator: ActionInterpolator | None = None self._ring: RolloutRingBuffer | None = None self._listener = None self._save_requested = ThreadingEvent() self._recording_live = ThreadingEvent() + self._shutdown_event: ThreadingEvent | None = None def setup(self, ctx: RolloutContext) -> None: - self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier) - - self._engine = InferenceEngine( - policy=ctx.policy, - preprocessor=ctx.preprocessor, - postprocessor=ctx.postprocessor, - robot_wrapper=ctx.robot_wrapper, - rtc_config=ctx.cfg.rtc, - hw_features=ctx.hw_features, - action_keys=ctx.action_keys, - task=ctx.cfg.task, - fps=ctx.cfg.fps, - device=ctx.cfg.device, - use_torch_compile=ctx.cfg.use_torch_compile, - compile_warmup_inferences=ctx.cfg.compile_warmup_inferences, - ) - self._engine.start() + self._init_engine(ctx) self._ring = RolloutRingBuffer( max_seconds=self.config.ring_buffer_seconds, @@ -86,6 +67,7 @@ class HighlightStrategy(RolloutStrategy): fps=ctx.cfg.fps, ) + self._shutdown_event = ctx.shutdown_event self._setup_keyboard() logger.info( "Highlight strategy ready (buffer=%.0fs, key='%s')", @@ -109,7 +91,6 @@ class HighlightStrategy(RolloutStrategy): engine.resume() start_time = time.perf_counter() - warmup_flushed = False task_str = cfg.dataset.single_task if cfg.dataset else cfg.task with VideoEncodingManager(dataset): @@ -126,19 +107,9 @@ class HighlightStrategy(RolloutStrategy): if engine.is_rtc: engine.update_observation(obs_processed) - if cfg.use_torch_compile and not engine.compile_warmup_done.is_set(): - dt = time.perf_counter() - loop_start - if (sleep_t := control_interval - dt) > 0: - precise_sleep(sleep_t) + if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): continue - if cfg.use_torch_compile and not warmup_flushed: - engine.reset() - interpolator.reset() - warmup_flushed = True - if engine.is_rtc: - engine.resume() - action_dict = infer_action( engine, obs_processed, obs, ctx, interpolator, ordered_keys, features ) @@ -186,8 +157,6 @@ class HighlightStrategy(RolloutStrategy): dataset.save_episode() def teardown(self, ctx: RolloutContext) -> None: - if self._engine is not None: - self._engine.stop() if self._listener is not None: self._listener.stop() @@ -199,10 +168,7 @@ class HighlightStrategy(RolloutStrategy): private=ctx.cfg.dataset.private, ) - if ctx.robot.is_connected: - ctx.robot.disconnect() - if ctx.teleop is not None and ctx.teleop.is_connected: - ctx.teleop.disconnect() + self._teardown_hardware(ctx) logger.info("Highlight strategy teardown complete") def _setup_keyboard(self) -> None: @@ -224,6 +190,8 @@ class HighlightStrategy(RolloutStrategy): self._save_requested.set() elif key == keyboard.Key.esc: self._save_requested.clear() + if self._shutdown_event is not None: + self._shutdown_event.set() self._listener = keyboard.Listener(on_press=on_press) self._listener.start() diff --git a/src/lerobot/rollout/strategies/sentry.py b/src/lerobot/rollout/strategies/sentry.py index 3b36cd0dd..a0d85747a 100644 --- a/src/lerobot/rollout/strategies/sentry.py +++ b/src/lerobot/rollout/strategies/sentry.py @@ -19,17 +19,15 @@ from __future__ import annotations import contextlib import logging import time -from threading import Thread +from threading import Event, Thread from lerobot.datasets import VideoEncodingManager -from lerobot.policies.rtc import ActionInterpolator from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame from lerobot.utils.robot_utils import precise_sleep from ..configs import SentryStrategyConfig from ..context import RolloutContext -from ..inference import InferenceEngine from . import RolloutStrategy, infer_action logger = logging.getLogger(__name__) @@ -54,29 +52,11 @@ class SentryStrategy(RolloutStrategy): def __init__(self, config: SentryStrategyConfig): super().__init__(config) - self._engine: InferenceEngine | None = None - self._interpolator: ActionInterpolator | None = None self._push_thread: Thread | None = None - self._needs_push: bool = False + self._needs_push = Event() def setup(self, ctx: RolloutContext) -> None: - self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier) - - self._engine = InferenceEngine( - policy=ctx.policy, - preprocessor=ctx.preprocessor, - postprocessor=ctx.postprocessor, - robot_wrapper=ctx.robot_wrapper, - rtc_config=ctx.cfg.rtc, - hw_features=ctx.hw_features, - action_keys=ctx.action_keys, - task=ctx.cfg.task, - fps=ctx.cfg.fps, - device=ctx.cfg.device, - use_torch_compile=ctx.cfg.use_torch_compile, - compile_warmup_inferences=ctx.cfg.compile_warmup_inferences, - ) - self._engine.start() + self._init_engine(ctx) logger.info( "Sentry strategy ready (episode_duration=%.0fs, upload_every=%d eps)", self.config.episode_duration_s, @@ -100,7 +80,6 @@ class SentryStrategy(RolloutStrategy): start_time = time.perf_counter() episode_start = time.perf_counter() episodes_since_push = 0 - warmup_flushed = False task_str = cfg.dataset.single_task if cfg.dataset else cfg.task with VideoEncodingManager(dataset): @@ -117,19 +96,9 @@ class SentryStrategy(RolloutStrategy): if engine.is_rtc: engine.update_observation(obs_processed) - if cfg.use_torch_compile and not engine.compile_warmup_done.is_set(): - dt = time.perf_counter() - loop_start - if (sleep_t := control_interval - dt) > 0: - precise_sleep(sleep_t) + if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): continue - if cfg.use_torch_compile and not warmup_flushed: - engine.reset() - interpolator.reset() - warmup_flushed = True - if engine.is_rtc: - engine.resume() - action_dict = infer_action( engine, obs_processed, obs, ctx, interpolator, ordered_keys, features ) @@ -146,7 +115,7 @@ class SentryStrategy(RolloutStrategy): if elapsed >= self.config.episode_duration_s: dataset.save_episode() episodes_since_push += 1 - self._needs_push = True + self._needs_push.set() logger.info("Episode saved (total: %d)", dataset.num_episodes) if episodes_since_push >= self.config.upload_every_n_episodes: @@ -166,12 +135,9 @@ class SentryStrategy(RolloutStrategy): finally: with contextlib.suppress(Exception): dataset.save_episode() - self._needs_push = True + self._needs_push.set() def teardown(self, ctx: RolloutContext) -> None: - if self._engine is not None: - self._engine.stop() - # Wait for any in-flight background push if self._push_thread is not None and self._push_thread.is_alive(): self._push_thread.join(timeout=60) @@ -179,16 +145,13 @@ class SentryStrategy(RolloutStrategy): if ctx.dataset is not None: ctx.dataset.finalize() # Only push if there are unsaved changes since last background push - if self._needs_push and ctx.cfg.dataset and ctx.cfg.dataset.push_to_hub: + if self._needs_push.is_set() and ctx.cfg.dataset and ctx.cfg.dataset.push_to_hub: ctx.dataset.push_to_hub( tags=ctx.cfg.dataset.tags, private=ctx.cfg.dataset.private, ) - if ctx.robot.is_connected: - ctx.robot.disconnect() - if ctx.teleop is not None and ctx.teleop.is_connected: - ctx.teleop.disconnect() + self._teardown_hardware(ctx) logger.info("Sentry strategy teardown complete") def _background_push(self, dataset, cfg) -> None: @@ -203,7 +166,7 @@ class SentryStrategy(RolloutStrategy): tags=cfg.dataset.tags if cfg.dataset else None, private=cfg.dataset.private if cfg.dataset else False, ) - self._needs_push = False + self._needs_push.clear() logger.info("Background push to hub complete") except Exception as e: logger.error("Background push failed: %s", e)