mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
Package folder structure (#1417)
* Move files * Replace imports & paths * Update relative paths * Update doc symlinks * Update instructions paths * Fix imports * Update grpc files * Update more instructions * Downgrade grpc-tools * Update manifest * Update more paths * Update config paths * Update CI paths * Update bandit exclusions * Remove walkthrough section
This commit is contained in:
709
src/lerobot/scripts/rl/actor.py
Normal file
709
src/lerobot/scripts/rl/actor.py
Normal file
@@ -0,0 +1,709 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Actor server runner for distributed HILSerl robot policy training.
|
||||
|
||||
This script implements the actor component of the distributed HILSerl architecture.
|
||||
It executes the policy in the robot environment, collects experience,
|
||||
and sends transitions to the learner server for policy updates.
|
||||
|
||||
Examples of usage:
|
||||
|
||||
- Start an actor server for real robot training with human-in-the-loop intervention:
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.actor --config_path src/lerobot/configs/train_config_hilserl_so100.json
|
||||
```
|
||||
|
||||
**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner
|
||||
server is started before launching the actor.
|
||||
|
||||
**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the
|
||||
gamepad to take control of the robot during training. Initially intervene frequently, then gradually
|
||||
reduce interventions as the policy improves.
|
||||
|
||||
**WORKFLOW**:
|
||||
1. Determine robot workspace bounds using `find_joint_limits.py`
|
||||
2. Record demonstrations with `gym_manipulator.py` in record mode
|
||||
3. Process the dataset and determine camera crops with `crop_dataset_roi.py`
|
||||
4. Start the learner server with the training configuration
|
||||
5. Start this actor server with the same configuration
|
||||
6. Use human interventions to guide policy learning
|
||||
|
||||
For more details on the complete HILSerl training workflow, see:
|
||||
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from queue import Empty
|
||||
|
||||
import grpc
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.robots import so100_follower # noqa: F401
|
||||
from lerobot.scripts.rl import learner_service
|
||||
from lerobot.scripts.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport.utils import (
|
||||
bytes_to_state_dict,
|
||||
python_object_to_bytes,
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.queue import get_last_item_from_queue
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.transition import (
|
||||
Transition,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
from lerobot.utils.utils import (
|
||||
TimerManager,
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
)
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
|
||||
#################################################
|
||||
# Main entry point #
|
||||
#################################################
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def actor_cli(cfg: TrainRLServerPipelineConfig):
|
||||
cfg.validate()
|
||||
display_pid = False
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
display_pid = True
|
||||
|
||||
# Create logs directory to ensure it exists
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=display_pid)
|
||||
logging.info(f"Actor logging initialized, writing to {log_file}")
|
||||
|
||||
is_threaded = use_threads(cfg)
|
||||
shutdown_event = ProcessSignalHandler(is_threaded, display_pid=display_pid).shutdown_event
|
||||
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
logging.info("[ACTOR] Establishing connection with Learner")
|
||||
if not establish_learner_connection(learner_client, shutdown_event):
|
||||
logging.error("[ACTOR] Failed to establish connection with Learner")
|
||||
return
|
||||
|
||||
if not use_threads(cfg):
|
||||
# If we use multithreading, we can reuse the channel
|
||||
grpc_channel.close()
|
||||
grpc_channel = None
|
||||
|
||||
logging.info("[ACTOR] Connection with Learner established")
|
||||
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
|
||||
concurrency_entity = None
|
||||
if use_threads(cfg):
|
||||
from threading import Thread
|
||||
|
||||
concurrency_entity = Thread
|
||||
else:
|
||||
from multiprocessing import Process
|
||||
|
||||
concurrency_entity = Process
|
||||
|
||||
receive_policy_process = concurrency_entity(
|
||||
target=receive_policy,
|
||||
args=(cfg, parameters_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_process = concurrency_entity(
|
||||
target=send_transitions,
|
||||
args=(cfg, transitions_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
interactions_process = concurrency_entity(
|
||||
target=send_interactions,
|
||||
args=(cfg, interactions_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_process.start()
|
||||
interactions_process.start()
|
||||
receive_policy_process.start()
|
||||
|
||||
act_with_policy(
|
||||
cfg=cfg,
|
||||
shutdown_event=shutdown_event,
|
||||
parameters_queue=parameters_queue,
|
||||
transitions_queue=transitions_queue,
|
||||
interactions_queue=interactions_queue,
|
||||
)
|
||||
logging.info("[ACTOR] Policy process joined")
|
||||
|
||||
logging.info("[ACTOR] Closing queues")
|
||||
transitions_queue.close()
|
||||
interactions_queue.close()
|
||||
parameters_queue.close()
|
||||
|
||||
transitions_process.join()
|
||||
logging.info("[ACTOR] Transitions process joined")
|
||||
interactions_process.join()
|
||||
logging.info("[ACTOR] Interactions process joined")
|
||||
receive_policy_process.join()
|
||||
logging.info("[ACTOR] Receive policy process joined")
|
||||
|
||||
logging.info("[ACTOR] join queues")
|
||||
transitions_queue.cancel_join_thread()
|
||||
interactions_queue.cancel_join_thread()
|
||||
parameters_queue.cancel_join_thread()
|
||||
|
||||
logging.info("[ACTOR] queues closed")
|
||||
|
||||
|
||||
#################################################
|
||||
# Core algorithm functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def act_with_policy(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
shutdown_event: any, # Event,
|
||||
parameters_queue: Queue,
|
||||
transitions_queue: Queue,
|
||||
interactions_queue: Queue,
|
||||
):
|
||||
"""
|
||||
Executes policy interaction within the environment.
|
||||
|
||||
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
|
||||
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
|
||||
|
||||
Args:
|
||||
cfg: Configuration settings for the interaction process.
|
||||
shutdown_event: Event to check if the process should shutdown.
|
||||
parameters_queue: Queue to receive updated network parameters from the learner.
|
||||
transitions_queue: Queue to send transitions to the learner.
|
||||
interactions_queue: Queue to send interactions to the learner.
|
||||
"""
|
||||
# Initialize logging for multiprocessing
|
||||
if not use_threads(cfg):
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log")
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor policy process logging initialized")
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(cfg=cfg.env)
|
||||
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
list_transition_to_send_to_learner = []
|
||||
episode_intervention = False
|
||||
# Add counters for intervention rate calculation
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
policy_timer = TimerManager("Policy inference", log=False)
|
||||
|
||||
for interaction_step in range(cfg.policy.online_steps):
|
||||
start_time = time.perf_counter()
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
# Increment total steps counter for intervention rate
|
||||
episode_total_steps += 1
|
||||
|
||||
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
|
||||
if "is_intervention" in info and info["is_intervention"]:
|
||||
# NOTE: The action space for demonstration before hand is with the full action space
|
||||
# but sometimes for example we want to deactivate the gripper
|
||||
action = info["action_intervention"]
|
||||
episode_intervention = True
|
||||
# Increment intervention steps counter
|
||||
episode_intervention_steps += 1
|
||||
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
done=done,
|
||||
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
||||
complementary_info=info,
|
||||
)
|
||||
)
|
||||
# assign obs to the next obs and continue the rollout
|
||||
obs = next_obs
|
||||
|
||||
if done or truncated:
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
push_transitions_to_transport_queue(
|
||||
transitions=list_transition_to_send_to_learner,
|
||||
transitions_queue=transitions_queue,
|
||||
)
|
||||
list_transition_to_send_to_learner = []
|
||||
|
||||
stats = get_frequency_stats(policy_timer)
|
||||
policy_timer.reset()
|
||||
|
||||
# Calculate intervention rate
|
||||
intervention_rate = 0.0
|
||||
if episode_total_steps > 0:
|
||||
intervention_rate = episode_intervention_steps / episode_total_steps
|
||||
|
||||
# Send episodic reward to the learner
|
||||
interactions_queue.put(
|
||||
python_object_to_bytes(
|
||||
{
|
||||
"Episodic reward": sum_reward_episode,
|
||||
"Interaction step": interaction_step,
|
||||
"Episode intervention": int(episode_intervention),
|
||||
"Intervention rate": intervention_rate,
|
||||
**stats,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Reset intervention counters
|
||||
sum_reward_episode = 0.0
|
||||
episode_intervention = False
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
obs, info = online_env.reset()
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
busy_wait(1 / cfg.env.fps - dt_time)
|
||||
|
||||
|
||||
#################################################
|
||||
# Communication Functions - Group all gRPC/messaging functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def establish_learner_connection(
|
||||
stub: services_pb2_grpc.LearnerServiceStub,
|
||||
shutdown_event: Event, # type: ignore
|
||||
attempts: int = 30,
|
||||
):
|
||||
"""Establish a connection with the learner.
|
||||
|
||||
Args:
|
||||
stub (services_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
|
||||
shutdown_event (Event): The event to check if the connection should be established.
|
||||
attempts (int): The number of attempts to establish the connection.
|
||||
Returns:
|
||||
bool: True if the connection is established, False otherwise.
|
||||
"""
|
||||
for _ in range(attempts):
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down establish_learner_connection")
|
||||
return False
|
||||
|
||||
# Force a connection attempt and check state
|
||||
try:
|
||||
logging.info("[ACTOR] Send ready message to Learner")
|
||||
if stub.Ready(services_pb2.Empty()) == services_pb2.Empty():
|
||||
return True
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
|
||||
time.sleep(2)
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def learner_service_client(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 50051,
|
||||
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
||||
import json
|
||||
|
||||
"""
|
||||
Returns a client for the learner service.
|
||||
|
||||
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
|
||||
So we need to create only one client and reuse it.
|
||||
"""
|
||||
|
||||
service_config = {
|
||||
"methodConfig": [
|
||||
{
|
||||
"name": [{}], # Applies to ALL methods in ALL services
|
||||
"retryPolicy": {
|
||||
"maxAttempts": 5, # Max retries (total attempts = 5)
|
||||
"initialBackoff": "0.1s", # First retry after 0.1s
|
||||
"maxBackoff": "2s", # Max wait time between retries
|
||||
"backoffMultiplier": 2, # Exponential backoff factor
|
||||
"retryableStatusCodes": [
|
||||
"UNAVAILABLE",
|
||||
"DEADLINE_EXCEEDED",
|
||||
], # Retries on network failures
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
service_config_json = json.dumps(service_config)
|
||||
|
||||
channel = grpc.insecure_channel(
|
||||
f"{host}:{port}",
|
||||
options=[
|
||||
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.enable_retries", 1),
|
||||
("grpc.service_config", service_config_json),
|
||||
],
|
||||
)
|
||||
stub = services_pb2_grpc.LearnerServiceStub(channel)
|
||||
logging.info("[ACTOR] Learner service client created")
|
||||
return stub, channel
|
||||
|
||||
|
||||
def receive_policy(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
parameters_queue: Queue,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
):
|
||||
"""Receive parameters from the learner.
|
||||
|
||||
Args:
|
||||
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
|
||||
parameters_queue (Queue): The queue to receive the parameters.
|
||||
shutdown_event (Event): The event to check if the process should shutdown.
|
||||
"""
|
||||
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor receive policy process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
_ = ProcessSignalHandler(use_threads=False, display_pid=True)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
iterator = learner_client.StreamParameters(services_pb2.Empty())
|
||||
receive_bytes_in_chunks(
|
||||
iterator,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
log_prefix="[ACTOR] parameters",
|
||||
)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Received policy loop stopped")
|
||||
|
||||
|
||||
def send_transitions(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
transitions_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> services_pb2.Empty:
|
||||
"""
|
||||
Sends transitions to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- Transition Data:
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner.
|
||||
"""
|
||||
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor transitions process logging initialized")
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendTransitions(
|
||||
transitions_stream(
|
||||
shutdown_event, transitions_queue, cfg.policy.actor_learner_config.queue_get_timeout
|
||||
)
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
logging.info("[ACTOR] Finished streaming transitions")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Transitions process stopped")
|
||||
|
||||
|
||||
def send_interactions(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
interactions_queue: Queue,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> services_pb2.Empty:
|
||||
"""
|
||||
Sends interactions to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- Interaction Messages:
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
"""
|
||||
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor interactions process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
_ = ProcessSignalHandler(use_threads=False, display_pid=True)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendInteractions(
|
||||
interactions_stream(
|
||||
shutdown_event, interactions_queue, cfg.policy.actor_learner_config.queue_get_timeout
|
||||
)
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
logging.info("[ACTOR] Finished streaming interactions")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Interactions process stopped")
|
||||
|
||||
|
||||
def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = transitions_queue.get(block=True, timeout=timeout)
|
||||
except Empty:
|
||||
logging.debug("[ACTOR] Transition queue is empty")
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
message, services_pb2.Transition, log_prefix="[ACTOR] Send transitions"
|
||||
)
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
|
||||
def interactions_stream(
|
||||
shutdown_event: Event,
|
||||
interactions_queue: Queue,
|
||||
timeout: float, # type: ignore
|
||||
) -> services_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = interactions_queue.get(block=True, timeout=timeout)
|
||||
except Empty:
|
||||
logging.debug("[ACTOR] Interaction queue is empty")
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
message,
|
||||
services_pb2.InteractionMessage,
|
||||
log_prefix="[ACTOR] Send interactions",
|
||||
)
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
|
||||
#################################################
|
||||
# Policy functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
|
||||
if bytes_state_dict is not None:
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
state_dict = bytes_to_state_dict(bytes_state_dict)
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.load_state_dict(state_dict)
|
||||
|
||||
|
||||
#################################################
|
||||
# Utilities functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||
"""Send transitions to learner in smaller chunks to avoid network issues.
|
||||
|
||||
Args:
|
||||
transitions: List of transitions to send
|
||||
message_queue: Queue to send messages to learner
|
||||
chunk_size: Size of each chunk to send
|
||||
"""
|
||||
transition_to_send_to_learner = []
|
||||
for transition in transitions:
|
||||
tr = move_transition_to_device(transition=transition, device="cpu")
|
||||
for key, value in tr["state"].items():
|
||||
if torch.isnan(value).any():
|
||||
logging.warning(f"Found NaN values in transition {key}")
|
||||
|
||||
transition_to_send_to_learner.append(tr)
|
||||
|
||||
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
|
||||
|
||||
|
||||
def get_frequency_stats(timer: TimerManager) -> dict[str, float]:
|
||||
"""Get the frequency statistics of the policy.
|
||||
|
||||
Args:
|
||||
timer (TimerManager): The timer with collected metrics.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: The frequency statistics of the policy.
|
||||
"""
|
||||
stats = {}
|
||||
if timer.count > 1:
|
||||
avg_fps = timer.fps_avg
|
||||
p90_fps = timer.fps_percentile(90)
|
||||
logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}")
|
||||
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}")
|
||||
stats = {
|
||||
"Policy frequency [Hz]": avg_fps,
|
||||
"Policy frequency 90th-p [Hz]": p90_fps,
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def log_policy_frequency_issue(policy_fps: float, cfg: TrainRLServerPipelineConfig, interaction_step: int):
|
||||
if policy_fps < cfg.env.fps:
|
||||
logging.warning(
|
||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
|
||||
)
|
||||
|
||||
|
||||
def use_threads(cfg: TrainRLServerPipelineConfig) -> bool:
|
||||
return cfg.policy.concurrency.actor == "threads"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
actor_cli()
|
||||
314
src/lerobot/scripts/rl/crop_dataset_roi.py
Normal file
314
src/lerobot/scripts/rl/crop_dataset_roi.py
Normal file
@@ -0,0 +1,314 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import cv2
|
||||
|
||||
# import torch.nn.functional as F # noqa: N812
|
||||
import torchvision.transforms.functional as F # type: ignore # noqa: N812
|
||||
from tqdm import tqdm # type: ignore
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def select_rect_roi(img):
|
||||
"""
|
||||
Allows the user to draw a rectangular ROI on the image.
|
||||
|
||||
The user must click and drag to draw the rectangle.
|
||||
- While dragging, the rectangle is dynamically drawn.
|
||||
- On mouse button release, the rectangle is fixed.
|
||||
- Press 'c' to confirm the selection.
|
||||
- Press 'r' to reset the selection.
|
||||
- Press ESC to cancel.
|
||||
|
||||
Returns:
|
||||
A tuple (top, left, height, width) representing the rectangular ROI,
|
||||
or None if no valid ROI is selected.
|
||||
"""
|
||||
# Create a working copy of the image
|
||||
clone = img.copy()
|
||||
working_img = clone.copy()
|
||||
|
||||
roi = None # Will store the final ROI as (top, left, height, width)
|
||||
drawing = False
|
||||
index_x, index_y = -1, -1 # Initial click coordinates
|
||||
|
||||
def mouse_callback(event, x, y, flags, param):
|
||||
nonlocal index_x, index_y, drawing, roi, working_img
|
||||
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
# Start drawing: record starting coordinates
|
||||
drawing = True
|
||||
index_x, index_y = x, y
|
||||
|
||||
elif event == cv2.EVENT_MOUSEMOVE:
|
||||
if drawing:
|
||||
# Compute the top-left and bottom-right corners regardless of drag direction
|
||||
top = min(index_y, y)
|
||||
left = min(index_x, x)
|
||||
bottom = max(index_y, y)
|
||||
right = max(index_x, x)
|
||||
# Show a temporary image with the current rectangle drawn
|
||||
temp = working_img.copy()
|
||||
cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2)
|
||||
cv2.imshow("Select ROI", temp)
|
||||
|
||||
elif event == cv2.EVENT_LBUTTONUP:
|
||||
# Finish drawing
|
||||
drawing = False
|
||||
top = min(index_y, y)
|
||||
left = min(index_x, x)
|
||||
bottom = max(index_y, y)
|
||||
right = max(index_x, x)
|
||||
height = bottom - top
|
||||
width = right - left
|
||||
roi = (top, left, height, width) # (top, left, height, width)
|
||||
# Draw the final rectangle on the working image and display it
|
||||
working_img = clone.copy()
|
||||
cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2)
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
|
||||
# Create the window and set the callback
|
||||
cv2.namedWindow("Select ROI")
|
||||
cv2.setMouseCallback("Select ROI", mouse_callback)
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
|
||||
print("Instructions for ROI selection:")
|
||||
print(" - Click and drag to draw a rectangular ROI.")
|
||||
print(" - Press 'c' to confirm the selection.")
|
||||
print(" - Press 'r' to reset and draw again.")
|
||||
print(" - Press ESC to cancel the selection.")
|
||||
|
||||
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
|
||||
while True:
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
# Confirm ROI if one has been drawn
|
||||
if key == ord("c") and roi is not None:
|
||||
break
|
||||
# Reset: clear the ROI and restore the original image
|
||||
elif key == ord("r"):
|
||||
working_img = clone.copy()
|
||||
roi = None
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
# Cancel selection for this image
|
||||
elif key == 27: # ESC key
|
||||
roi = None
|
||||
break
|
||||
|
||||
cv2.destroyWindow("Select ROI")
|
||||
return roi
|
||||
|
||||
|
||||
def select_square_roi_for_images(images: dict) -> dict:
|
||||
"""
|
||||
For each image in the provided dictionary, open a window to allow the user
|
||||
to select a rectangular ROI. Returns a dictionary mapping each key to a tuple
|
||||
(top, left, height, width) representing the ROI.
|
||||
|
||||
Parameters:
|
||||
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
|
||||
|
||||
Returns:
|
||||
dict: Mapping of image keys to the selected rectangular ROI.
|
||||
"""
|
||||
selected_rois = {}
|
||||
|
||||
for key, img in images.items():
|
||||
if img is None:
|
||||
print(f"Image for key '{key}' is None, skipping.")
|
||||
continue
|
||||
|
||||
print(f"\nSelect rectangular ROI for image with key: '{key}'")
|
||||
roi = select_rect_roi(img)
|
||||
|
||||
if roi is None:
|
||||
print(f"No valid ROI selected for '{key}'.")
|
||||
else:
|
||||
selected_rois[key] = roi
|
||||
print(f"ROI for '{key}': {roi}")
|
||||
|
||||
return selected_rois
|
||||
|
||||
|
||||
def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
|
||||
"""
|
||||
Find the first row in the dataset and extract the image in order to be used for the crop.
|
||||
"""
|
||||
row = dataset[0]
|
||||
image_dict = {}
|
||||
for k in row:
|
||||
if "image" in k:
|
||||
image_dict[k] = deepcopy(row[k])
|
||||
return image_dict
|
||||
|
||||
|
||||
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset: LeRobotDataset,
|
||||
crop_params_dict: Dict[str, Tuple[int, int, int, int]],
|
||||
new_repo_id: str,
|
||||
new_dataset_root: str,
|
||||
resize_size: Tuple[int, int] = (128, 128),
|
||||
push_to_hub: bool = False,
|
||||
task: str = "",
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts an existing LeRobotDataset by iterating over its episodes and frames,
|
||||
applying cropping and resizing to image observations, and saving a new dataset
|
||||
with the transformed data.
|
||||
|
||||
Args:
|
||||
original_dataset (LeRobotDataset): The source dataset.
|
||||
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
|
||||
A dictionary mapping observation keys to crop parameters (top, left, height, width).
|
||||
new_repo_id (str): Repository id for the new dataset.
|
||||
new_dataset_root (str): The root directory where the new dataset will be written.
|
||||
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
|
||||
Defaults to (128, 128).
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped
|
||||
and resized.
|
||||
"""
|
||||
# 1. Create a new (empty) LeRobotDataset for writing.
|
||||
new_dataset = LeRobotDataset.create(
|
||||
repo_id=new_repo_id,
|
||||
fps=original_dataset.fps,
|
||||
root=new_dataset_root,
|
||||
robot_type=original_dataset.meta.robot_type,
|
||||
features=original_dataset.meta.info["features"],
|
||||
use_videos=len(original_dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Update the metadata for every image key that will be cropped:
|
||||
# (Here we simply set the shape to be the final resize_size.)
|
||||
for key in crop_params_dict:
|
||||
if key in new_dataset.meta.info["features"]:
|
||||
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
|
||||
|
||||
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
|
||||
prev_episode_index = 0
|
||||
for frame_idx in tqdm(range(len(original_dataset))):
|
||||
frame = original_dataset[frame_idx]
|
||||
|
||||
# Create a copy of the frame to add to the new dataset
|
||||
new_frame = {}
|
||||
for key, value in frame.items():
|
||||
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index", "task"):
|
||||
continue
|
||||
if key in ("next.done", "next.reward"):
|
||||
# if not isinstance(value, str) and len(value.shape) == 0:
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
if key in crop_params_dict:
|
||||
top, left, height, width = crop_params_dict[key]
|
||||
# Apply crop then resize.
|
||||
cropped = F.crop(value, top, left, height, width)
|
||||
value = F.resize(cropped, resize_size)
|
||||
value = value.clamp(0, 1)
|
||||
|
||||
new_frame[key] = value
|
||||
|
||||
new_dataset.add_frame(new_frame, task=task)
|
||||
|
||||
if frame["episode_index"].item() != prev_episode_index:
|
||||
# Save the episode
|
||||
new_dataset.save_episode()
|
||||
prev_episode_index = frame["episode_index"].item()
|
||||
|
||||
# Save the last episode
|
||||
new_dataset.save_episode()
|
||||
|
||||
if push_to_hub:
|
||||
new_dataset.push_to_hub()
|
||||
|
||||
return new_dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot",
|
||||
help="The repository id of the LeRobot dataset to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The root directory of the LeRobot dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop-params-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the JSON file containing the ROIs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to push the new dataset to the hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="",
|
||||
help="The natural language task to describe the dataset.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
|
||||
|
||||
images = get_image_from_lerobot_dataset(dataset)
|
||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
|
||||
|
||||
if args.crop_params_path is None:
|
||||
rois = select_square_roi_for_images(images)
|
||||
else:
|
||||
with open(args.crop_params_path) as f:
|
||||
rois = json.load(f)
|
||||
|
||||
# Print the selected rectangular ROIs
|
||||
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
|
||||
for key, roi in rois.items():
|
||||
print(f"{key}: {roi}")
|
||||
|
||||
new_repo_id = args.repo_id + "_cropped_resized"
|
||||
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
|
||||
|
||||
cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset=dataset,
|
||||
crop_params_dict=rois,
|
||||
new_repo_id=new_repo_id,
|
||||
new_dataset_root=new_dataset_root,
|
||||
resize_size=(128, 128),
|
||||
push_to_hub=args.push_to_hub,
|
||||
task=args.task,
|
||||
)
|
||||
|
||||
meta_dir = new_dataset_root / "meta"
|
||||
meta_dir.mkdir(exist_ok=True)
|
||||
|
||||
with open(meta_dir / "crop_params.json", "w") as f:
|
||||
json.dump(rois, f, indent=4)
|
||||
74
src/lerobot/scripts/rl/eval_policy.py
Normal file
74
src/lerobot/scripts/rl/eval_policy.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
RobotConfig,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
)
|
||||
from lerobot.scripts.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.teleoperators import (
|
||||
gamepad, # noqa: F401
|
||||
so101_leader, # noqa: F401
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def eval_policy(env, policy, n_episodes):
|
||||
sum_reward_episode = []
|
||||
for _ in range(n_episodes):
|
||||
obs, _ = env.reset()
|
||||
episode_reward = 0.0
|
||||
while True:
|
||||
action = policy.select_action(obs)
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
episode_reward += reward
|
||||
if terminated or truncated:
|
||||
break
|
||||
sum_reward_episode.append(episode_reward)
|
||||
|
||||
logging.info(f"Success after 20 steps {sum_reward_episode}")
|
||||
logging.info(f"success rate {sum(sum_reward_episode) / len(sum_reward_episode)}")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: TrainRLServerPipelineConfig):
|
||||
env_cfg = cfg.env
|
||||
env = make_robot_env(env_cfg)
|
||||
dataset_cfg = cfg.dataset
|
||||
dataset = LeRobotDataset(repo_id=dataset_cfg.repo_id)
|
||||
dataset_meta = dataset.meta
|
||||
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
# env_cfg=cfg.env,
|
||||
ds_meta=dataset_meta,
|
||||
)
|
||||
policy.from_pretrained(env_cfg.pretrained_policy_name_or_path)
|
||||
policy.eval()
|
||||
|
||||
eval_policy(env, policy=policy, n_episodes=10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2266
src/lerobot/scripts/rl/gym_manipulator.py
Normal file
2266
src/lerobot/scripts/rl/gym_manipulator.py
Normal file
File diff suppressed because it is too large
Load Diff
1206
src/lerobot/scripts/rl/learner.py
Normal file
1206
src/lerobot/scripts/rl/learner.py
Normal file
File diff suppressed because it is too large
Load Diff
118
src/lerobot/scripts/rl/learner_service.py
Normal file
118
src/lerobot/scripts/rl/learner_service.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
||||
from lerobot.utils.queue import get_last_item_from_queue
|
||||
|
||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||
SHUTDOWN_TIMEOUT = 10
|
||||
|
||||
|
||||
class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
"""
|
||||
Implementation of the LearnerService gRPC service
|
||||
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
|
||||
check transport.proto for the gRPC service definition
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shutdown_event: Event, # type: ignore
|
||||
parameters_queue: Queue,
|
||||
seconds_between_pushes: float,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
queue_get_timeout: float = 0.001,
|
||||
):
|
||||
self.shutdown_event = shutdown_event
|
||||
self.parameters_queue = parameters_queue
|
||||
self.seconds_between_pushes = seconds_between_pushes
|
||||
self.transition_queue = transition_queue
|
||||
self.interaction_message_queue = interaction_message_queue
|
||||
self.queue_get_timeout = queue_get_timeout
|
||||
|
||||
def StreamParameters(self, request, context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
||||
|
||||
last_push_time = 0
|
||||
|
||||
while not self.shutdown_event.is_set():
|
||||
time_since_last_push = time.time() - last_push_time
|
||||
if time_since_last_push < self.seconds_between_pushes:
|
||||
self.shutdown_event.wait(self.seconds_between_pushes - time_since_last_push)
|
||||
# Continue, because we could receive a shutdown event,
|
||||
# and it's checked in the while loop
|
||||
continue
|
||||
|
||||
logging.info("[LEARNER] Push parameters to the Actor")
|
||||
buffer = get_last_item_from_queue(
|
||||
self.parameters_queue, block=True, timeout=self.queue_get_timeout
|
||||
)
|
||||
|
||||
if buffer is None:
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
buffer,
|
||||
services_pb2.Parameters,
|
||||
log_prefix="[LEARNER] Sending parameters",
|
||||
silent=True,
|
||||
)
|
||||
|
||||
last_push_time = time.time()
|
||||
logging.info("[LEARNER] Parameters sent")
|
||||
|
||||
logging.info("[LEARNER] Stream parameters finished")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendTransitions(self, request_iterator, _context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.transition_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] transitions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving transitions")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendInteractions(self, request_iterator, _context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive interactions from the Actor")
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.interaction_message_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] interactions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving interactions")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
return services_pb2.Empty()
|
||||
Reference in New Issue
Block a user