Files
llm-in-text/backend/main.py
ydy0615 59334e4057 Stabilize pro editing without heavy office runtime
The workspace now carries the pro editing flow, streaming completion path, and lighter Office preview state as one checkpoint so the remote has the current runnable project shape.

Constraint: Preserve the current workspace as a single reviewable project commit while excluding local agent state and verification artifacts. Removed stale Univer runtime dependencies from the lockfile so installs match package.json.

Rejected: Commit runtime screenshots, .omx state, and coverage files | they are local artifacts rather than source state.

Confidence: medium

Scope-risk: broad

Directive: Keep package.json and package-lock.json synchronized when changing frontend dependencies.

Tested: npm run build; C:\Users\ydy\.conda\envs\llmwebsite\python.exe -m pytest backend/tests/test_main_endpoints.py backend/tests/test_main_cancel.py backend/tests/test_llm.py backend/tests/test_llm_extended.py -v -o addopts= (44 passed).

Not-tested: Full pytest with repository coverage addopts currently reports 0% coverage because pytest-cov watches backend.* module names while tests import top-level backend modules.

Co-authored-by: OmX <omx@oh-my-codex.dev>
2026-05-24 23:30:32 +08:00

503 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import base64
import json
import logging
import os
import re
import shutil
import subprocess
import tempfile
import uuid
from typing import Optional
from fastapi import FastAPI, HTTPException, Request, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from geoip import get_ip_location_text
from llm import call_ollama, call_vlm_ocr, stream_ollama
from models import UserPreferences
from prompt import build_completion_prompts, prepare_prompt_context
import markitdown
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
)
logger = logging.getLogger("api")
_markitdown_instance = None
def _get_markitdown(): # pragma: no cover
global _markitdown_instance
if _markitdown_instance is None:
_markitdown_instance = markitdown.MarkItDown()
return _markitdown_instance
app = FastAPI()
@app.on_event("startup") # pragma: no cover
async def startup_event():
logger.info("Starting blocking preload for TTS and ASR models...")
try:
from tts_asr import _warmup_all
await _warmup_all()
except Exception as e:
logger.warning(f"Failed to initiate model warmup: {e}")
ACTIVE_COMPLETIONS: dict[str, asyncio.Task] = {}
ACTIVE_COMPLETIONS_LOCK = asyncio.Lock()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*", "X-API-Key", "X-Client-IP", "X-Request-Id"],
)
API_KEY = os.getenv("API_KEY", "your-secret-key-here")
api_key_header = APIKeyHeader(name="X-API-Key")
async def get_api_key(api_key: str = Security(api_key_header)): # pragma: no cover
if api_key != API_KEY:
raise HTTPException(
status_code=403,
detail="Could not validate credentials",
)
return api_key
class CompletionRequest(BaseModel):
prefix: str
suffix: str
languageId: str = "markdown"
model_thinking: str = "low"
privacy_mode: bool = False
user_preferences: Optional[UserPreferences] = None
model: Optional[str] = None
temperature: float = 0.7
class CancelCompletionRequest(BaseModel):
request_id: str
reason: str = "abort"
class OCRRequest(BaseModel):
image: str
filename: str = "image.jpg"
language: str = "auto"
class ConvertRequest(BaseModel):
file: str
filename: str = "document.pdf"
ALLOWED_CONVERT_EXTENSIONS = {".txt", ".docx", ".pptx", ".pdf"}
IMAGE_MARKDOWN_RE = re.compile(r"!\[[^\]]*]\([^)]+\)")
IMAGE_HTML_RE = re.compile(r"<img\b[^>]*>", re.IGNORECASE)
def _convert_docx_to_pdf(input_path: str, output_path: str) -> None: # pragma: no cover
node_executable = shutil.which("node")
if not node_executable:
raise RuntimeError("未找到 Node.js无法转换 DOCX 为 PDF")
bridge_path = os.path.join(os.path.dirname(__file__), "docx2pdf_bridge.cjs")
if not os.path.exists(bridge_path):
raise RuntimeError("缺少 DOCX 转 PDF 桥接脚本")
result = subprocess.run(
[node_executable, bridge_path, input_path, output_path],
cwd=os.path.dirname(os.path.dirname(__file__)),
capture_output=True,
text=True,
)
if result.returncode != 0:
error_text = (result.stderr or result.stdout or "DOCX 转 PDF 失败").strip()
raise RuntimeError(error_text)
def _preview(text: str, limit: int = 80) -> str:
value = (text or "").replace("\n", "\\n")
if len(value) <= limit:
return value
return value[:limit] + "..."
def _sanitize_converted_markdown(text: str) -> str:
value = (text or "").replace("\r\n", "\n").replace("\r", "\n")
value = IMAGE_MARKDOWN_RE.sub("", value)
value = IMAGE_HTML_RE.sub("", value)
value = re.sub(r"\n{3,}", "\n\n", value)
return value.strip()
def get_client_ip(request: Request) -> str:
if request.client:
return request.headers.get("X-Client-IP") or request.client.host
return request.headers.get("X-Client-IP") or "unknown"
def _clamp_temperature(value: float, default: float = 0.7) -> float:
try:
numeric = float(value)
except (TypeError, ValueError):
return default
return max(0.0, min(numeric, 1.2))
@app.post("/v1/completions")
async def create_completion(request: Request, req: CompletionRequest, api_key: str = Security(get_api_key)):
request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4())
request_tag = request_id[:8]
inference_task: Optional[asyncio.Task] = None
client_ip = "hidden"
location = ""
if not req.privacy_mode: # pragma: no cover
client_ip = get_client_ip(request)
location = get_ip_location_text(client_ip)
if location:
logger.info("[%s] client_location=%s", request_tag, location)
try:
logger.info(
"[%s] /v1/completions request_id=%s client_ip=%s prefix_chars=%d suffix_chars=%d lang=%s thinking=%s privacy=%s",
request_tag,
request_id,
client_ip,
len(req.prefix or ""),
len(req.suffix or ""),
req.languageId,
req.model_thinking,
req.privacy_mode,
)
llm_prefix, llm_suffix = prepare_prompt_context(req.prefix or "", req.suffix or "")
logger.info("[%s] llm_input_prefix=%r", request_tag, llm_prefix)
logger.info("[%s] llm_input_suffix=%r", request_tag, llm_suffix)
system_prompt, user_prompt = build_completion_prompts(
req.prefix,
req.suffix,
req.languageId,
location=location,
thinking_level=req.model_thinking,
preferences=req.user_preferences,
)
inference_task = asyncio.create_task(
call_ollama(
user_prompt,
system_prompt=system_prompt,
tag=f"{request_tag}-primary",
temperature=_clamp_temperature(req.temperature, 0.7),
thinking=req.model_thinking if req.model_thinking != "none" else None,
model=req.model,
)
)
existing = ACTIVE_COMPLETIONS.get(request_id)
if existing and not existing.done():
existing.cancel()
ACTIVE_COMPLETIONS[request_id] = inference_task
result = await inference_task
content = result["content"] or ""
if not content.strip():
logger.warning("[%s] primary returned empty content, returning empty result", request_tag)
logger.info(
"[%s] completion resolved source=primary request_id=%s content_chars=%d content_preview='%s'",
request_tag,
request_id,
len(content),
_preview(content, 120),
)
return JSONResponse(content={"content": content, "request_id": request_id})
except asyncio.CancelledError:
logger.info("[%s] /v1/completions cancelled request_id=%s", request_tag, request_id)
return JSONResponse(content={"cancelled": True, "request_id": request_id}, status_code=499)
except Exception as e:
logger.exception("[%s] /v1/completions failed request_id=%s: %s", request_tag, request_id, e)
return JSONResponse(content={"error": str(e)}, status_code=500)
finally:
active = ACTIVE_COMPLETIONS.get(request_id)
if active is not None and active is inference_task:
ACTIVE_COMPLETIONS.pop(request_id, None)
@app.post("/v1/pro/completions/stream")
async def create_pro_completion_stream(request: Request, req: CompletionRequest, api_key: str = Security(get_api_key)):
request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4())
request_tag = request_id[:8]
queue: asyncio.Queue[Optional[tuple[str, str]]] = asyncio.Queue()
client_ip = "hidden"
location = ""
if not req.privacy_mode: # pragma: no cover
client_ip = get_client_ip(request)
location = get_ip_location_text(client_ip)
if location:
logger.info("[%s] client_location=%s", request_tag, location)
logger.info(
"[%s] /v1/pro/completions/stream request_id=%s client_ip=%s prefix_chars=%d suffix_chars=%d lang=%s thinking=%s privacy=%s model=%s temp=%.2f",
request_tag,
request_id,
client_ip,
len(req.prefix or ""),
len(req.suffix or ""),
req.languageId,
req.model_thinking,
req.privacy_mode,
req.model or "",
_clamp_temperature(req.temperature, 0.7),
)
llm_prefix, llm_suffix = prepare_prompt_context(req.prefix or "", req.suffix or "")
logger.info("[%s] pro_llm_input_prefix=%r", request_tag, llm_prefix)
logger.info("[%s] pro_llm_input_suffix=%r", request_tag, llm_suffix)
system_prompt, user_prompt = build_completion_prompts(
req.prefix,
req.suffix,
req.languageId,
location=location,
thinking_level=req.model_thinking,
preferences=req.user_preferences,
)
async def producer() -> None:
chunks: list[str] = []
try:
async for delta in stream_ollama(
user_prompt,
system_prompt=system_prompt,
tag=f"{request_tag}-pro",
temperature=_clamp_temperature(req.temperature, 0.7),
thinking=req.model_thinking if req.model_thinking != "none" else None,
model=req.model,
use_pro_model=True,
):
chunks.append(delta)
await queue.put(("chunk", json.dumps({"delta": delta}, ensure_ascii=False)))
content = "".join(chunks)
logger.info(
"[%s] pro stream resolved request_id=%s content_chars=%d content_preview='%s'",
request_tag,
request_id,
len(content),
_preview(content, 120),
)
await queue.put((
"done",
json.dumps({"content": content, "request_id": request_id}, ensure_ascii=False),
))
except asyncio.CancelledError:
logger.info("[%s] /v1/pro/completions/stream cancelled request_id=%s", request_tag, request_id)
await queue.put((
"cancelled",
json.dumps({"cancelled": True, "request_id": request_id}, ensure_ascii=False),
))
raise
except Exception as e:
logger.exception("[%s] /v1/pro/completions/stream failed request_id=%s: %s", request_tag, request_id, e)
await queue.put((
"error",
json.dumps({"error": str(e), "request_id": request_id}, ensure_ascii=False),
))
finally:
await queue.put(None)
producer_task = asyncio.create_task(producer())
existing = ACTIVE_COMPLETIONS.get(request_id)
if existing and not existing.done():
existing.cancel()
ACTIVE_COMPLETIONS[request_id] = producer_task
async def event_stream():
try:
while True:
item = await queue.get()
if item is None:
break
event_name, data = item
yield f"event: {event_name}\ndata: {data}\n\n"
except asyncio.CancelledError:
producer_task.cancel()
raise
finally:
active = ACTIVE_COMPLETIONS.get(request_id)
if active is producer_task:
ACTIVE_COMPLETIONS.pop(request_id, None)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
@app.post("/v1/completions/cancel")
async def cancel_completion(req: CancelCompletionRequest, api_key: str = Security(get_api_key)):
request_tag = str(uuid.uuid4())[:8]
request_id = req.request_id or ""
async with ACTIVE_COMPLETIONS_LOCK:
task = ACTIVE_COMPLETIONS.get(request_id)
if task is None:
logger.info(
"[%s] /v1/completions/cancel request_id=%s status=not_found reason=%s",
request_tag,
request_id,
req.reason,
)
return {"cancelled": False, "status": "not_found"}
if task.done():
logger.info(
"[%s] /v1/completions/cancel request_id=%s status=already_done reason=%s",
request_tag,
request_id,
req.reason,
)
return {"cancelled": False, "status": "already_done"}
task.cancel()
logger.info(
"[%s] /v1/completions/cancel request_id=%s status=ok reason=%s",
request_tag,
request_id,
req.reason,
)
return {"cancelled": True, "status": "ok"}
@app.post("/v1/ocr")
async def ocr_image(request: OCRRequest, api_key: str = Security(get_api_key)):
request_id = str(uuid.uuid4())[:8]
try:
logger.info(
"[%s] /v1/ocr filename=%s language=%s image_base64_chars=%d",
request_id,
request.filename,
request.language,
len(request.image or ""),
)
image_bytes = base64.b64decode(request.image)
logger.info("[%s] /v1/ocr decoded image_bytes=%d", request_id, len(image_bytes))
result = await call_vlm_ocr(image_bytes, request.language)
logger.info(
"[%s] /v1/ocr success text_chars=%d text_preview='%s'",
request_id,
len(result or ""),
_preview(result or "", 120),
)
return {"text": result, "filename": request.filename}
except Exception as e:
logger.exception("[%s] /v1/ocr failed: %s", request_id, e)
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.post("/v1/convert")
async def convert_to_markdown(request: ConvertRequest, api_key: str = Security(get_api_key)):
"""Convert file to markdown"""
request_id = str(uuid.uuid4())[:8]
try:
logger.info(
"[%s] /v1/convert filename=%s file_base64_chars=%d",
request_id,
request.filename,
len(request.file or ""),
)
# Decode base64
file_bytes = base64.b64decode(request.file)
logger.info("[%s] /v1/convert decoded file_bytes=%d", request_id, len(file_bytes))
# Get file extension
ext = os.path.splitext(request.filename)[1].lower()
if ext not in ALLOWED_CONVERT_EXTENSIONS:
raise ValueError("仅支持 txt、docx、pptx、pdf 格式")
if ext == ".txt":
markdown_text = _sanitize_converted_markdown(file_bytes.decode("utf-8", errors="ignore"))
return {
"markdown": markdown_text,
"filename": request.filename
}
# Create temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
tmp.write(file_bytes)
tmp_path = tmp.name
try:
# Convert using MarkItDown
md = _get_markitdown()
result = await asyncio.to_thread(md.convert, tmp_path)
markdown_text = _sanitize_converted_markdown(result.text_content)
logger.info(
"[%s] /v1/convert success text_chars=%d text_preview='%s'",
request_id,
len(markdown_text or ""),
_preview(markdown_text, 120),
)
return {
"markdown": markdown_text,
"filename": request.filename
}
finally:
# Clean up temporary file
if os.path.exists(tmp_path):
os.unlink(tmp_path)
except Exception as e:
logger.exception("[%s] /v1/convert failed: %s", request_id, e)
return JSONResponse(content={"error": str(e)}, status_code=500)
# TTS and ASR routes (lazy loaded to avoid heavy import on startup)
def _register_tts_asr_routes():
try:
from tts_asr import register_tts_asr_routes
except ModuleNotFoundError as exc:
logger.warning("Skipping TTS/ASR route registration because a dependency is missing: %s", exc)
return
except Exception as exc:
logger.warning("Skipping TTS/ASR route registration because import failed: %s", exc)
return
try:
register_tts_asr_routes(app)
except Exception as exc:
logger.warning("Failed to register TTS/ASR routes: %s", exc)
_register_tts_asr_routes()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)