refactor(rl): update shutdown_event type hints from 'any' to 'Any' for consistency and clarity

This commit is contained in:
Khalil Meftah
2026-04-18 15:39:32 +02:00
parent 87d4c9879c
commit 2c97cb23c8
2 changed files with 20 additions and 14 deletions

View File

@@ -51,11 +51,12 @@ import os
import time
from functools import lru_cache
from queue import Empty
from typing import Any
import grpc
import torch
from torch import nn
from torch.multiprocessing import Event, Queue
from torch.multiprocessing import Queue
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
@@ -210,7 +211,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
def act_with_policy(
cfg: TrainRLServerPipelineConfig,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
parameters_queue: Queue,
transitions_queue: Queue,
interactions_queue: Queue,
@@ -414,7 +415,7 @@ def act_with_policy(
def establish_learner_connection(
stub: services_pb2_grpc.LearnerServiceStub,
shutdown_event: Event, # type: ignore
shutdown_event: Any, # Event
attempts: int = 30,
):
"""Establish a connection with the learner.
@@ -466,7 +467,7 @@ def learner_service_client(
def receive_policy(
cfg: TrainRLServerPipelineConfig,
parameters_queue: Queue,
shutdown_event: Event, # type: ignore
shutdown_event: Any, # Event
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
):
@@ -518,7 +519,7 @@ def receive_policy(
def send_transitions(
cfg: TrainRLServerPipelineConfig,
transitions_queue: Queue,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> services_pb2.Empty:
@@ -568,7 +569,7 @@ def send_transitions(
def send_interactions(
cfg: TrainRLServerPipelineConfig,
interactions_queue: Queue,
shutdown_event: Event, # type: ignore
shutdown_event: Any, # Event
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> services_pb2.Empty:
@@ -618,7 +619,11 @@ def send_interactions(
logging.info("[ACTOR] Interactions process stopped")
def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore
def transitions_stream(
shutdown_event: Any, # Event
transitions_queue: Queue,
timeout: float,
) -> services_pb2.Empty:
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=timeout)
@@ -634,9 +639,9 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout:
def interactions_stream(
shutdown_event: Event,
shutdown_event: Any, # Event
interactions_queue: Queue,
timeout: float, # type: ignore
timeout: float,
) -> services_pb2.Empty:
while not shutdown_event.is_set():
try:

View File

@@ -51,6 +51,7 @@ import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from pprint import pformat
from typing import Any
import grpc
import torch
@@ -181,7 +182,7 @@ def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None):
def start_learner_threads(
cfg: TrainRLServerPipelineConfig,
wandb_logger: WandBLogger | None,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
) -> None:
"""
Start the learner threads for training.
@@ -255,7 +256,7 @@ def start_learner_threads(
def add_actor_information_and_train(
cfg: TrainRLServerPipelineConfig,
wandb_logger: WandBLogger | None,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
transition_queue: Queue,
interaction_message_queue: Queue,
parameters_queue: Queue,
@@ -465,7 +466,7 @@ def start_learner(
parameters_queue: Queue,
transition_queue: Queue,
interaction_message_queue: Queue,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
cfg: TrainRLServerPipelineConfig,
):
"""
@@ -907,7 +908,7 @@ def process_transitions(
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
dataset_repo_id: str | None,
shutdown_event: any,
shutdown_event: Any, # Event
):
"""Process all available transitions from the queue.
@@ -945,7 +946,7 @@ def process_interaction_messages(
interaction_message_queue: Queue,
interaction_step_shift: int,
wandb_logger: WandBLogger | None,
shutdown_event: any,
shutdown_event: Any, # Event
) -> dict | None:
"""Process all available interaction messages from the queue.