Files
llm-in-text/backend/main.py
ydy0615 637456ee34 feat(api): add completion request cancellation and mermaid rendering
Add support for cancelling in-progress LLM completion requests via new /v1/completions/cancel endpoint with task tracking. Implement mermaid diagram rendering in the Milkdown editor with a new mermaidPlugin. Update copilotPlugin to properly abort requests with descriptive reasons. Refactor settings panel to handle system theme changes reactively. Add camera capture support for image uploads.
2026-02-25 19:00:17 +08:00

250 lines
7.8 KiB
Python

import asyncio
import base64
import json
import logging
import uuid
from typing import Optional
from fastapi import FastAPI, HTTPException, Request, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from geoip import get_ip_location_text
from llm import call_ollama, call_vlm_ocr
from prompt import build_completion_prompts, prepare_prompt_context
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
)
logger = logging.getLogger("api")
app = FastAPI()
ACTIVE_COMPLETIONS: dict[str, asyncio.Task] = {}
ACTIVE_COMPLETIONS_LOCK = asyncio.Lock()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*", "X-API-Key", "X-Client-IP", "X-Request-Id"],
)
API_KEY = "your-secret-key-here"
api_key_header = APIKeyHeader(name="X-API-Key")
async def get_api_key(api_key: str = Security(api_key_header)):
if api_key != API_KEY:
raise HTTPException(
status_code=403,
detail="Could not validate credentials",
)
return api_key
class UserPreferences(BaseModel):
language: str = "auto"
currency: str = "auto"
timezone: str = "auto"
class CompletionRequest(BaseModel):
prefix: str
suffix: str
languageId: str = "markdown"
model_thinking: str = "low"
privacy_mode: bool = False
user_preferences: Optional[UserPreferences] = None
class CancelCompletionRequest(BaseModel):
request_id: str
reason: str = "abort"
class OCRRequest(BaseModel):
image: str
filename: str = "image.jpg"
language: str = "auto"
def _preview(text: str, limit: int = 80) -> str:
value = (text or "").replace("\n", "\\n")
if len(value) <= limit:
return value
return value[:limit] + "..."
def _sse_payload(payload: dict) -> str:
return f"data: {json.dumps(payload)}\n\n"
def get_client_ip(request: Request) -> str:
if request.client:
return request.headers.get("X-Client-IP") or request.client.host
return request.headers.get("X-Client-IP") or "unknown"
@app.post("/v1/completions")
async def create_completion(request: Request, req: CompletionRequest, api_key: str = Security(get_api_key)):
request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4())
request_tag = request_id[:8]
inference_task: Optional[asyncio.Task] = None
client_ip = "hidden"
location = ""
if not req.privacy_mode:
client_ip = get_client_ip(request)
location = get_ip_location_text(client_ip)
if location:
logger.info("[%s] client_location=%s", request_tag, location)
try:
logger.info(
"[%s] /v1/completions request_id=%s client_ip=%s prefix_chars=%d suffix_chars=%d lang=%s thinking=%s privacy=%s",
request_tag,
request_id,
client_ip,
len(req.prefix or ""),
len(req.suffix or ""),
req.languageId,
req.model_thinking,
req.privacy_mode,
)
llm_prefix, llm_suffix = prepare_prompt_context(req.prefix or "", req.suffix or "")
logger.info("[%s] llm_input_prefix=%r", request_tag, llm_prefix)
logger.info("[%s] llm_input_suffix=%r", request_tag, llm_suffix)
system_prompt, user_prompt = build_completion_prompts(
req.prefix,
req.suffix,
req.languageId,
location=location,
thinking_level=req.model_thinking,
preferences=req.user_preferences,
)
inference_task = asyncio.create_task(
call_ollama(
user_prompt,
system_prompt=system_prompt,
tag=f"{request_tag}-primary",
temperature=0.7,
thinking=req.model_thinking if req.model_thinking != "none" else None,
)
)
async with ACTIVE_COMPLETIONS_LOCK:
existing = ACTIVE_COMPLETIONS.get(request_id)
if existing and not existing.done():
existing.cancel()
ACTIVE_COMPLETIONS[request_id] = inference_task
result = await inference_task
content = result["content"] or ""
if not content.strip():
logger.warning("[%s] primary returned empty content, returning empty result", request_tag)
logger.info(
"[%s] completion resolved source=primary request_id=%s content_chars=%d content_preview='%s'",
request_tag,
request_id,
len(content),
_preview(content, 120),
)
async def generate():
yield _sse_payload({"content": content})
yield _sse_payload({"done": True})
return StreamingResponse(generate(), media_type="text/event-stream")
except asyncio.CancelledError:
logger.info("[%s] /v1/completions cancelled request_id=%s", request_tag, request_id)
async def cancelled():
yield _sse_payload({"cancelled": True, "request_id": request_id, "done": True})
return StreamingResponse(cancelled(), media_type="text/event-stream")
except Exception as e:
logger.exception("[%s] /v1/completions failed request_id=%s: %s", request_tag, request_id, e)
return JSONResponse(content={"error": str(e)}, status_code=500)
finally:
async with ACTIVE_COMPLETIONS_LOCK:
active = ACTIVE_COMPLETIONS.get(request_id)
if active is not None and active is inference_task:
ACTIVE_COMPLETIONS.pop(request_id, None)
@app.post("/v1/completions/cancel")
async def cancel_completion(req: CancelCompletionRequest, api_key: str = Security(get_api_key)):
request_tag = str(uuid.uuid4())[:8]
request_id = req.request_id or ""
async with ACTIVE_COMPLETIONS_LOCK:
task = ACTIVE_COMPLETIONS.get(request_id)
if task is None:
logger.info(
"[%s] /v1/completions/cancel request_id=%s status=not_found reason=%s",
request_tag,
request_id,
req.reason,
)
return {"cancelled": False, "status": "not_found"}
if task.done():
logger.info(
"[%s] /v1/completions/cancel request_id=%s status=already_done reason=%s",
request_tag,
request_id,
req.reason,
)
return {"cancelled": False, "status": "already_done"}
task.cancel()
logger.info(
"[%s] /v1/completions/cancel request_id=%s status=ok reason=%s",
request_tag,
request_id,
req.reason,
)
return {"cancelled": True, "status": "ok"}
@app.post("/v1/ocr")
async def ocr_image(request: OCRRequest, api_key: str = Security(get_api_key)):
request_id = str(uuid.uuid4())[:8]
try:
logger.info(
"[%s] /v1/ocr filename=%s language=%s image_base64_chars=%d",
request_id,
request.filename,
request.language,
len(request.image or ""),
)
image_bytes = base64.b64decode(request.image)
logger.info("[%s] /v1/ocr decoded image_bytes=%d", request_id, len(image_bytes))
result = await call_vlm_ocr(image_bytes, request.language)
logger.info(
"[%s] /v1/ocr success text_chars=%d text_preview='%s'",
request_id,
len(result or ""),
_preview(result or "", 120),
)
return {"text": result, "filename": request.filename}
except Exception as e:
logger.exception("[%s] /v1/ocr failed: %s", request_id, e)
return JSONResponse(content={"error": str(e)}, status_code=500)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)