add healthy route

This commit is contained in:
Francesco Capuano
2025-07-10 17:46:26 +02:00
parent abe51eeba3
commit b6eb651bab
3 changed files with 113 additions and 1 deletions

View File

@@ -27,3 +27,7 @@ SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"]
# TODO: Add all other robots
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"]
"""Networking support"""
HEALTH_CHECK_PORT = 8081
HEALTH_SERVER_HOST = "0.0.0.0" # nosec

View File

@@ -13,11 +13,13 @@
# limitations under the License.
import io
import json
import logging
import logging.handlers
import os
import time
from dataclasses import dataclass
from http.server import BaseHTTPRequestHandler
from pathlib import Path
from threading import Event
from typing import Any
@@ -206,6 +208,83 @@ def get_logger(name: str, log_to_file: bool = True) -> logging.Logger:
return logging.getLogger(name)
def create_health_handler(policy_server):
"""Factory function to create health handler with policy server reference."""
def handler(*args, **kwargs):
return HealthHandler(policy_server, *args, **kwargs)
return handler
class HealthHandler(BaseHTTPRequestHandler):
"""HTTP handler for health checks."""
def __init__(self, policy_server, *args, **kwargs):
self.policy_server = policy_server
super().__init__(*args, **kwargs)
def do_GET(self): # noqa: N802
"""Handle GET requests for health check."""
if self.path == "/health":
self.send_health_response()
elif self.path == "/":
self.send_info_response()
else:
self.send_error(404, "Not Found")
def send_health_response(self):
"""Send health check response."""
try:
# Check if the policy server is in a healthy state
is_healthy = (
hasattr(self.policy_server, "_running_event")
and self.policy_server._running_event is not None
)
status_code = 200 if is_healthy else 503
response = {
"status": "healthy" if is_healthy else "unhealthy",
"timestamp": time.time(),
"server_running": self.policy_server.running
if hasattr(self.policy_server, "running")
else False,
"policy_loaded": self.policy_server.policy is not None
if hasattr(self.policy_server, "policy")
else False,
}
self.send_response(status_code)
self.send_header("Content-type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(response).encode())
except Exception as e:
self.send_response(500)
self.send_header("Content-type", "application/json")
self.end_headers()
error_response = {"status": "error", "message": str(e)}
self.wfile.write(json.dumps(error_response).encode())
def send_info_response(self):
"""Send basic server info."""
response = {
"service": "lerobot-policy-server",
"version": "1.0.0",
"endpoints": {"health": "/health", "grpc_port": self.policy_server.config.port},
}
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(response).encode())
def log_message(self, format, *args):
"""Override to use our logger instead of stderr."""
if hasattr(self.policy_server, "logger"):
self.policy_server.logger.debug(f"HTTP: {format % args}")
@dataclass
class TimedData:
"""A data object with timestamp and timestep information.

View File

@@ -30,6 +30,7 @@ import threading
import time
from concurrent import futures
from dataclasses import asdict
from http.server import HTTPServer
from pprint import pformat
from queue import Empty, Queue
@@ -39,13 +40,14 @@ import torch
from lerobot.policies.factory import get_policy_class
from lerobot.scripts.server.configs import PolicyServerConfig
from lerobot.scripts.server.constants import SUPPORTED_POLICIES
from lerobot.scripts.server.constants import HEALTH_CHECK_PORT, HEALTH_SERVER_HOST, SUPPORTED_POLICIES
from lerobot.scripts.server.helpers import (
FPSTracker,
Observation,
RemotePolicyConfig,
TimedAction,
TimedObservation,
create_health_handler,
get_logger,
observations_similar,
raw_observation_to_observation,
@@ -82,6 +84,30 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
self.actions_per_chunk = None
self.policy = None
# HTTP health server
self.http_server = None
self.http_thread = None
def start_health_server(self):
"""Start HTTP server for health checks on port 8081."""
try:
health_handler = create_health_handler(self)
self.http_server = HTTPServer((HEALTH_SERVER_HOST, HEALTH_CHECK_PORT), health_handler)
self.http_thread = threading.Thread(target=self.http_server.serve_forever, daemon=True)
self.http_thread.start()
self.logger.info(f"Health server started on port {HEALTH_CHECK_PORT}")
except Exception as e:
self.logger.error(f"Failed to start health server: {e}")
def stop_health_server(self):
"""Stop the HTTP health server."""
if self.http_server:
self.http_server.shutdown()
self.http_server.server_close()
if self.http_thread:
self.http_thread.join(timeout=5)
self.logger.info("Health server stopped")
@property
def running(self):
return self._running_event.is_set()
@@ -372,6 +398,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
"""Stop the server"""
self._reset_server()
self.logger.info("Server stopping...")
self.stop_health_server()
@draccus.wrap()
@@ -385,6 +412,7 @@ def serve(cfg: PolicyServerConfig):
# Create the server instance first
policy_server = PolicyServer(cfg)
policy_server.start_health_server()
# Setup and start gRPC server
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
@@ -397,6 +425,7 @@ def serve(cfg: PolicyServerConfig):
server.wait_for_termination()
policy_server.logger.info("Server terminated")
policy_server.stop_health_server()
if __name__ == "__main__":