Files
llm-in-text/backend/main.py

164 lines
5.2 KiB
Python

from fastapi import FastAPI, Request, HTTPException, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
import json
import base64
import uuid
import logging
from prompt import build_prompt, prepare_prompt_context
from llm import call_ollama, call_vlm_ocr
from geoip import get_ip_location_text
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
)
logger = logging.getLogger("api")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*", "X-API-Key", "X-Client-IP"],
)
API_KEY = "your-secret-key-here" # 建议从环境变量读取
api_key_header = APIKeyHeader(name="X-API-Key")
async def get_api_key(api_key: str = Security(api_key_header)):
if api_key != API_KEY:
raise HTTPException(
status_code=403,
detail="Could not validate credentials"
)
return api_key
from typing import Optional
class UserPreferences(BaseModel):
language: str = 'auto'
currency: str = 'auto'
timezone: str = 'auto'
class CompletionRequest(BaseModel):
prefix: str
suffix: str
languageId: str = 'markdown'
model_thinking: str = 'low'
privacy_mode: bool = False
user_preferences: Optional[UserPreferences] = None
class OCRRequest(BaseModel):
image: str
filename: str = "image.jpg"
language: str = 'auto'
def _preview(text: str, limit: int = 80) -> str:
value = (text or "").replace("\n", "\\n")
if len(value) <= limit:
return value
return value[:limit] + "..."
def get_client_ip(request: Request) -> str:
return request.headers.get("X-Client-IP") or request.client.host if request.client else "unknown"
@app.post("/v1/completions")
async def create_completion(request: Request, req: CompletionRequest, api_key: str = Security(get_api_key)):
request_id = str(uuid.uuid4())[:8]
client_ip = "hidden"
location = ""
if not req.privacy_mode:
client_ip = get_client_ip(request)
# 查询 IP 归属地
location = get_ip_location_text(client_ip)
if location:
logger.info("[%s] client_location=%s", request_id, location)
try:
logger.info(
"[%s] /v1/completions client_ip=%s prefix_chars=%d suffix_chars=%d lang=%s thinking=%s privacy=%s",
request_id,
client_ip,
len(req.prefix or ""),
len(req.suffix or ""),
req.languageId,
req.model_thinking,
req.privacy_mode
)
llm_prefix, llm_suffix = prepare_prompt_context(req.prefix or "", req.suffix or "")
logger.info("[%s] llm_input_prefix=%r", request_id, llm_prefix)
logger.info("[%s] llm_input_suffix=%r", request_id, llm_suffix)
prompt = build_prompt(
req.prefix,
req.suffix,
req.languageId,
location=location,
thinking_level=req.model_thinking,
preferences=req.user_preferences
)
result = await call_ollama(
prompt,
tag=f"{request_id}-primary",
temperature=0.7,
thinking=req.model_thinking if req.model_thinking != "none" else None
)
content = result["content"] or ""
if not content.strip():
logger.warning("[%s] primary returned empty content, returning empty result", request_id)
logger.info(
"[%s] completion resolved source=primary content_chars=%d content_preview='%s'",
request_id,
len(content),
_preview(content, 120),
)
async def generate():
yield f"data: {json.dumps({'content': content})}\n\n"
yield f"data: {json.dumps({'done': True})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
except Exception as e:
logger.exception("[%s] /v1/completions failed: %s", request_id, e)
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.post("/v1/ocr")
async def ocr_image(request: OCRRequest, api_key: str = Security(get_api_key)):
request_id = str(uuid.uuid4())[:8]
try:
logger.info(
"[%s] /v1/ocr filename=%s language=%s image_base64_chars=%d",
request_id,
request.filename,
request.language,
len(request.image or ""),
)
image_bytes = base64.b64decode(request.image)
logger.info("[%s] /v1/ocr decoded image_bytes=%d", request_id, len(image_bytes))
result = await call_vlm_ocr(image_bytes, request.language)
logger.info(
"[%s] /v1/ocr success text_chars=%d text_preview='%s'",
request_id,
len(result or ""),
_preview(result or "", 120),
)
return {"text": result, "filename": request.filename}
except Exception as e:
logger.exception("[%s] /v1/ocr failed: %s", request_id, e)
return JSONResponse(content={"error": str(e)}, status_code=500)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)