[Async Inference] Merge Protos & refactoring (#1480)

* Merge together proto files and refactor Async inference

* Fixup for Async inference

* Drop not reuqired changes

* Fix tests

* Drop old async files

* Drop chunk_size param

* Fix versions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix wrong fix

Co-authored-by: Ben Zhang <ben.zhang@uwaterloo.ca>

* Fixup

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Ben Zhang <ben.zhang@uwaterloo.ca>
Co-authored-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
Eugene Mironov
2025-07-23 16:30:01 +07:00
committed by GitHub
parent f5d6b5b3a7
commit 989f3d05ba
12 changed files with 299 additions and 518 deletions

View File

@@ -69,15 +69,14 @@ from lerobot.scripts.server.helpers import (
TimedObservation,
get_logger,
map_robot_keys_to_lerobot_features,
send_bytes_in_chunks,
validate_robot_cameras_for_policy,
visualize_action_queue_size,
)
from lerobot.transport import (
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
services_pb2, # type: ignore
services_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import grpc_channel_options
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
class RobotClient:
@@ -118,10 +117,10 @@ class RobotClient:
self.channel = grpc.insecure_channel(
self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
)
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel)
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
self._running_event = threading.Event()
self.shutdown_event = threading.Event()
# Initialize client side variables
self.latest_action_lock = threading.Lock()
@@ -146,20 +145,20 @@ class RobotClient:
@property
def running(self):
return self._running_event.is_set()
return not self.shutdown_event.is_set()
def start(self):
"""Start the robot client and connect to the policy server"""
try:
# client-server handshake
start_time = time.perf_counter()
self.stub.Ready(async_inference_pb2.Empty())
self.stub.Ready(services_pb2.Empty())
end_time = time.perf_counter()
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
# send policy instructions
policy_config_bytes = pickle.dumps(self.policy_config)
policy_setup = async_inference_pb2.PolicySetup(data=policy_config_bytes)
policy_setup = services_pb2.PolicySetup(data=policy_config_bytes)
self.logger.info("Sending policy instructions to policy server")
self.logger.debug(
@@ -170,7 +169,7 @@ class RobotClient:
self.stub.SendPolicyInstructions(policy_setup)
self._running_event.set()
self.shutdown_event.clear()
return True
@@ -180,7 +179,7 @@ class RobotClient:
def stop(self):
"""Stop the robot client"""
self._running_event.clear()
self.shutdown_event.set()
self.robot.disconnect()
self.logger.debug("Robot disconnected")
@@ -208,7 +207,7 @@ class RobotClient:
try:
observation_iterator = send_bytes_in_chunks(
observation_bytes,
async_inference_pb2.Observation,
services_pb2.Observation,
log_prefix="[CLIENT] Observation",
silent=True,
)
@@ -283,7 +282,7 @@ class RobotClient:
while self.running:
try:
# Use StreamActions to get a stream of actions from the server
actions_chunk = self.stub.GetActions(async_inference_pb2.Empty())
actions_chunk = self.stub.GetActions(services_pb2.Empty())
if len(actions_chunk.data) == 0:
continue # received `Empty` from server, wait for next call