minor improvements

This commit is contained in:
Steven Palma
2026-04-16 14:34:22 +02:00
parent 4e3175ff15
commit 783ec6e232
3 changed files with 31 additions and 22 deletions

View File

@@ -182,18 +182,6 @@ class RolloutConfig:
compile_warmup_inferences: int = 2
def __post_init__(self):
# --- Policy loading (same pattern as existing scripts) ---
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.policy is None:
raise ValueError("--policy.path is required for rollout")
if self.robot is None:
raise ValueError("--robot.type is required for rollout")
# --- Strategy-specific validation ---
if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None:
raise ValueError("DAgger strategy requires --teleop.type to be set")
@@ -237,6 +225,18 @@ class RolloutConfig:
logger.warning("DAgger with record_autonomous=True forces streaming_encoding=True")
self.dataset.streaming_encoding = True
# --- Policy loading (same pattern as existing scripts) ---
if self.robot is None:
raise ValueError("--robot.type is required for rollout")
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.policy is None:
raise ValueError("--policy.path is required for rollout")
@classmethod
def __get_path_fields__(cls) -> list[str]:
return ["policy"]

View File

@@ -21,7 +21,6 @@ and :class:`DatasetContext` — assembled into :class:`RolloutContext`.
from __future__ import annotations
import datetime as _dt
import logging
from dataclasses import dataclass, field
from threading import Event
@@ -48,7 +47,7 @@ from lerobot.robots import make_robot_from_config
from lerobot.teleoperators import Teleoperator, make_teleoperator_from_config
from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_features
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig, SentryStrategyConfig
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig
from .inference import (
InferenceStrategy,
RTCInferenceConfig,
@@ -269,14 +268,9 @@ def build_rollout_context(
raw_action_keys,
)
# --- 5. Dataset (Sentry gets a unique per-run suffix) -------------
# --- 5. Dataset -------------
dataset = None
if cfg.dataset is not None and not isinstance(cfg.strategy, BaseStrategyConfig):
if not cfg.resume and isinstance(cfg.strategy, SentryStrategyConfig) and cfg.dataset.repo_id:
suffix = _dt.datetime.now(_dt.UTC).strftime("%Y%m%dT%H%M%SZ")
cfg.dataset.repo_id = f"{cfg.dataset.repo_id}-{suffix}"
logger.info("Sentry mode: using run-suffixed repo_id=%s", cfg.dataset.repo_id)
if cfg.resume:
dataset = LeRobotDataset.resume(
cfg.dataset.repo_id,

View File

@@ -18,6 +18,8 @@ from __future__ import annotations
import contextlib
import logging
import os
import sys
import time
from threading import Event as ThreadingEvent
@@ -25,6 +27,7 @@ from lerobot.common.control_utils import is_headless
from lerobot.datasets import VideoEncodingManager
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.import_utils import _pynput_available, require_package
from lerobot.utils.robot_utils import precise_sleep
from ..configs import HighlightStrategyConfig
@@ -32,6 +35,19 @@ from ..context import RolloutContext
from ..ring_buffer import RolloutRingBuffer
from .core import RolloutStrategy, send_next_action
PYNPUT_AVAILABLE = _pynput_available
keyboard = None
if PYNPUT_AVAILABLE:
try:
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
logging.info("No DISPLAY set. Skipping pynput import.")
PYNPUT_AVAILABLE = False
else:
from pynput import keyboard
except Exception as e:
PYNPUT_AVAILABLE = False
logging.info(f"Could not import pynput: {e}")
logger = logging.getLogger(__name__)
@@ -54,6 +70,7 @@ class HighlightStrategy(RolloutStrategy):
def __init__(self, config: HighlightStrategyConfig):
super().__init__(config)
require_package("pynput", extra="pynput-dep")
self._ring: RolloutRingBuffer | None = None
self._listener = None
self._save_requested = ThreadingEvent()
@@ -181,8 +198,6 @@ class HighlightStrategy(RolloutStrategy):
return
try:
from pynput import keyboard
save_key = self.config.save_key
def on_press(key):