mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
[Port HIL-SERL] Adjust Actor-Learner architecture & clean up dependency management for HIL-SERL (#722)
This commit is contained in:
committed by
Michel Aractingi
parent
85242cac67
commit
e1d55c7a44
@@ -17,9 +17,9 @@ import io
|
||||
import logging
|
||||
import pickle
|
||||
import queue
|
||||
import time
|
||||
from concurrent import futures
|
||||
from statistics import mean, quantiles
|
||||
import signal
|
||||
from functools import lru_cache
|
||||
|
||||
# from lerobot.scripts.eval import eval_policy
|
||||
from threading import Thread
|
||||
@@ -35,7 +35,6 @@ from torch import nn
|
||||
# from lerobot.common.envs.utils import preprocess_maniskill_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import (
|
||||
@@ -44,14 +43,24 @@ from lerobot.common.utils.utils import (
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
|
||||
from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device
|
||||
from lerobot.scripts.server.buffer import (
|
||||
Transition,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
bytes_buffer_size,
|
||||
)
|
||||
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
||||
from lerobot.scripts.server import learner_service
|
||||
|
||||
from threading import Event
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
parameters_queue = queue.Queue(maxsize=1)
|
||||
message_queue = queue.Queue(maxsize=1_000_000)
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
|
||||
class ActorInformation:
|
||||
"""
|
||||
@@ -70,95 +79,171 @@ class ActorInformation:
|
||||
self.interaction_message = interaction_message
|
||||
|
||||
|
||||
class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
|
||||
"""
|
||||
gRPC service for actor-learner communication in reinforcement learning.
|
||||
def receive_policy(
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub,
|
||||
shutdown_event: Event,
|
||||
parameters_queue: queue.Queue,
|
||||
):
|
||||
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||
bytes_buffer = io.BytesIO()
|
||||
step = 0
|
||||
try:
|
||||
for model_update in learner_client.StreamParameters(hilserl_pb2.Empty()):
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down policy streaming receiver")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
This service is responsible for:
|
||||
1. Streaming batches of transition data and statistical metrics from the actor to the learner.
|
||||
2. Receiving updated network parameters from the learner.
|
||||
"""
|
||||
|
||||
def StreamTransition(self, request, context): # noqa: N802
|
||||
"""
|
||||
Streams data from the actor to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes them based on their type:
|
||||
|
||||
- **Transition Data:**
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
||||
|
||||
- **Interaction Messages:**
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
|
||||
Yields:
|
||||
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
|
||||
"""
|
||||
while True:
|
||||
message = message_queue.get(block=True)
|
||||
|
||||
if message.transition is not None:
|
||||
transition_to_send_to_learner: list[Transition] = [
|
||||
move_transition_to_device(transition=T, device="cpu") for T in message.transition
|
||||
]
|
||||
# Check for NaNs in transitions before sending to learner
|
||||
for transition in transition_to_send_to_learner:
|
||||
for key, value in transition["state"].items():
|
||||
if torch.isnan(value).any():
|
||||
logging.warning(f"Found NaN values in transition {key}")
|
||||
buf = io.BytesIO()
|
||||
torch.save(transition_to_send_to_learner, buf)
|
||||
transition_bytes = buf.getvalue()
|
||||
|
||||
transition_message = hilserl_pb2.Transition(transition_bytes=transition_bytes)
|
||||
|
||||
response = hilserl_pb2.ActorInformation(transition=transition_message)
|
||||
|
||||
elif message.interaction_message is not None:
|
||||
content = hilserl_pb2.InteractionMessage(
|
||||
interaction_message_bytes=pickle.dumps(message.interaction_message)
|
||||
if model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
bytes_buffer.write(model_update.parameter_bytes)
|
||||
logging.info("Received model update at step 0")
|
||||
step = 0
|
||||
continue
|
||||
elif (
|
||||
model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE
|
||||
):
|
||||
bytes_buffer.write(model_update.parameter_bytes)
|
||||
step += 1
|
||||
logging.info(f"Received model update at step {step}")
|
||||
elif model_update.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
|
||||
bytes_buffer.write(model_update.parameter_bytes)
|
||||
logging.info(
|
||||
f"Received model update at step end size {bytes_buffer_size(bytes_buffer)}"
|
||||
)
|
||||
response = hilserl_pb2.ActorInformation(interaction_message=content)
|
||||
|
||||
yield response
|
||||
state_dict = torch.load(bytes_buffer)
|
||||
|
||||
def SendParameters(self, request, context): # noqa: N802
|
||||
"""
|
||||
Receives updated parameters from the learner and updates the actor.
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
step = 0
|
||||
|
||||
The learner calls this method to send new model parameters. The received parameters are deserialized
|
||||
and placed in a queue to be consumed by the actor.
|
||||
logging.info("Model updated")
|
||||
|
||||
Args:
|
||||
request (hilserl_pb2.ParameterUpdate): The request containing serialized network parameters.
|
||||
context (grpc.ServicerContext): The gRPC context.
|
||||
parameters_queue.put(state_dict)
|
||||
|
||||
Returns:
|
||||
hilserl_pb2.Empty: An empty response to acknowledge receipt.
|
||||
"""
|
||||
buffer = io.BytesIO(request.parameter_bytes)
|
||||
params = torch.load(buffer)
|
||||
parameters_queue.put(params)
|
||||
return hilserl_pb2.Empty()
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
|
||||
def serve_actor_service(port=50052):
|
||||
def transitions_stream(shutdown_event: Event, message_queue: queue.Queue):
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = message_queue.get(block=True, timeout=5)
|
||||
except queue.Empty:
|
||||
logging.debug("[ACTOR] Transition queue is empty")
|
||||
continue
|
||||
|
||||
if message.transition is not None:
|
||||
transition_to_send_to_learner: list[Transition] = [
|
||||
move_transition_to_device(transition=T, device="cpu")
|
||||
for T in message.transition
|
||||
]
|
||||
# Check for NaNs in transitions before sending to learner
|
||||
for transition in transition_to_send_to_learner:
|
||||
for key, value in transition["state"].items():
|
||||
if torch.isnan(value).any():
|
||||
logging.warning(f"Found NaN values in transition {key}")
|
||||
buf = io.BytesIO()
|
||||
torch.save(transition_to_send_to_learner, buf)
|
||||
transition_bytes = buf.getvalue()
|
||||
|
||||
transition_message = hilserl_pb2.Transition(
|
||||
transition_bytes=transition_bytes
|
||||
)
|
||||
|
||||
response = hilserl_pb2.ActorInformation(transition=transition_message)
|
||||
|
||||
elif message.interaction_message is not None:
|
||||
content = hilserl_pb2.InteractionMessage(
|
||||
interaction_message_bytes=pickle.dumps(message.interaction_message)
|
||||
)
|
||||
response = hilserl_pb2.ActorInformation(interaction_message=content)
|
||||
|
||||
yield response
|
||||
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
|
||||
def send_transitions(
|
||||
learner_client: hilserl_pb2_grpc.LearnerServiceStub,
|
||||
shutdown_event: Event,
|
||||
message_queue: queue.Queue,
|
||||
):
|
||||
"""
|
||||
Runs a gRPC server to start streaming the data from the actor to the learner.
|
||||
Throught this server the learner can push parameters to the Actor as well.
|
||||
Streams data from the actor to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes them based on their type:
|
||||
|
||||
- **Transition Data:**
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
||||
|
||||
- **Interaction Messages:**
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
|
||||
Yields:
|
||||
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
|
||||
"""
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=20),
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
try:
|
||||
learner_client.ReceiveTransitions(
|
||||
transitions_stream(shutdown_event, message_queue)
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
logging.info("[ACTOR] Finished streaming transitions")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def learner_service_client(
|
||||
host="127.0.0.1", port=50051
|
||||
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
||||
import json
|
||||
|
||||
"""
|
||||
Returns a client for the learner service.
|
||||
|
||||
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
|
||||
So we need to create only one client and reuse it.
|
||||
"""
|
||||
|
||||
service_config = {
|
||||
"methodConfig": [
|
||||
{
|
||||
"name": [{}], # Applies to ALL methods in ALL services
|
||||
"retryPolicy": {
|
||||
"maxAttempts": 5, # Max retries (total attempts = 5)
|
||||
"initialBackoff": "0.1s", # First retry after 0.1s
|
||||
"maxBackoff": "2s", # Max wait time between retries
|
||||
"backoffMultiplier": 2, # Exponential backoff factor
|
||||
"retryableStatusCodes": [
|
||||
"UNAVAILABLE",
|
||||
"DEADLINE_EXCEEDED",
|
||||
], # Retries on network failures
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
service_config_json = json.dumps(service_config)
|
||||
|
||||
channel = grpc.insecure_channel(
|
||||
f"{host}:{port}",
|
||||
options=[
|
||||
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.enable_retries", 1),
|
||||
("grpc.service_config", service_config_json),
|
||||
],
|
||||
)
|
||||
hilserl_pb2_grpc.add_ActorServiceServicer_to_server(ActorServiceServicer(), server)
|
||||
server.add_insecure_port(f"[::]:{port}")
|
||||
server.start()
|
||||
logging.info(f"[ACTOR] gRPC server listening on port {port}")
|
||||
server.wait_for_termination()
|
||||
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
|
||||
logging.info("[LEARNER] Learner service client created")
|
||||
return stub, channel
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
|
||||
@@ -169,7 +254,9 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, d
|
||||
policy.load_state_dict(state_dict)
|
||||
|
||||
|
||||
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module):
|
||||
def act_with_policy(
|
||||
cfg: DictConfig, robot: Robot, reward_classifier: nn.Module, shutdown_event: Event
|
||||
):
|
||||
"""
|
||||
Executes policy interaction within the environment.
|
||||
|
||||
@@ -182,7 +269,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
|
||||
online_env = make_robot_env(
|
||||
robot=robot, reward_classifier=reward_classifier, cfg=cfg
|
||||
)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
@@ -227,17 +316,27 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
episode_intervention = False
|
||||
|
||||
for interaction_step in range(cfg.training.online_steps):
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutdown signal received. Exiting...")
|
||||
return
|
||||
|
||||
if interaction_step >= cfg.training.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with TimerManager(
|
||||
elapsed_time_list=list_policy_time, label="Policy inference time", log=False
|
||||
elapsed_time_list=list_policy_time,
|
||||
label="Policy inference time",
|
||||
log=False,
|
||||
) as timer: # noqa: F841
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
log_policy_frequency_issue(
|
||||
policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step
|
||||
)
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
||||
next_obs, reward, done, truncated, info = online_env.step(
|
||||
action.squeeze(dim=0).cpu().numpy()
|
||||
)
|
||||
else:
|
||||
# TODO (azouitine): Make a custom space for torch tensor
|
||||
action = online_env.action_space.sample()
|
||||
@@ -245,7 +344,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
|
||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
||||
action = (
|
||||
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
|
||||
torch.from_numpy(action[0])
|
||||
.to(device, non_blocking=device.type == "cuda")
|
||||
.unsqueeze(dim=0)
|
||||
)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
@@ -261,7 +362,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
# Check for NaN values in observations
|
||||
for key, tensor in obs.items():
|
||||
if torch.isnan(tensor).any():
|
||||
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
|
||||
logging.error(
|
||||
f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}"
|
||||
)
|
||||
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
@@ -281,13 +384,19 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
# Because we are using a single environment we can index at zero
|
||||
if done or truncated:
|
||||
# TODO: Handle logging for episode information
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
logging.info(
|
||||
f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}"
|
||||
)
|
||||
|
||||
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
||||
update_policy_parameters(
|
||||
policy=policy.actor, parameters_queue=parameters_queue, device=device
|
||||
)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
send_transitions_in_chunks(
|
||||
transitions=list_transition_to_send_to_learner, message_queue=message_queue, chunk_size=4
|
||||
transitions=list_transition_to_send_to_learner,
|
||||
message_queue=message_queue,
|
||||
chunk_size=4,
|
||||
)
|
||||
list_transition_to_send_to_learner = []
|
||||
|
||||
@@ -332,11 +441,16 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
|
||||
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
|
||||
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
|
||||
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
|
||||
stats = {
|
||||
"Policy frequency [Hz]": policy_fps,
|
||||
"Policy frequency 90th-p [Hz]": quantiles_90,
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
|
||||
def log_policy_frequency_issue(
|
||||
policy_fps: float, cfg: DictConfig, interaction_step: int
|
||||
):
|
||||
if policy_fps < cfg.fps:
|
||||
logging.warning(
|
||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
|
||||
@@ -347,7 +461,34 @@ def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_s
|
||||
def actor_cli(cfg: dict):
|
||||
robot = make_robot(cfg=cfg.robot)
|
||||
|
||||
server_thread = Thread(target=serve_actor_service, args=(cfg.actor_learner_config.port,), daemon=True)
|
||||
shutdown_event = Event()
|
||||
|
||||
# Define signal handler
|
||||
def signal_handler(signum, frame):
|
||||
logging.info("Shutdown signal received. Cleaning up...")
|
||||
shutdown_event.set()
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
|
||||
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
|
||||
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
||||
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.actor_learner_config.learner_host,
|
||||
port=cfg.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
receive_policy_thread = Thread(
|
||||
target=receive_policy,
|
||||
args=(learner_client, shutdown_event, parameters_queue),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_thread = Thread(
|
||||
target=send_transitions,
|
||||
args=(learner_client, shutdown_event, message_queue),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
# HACK: FOR MANISKILL we do not have a reward classifier
|
||||
# TODO: Remove this once we merge into main
|
||||
@@ -360,15 +501,27 @@ def actor_cli(cfg: dict):
|
||||
pretrained_path=cfg.env.reward_classifier.pretrained_path,
|
||||
config_path=cfg.env.reward_classifier.config_path,
|
||||
)
|
||||
|
||||
policy_thread = Thread(
|
||||
target=act_with_policy,
|
||||
daemon=True,
|
||||
args=(cfg, robot, reward_classifier),
|
||||
args=(cfg, robot, reward_classifier, shutdown_event),
|
||||
)
|
||||
server_thread.start()
|
||||
|
||||
transitions_thread.start()
|
||||
policy_thread.start()
|
||||
receive_policy_thread.start()
|
||||
|
||||
shutdown_event.wait()
|
||||
logging.info("[ACTOR] Shutdown event received")
|
||||
grpc_channel.close()
|
||||
|
||||
policy_thread.join()
|
||||
server_thread.join()
|
||||
logging.info("[ACTOR] Policy thread joined")
|
||||
transitions_thread.join()
|
||||
logging.info("[ACTOR] Transitions thread joined")
|
||||
receive_policy_thread.join()
|
||||
logging.info("[ACTOR] Receive policy thread joined")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user