mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
Clean the code
This commit is contained in:
@@ -356,10 +356,19 @@ def act_with_policy(
|
||||
|
||||
|
||||
def establish_learner_connection(
|
||||
stub,
|
||||
shutdown_event: any, # Event,
|
||||
attempts=30,
|
||||
stub: hilserl_pb2_grpc.LearnerServiceStub,
|
||||
shutdown_event: Event, # type: ignore
|
||||
attempts: int = 30,
|
||||
):
|
||||
"""Establish a connection with the learner.
|
||||
|
||||
Args:
|
||||
stub (hilserl_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
|
||||
shutdown_event (Event): The event to check if the connection should be established.
|
||||
attempts (int): The number of attempts to establish the connection.
|
||||
Returns:
|
||||
bool: True if the connection is established, False otherwise.
|
||||
"""
|
||||
for _ in range(attempts):
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down establish_learner_connection")
|
||||
@@ -378,7 +387,8 @@ def establish_learner_connection(
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def learner_service_client(
|
||||
host="127.0.0.1", port=50051
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 50051,
|
||||
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
||||
import json
|
||||
|
||||
@@ -426,12 +436,18 @@ def learner_service_client(
|
||||
def receive_policy(
|
||||
cfg: TrainPipelineConfig,
|
||||
parameters_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
):
|
||||
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||
"""Receive parameters from the learner.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The configuration for the actor.
|
||||
parameters_queue (Queue): The queue to receive the parameters.
|
||||
shutdown_event (Event): The event to check if the process should shutdown.
|
||||
"""
|
||||
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
@@ -481,7 +497,7 @@ def send_transitions(
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- **Transition Data:**
|
||||
- Transition Data:
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
||||
@@ -522,7 +538,7 @@ def send_transitions(
|
||||
def send_interactions(
|
||||
cfg: TrainPipelineConfig,
|
||||
interactions_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> hilserl_pb2.Empty:
|
||||
@@ -531,7 +547,7 @@ def send_interactions(
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- **Interaction Messages:**
|
||||
- Interaction Messages:
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
"""
|
||||
@@ -568,7 +584,7 @@ def send_interactions(
|
||||
logging.info("[ACTOR] Interactions process stopped")
|
||||
|
||||
|
||||
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty:
|
||||
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: # type: ignore
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = transitions_queue.get(block=True, timeout=5)
|
||||
@@ -584,7 +600,7 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilse
|
||||
|
||||
|
||||
def interactions_stream(
|
||||
shutdown_event: any, # Event,
|
||||
shutdown_event: Event, # type: ignore
|
||||
interactions_queue: Queue,
|
||||
) -> hilserl_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
@@ -643,6 +659,14 @@ def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||
|
||||
|
||||
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||
"""Get the frequency statistics of the policy.
|
||||
|
||||
Args:
|
||||
list_policy_time (list[float]): The list of policy times.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: The frequency statistics of the policy.
|
||||
"""
|
||||
stats = {}
|
||||
list_policy_fps = [1.0 / t for t in list_policy_time]
|
||||
if len(list_policy_fps) > 1:
|
||||
|
||||
Reference in New Issue
Block a user