mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
minor improvements
This commit is contained in:
@@ -182,18 +182,6 @@ class RolloutConfig:
|
||||
compile_warmup_inferences: int = 2
|
||||
|
||||
def __post_init__(self):
|
||||
# --- Policy loading (same pattern as existing scripts) ---
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
if self.policy is None:
|
||||
raise ValueError("--policy.path is required for rollout")
|
||||
|
||||
if self.robot is None:
|
||||
raise ValueError("--robot.type is required for rollout")
|
||||
|
||||
# --- Strategy-specific validation ---
|
||||
if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None:
|
||||
raise ValueError("DAgger strategy requires --teleop.type to be set")
|
||||
@@ -237,6 +225,18 @@ class RolloutConfig:
|
||||
logger.warning("DAgger with record_autonomous=True forces streaming_encoding=True")
|
||||
self.dataset.streaming_encoding = True
|
||||
|
||||
# --- Policy loading (same pattern as existing scripts) ---
|
||||
if self.robot is None:
|
||||
raise ValueError("--robot.type is required for rollout")
|
||||
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
if self.policy is None:
|
||||
raise ValueError("--policy.path is required for rollout")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
@@ -21,7 +21,6 @@ and :class:`DatasetContext` — assembled into :class:`RolloutContext`.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event
|
||||
@@ -48,7 +47,7 @@ from lerobot.robots import make_robot_from_config
|
||||
from lerobot.teleoperators import Teleoperator, make_teleoperator_from_config
|
||||
from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_features
|
||||
|
||||
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig, SentryStrategyConfig
|
||||
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig
|
||||
from .inference import (
|
||||
InferenceStrategy,
|
||||
RTCInferenceConfig,
|
||||
@@ -269,14 +268,9 @@ def build_rollout_context(
|
||||
raw_action_keys,
|
||||
)
|
||||
|
||||
# --- 5. Dataset (Sentry gets a unique per-run suffix) -------------
|
||||
# --- 5. Dataset -------------
|
||||
dataset = None
|
||||
if cfg.dataset is not None and not isinstance(cfg.strategy, BaseStrategyConfig):
|
||||
if not cfg.resume and isinstance(cfg.strategy, SentryStrategyConfig) and cfg.dataset.repo_id:
|
||||
suffix = _dt.datetime.now(_dt.UTC).strftime("%Y%m%dT%H%M%SZ")
|
||||
cfg.dataset.repo_id = f"{cfg.dataset.repo_id}-{suffix}"
|
||||
logger.info("Sentry mode: using run-suffixed repo_id=%s", cfg.dataset.repo_id)
|
||||
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset.resume(
|
||||
cfg.dataset.repo_id,
|
||||
|
||||
@@ -18,6 +18,8 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from threading import Event as ThreadingEvent
|
||||
|
||||
@@ -25,6 +27,7 @@ from lerobot.common.control_utils import is_headless
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available, require_package
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
from ..configs import HighlightStrategyConfig
|
||||
@@ -32,6 +35,19 @@ from ..context import RolloutContext
|
||||
from ..ring_buffer import RolloutRingBuffer
|
||||
from .core import RolloutStrategy, send_next_action
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
keyboard = None
|
||||
if PYNPUT_AVAILABLE:
|
||||
try:
|
||||
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
||||
logging.info("No DISPLAY set. Skipping pynput import.")
|
||||
PYNPUT_AVAILABLE = False
|
||||
else:
|
||||
from pynput import keyboard
|
||||
except Exception as e:
|
||||
PYNPUT_AVAILABLE = False
|
||||
logging.info(f"Could not import pynput: {e}")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -54,6 +70,7 @@ class HighlightStrategy(RolloutStrategy):
|
||||
|
||||
def __init__(self, config: HighlightStrategyConfig):
|
||||
super().__init__(config)
|
||||
require_package("pynput", extra="pynput-dep")
|
||||
self._ring: RolloutRingBuffer | None = None
|
||||
self._listener = None
|
||||
self._save_requested = ThreadingEvent()
|
||||
@@ -181,8 +198,6 @@ class HighlightStrategy(RolloutStrategy):
|
||||
return
|
||||
|
||||
try:
|
||||
from pynput import keyboard
|
||||
|
||||
save_key = self.config.save_key
|
||||
|
||||
def on_press(key):
|
||||
|
||||
Reference in New Issue
Block a user