From 475b8da7d62cd021f25ed5aab120cefd09aece1f Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 21 Apr 2026 19:06:53 +0200 Subject: [PATCH] device + task + warn fix --- src/lerobot/rollout/configs.py | 28 ++++++++++++++++++--- src/lerobot/rollout/context.py | 14 ++++++++--- src/lerobot/rollout/strategies/base.py | 1 + src/lerobot/rollout/strategies/core.py | 14 +++++++++++ src/lerobot/rollout/strategies/dagger.py | 2 ++ src/lerobot/rollout/strategies/highlight.py | 1 + src/lerobot/rollout/strategies/sentry.py | 1 + 7 files changed, 54 insertions(+), 7 deletions(-) diff --git a/src/lerobot/rollout/configs.py b/src/lerobot/rollout/configs.py index d1527dc08..7fb7de0a0 100644 --- a/src/lerobot/rollout/configs.py +++ b/src/lerobot/rollout/configs.py @@ -225,10 +225,10 @@ class RolloutConfig: if needs_dataset and (self.dataset is None or not self.dataset.repo_id): raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set") - if isinstance(self.strategy, BaseStrategyConfig) and self.dataset is not None: - raise ValueError( - "Base strategy does not record data. Use sentry, highlight, or dagger for recording." - ) + # if isinstance(self.strategy, BaseStrategyConfig) and self.dataset is not None: + # raise ValueError( + # "Base strategy does not record data. Use sentry, highlight, or dagger for recording." + # ) # Sentry MUST use streaming encoding to avoid disk I/O blocking the control loop if ( @@ -285,6 +285,26 @@ class RolloutConfig: if self.policy is None: raise ValueError("--policy.path is required for rollout") + # --- Task resolution --- + # When --dataset.rename_map (or any --dataset.* flag) is passed, draccus + # creates a DatasetRecordConfig with single_task="". If the user set + # the task via the top-level --task flag, propagate it so that all + # downstream consumers (inference engine, dataset frame builders) see it. + if self.dataset is not None and not self.dataset.single_task and self.task: + self.dataset.single_task = self.task + elif self.dataset is not None and self.dataset.single_task and not self.task: + self.task = self.dataset.single_task + + # --- Device resolution --- + # Resolve device from the policy config when not explicitly set so all + # components (policy.to, preprocessor, inference engine) use the same + # device string instead of inconsistent fallbacks. + if self.device is None and self.policy is not None: + resolved = getattr(self.policy, "device", None) + if resolved: + self.device = resolved + logger.info("Resolved device from policy config: %s", self.device) + @classmethod def __get_path_fields__(cls) -> list[str]: return ["policy"] diff --git a/src/lerobot/rollout/context.py b/src/lerobot/rollout/context.py index e44b5f5bf..e71cc20c6 100644 --- a/src/lerobot/rollout/context.py +++ b/src/lerobot/rollout/context.py @@ -417,7 +417,7 @@ def build_rollout_context( pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset_stats, preprocessor_overrides={ - "device_processor": {"device": cfg.device or getattr(policy_config, "device", "cpu")}, + "device_processor": {"device": cfg.device}, "rename_observations_processor": {"rename_map": cfg.dataset.rename_map if cfg.dataset else {}}, }, ) @@ -428,13 +428,21 @@ def build_rollout_context( for step in preprocessor.steps: if isinstance(step, NormalizerProcessorStep): n_stats = sum(len(v) for v in step._tensor_stats.values()) if step._tensor_stats else 0 - logger.info("Preprocessor normalizer: %d stat tensors, keys=%s", n_stats, list(step._tensor_stats.keys())[:3]) + logger.info( + "Preprocessor normalizer: %d stat tensors, keys=%s", + n_stats, + list(step._tensor_stats.keys())[:3], + ) if n_stats == 0: logger.error("PREPROCESSOR NORMALIZER HAS NO STATS — observations will NOT be normalized!") for step in postprocessor.steps: if isinstance(step, UnnormalizerProcessorStep): n_stats = sum(len(v) for v in step._tensor_stats.values()) if step._tensor_stats else 0 - logger.info("Postprocessor unnormalizer: %d stat tensors, keys=%s", n_stats, list(step._tensor_stats.keys())[:3]) + logger.info( + "Postprocessor unnormalizer: %d stat tensors, keys=%s", + n_stats, + list(step._tensor_stats.keys())[:3], + ) if n_stats == 0: logger.error("POSTPROCESSOR UNNORMALIZER HAS NO STATS — actions will NOT be denormalized!") diff --git a/src/lerobot/rollout/strategies/base.py b/src/lerobot/rollout/strategies/base.py index 959c9c28a..300612cf6 100644 --- a/src/lerobot/rollout/strategies/base.py +++ b/src/lerobot/rollout/strategies/base.py @@ -80,6 +80,7 @@ class BaseStrategy(RolloutStrategy): self._log_telemetry(obs_processed, action_dict, ctx.runtime) dt = time.perf_counter() - loop_start + self._warn_if_slow(dt, control_interval, cfg.fps) if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t) diff --git a/src/lerobot/rollout/strategies/core.py b/src/lerobot/rollout/strategies/core.py index 3967d8804..6aa6b48be 100644 --- a/src/lerobot/rollout/strategies/core.py +++ b/src/lerobot/rollout/strategies/core.py @@ -146,6 +146,20 @@ class RolloutStrategy(abc.ABC): compress_images=cfg.display_compressed_images, ) + @staticmethod + def _warn_if_slow(dt: float, control_interval: float, fps: float) -> None: + """Log a warning when the control loop runs slower than target FPS.""" + if dt > control_interval: + actual_fps = 1.0 / dt if dt > 0 else 0 + logger.warning( + "Control loop is running slower (%.1f Hz) than target FPS (%.0f Hz). " + "Dataset frames might be dropped and robot control might be unstable. " + "Common causes: 1) Camera FPS not keeping up " + "2) Policy inference taking too long 3) CPU starvation", + actual_fps, + fps, + ) + @abc.abstractmethod def setup(self, ctx: RolloutContext) -> None: """Strategy-specific initialisation (keyboard listeners, buffers, etc.).""" diff --git a/src/lerobot/rollout/strategies/dagger.py b/src/lerobot/rollout/strategies/dagger.py index ba842844f..0464fa2fc 100644 --- a/src/lerobot/rollout/strategies/dagger.py +++ b/src/lerobot/rollout/strategies/dagger.py @@ -506,6 +506,7 @@ class DAggerStrategy(RolloutStrategy): episode_start = time.perf_counter() dt = time.perf_counter() - loop_start + self._warn_if_slow(dt, control_interval, cfg.fps) if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t) @@ -646,6 +647,7 @@ class DAggerStrategy(RolloutStrategy): last_action = ctx.processors.robot_action_processor((action_dict, obs)) dt = time.perf_counter() - loop_start + self._warn_if_slow(dt, control_interval, cfg.fps) if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t) diff --git a/src/lerobot/rollout/strategies/highlight.py b/src/lerobot/rollout/strategies/highlight.py index 05b32e225..2bffa0801 100644 --- a/src/lerobot/rollout/strategies/highlight.py +++ b/src/lerobot/rollout/strategies/highlight.py @@ -187,6 +187,7 @@ class HighlightStrategy(RolloutStrategy): ring.append(frame) dt = time.perf_counter() - loop_start + self._warn_if_slow(dt, control_interval, cfg.fps) if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t) diff --git a/src/lerobot/rollout/strategies/sentry.py b/src/lerobot/rollout/strategies/sentry.py index 047aced25..8d716497a 100644 --- a/src/lerobot/rollout/strategies/sentry.py +++ b/src/lerobot/rollout/strategies/sentry.py @@ -158,6 +158,7 @@ class SentryStrategy(RolloutStrategy): episode_start = time.perf_counter() dt = time.perf_counter() - loop_start + self._warn_if_slow(dt, control_interval, cfg.fps) if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t)