Clean the code

This commit is contained in:
AdilZouitine
2025-04-24 17:22:54 +02:00
parent b8c2b0bb93
commit a8da4a347e
4 changed files with 56 additions and 39 deletions

View File

@@ -356,10 +356,19 @@ def act_with_policy(
def establish_learner_connection(
stub,
shutdown_event: any, # Event,
attempts=30,
stub: hilserl_pb2_grpc.LearnerServiceStub,
shutdown_event: Event, # type: ignore
attempts: int = 30,
):
"""Establish a connection with the learner.
Args:
stub (hilserl_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
shutdown_event (Event): The event to check if the connection should be established.
attempts (int): The number of attempts to establish the connection.
Returns:
bool: True if the connection is established, False otherwise.
"""
for _ in range(attempts):
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down establish_learner_connection")
@@ -378,7 +387,8 @@ def establish_learner_connection(
@lru_cache(maxsize=1)
def learner_service_client(
host="127.0.0.1", port=50051
host: str = "127.0.0.1",
port: int = 50051,
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
import json
@@ -426,12 +436,18 @@ def learner_service_client(
def receive_policy(
cfg: TrainPipelineConfig,
parameters_queue: Queue,
shutdown_event: any, # Event,
shutdown_event: Event, # type: ignore
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
):
logging.info("[ACTOR] Start receiving parameters from the Learner")
"""Receive parameters from the learner.
Args:
cfg (TrainPipelineConfig): The configuration for the actor.
parameters_queue (Queue): The queue to receive the parameters.
shutdown_event (Event): The event to check if the process should shutdown.
"""
logging.info("[ACTOR] Start receiving parameters from the Learner")
if not use_threads(cfg):
# Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs")
@@ -481,7 +497,7 @@ def send_transitions(
This function continuously retrieves messages from the queue and processes:
- **Transition Data:**
- Transition Data:
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
@@ -522,7 +538,7 @@ def send_transitions(
def send_interactions(
cfg: TrainPipelineConfig,
interactions_queue: Queue,
shutdown_event: any, # Event,
shutdown_event: Event, # type: ignore
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> hilserl_pb2.Empty:
@@ -531,7 +547,7 @@ def send_interactions(
This function continuously retrieves messages from the queue and processes:
- **Interaction Messages:**
- Interaction Messages:
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
"""
@@ -568,7 +584,7 @@ def send_interactions(
logging.info("[ACTOR] Interactions process stopped")
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty:
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: # type: ignore
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=5)
@@ -584,7 +600,7 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilse
def interactions_stream(
shutdown_event: any, # Event,
shutdown_event: Event, # type: ignore
interactions_queue: Queue,
) -> hilserl_pb2.Empty:
while not shutdown_event.is_set():
@@ -643,6 +659,14 @@ def push_transitions_to_transport_queue(transitions: list, transitions_queue):
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
"""Get the frequency statistics of the policy.
Args:
list_policy_time (list[float]): The list of policy times.
Returns:
dict[str, float]: The frequency statistics of the policy.
"""
stats = {}
list_policy_fps = [1.0 / t for t in list_policy_time]
if len(list_policy_fps) > 1: