mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
add healthy route
This commit is contained in:
@@ -27,3 +27,7 @@ SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"]
|
||||
|
||||
"""Networking support"""
|
||||
HEALTH_CHECK_PORT = 8081
|
||||
HEALTH_SERVER_HOST = "0.0.0.0" # nosec
|
||||
|
||||
@@ -13,11 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from pathlib import Path
|
||||
from threading import Event
|
||||
from typing import Any
|
||||
@@ -206,6 +208,83 @@ def get_logger(name: str, log_to_file: bool = True) -> logging.Logger:
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def create_health_handler(policy_server):
|
||||
"""Factory function to create health handler with policy server reference."""
|
||||
|
||||
def handler(*args, **kwargs):
|
||||
return HealthHandler(policy_server, *args, **kwargs)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
class HealthHandler(BaseHTTPRequestHandler):
|
||||
"""HTTP handler for health checks."""
|
||||
|
||||
def __init__(self, policy_server, *args, **kwargs):
|
||||
self.policy_server = policy_server
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def do_GET(self): # noqa: N802
|
||||
"""Handle GET requests for health check."""
|
||||
if self.path == "/health":
|
||||
self.send_health_response()
|
||||
elif self.path == "/":
|
||||
self.send_info_response()
|
||||
else:
|
||||
self.send_error(404, "Not Found")
|
||||
|
||||
def send_health_response(self):
|
||||
"""Send health check response."""
|
||||
try:
|
||||
# Check if the policy server is in a healthy state
|
||||
is_healthy = (
|
||||
hasattr(self.policy_server, "_running_event")
|
||||
and self.policy_server._running_event is not None
|
||||
)
|
||||
|
||||
status_code = 200 if is_healthy else 503
|
||||
response = {
|
||||
"status": "healthy" if is_healthy else "unhealthy",
|
||||
"timestamp": time.time(),
|
||||
"server_running": self.policy_server.running
|
||||
if hasattr(self.policy_server, "running")
|
||||
else False,
|
||||
"policy_loaded": self.policy_server.policy is not None
|
||||
if hasattr(self.policy_server, "policy")
|
||||
else False,
|
||||
}
|
||||
|
||||
self.send_response(status_code)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(response).encode())
|
||||
|
||||
except Exception as e:
|
||||
self.send_response(500)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
error_response = {"status": "error", "message": str(e)}
|
||||
self.wfile.write(json.dumps(error_response).encode())
|
||||
|
||||
def send_info_response(self):
|
||||
"""Send basic server info."""
|
||||
response = {
|
||||
"service": "lerobot-policy-server",
|
||||
"version": "1.0.0",
|
||||
"endpoints": {"health": "/health", "grpc_port": self.policy_server.config.port},
|
||||
}
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(response).encode())
|
||||
|
||||
def log_message(self, format, *args):
|
||||
"""Override to use our logger instead of stderr."""
|
||||
if hasattr(self.policy_server, "logger"):
|
||||
self.policy_server.logger.debug(f"HTTP: {format % args}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedData:
|
||||
"""A data object with timestamp and timestep information.
|
||||
|
||||
@@ -30,6 +30,7 @@ import threading
|
||||
import time
|
||||
from concurrent import futures
|
||||
from dataclasses import asdict
|
||||
from http.server import HTTPServer
|
||||
from pprint import pformat
|
||||
from queue import Empty, Queue
|
||||
|
||||
@@ -39,13 +40,14 @@ import torch
|
||||
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
from lerobot.scripts.server.configs import PolicyServerConfig
|
||||
from lerobot.scripts.server.constants import SUPPORTED_POLICIES
|
||||
from lerobot.scripts.server.constants import HEALTH_CHECK_PORT, HEALTH_SERVER_HOST, SUPPORTED_POLICIES
|
||||
from lerobot.scripts.server.helpers import (
|
||||
FPSTracker,
|
||||
Observation,
|
||||
RemotePolicyConfig,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
create_health_handler,
|
||||
get_logger,
|
||||
observations_similar,
|
||||
raw_observation_to_observation,
|
||||
@@ -82,6 +84,30 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
self.actions_per_chunk = None
|
||||
self.policy = None
|
||||
|
||||
# HTTP health server
|
||||
self.http_server = None
|
||||
self.http_thread = None
|
||||
|
||||
def start_health_server(self):
|
||||
"""Start HTTP server for health checks on port 8081."""
|
||||
try:
|
||||
health_handler = create_health_handler(self)
|
||||
self.http_server = HTTPServer((HEALTH_SERVER_HOST, HEALTH_CHECK_PORT), health_handler)
|
||||
self.http_thread = threading.Thread(target=self.http_server.serve_forever, daemon=True)
|
||||
self.http_thread.start()
|
||||
self.logger.info(f"Health server started on port {HEALTH_CHECK_PORT}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to start health server: {e}")
|
||||
|
||||
def stop_health_server(self):
|
||||
"""Stop the HTTP health server."""
|
||||
if self.http_server:
|
||||
self.http_server.shutdown()
|
||||
self.http_server.server_close()
|
||||
if self.http_thread:
|
||||
self.http_thread.join(timeout=5)
|
||||
self.logger.info("Health server stopped")
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return self._running_event.is_set()
|
||||
@@ -372,6 +398,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
"""Stop the server"""
|
||||
self._reset_server()
|
||||
self.logger.info("Server stopping...")
|
||||
self.stop_health_server()
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
@@ -385,6 +412,7 @@ def serve(cfg: PolicyServerConfig):
|
||||
|
||||
# Create the server instance first
|
||||
policy_server = PolicyServer(cfg)
|
||||
policy_server.start_health_server()
|
||||
|
||||
# Setup and start gRPC server
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||||
@@ -397,6 +425,7 @@ def serve(cfg: PolicyServerConfig):
|
||||
server.wait_for_termination()
|
||||
|
||||
policy_server.logger.info("Server terminated")
|
||||
policy_server.stop_health_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user