Files
lerobot-clone/src/lerobot/async_inference/robot_client.py

500 lines
19 KiB
Python
Raw Normal View History

# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example command:
```shell
python src/lerobot/async_inference/robot_client.py \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
--robot.id=black \
--task="dummy" \
--server_address=127.0.0.1:8080 \
--policy_type=act \
--pretrained_name_or_path=user/model \
--policy_device=mps \
--actions_per_chunk=50 \
--chunk_size_threshold=0.5 \
--aggregate_fn_name=weighted_average \
--debug_visualize_queue_size=True
```
"""
import logging
import pickle # nosec
import threading
import time
from collections.abc import Callable
from dataclasses import asdict
from pprint import pformat
from queue import Queue
import draccus
import grpc
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.processor import RobotAction
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_so_follower,
koch_follower,
make_robot_from_config,
omx_follower,
so_follower,
)
from lerobot.transport import (
services_pb2, # type: ignore
services_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
from .configs import RobotClientConfig
from .constants import SUPPORTED_ROBOTS
from .helpers import (
Action,
FPSTracker,
Observation,
RawObservation,
RemotePolicyConfig,
TimedAction,
TimedObservation,
get_logger,
map_robot_keys_to_lerobot_features,
visualize_action_queue_size,
)
class RobotClient:
prefix = "robot_client"
logger = get_logger(prefix)
def __init__(self, config: RobotClientConfig):
"""Initialize RobotClient with unified configuration.
Args:
config: RobotClientConfig containing all configuration parameters
"""
# Store configuration
self.config = config
self.robot = make_robot_from_config(config.robot)
self.robot.connect()
lerobot_features = map_robot_keys_to_lerobot_features(self.robot)
# Use environment variable if server_address is not provided in config
self.server_address = config.server_address
self.policy_config = RemotePolicyConfig(
config.policy_type,
config.pretrained_name_or_path,
lerobot_features,
config.actions_per_chunk,
config.policy_device,
)
self.channel = grpc.insecure_channel(
self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
)
self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel)
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
self.shutdown_event = threading.Event()
# Initialize client side variables
self.latest_action_lock = threading.Lock()
self.latest_action = -1
self.action_chunk_size = -1
self._chunk_size_threshold = config.chunk_size_threshold
self.action_queue = Queue()
self.action_queue_lock = threading.Lock() # Protect queue operations
self.action_queue_size = []
self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop
# FPS measurement
self.fps_tracker = FPSTracker(target_fps=self.config.fps)
self.logger.info("Robot connected and ready")
# Use an event for thread-safe coordination
self.must_go = threading.Event()
self.must_go.set() # Initially set - observations qualify for direct processing
@property
def running(self):
return not self.shutdown_event.is_set()
def start(self):
"""Start the robot client and connect to the policy server"""
try:
# client-server handshake
start_time = time.perf_counter()
self.stub.Ready(services_pb2.Empty())
end_time = time.perf_counter()
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
# send policy instructions
policy_config_bytes = pickle.dumps(self.policy_config)
policy_setup = services_pb2.PolicySetup(data=policy_config_bytes)
self.logger.info("Sending policy instructions to policy server")
self.logger.debug(
f"Policy type: {self.policy_config.policy_type} | "
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
f"Device: {self.policy_config.device}"
)
self.stub.SendPolicyInstructions(policy_setup)
self.shutdown_event.clear()
return True
except grpc.RpcError as e:
self.logger.error(f"Failed to connect to policy server: {e}")
return False
def stop(self):
"""Stop the robot client"""
self.shutdown_event.set()
self.robot.disconnect()
self.logger.debug("Robot disconnected")
self.channel.close()
self.logger.debug("Client stopped, channel closed")
def send_observation(
self,
obs: TimedObservation,
) -> bool:
"""Send observation to the policy server.
Returns True if the observation was sent successfully, False otherwise."""
if not self.running:
raise RuntimeError("Client not running. Run RobotClient.start() before sending observations.")
if not isinstance(obs, TimedObservation):
raise ValueError("Input observation needs to be a TimedObservation!")
start_time = time.perf_counter()
observation_bytes = pickle.dumps(obs)
serialize_time = time.perf_counter() - start_time
self.logger.debug(f"Observation serialization time: {serialize_time:.6f}s")
try:
observation_iterator = send_bytes_in_chunks(
observation_bytes,
services_pb2.Observation,
log_prefix="[CLIENT] Observation",
silent=True,
)
_ = self.stub.SendObservations(observation_iterator)
obs_timestep = obs.get_timestep()
self.logger.debug(f"Sent observation #{obs_timestep} | ")
return True
except grpc.RpcError as e:
self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
return False
def _inspect_action_queue(self):
with self.action_queue_lock:
queue_size = self.action_queue.qsize()
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
return queue_size, timestamps
def _aggregate_action_queues(
self,
incoming_actions: list[TimedAction],
aggregate_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
):
"""Finds the same timestep actions in the queue and aggregates them using the aggregate_fn"""
if aggregate_fn is None:
# default aggregate function: take the latest action
def aggregate_fn(x1, x2):
return x2
future_action_queue = Queue()
with self.action_queue_lock:
internal_queue = self.action_queue.queue
current_action_queue = {action.get_timestep(): action.get_action() for action in internal_queue}
for new_action in incoming_actions:
with self.latest_action_lock:
latest_action = self.latest_action
# New action is older than the latest action in the queue, skip it
if new_action.get_timestep() <= latest_action:
continue
# If the new action's timestep is not in the current action queue, add it directly
elif new_action.get_timestep() not in current_action_queue:
future_action_queue.put(new_action)
continue
# If the new action's timestep is in the current action queue, aggregate it
# TODO: There is probably a way to do this with broadcasting of the two action tensors
future_action_queue.put(
TimedAction(
timestamp=new_action.get_timestamp(),
timestep=new_action.get_timestep(),
action=aggregate_fn(
current_action_queue[new_action.get_timestep()], new_action.get_action()
),
)
)
with self.action_queue_lock:
self.action_queue = future_action_queue
def receive_actions(self, verbose: bool = False):
"""Receive actions from the policy server"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
self.logger.info("Action receiving thread starting")
while self.running:
try:
# Use StreamActions to get a stream of actions from the server
actions_chunk = self.stub.GetActions(services_pb2.Empty())
if len(actions_chunk.data) == 0:
continue # received `Empty` from server, wait for next call
receive_time = time.time()
# Deserialize bytes back into list[TimedAction]
deserialize_start = time.perf_counter()
timed_actions = pickle.loads(actions_chunk.data) # nosec
deserialize_time = time.perf_counter() - deserialize_start
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
# Calculate network latency if we have matching observations
if len(timed_actions) > 0 and verbose:
with self.latest_action_lock:
latest_action = self.latest_action
self.logger.debug(f"Current latest action: {latest_action}")
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:
old_timesteps = [latest_action] # queue was empty
# Log incoming actions
incoming_timesteps = [a.get_timestep() for a in timed_actions]
first_action_timestep = timed_actions[0].get_timestep()
server_to_client_latency = (receive_time - timed_actions[0].get_timestamp()) * 1000
self.logger.info(
f"Received action chunk for step #{first_action_timestep} | "
f"Latest action: #{latest_action} | "
f"Incoming actions: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
f"Network latency (server->client): {server_to_client_latency:.2f}ms | "
f"Deserialization time: {deserialize_time * 1000:.2f}ms"
)
# Update action queue
start_time = time.perf_counter()
self._aggregate_action_queues(timed_actions, self.config.aggregate_fn)
queue_update_time = time.perf_counter() - start_time
self.must_go.set() # after receiving actions, next empty queue triggers must-go processing!
if verbose:
# Get queue state after changes
new_size, new_timesteps = self._inspect_action_queue()
with self.latest_action_lock:
latest_action = self.latest_action
self.logger.info(
f"Latest action: {latest_action} | "
f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | "
f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}"
)
self.logger.debug(
f"Queue update complete ({queue_update_time:.6f}s) | "
f"Before: {old_size} items | "
f"After: {new_size} items | "
)
except grpc.RpcError as e:
self.logger.error(f"Error receiving actions: {e}")
def actions_available(self):
"""Check if there are actions available in the queue"""
with self.action_queue_lock:
return not self.action_queue.empty()
def _action_tensor_to_action_dict(self, action_tensor: torch.Tensor) -> dict[str, float]:
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
return action
def control_loop_action(self, verbose: bool = False) -> RobotAction:
"""Reading and performing actions in local queue"""
# Lock only for queue operations
get_start = time.perf_counter()
with self.action_queue_lock:
self.action_queue_size.append(self.action_queue.qsize())
# Get action from queue
timed_action = self.action_queue.get_nowait()
get_end = time.perf_counter() - get_start
_performed_action = self.robot.send_action(
self._action_tensor_to_action_dict(timed_action.get_action())
)
with self.latest_action_lock:
self.latest_action = timed_action.get_timestep()
if verbose:
with self.action_queue_lock:
current_queue_size = self.action_queue.qsize()
self.logger.debug(
f"Ts={timed_action.get_timestamp()} | "
f"Action #{timed_action.get_timestep()} performed | "
f"Queue size: {current_queue_size}"
)
self.logger.debug(
f"Popping action from queue to perform took {get_end:.6f}s | Queue size: {current_queue_size}"
)
return _performed_action
def _ready_to_send_observation(self):
"""Flags when the client is ready to send an observation"""
with self.action_queue_lock:
return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold
def control_loop_observation(self, task: str, verbose: bool = False) -> RawObservation:
try:
# Get serialized observation bytes from the function
start_time = time.perf_counter()
raw_observation: RawObservation = self.robot.get_observation()
raw_observation["task"] = task
with self.latest_action_lock:
latest_action = self.latest_action
observation = TimedObservation(
timestamp=time.time(), # need time.time() to compare timestamps across client and server
observation=raw_observation,
timestep=max(latest_action, 0),
)
obs_capture_time = time.perf_counter() - start_time
# If there are no actions left in the queue, the observation must go through processing!
with self.action_queue_lock:
observation.must_go = self.must_go.is_set() and self.action_queue.empty()
current_queue_size = self.action_queue.qsize()
_ = self.send_observation(observation)
self.logger.debug(f"QUEUE SIZE: {current_queue_size} (Must go: {observation.must_go})")
if observation.must_go:
# must-go event will be set again after receiving actions
self.must_go.clear()
if verbose:
# Calculate comprehensive FPS metrics
fps_metrics = self.fps_tracker.calculate_fps_metrics(observation.get_timestamp())
self.logger.info(
f"Obs #{observation.get_timestep()} | "
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | "
f"Target: {fps_metrics['target_fps']:.2f}"
)
self.logger.debug(
f"Ts={observation.get_timestamp():.6f} | Capturing observation took {obs_capture_time:.6f}s"
)
return raw_observation
except Exception as e:
self.logger.error(f"Error in observation sender: {e}")
def control_loop(self, task: str, verbose: bool = False) -> tuple[Observation, Action]:
"""Combined function for executing actions and streaming observations"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
self.logger.info("Control loop thread starting")
_performed_action = None
_captured_observation = None
while self.running:
control_loop_start = time.perf_counter()
"""Control loop: (1) Performing actions, when available"""
if self.actions_available():
_performed_action = self.control_loop_action(verbose)
"""Control loop: (2) Streaming observations to the remote policy server"""
if self._ready_to_send_observation():
_captured_observation = self.control_loop_observation(task, verbose)
self.logger.debug(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
# Dynamically adjust sleep time to maintain the desired control frequency
time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start)))
return _captured_observation, _performed_action
@draccus.wrap()
def async_client(cfg: RobotClientConfig):
logging.info(pformat(asdict(cfg)))
if cfg.robot.type not in SUPPORTED_ROBOTS:
raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
client = RobotClient(cfg)
if client.start():
client.logger.info("Starting action receiver thread...")
# Create and start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
# Start action receiver thread
action_receiver_thread.start()
try:
# The main thread runs the control loop
client.control_loop(task=cfg.task)
finally:
client.stop()
action_receiver_thread.join()
if cfg.debug_visualize_queue_size:
visualize_action_queue_size(client.action_queue_size)
client.logger.info("Client stopped")
if __name__ == "__main__":
async_client() # run the client