import asyncio import base64 import json import logging import os import tempfile import uuid from typing import Optional from fastapi import FastAPI, HTTPException, Request, Security from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from fastapi.security import APIKeyHeader from pydantic import BaseModel from geoip import get_ip_location_text from llm import call_ollama, call_vlm_ocr from prompt import build_completion_prompts, prepare_prompt_context import markitdown logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s - %(message)s", ) logger = logging.getLogger("api") app = FastAPI() ACTIVE_COMPLETIONS: dict[str, asyncio.Task] = {} ACTIVE_COMPLETIONS_LOCK = asyncio.Lock() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*", "X-API-Key", "X-Client-IP", "X-Request-Id"], ) API_KEY = "your-secret-key-here" api_key_header = APIKeyHeader(name="X-API-Key") async def get_api_key(api_key: str = Security(api_key_header)): if api_key != API_KEY: raise HTTPException( status_code=403, detail="Could not validate credentials", ) return api_key class UserPreferences(BaseModel): language: str = "auto" currency: str = "auto" timezone: str = "auto" class CompletionRequest(BaseModel): prefix: str suffix: str languageId: str = "markdown" model_thinking: str = "low" privacy_mode: bool = False user_preferences: Optional[UserPreferences] = None class CancelCompletionRequest(BaseModel): request_id: str reason: str = "abort" class OCRRequest(BaseModel): image: str filename: str = "image.jpg" language: str = "auto" class ConvertRequest(BaseModel): file: str filename: str = "document.pdf" def _preview(text: str, limit: int = 80) -> str: value = (text or "").replace("\n", "\\n") if len(value) <= limit: return value return value[:limit] + "..." def _sse_payload(payload: dict) -> str: return f"data: {json.dumps(payload)}\n\n" def get_client_ip(request: Request) -> str: if request.client: return request.headers.get("X-Client-IP") or request.client.host return request.headers.get("X-Client-IP") or "unknown" @app.post("/v1/completions") async def create_completion(request: Request, req: CompletionRequest, api_key: str = Security(get_api_key)): request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4()) request_tag = request_id[:8] inference_task: Optional[asyncio.Task] = None client_ip = "hidden" location = "" if not req.privacy_mode: client_ip = get_client_ip(request) location = get_ip_location_text(client_ip) if location: logger.info("[%s] client_location=%s", request_tag, location) try: logger.info( "[%s] /v1/completions request_id=%s client_ip=%s prefix_chars=%d suffix_chars=%d lang=%s thinking=%s privacy=%s", request_tag, request_id, client_ip, len(req.prefix or ""), len(req.suffix or ""), req.languageId, req.model_thinking, req.privacy_mode, ) llm_prefix, llm_suffix = prepare_prompt_context(req.prefix or "", req.suffix or "") logger.info("[%s] llm_input_prefix=%r", request_tag, llm_prefix) logger.info("[%s] llm_input_suffix=%r", request_tag, llm_suffix) system_prompt, user_prompt = build_completion_prompts( req.prefix, req.suffix, req.languageId, location=location, thinking_level=req.model_thinking, preferences=req.user_preferences, ) inference_task = asyncio.create_task( call_ollama( user_prompt, system_prompt=system_prompt, tag=f"{request_tag}-primary", temperature=0.7, thinking=req.model_thinking if req.model_thinking != "none" else None, ) ) async with ACTIVE_COMPLETIONS_LOCK: existing = ACTIVE_COMPLETIONS.get(request_id) if existing and not existing.done(): existing.cancel() ACTIVE_COMPLETIONS[request_id] = inference_task result = await inference_task content = result["content"] or "" if not content.strip(): logger.warning("[%s] primary returned empty content, returning empty result", request_tag) logger.info( "[%s] completion resolved source=primary request_id=%s content_chars=%d content_preview='%s'", request_tag, request_id, len(content), _preview(content, 120), ) async def generate(): yield _sse_payload({"content": content}) yield _sse_payload({"done": True}) return StreamingResponse(generate(), media_type="text/event-stream") except asyncio.CancelledError: logger.info("[%s] /v1/completions cancelled request_id=%s", request_tag, request_id) async def cancelled(): yield _sse_payload({"cancelled": True, "request_id": request_id, "done": True}) return StreamingResponse(cancelled(), media_type="text/event-stream") except Exception as e: logger.exception("[%s] /v1/completions failed request_id=%s: %s", request_tag, request_id, e) return JSONResponse(content={"error": str(e)}, status_code=500) finally: async with ACTIVE_COMPLETIONS_LOCK: active = ACTIVE_COMPLETIONS.get(request_id) if active is not None and active is inference_task: ACTIVE_COMPLETIONS.pop(request_id, None) @app.post("/v1/completions/cancel") async def cancel_completion(req: CancelCompletionRequest, api_key: str = Security(get_api_key)): request_tag = str(uuid.uuid4())[:8] request_id = req.request_id or "" async with ACTIVE_COMPLETIONS_LOCK: task = ACTIVE_COMPLETIONS.get(request_id) if task is None: logger.info( "[%s] /v1/completions/cancel request_id=%s status=not_found reason=%s", request_tag, request_id, req.reason, ) return {"cancelled": False, "status": "not_found"} if task.done(): logger.info( "[%s] /v1/completions/cancel request_id=%s status=already_done reason=%s", request_tag, request_id, req.reason, ) return {"cancelled": False, "status": "already_done"} task.cancel() logger.info( "[%s] /v1/completions/cancel request_id=%s status=ok reason=%s", request_tag, request_id, req.reason, ) return {"cancelled": True, "status": "ok"} @app.post("/v1/ocr") async def ocr_image(request: OCRRequest, api_key: str = Security(get_api_key)): request_id = str(uuid.uuid4())[:8] try: logger.info( "[%s] /v1/ocr filename=%s language=%s image_base64_chars=%d", request_id, request.filename, request.language, len(request.image or ""), ) image_bytes = base64.b64decode(request.image) logger.info("[%s] /v1/ocr decoded image_bytes=%d", request_id, len(image_bytes)) result = await call_vlm_ocr(image_bytes, request.language) logger.info( "[%s] /v1/ocr success text_chars=%d text_preview='%s'", request_id, len(result or ""), _preview(result or "", 120), ) return {"text": result, "filename": request.filename} except Exception as e: logger.exception("[%s] /v1/ocr failed: %s", request_id, e) return JSONResponse(content={"error": str(e)}, status_code=500) @app.post("/v1/convert") async def convert_to_markdown(request: ConvertRequest, api_key: str = Security(get_api_key)): """将文件转换为Markdown格式""" request_id = str(uuid.uuid4())[:8] try: logger.info( "[%s] /v1/convert filename=%s file_base64_chars=%d", request_id, request.filename, len(request.file or ""), ) # 解码Base64文件内容 file_bytes = base64.b64decode(request.file) logger.info("[%s] /v1/convert decoded file_bytes=%d", request_id, len(file_bytes)) # 获取文件扩展名 ext = os.path.splitext(request.filename)[1].lower() # 创建临时文件 with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp: tmp.write(file_bytes) tmp_path = tmp.name try: # 使用MarkItDown转换为Markdown md = markitdown.MarkItDown() result = md.convert(tmp_path) markdown_text = result.text_content logger.info( "[%s] /v1/convert success text_chars=%d text_preview='%s'", request_id, len(markdown_text or ""), _preview(markdown_text, 120), ) return { "markdown": markdown_text, "filename": request.filename } finally: # 清理临时文件 if os.path.exists(tmp_path): os.unlink(tmp_path) except Exception as e: logger.exception("[%s] /v1/convert failed: %s", request_id, e) return JSONResponse(content={"error": str(e)}, status_code=500) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8001)