From 783ec6e232debf380fbdfe0510b19bc4d2a35577 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 16 Apr 2026 14:34:22 +0200 Subject: [PATCH] minor improvements --- src/lerobot/rollout/configs.py | 24 ++++++++++----------- src/lerobot/rollout/context.py | 10 ++------- src/lerobot/rollout/strategies/highlight.py | 19 ++++++++++++++-- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/src/lerobot/rollout/configs.py b/src/lerobot/rollout/configs.py index d3da043b5..8d9ac776c 100644 --- a/src/lerobot/rollout/configs.py +++ b/src/lerobot/rollout/configs.py @@ -182,18 +182,6 @@ class RolloutConfig: compile_warmup_inferences: int = 2 def __post_init__(self): - # --- Policy loading (same pattern as existing scripts) --- - policy_path = parser.get_path_arg("policy") - if policy_path: - cli_overrides = parser.get_cli_overrides("policy") - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) - self.policy.pretrained_path = policy_path - if self.policy is None: - raise ValueError("--policy.path is required for rollout") - - if self.robot is None: - raise ValueError("--robot.type is required for rollout") - # --- Strategy-specific validation --- if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None: raise ValueError("DAgger strategy requires --teleop.type to be set") @@ -237,6 +225,18 @@ class RolloutConfig: logger.warning("DAgger with record_autonomous=True forces streaming_encoding=True") self.dataset.streaming_encoding = True + # --- Policy loading (same pattern as existing scripts) --- + if self.robot is None: + raise ValueError("--robot.type is required for rollout") + + policy_path = parser.get_path_arg("policy") + if policy_path: + cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + if self.policy is None: + raise ValueError("--policy.path is required for rollout") + @classmethod def __get_path_fields__(cls) -> list[str]: return ["policy"] diff --git a/src/lerobot/rollout/context.py b/src/lerobot/rollout/context.py index 3d33f4cd5..81cab5713 100644 --- a/src/lerobot/rollout/context.py +++ b/src/lerobot/rollout/context.py @@ -21,7 +21,6 @@ and :class:`DatasetContext` — assembled into :class:`RolloutContext`. from __future__ import annotations -import datetime as _dt import logging from dataclasses import dataclass, field from threading import Event @@ -48,7 +47,7 @@ from lerobot.robots import make_robot_from_config from lerobot.teleoperators import Teleoperator, make_teleoperator_from_config from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_features -from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig, SentryStrategyConfig +from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig from .inference import ( InferenceStrategy, RTCInferenceConfig, @@ -269,14 +268,9 @@ def build_rollout_context( raw_action_keys, ) - # --- 5. Dataset (Sentry gets a unique per-run suffix) ------------- + # --- 5. Dataset ------------- dataset = None if cfg.dataset is not None and not isinstance(cfg.strategy, BaseStrategyConfig): - if not cfg.resume and isinstance(cfg.strategy, SentryStrategyConfig) and cfg.dataset.repo_id: - suffix = _dt.datetime.now(_dt.UTC).strftime("%Y%m%dT%H%M%SZ") - cfg.dataset.repo_id = f"{cfg.dataset.repo_id}-{suffix}" - logger.info("Sentry mode: using run-suffixed repo_id=%s", cfg.dataset.repo_id) - if cfg.resume: dataset = LeRobotDataset.resume( cfg.dataset.repo_id, diff --git a/src/lerobot/rollout/strategies/highlight.py b/src/lerobot/rollout/strategies/highlight.py index ed582d0a1..766929f45 100644 --- a/src/lerobot/rollout/strategies/highlight.py +++ b/src/lerobot/rollout/strategies/highlight.py @@ -18,6 +18,8 @@ from __future__ import annotations import contextlib import logging +import os +import sys import time from threading import Event as ThreadingEvent @@ -25,6 +27,7 @@ from lerobot.common.control_utils import is_headless from lerobot.datasets import VideoEncodingManager from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame +from lerobot.utils.import_utils import _pynput_available, require_package from lerobot.utils.robot_utils import precise_sleep from ..configs import HighlightStrategyConfig @@ -32,6 +35,19 @@ from ..context import RolloutContext from ..ring_buffer import RolloutRingBuffer from .core import RolloutStrategy, send_next_action +PYNPUT_AVAILABLE = _pynput_available +keyboard = None +if PYNPUT_AVAILABLE: + try: + if ("DISPLAY" not in os.environ) and ("linux" in sys.platform): + logging.info("No DISPLAY set. Skipping pynput import.") + PYNPUT_AVAILABLE = False + else: + from pynput import keyboard + except Exception as e: + PYNPUT_AVAILABLE = False + logging.info(f"Could not import pynput: {e}") + logger = logging.getLogger(__name__) @@ -54,6 +70,7 @@ class HighlightStrategy(RolloutStrategy): def __init__(self, config: HighlightStrategyConfig): super().__init__(config) + require_package("pynput", extra="pynput-dep") self._ring: RolloutRingBuffer | None = None self._listener = None self._save_requested = ThreadingEvent() @@ -181,8 +198,6 @@ class HighlightStrategy(RolloutStrategy): return try: - from pynput import keyboard - save_key = self.config.save_key def on_press(key):