#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import time from functools import lru_cache from queue import Empty from statistics import mean, quantiles # from lerobot.scripts.eval import eval_policy import grpc import torch from torch import nn from torch.multiprocessing import Event, Queue # TODO: Remove the import of maniskill from lerobot.common.policies.factory import make_policy from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.utils.random_utils import set_seed from lerobot.common.utils.utils import ( TimerManager, get_safe_torch_device, init_logging, ) from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service from lerobot.scripts.server.buffer import ( Transition, bytes_to_state_dict, move_state_dict_to_device, move_transition_to_device, python_object_to_bytes, transitions_to_bytes, ) from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env from lerobot.scripts.server.network_utils import ( receive_bytes_in_chunks, send_bytes_in_chunks, ) from lerobot.scripts.server.utils import get_last_item_from_queue, setup_process_handlers ACTOR_SHUTDOWN_TIMEOUT = 30 def receive_policy( cfg: TrainPipelineConfig, parameters_queue: Queue, shutdown_event: any, # Event, learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, grpc_channel: grpc.Channel | None = None, ): logging.info("[ACTOR] Start receiving parameters from the Learner") if not use_threads(cfg): # Setup process handlers to handle shutdown signal # But use shutdown event from the main process setup_process_handlers(use_threads=False) if grpc_channel is None or learner_client is None: learner_client, grpc_channel = learner_service_client( host=cfg.policy.actor_learner_config.learner_host, port=cfg.policy.actor_learner_config.learner_port, ) try: iterator = learner_client.StreamParameters(hilserl_pb2.Empty()) receive_bytes_in_chunks( iterator, parameters_queue, shutdown_event, log_prefix="[ACTOR] parameters", ) except grpc.RpcError as e: logging.error(f"[ACTOR] gRPC error: {e}") if not use_threads(cfg): grpc_channel.close() logging.info("[ACTOR] Received policy loop stopped") def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: while not shutdown_event.is_set(): try: message = transitions_queue.get(block=True, timeout=5) except Empty: logging.debug("[ACTOR] Transition queue is empty") continue yield from send_bytes_in_chunks( message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions" ) return hilserl_pb2.Empty() def interactions_stream( shutdown_event: any, # Event, interactions_queue: Queue, ) -> hilserl_pb2.Empty: while not shutdown_event.is_set(): try: message = interactions_queue.get(block=True, timeout=5) except Empty: logging.debug("[ACTOR] Interaction queue is empty") continue yield from send_bytes_in_chunks( message, hilserl_pb2.InteractionMessage, log_prefix="[ACTOR] Send interactions", ) return hilserl_pb2.Empty() def send_transitions( cfg: TrainPipelineConfig, transitions_queue: Queue, shutdown_event: any, # Event, learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, grpc_channel: grpc.Channel | None = None, ) -> hilserl_pb2.Empty: """ Sends transitions to the learner. This function continuously retrieves messages from the queue and processes: - **Transition Data:** - A batch of transitions (observation, action, reward, next observation) is collected. - Transitions are moved to the CPU and serialized using PyTorch. - The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner. """ if not use_threads(cfg): # Setup process handlers to handle shutdown signal # But use shutdown event from the main process setup_process_handlers(False) if grpc_channel is None or learner_client is None: learner_client, grpc_channel = learner_service_client( host=cfg.policy.actor_learner_config.learner_host, port=cfg.policy.actor_learner_config.learner_port, ) try: learner_client.SendTransitions(transitions_stream(shutdown_event, transitions_queue)) except grpc.RpcError as e: logging.error(f"[ACTOR] gRPC error: {e}") logging.info("[ACTOR] Finished streaming transitions") if not use_threads(cfg): grpc_channel.close() logging.info("[ACTOR] Transitions process stopped") def send_interactions( cfg: TrainPipelineConfig, interactions_queue: Queue, shutdown_event: any, # Event, learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, grpc_channel: grpc.Channel | None = None, ) -> hilserl_pb2.Empty: """ Sends interactions to the learner. This function continuously retrieves messages from the queue and processes: - **Interaction Messages:** - Contains useful statistics about episodic rewards and policy timings. - The message is serialized using `pickle` and sent to the learner. """ if not use_threads(cfg): # Setup process handlers to handle shutdown signal # But use shutdown event from the main process setup_process_handlers(False) if grpc_channel is None or learner_client is None: learner_client, grpc_channel = learner_service_client( host=cfg.policy.actor_learner_config.learner_host, port=cfg.policy.actor_learner_config.learner_port, ) try: learner_client.SendInteractions(interactions_stream(shutdown_event, interactions_queue)) except grpc.RpcError as e: logging.error(f"[ACTOR] gRPC error: {e}") logging.info("[ACTOR] Finished streaming interactions") if not use_threads(cfg): grpc_channel.close() logging.info("[ACTOR] Interactions process stopped") @lru_cache(maxsize=1) def learner_service_client( host="127.0.0.1", port=50051 ) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]: import json """ Returns a client for the learner service. GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection. So we need to create only one client and reuse it. """ service_config = { "methodConfig": [ { "name": [{}], # Applies to ALL methods in ALL services "retryPolicy": { "maxAttempts": 5, # Max retries (total attempts = 5) "initialBackoff": "0.1s", # First retry after 0.1s "maxBackoff": "2s", # Max wait time between retries "backoffMultiplier": 2, # Exponential backoff factor "retryableStatusCodes": [ "UNAVAILABLE", "DEADLINE_EXCEEDED", ], # Retries on network failures }, } ] } service_config_json = json.dumps(service_config) channel = grpc.insecure_channel( f"{host}:{port}", options=[ ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE), ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE), ("grpc.enable_retries", 1), ("grpc.service_config", service_config_json), ], ) stub = hilserl_pb2_grpc.LearnerServiceStub(channel) logging.info("[ACTOR] Learner service client created") return stub, channel def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device): if not parameters_queue.empty(): logging.info("[ACTOR] Load new parameters from Learner.") bytes_state_dict = get_last_item_from_queue(parameters_queue) state_dict = bytes_to_state_dict(bytes_state_dict) state_dict = move_state_dict_to_device(state_dict, device=device) policy.load_state_dict(state_dict) def act_with_policy( cfg: TrainPipelineConfig, # robot: Robot, reward_classifier: nn.Module, shutdown_event: any, # Event, parameters_queue: Queue, transitions_queue: Queue, interactions_queue: Queue, ): """ Executes policy interaction within the environment. This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner. Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network. Args: cfg (DictConfig): Configuration settings for the interaction process. """ logging.info("make_env online") online_env = make_robot_env(cfg=cfg.env) set_seed(cfg.seed) device = get_safe_torch_device(cfg.policy.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True logging.info("make_policy") ### Instantiate the policy in both the actor and learner processes ### To avoid sending a SACPolicy object through the port, we create a policy intance ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters # TODO: At some point we should just need make sac policy policy: SACPolicy = make_policy( cfg=cfg.policy, env_cfg=cfg.env, ) policy = torch.compile(policy) assert isinstance(policy, nn.Module) obs, info = online_env.reset() # NOTE: For the moment we will solely handle the case of a single environment sum_reward_episode = 0 list_transition_to_send_to_learner = [] list_policy_time = [] episode_intervention = False # Add counters for intervention rate calculation episode_intervention_steps = 0 episode_total_steps = 0 for interaction_step in range(cfg.policy.online_steps): start_time = time.perf_counter() if shutdown_event.is_set(): logging.info("[ACTOR] Shutting down act_with_policy") return if interaction_step >= cfg.policy.online_step_before_learning: # Time policy inference and check if it meets FPS requirement with TimerManager( elapsed_time_list=list_policy_time, label="Policy inference time", log=False, ) as timer: # noqa: F841 action = policy.select_action(batch=obs) policy_fps = 1.0 / (list_policy_time[-1] + 1e-9) log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy()) else: # TODO (azouitine): Make a custom space for torch tensor action = online_env.action_space.sample() next_obs, reward, done, truncated, info = online_env.step(action) # HACK: We have only one env but we want to batch it, it will be resolved with the torch box action = ( torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0) ) sum_reward_episode += float(reward) # Increment total steps counter for intervention rate episode_total_steps += 1 # NOTE: We overide the action if the intervention is True, because the action applied is the intervention action if "is_intervention" in info and info["is_intervention"]: # TODO: Check the shape # NOTE: The action space for demonstration before hand is with the full action space # but sometimes for example we want to deactivate the gripper action = info["action_intervention"] episode_intervention = True # Increment intervention steps counter episode_intervention_steps += 1 # Check for NaN values in observations for key, tensor in obs.items(): if torch.isnan(tensor).any(): logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}") list_transition_to_send_to_learner.append( Transition( state=obs, action=action, reward=reward, next_state=next_obs, done=done, truncated=truncated, # TODO: (azouitine) Handle truncation properly complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool ) ) # assign obs to the next obs and continue the rollout obs = next_obs # HACK: We have only one env but we want to batch it, it will be resolved with the torch box # Because we are using a single environment we can index at zero if done or truncated: # TODO: Handle logging for episode information logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device) if len(list_transition_to_send_to_learner) > 0: push_transitions_to_transport_queue( transitions=list_transition_to_send_to_learner, transitions_queue=transitions_queue, ) list_transition_to_send_to_learner = [] stats = get_frequency_stats(list_policy_time) list_policy_time.clear() # Calculate intervention rate intervention_rate = 0.0 if episode_total_steps > 0: intervention_rate = episode_intervention_steps / episode_total_steps # Send episodic reward to the learner interactions_queue.put( python_object_to_bytes( { "Episodic reward": sum_reward_episode, "Interaction step": interaction_step, "Episode intervention": int(episode_intervention), "Intervention rate": intervention_rate, **stats, } ) ) sum_reward_episode = 0.0 episode_intervention = False # Reset intervention counters episode_intervention_steps = 0 episode_total_steps = 0 obs, info = online_env.reset() if cfg.env.fps is not None: dt_time = time.perf_counter() - start_time busy_wait(1 / cfg.env.fps - dt_time) def push_transitions_to_transport_queue(transitions: list, transitions_queue): """Send transitions to learner in smaller chunks to avoid network issues. Args: transitions: List of transitions to send message_queue: Queue to send messages to learner chunk_size: Size of each chunk to send """ transition_to_send_to_learner = [] for transition in transitions: tr = move_transition_to_device(transition=transition, device="cpu") for key, value in tr["state"].items(): if torch.isnan(value).any(): logging.warning(f"Found NaN values in transition {key}") transition_to_send_to_learner.append(tr) transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner)) def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]: stats = {} list_policy_fps = [1.0 / t for t in list_policy_time] if len(list_policy_fps) > 1: policy_fps = mean(list_policy_fps) quantiles_90 = quantiles(list_policy_fps, n=10)[-1] logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}") logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}") stats = { "Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90, } return stats def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int): if policy_fps < cfg.env.fps: logging.warning( f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}" ) def establish_learner_connection( stub, shutdown_event: any, # Event, attempts=30, ): for _ in range(attempts): if shutdown_event.is_set(): logging.info("[ACTOR] Shutting down establish_learner_connection") return False # Force a connection attempt and check state try: logging.info("[ACTOR] Send ready message to Learner") if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty(): return True except grpc.RpcError as e: logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}") time.sleep(2) return False def use_threads(cfg: TrainPipelineConfig) -> bool: return cfg.policy.concurrency.actor == "threads" @parser.wrap() def actor_cli(cfg: TrainPipelineConfig): if not use_threads(cfg): import torch.multiprocessing as mp mp.set_start_method("spawn") init_logging(log_file="actor.log") shutdown_event = setup_process_handlers(use_threads(cfg)) learner_client, grpc_channel = learner_service_client( host=cfg.policy.actor_learner_config.learner_host, port=cfg.policy.actor_learner_config.learner_port, ) logging.info("[ACTOR] Establishing connection with Learner") if not establish_learner_connection(learner_client, shutdown_event): logging.error("[ACTOR] Failed to establish connection with Learner") return if not use_threads(cfg): # If we use multithreading, we can reuse the channel grpc_channel.close() grpc_channel = None logging.info("[ACTOR] Connection with Learner established") parameters_queue = Queue() transitions_queue = Queue() interactions_queue = Queue() concurrency_entity = None if use_threads(cfg): from threading import Thread concurrency_entity = Thread else: from multiprocessing import Process concurrency_entity = Process receive_policy_process = concurrency_entity( target=receive_policy, args=(cfg, parameters_queue, shutdown_event, grpc_channel), daemon=True, ) transitions_process = concurrency_entity( target=send_transitions, args=(cfg, transitions_queue, shutdown_event, grpc_channel), daemon=True, ) interactions_process = concurrency_entity( target=send_interactions, args=(cfg, interactions_queue, shutdown_event, grpc_channel), daemon=True, ) transitions_process.start() interactions_process.start() receive_policy_process.start() # HACK: FOR MANISKILL we do not have a reward classifier # TODO: Remove this once we merge into main reward_classifier = None # if ( # cfg.env.reward_classifier["pretrained_path"] is not None # and cfg.env.reward_classifier["config_path"] is not None # ): # reward_classifier = get_classifier( # pretrained_path=cfg.env.reward_classifier["pretrained_path"], # config_path=cfg.env.reward_classifier["config_path"], # ) act_with_policy( cfg=cfg, reward_classifier=reward_classifier, shutdown_event=shutdown_event, parameters_queue=parameters_queue, transitions_queue=transitions_queue, interactions_queue=interactions_queue, ) logging.info("[ACTOR] Policy process joined") logging.info("[ACTOR] Closing queues") transitions_queue.close() interactions_queue.close() parameters_queue.close() transitions_process.join() logging.info("[ACTOR] Transitions process joined") interactions_process.join() logging.info("[ACTOR] Interactions process joined") receive_policy_process.join() logging.info("[ACTOR] Receive policy process joined") logging.info("[ACTOR] join queues") transitions_queue.cancel_join_thread() interactions_queue.cancel_join_thread() parameters_queue.cancel_join_thread() logging.info("[ACTOR] queues closed") if __name__ == "__main__": actor_cli()