Add support for cancelling in-progress LLM completion requests via new /v1/completions/cancel endpoint with task tracking. Implement mermaid diagram rendering in the Milkdown editor with a new mermaidPlugin. Update copilotPlugin to properly abort requests with descriptive reasons. Refactor settings panel to handle system theme changes reactively. Add camera capture support for image uploads.
250 lines
7.8 KiB
Python
250 lines
7.8 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from typing import Optional
|
|
|
|
from fastapi import FastAPI, HTTPException, Request, Security
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from fastapi.security import APIKeyHeader
|
|
from pydantic import BaseModel
|
|
|
|
from geoip import get_ip_location_text
|
|
from llm import call_ollama, call_vlm_ocr
|
|
from prompt import build_completion_prompts, prepare_prompt_context
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger("api")
|
|
|
|
app = FastAPI()
|
|
|
|
ACTIVE_COMPLETIONS: dict[str, asyncio.Task] = {}
|
|
ACTIVE_COMPLETIONS_LOCK = asyncio.Lock()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*", "X-API-Key", "X-Client-IP", "X-Request-Id"],
|
|
)
|
|
|
|
API_KEY = "your-secret-key-here"
|
|
api_key_header = APIKeyHeader(name="X-API-Key")
|
|
|
|
|
|
async def get_api_key(api_key: str = Security(api_key_header)):
|
|
if api_key != API_KEY:
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="Could not validate credentials",
|
|
)
|
|
return api_key
|
|
|
|
|
|
class UserPreferences(BaseModel):
|
|
language: str = "auto"
|
|
currency: str = "auto"
|
|
timezone: str = "auto"
|
|
|
|
|
|
class CompletionRequest(BaseModel):
|
|
prefix: str
|
|
suffix: str
|
|
languageId: str = "markdown"
|
|
model_thinking: str = "low"
|
|
privacy_mode: bool = False
|
|
user_preferences: Optional[UserPreferences] = None
|
|
|
|
|
|
class CancelCompletionRequest(BaseModel):
|
|
request_id: str
|
|
reason: str = "abort"
|
|
|
|
|
|
class OCRRequest(BaseModel):
|
|
image: str
|
|
filename: str = "image.jpg"
|
|
language: str = "auto"
|
|
|
|
|
|
def _preview(text: str, limit: int = 80) -> str:
|
|
value = (text or "").replace("\n", "\\n")
|
|
if len(value) <= limit:
|
|
return value
|
|
return value[:limit] + "..."
|
|
|
|
|
|
def _sse_payload(payload: dict) -> str:
|
|
return f"data: {json.dumps(payload)}\n\n"
|
|
|
|
|
|
def get_client_ip(request: Request) -> str:
|
|
if request.client:
|
|
return request.headers.get("X-Client-IP") or request.client.host
|
|
return request.headers.get("X-Client-IP") or "unknown"
|
|
|
|
|
|
@app.post("/v1/completions")
|
|
async def create_completion(request: Request, req: CompletionRequest, api_key: str = Security(get_api_key)):
|
|
request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4())
|
|
request_tag = request_id[:8]
|
|
inference_task: Optional[asyncio.Task] = None
|
|
|
|
client_ip = "hidden"
|
|
location = ""
|
|
|
|
if not req.privacy_mode:
|
|
client_ip = get_client_ip(request)
|
|
location = get_ip_location_text(client_ip)
|
|
if location:
|
|
logger.info("[%s] client_location=%s", request_tag, location)
|
|
|
|
try:
|
|
logger.info(
|
|
"[%s] /v1/completions request_id=%s client_ip=%s prefix_chars=%d suffix_chars=%d lang=%s thinking=%s privacy=%s",
|
|
request_tag,
|
|
request_id,
|
|
client_ip,
|
|
len(req.prefix or ""),
|
|
len(req.suffix or ""),
|
|
req.languageId,
|
|
req.model_thinking,
|
|
req.privacy_mode,
|
|
)
|
|
|
|
llm_prefix, llm_suffix = prepare_prompt_context(req.prefix or "", req.suffix or "")
|
|
logger.info("[%s] llm_input_prefix=%r", request_tag, llm_prefix)
|
|
logger.info("[%s] llm_input_suffix=%r", request_tag, llm_suffix)
|
|
|
|
system_prompt, user_prompt = build_completion_prompts(
|
|
req.prefix,
|
|
req.suffix,
|
|
req.languageId,
|
|
location=location,
|
|
thinking_level=req.model_thinking,
|
|
preferences=req.user_preferences,
|
|
)
|
|
|
|
inference_task = asyncio.create_task(
|
|
call_ollama(
|
|
user_prompt,
|
|
system_prompt=system_prompt,
|
|
tag=f"{request_tag}-primary",
|
|
temperature=0.7,
|
|
thinking=req.model_thinking if req.model_thinking != "none" else None,
|
|
)
|
|
)
|
|
|
|
async with ACTIVE_COMPLETIONS_LOCK:
|
|
existing = ACTIVE_COMPLETIONS.get(request_id)
|
|
if existing and not existing.done():
|
|
existing.cancel()
|
|
ACTIVE_COMPLETIONS[request_id] = inference_task
|
|
|
|
result = await inference_task
|
|
content = result["content"] or ""
|
|
if not content.strip():
|
|
logger.warning("[%s] primary returned empty content, returning empty result", request_tag)
|
|
logger.info(
|
|
"[%s] completion resolved source=primary request_id=%s content_chars=%d content_preview='%s'",
|
|
request_tag,
|
|
request_id,
|
|
len(content),
|
|
_preview(content, 120),
|
|
)
|
|
|
|
async def generate():
|
|
yield _sse_payload({"content": content})
|
|
yield _sse_payload({"done": True})
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream")
|
|
except asyncio.CancelledError:
|
|
logger.info("[%s] /v1/completions cancelled request_id=%s", request_tag, request_id)
|
|
|
|
async def cancelled():
|
|
yield _sse_payload({"cancelled": True, "request_id": request_id, "done": True})
|
|
|
|
return StreamingResponse(cancelled(), media_type="text/event-stream")
|
|
except Exception as e:
|
|
logger.exception("[%s] /v1/completions failed request_id=%s: %s", request_tag, request_id, e)
|
|
return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
finally:
|
|
async with ACTIVE_COMPLETIONS_LOCK:
|
|
active = ACTIVE_COMPLETIONS.get(request_id)
|
|
if active is not None and active is inference_task:
|
|
ACTIVE_COMPLETIONS.pop(request_id, None)
|
|
|
|
|
|
@app.post("/v1/completions/cancel")
|
|
async def cancel_completion(req: CancelCompletionRequest, api_key: str = Security(get_api_key)):
|
|
request_tag = str(uuid.uuid4())[:8]
|
|
request_id = req.request_id or ""
|
|
|
|
async with ACTIVE_COMPLETIONS_LOCK:
|
|
task = ACTIVE_COMPLETIONS.get(request_id)
|
|
if task is None:
|
|
logger.info(
|
|
"[%s] /v1/completions/cancel request_id=%s status=not_found reason=%s",
|
|
request_tag,
|
|
request_id,
|
|
req.reason,
|
|
)
|
|
return {"cancelled": False, "status": "not_found"}
|
|
|
|
if task.done():
|
|
logger.info(
|
|
"[%s] /v1/completions/cancel request_id=%s status=already_done reason=%s",
|
|
request_tag,
|
|
request_id,
|
|
req.reason,
|
|
)
|
|
return {"cancelled": False, "status": "already_done"}
|
|
|
|
task.cancel()
|
|
|
|
logger.info(
|
|
"[%s] /v1/completions/cancel request_id=%s status=ok reason=%s",
|
|
request_tag,
|
|
request_id,
|
|
req.reason,
|
|
)
|
|
return {"cancelled": True, "status": "ok"}
|
|
|
|
|
|
@app.post("/v1/ocr")
|
|
async def ocr_image(request: OCRRequest, api_key: str = Security(get_api_key)):
|
|
request_id = str(uuid.uuid4())[:8]
|
|
try:
|
|
logger.info(
|
|
"[%s] /v1/ocr filename=%s language=%s image_base64_chars=%d",
|
|
request_id,
|
|
request.filename,
|
|
request.language,
|
|
len(request.image or ""),
|
|
)
|
|
image_bytes = base64.b64decode(request.image)
|
|
logger.info("[%s] /v1/ocr decoded image_bytes=%d", request_id, len(image_bytes))
|
|
result = await call_vlm_ocr(image_bytes, request.language)
|
|
logger.info(
|
|
"[%s] /v1/ocr success text_chars=%d text_preview='%s'",
|
|
request_id,
|
|
len(result or ""),
|
|
_preview(result or "", 120),
|
|
)
|
|
return {"text": result, "filename": request.filename}
|
|
except Exception as e:
|
|
logger.exception("[%s] /v1/ocr failed: %s", request_id, e)
|
|
return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8001)
|