mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 04:41:24 +00:00
more improvements and fixes
This commit is contained in:
@@ -92,10 +92,10 @@ class ActionQueue:
|
||||
Returns:
|
||||
int: Number of unconsumed actions.
|
||||
"""
|
||||
if self.queue is None:
|
||||
return 0
|
||||
length = len(self.queue)
|
||||
return length - self.last_index
|
||||
with self.lock:
|
||||
if self.queue is None:
|
||||
return 0
|
||||
return len(self.queue) - self.last_index
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the queue is empty.
|
||||
@@ -103,11 +103,10 @@ class ActionQueue:
|
||||
Returns:
|
||||
bool: True if no actions remain, False otherwise.
|
||||
"""
|
||||
if self.queue is None:
|
||||
return True
|
||||
|
||||
length = len(self.queue)
|
||||
return length - self.last_index <= 0
|
||||
with self.lock:
|
||||
if self.queue is None:
|
||||
return True
|
||||
return len(self.queue) - self.last_index <= 0
|
||||
|
||||
def get_action_index(self) -> int:
|
||||
"""Get the current action consumption index.
|
||||
@@ -115,7 +114,8 @@ class ActionQueue:
|
||||
Returns:
|
||||
int: Index of the next action to be consumed.
|
||||
"""
|
||||
return self.last_index
|
||||
with self.lock:
|
||||
return self.last_index
|
||||
|
||||
def get_left_over(self) -> Tensor | None:
|
||||
"""Get leftover original actions for RTC prev_chunk_left_over.
|
||||
|
||||
@@ -203,10 +203,13 @@ class RolloutConfig:
|
||||
)
|
||||
|
||||
# Sentry MUST use streaming encoding to avoid disk I/O blocking the control loop
|
||||
if isinstance(self.strategy, SentryStrategyConfig) and self.dataset is not None:
|
||||
if not self.dataset.streaming_encoding:
|
||||
logger.warning("Sentry mode forces streaming_encoding=True")
|
||||
self.dataset.streaming_encoding = True
|
||||
if (
|
||||
isinstance(self.strategy, SentryStrategyConfig)
|
||||
and self.dataset is not None
|
||||
and not self.dataset.streaming_encoding
|
||||
):
|
||||
logger.warning("Sentry mode forces streaming_encoding=True")
|
||||
self.dataset.streaming_encoding = True
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
|
||||
@@ -147,6 +147,7 @@ class InferenceEngine:
|
||||
use_torch_compile: bool = False,
|
||||
compile_warmup_inferences: int = 2,
|
||||
rtc_queue_threshold: int = 30,
|
||||
shutdown_event: Event | None = None,
|
||||
) -> None:
|
||||
self._policy = policy
|
||||
self._preprocessor = preprocessor
|
||||
@@ -170,6 +171,8 @@ class InferenceEngine:
|
||||
self._policy_active = Event()
|
||||
self._compile_warmup_done = Event()
|
||||
self._shutdown_event = Event()
|
||||
self._rtc_error = Event()
|
||||
self._global_shutdown_event = shutdown_event
|
||||
self._rtc_thread: Thread | None = None
|
||||
|
||||
if not self._use_torch_compile:
|
||||
@@ -211,6 +214,11 @@ class InferenceEngine:
|
||||
def compile_warmup_done(self) -> Event:
|
||||
return self._compile_warmup_done
|
||||
|
||||
@property
|
||||
def rtc_failed(self) -> bool:
|
||||
"""True if the RTC background thread exited due to an unrecoverable error."""
|
||||
return self._rtc_error.is_set()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the inference engine. Launches the RTC background thread if enabled."""
|
||||
if self._use_rtc:
|
||||
@@ -249,8 +257,8 @@ class InferenceEngine:
|
||||
self._policy.reset()
|
||||
self._preprocessor.reset()
|
||||
self._postprocessor.reset()
|
||||
if self._use_rtc:
|
||||
self._action_queue = ActionQueue(self._rtc_config)
|
||||
if self._use_rtc and self._action_queue is not None:
|
||||
self._action_queue.clear()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sync inference
|
||||
@@ -401,3 +409,9 @@ class InferenceEngine:
|
||||
except Exception as e:
|
||||
logger.error("Fatal error in RTC thread: %s", e)
|
||||
logger.error(traceback.format_exc())
|
||||
self._rtc_error.set()
|
||||
# Unblock any warmup waiters so the main loop doesn't spin forever
|
||||
self._compile_warmup_done.set()
|
||||
# Signal the top-level shutdown so strategies exit their control loops
|
||||
if self._global_shutdown_event is not None:
|
||||
self._global_shutdown_event.set()
|
||||
|
||||
@@ -95,6 +95,6 @@ def _estimate_frame_bytes(frame: dict) -> int:
|
||||
total += v.nbytes
|
||||
elif isinstance(v, (int, float)):
|
||||
total += 8
|
||||
elif isinstance(v, str) or isinstance(v, bytes):
|
||||
elif isinstance(v, (str, bytes)):
|
||||
total += len(v)
|
||||
return max(total, 1) # avoid zero-size frames
|
||||
|
||||
@@ -12,209 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Rollout strategy ABC, factory, and shared inference helper."""
|
||||
"""Rollout strategies — public API re-exports."""
|
||||
|
||||
from __future__ import annotations
|
||||
from .core import RolloutStrategy, infer_action
|
||||
from .factory import create_strategy
|
||||
|
||||
import abc
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rollout.configs import RolloutStrategyConfig
|
||||
from lerobot.rollout.context import RolloutContext
|
||||
from lerobot.rollout.inference import InferenceEngine
|
||||
|
||||
|
||||
class RolloutStrategy(abc.ABC):
|
||||
"""Abstract base for rollout execution strategies.
|
||||
|
||||
Each concrete strategy implements a self-contained control loop with
|
||||
its own recording/interaction semantics. Strategies are mutually
|
||||
exclusive — only one runs per session.
|
||||
"""
|
||||
|
||||
def __init__(self, config: RolloutStrategyConfig) -> None:
|
||||
self.config = config
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._interpolator: ActionInterpolator | None = None
|
||||
self._warmup_flushed: bool = False
|
||||
|
||||
def _init_engine(self, ctx: RolloutContext) -> None:
|
||||
"""Create and start the inference engine and action interpolator.
|
||||
|
||||
Call this from ``setup()`` to avoid duplicating the engine
|
||||
construction across every strategy.
|
||||
"""
|
||||
from lerobot.rollout.inference import InferenceEngine
|
||||
|
||||
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
|
||||
self._engine = InferenceEngine(
|
||||
policy=ctx.policy,
|
||||
preprocessor=ctx.preprocessor,
|
||||
postprocessor=ctx.postprocessor,
|
||||
robot_wrapper=ctx.robot_wrapper,
|
||||
rtc_config=ctx.cfg.rtc,
|
||||
hw_features=ctx.hw_features,
|
||||
action_keys=ctx.action_keys,
|
||||
task=ctx.cfg.task,
|
||||
fps=ctx.cfg.fps,
|
||||
device=ctx.cfg.device,
|
||||
use_torch_compile=ctx.cfg.use_torch_compile,
|
||||
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
|
||||
)
|
||||
self._engine.start()
|
||||
self._warmup_flushed = False
|
||||
|
||||
def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool:
|
||||
"""Handle torch.compile warmup phase.
|
||||
|
||||
Returns ``True`` if the caller should ``continue`` (still warming
|
||||
up). On the first post-warmup iteration the engine and
|
||||
interpolator are reset so stale warmup state is discarded.
|
||||
"""
|
||||
engine = self._engine
|
||||
interpolator = self._interpolator
|
||||
if not use_torch_compile:
|
||||
return False
|
||||
if not engine.compile_warmup_done.is_set():
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
return True
|
||||
if not self._warmup_flushed:
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
self._warmup_flushed = True
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
return False
|
||||
|
||||
def _teardown_hardware(self, ctx: RolloutContext) -> None:
|
||||
"""Stop the inference engine and disconnect hardware."""
|
||||
if self._engine is not None:
|
||||
self._engine.stop()
|
||||
if ctx.robot.is_connected:
|
||||
ctx.robot.disconnect()
|
||||
if ctx.teleop is not None and ctx.teleop.is_connected:
|
||||
ctx.teleop.disconnect()
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Strategy-specific initialisation (keyboard listeners, buffers, etc.)."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Main rollout loop. Returns when shutdown is requested or duration expires."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Cleanup: save dataset, stop threads, disconnect hardware."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared inference helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def infer_action(
|
||||
engine: InferenceEngine,
|
||||
obs_processed: dict,
|
||||
obs_raw: dict,
|
||||
ctx: RolloutContext,
|
||||
interpolator: ActionInterpolator,
|
||||
ordered_keys: list[str],
|
||||
features: dict,
|
||||
) -> dict | None:
|
||||
"""Run one policy inference step and send the resulting action to the robot.
|
||||
|
||||
Handles both sync and RTC backends. Uses the interpolator for smooth
|
||||
control at higher-than-inference rates (works with any multiplier,
|
||||
including 1 where it acts as a pass-through).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
engine:
|
||||
The inference engine (sync or RTC).
|
||||
obs_processed:
|
||||
Observation dict after ``robot_observation_processor``.
|
||||
obs_raw:
|
||||
Raw observation dict (needed by ``robot_action_processor``).
|
||||
ctx:
|
||||
Rollout context.
|
||||
interpolator:
|
||||
Action interpolator for Nx control rate.
|
||||
ordered_keys:
|
||||
Ordered action feature names (policy-to-robot mapping).
|
||||
features:
|
||||
Feature specification dict for ``build_dataset_frame`` /
|
||||
``make_robot_action``. Use ``dataset.features`` when recording,
|
||||
``ctx.dataset_features`` otherwise.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Action dict sent to the robot, or ``None`` if no action was
|
||||
available (empty RTC queue, interpolator buffer not ready).
|
||||
"""
|
||||
if engine.is_rtc:
|
||||
if interpolator.needs_new_action():
|
||||
action_tensor = engine.consume_rtc_action()
|
||||
if action_tensor is not None:
|
||||
interpolator.add(action_tensor.cpu())
|
||||
else:
|
||||
if interpolator.needs_new_action():
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_tensor = engine.get_action_sync(obs_frame)
|
||||
action_dict = make_robot_action(action_tensor, features)
|
||||
action_t = torch.tensor([action_dict[k] for k in ordered_keys])
|
||||
interpolator.add(action_t)
|
||||
|
||||
interp = interpolator.get()
|
||||
if interp is not None:
|
||||
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
|
||||
processed = ctx.robot_action_processor((action_dict, obs_raw))
|
||||
ctx.robot_wrapper.send_action(processed)
|
||||
return action_dict
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Strategy factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
||||
"""Instantiate the appropriate strategy from a config object."""
|
||||
from lerobot.rollout.configs import (
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
SentryStrategyConfig,
|
||||
)
|
||||
|
||||
if isinstance(config, BaseStrategyConfig):
|
||||
from .base import BaseStrategy
|
||||
|
||||
return BaseStrategy(config)
|
||||
if isinstance(config, SentryStrategyConfig):
|
||||
from .sentry import SentryStrategy
|
||||
|
||||
return SentryStrategy(config)
|
||||
if isinstance(config, HighlightStrategyConfig):
|
||||
from .highlight import HighlightStrategy
|
||||
|
||||
return HighlightStrategy(config)
|
||||
if isinstance(config, DAggerStrategyConfig):
|
||||
from .dagger import DAggerStrategy
|
||||
|
||||
return DAggerStrategy(config)
|
||||
|
||||
raise ValueError(f"Unknown strategy config type: {type(config).__name__}")
|
||||
__all__ = [
|
||||
"RolloutStrategy",
|
||||
"create_strategy",
|
||||
"infer_action",
|
||||
]
|
||||
|
||||
187
src/lerobot/rollout/strategies/core.py
Normal file
187
src/lerobot/rollout/strategies/core.py
Normal file
@@ -0,0 +1,187 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Rollout strategy ABC and shared inference helper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rollout.configs import RolloutStrategyConfig
|
||||
from lerobot.rollout.context import RolloutContext
|
||||
from lerobot.rollout.inference import InferenceEngine
|
||||
|
||||
|
||||
class RolloutStrategy(abc.ABC):
|
||||
"""Abstract base for rollout execution strategies.
|
||||
|
||||
Each concrete strategy implements a self-contained control loop with
|
||||
its own recording/interaction semantics. Strategies are mutually
|
||||
exclusive — only one runs per session.
|
||||
"""
|
||||
|
||||
def __init__(self, config: RolloutStrategyConfig) -> None:
|
||||
self.config = config
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._interpolator: ActionInterpolator | None = None
|
||||
self._warmup_flushed: bool = False
|
||||
|
||||
def _init_engine(self, ctx: RolloutContext) -> None:
|
||||
"""Create and start the inference engine and action interpolator.
|
||||
|
||||
Call this from ``setup()`` to avoid duplicating the engine
|
||||
construction across every strategy.
|
||||
"""
|
||||
from lerobot.rollout.inference import InferenceEngine
|
||||
|
||||
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
|
||||
self._engine = InferenceEngine(
|
||||
policy=ctx.policy,
|
||||
preprocessor=ctx.preprocessor,
|
||||
postprocessor=ctx.postprocessor,
|
||||
robot_wrapper=ctx.robot_wrapper,
|
||||
rtc_config=ctx.cfg.rtc,
|
||||
hw_features=ctx.hw_features,
|
||||
action_keys=ctx.action_keys,
|
||||
task=ctx.cfg.task,
|
||||
fps=ctx.cfg.fps,
|
||||
device=ctx.cfg.device,
|
||||
use_torch_compile=ctx.cfg.use_torch_compile,
|
||||
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
|
||||
shutdown_event=ctx.shutdown_event,
|
||||
)
|
||||
self._engine.start()
|
||||
self._warmup_flushed = False
|
||||
|
||||
def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool:
|
||||
"""Handle torch.compile warmup phase.
|
||||
|
||||
Returns ``True`` if the caller should ``continue`` (still warming
|
||||
up). On the first post-warmup iteration the engine and
|
||||
interpolator are reset so stale warmup state is discarded.
|
||||
"""
|
||||
engine = self._engine
|
||||
interpolator = self._interpolator
|
||||
if not use_torch_compile:
|
||||
return False
|
||||
if not engine.compile_warmup_done.is_set():
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
return True
|
||||
if not self._warmup_flushed:
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
self._warmup_flushed = True
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
return False
|
||||
|
||||
def _teardown_hardware(self, ctx: RolloutContext) -> None:
|
||||
"""Stop the inference engine and disconnect hardware."""
|
||||
if self._engine is not None:
|
||||
self._engine.stop()
|
||||
if ctx.robot.is_connected:
|
||||
ctx.robot.disconnect()
|
||||
if ctx.teleop is not None and ctx.teleop.is_connected:
|
||||
ctx.teleop.disconnect()
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Strategy-specific initialisation (keyboard listeners, buffers, etc.)."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Main rollout loop. Returns when shutdown is requested or duration expires."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Cleanup: save dataset, stop threads, disconnect hardware."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared inference helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def infer_action(
|
||||
engine: InferenceEngine,
|
||||
obs_processed: dict,
|
||||
obs_raw: dict,
|
||||
ctx: RolloutContext,
|
||||
interpolator: ActionInterpolator,
|
||||
ordered_keys: list[str],
|
||||
features: dict,
|
||||
) -> dict | None:
|
||||
"""Run one policy inference step and send the resulting action to the robot.
|
||||
|
||||
Handles both sync and RTC backends. Uses the interpolator for smooth
|
||||
control at higher-than-inference rates (works with any multiplier,
|
||||
including 1 where it acts as a pass-through).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
engine:
|
||||
The inference engine (sync or RTC).
|
||||
obs_processed:
|
||||
Observation dict after ``robot_observation_processor``.
|
||||
obs_raw:
|
||||
Raw observation dict (needed by ``robot_action_processor``).
|
||||
ctx:
|
||||
Rollout context.
|
||||
interpolator:
|
||||
Action interpolator for Nx control rate.
|
||||
ordered_keys:
|
||||
Ordered action feature names (policy-to-robot mapping).
|
||||
features:
|
||||
Feature specification dict for ``build_dataset_frame`` /
|
||||
``make_robot_action``. Use ``dataset.features`` when recording,
|
||||
``ctx.dataset_features`` otherwise.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Action dict sent to the robot, or ``None`` if no action was
|
||||
available (empty RTC queue, interpolator buffer not ready).
|
||||
"""
|
||||
if engine.is_rtc:
|
||||
if interpolator.needs_new_action():
|
||||
action_tensor = engine.consume_rtc_action()
|
||||
if action_tensor is not None:
|
||||
interpolator.add(action_tensor.cpu())
|
||||
else:
|
||||
if interpolator.needs_new_action():
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_tensor = engine.get_action_sync(obs_frame)
|
||||
action_dict = make_robot_action(action_tensor, features)
|
||||
action_t = torch.tensor([action_dict[k] for k in ordered_keys])
|
||||
interpolator.add(action_t)
|
||||
|
||||
interp = interpolator.get()
|
||||
if interp is not None:
|
||||
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
|
||||
processed = ctx.robot_action_processor((action_dict, obs_raw))
|
||||
ctx.robot_wrapper.send_action(processed)
|
||||
return action_dict
|
||||
return None
|
||||
@@ -30,8 +30,10 @@ Keyboard Controls:
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import enum
|
||||
import logging
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@@ -53,6 +55,99 @@ from . import RolloutStrategy, infer_action
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DAgger state machine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DAggerPhase(enum.Enum):
|
||||
"""Observable phases of a DAgger episode."""
|
||||
|
||||
AUTONOMOUS = "autonomous" # Policy driving, recording autonomous frames
|
||||
PAUSED = "paused" # Engine paused, teleop aligned, awaiting takeover/resume
|
||||
CORRECTING = "correcting" # Human driving via teleop, recording interventions
|
||||
|
||||
|
||||
# Valid (current_phase, event) → next_phase
|
||||
_DAGGER_TRANSITIONS: dict[tuple[DAggerPhase, str], DAggerPhase] = {
|
||||
(DAggerPhase.AUTONOMOUS, "pause"): DAggerPhase.PAUSED,
|
||||
(DAggerPhase.PAUSED, "takeover"): DAggerPhase.CORRECTING,
|
||||
(DAggerPhase.PAUSED, "resume"): DAggerPhase.AUTONOMOUS,
|
||||
(DAggerPhase.CORRECTING, "resume"): DAggerPhase.AUTONOMOUS,
|
||||
}
|
||||
|
||||
|
||||
class DAggerEvents:
|
||||
"""Thread-safe container for DAgger keyboard/pedal events.
|
||||
|
||||
Replaces the previous plain dict with a lock-protected phase enum
|
||||
and edge-triggered transition requests. The keyboard/pedal threads
|
||||
write transition requests; the main loop consumes them.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = Lock()
|
||||
self._phase = DAggerPhase.AUTONOMOUS
|
||||
self._pending_transition: str | None = None
|
||||
|
||||
# Episode-level flags (written by keyboard, consumed by main loop)
|
||||
self.exit_early: bool = False
|
||||
self.rerecord_episode: bool = False
|
||||
self.stop_recording: bool = False
|
||||
|
||||
# Reset-phase flags (simpler lifecycle, shared between threads)
|
||||
self.in_reset: bool = False
|
||||
self.start_next_episode: bool = False
|
||||
|
||||
# -- Thread-safe phase access ------------------------------------------
|
||||
|
||||
@property
|
||||
def phase(self) -> DAggerPhase:
|
||||
with self._lock:
|
||||
return self._phase
|
||||
|
||||
@phase.setter
|
||||
def phase(self, value: DAggerPhase) -> None:
|
||||
with self._lock:
|
||||
self._phase = value
|
||||
|
||||
def request_transition(self, event: str) -> None:
|
||||
"""Request a phase transition (called from keyboard/pedal threads).
|
||||
|
||||
Only enqueues the request if it corresponds to a valid transition
|
||||
from the current phase, preventing impossible state changes.
|
||||
"""
|
||||
with self._lock:
|
||||
if (self._phase, event) in _DAGGER_TRANSITIONS:
|
||||
self._pending_transition = event
|
||||
|
||||
def consume_transition(self) -> tuple[DAggerPhase, DAggerPhase] | None:
|
||||
"""Consume a pending transition (called from main loop).
|
||||
|
||||
Returns ``(old_phase, new_phase)`` if a valid transition was
|
||||
pending, or ``None`` if there is nothing to process.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._pending_transition is None:
|
||||
return None
|
||||
key = (self._phase, self._pending_transition)
|
||||
self._pending_transition = None
|
||||
new_phase = _DAGGER_TRANSITIONS.get(key)
|
||||
if new_phase is None:
|
||||
return None
|
||||
old_phase = self._phase
|
||||
self._phase = new_phase
|
||||
return old_phase, new_phase
|
||||
|
||||
def reset_for_episode(self) -> None:
|
||||
"""Reset all transient state at the start of an episode."""
|
||||
with self._lock:
|
||||
self._phase = DAggerPhase.AUTONOMOUS
|
||||
self._pending_transition = None
|
||||
self.exit_early = False
|
||||
self.rerecord_episode = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Teleoperator helpers (extracted from examples/hil/hil_utils.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -99,7 +194,7 @@ def _teleop_smooth_move_to(
|
||||
def _reset_loop(
|
||||
robot: ThreadSafeRobot,
|
||||
teleop: Teleoperator,
|
||||
events: dict,
|
||||
events: DAggerEvents,
|
||||
fps: int,
|
||||
teleop_action_processor: RobotProcessorPipeline,
|
||||
robot_action_processor: RobotProcessorPipeline,
|
||||
@@ -111,24 +206,24 @@ def _reset_loop(
|
||||
"""
|
||||
logger.info("RESET — press any key to enable teleoperation")
|
||||
|
||||
events["in_reset"] = True
|
||||
events["start_next_episode"] = False
|
||||
events.in_reset = True
|
||||
events.start_next_episode = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
_teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
while not events.start_next_episode and not events.stop_recording:
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events["stop_recording"]:
|
||||
if events.stop_recording:
|
||||
return
|
||||
|
||||
events["start_next_episode"] = False
|
||||
events.start_next_episode = False
|
||||
_teleop_disable_torque(teleop)
|
||||
logger.info("Teleop enabled — press any key to start episode")
|
||||
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
while not events.start_next_episode and not events.stop_recording:
|
||||
loop_start = time.perf_counter()
|
||||
obs = robot.get_observation()
|
||||
action = teleop.get_action()
|
||||
@@ -137,78 +232,78 @@ def _reset_loop(
|
||||
robot.send_action(robot_action_to_send)
|
||||
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
|
||||
|
||||
events["in_reset"] = False
|
||||
events["start_next_episode"] = False
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
events["resume_policy"] = False
|
||||
events.in_reset = False
|
||||
events.start_next_episode = False
|
||||
events.reset_for_episode()
|
||||
|
||||
|
||||
def _init_dagger_keyboard():
|
||||
"""Initialise keyboard listener with DAgger/HIL controls."""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
"policy_paused": False,
|
||||
"correction_active": False,
|
||||
"resume_policy": False,
|
||||
"in_reset": False,
|
||||
"start_next_episode": False,
|
||||
}
|
||||
def _init_dagger_keyboard(events: DAggerEvents):
|
||||
"""Initialise keyboard listener with DAgger/HIL controls.
|
||||
|
||||
Returns the pynput Listener (or ``None`` in headless mode).
|
||||
"""
|
||||
if is_headless():
|
||||
logger.warning("Headless environment — keyboard controls unavailable")
|
||||
return None, events
|
||||
return None
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events["in_reset"]:
|
||||
# During the reset phase, only accept episode-start or stop
|
||||
if events.in_reset:
|
||||
if (
|
||||
key in [keyboard.Key.space, keyboard.Key.right]
|
||||
or hasattr(key, "char")
|
||||
and key.char == "c"
|
||||
):
|
||||
events["start_next_episode"] = True
|
||||
events.start_next_episode = True
|
||||
elif key == keyboard.Key.esc:
|
||||
events["stop_recording"] = True
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
if key == keyboard.Key.space:
|
||||
if not events["policy_paused"] and not events["correction_active"]:
|
||||
logger.info("PAUSED — press 'c' to take control or 'p' to resume policy")
|
||||
events["policy_paused"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
logger.info("Taking control...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, "char") and key.char == "p":
|
||||
if events["policy_paused"] or events["correction_active"]:
|
||||
logger.info("Resuming policy...")
|
||||
events["resume_policy"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
logger.info("End episode")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
logger.info("Re-record episode")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
logger.info("Stop recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
events.stop_recording = True
|
||||
events.start_next_episode = True
|
||||
return
|
||||
|
||||
# Phase-aware transition requests
|
||||
phase = events.phase
|
||||
if key == keyboard.Key.space and phase == DAggerPhase.AUTONOMOUS:
|
||||
logger.info("PAUSED — press 'c' to take control or 'p' to resume policy")
|
||||
events.request_transition("pause")
|
||||
elif hasattr(key, "char") and key.char == "c" and phase == DAggerPhase.PAUSED:
|
||||
logger.info("Taking control...")
|
||||
events.request_transition("takeover")
|
||||
elif (
|
||||
hasattr(key, "char")
|
||||
and key.char == "p"
|
||||
and phase
|
||||
in (
|
||||
DAggerPhase.PAUSED,
|
||||
DAggerPhase.CORRECTING,
|
||||
)
|
||||
):
|
||||
logger.info("Resuming policy...")
|
||||
events.request_transition("resume")
|
||||
|
||||
# Episode-level controls (valid in any phase)
|
||||
elif key == keyboard.Key.right:
|
||||
logger.info("End episode")
|
||||
events.exit_early = True
|
||||
elif key == keyboard.Key.left:
|
||||
logger.info("Re-record episode")
|
||||
events.rerecord_episode = True
|
||||
events.exit_early = True
|
||||
elif key == keyboard.Key.esc:
|
||||
logger.info("Stop recording...")
|
||||
events.stop_recording = True
|
||||
events.exit_early = True
|
||||
except Exception as e:
|
||||
logger.debug("Key error: %s", e)
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
return listener, events
|
||||
return listener
|
||||
|
||||
|
||||
def _start_pedal_listener(events: dict) -> None:
|
||||
def _start_pedal_listener(events: DAggerEvents) -> None:
|
||||
"""Start foot pedal listener thread if evdev is available."""
|
||||
import threading
|
||||
|
||||
@@ -232,18 +327,19 @@ def _start_pedal_listener(events: dict) -> None:
|
||||
code = code[0]
|
||||
if key.keystate != 1:
|
||||
continue
|
||||
if events["in_reset"]:
|
||||
if events.in_reset:
|
||||
if code in ["KEY_A", "KEY_C"]:
|
||||
events["start_next_episode"] = True
|
||||
events.start_next_episode = True
|
||||
else:
|
||||
if code not in ["KEY_A", "KEY_C"]:
|
||||
continue
|
||||
if events["correction_active"]:
|
||||
events["resume_policy"] = True
|
||||
elif events["policy_paused"]:
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
events["policy_paused"] = True
|
||||
phase = events.phase
|
||||
if phase == DAggerPhase.CORRECTING:
|
||||
events.request_transition("resume")
|
||||
elif phase == DAggerPhase.PAUSED:
|
||||
events.request_transition("takeover")
|
||||
elif phase == DAggerPhase.AUTONOMOUS:
|
||||
events.request_transition("pause")
|
||||
except (FileNotFoundError, PermissionError):
|
||||
pass
|
||||
except Exception as e:
|
||||
@@ -260,9 +356,11 @@ def _start_pedal_listener(events: dict) -> None:
|
||||
class DAggerStrategy(RolloutStrategy):
|
||||
"""Human-in-the-Loop data collection with intervention tagging.
|
||||
|
||||
State machine:
|
||||
AUTONOMOUS -> (SPACE) -> PAUSED -> ('c') -> TAKEOVER -> ('p') -> AUTONOMOUS
|
||||
-> (->) -> save episode
|
||||
Uses a formal state machine (see :class:`DAggerPhase`) for phase
|
||||
transitions, eliminating impossible states::
|
||||
|
||||
AUTONOMOUS --(SPACE)--> PAUSED --(c)--> CORRECTING --(p)--> AUTONOMOUS
|
||||
--(p)--> AUTONOMOUS
|
||||
|
||||
Supports both synchronous and RTC inference backends.
|
||||
All actions (policy and teleop) flow through the appropriate
|
||||
@@ -278,12 +376,12 @@ class DAggerStrategy(RolloutStrategy):
|
||||
def __init__(self, config: DAggerStrategyConfig):
|
||||
super().__init__(config)
|
||||
self._listener = None
|
||||
self._events: dict[str, Any] = {}
|
||||
self._events = DAggerEvents()
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
self._init_engine(ctx)
|
||||
|
||||
self._listener, self._events = _init_dagger_keyboard()
|
||||
self._listener = _init_dagger_keyboard(self._events)
|
||||
_start_pedal_listener(self._events)
|
||||
|
||||
logger.info(
|
||||
@@ -302,22 +400,22 @@ class DAggerStrategy(RolloutStrategy):
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
recorded = 0
|
||||
while recorded < self.config.num_episodes and not events["stop_recording"]:
|
||||
while recorded < self.config.num_episodes and not events.stop_recording:
|
||||
log_say(f"Episode {dataset.num_episodes}", self.config.play_sounds)
|
||||
|
||||
self._run_episode(ctx)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
if events.rerecord_episode:
|
||||
log_say("Re-recording", self.config.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
events.rerecord_episode = False
|
||||
events.exit_early = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
|
||||
if recorded < self.config.num_episodes and not events["stop_recording"]:
|
||||
if recorded < self.config.num_episodes and not events.stop_recording:
|
||||
_reset_loop(
|
||||
ctx.robot_wrapper,
|
||||
teleop,
|
||||
@@ -371,10 +469,9 @@ class DAggerStrategy(RolloutStrategy):
|
||||
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
events.reset_for_episode()
|
||||
_teleop_disable_torque(teleop)
|
||||
|
||||
was_paused = False
|
||||
waiting_for_takeover = False
|
||||
last_action: dict[str, Any] | None = None
|
||||
frame_buffer: list[dict] = []
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
@@ -389,59 +486,26 @@ class DAggerStrategy(RolloutStrategy):
|
||||
while timestamp < self.config.episode_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
events["resume_policy"] = False
|
||||
if events.exit_early:
|
||||
events.exit_early = False
|
||||
break
|
||||
|
||||
# --- Resume from pause/correction ---
|
||||
if events["resume_policy"] and (
|
||||
events["policy_paused"] or events["correction_active"] or waiting_for_takeover
|
||||
):
|
||||
events["resume_policy"] = False
|
||||
events["start_next_episode"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
waiting_for_takeover = False
|
||||
was_paused = False
|
||||
# --- Process pending phase transition ---
|
||||
transition = events.consume_transition()
|
||||
if transition is not None:
|
||||
old_phase, new_phase = transition
|
||||
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
|
||||
last_action = None
|
||||
interpolator.reset()
|
||||
engine.reset()
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
|
||||
# --- Pause: align teleop to robot position ---
|
||||
if events["policy_paused"] and not was_paused:
|
||||
if engine.is_rtc:
|
||||
engine.pause()
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {
|
||||
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||
}
|
||||
_teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
events["start_next_episode"] = False
|
||||
waiting_for_takeover = True
|
||||
was_paused = True
|
||||
interpolator.reset()
|
||||
|
||||
# --- Takeover: enable teleop control ---
|
||||
if waiting_for_takeover and events["start_next_episode"]:
|
||||
_teleop_disable_torque(teleop)
|
||||
events["start_next_episode"] = False
|
||||
events["correction_active"] = True
|
||||
waiting_for_takeover = False
|
||||
if engine.is_rtc:
|
||||
engine.reset()
|
||||
phase = events.phase
|
||||
|
||||
# --- Get observation ---
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.robot_observation_processor(obs)
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# --- CORRECTION: human teleop control ---
|
||||
if events["correction_active"]:
|
||||
# --- CORRECTING: human teleop control ---
|
||||
if phase == DAggerPhase.CORRECTING:
|
||||
teleop_action = teleop.get_action()
|
||||
processed_teleop = ctx.teleop_action_processor((teleop_action, obs))
|
||||
robot_action_to_send = ctx.robot_action_processor((processed_teleop, obs))
|
||||
@@ -461,7 +525,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
record_tick += 1
|
||||
|
||||
# --- PAUSED: hold position ---
|
||||
elif waiting_for_takeover or events["policy_paused"]:
|
||||
elif phase == DAggerPhase.PAUSED:
|
||||
if last_action:
|
||||
robot.send_action(last_action)
|
||||
|
||||
@@ -507,3 +571,41 @@ class DAggerStrategy(RolloutStrategy):
|
||||
if not stream_online:
|
||||
for frame in frame_buffer:
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# State-machine transition side-effects
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _apply_transition(
|
||||
old_phase: DAggerPhase,
|
||||
new_phase: DAggerPhase,
|
||||
engine,
|
||||
interpolator,
|
||||
robot: ThreadSafeRobot,
|
||||
teleop: Teleoperator,
|
||||
) -> None:
|
||||
"""Execute side-effects for a validated phase transition."""
|
||||
if old_phase == DAggerPhase.AUTONOMOUS and new_phase == DAggerPhase.PAUSED:
|
||||
# Pause engine + align teleop to robot position
|
||||
if engine.is_rtc:
|
||||
engine.pause()
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {
|
||||
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||
}
|
||||
_teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
interpolator.reset()
|
||||
|
||||
elif new_phase == DAggerPhase.CORRECTING:
|
||||
# Enable human teleop control
|
||||
_teleop_disable_torque(teleop)
|
||||
if engine.is_rtc:
|
||||
engine.reset()
|
||||
|
||||
elif new_phase == DAggerPhase.AUTONOMOUS:
|
||||
# Resume policy from pause or correction
|
||||
interpolator.reset()
|
||||
engine.reset()
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
|
||||
54
src/lerobot/rollout/strategies/factory.py
Normal file
54
src/lerobot/rollout/strategies/factory.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Strategy factory: config type-name → strategy class dispatch."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .core import RolloutStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rollout.configs import RolloutStrategyConfig
|
||||
|
||||
|
||||
def _lazy_strategy_map() -> dict[str, type[RolloutStrategy]]:
|
||||
"""Build the strategy type-name → class mapping with lazy imports."""
|
||||
from .base import BaseStrategy
|
||||
from .dagger import DAggerStrategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
|
||||
return {
|
||||
"base": BaseStrategy,
|
||||
"sentry": SentryStrategy,
|
||||
"highlight": HighlightStrategy,
|
||||
"dagger": DAggerStrategy,
|
||||
}
|
||||
|
||||
|
||||
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
||||
"""Instantiate the appropriate strategy from a config object.
|
||||
|
||||
Uses ``config.type`` (the name registered via ``draccus.ChoiceRegistry``)
|
||||
to look up the strategy class, so adding a new strategy only requires
|
||||
registering its config subclass and adding one entry to
|
||||
``_lazy_strategy_map``.
|
||||
"""
|
||||
strategy_map = _lazy_strategy_map()
|
||||
strategy_cls = strategy_map.get(config.type)
|
||||
if strategy_cls is None:
|
||||
raise ValueError(f"Unknown strategy type '{config.type}'. Available: {sorted(strategy_map.keys())}")
|
||||
return strategy_cls(config)
|
||||
@@ -19,7 +19,7 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from threading import Event, Thread
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
@@ -46,6 +46,10 @@ class SentryStrategy(RolloutStrategy):
|
||||
All actions flow through ``robot_observation_processor`` (observations)
|
||||
and ``robot_action_processor`` (actions) before reaching the robot,
|
||||
supporting EE-space recording with joint-space robots.
|
||||
|
||||
**Thread safety:** A lock (``_episode_lock``) serialises
|
||||
``save_episode`` and ``push_to_hub`` calls so the background push
|
||||
thread never reads an episode that is still being finalised.
|
||||
"""
|
||||
|
||||
config: SentryStrategyConfig
|
||||
@@ -54,6 +58,7 @@ class SentryStrategy(RolloutStrategy):
|
||||
super().__init__(config)
|
||||
self._push_thread: Thread | None = None
|
||||
self._needs_push = Event()
|
||||
self._episode_lock = Lock()
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
self._init_engine(ctx)
|
||||
@@ -113,7 +118,8 @@ class SentryStrategy(RolloutStrategy):
|
||||
# Auto-rotate episodes
|
||||
elapsed = time.perf_counter() - episode_start
|
||||
if elapsed >= self.config.episode_duration_s:
|
||||
dataset.save_episode()
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
episodes_since_push += 1
|
||||
self._needs_push.set()
|
||||
logger.info("Episode saved (total: %d)", dataset.num_episodes)
|
||||
@@ -134,7 +140,8 @@ class SentryStrategy(RolloutStrategy):
|
||||
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
dataset.save_episode()
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
self._needs_push.set()
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
@@ -155,17 +162,22 @@ class SentryStrategy(RolloutStrategy):
|
||||
logger.info("Sentry strategy teardown complete")
|
||||
|
||||
def _background_push(self, dataset, cfg) -> None:
|
||||
"""Push dataset to hub in a background thread (non-blocking)."""
|
||||
"""Push dataset to hub in a background thread (non-blocking).
|
||||
|
||||
Acquires ``_episode_lock`` during the push to prevent
|
||||
``save_episode`` from finalising a new episode mid-upload.
|
||||
"""
|
||||
if self._push_thread is not None and self._push_thread.is_alive():
|
||||
logger.info("Previous push still in progress, skipping")
|
||||
return
|
||||
|
||||
def _push():
|
||||
try:
|
||||
dataset.push_to_hub(
|
||||
tags=cfg.dataset.tags if cfg.dataset else None,
|
||||
private=cfg.dataset.private if cfg.dataset else False,
|
||||
)
|
||||
with self._episode_lock:
|
||||
dataset.push_to_hub(
|
||||
tags=cfg.dataset.tags if cfg.dataset else None,
|
||||
private=cfg.dataset.private if cfg.dataset else False,
|
||||
)
|
||||
self._needs_push.clear()
|
||||
logger.info("Background push to hub complete")
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user