Files
llm-in-text/backend/main.py

316 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import base64
import json
import logging
import os
import tempfile
import uuid
from typing import Optional
from fastapi import FastAPI, HTTPException, Request, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from geoip import get_ip_location_text
from llm import call_ollama, call_vlm_ocr
from prompt import build_completion_prompts, prepare_prompt_context
import markitdown
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
)
logger = logging.getLogger("api")
app = FastAPI()
ACTIVE_COMPLETIONS: dict[str, asyncio.Task] = {}
ACTIVE_COMPLETIONS_LOCK = asyncio.Lock()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*", "X-API-Key", "X-Client-IP", "X-Request-Id"],
)
API_KEY = "your-secret-key-here"
api_key_header = APIKeyHeader(name="X-API-Key")
async def get_api_key(api_key: str = Security(api_key_header)):
if api_key != API_KEY:
raise HTTPException(
status_code=403,
detail="Could not validate credentials",
)
return api_key
class UserPreferences(BaseModel):
language: str = "auto"
currency: str = "auto"
timezone: str = "auto"
class CompletionRequest(BaseModel):
prefix: str
suffix: str
languageId: str = "markdown"
model_thinking: str = "low"
privacy_mode: bool = False
user_preferences: Optional[UserPreferences] = None
class CancelCompletionRequest(BaseModel):
request_id: str
reason: str = "abort"
class OCRRequest(BaseModel):
image: str
filename: str = "image.jpg"
language: str = "auto"
class ConvertRequest(BaseModel):
file: str
filename: str = "document.pdf"
def _preview(text: str, limit: int = 80) -> str:
value = (text or "").replace("\n", "\\n")
if len(value) <= limit:
return value
return value[:limit] + "..."
def _sse_payload(payload: dict) -> str:
return f"data: {json.dumps(payload)}\n\n"
def get_client_ip(request: Request) -> str:
if request.client:
return request.headers.get("X-Client-IP") or request.client.host
return request.headers.get("X-Client-IP") or "unknown"
@app.post("/v1/completions")
async def create_completion(request: Request, req: CompletionRequest, api_key: str = Security(get_api_key)):
request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4())
request_tag = request_id[:8]
inference_task: Optional[asyncio.Task] = None
client_ip = "hidden"
location = ""
if not req.privacy_mode:
client_ip = get_client_ip(request)
location = get_ip_location_text(client_ip)
if location:
logger.info("[%s] client_location=%s", request_tag, location)
try:
logger.info(
"[%s] /v1/completions request_id=%s client_ip=%s prefix_chars=%d suffix_chars=%d lang=%s thinking=%s privacy=%s",
request_tag,
request_id,
client_ip,
len(req.prefix or ""),
len(req.suffix or ""),
req.languageId,
req.model_thinking,
req.privacy_mode,
)
llm_prefix, llm_suffix = prepare_prompt_context(req.prefix or "", req.suffix or "")
logger.info("[%s] llm_input_prefix=%r", request_tag, llm_prefix)
logger.info("[%s] llm_input_suffix=%r", request_tag, llm_suffix)
system_prompt, user_prompt = build_completion_prompts(
req.prefix,
req.suffix,
req.languageId,
location=location,
thinking_level=req.model_thinking,
preferences=req.user_preferences,
)
inference_task = asyncio.create_task(
call_ollama(
user_prompt,
system_prompt=system_prompt,
tag=f"{request_tag}-primary",
temperature=0.7,
thinking=req.model_thinking if req.model_thinking != "none" else None,
)
)
async with ACTIVE_COMPLETIONS_LOCK:
existing = ACTIVE_COMPLETIONS.get(request_id)
if existing and not existing.done():
existing.cancel()
ACTIVE_COMPLETIONS[request_id] = inference_task
result = await inference_task
content = result["content"] or ""
if not content.strip():
logger.warning("[%s] primary returned empty content, returning empty result", request_tag)
logger.info(
"[%s] completion resolved source=primary request_id=%s content_chars=%d content_preview='%s'",
request_tag,
request_id,
len(content),
_preview(content, 120),
)
async def generate():
yield _sse_payload({"content": content})
yield _sse_payload({"done": True})
return StreamingResponse(generate(), media_type="text/event-stream")
except asyncio.CancelledError:
logger.info("[%s] /v1/completions cancelled request_id=%s", request_tag, request_id)
async def cancelled():
yield _sse_payload({"cancelled": True, "request_id": request_id, "done": True})
return StreamingResponse(cancelled(), media_type="text/event-stream")
except Exception as e:
logger.exception("[%s] /v1/completions failed request_id=%s: %s", request_tag, request_id, e)
return JSONResponse(content={"error": str(e)}, status_code=500)
finally:
async with ACTIVE_COMPLETIONS_LOCK:
active = ACTIVE_COMPLETIONS.get(request_id)
if active is not None and active is inference_task:
ACTIVE_COMPLETIONS.pop(request_id, None)
@app.post("/v1/completions/cancel")
async def cancel_completion(req: CancelCompletionRequest, api_key: str = Security(get_api_key)):
request_tag = str(uuid.uuid4())[:8]
request_id = req.request_id or ""
async with ACTIVE_COMPLETIONS_LOCK:
task = ACTIVE_COMPLETIONS.get(request_id)
if task is None:
logger.info(
"[%s] /v1/completions/cancel request_id=%s status=not_found reason=%s",
request_tag,
request_id,
req.reason,
)
return {"cancelled": False, "status": "not_found"}
if task.done():
logger.info(
"[%s] /v1/completions/cancel request_id=%s status=already_done reason=%s",
request_tag,
request_id,
req.reason,
)
return {"cancelled": False, "status": "already_done"}
task.cancel()
logger.info(
"[%s] /v1/completions/cancel request_id=%s status=ok reason=%s",
request_tag,
request_id,
req.reason,
)
return {"cancelled": True, "status": "ok"}
@app.post("/v1/ocr")
async def ocr_image(request: OCRRequest, api_key: str = Security(get_api_key)):
request_id = str(uuid.uuid4())[:8]
try:
logger.info(
"[%s] /v1/ocr filename=%s language=%s image_base64_chars=%d",
request_id,
request.filename,
request.language,
len(request.image or ""),
)
image_bytes = base64.b64decode(request.image)
logger.info("[%s] /v1/ocr decoded image_bytes=%d", request_id, len(image_bytes))
result = await call_vlm_ocr(image_bytes, request.language)
logger.info(
"[%s] /v1/ocr success text_chars=%d text_preview='%s'",
request_id,
len(result or ""),
_preview(result or "", 120),
)
return {"text": result, "filename": request.filename}
except Exception as e:
logger.exception("[%s] /v1/ocr failed: %s", request_id, e)
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.post("/v1/convert")
async def convert_to_markdown(request: ConvertRequest, api_key: str = Security(get_api_key)):
"""鐏忓棙鏋冩禒鎯版祮閹诡澀璐烳arkdown閺嶇厧绱?""
request_id = str(uuid.uuid4())[:8]
try:
logger.info(
"[%s] /v1/convert filename=%s file_base64_chars=%d",
request_id,
request.filename,
len(request.file or ""),
)
# 鐟欙絿鐖淏ase64閺傚洣娆㈤崘鍛啇
file_bytes = base64.b64decode(request.file)
logger.info("[%s] /v1/convert decoded file_bytes=%d", request_id, len(file_bytes))
# 閼惧嘲褰囬弬鍥︽閹碘晛鐫嶉崥?
ext = os.path.splitext(request.filename)[1].lower()
# 閸掓稑缂撴稉瀛樻閺傚洣娆?
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
tmp.write(file_bytes)
tmp_path = tmp.name
try:
# 娴法鏁arkItDown鏉烆剚宕叉稉绡梐rkdown
md = markitdown.MarkItDown()
result = md.convert(tmp_path)
markdown_text = result.text_content
logger.info(
"[%s] /v1/convert success text_chars=%d text_preview='%s'",
request_id,
len(markdown_text or ""),
_preview(markdown_text, 120),
)
return {
"markdown": markdown_text,
"filename": request.filename
}
finally:
# 濞撳懐鎮婃稉瀛樻閺傚洣娆?
if os.path.exists(tmp_path):
os.unlink(tmp_path)
except Exception as e:
logger.exception("[%s] /v1/convert failed: %s", request_id, e)
return JSONResponse(content={"error": str(e)}, status_code=500)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)
# TTS and STT routes
from tts_asr import register_tts_asr_routes
register_tts_asr_routes(app)