Files
llm-in-text/backend/main.py
ydy0615 8d89c2a0f6 Merge remote changes with local modifications
- Add docx and html2pdf.js export functionality (from remote)
- Update backend with new API endpoints
- Sync local configuration changes
2026-03-10 23:10:11 +08:00

310 lines
9.7 KiB
Python

import asyncio
import base64
import json
import logging
import os
import tempfile
import uuid
from typing import Optional
from fastapi import FastAPI, HTTPException, Request, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from geoip import get_ip_location_text
from llm import call_ollama, call_vlm_ocr
from prompt import build_completion_prompts, prepare_prompt_context
import markitdown
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
)
logger = logging.getLogger("api")
app = FastAPI()
ACTIVE_COMPLETIONS: dict[str, asyncio.Task] = {}
ACTIVE_COMPLETIONS_LOCK = asyncio.Lock()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*", "X-API-Key", "X-Client-IP", "X-Request-Id"],
)
API_KEY = "your-secret-key-here"
api_key_header = APIKeyHeader(name="X-API-Key")
async def get_api_key(api_key: str = Security(api_key_header)):
if api_key != API_KEY:
raise HTTPException(
status_code=403,
detail="Could not validate credentials",
)
return api_key
class UserPreferences(BaseModel):
language: str = "auto"
currency: str = "auto"
timezone: str = "auto"
class CompletionRequest(BaseModel):
prefix: str
suffix: str
languageId: str = "markdown"
model_thinking: str = "low"
privacy_mode: bool = False
user_preferences: Optional[UserPreferences] = None
class CancelCompletionRequest(BaseModel):
request_id: str
reason: str = "abort"
class OCRRequest(BaseModel):
image: str
filename: str = "image.jpg"
language: str = "auto"
class ConvertRequest(BaseModel):
file: str
filename: str = "document.pdf"
def _preview(text: str, limit: int = 80) -> str:
value = (text or "").replace("\n", "\\n")
if len(value) <= limit:
return value
return value[:limit] + "..."
def _sse_payload(payload: dict) -> str:
return f"data: {json.dumps(payload)}\n\n"
def get_client_ip(request: Request) -> str:
if request.client:
return request.headers.get("X-Client-IP") or request.client.host
return request.headers.get("X-Client-IP") or "unknown"
@app.post("/v1/completions")
async def create_completion(request: Request, req: CompletionRequest, api_key: str = Security(get_api_key)):
request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4())
request_tag = request_id[:8]
inference_task: Optional[asyncio.Task] = None
client_ip = "hidden"
location = ""
if not req.privacy_mode:
client_ip = get_client_ip(request)
location = get_ip_location_text(client_ip)
if location:
logger.info("[%s] client_location=%s", request_tag, location)
try:
logger.info(
"[%s] /v1/completions request_id=%s client_ip=%s prefix_chars=%d suffix_chars=%d lang=%s thinking=%s privacy=%s",
request_tag,
request_id,
client_ip,
len(req.prefix or ""),
len(req.suffix or ""),
req.languageId,
req.model_thinking,
req.privacy_mode,
)
llm_prefix, llm_suffix = prepare_prompt_context(req.prefix or "", req.suffix or "")
logger.info("[%s] llm_input_prefix=%r", request_tag, llm_prefix)
logger.info("[%s] llm_input_suffix=%r", request_tag, llm_suffix)
system_prompt, user_prompt = build_completion_prompts(
req.prefix,
req.suffix,
req.languageId,
location=location,
thinking_level=req.model_thinking,
preferences=req.user_preferences,
)
inference_task = asyncio.create_task(
call_ollama(
user_prompt,
system_prompt=system_prompt,
tag=f"{request_tag}-primary",
temperature=0.7,
thinking=req.model_thinking if req.model_thinking != "none" else None,
)
)
async with ACTIVE_COMPLETIONS_LOCK:
existing = ACTIVE_COMPLETIONS.get(request_id)
if existing and not existing.done():
existing.cancel()
ACTIVE_COMPLETIONS[request_id] = inference_task
result = await inference_task
content = result["content"] or ""
if not content.strip():
logger.warning("[%s] primary returned empty content, returning empty result", request_tag)
logger.info(
"[%s] completion resolved source=primary request_id=%s content_chars=%d content_preview='%s'",
request_tag,
request_id,
len(content),
_preview(content, 120),
)
async def generate():
yield _sse_payload({"content": content})
yield _sse_payload({"done": True})
return StreamingResponse(generate(), media_type="text/event-stream")
except asyncio.CancelledError:
logger.info("[%s] /v1/completions cancelled request_id=%s", request_tag, request_id)
async def cancelled():
yield _sse_payload({"cancelled": True, "request_id": request_id, "done": True})
return StreamingResponse(cancelled(), media_type="text/event-stream")
except Exception as e:
logger.exception("[%s] /v1/completions failed request_id=%s: %s", request_tag, request_id, e)
return JSONResponse(content={"error": str(e)}, status_code=500)
finally:
async with ACTIVE_COMPLETIONS_LOCK:
active = ACTIVE_COMPLETIONS.get(request_id)
if active is not None and active is inference_task:
ACTIVE_COMPLETIONS.pop(request_id, None)
@app.post("/v1/completions/cancel")
async def cancel_completion(req: CancelCompletionRequest, api_key: str = Security(get_api_key)):
request_tag = str(uuid.uuid4())[:8]
request_id = req.request_id or ""
async with ACTIVE_COMPLETIONS_LOCK:
task = ACTIVE_COMPLETIONS.get(request_id)
if task is None:
logger.info(
"[%s] /v1/completions/cancel request_id=%s status=not_found reason=%s",
request_tag,
request_id,
req.reason,
)
return {"cancelled": False, "status": "not_found"}
if task.done():
logger.info(
"[%s] /v1/completions/cancel request_id=%s status=already_done reason=%s",
request_tag,
request_id,
req.reason,
)
return {"cancelled": False, "status": "already_done"}
task.cancel()
logger.info(
"[%s] /v1/completions/cancel request_id=%s status=ok reason=%s",
request_tag,
request_id,
req.reason,
)
return {"cancelled": True, "status": "ok"}
@app.post("/v1/ocr")
async def ocr_image(request: OCRRequest, api_key: str = Security(get_api_key)):
request_id = str(uuid.uuid4())[:8]
try:
logger.info(
"[%s] /v1/ocr filename=%s language=%s image_base64_chars=%d",
request_id,
request.filename,
request.language,
len(request.image or ""),
)
image_bytes = base64.b64decode(request.image)
logger.info("[%s] /v1/ocr decoded image_bytes=%d", request_id, len(image_bytes))
result = await call_vlm_ocr(image_bytes, request.language)
logger.info(
"[%s] /v1/ocr success text_chars=%d text_preview='%s'",
request_id,
len(result or ""),
_preview(result or "", 120),
)
return {"text": result, "filename": request.filename}
except Exception as e:
logger.exception("[%s] /v1/ocr failed: %s", request_id, e)
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.post("/v1/convert")
async def convert_to_markdown(request: ConvertRequest, api_key: str = Security(get_api_key)):
"""将文件转换为Markdown格式"""
request_id = str(uuid.uuid4())[:8]
try:
logger.info(
"[%s] /v1/convert filename=%s file_base64_chars=%d",
request_id,
request.filename,
len(request.file or ""),
)
# 解码Base64文件内容
file_bytes = base64.b64decode(request.file)
logger.info("[%s] /v1/convert decoded file_bytes=%d", request_id, len(file_bytes))
# 获取文件扩展名
ext = os.path.splitext(request.filename)[1].lower()
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
tmp.write(file_bytes)
tmp_path = tmp.name
try:
# 使用MarkItDown转换为Markdown
md = markitdown.MarkItDown()
result = md.convert(tmp_path)
markdown_text = result.text_content
logger.info(
"[%s] /v1/convert success text_chars=%d text_preview='%s'",
request_id,
len(markdown_text or ""),
_preview(markdown_text, 120),
)
return {
"markdown": markdown_text,
"filename": request.filename
}
finally:
# 清理临时文件
if os.path.exists(tmp_path):
os.unlink(tmp_path)
except Exception as e:
logger.exception("[%s] /v1/convert failed: %s", request_id, e)
return JSONResponse(content={"error": str(e)}, status_code=500)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)