mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
target review
This commit is contained in:
@@ -177,7 +177,7 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC
|
||||
k: v for k, v in all_obs_features.items() if v is float or isinstance(v, tuple)
|
||||
}
|
||||
|
||||
action_features_hw = {k: v for k, v in robot.action_features.items() if k.endswith(".pos")}
|
||||
action_features_hw = robot.action_features
|
||||
|
||||
# Build dataset features
|
||||
dataset_features = combine_feature_dicts(
|
||||
@@ -196,7 +196,7 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC
|
||||
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
|
||||
|
||||
# Action keys
|
||||
action_keys = [k for k in robot.action_features if k.endswith(".pos")]
|
||||
action_keys = list(robot.action_features.keys())
|
||||
|
||||
# Ordered action keys (reconcile policy vs dataset ordering)
|
||||
policy_action_names = getattr(policy_config, "action_feature_names", None)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@@ -25,6 +26,7 @@ from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rollout.configs import RolloutStrategyConfig
|
||||
@@ -42,6 +44,68 @@ class RolloutStrategy(abc.ABC):
|
||||
|
||||
def __init__(self, config: RolloutStrategyConfig) -> None:
|
||||
self.config = config
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._interpolator: ActionInterpolator | None = None
|
||||
self._warmup_flushed: bool = False
|
||||
|
||||
def _init_engine(self, ctx: RolloutContext) -> None:
|
||||
"""Create and start the inference engine and action interpolator.
|
||||
|
||||
Call this from ``setup()`` to avoid duplicating the engine
|
||||
construction across every strategy.
|
||||
"""
|
||||
from lerobot.rollout.inference import InferenceEngine
|
||||
|
||||
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
|
||||
self._engine = InferenceEngine(
|
||||
policy=ctx.policy,
|
||||
preprocessor=ctx.preprocessor,
|
||||
postprocessor=ctx.postprocessor,
|
||||
robot_wrapper=ctx.robot_wrapper,
|
||||
rtc_config=ctx.cfg.rtc,
|
||||
hw_features=ctx.hw_features,
|
||||
action_keys=ctx.action_keys,
|
||||
task=ctx.cfg.task,
|
||||
fps=ctx.cfg.fps,
|
||||
device=ctx.cfg.device,
|
||||
use_torch_compile=ctx.cfg.use_torch_compile,
|
||||
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
|
||||
)
|
||||
self._engine.start()
|
||||
self._warmup_flushed = False
|
||||
|
||||
def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool:
|
||||
"""Handle torch.compile warmup phase.
|
||||
|
||||
Returns ``True`` if the caller should ``continue`` (still warming
|
||||
up). On the first post-warmup iteration the engine and
|
||||
interpolator are reset so stale warmup state is discarded.
|
||||
"""
|
||||
engine = self._engine
|
||||
interpolator = self._interpolator
|
||||
if not use_torch_compile:
|
||||
return False
|
||||
if not engine.compile_warmup_done.is_set():
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
return True
|
||||
if not self._warmup_flushed:
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
self._warmup_flushed = True
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
return False
|
||||
|
||||
def _teardown_hardware(self, ctx: RolloutContext) -> None:
|
||||
"""Stop the inference engine and disconnect hardware."""
|
||||
if self._engine is not None:
|
||||
self._engine.stop()
|
||||
if ctx.robot.is_connected:
|
||||
ctx.robot.disconnect()
|
||||
if ctx.teleop is not None and ctx.teleop.is_connected:
|
||||
ctx.teleop.disconnect()
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
|
||||
@@ -19,11 +19,9 @@ from __future__ import annotations
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
from ..context import RolloutContext
|
||||
from ..inference import InferenceEngine
|
||||
from . import RolloutStrategy, infer_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -37,29 +35,8 @@ class BaseStrategy(RolloutStrategy):
|
||||
``robot_action_processor`` pipeline before reaching the robot.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._interpolator: ActionInterpolator | None = None
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
|
||||
|
||||
self._engine = InferenceEngine(
|
||||
policy=ctx.policy,
|
||||
preprocessor=ctx.preprocessor,
|
||||
postprocessor=ctx.postprocessor,
|
||||
robot_wrapper=ctx.robot_wrapper,
|
||||
rtc_config=ctx.cfg.rtc,
|
||||
hw_features=ctx.hw_features,
|
||||
action_keys=ctx.action_keys,
|
||||
task=ctx.cfg.task,
|
||||
fps=ctx.cfg.fps,
|
||||
device=ctx.cfg.device,
|
||||
use_torch_compile=ctx.cfg.use_torch_compile,
|
||||
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
|
||||
)
|
||||
self._engine.start()
|
||||
self._init_engine(ctx)
|
||||
logger.info("Base strategy ready (rtc=%s)", self._engine.is_rtc)
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
@@ -72,7 +49,6 @@ class BaseStrategy(RolloutStrategy):
|
||||
ordered_keys = ctx.ordered_action_keys
|
||||
|
||||
start_time = time.perf_counter()
|
||||
warmup_flushed = False
|
||||
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
@@ -89,20 +65,9 @@ class BaseStrategy(RolloutStrategy):
|
||||
if engine.is_rtc:
|
||||
engine.update_observation(obs_processed)
|
||||
|
||||
# Wait for torch.compile warmup before running live inference
|
||||
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
if cfg.use_torch_compile and not warmup_flushed:
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
warmup_flushed = True
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
|
||||
infer_action(engine, obs_processed, obs, ctx, interpolator, ordered_keys, ctx.dataset_features)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
@@ -110,10 +75,5 @@ class BaseStrategy(RolloutStrategy):
|
||||
precise_sleep(sleep_t)
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
if self._engine is not None:
|
||||
self._engine.stop()
|
||||
if ctx.robot.is_connected:
|
||||
ctx.robot.disconnect()
|
||||
if ctx.teleop is not None and ctx.teleop.is_connected:
|
||||
ctx.teleop.disconnect()
|
||||
self._teardown_hardware(ctx)
|
||||
logger.info("Base strategy teardown complete")
|
||||
|
||||
@@ -38,8 +38,8 @@ import numpy as np
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
@@ -47,7 +47,7 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
from ..configs import DAggerStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..inference import InferenceEngine
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from . import RolloutStrategy, infer_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -58,21 +58,23 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _teleop_has_motor_control(teleop) -> bool:
|
||||
def _teleop_has_motor_control(teleop: Teleoperator) -> bool:
|
||||
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
|
||||
|
||||
|
||||
def _teleop_disable_torque(teleop) -> None:
|
||||
def _teleop_disable_torque(teleop: Teleoperator) -> None:
|
||||
if hasattr(teleop, "disable_torque"):
|
||||
teleop.disable_torque()
|
||||
|
||||
|
||||
def _teleop_enable_torque(teleop) -> None:
|
||||
def _teleop_enable_torque(teleop: Teleoperator) -> None:
|
||||
if hasattr(teleop, "enable_torque"):
|
||||
teleop.enable_torque()
|
||||
|
||||
|
||||
def _teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 50) -> None:
|
||||
def _teleop_smooth_move_to(
|
||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50
|
||||
) -> None:
|
||||
"""Smoothly move teleop to target position if motor control is available."""
|
||||
if not _teleop_has_motor_control(teleop):
|
||||
logger.warning("Teleop does not support motor control — cannot mirror robot position")
|
||||
@@ -95,8 +97,8 @@ def _teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fp
|
||||
|
||||
|
||||
def _reset_loop(
|
||||
robot,
|
||||
teleop,
|
||||
robot: ThreadSafeRobot,
|
||||
teleop: Teleoperator,
|
||||
events: dict,
|
||||
fps: int,
|
||||
teleop_action_processor: RobotProcessorPipeline,
|
||||
@@ -275,29 +277,11 @@ class DAggerStrategy(RolloutStrategy):
|
||||
|
||||
def __init__(self, config: DAggerStrategyConfig):
|
||||
super().__init__(config)
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._interpolator: ActionInterpolator | None = None
|
||||
self._listener = None
|
||||
self._events: dict[str, Any] = {}
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
|
||||
|
||||
self._engine = InferenceEngine(
|
||||
policy=ctx.policy,
|
||||
preprocessor=ctx.preprocessor,
|
||||
postprocessor=ctx.postprocessor,
|
||||
robot_wrapper=ctx.robot_wrapper,
|
||||
rtc_config=ctx.cfg.rtc,
|
||||
hw_features=ctx.hw_features,
|
||||
action_keys=ctx.action_keys,
|
||||
task=ctx.cfg.task,
|
||||
fps=ctx.cfg.fps,
|
||||
device=ctx.cfg.device,
|
||||
use_torch_compile=ctx.cfg.use_torch_compile,
|
||||
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
|
||||
)
|
||||
self._engine.start()
|
||||
self._init_engine(ctx)
|
||||
|
||||
self._listener, self._events = _init_dagger_keyboard()
|
||||
_start_pedal_listener(self._events)
|
||||
@@ -350,9 +334,6 @@ class DAggerStrategy(RolloutStrategy):
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
log_say("Stop recording", self.config.play_sounds, blocking=True)
|
||||
|
||||
if self._engine is not None:
|
||||
self._engine.stop()
|
||||
|
||||
if self._listener is not None and not is_headless():
|
||||
self._listener.stop()
|
||||
|
||||
@@ -364,10 +345,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
private=ctx.cfg.dataset.private,
|
||||
)
|
||||
|
||||
if ctx.robot.is_connected:
|
||||
ctx.robot.disconnect()
|
||||
if ctx.teleop is not None and ctx.teleop.is_connected:
|
||||
ctx.teleop.disconnect()
|
||||
self._teardown_hardware(ctx)
|
||||
logger.info("DAgger strategy teardown complete")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -404,7 +382,6 @@ class DAggerStrategy(RolloutStrategy):
|
||||
timestamp = 0.0
|
||||
record_tick = 0
|
||||
start_t = time.perf_counter()
|
||||
warmup_flushed = False
|
||||
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
@@ -493,21 +470,10 @@ class DAggerStrategy(RolloutStrategy):
|
||||
if engine.is_rtc:
|
||||
engine.update_observation(obs_processed)
|
||||
|
||||
# Wait for torch.compile warmup
|
||||
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
timestamp = time.perf_counter() - start_t
|
||||
continue
|
||||
|
||||
if cfg.use_torch_compile and not warmup_flushed:
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
warmup_flushed = True
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
|
||||
action_dict = infer_action(
|
||||
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
|
||||
)
|
||||
|
||||
@@ -22,14 +22,12 @@ import time
|
||||
from threading import Event as ThreadingEvent
|
||||
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
from ..configs import HighlightStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..inference import InferenceEngine
|
||||
from ..ring_buffer import RolloutRingBuffer
|
||||
from . import RolloutStrategy, infer_action
|
||||
|
||||
@@ -54,31 +52,14 @@ class HighlightStrategy(RolloutStrategy):
|
||||
|
||||
def __init__(self, config: HighlightStrategyConfig):
|
||||
super().__init__(config)
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._interpolator: ActionInterpolator | None = None
|
||||
self._ring: RolloutRingBuffer | None = None
|
||||
self._listener = None
|
||||
self._save_requested = ThreadingEvent()
|
||||
self._recording_live = ThreadingEvent()
|
||||
self._shutdown_event: ThreadingEvent | None = None
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
|
||||
|
||||
self._engine = InferenceEngine(
|
||||
policy=ctx.policy,
|
||||
preprocessor=ctx.preprocessor,
|
||||
postprocessor=ctx.postprocessor,
|
||||
robot_wrapper=ctx.robot_wrapper,
|
||||
rtc_config=ctx.cfg.rtc,
|
||||
hw_features=ctx.hw_features,
|
||||
action_keys=ctx.action_keys,
|
||||
task=ctx.cfg.task,
|
||||
fps=ctx.cfg.fps,
|
||||
device=ctx.cfg.device,
|
||||
use_torch_compile=ctx.cfg.use_torch_compile,
|
||||
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
|
||||
)
|
||||
self._engine.start()
|
||||
self._init_engine(ctx)
|
||||
|
||||
self._ring = RolloutRingBuffer(
|
||||
max_seconds=self.config.ring_buffer_seconds,
|
||||
@@ -86,6 +67,7 @@ class HighlightStrategy(RolloutStrategy):
|
||||
fps=ctx.cfg.fps,
|
||||
)
|
||||
|
||||
self._shutdown_event = ctx.shutdown_event
|
||||
self._setup_keyboard()
|
||||
logger.info(
|
||||
"Highlight strategy ready (buffer=%.0fs, key='%s')",
|
||||
@@ -109,7 +91,6 @@ class HighlightStrategy(RolloutStrategy):
|
||||
engine.resume()
|
||||
|
||||
start_time = time.perf_counter()
|
||||
warmup_flushed = False
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
@@ -126,19 +107,9 @@ class HighlightStrategy(RolloutStrategy):
|
||||
if engine.is_rtc:
|
||||
engine.update_observation(obs_processed)
|
||||
|
||||
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
if cfg.use_torch_compile and not warmup_flushed:
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
warmup_flushed = True
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
|
||||
action_dict = infer_action(
|
||||
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
|
||||
)
|
||||
@@ -186,8 +157,6 @@ class HighlightStrategy(RolloutStrategy):
|
||||
dataset.save_episode()
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
if self._engine is not None:
|
||||
self._engine.stop()
|
||||
if self._listener is not None:
|
||||
self._listener.stop()
|
||||
|
||||
@@ -199,10 +168,7 @@ class HighlightStrategy(RolloutStrategy):
|
||||
private=ctx.cfg.dataset.private,
|
||||
)
|
||||
|
||||
if ctx.robot.is_connected:
|
||||
ctx.robot.disconnect()
|
||||
if ctx.teleop is not None and ctx.teleop.is_connected:
|
||||
ctx.teleop.disconnect()
|
||||
self._teardown_hardware(ctx)
|
||||
logger.info("Highlight strategy teardown complete")
|
||||
|
||||
def _setup_keyboard(self) -> None:
|
||||
@@ -224,6 +190,8 @@ class HighlightStrategy(RolloutStrategy):
|
||||
self._save_requested.set()
|
||||
elif key == keyboard.Key.esc:
|
||||
self._save_requested.clear()
|
||||
if self._shutdown_event is not None:
|
||||
self._shutdown_event.set()
|
||||
|
||||
self._listener = keyboard.Listener(on_press=on_press)
|
||||
self._listener.start()
|
||||
|
||||
@@ -19,17 +19,15 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from threading import Thread
|
||||
from threading import Event, Thread
|
||||
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
from ..configs import SentryStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..inference import InferenceEngine
|
||||
from . import RolloutStrategy, infer_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -54,29 +52,11 @@ class SentryStrategy(RolloutStrategy):
|
||||
|
||||
def __init__(self, config: SentryStrategyConfig):
|
||||
super().__init__(config)
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._interpolator: ActionInterpolator | None = None
|
||||
self._push_thread: Thread | None = None
|
||||
self._needs_push: bool = False
|
||||
self._needs_push = Event()
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
|
||||
|
||||
self._engine = InferenceEngine(
|
||||
policy=ctx.policy,
|
||||
preprocessor=ctx.preprocessor,
|
||||
postprocessor=ctx.postprocessor,
|
||||
robot_wrapper=ctx.robot_wrapper,
|
||||
rtc_config=ctx.cfg.rtc,
|
||||
hw_features=ctx.hw_features,
|
||||
action_keys=ctx.action_keys,
|
||||
task=ctx.cfg.task,
|
||||
fps=ctx.cfg.fps,
|
||||
device=ctx.cfg.device,
|
||||
use_torch_compile=ctx.cfg.use_torch_compile,
|
||||
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
|
||||
)
|
||||
self._engine.start()
|
||||
self._init_engine(ctx)
|
||||
logger.info(
|
||||
"Sentry strategy ready (episode_duration=%.0fs, upload_every=%d eps)",
|
||||
self.config.episode_duration_s,
|
||||
@@ -100,7 +80,6 @@ class SentryStrategy(RolloutStrategy):
|
||||
start_time = time.perf_counter()
|
||||
episode_start = time.perf_counter()
|
||||
episodes_since_push = 0
|
||||
warmup_flushed = False
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
@@ -117,19 +96,9 @@ class SentryStrategy(RolloutStrategy):
|
||||
if engine.is_rtc:
|
||||
engine.update_observation(obs_processed)
|
||||
|
||||
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
if cfg.use_torch_compile and not warmup_flushed:
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
warmup_flushed = True
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
|
||||
action_dict = infer_action(
|
||||
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
|
||||
)
|
||||
@@ -146,7 +115,7 @@ class SentryStrategy(RolloutStrategy):
|
||||
if elapsed >= self.config.episode_duration_s:
|
||||
dataset.save_episode()
|
||||
episodes_since_push += 1
|
||||
self._needs_push = True
|
||||
self._needs_push.set()
|
||||
logger.info("Episode saved (total: %d)", dataset.num_episodes)
|
||||
|
||||
if episodes_since_push >= self.config.upload_every_n_episodes:
|
||||
@@ -166,12 +135,9 @@ class SentryStrategy(RolloutStrategy):
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
dataset.save_episode()
|
||||
self._needs_push = True
|
||||
self._needs_push.set()
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
if self._engine is not None:
|
||||
self._engine.stop()
|
||||
|
||||
# Wait for any in-flight background push
|
||||
if self._push_thread is not None and self._push_thread.is_alive():
|
||||
self._push_thread.join(timeout=60)
|
||||
@@ -179,16 +145,13 @@ class SentryStrategy(RolloutStrategy):
|
||||
if ctx.dataset is not None:
|
||||
ctx.dataset.finalize()
|
||||
# Only push if there are unsaved changes since last background push
|
||||
if self._needs_push and ctx.cfg.dataset and ctx.cfg.dataset.push_to_hub:
|
||||
if self._needs_push.is_set() and ctx.cfg.dataset and ctx.cfg.dataset.push_to_hub:
|
||||
ctx.dataset.push_to_hub(
|
||||
tags=ctx.cfg.dataset.tags,
|
||||
private=ctx.cfg.dataset.private,
|
||||
)
|
||||
|
||||
if ctx.robot.is_connected:
|
||||
ctx.robot.disconnect()
|
||||
if ctx.teleop is not None and ctx.teleop.is_connected:
|
||||
ctx.teleop.disconnect()
|
||||
self._teardown_hardware(ctx)
|
||||
logger.info("Sentry strategy teardown complete")
|
||||
|
||||
def _background_push(self, dataset, cfg) -> None:
|
||||
@@ -203,7 +166,7 @@ class SentryStrategy(RolloutStrategy):
|
||||
tags=cfg.dataset.tags if cfg.dataset else None,
|
||||
private=cfg.dataset.private if cfg.dataset else False,
|
||||
)
|
||||
self._needs_push = False
|
||||
self._needs_push.clear()
|
||||
logger.info("Background push to hub complete")
|
||||
except Exception as e:
|
||||
logger.error("Background push failed: %s", e)
|
||||
|
||||
Reference in New Issue
Block a user