[Port HIL-SERL] Adjust Actor-Learner architecture & clean up dependency management for HIL-SERL (#722)

This commit is contained in:
Eugene Mironov
2025-02-21 16:29:00 +07:00
committed by Michel Aractingi
parent 85242cac67
commit e1d55c7a44
17 changed files with 1949 additions and 475 deletions

View File

@@ -17,9 +17,9 @@ import io
import logging
import pickle
import queue
import time
from concurrent import futures
from statistics import mean, quantiles
import signal
from functools import lru_cache
# from lerobot.scripts.eval import eval_policy
from threading import Thread
@@ -35,7 +35,6 @@ from torch import nn
# from lerobot.common.envs.utils import preprocess_maniskill_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.control_utils import busy_wait
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import (
@@ -44,14 +43,24 @@ from lerobot.common.utils.utils import (
set_global_seed,
)
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device
from lerobot.scripts.server.buffer import (
Transition,
move_state_dict_to_device,
move_transition_to_device,
bytes_buffer_size,
)
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
from lerobot.scripts.server import learner_service
from threading import Event
logging.basicConfig(level=logging.INFO)
parameters_queue = queue.Queue(maxsize=1)
message_queue = queue.Queue(maxsize=1_000_000)
ACTOR_SHUTDOWN_TIMEOUT = 30
class ActorInformation:
"""
@@ -70,95 +79,171 @@ class ActorInformation:
self.interaction_message = interaction_message
class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
"""
gRPC service for actor-learner communication in reinforcement learning.
def receive_policy(
learner_client: hilserl_pb2_grpc.LearnerServiceStub,
shutdown_event: Event,
parameters_queue: queue.Queue,
):
logging.info("[ACTOR] Start receiving parameters from the Learner")
bytes_buffer = io.BytesIO()
step = 0
try:
for model_update in learner_client.StreamParameters(hilserl_pb2.Empty()):
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down policy streaming receiver")
return hilserl_pb2.Empty()
This service is responsible for:
1. Streaming batches of transition data and statistical metrics from the actor to the learner.
2. Receiving updated network parameters from the learner.
"""
def StreamTransition(self, request, context): # noqa: N802
"""
Streams data from the actor to the learner.
This function continuously retrieves messages from the queue and processes them based on their type:
- **Transition Data:**
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
- **Interaction Messages:**
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
Yields:
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
"""
while True:
message = message_queue.get(block=True)
if message.transition is not None:
transition_to_send_to_learner: list[Transition] = [
move_transition_to_device(transition=T, device="cpu") for T in message.transition
]
# Check for NaNs in transitions before sending to learner
for transition in transition_to_send_to_learner:
for key, value in transition["state"].items():
if torch.isnan(value).any():
logging.warning(f"Found NaN values in transition {key}")
buf = io.BytesIO()
torch.save(transition_to_send_to_learner, buf)
transition_bytes = buf.getvalue()
transition_message = hilserl_pb2.Transition(transition_bytes=transition_bytes)
response = hilserl_pb2.ActorInformation(transition=transition_message)
elif message.interaction_message is not None:
content = hilserl_pb2.InteractionMessage(
interaction_message_bytes=pickle.dumps(message.interaction_message)
if model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
bytes_buffer.write(model_update.parameter_bytes)
logging.info("Received model update at step 0")
step = 0
continue
elif (
model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE
):
bytes_buffer.write(model_update.parameter_bytes)
step += 1
logging.info(f"Received model update at step {step}")
elif model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
bytes_buffer.write(model_update.parameter_bytes)
logging.info(
f"Received model update at step end size {bytes_buffer_size(bytes_buffer)}"
)
response = hilserl_pb2.ActorInformation(interaction_message=content)
yield response
state_dict = torch.load(bytes_buffer)
def SendParameters(self, request, context): # noqa: N802
"""
Receives updated parameters from the learner and updates the actor.
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
step = 0
The learner calls this method to send new model parameters. The received parameters are deserialized
and placed in a queue to be consumed by the actor.
logging.info("Model updated")
Args:
request (hilserl_pb2.ParameterUpdate): The request containing serialized network parameters.
context (grpc.ServicerContext): The gRPC context.
parameters_queue.put(state_dict)
Returns:
hilserl_pb2.Empty: An empty response to acknowledge receipt.
"""
buffer = io.BytesIO(request.parameter_bytes)
params = torch.load(buffer)
parameters_queue.put(params)
return hilserl_pb2.Empty()
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
return hilserl_pb2.Empty()
def serve_actor_service(port=50052):
def transitions_stream(shutdown_event: Event, message_queue: queue.Queue):
while not shutdown_event.is_set():
try:
message = message_queue.get(block=True, timeout=5)
except queue.Empty:
logging.debug("[ACTOR] Transition queue is empty")
continue
if message.transition is not None:
transition_to_send_to_learner: list[Transition] = [
move_transition_to_device(transition=T, device="cpu")
for T in message.transition
]
# Check for NaNs in transitions before sending to learner
for transition in transition_to_send_to_learner:
for key, value in transition["state"].items():
if torch.isnan(value).any():
logging.warning(f"Found NaN values in transition {key}")
buf = io.BytesIO()
torch.save(transition_to_send_to_learner, buf)
transition_bytes = buf.getvalue()
transition_message = hilserl_pb2.Transition(
transition_bytes=transition_bytes
)
response = hilserl_pb2.ActorInformation(transition=transition_message)
elif message.interaction_message is not None:
content = hilserl_pb2.InteractionMessage(
interaction_message_bytes=pickle.dumps(message.interaction_message)
)
response = hilserl_pb2.ActorInformation(interaction_message=content)
yield response
return hilserl_pb2.Empty()
def send_transitions(
learner_client: hilserl_pb2_grpc.LearnerServiceStub,
shutdown_event: Event,
message_queue: queue.Queue,
):
"""
Runs a gRPC server to start streaming the data from the actor to the learner.
Throught this server the learner can push parameters to the Actor as well.
Streams data from the actor to the learner.
This function continuously retrieves messages from the queue and processes them based on their type:
- **Transition Data:**
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
- **Interaction Messages:**
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
Yields:
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
"""
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=20),
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
try:
learner_client.ReceiveTransitions(
transitions_stream(shutdown_event, message_queue)
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
logging.info("[ACTOR] Finished streaming transitions")
@lru_cache(maxsize=1)
def learner_service_client(
host="127.0.0.1", port=50051
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
import json
"""
Returns a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
So we need to create only one client and reuse it.
"""
service_config = {
"methodConfig": [
{
"name": [{}], # Applies to ALL methods in ALL services
"retryPolicy": {
"maxAttempts": 5, # Max retries (total attempts = 5)
"initialBackoff": "0.1s", # First retry after 0.1s
"maxBackoff": "2s", # Max wait time between retries
"backoffMultiplier": 2, # Exponential backoff factor
"retryableStatusCodes": [
"UNAVAILABLE",
"DEADLINE_EXCEEDED",
], # Retries on network failures
},
}
]
}
service_config_json = json.dumps(service_config)
channel = grpc.insecure_channel(
f"{host}:{port}",
options=[
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.enable_retries", 1),
("grpc.service_config", service_config_json),
],
)
hilserl_pb2_grpc.add_ActorServiceServicer_to_server(ActorServiceServicer(), server)
server.add_insecure_port(f"[::]:{port}")
server.start()
logging.info(f"[ACTOR] gRPC server listening on port {port}")
server.wait_for_termination()
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
logging.info("[LEARNER] Learner service client created")
return stub, channel
def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
@@ -169,7 +254,9 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, d
policy.load_state_dict(state_dict)
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module):
def act_with_policy(
cfg: DictConfig, robot: Robot, reward_classifier: nn.Module, shutdown_event: Event
):
"""
Executes policy interaction within the environment.
@@ -182,7 +269,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
logging.info("make_env online")
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
online_env = make_robot_env(
robot=robot, reward_classifier=reward_classifier, cfg=cfg
)
set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
@@ -227,17 +316,27 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
episode_intervention = False
for interaction_step in range(cfg.training.online_steps):
if shutdown_event.is_set():
logging.info("[ACTOR] Shutdown signal received. Exiting...")
return
if interaction_step >= cfg.training.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement
with TimerManager(
elapsed_time_list=list_policy_time, label="Policy inference time", log=False
elapsed_time_list=list_policy_time,
label="Policy inference time",
log=False,
) as timer: # noqa: F841
action = policy.select_action(batch=obs)
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
log_policy_frequency_issue(
policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step
)
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
next_obs, reward, done, truncated, info = online_env.step(
action.squeeze(dim=0).cpu().numpy()
)
else:
# TODO (azouitine): Make a custom space for torch tensor
action = online_env.action_space.sample()
@@ -245,7 +344,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
action = (
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
torch.from_numpy(action[0])
.to(device, non_blocking=device.type == "cuda")
.unsqueeze(dim=0)
)
sum_reward_episode += float(reward)
@@ -261,7 +362,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
# Check for NaN values in observations
for key, tensor in obs.items():
if torch.isnan(tensor).any():
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
logging.error(
f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}"
)
list_transition_to_send_to_learner.append(
Transition(
@@ -281,13 +384,19 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
# Because we are using a single environment we can index at zero
if done or truncated:
# TODO: Handle logging for episode information
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
logging.info(
f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}"
)
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
update_policy_parameters(
policy=policy.actor, parameters_queue=parameters_queue, device=device
)
if len(list_transition_to_send_to_learner) > 0:
send_transitions_in_chunks(
transitions=list_transition_to_send_to_learner, message_queue=message_queue, chunk_size=4
transitions=list_transition_to_send_to_learner,
message_queue=message_queue,
chunk_size=4,
)
list_transition_to_send_to_learner = []
@@ -332,11 +441,16 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
stats = {
"Policy frequency [Hz]": policy_fps,
"Policy frequency 90th-p [Hz]": quantiles_90,
}
return stats
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
def log_policy_frequency_issue(
policy_fps: float, cfg: DictConfig, interaction_step: int
):
if policy_fps < cfg.fps:
logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
@@ -347,7 +461,34 @@ def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_s
def actor_cli(cfg: dict):
robot = make_robot(cfg=cfg.robot)
server_thread = Thread(target=serve_actor_service, args=(cfg.actor_learner_config.port,), daemon=True)
shutdown_event = Event()
# Define signal handler
def signal_handler(signum, frame):
logging.info("Shutdown signal received. Cleaning up...")
shutdown_event.set()
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host,
port=cfg.actor_learner_config.learner_port,
)
receive_policy_thread = Thread(
target=receive_policy,
args=(learner_client, shutdown_event, parameters_queue),
daemon=True,
)
transitions_thread = Thread(
target=send_transitions,
args=(learner_client, shutdown_event, message_queue),
daemon=True,
)
# HACK: FOR MANISKILL we do not have a reward classifier
# TODO: Remove this once we merge into main
@@ -360,15 +501,27 @@ def actor_cli(cfg: dict):
pretrained_path=cfg.env.reward_classifier.pretrained_path,
config_path=cfg.env.reward_classifier.config_path,
)
policy_thread = Thread(
target=act_with_policy,
daemon=True,
args=(cfg, robot, reward_classifier),
args=(cfg, robot, reward_classifier, shutdown_event),
)
server_thread.start()
transitions_thread.start()
policy_thread.start()
receive_policy_thread.start()
shutdown_event.wait()
logging.info("[ACTOR] Shutdown event received")
grpc_channel.close()
policy_thread.join()
server_thread.join()
logging.info("[ACTOR] Policy thread joined")
transitions_thread.join()
logging.info("[ACTOR] Transitions thread joined")
receive_policy_thread.join()
logging.info("[ACTOR] Receive policy thread joined")
if __name__ == "__main__":