From da41646073d288f01f21fdd602b2281864ed06f6 Mon Sep 17 00:00:00 2001 From: Sung-Wook Lee <60949213+sean1295@users.noreply.github.com> Date: Mon, 19 Jan 2026 07:18:52 -0500 Subject: [PATCH 1/3] fix libero reset logic for correct resetting (#2817) --- src/lerobot/envs/libero.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 74882ad18..96c5cf102 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -293,9 +293,9 @@ class LiberoEnv(gym.Env): def reset(self, seed=None, **kwargs): super().reset(seed=seed) self._env.seed(seed) - if self.init_states and self._init_states is not None: - self._env.set_init_state(self._init_states[self._init_state_id]) raw_obs = self._env.reset() + if self.init_states and self._init_states is not None: + raw_obs = self._env.set_init_state(self._init_states[self._init_state_id]) # After reset, objects may be unstable (slightly floating, intersecting, etc.). # Step the simulator with a no-op action for a few frames so everything settles. From fe068df711dd5d08aa04c577b40d31ec61245f22 Mon Sep 17 00:00:00 2001 From: bigmbigk Date: Mon, 19 Jan 2026 22:14:10 +0900 Subject: [PATCH 2/3] fix(train): eval env initialization on train script (#2818) * fix: eval env initialization on train script Signed-off-by: bigmbigk * fix: eval env creation condition --------- Signed-off-by: bigmbigk --- src/lerobot/scripts/lerobot_train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 41f866ce8..93b99e245 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -225,9 +225,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # On real-world data, no need to create an environment as evaluations are done outside train.py, # using the eval.py instead, with gym_dora environment and dora-rs. eval_env = None - if cfg.eval_freq > 0 and cfg.env is not None: - if is_main_process: - logging.info("Creating env") + if cfg.eval_freq > 0 and cfg.env is not None and is_main_process: + logging.info("Creating env") eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) if is_main_process: From 5286ef8439efe95ba3f98f3d7b3c1bdf4898c347 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 19 Jan 2026 16:43:11 +0100 Subject: [PATCH 3/3] feat(utils): extend import check util (#2820) * refactor(utils): is_package_available now differentiate between pkg name and module name * refactor(tests): update require_package decorator --- src/lerobot/utils/import_utils.py | 26 ++++++--- tests/async_inference/test_policy_server.py | 2 +- tests/rl/test_actor.py | 10 ++-- tests/rl/test_actor_learner.py | 6 +- tests/rl/test_learner_service.py | 16 +++--- tests/transport/test_transport_utils.py | 62 ++++++++++----------- tests/utils.py | 4 +- 7 files changed, 67 insertions(+), 59 deletions(-) diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index 0206a8ac9..a499b96c7 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -21,12 +21,23 @@ from typing import Any from draccus.choice_types import ChoiceRegistry -def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: - """Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py - Check if the package spec exists and grab its version to avoid importing a local directory. - **Note:** this doesn't work for all packages. +def is_package_available( + pkg_name: str, import_name: str | None = None, return_version: bool = False +) -> tuple[bool, str] | bool: """ - package_exists = importlib.util.find_spec(pkg_name) is not None + Check if the package spec exists and grab its version to avoid importing a local directory. + + Args: + pkg_name: The name of the package as installed via pip (e.g. "python-can"). + import_name: The actual name used to import the package (e.g. "can"). + Defaults to pkg_name if not provided. + return_version: Whether to return the version string. + """ + if import_name is None: + import_name = pkg_name + + # Check if the module spec exists using the import name + package_exists = importlib.util.find_spec(import_name) is not None package_version = "N/A" if package_exists: try: @@ -37,7 +48,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b # Fallback method: Only for "torch" and versions containing "dev" if pkg_name == "torch": try: - package = importlib.import_module(pkg_name) + package = importlib.import_module(import_name) temp_version = getattr(package, "__version__", "N/A") # Check if the version contains "dev" if "dev" in temp_version: @@ -48,9 +59,6 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b except ImportError: # If the package can't be imported, it's not available package_exists = False - elif pkg_name == "grpc": - package = importlib.import_module(pkg_name) - package_version = getattr(package, "__version__", "N/A") else: # For packages other than "torch", don't attempt the fallback and set as not available package_exists = False diff --git a/tests/async_inference/test_policy_server.py b/tests/async_inference/test_policy_server.py index 29583d4fa..c3ee37c8f 100644 --- a/tests/async_inference/test_policy_server.py +++ b/tests/async_inference/test_policy_server.py @@ -62,7 +62,7 @@ class MockPolicy: @pytest.fixture -@require_package("grpc") +@require_package("grpcio", "grpc") def policy_server(): """Fresh `PolicyServer` instance with a stubbed-out policy model.""" # Import only when the test actually runs (after decorator check) diff --git a/tests/rl/test_actor.py b/tests/rl/test_actor.py index ec67f1889..54e4d2870 100644 --- a/tests/rl/test_actor.py +++ b/tests/rl/test_actor.py @@ -64,7 +64,7 @@ def close_service_stub(channel, server): server.stop(None) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_establish_learner_connection_success(): from lerobot.rl.actor import establish_learner_connection @@ -81,7 +81,7 @@ def test_establish_learner_connection_success(): close_service_stub(channel, server) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_establish_learner_connection_failure(): from lerobot.rl.actor import establish_learner_connection @@ -100,7 +100,7 @@ def test_establish_learner_connection_failure(): close_service_stub(channel, server) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_push_transitions_to_transport_queue(): from lerobot.rl.actor import push_transitions_to_transport_queue from lerobot.transport.utils import bytes_to_transitions @@ -135,7 +135,7 @@ def test_push_transitions_to_transport_queue(): assert_transitions_equal(deserialized_transition, transitions[i]) -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_transitions_stream(): from lerobot.rl.actor import transitions_stream @@ -167,7 +167,7 @@ def test_transitions_stream(): assert streamed_data[2].data == b"transition_data_3" -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_interactions_stream(): from lerobot.rl.actor import interactions_stream diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py index 5d95dee04..e13862d82 100644 --- a/tests/rl/test_actor_learner.py +++ b/tests/rl/test_actor_learner.py @@ -88,7 +88,7 @@ def cfg(): return cfg -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(10) # force cross-platform watchdog def test_end_to_end_transitions_flow(cfg): from lerobot.rl.actor import ( @@ -150,7 +150,7 @@ def test_end_to_end_transitions_flow(cfg): assert_transitions_equal(transition, input_transitions[i]) -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(10) def test_end_to_end_interactions_flow(cfg): from lerobot.rl.actor import ( @@ -223,7 +223,7 @@ def test_end_to_end_interactions_flow(cfg): assert received == expected -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.parametrize("data_size", ["small", "large"]) @pytest.mark.timeout(10) def test_end_to_end_parameters_flow(cfg, data_size): diff --git a/tests/rl/test_learner_service.py b/tests/rl/test_learner_service.py index e0f0292be..d967388f0 100644 --- a/tests/rl/test_learner_service.py +++ b/tests/rl/test_learner_service.py @@ -39,7 +39,7 @@ def learner_service_stub(): close_learner_service_stub(channel, server) -@require_package("grpc") +@require_package("grpcio", "grpc") def create_learner_service_stub( shutdown_event: Event, parameters_queue: Queue, @@ -75,7 +75,7 @@ def create_learner_service_stub( return services_pb2_grpc.LearnerServiceStub(channel), channel, server -@require_package("grpc") +@require_package("grpcio", "grpc") def close_learner_service_stub(channel, server): channel.close() server.stop(None) @@ -91,7 +91,7 @@ def test_ready_method(learner_service_stub): assert response == services_pb2.Empty() -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_send_interactions(): from lerobot.transport import services_pb2 @@ -135,7 +135,7 @@ def test_send_interactions(): assert interactions == [b"123", b"4", b"5", b"678"] -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_send_transitions(): from lerobot.transport import services_pb2 @@ -181,7 +181,7 @@ def test_send_transitions(): assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"] -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_send_transitions_empty_stream(): from lerobot.transport import services_pb2 @@ -209,7 +209,7 @@ def test_send_transitions_empty_stream(): assert transitions_queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(10) # force cross-platform watchdog def test_stream_parameters(): import time @@ -267,7 +267,7 @@ def test_stream_parameters(): assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1) -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_stream_parameters_with_shutdown(): from lerobot.transport import services_pb2 @@ -319,7 +319,7 @@ def test_stream_parameters_with_shutdown(): assert received_params == [b"param_batch_1", b"stop"] -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_stream_parameters_waits_and_retries_on_empty_queue(): import threading diff --git a/tests/transport/test_transport_utils.py b/tests/transport/test_transport_utils.py index 52825a24e..63632a8f4 100644 --- a/tests/transport/test_transport_utils.py +++ b/tests/transport/test_transport_utils.py @@ -26,7 +26,7 @@ from lerobot.utils.transition import Transition from tests.utils import require_cuda, require_package -@require_package("grpc") +@require_package("grpcio", "grpc") def test_bytes_buffer_size_empty_buffer(): from lerobot.transport.utils import bytes_buffer_size @@ -37,7 +37,7 @@ def test_bytes_buffer_size_empty_buffer(): assert buffer.tell() == 0 -@require_package("grpc") +@require_package("grpcio", "grpc") def test_bytes_buffer_size_small_buffer(): from lerobot.transport.utils import bytes_buffer_size @@ -47,7 +47,7 @@ def test_bytes_buffer_size_small_buffer(): assert buffer.tell() == 0 -@require_package("grpc") +@require_package("grpcio", "grpc") def test_bytes_buffer_size_large_buffer(): from lerobot.transport.utils import CHUNK_SIZE, bytes_buffer_size @@ -58,7 +58,7 @@ def test_bytes_buffer_size_large_buffer(): assert buffer.tell() == 0 -@require_package("grpc") +@require_package("grpcio", "grpc") def test_send_bytes_in_chunks_empty_data(): from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 @@ -68,7 +68,7 @@ def test_send_bytes_in_chunks_empty_data(): assert len(chunks) == 0 -@require_package("grpc") +@require_package("grpcio", "grpc") def test_single_chunk_small_data(): from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 @@ -82,7 +82,7 @@ def test_single_chunk_small_data(): assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END -@require_package("grpc") +@require_package("grpcio", "grpc") def test_not_silent_mode(): from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 @@ -94,7 +94,7 @@ def test_not_silent_mode(): assert chunks[0].data == b"Some data" -@require_package("grpc") +@require_package("grpcio", "grpc") def test_send_bytes_in_chunks_large_data(): from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 @@ -111,7 +111,7 @@ def test_send_bytes_in_chunks_large_data(): assert chunks[2].transfer_state == services_pb2.TransferState.TRANSFER_END -@require_package("grpc") +@require_package("grpcio", "grpc") def test_send_bytes_in_chunks_large_data_with_exact_chunk_size(): from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 @@ -124,7 +124,7 @@ def test_send_bytes_in_chunks_large_data_with_exact_chunk_size(): assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_empty_data(): from lerobot.transport.utils import receive_bytes_in_chunks @@ -138,7 +138,7 @@ def test_receive_bytes_in_chunks_empty_data(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_single_chunk(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -157,7 +157,7 @@ def test_receive_bytes_in_chunks_single_chunk(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_single_not_end_chunk(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -175,7 +175,7 @@ def test_receive_bytes_in_chunks_single_not_end_chunk(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_multiple_chunks(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -199,7 +199,7 @@ def test_receive_bytes_in_chunks_multiple_chunks(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_multiple_messages(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -235,7 +235,7 @@ def test_receive_bytes_in_chunks_multiple_messages(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_shutdown_during_receive(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -259,7 +259,7 @@ def test_receive_bytes_in_chunks_shutdown_during_receive(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_only_begin_chunk(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -279,7 +279,7 @@ def test_receive_bytes_in_chunks_only_begin_chunk(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_missing_begin(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -303,7 +303,7 @@ def test_receive_bytes_in_chunks_missing_begin(): # Tests for state_to_bytes and bytes_to_state_dict -@require_package("grpc") +@require_package("grpcio", "grpc") def test_state_to_bytes_empty_dict(): from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes @@ -314,7 +314,7 @@ def test_state_to_bytes_empty_dict(): assert reconstructed == state_dict -@require_package("grpc") +@require_package("grpcio", "grpc") def test_bytes_to_state_dict_empty_data(): from lerobot.transport.utils import bytes_to_state_dict @@ -323,7 +323,7 @@ def test_bytes_to_state_dict_empty_data(): bytes_to_state_dict(b"") -@require_package("grpc") +@require_package("grpcio", "grpc") def test_state_to_bytes_simple_dict(): from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes @@ -347,7 +347,7 @@ def test_state_to_bytes_simple_dict(): assert torch.allclose(state_dict[key], reconstructed[key]) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_state_to_bytes_various_dtypes(): from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes @@ -372,7 +372,7 @@ def test_state_to_bytes_various_dtypes(): assert torch.allclose(state_dict[key], reconstructed[key]) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_bytes_to_state_dict_invalid_data(): from lerobot.transport.utils import bytes_to_state_dict @@ -382,7 +382,7 @@ def test_bytes_to_state_dict_invalid_data(): @require_cuda -@require_package("grpc") +@require_package("grpcio", "grpc") def test_state_to_bytes_various_dtypes_cuda(): from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes @@ -407,7 +407,7 @@ def test_state_to_bytes_various_dtypes_cuda(): assert torch.allclose(state_dict[key], reconstructed[key]) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_python_object_to_bytes_none(): from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes @@ -439,7 +439,7 @@ def test_python_object_to_bytes_none(): (1, 2, 3), ], ) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_python_object_to_bytes_simple_types(obj): from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes @@ -450,7 +450,7 @@ def test_python_object_to_bytes_simple_types(obj): assert type(reconstructed) is type(obj) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_python_object_to_bytes_with_tensors(): from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes @@ -475,7 +475,7 @@ def test_python_object_to_bytes_with_tensors(): assert torch.equal(obj["nested"]["tensor2"], reconstructed["nested"]["tensor2"]) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_transitions_to_bytes_empty_list(): from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes @@ -487,7 +487,7 @@ def test_transitions_to_bytes_empty_list(): assert isinstance(reconstructed, list) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_transitions_to_bytes_single_transition(): from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes @@ -509,7 +509,7 @@ def test_transitions_to_bytes_single_transition(): assert_transitions_equal(transitions[0], reconstructed[0]) -@require_package("grpc") +@require_package("grpcio", "grpc") def assert_transitions_equal(t1: Transition, t2: Transition): """Helper to assert two transitions are equal.""" assert_observation_equal(t1["state"], t2["state"]) @@ -519,7 +519,7 @@ def assert_transitions_equal(t1: Transition, t2: Transition): assert_observation_equal(t1["next_state"], t2["next_state"]) -@require_package("grpc") +@require_package("grpcio", "grpc") def assert_observation_equal(o1: dict, o2: dict): """Helper to assert two observations are equal.""" assert set(o1.keys()) == set(o2.keys()) @@ -527,7 +527,7 @@ def assert_observation_equal(o1: dict, o2: dict): assert torch.allclose(o1[key], o2[key]) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_transitions_to_bytes_multiple_transitions(): from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes @@ -551,7 +551,7 @@ def test_transitions_to_bytes_multiple_transitions(): assert_transitions_equal(original, reconstructed_item) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_unknown_state(): from lerobot.transport.utils import receive_bytes_in_chunks diff --git a/tests/utils.py b/tests/utils.py index 800b7d4b3..38841db02 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -167,7 +167,7 @@ def require_package_arg(func): return wrapper -def require_package(package_name): +def require_package(package_name, import_name=None): """ Decorator that skips the test if the specified package is not installed. """ @@ -175,7 +175,7 @@ def require_package(package_name): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): - if not is_package_available(package_name): + if not is_package_available(pkg_name=package_name, import_name=import_name): pytest.skip(f"{package_name} not installed") return func(*args, **kwargs)