Files
llm-in-text/backend/main.py

351 lines
11 KiB
Python

import asyncio
import base64
import json
import logging
import os
import tempfile
import uuid
from typing import Optional
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
from llm import call_ollama, call_vlm_ocr
from prompt import build_completion_prompts, prepare_prompt_context
import markitdown
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
)
logger = logging.getLogger("api")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:5173",
"http://localhost:3000",
"https://www.imageteach.tech",
"https://chat.imageteach.tech",
],
allow_credentials=False,
allow_methods=["POST", "OPTIONS"],
allow_headers=["Content-Type", "X-Request-Id"],
)
ACTIVE_COMPLETIONS: dict[str, asyncio.Task] = {}
ACTIVE_COMPLETIONS_LOCK = asyncio.Lock()
# Rate limiting
MAX_CONCURRENT_COMPLETIONS = 4
COMPLETION_RATE_LIMIT = 60 # per minute
# File size limits (bytes)
MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB
MAX_CONVERT_SIZE = 50 * 1024 * 1024 # 50MB
# Allowed file extensions
ALLOWED_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"}
ALLOWED_CONVERT_EXTENSIONS = {".pdf", ".docx", ".pptx", ".xlsx", ".md", ".txt"}
class UserPreferences(BaseModel):
language: str = "auto"
currency: str = "auto"
timezone: str = "auto"
class CompletionRequest(BaseModel):
prefix: str
suffix: str
languageId: str = "markdown"
model_thinking: str = "low"
privacy_mode: bool = False
user_preferences: Optional[UserPreferences] = None
class CancelCompletionRequest(BaseModel):
request_id: str
reason: str = "abort"
class OCRRequest(BaseModel):
image: str
filename: str = "image.jpg"
language: str = "auto"
class ConvertRequest(BaseModel):
file: str
filename: str = "document.pdf"
def _preview(text: str, limit: int = 80) -> str:
value = (text or "").replace("\n", "\\n")
if len(value) <= limit:
return value
return value[:limit] + "..."
def _error_response(request_id: str, code: str, message: str, status_code: int = 500) -> JSONResponse:
return JSONResponse(
content={
"error": {
"code": code,
"message": message,
"request_id": request_id,
}
},
status_code=status_code,
)
def _sse_payload(payload: dict) -> str:
return f"data: {json.dumps(payload)}\n\n"
@app.post("/v1/completions")
async def create_completion(request: Request, req: CompletionRequest):
request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4())
request_tag = request_id[:8]
inference_task: Optional[asyncio.Task] = None
try:
logger.info(
"[%s] /v1/completions request_id=%s prefix_chars=%d suffix_chars=%d lang=%s thinking=%s privacy=%s",
request_tag,
request_id,
len(req.prefix or ""),
len(req.suffix or ""),
req.languageId,
req.model_thinking,
req.privacy_mode,
)
llm_prefix, llm_suffix = prepare_prompt_context(req.prefix or "", req.suffix or "")
logger.info("[%s] llm_input_prefix=%r", request_tag, llm_prefix)
logger.info("[%s] llm_input_suffix=%r", request_tag, llm_suffix)
system_prompt, user_prompt = build_completion_prompts(
req.prefix,
req.suffix,
req.languageId,
thinking_level=req.model_thinking,
preferences=req.user_preferences,
)
inference_task = asyncio.create_task(
call_ollama(
user_prompt,
system_prompt=system_prompt,
tag=f"{request_tag}-primary",
temperature=0.7,
thinking=req.model_thinking if req.model_thinking != "none" else None,
)
)
async with ACTIVE_COMPLETIONS_LOCK:
existing = ACTIVE_COMPLETIONS.get(request_id)
if existing and not existing.done():
existing.cancel()
ACTIVE_COMPLETIONS[request_id] = inference_task
result = await inference_task
content = result["content"] or ""
if not content.strip():
logger.warning("[%s] primary returned empty content, returning empty result", request_tag)
logger.info(
"[%s] completion resolved source=primary request_id=%s content_chars=%d content_preview='%s'",
request_tag,
request_id,
len(content),
_preview(content, 120),
)
async def generate():
yield _sse_payload({"content": content})
yield _sse_payload({"done": True})
return StreamingResponse(generate(), media_type="text/event-stream")
except asyncio.CancelledError:
logger.info("[%s] /v1/completions cancelled request_id=%s", request_tag, request_id)
async def cancelled():
yield _sse_payload({"cancelled": True, "request_id": request_id, "done": True})
return StreamingResponse(cancelled(), media_type="text/event-stream")
except Exception as e:
logger.exception("[%s] /v1/completions failed request_id=%s: %s", request_tag, request_id, e)
return _error_response(request_id, "INTERNAL_ERROR", "Service temporarily unavailable", 500)
finally:
async with ACTIVE_COMPLETIONS_LOCK:
active = ACTIVE_COMPLETIONS.get(request_id)
if active is not None and active is inference_task:
ACTIVE_COMPLETIONS.pop(request_id, None)
@app.post("/v1/completions/cancel")
async def cancel_completion(req: CancelCompletionRequest):
request_tag = str(uuid.uuid4())[:8]
request_id = req.request_id or ""
async with ACTIVE_COMPLETIONS_LOCK:
task = ACTIVE_COMPLETIONS.get(request_id)
if task is None:
logger.info(
"[%s] /v1/completions/cancel request_id=%s status=not_found reason=%s",
request_tag,
request_id,
req.reason,
)
return {"cancelled": False, "status": "not_found"}
if task.done():
logger.info(
"[%s] /v1/completions/cancel request_id=%s status=already_done reason=%s",
request_tag,
request_id,
req.reason,
)
return {"cancelled": False, "status": "already_done"}
task.cancel()
logger.info(
"[%s] /v1/completions/cancel request_id=%s status=ok reason=%s",
request_tag,
request_id,
req.reason,
)
return {"cancelled": True, "status": "ok"}
@app.post("/v1/ocr")
async def ocr_image(request: OCRRequest):
request_id = str(uuid.uuid4())[:8]
try:
logger.info(
"[%s] /v1/ocr filename=%s language=%s image_base64_chars=%d",
request_id,
request.filename,
request.language,
len(request.image or ""),
)
# Check file size before decoding
if len(request.image or "") > MAX_IMAGE_SIZE * 4 // 3: # base64 overhead
return _error_response(request_id, "FILE_TOO_LARGE", "Image exceeds 10MB limit", 413)
# Check extension
ext = os.path.splitext(request.filename)[1].lower()
if ext not in ALLOWED_IMAGE_EXTENSIONS:
return _error_response(request_id, "INVALID_FILE_TYPE", "Only jpg/png/webp allowed", 415)
image_bytes = base64.b64decode(request.image)
# Check actual decoded size
if len(image_bytes) > MAX_IMAGE_SIZE:
return _error_response(request_id, "FILE_TOO_LARGE", "Image exceeds 10MB limit", 413)
logger.info("[%s] /v1/ocr decoded image_bytes=%d", request_id, len(image_bytes))
result = await call_vlm_ocr(image_bytes, request.language)
logger.info(
"[%s] /v1/ocr success text_chars=%d text_preview='%s'",
request_id,
len(result or ""),
_preview(result or "", 120),
)
return {"text": result, "filename": request.filename}
except Exception as e:
logger.exception("[%s] /v1/ocr failed: %s", request_id, e)
return _error_response(request_id, "OCR_FAILED", "Failed to process image", 500)
@app.post("/v1/convert")
async def convert_to_markdown(request: ConvertRequest):
"""将文件转换为Markdown格式"""
request_id = str(uuid.uuid4())[:8]
try:
logger.info(
"[%s] /v1/convert filename=%s file_base64_chars=%d",
request_id,
request.filename,
len(request.file or ""),
)
# Check file size before decoding
if len(request.file or "") > MAX_CONVERT_SIZE * 4 // 3:
return _error_response(request_id, "FILE_TOO_LARGE", "File exceeds 50MB limit", 413)
# Get file extension and validate
ext = os.path.splitext(request.filename)[1].lower()
if ext not in ALLOWED_CONVERT_EXTENSIONS:
return _error_response(request_id, "INVALID_FILE_TYPE", "Only pdf/docx/pptx/xlsx/md/txt allowed", 415)
# 解码Base64文件内容
file_bytes = base64.b64decode(request.file)
# Check actual decoded size
if len(file_bytes) > MAX_CONVERT_SIZE:
return _error_response(request_id, "FILE_TOO_LARGE", "File exceeds 50MB limit", 413)
logger.info("[%s] /v1/convert decoded file_bytes=%d", request_id, len(file_bytes))
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
tmp.write(file_bytes)
tmp_path = tmp.name
try:
# 使用MarkItDown转换为Markdown
md = markitdown.MarkItDown()
result = md.convert(tmp_path)
markdown_text = result.text_content
logger.info(
"[%s] /v1/convert success text_chars=%d text_preview='%s'",
request_id,
len(markdown_text or ""),
_preview(markdown_text, 120),
)
return {
"markdown": markdown_text,
"filename": request.filename
}
finally:
# 清理临时文件
if os.path.exists(tmp_path):
os.unlink(tmp_path)
except Exception as e:
logger.exception("[%s] /v1/convert failed: %s", request_id, e)
return _error_response(request_id, "CONVERT_FAILED", "Failed to convert file", 500)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)
@app.get("/health/live")
async def health_live():
return {"status": "ok"}
@app.get("/health/ready")
async def health_ready():
# Check if critical components are available
try:
# Could add more checks here (e.g., Ollama connectivity)
return {"status": "ready"}
except Exception as e:
logger.warning("[health/ready] not ready: %s", e)
return _error_response("health-check", "NOT_READY", "Service not ready", 503)