target review

This commit is contained in:
Steven Palma
2026-04-14 17:14:09 +02:00
parent 49f32b9796
commit 8bc47e4318
6 changed files with 98 additions and 177 deletions

View File

@@ -177,7 +177,7 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC
k: v for k, v in all_obs_features.items() if v is float or isinstance(v, tuple)
}
action_features_hw = {k: v for k, v in robot.action_features.items() if k.endswith(".pos")}
action_features_hw = robot.action_features
# Build dataset features
dataset_features = combine_feature_dicts(
@@ -196,7 +196,7 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
# Action keys
action_keys = [k for k in robot.action_features if k.endswith(".pos")]
action_keys = list(robot.action_features.keys())
# Ordered action keys (reconcile policy vs dataset ordering)
policy_action_names = getattr(policy_config, "action_feature_names", None)

View File

@@ -17,6 +17,7 @@
from __future__ import annotations
import abc
import time
from typing import TYPE_CHECKING
import torch
@@ -25,6 +26,7 @@ from lerobot.policies.rtc import ActionInterpolator
from lerobot.policies.utils import make_robot_action
from lerobot.utils.constants import OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
if TYPE_CHECKING:
from lerobot.rollout.configs import RolloutStrategyConfig
@@ -42,6 +44,68 @@ class RolloutStrategy(abc.ABC):
def __init__(self, config: RolloutStrategyConfig) -> None:
self.config = config
self._engine: InferenceEngine | None = None
self._interpolator: ActionInterpolator | None = None
self._warmup_flushed: bool = False
def _init_engine(self, ctx: RolloutContext) -> None:
"""Create and start the inference engine and action interpolator.
Call this from ``setup()`` to avoid duplicating the engine
construction across every strategy.
"""
from lerobot.rollout.inference import InferenceEngine
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
self._engine = InferenceEngine(
policy=ctx.policy,
preprocessor=ctx.preprocessor,
postprocessor=ctx.postprocessor,
robot_wrapper=ctx.robot_wrapper,
rtc_config=ctx.cfg.rtc,
hw_features=ctx.hw_features,
action_keys=ctx.action_keys,
task=ctx.cfg.task,
fps=ctx.cfg.fps,
device=ctx.cfg.device,
use_torch_compile=ctx.cfg.use_torch_compile,
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
)
self._engine.start()
self._warmup_flushed = False
def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool:
"""Handle torch.compile warmup phase.
Returns ``True`` if the caller should ``continue`` (still warming
up). On the first post-warmup iteration the engine and
interpolator are reset so stale warmup state is discarded.
"""
engine = self._engine
interpolator = self._interpolator
if not use_torch_compile:
return False
if not engine.compile_warmup_done.is_set():
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
return True
if not self._warmup_flushed:
engine.reset()
interpolator.reset()
self._warmup_flushed = True
if engine.is_rtc:
engine.resume()
return False
def _teardown_hardware(self, ctx: RolloutContext) -> None:
"""Stop the inference engine and disconnect hardware."""
if self._engine is not None:
self._engine.stop()
if ctx.robot.is_connected:
ctx.robot.disconnect()
if ctx.teleop is not None and ctx.teleop.is_connected:
ctx.teleop.disconnect()
@abc.abstractmethod
def setup(self, ctx: RolloutContext) -> None:

View File

@@ -19,11 +19,9 @@ from __future__ import annotations
import logging
import time
from lerobot.policies.rtc import ActionInterpolator
from lerobot.utils.robot_utils import precise_sleep
from ..context import RolloutContext
from ..inference import InferenceEngine
from . import RolloutStrategy, infer_action
logger = logging.getLogger(__name__)
@@ -37,29 +35,8 @@ class BaseStrategy(RolloutStrategy):
``robot_action_processor`` pipeline before reaching the robot.
"""
def __init__(self, config):
super().__init__(config)
self._engine: InferenceEngine | None = None
self._interpolator: ActionInterpolator | None = None
def setup(self, ctx: RolloutContext) -> None:
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
self._engine = InferenceEngine(
policy=ctx.policy,
preprocessor=ctx.preprocessor,
postprocessor=ctx.postprocessor,
robot_wrapper=ctx.robot_wrapper,
rtc_config=ctx.cfg.rtc,
hw_features=ctx.hw_features,
action_keys=ctx.action_keys,
task=ctx.cfg.task,
fps=ctx.cfg.fps,
device=ctx.cfg.device,
use_torch_compile=ctx.cfg.use_torch_compile,
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
)
self._engine.start()
self._init_engine(ctx)
logger.info("Base strategy ready (rtc=%s)", self._engine.is_rtc)
def run(self, ctx: RolloutContext) -> None:
@@ -72,7 +49,6 @@ class BaseStrategy(RolloutStrategy):
ordered_keys = ctx.ordered_action_keys
start_time = time.perf_counter()
warmup_flushed = False
if engine.is_rtc:
engine.resume()
@@ -89,20 +65,9 @@ class BaseStrategy(RolloutStrategy):
if engine.is_rtc:
engine.update_observation(obs_processed)
# Wait for torch.compile warmup before running live inference
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
if cfg.use_torch_compile and not warmup_flushed:
engine.reset()
interpolator.reset()
warmup_flushed = True
if engine.is_rtc:
engine.resume()
infer_action(engine, obs_processed, obs, ctx, interpolator, ordered_keys, ctx.dataset_features)
dt = time.perf_counter() - loop_start
@@ -110,10 +75,5 @@ class BaseStrategy(RolloutStrategy):
precise_sleep(sleep_t)
def teardown(self, ctx: RolloutContext) -> None:
if self._engine is not None:
self._engine.stop()
if ctx.robot.is_connected:
ctx.robot.disconnect()
if ctx.teleop is not None and ctx.teleop.is_connected:
ctx.teleop.disconnect()
self._teardown_hardware(ctx)
logger.info("Base strategy teardown complete")

View File

@@ -38,8 +38,8 @@ import numpy as np
from lerobot.common.control_utils import is_headless
from lerobot.datasets import VideoEncodingManager
from lerobot.policies.rtc import ActionInterpolator
from lerobot.processor import RobotProcessorPipeline
from lerobot.teleoperators import Teleoperator
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
@@ -47,7 +47,7 @@ from lerobot.utils.utils import log_say
from ..configs import DAggerStrategyConfig
from ..context import RolloutContext
from ..inference import InferenceEngine
from ..robot_wrapper import ThreadSafeRobot
from . import RolloutStrategy, infer_action
logger = logging.getLogger(__name__)
@@ -58,21 +58,23 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
def _teleop_has_motor_control(teleop) -> bool:
def _teleop_has_motor_control(teleop: Teleoperator) -> bool:
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
def _teleop_disable_torque(teleop) -> None:
def _teleop_disable_torque(teleop: Teleoperator) -> None:
if hasattr(teleop, "disable_torque"):
teleop.disable_torque()
def _teleop_enable_torque(teleop) -> None:
def _teleop_enable_torque(teleop: Teleoperator) -> None:
if hasattr(teleop, "enable_torque"):
teleop.enable_torque()
def _teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 50) -> None:
def _teleop_smooth_move_to(
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50
) -> None:
"""Smoothly move teleop to target position if motor control is available."""
if not _teleop_has_motor_control(teleop):
logger.warning("Teleop does not support motor control — cannot mirror robot position")
@@ -95,8 +97,8 @@ def _teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fp
def _reset_loop(
robot,
teleop,
robot: ThreadSafeRobot,
teleop: Teleoperator,
events: dict,
fps: int,
teleop_action_processor: RobotProcessorPipeline,
@@ -275,29 +277,11 @@ class DAggerStrategy(RolloutStrategy):
def __init__(self, config: DAggerStrategyConfig):
super().__init__(config)
self._engine: InferenceEngine | None = None
self._interpolator: ActionInterpolator | None = None
self._listener = None
self._events: dict[str, Any] = {}
def setup(self, ctx: RolloutContext) -> None:
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
self._engine = InferenceEngine(
policy=ctx.policy,
preprocessor=ctx.preprocessor,
postprocessor=ctx.postprocessor,
robot_wrapper=ctx.robot_wrapper,
rtc_config=ctx.cfg.rtc,
hw_features=ctx.hw_features,
action_keys=ctx.action_keys,
task=ctx.cfg.task,
fps=ctx.cfg.fps,
device=ctx.cfg.device,
use_torch_compile=ctx.cfg.use_torch_compile,
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
)
self._engine.start()
self._init_engine(ctx)
self._listener, self._events = _init_dagger_keyboard()
_start_pedal_listener(self._events)
@@ -350,9 +334,6 @@ class DAggerStrategy(RolloutStrategy):
def teardown(self, ctx: RolloutContext) -> None:
log_say("Stop recording", self.config.play_sounds, blocking=True)
if self._engine is not None:
self._engine.stop()
if self._listener is not None and not is_headless():
self._listener.stop()
@@ -364,10 +345,7 @@ class DAggerStrategy(RolloutStrategy):
private=ctx.cfg.dataset.private,
)
if ctx.robot.is_connected:
ctx.robot.disconnect()
if ctx.teleop is not None and ctx.teleop.is_connected:
ctx.teleop.disconnect()
self._teardown_hardware(ctx)
logger.info("DAgger strategy teardown complete")
# ------------------------------------------------------------------
@@ -404,7 +382,6 @@ class DAggerStrategy(RolloutStrategy):
timestamp = 0.0
record_tick = 0
start_t = time.perf_counter()
warmup_flushed = False
if engine.is_rtc:
engine.resume()
@@ -493,21 +470,10 @@ class DAggerStrategy(RolloutStrategy):
if engine.is_rtc:
engine.update_observation(obs_processed)
# Wait for torch.compile warmup
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
timestamp = time.perf_counter() - start_t
continue
if cfg.use_torch_compile and not warmup_flushed:
engine.reset()
interpolator.reset()
warmup_flushed = True
if engine.is_rtc:
engine.resume()
action_dict = infer_action(
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
)

View File

@@ -22,14 +22,12 @@ import time
from threading import Event as ThreadingEvent
from lerobot.datasets import VideoEncodingManager
from lerobot.policies.rtc import ActionInterpolator
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from ..configs import HighlightStrategyConfig
from ..context import RolloutContext
from ..inference import InferenceEngine
from ..ring_buffer import RolloutRingBuffer
from . import RolloutStrategy, infer_action
@@ -54,31 +52,14 @@ class HighlightStrategy(RolloutStrategy):
def __init__(self, config: HighlightStrategyConfig):
super().__init__(config)
self._engine: InferenceEngine | None = None
self._interpolator: ActionInterpolator | None = None
self._ring: RolloutRingBuffer | None = None
self._listener = None
self._save_requested = ThreadingEvent()
self._recording_live = ThreadingEvent()
self._shutdown_event: ThreadingEvent | None = None
def setup(self, ctx: RolloutContext) -> None:
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
self._engine = InferenceEngine(
policy=ctx.policy,
preprocessor=ctx.preprocessor,
postprocessor=ctx.postprocessor,
robot_wrapper=ctx.robot_wrapper,
rtc_config=ctx.cfg.rtc,
hw_features=ctx.hw_features,
action_keys=ctx.action_keys,
task=ctx.cfg.task,
fps=ctx.cfg.fps,
device=ctx.cfg.device,
use_torch_compile=ctx.cfg.use_torch_compile,
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
)
self._engine.start()
self._init_engine(ctx)
self._ring = RolloutRingBuffer(
max_seconds=self.config.ring_buffer_seconds,
@@ -86,6 +67,7 @@ class HighlightStrategy(RolloutStrategy):
fps=ctx.cfg.fps,
)
self._shutdown_event = ctx.shutdown_event
self._setup_keyboard()
logger.info(
"Highlight strategy ready (buffer=%.0fs, key='%s')",
@@ -109,7 +91,6 @@ class HighlightStrategy(RolloutStrategy):
engine.resume()
start_time = time.perf_counter()
warmup_flushed = False
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
with VideoEncodingManager(dataset):
@@ -126,19 +107,9 @@ class HighlightStrategy(RolloutStrategy):
if engine.is_rtc:
engine.update_observation(obs_processed)
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
if cfg.use_torch_compile and not warmup_flushed:
engine.reset()
interpolator.reset()
warmup_flushed = True
if engine.is_rtc:
engine.resume()
action_dict = infer_action(
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
)
@@ -186,8 +157,6 @@ class HighlightStrategy(RolloutStrategy):
dataset.save_episode()
def teardown(self, ctx: RolloutContext) -> None:
if self._engine is not None:
self._engine.stop()
if self._listener is not None:
self._listener.stop()
@@ -199,10 +168,7 @@ class HighlightStrategy(RolloutStrategy):
private=ctx.cfg.dataset.private,
)
if ctx.robot.is_connected:
ctx.robot.disconnect()
if ctx.teleop is not None and ctx.teleop.is_connected:
ctx.teleop.disconnect()
self._teardown_hardware(ctx)
logger.info("Highlight strategy teardown complete")
def _setup_keyboard(self) -> None:
@@ -224,6 +190,8 @@ class HighlightStrategy(RolloutStrategy):
self._save_requested.set()
elif key == keyboard.Key.esc:
self._save_requested.clear()
if self._shutdown_event is not None:
self._shutdown_event.set()
self._listener = keyboard.Listener(on_press=on_press)
self._listener.start()

View File

@@ -19,17 +19,15 @@ from __future__ import annotations
import contextlib
import logging
import time
from threading import Thread
from threading import Event, Thread
from lerobot.datasets import VideoEncodingManager
from lerobot.policies.rtc import ActionInterpolator
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from ..configs import SentryStrategyConfig
from ..context import RolloutContext
from ..inference import InferenceEngine
from . import RolloutStrategy, infer_action
logger = logging.getLogger(__name__)
@@ -54,29 +52,11 @@ class SentryStrategy(RolloutStrategy):
def __init__(self, config: SentryStrategyConfig):
super().__init__(config)
self._engine: InferenceEngine | None = None
self._interpolator: ActionInterpolator | None = None
self._push_thread: Thread | None = None
self._needs_push: bool = False
self._needs_push = Event()
def setup(self, ctx: RolloutContext) -> None:
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
self._engine = InferenceEngine(
policy=ctx.policy,
preprocessor=ctx.preprocessor,
postprocessor=ctx.postprocessor,
robot_wrapper=ctx.robot_wrapper,
rtc_config=ctx.cfg.rtc,
hw_features=ctx.hw_features,
action_keys=ctx.action_keys,
task=ctx.cfg.task,
fps=ctx.cfg.fps,
device=ctx.cfg.device,
use_torch_compile=ctx.cfg.use_torch_compile,
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
)
self._engine.start()
self._init_engine(ctx)
logger.info(
"Sentry strategy ready (episode_duration=%.0fs, upload_every=%d eps)",
self.config.episode_duration_s,
@@ -100,7 +80,6 @@ class SentryStrategy(RolloutStrategy):
start_time = time.perf_counter()
episode_start = time.perf_counter()
episodes_since_push = 0
warmup_flushed = False
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
with VideoEncodingManager(dataset):
@@ -117,19 +96,9 @@ class SentryStrategy(RolloutStrategy):
if engine.is_rtc:
engine.update_observation(obs_processed)
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
if cfg.use_torch_compile and not warmup_flushed:
engine.reset()
interpolator.reset()
warmup_flushed = True
if engine.is_rtc:
engine.resume()
action_dict = infer_action(
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
)
@@ -146,7 +115,7 @@ class SentryStrategy(RolloutStrategy):
if elapsed >= self.config.episode_duration_s:
dataset.save_episode()
episodes_since_push += 1
self._needs_push = True
self._needs_push.set()
logger.info("Episode saved (total: %d)", dataset.num_episodes)
if episodes_since_push >= self.config.upload_every_n_episodes:
@@ -166,12 +135,9 @@ class SentryStrategy(RolloutStrategy):
finally:
with contextlib.suppress(Exception):
dataset.save_episode()
self._needs_push = True
self._needs_push.set()
def teardown(self, ctx: RolloutContext) -> None:
if self._engine is not None:
self._engine.stop()
# Wait for any in-flight background push
if self._push_thread is not None and self._push_thread.is_alive():
self._push_thread.join(timeout=60)
@@ -179,16 +145,13 @@ class SentryStrategy(RolloutStrategy):
if ctx.dataset is not None:
ctx.dataset.finalize()
# Only push if there are unsaved changes since last background push
if self._needs_push and ctx.cfg.dataset and ctx.cfg.dataset.push_to_hub:
if self._needs_push.is_set() and ctx.cfg.dataset and ctx.cfg.dataset.push_to_hub:
ctx.dataset.push_to_hub(
tags=ctx.cfg.dataset.tags,
private=ctx.cfg.dataset.private,
)
if ctx.robot.is_connected:
ctx.robot.disconnect()
if ctx.teleop is not None and ctx.teleop.is_connected:
ctx.teleop.disconnect()
self._teardown_hardware(ctx)
logger.info("Sentry strategy teardown complete")
def _background_push(self, dataset, cfg) -> None:
@@ -203,7 +166,7 @@ class SentryStrategy(RolloutStrategy):
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
)
self._needs_push = False
self._needs_push.clear()
logger.info("Background push to hub complete")
except Exception as e:
logger.error("Background push failed: %s", e)