351 lines
11 KiB
Python
351 lines
11 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
import uuid
|
|
from typing import Optional
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from llm import call_ollama, call_vlm_ocr
|
|
from prompt import build_completion_prompts, prepare_prompt_context
|
|
import markitdown
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger("api")
|
|
|
|
app = FastAPI()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=[
|
|
"http://localhost:5173",
|
|
"http://localhost:3000",
|
|
"https://www.imageteach.tech",
|
|
"https://chat.imageteach.tech",
|
|
],
|
|
allow_credentials=False,
|
|
allow_methods=["POST", "OPTIONS"],
|
|
allow_headers=["Content-Type", "X-Request-Id"],
|
|
)
|
|
|
|
ACTIVE_COMPLETIONS: dict[str, asyncio.Task] = {}
|
|
ACTIVE_COMPLETIONS_LOCK = asyncio.Lock()
|
|
|
|
# Rate limiting
|
|
MAX_CONCURRENT_COMPLETIONS = 4
|
|
COMPLETION_RATE_LIMIT = 60 # per minute
|
|
|
|
# File size limits (bytes)
|
|
MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB
|
|
MAX_CONVERT_SIZE = 50 * 1024 * 1024 # 50MB
|
|
|
|
# Allowed file extensions
|
|
ALLOWED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"}
|
|
ALLOWED_CONVERT_EXTENSIONS = {".pdf", ".docx", ".pptx", ".xlsx", ".md", ".txt"}
|
|
|
|
|
|
class UserPreferences(BaseModel):
|
|
language: str = "auto"
|
|
currency: str = "auto"
|
|
timezone: str = "auto"
|
|
|
|
|
|
class CompletionRequest(BaseModel):
|
|
prefix: str
|
|
suffix: str
|
|
languageId: str = "markdown"
|
|
model_thinking: str = "low"
|
|
privacy_mode: bool = False
|
|
user_preferences: Optional[UserPreferences] = None
|
|
|
|
|
|
class CancelCompletionRequest(BaseModel):
|
|
request_id: str
|
|
reason: str = "abort"
|
|
|
|
|
|
class OCRRequest(BaseModel):
|
|
image: str
|
|
filename: str = "image.jpg"
|
|
language: str = "auto"
|
|
|
|
|
|
class ConvertRequest(BaseModel):
|
|
file: str
|
|
filename: str = "document.pdf"
|
|
|
|
|
|
def _preview(text: str, limit: int = 80) -> str:
|
|
value = (text or "").replace("\n", "\\n")
|
|
if len(value) <= limit:
|
|
return value
|
|
return value[:limit] + "..."
|
|
|
|
|
|
def _error_response(request_id: str, code: str, message: str, status_code: int = 500) -> JSONResponse:
|
|
return JSONResponse(
|
|
content={
|
|
"error": {
|
|
"code": code,
|
|
"message": message,
|
|
"request_id": request_id,
|
|
}
|
|
},
|
|
status_code=status_code,
|
|
)
|
|
|
|
|
|
def _sse_payload(payload: dict) -> str:
|
|
return f"data: {json.dumps(payload)}\n\n"
|
|
|
|
|
|
@app.post("/v1/completions")
|
|
async def create_completion(request: Request, req: CompletionRequest):
|
|
request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4())
|
|
request_tag = request_id[:8]
|
|
inference_task: Optional[asyncio.Task] = None
|
|
|
|
try:
|
|
logger.info(
|
|
"[%s] /v1/completions request_id=%s prefix_chars=%d suffix_chars=%d lang=%s thinking=%s privacy=%s",
|
|
request_tag,
|
|
request_id,
|
|
len(req.prefix or ""),
|
|
len(req.suffix or ""),
|
|
req.languageId,
|
|
req.model_thinking,
|
|
req.privacy_mode,
|
|
)
|
|
|
|
llm_prefix, llm_suffix = prepare_prompt_context(req.prefix or "", req.suffix or "")
|
|
logger.info("[%s] llm_input_prefix=%r", request_tag, llm_prefix)
|
|
logger.info("[%s] llm_input_suffix=%r", request_tag, llm_suffix)
|
|
|
|
system_prompt, user_prompt = build_completion_prompts(
|
|
req.prefix,
|
|
req.suffix,
|
|
req.languageId,
|
|
thinking_level=req.model_thinking,
|
|
preferences=req.user_preferences,
|
|
)
|
|
|
|
inference_task = asyncio.create_task(
|
|
call_ollama(
|
|
user_prompt,
|
|
system_prompt=system_prompt,
|
|
tag=f"{request_tag}-primary",
|
|
temperature=0.7,
|
|
thinking=req.model_thinking if req.model_thinking != "none" else None,
|
|
)
|
|
)
|
|
|
|
async with ACTIVE_COMPLETIONS_LOCK:
|
|
existing = ACTIVE_COMPLETIONS.get(request_id)
|
|
if existing and not existing.done():
|
|
existing.cancel()
|
|
ACTIVE_COMPLETIONS[request_id] = inference_task
|
|
|
|
result = await inference_task
|
|
content = result["content"] or ""
|
|
if not content.strip():
|
|
logger.warning("[%s] primary returned empty content, returning empty result", request_tag)
|
|
logger.info(
|
|
"[%s] completion resolved source=primary request_id=%s content_chars=%d content_preview='%s'",
|
|
request_tag,
|
|
request_id,
|
|
len(content),
|
|
_preview(content, 120),
|
|
)
|
|
|
|
async def generate():
|
|
yield _sse_payload({"content": content})
|
|
yield _sse_payload({"done": True})
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream")
|
|
except asyncio.CancelledError:
|
|
logger.info("[%s] /v1/completions cancelled request_id=%s", request_tag, request_id)
|
|
|
|
async def cancelled():
|
|
yield _sse_payload({"cancelled": True, "request_id": request_id, "done": True})
|
|
|
|
return StreamingResponse(cancelled(), media_type="text/event-stream")
|
|
except Exception as e:
|
|
logger.exception("[%s] /v1/completions failed request_id=%s: %s", request_tag, request_id, e)
|
|
return _error_response(request_id, "INTERNAL_ERROR", "Service temporarily unavailable", 500)
|
|
finally:
|
|
async with ACTIVE_COMPLETIONS_LOCK:
|
|
active = ACTIVE_COMPLETIONS.get(request_id)
|
|
if active is not None and active is inference_task:
|
|
ACTIVE_COMPLETIONS.pop(request_id, None)
|
|
|
|
|
|
@app.post("/v1/completions/cancel")
|
|
async def cancel_completion(req: CancelCompletionRequest):
|
|
request_tag = str(uuid.uuid4())[:8]
|
|
request_id = req.request_id or ""
|
|
|
|
async with ACTIVE_COMPLETIONS_LOCK:
|
|
task = ACTIVE_COMPLETIONS.get(request_id)
|
|
if task is None:
|
|
logger.info(
|
|
"[%s] /v1/completions/cancel request_id=%s status=not_found reason=%s",
|
|
request_tag,
|
|
request_id,
|
|
req.reason,
|
|
)
|
|
return {"cancelled": False, "status": "not_found"}
|
|
|
|
if task.done():
|
|
logger.info(
|
|
"[%s] /v1/completions/cancel request_id=%s status=already_done reason=%s",
|
|
request_tag,
|
|
request_id,
|
|
req.reason,
|
|
)
|
|
return {"cancelled": False, "status": "already_done"}
|
|
|
|
task.cancel()
|
|
|
|
logger.info(
|
|
"[%s] /v1/completions/cancel request_id=%s status=ok reason=%s",
|
|
request_tag,
|
|
request_id,
|
|
req.reason,
|
|
)
|
|
return {"cancelled": True, "status": "ok"}
|
|
|
|
|
|
@app.post("/v1/ocr")
|
|
async def ocr_image(request: OCRRequest):
|
|
request_id = str(uuid.uuid4())[:8]
|
|
try:
|
|
logger.info(
|
|
"[%s] /v1/ocr filename=%s language=%s image_base64_chars=%d",
|
|
request_id,
|
|
request.filename,
|
|
request.language,
|
|
len(request.image or ""),
|
|
)
|
|
|
|
# Check file size before decoding
|
|
if len(request.image or "") > MAX_IMAGE_SIZE * 4 // 3: # base64 overhead
|
|
return _error_response(request_id, "FILE_TOO_LARGE", "Image exceeds 10MB limit", 413)
|
|
|
|
# Check extension
|
|
ext = os.path.splitext(request.filename)[1].lower()
|
|
if ext not in ALLOWED_IMAGE_EXTENSIONS:
|
|
return _error_response(request_id, "INVALID_FILE_TYPE", "Only jpg/png/webp allowed", 415)
|
|
|
|
image_bytes = base64.b64decode(request.image)
|
|
|
|
# Check actual decoded size
|
|
if len(image_bytes) > MAX_IMAGE_SIZE:
|
|
return _error_response(request_id, "FILE_TOO_LARGE", "Image exceeds 10MB limit", 413)
|
|
|
|
logger.info("[%s] /v1/ocr decoded image_bytes=%d", request_id, len(image_bytes))
|
|
result = await call_vlm_ocr(image_bytes, request.language)
|
|
logger.info(
|
|
"[%s] /v1/ocr success text_chars=%d text_preview='%s'",
|
|
request_id,
|
|
len(result or ""),
|
|
_preview(result or "", 120),
|
|
)
|
|
return {"text": result, "filename": request.filename}
|
|
except Exception as e:
|
|
logger.exception("[%s] /v1/ocr failed: %s", request_id, e)
|
|
return _error_response(request_id, "OCR_FAILED", "Failed to process image", 500)
|
|
|
|
|
|
@app.post("/v1/convert")
|
|
async def convert_to_markdown(request: ConvertRequest):
|
|
"""将文件转换为Markdown格式"""
|
|
request_id = str(uuid.uuid4())[:8]
|
|
|
|
try:
|
|
logger.info(
|
|
"[%s] /v1/convert filename=%s file_base64_chars=%d",
|
|
request_id,
|
|
request.filename,
|
|
len(request.file or ""),
|
|
)
|
|
|
|
# Check file size before decoding
|
|
if len(request.file or "") > MAX_CONVERT_SIZE * 4 // 3:
|
|
return _error_response(request_id, "FILE_TOO_LARGE", "File exceeds 50MB limit", 413)
|
|
|
|
# Get file extension and validate
|
|
ext = os.path.splitext(request.filename)[1].lower()
|
|
if ext not in ALLOWED_CONVERT_EXTENSIONS:
|
|
return _error_response(request_id, "INVALID_FILE_TYPE", "Only pdf/docx/pptx/xlsx/md/txt allowed", 415)
|
|
|
|
# 解码Base64文件内容
|
|
file_bytes = base64.b64decode(request.file)
|
|
|
|
# Check actual decoded size
|
|
if len(file_bytes) > MAX_CONVERT_SIZE:
|
|
return _error_response(request_id, "FILE_TOO_LARGE", "File exceeds 50MB limit", 413)
|
|
|
|
logger.info("[%s] /v1/convert decoded file_bytes=%d", request_id, len(file_bytes))
|
|
|
|
# 创建临时文件
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
|
|
tmp.write(file_bytes)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
# 使用MarkItDown转换为Markdown
|
|
md = markitdown.MarkItDown()
|
|
result = md.convert(tmp_path)
|
|
markdown_text = result.text_content
|
|
|
|
logger.info(
|
|
"[%s] /v1/convert success text_chars=%d text_preview='%s'",
|
|
request_id,
|
|
len(markdown_text or ""),
|
|
_preview(markdown_text, 120),
|
|
)
|
|
|
|
return {
|
|
"markdown": markdown_text,
|
|
"filename": request.filename
|
|
}
|
|
finally:
|
|
# 清理临时文件
|
|
if os.path.exists(tmp_path):
|
|
os.unlink(tmp_path)
|
|
|
|
except Exception as e:
|
|
logger.exception("[%s] /v1/convert failed: %s", request_id, e)
|
|
return _error_response(request_id, "CONVERT_FAILED", "Failed to convert file", 500)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8001)
|
|
|
|
|
|
@app.get("/health/live")
|
|
async def health_live():
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.get("/health/ready")
|
|
async def health_ready():
|
|
# Check if critical components are available
|
|
try:
|
|
# Could add more checks here (e.g., Ollama connectivity)
|
|
return {"status": "ready"}
|
|
except Exception as e:
|
|
logger.warning("[health/ready] not ready: %s", e)
|
|
return _error_response("health-check", "NOT_READY", "Service not ready", 503)
|