diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 83c88fa96..9007c370b 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -51,11 +51,12 @@ import os import time from functools import lru_cache from queue import Empty +from typing import Any import grpc import torch from torch import nn -from torch.multiprocessing import Event, Queue +from torch.multiprocessing import Queue from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser @@ -210,7 +211,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig): def act_with_policy( cfg: TrainRLServerPipelineConfig, - shutdown_event: any, # Event, + shutdown_event: Any, # Event parameters_queue: Queue, transitions_queue: Queue, interactions_queue: Queue, @@ -414,7 +415,7 @@ def act_with_policy( def establish_learner_connection( stub: services_pb2_grpc.LearnerServiceStub, - shutdown_event: Event, # type: ignore + shutdown_event: Any, # Event attempts: int = 30, ): """Establish a connection with the learner. @@ -466,7 +467,7 @@ def learner_service_client( def receive_policy( cfg: TrainRLServerPipelineConfig, parameters_queue: Queue, - shutdown_event: Event, # type: ignore + shutdown_event: Any, # Event learner_client: services_pb2_grpc.LearnerServiceStub | None = None, grpc_channel: grpc.Channel | None = None, ): @@ -518,7 +519,7 @@ def receive_policy( def send_transitions( cfg: TrainRLServerPipelineConfig, transitions_queue: Queue, - shutdown_event: any, # Event, + shutdown_event: Any, # Event learner_client: services_pb2_grpc.LearnerServiceStub | None = None, grpc_channel: grpc.Channel | None = None, ) -> services_pb2.Empty: @@ -568,7 +569,7 @@ def send_transitions( def send_interactions( cfg: TrainRLServerPipelineConfig, interactions_queue: Queue, - shutdown_event: Event, # type: ignore + shutdown_event: Any, # Event learner_client: services_pb2_grpc.LearnerServiceStub | None = None, grpc_channel: grpc.Channel | None = None, ) -> services_pb2.Empty: @@ -618,7 +619,11 @@ def send_interactions( logging.info("[ACTOR] Interactions process stopped") -def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore +def transitions_stream( + shutdown_event: Any, # Event + transitions_queue: Queue, + timeout: float, +) -> services_pb2.Empty: while not shutdown_event.is_set(): try: message = transitions_queue.get(block=True, timeout=timeout) @@ -634,9 +639,9 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: def interactions_stream( - shutdown_event: Event, + shutdown_event: Any, # Event interactions_queue: Queue, - timeout: float, # type: ignore + timeout: float, ) -> services_pb2.Empty: while not shutdown_event.is_set(): try: diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index fa5fa98e8..03b57eac9 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -51,6 +51,7 @@ import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path from pprint import pformat +from typing import Any import grpc import torch @@ -181,7 +182,7 @@ def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None): def start_learner_threads( cfg: TrainRLServerPipelineConfig, wandb_logger: WandBLogger | None, - shutdown_event: any, # Event, + shutdown_event: Any, # Event ) -> None: """ Start the learner threads for training. @@ -255,7 +256,7 @@ def start_learner_threads( def add_actor_information_and_train( cfg: TrainRLServerPipelineConfig, wandb_logger: WandBLogger | None, - shutdown_event: any, # Event, + shutdown_event: Any, # Event transition_queue: Queue, interaction_message_queue: Queue, parameters_queue: Queue, @@ -465,7 +466,7 @@ def start_learner( parameters_queue: Queue, transition_queue: Queue, interaction_message_queue: Queue, - shutdown_event: any, # Event, + shutdown_event: Any, # Event cfg: TrainRLServerPipelineConfig, ): """ @@ -907,7 +908,7 @@ def process_transitions( replay_buffer: ReplayBuffer, offline_replay_buffer: ReplayBuffer, dataset_repo_id: str | None, - shutdown_event: any, + shutdown_event: Any, # Event ): """Process all available transitions from the queue. @@ -945,7 +946,7 @@ def process_interaction_messages( interaction_message_queue: Queue, interaction_step_shift: int, wandb_logger: WandBLogger | None, - shutdown_event: any, + shutdown_event: Any, # Event ) -> dict | None: """Process all available interaction messages from the queue.