diff --git a/src/lerobot/scripts/server/constants.py b/src/lerobot/scripts/server/constants.py index af983a800..291b0162e 100644 --- a/src/lerobot/scripts/server/constants.py +++ b/src/lerobot/scripts/server/constants.py @@ -27,3 +27,7 @@ SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"] # TODO: Add all other robots SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"] + +"""Networking support""" +HEALTH_CHECK_PORT = 8081 +HEALTH_SERVER_HOST = "0.0.0.0" # nosec diff --git a/src/lerobot/scripts/server/helpers.py b/src/lerobot/scripts/server/helpers.py index 7fd56e693..de5f37d3b 100644 --- a/src/lerobot/scripts/server/helpers.py +++ b/src/lerobot/scripts/server/helpers.py @@ -13,11 +13,13 @@ # limitations under the License. import io +import json import logging import logging.handlers import os import time from dataclasses import dataclass +from http.server import BaseHTTPRequestHandler from pathlib import Path from threading import Event from typing import Any @@ -206,6 +208,83 @@ def get_logger(name: str, log_to_file: bool = True) -> logging.Logger: return logging.getLogger(name) +def create_health_handler(policy_server): + """Factory function to create health handler with policy server reference.""" + + def handler(*args, **kwargs): + return HealthHandler(policy_server, *args, **kwargs) + + return handler + + +class HealthHandler(BaseHTTPRequestHandler): + """HTTP handler for health checks.""" + + def __init__(self, policy_server, *args, **kwargs): + self.policy_server = policy_server + super().__init__(*args, **kwargs) + + def do_GET(self): # noqa: N802 + """Handle GET requests for health check.""" + if self.path == "/health": + self.send_health_response() + elif self.path == "/": + self.send_info_response() + else: + self.send_error(404, "Not Found") + + def send_health_response(self): + """Send health check response.""" + try: + # Check if the policy server is in a healthy state + is_healthy = ( + hasattr(self.policy_server, "_running_event") + and self.policy_server._running_event is not None + ) + + status_code = 200 if is_healthy else 503 + response = { + "status": "healthy" if is_healthy else "unhealthy", + "timestamp": time.time(), + "server_running": self.policy_server.running + if hasattr(self.policy_server, "running") + else False, + "policy_loaded": self.policy_server.policy is not None + if hasattr(self.policy_server, "policy") + else False, + } + + self.send_response(status_code) + self.send_header("Content-type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(response).encode()) + + except Exception as e: + self.send_response(500) + self.send_header("Content-type", "application/json") + self.end_headers() + error_response = {"status": "error", "message": str(e)} + self.wfile.write(json.dumps(error_response).encode()) + + def send_info_response(self): + """Send basic server info.""" + response = { + "service": "lerobot-policy-server", + "version": "1.0.0", + "endpoints": {"health": "/health", "grpc_port": self.policy_server.config.port}, + } + + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(response).encode()) + + def log_message(self, format, *args): + """Override to use our logger instead of stderr.""" + if hasattr(self.policy_server, "logger"): + self.policy_server.logger.debug(f"HTTP: {format % args}") + + @dataclass class TimedData: """A data object with timestamp and timestep information. diff --git a/src/lerobot/scripts/server/policy_server.py b/src/lerobot/scripts/server/policy_server.py index 669ccc58e..5ab55d267 100644 --- a/src/lerobot/scripts/server/policy_server.py +++ b/src/lerobot/scripts/server/policy_server.py @@ -30,6 +30,7 @@ import threading import time from concurrent import futures from dataclasses import asdict +from http.server import HTTPServer from pprint import pformat from queue import Empty, Queue @@ -39,13 +40,14 @@ import torch from lerobot.policies.factory import get_policy_class from lerobot.scripts.server.configs import PolicyServerConfig -from lerobot.scripts.server.constants import SUPPORTED_POLICIES +from lerobot.scripts.server.constants import HEALTH_CHECK_PORT, HEALTH_SERVER_HOST, SUPPORTED_POLICIES from lerobot.scripts.server.helpers import ( FPSTracker, Observation, RemotePolicyConfig, TimedAction, TimedObservation, + create_health_handler, get_logger, observations_similar, raw_observation_to_observation, @@ -82,6 +84,30 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): self.actions_per_chunk = None self.policy = None + # HTTP health server + self.http_server = None + self.http_thread = None + + def start_health_server(self): + """Start HTTP server for health checks on port 8081.""" + try: + health_handler = create_health_handler(self) + self.http_server = HTTPServer((HEALTH_SERVER_HOST, HEALTH_CHECK_PORT), health_handler) + self.http_thread = threading.Thread(target=self.http_server.serve_forever, daemon=True) + self.http_thread.start() + self.logger.info(f"Health server started on port {HEALTH_CHECK_PORT}") + except Exception as e: + self.logger.error(f"Failed to start health server: {e}") + + def stop_health_server(self): + """Stop the HTTP health server.""" + if self.http_server: + self.http_server.shutdown() + self.http_server.server_close() + if self.http_thread: + self.http_thread.join(timeout=5) + self.logger.info("Health server stopped") + @property def running(self): return self._running_event.is_set() @@ -372,6 +398,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): """Stop the server""" self._reset_server() self.logger.info("Server stopping...") + self.stop_health_server() @draccus.wrap() @@ -385,6 +412,7 @@ def serve(cfg: PolicyServerConfig): # Create the server instance first policy_server = PolicyServer(cfg) + policy_server.start_health_server() # Setup and start gRPC server server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) @@ -397,6 +425,7 @@ def serve(cfg: PolicyServerConfig): server.wait_for_termination() policy_server.logger.info("Server terminated") + policy_server.stop_health_server() if __name__ == "__main__":