From 8e21268c2954d692b36296bb622fa86b62c6db30 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 20 Apr 2026 00:36:02 +0200 Subject: [PATCH] test: add dataset guard + fix imports --- docs/source/inference.mdx | 5 ++- examples/lekiwi/rollout.py | 5 ++- examples/phone_to_so100/rollout.py | 5 ++- examples/so100_to_so100_EE/rollout.py | 5 ++- src/lerobot/rollout/__init__.py | 21 ++++++++++-- src/lerobot/rollout/strategies/__init__.py | 10 ++++++ src/lerobot/rollout/strategies/factory.py | 2 +- tests/test_rollout.py | 37 +++++++++++----------- 8 files changed, 56 insertions(+), 34 deletions(-) diff --git a/docs/source/inference.mdx b/docs/source/inference.mdx index 4a941fccd..c849f4c4e 100644 --- a/docs/source/inference.mdx +++ b/docs/source/inference.mdx @@ -227,10 +227,9 @@ See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters. For custom deployments (e.g. with kinematics processors), use the rollout module API directly: ```python -from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig -from lerobot.rollout.context import build_rollout_context +from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context from lerobot.rollout.inference import SyncInferenceConfig -from lerobot.rollout.strategies.base import BaseStrategy +from lerobot.rollout.strategies import BaseStrategy from lerobot.utils.process import ProcessSignalHandler cfg = RolloutConfig( diff --git a/examples/lekiwi/rollout.py b/examples/lekiwi/rollout.py index b785b6fcb..4fb103c8c 100644 --- a/examples/lekiwi/rollout.py +++ b/examples/lekiwi/rollout.py @@ -24,10 +24,9 @@ recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``. from lerobot.configs import PreTrainedConfig from lerobot.robots.lekiwi import LeKiwiClientConfig -from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig -from lerobot.rollout.context import build_rollout_context +from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context from lerobot.rollout.inference import SyncInferenceConfig -from lerobot.rollout.strategies.base import BaseStrategy +from lerobot.rollout.strategies import BaseStrategy from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.utils import init_logging diff --git a/examples/phone_to_so100/rollout.py b/examples/phone_to_so100/rollout.py index 2a17aa4d8..ca6706c52 100644 --- a/examples/phone_to_so100/rollout.py +++ b/examples/phone_to_so100/rollout.py @@ -40,10 +40,9 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( ForwardKinematicsJointsToEE, InverseKinematicsEEToJoints, ) -from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig -from lerobot.rollout.context import build_rollout_context +from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context from lerobot.rollout.inference import SyncInferenceConfig -from lerobot.rollout.strategies.base import BaseStrategy +from lerobot.rollout.strategies import BaseStrategy from lerobot.types import RobotAction, RobotObservation from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.utils import init_logging diff --git a/examples/so100_to_so100_EE/rollout.py b/examples/so100_to_so100_EE/rollout.py index 95339029d..d608bfab2 100644 --- a/examples/so100_to_so100_EE/rollout.py +++ b/examples/so100_to_so100_EE/rollout.py @@ -38,10 +38,9 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( ForwardKinematicsJointsToEE, InverseKinematicsEEToJoints, ) -from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig -from lerobot.rollout.context import build_rollout_context +from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context from lerobot.rollout.inference import SyncInferenceConfig -from lerobot.rollout.strategies.base import BaseStrategy +from lerobot.rollout.strategies import BaseStrategy from lerobot.types import RobotAction, RobotObservation from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.utils import init_logging diff --git a/src/lerobot/rollout/__init__.py b/src/lerobot/rollout/__init__.py index 8d7a90d43..6a2444762 100644 --- a/src/lerobot/rollout/__init__.py +++ b/src/lerobot/rollout/__init__.py @@ -14,6 +14,10 @@ """Policy deployment engine with pluggable rollout strategies.""" +from lerobot.utils.import_utils import require_package + +require_package("datasets", extra="dataset") + from .configs import ( BaseStrategyConfig, DAggerKeyboardConfig, @@ -25,7 +29,15 @@ from .configs import ( RolloutStrategyConfig, SentryStrategyConfig, ) -from .context import RolloutContext, build_rollout_context +from .context import ( + DatasetContext, + HardwareContext, + PolicyContext, + ProcessorContext, + RolloutContext, + RuntimeContext, + build_rollout_context, +) from .inference import ( InferenceEngine, InferenceEngineConfig, @@ -44,17 +56,22 @@ __all__ = [ "DAggerKeyboardConfig", "DAggerPedalConfig", "DAggerStrategyConfig", + "DatasetContext", + "DatasetRecordConfig", + "HardwareContext", "HighlightStrategyConfig", "InferenceEngine", "InferenceEngineConfig", + "PolicyContext", + "ProcessorContext", "RTCInferenceConfig", "RTCInferenceEngine", "RolloutConfig", "RolloutContext", - "DatasetRecordConfig", "RolloutRingBuffer", "RolloutStrategy", "RolloutStrategyConfig", + "RuntimeContext", "SentryStrategyConfig", "SyncInferenceConfig", "SyncInferenceEngine", diff --git a/src/lerobot/rollout/strategies/__init__.py b/src/lerobot/rollout/strategies/__init__.py index f900094dc..554327073 100644 --- a/src/lerobot/rollout/strategies/__init__.py +++ b/src/lerobot/rollout/strategies/__init__.py @@ -14,11 +14,21 @@ """Rollout strategies — public API re-exports.""" +from .base import BaseStrategy from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action +from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy from .factory import create_strategy +from .highlight import HighlightStrategy +from .sentry import SentryStrategy __all__ = [ + "BaseStrategy", + "DAggerEvents", + "DAggerPhase", + "DAggerStrategy", + "HighlightStrategy", "RolloutStrategy", + "SentryStrategy", "create_strategy", "estimate_max_episode_seconds", "safe_push_to_hub", diff --git a/src/lerobot/rollout/strategies/factory.py b/src/lerobot/rollout/strategies/factory.py index 0705ca3d0..0a07d4750 100644 --- a/src/lerobot/rollout/strategies/factory.py +++ b/src/lerobot/rollout/strategies/factory.py @@ -25,7 +25,7 @@ from .highlight import HighlightStrategy from .sentry import SentryStrategy if TYPE_CHECKING: - from lerobot.rollout.configs import RolloutStrategyConfig + from lerobot.rollout import RolloutStrategyConfig def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy: diff --git a/tests/test_rollout.py b/tests/test_rollout.py index 9f54f581d..0a53ea477 100644 --- a/tests/test_rollout.py +++ b/tests/test_rollout.py @@ -22,6 +22,8 @@ from unittest.mock import MagicMock import pytest import torch +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + # --------------------------------------------------------------------------- # Import smoke tests # --------------------------------------------------------------------------- @@ -54,7 +56,7 @@ def test_strategies_submodule_imports(): def test_strategy_config_types(): - from lerobot.rollout.configs import ( + from lerobot.rollout import ( BaseStrategyConfig, DAggerStrategyConfig, HighlightStrategyConfig, @@ -68,14 +70,14 @@ def test_strategy_config_types(): def test_dagger_config_invalid_input_device(): - from lerobot.rollout.configs import DAggerStrategyConfig + from lerobot.rollout import DAggerStrategyConfig with pytest.raises(ValueError, match="input_device must be 'keyboard' or 'pedal'"): DAggerStrategyConfig(input_device="joystick") def test_dagger_config_defaults(): - from lerobot.rollout.configs import DAggerStrategyConfig + from lerobot.rollout import DAggerStrategyConfig cfg = DAggerStrategyConfig() assert cfg.num_episodes == 10 @@ -95,7 +97,7 @@ def test_inference_config_types(): def test_sentry_config_defaults(): - from lerobot.rollout.configs import SentryStrategyConfig + from lerobot.rollout import SentryStrategyConfig cfg = SentryStrategyConfig() assert cfg.upload_every_n_episodes == 5 @@ -108,7 +110,7 @@ def test_sentry_config_defaults(): def test_ring_buffer_append_and_eviction(): - from lerobot.rollout.ring_buffer import RolloutRingBuffer + from lerobot.rollout import RolloutRingBuffer buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0) # max_frames = 5 @@ -118,7 +120,7 @@ def test_ring_buffer_append_and_eviction(): def test_ring_buffer_drain(): - from lerobot.rollout.ring_buffer import RolloutRingBuffer + from lerobot.rollout import RolloutRingBuffer buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) for i in range(3): @@ -130,7 +132,7 @@ def test_ring_buffer_drain(): def test_ring_buffer_clear(): - from lerobot.rollout.ring_buffer import RolloutRingBuffer + from lerobot.rollout import RolloutRingBuffer buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) buf.append({"val": 1}) @@ -140,7 +142,7 @@ def test_ring_buffer_clear(): def test_ring_buffer_tensor_bytes(): - from lerobot.rollout.ring_buffer import RolloutRingBuffer + from lerobot.rollout import RolloutRingBuffer buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0) t = torch.zeros(100, dtype=torch.float32) # 400 bytes @@ -154,7 +156,7 @@ def test_ring_buffer_tensor_bytes(): def test_thread_safe_robot_delegates(): - from lerobot.rollout.robot_wrapper import ThreadSafeRobot + from lerobot.rollout import ThreadSafeRobot from tests.mocks.mock_robot import MockRobot, MockRobotConfig robot = MockRobot(MockRobotConfig(n_motors=3)) @@ -174,7 +176,7 @@ def test_thread_safe_robot_delegates(): def test_thread_safe_robot_properties(): - from lerobot.rollout.robot_wrapper import ThreadSafeRobot + from lerobot.rollout import ThreadSafeRobot from tests.mocks.mock_robot import MockRobot, MockRobotConfig robot = MockRobot(MockRobotConfig(n_motors=3)) @@ -196,11 +198,8 @@ def test_thread_safe_robot_properties(): def test_create_strategy_dispatches(): - from lerobot.rollout.configs import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig - from lerobot.rollout.strategies import create_strategy - from lerobot.rollout.strategies.base import BaseStrategy - from lerobot.rollout.strategies.dagger import DAggerStrategy - from lerobot.rollout.strategies.sentry import SentryStrategy + from lerobot.rollout import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig + from lerobot.rollout.strategies import BaseStrategy, DAggerStrategy, SentryStrategy, create_strategy assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy) assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy) @@ -280,7 +279,7 @@ def test_safe_push_to_hub(): def test_dagger_full_transition_cycle(): - from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase + from lerobot.rollout.strategies import DAggerEvents, DAggerPhase events = DAggerEvents() assert events.phase == DAggerPhase.AUTONOMOUS @@ -307,7 +306,7 @@ def test_dagger_full_transition_cycle(): def test_dagger_invalid_transition_ignored(): - from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase + from lerobot.rollout.strategies import DAggerEvents, DAggerPhase events = DAggerEvents() events.request_transition("correction") # Not valid from AUTONOMOUS @@ -316,7 +315,7 @@ def test_dagger_invalid_transition_ignored(): def test_dagger_events_reset(): - from lerobot.rollout.strategies.dagger import DAggerEvents, DAggerPhase + from lerobot.rollout.strategies import DAggerEvents, DAggerPhase events = DAggerEvents() events.request_transition("pause_resume") @@ -333,7 +332,7 @@ def test_dagger_events_reset(): def test_rollout_context_fields(): - from lerobot.rollout.context import RolloutContext + from lerobot.rollout import RolloutContext field_names = {f.name for f in dataclasses.fields(RolloutContext)} assert field_names == {"runtime", "hardware", "policy", "processors", "data"}