- Normalize line endings in Markdown export for DOCX files. - Improve selection serialization to Markdown with better handling of empty documents. - Add a new `updateFile` function to the file system for updating file properties. - Introduce video transcoding capabilities using FFmpeg, supporting various video formats. - Update AGENTS.md for clearer plugin structure and responsibilities. - Add scoped styles for TreeNodeItem component to improve UI consistency. - Implement cross-origin isolation headers in Vite configuration for enhanced security. - Remove obsolete test_cross.py file.
361 lines
11 KiB
Python
361 lines
11 KiB
Python
import asyncio
|
||
import base64
|
||
import logging
|
||
import os
|
||
import re
|
||
import shutil
|
||
import subprocess
|
||
import tempfile
|
||
import uuid
|
||
from typing import Optional
|
||
|
||
from fastapi import FastAPI, HTTPException, Request, Security
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import JSONResponse
|
||
from fastapi.security import APIKeyHeader
|
||
from pydantic import BaseModel
|
||
|
||
from geoip import get_ip_location_text
|
||
from llm import call_ollama, call_vlm_ocr
|
||
from models import UserPreferences
|
||
from prompt import build_completion_prompts, prepare_prompt_context
|
||
import markitdown
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
|
||
)
|
||
logger = logging.getLogger("api")
|
||
|
||
_markitdown_instance = None
|
||
|
||
|
||
def _get_markitdown(): # pragma: no cover
|
||
global _markitdown_instance
|
||
if _markitdown_instance is None:
|
||
_markitdown_instance = markitdown.MarkItDown()
|
||
return _markitdown_instance
|
||
|
||
app = FastAPI()
|
||
|
||
@app.on_event("startup") # pragma: no cover
|
||
async def startup_event():
|
||
logger.info("Starting blocking preload for TTS and ASR models...")
|
||
try:
|
||
from tts_asr import _warmup_all
|
||
await _warmup_all()
|
||
except Exception as e:
|
||
logger.warning(f"Failed to initiate model warmup: {e}")
|
||
|
||
ACTIVE_COMPLETIONS: dict[str, asyncio.Task] = {}
|
||
ACTIVE_COMPLETIONS_LOCK = asyncio.Lock()
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*", "X-API-Key", "X-Client-IP", "X-Request-Id"],
|
||
)
|
||
|
||
API_KEY = os.getenv("API_KEY", "your-secret-key-here")
|
||
api_key_header = APIKeyHeader(name="X-API-Key")
|
||
|
||
|
||
async def get_api_key(api_key: str = Security(api_key_header)): # pragma: no cover
|
||
if api_key != API_KEY:
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail="Could not validate credentials",
|
||
)
|
||
return api_key
|
||
|
||
|
||
class CompletionRequest(BaseModel):
|
||
prefix: str
|
||
suffix: str
|
||
languageId: str = "markdown"
|
||
model_thinking: str = "low"
|
||
privacy_mode: bool = False
|
||
user_preferences: Optional[UserPreferences] = None
|
||
|
||
|
||
class CancelCompletionRequest(BaseModel):
|
||
request_id: str
|
||
reason: str = "abort"
|
||
|
||
|
||
class OCRRequest(BaseModel):
|
||
image: str
|
||
filename: str = "image.jpg"
|
||
language: str = "auto"
|
||
|
||
|
||
class ConvertRequest(BaseModel):
|
||
file: str
|
||
filename: str = "document.pdf"
|
||
|
||
|
||
ALLOWED_CONVERT_EXTENSIONS = {".txt", ".docx", ".pptx", ".pdf"}
|
||
IMAGE_MARKDOWN_RE = re.compile(r"!\[[^\]]*]\([^)]+\)")
|
||
IMAGE_HTML_RE = re.compile(r"<img\b[^>]*>", re.IGNORECASE)
|
||
|
||
|
||
def _convert_docx_to_pdf(input_path: str, output_path: str) -> None: # pragma: no cover
|
||
node_executable = shutil.which("node")
|
||
if not node_executable:
|
||
raise RuntimeError("未找到 Node.js,无法转换 DOCX 为 PDF")
|
||
|
||
bridge_path = os.path.join(os.path.dirname(__file__), "docx2pdf_bridge.cjs")
|
||
if not os.path.exists(bridge_path):
|
||
raise RuntimeError("缺少 DOCX 转 PDF 桥接脚本")
|
||
|
||
result = subprocess.run(
|
||
[node_executable, bridge_path, input_path, output_path],
|
||
cwd=os.path.dirname(os.path.dirname(__file__)),
|
||
capture_output=True,
|
||
text=True,
|
||
)
|
||
|
||
if result.returncode != 0:
|
||
error_text = (result.stderr or result.stdout or "DOCX 转 PDF 失败").strip()
|
||
raise RuntimeError(error_text)
|
||
|
||
|
||
def _preview(text: str, limit: int = 80) -> str:
|
||
value = (text or "").replace("\n", "\\n")
|
||
if len(value) <= limit:
|
||
return value
|
||
return value[:limit] + "..."
|
||
|
||
|
||
def _sanitize_converted_markdown(text: str) -> str:
|
||
value = (text or "").replace("\r\n", "\n").replace("\r", "\n")
|
||
value = IMAGE_MARKDOWN_RE.sub("", value)
|
||
value = IMAGE_HTML_RE.sub("", value)
|
||
return value
|
||
|
||
|
||
def get_client_ip(request: Request) -> str:
|
||
if request.client:
|
||
return request.headers.get("X-Client-IP") or request.client.host
|
||
return request.headers.get("X-Client-IP") or "unknown"
|
||
|
||
|
||
@app.post("/v1/completions")
|
||
async def create_completion(request: Request, req: CompletionRequest, api_key: str = Security(get_api_key)):
|
||
request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4())
|
||
request_tag = request_id[:8]
|
||
inference_task: Optional[asyncio.Task] = None
|
||
|
||
client_ip = "hidden"
|
||
location = ""
|
||
|
||
if not req.privacy_mode: # pragma: no cover
|
||
client_ip = get_client_ip(request)
|
||
location = get_ip_location_text(client_ip)
|
||
if location:
|
||
logger.info("[%s] client_location=%s", request_tag, location)
|
||
|
||
try:
|
||
logger.info(
|
||
"[%s] /v1/completions request_id=%s client_ip=%s prefix_chars=%d suffix_chars=%d lang=%s thinking=%s privacy=%s",
|
||
request_tag,
|
||
request_id,
|
||
client_ip,
|
||
len(req.prefix or ""),
|
||
len(req.suffix or ""),
|
||
req.languageId,
|
||
req.model_thinking,
|
||
req.privacy_mode,
|
||
)
|
||
|
||
llm_prefix, llm_suffix = prepare_prompt_context(req.prefix or "", req.suffix or "")
|
||
logger.info("[%s] llm_input_prefix=%r", request_tag, llm_prefix)
|
||
logger.info("[%s] llm_input_suffix=%r", request_tag, llm_suffix)
|
||
|
||
system_prompt, user_prompt = build_completion_prompts(
|
||
req.prefix,
|
||
req.suffix,
|
||
req.languageId,
|
||
location=location,
|
||
thinking_level=req.model_thinking,
|
||
preferences=req.user_preferences,
|
||
)
|
||
|
||
inference_task = asyncio.create_task(
|
||
call_ollama(
|
||
user_prompt,
|
||
system_prompt=system_prompt,
|
||
tag=f"{request_tag}-primary",
|
||
temperature=0.7,
|
||
thinking=req.model_thinking if req.model_thinking != "none" else None,
|
||
)
|
||
)
|
||
|
||
existing = ACTIVE_COMPLETIONS.get(request_id)
|
||
if existing and not existing.done():
|
||
existing.cancel()
|
||
ACTIVE_COMPLETIONS[request_id] = inference_task
|
||
|
||
result = await inference_task
|
||
content = result["content"] or ""
|
||
if not content.strip():
|
||
logger.warning("[%s] primary returned empty content, returning empty result", request_tag)
|
||
logger.info(
|
||
"[%s] completion resolved source=primary request_id=%s content_chars=%d content_preview='%s'",
|
||
request_tag,
|
||
request_id,
|
||
len(content),
|
||
_preview(content, 120),
|
||
)
|
||
|
||
return JSONResponse(content={"content": content, "request_id": request_id})
|
||
except asyncio.CancelledError:
|
||
logger.info("[%s] /v1/completions cancelled request_id=%s", request_tag, request_id)
|
||
return JSONResponse(content={"cancelled": True, "request_id": request_id}, status_code=499)
|
||
except Exception as e:
|
||
logger.exception("[%s] /v1/completions failed request_id=%s: %s", request_tag, request_id, e)
|
||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||
finally:
|
||
active = ACTIVE_COMPLETIONS.get(request_id)
|
||
if active is not None and active is inference_task:
|
||
ACTIVE_COMPLETIONS.pop(request_id, None)
|
||
|
||
|
||
@app.post("/v1/completions/cancel")
|
||
async def cancel_completion(req: CancelCompletionRequest, api_key: str = Security(get_api_key)):
|
||
request_tag = str(uuid.uuid4())[:8]
|
||
request_id = req.request_id or ""
|
||
|
||
async with ACTIVE_COMPLETIONS_LOCK:
|
||
task = ACTIVE_COMPLETIONS.get(request_id)
|
||
if task is None:
|
||
logger.info(
|
||
"[%s] /v1/completions/cancel request_id=%s status=not_found reason=%s",
|
||
request_tag,
|
||
request_id,
|
||
req.reason,
|
||
)
|
||
return {"cancelled": False, "status": "not_found"}
|
||
|
||
if task.done():
|
||
logger.info(
|
||
"[%s] /v1/completions/cancel request_id=%s status=already_done reason=%s",
|
||
request_tag,
|
||
request_id,
|
||
req.reason,
|
||
)
|
||
return {"cancelled": False, "status": "already_done"}
|
||
|
||
task.cancel()
|
||
|
||
logger.info(
|
||
"[%s] /v1/completions/cancel request_id=%s status=ok reason=%s",
|
||
request_tag,
|
||
request_id,
|
||
req.reason,
|
||
)
|
||
return {"cancelled": True, "status": "ok"}
|
||
|
||
|
||
@app.post("/v1/ocr")
|
||
async def ocr_image(request: OCRRequest, api_key: str = Security(get_api_key)):
|
||
request_id = str(uuid.uuid4())[:8]
|
||
try:
|
||
logger.info(
|
||
"[%s] /v1/ocr filename=%s language=%s image_base64_chars=%d",
|
||
request_id,
|
||
request.filename,
|
||
request.language,
|
||
len(request.image or ""),
|
||
)
|
||
image_bytes = base64.b64decode(request.image)
|
||
logger.info("[%s] /v1/ocr decoded image_bytes=%d", request_id, len(image_bytes))
|
||
result = await call_vlm_ocr(image_bytes, request.language)
|
||
logger.info(
|
||
"[%s] /v1/ocr success text_chars=%d text_preview='%s'",
|
||
request_id,
|
||
len(result or ""),
|
||
_preview(result or "", 120),
|
||
)
|
||
return {"text": result, "filename": request.filename}
|
||
except Exception as e:
|
||
logger.exception("[%s] /v1/ocr failed: %s", request_id, e)
|
||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||
|
||
|
||
@app.post("/v1/convert")
|
||
async def convert_to_markdown(request: ConvertRequest, api_key: str = Security(get_api_key)):
|
||
"""Convert file to markdown"""
|
||
request_id = str(uuid.uuid4())[:8]
|
||
|
||
try:
|
||
logger.info(
|
||
"[%s] /v1/convert filename=%s file_base64_chars=%d",
|
||
request_id,
|
||
request.filename,
|
||
len(request.file or ""),
|
||
)
|
||
|
||
# Decode base64
|
||
file_bytes = base64.b64decode(request.file)
|
||
logger.info("[%s] /v1/convert decoded file_bytes=%d", request_id, len(file_bytes))
|
||
|
||
# Get file extension
|
||
ext = os.path.splitext(request.filename)[1].lower()
|
||
|
||
if ext not in ALLOWED_CONVERT_EXTENSIONS:
|
||
raise ValueError("仅支持 txt、docx、pptx、pdf 格式")
|
||
|
||
if ext == ".txt":
|
||
markdown_text = _sanitize_converted_markdown(file_bytes.decode("utf-8", errors="ignore"))
|
||
return {
|
||
"markdown": markdown_text,
|
||
"filename": request.filename
|
||
}
|
||
|
||
# Create temporary file
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
|
||
tmp.write(file_bytes)
|
||
tmp_path = tmp.name
|
||
|
||
try:
|
||
# Convert using MarkItDown
|
||
md = _get_markitdown()
|
||
result = await asyncio.to_thread(md.convert, tmp_path)
|
||
markdown_text = _sanitize_converted_markdown(result.text_content)
|
||
|
||
logger.info(
|
||
"[%s] /v1/convert success text_chars=%d text_preview='%s'",
|
||
request_id,
|
||
len(markdown_text or ""),
|
||
_preview(markdown_text, 120),
|
||
)
|
||
|
||
return {
|
||
"markdown": markdown_text,
|
||
"filename": request.filename
|
||
}
|
||
finally:
|
||
# Clean up temporary file
|
||
if os.path.exists(tmp_path):
|
||
os.unlink(tmp_path)
|
||
|
||
except Exception as e:
|
||
logger.exception("[%s] /v1/convert failed: %s", request_id, e)
|
||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||
|
||
# TTS and ASR routes (lazy loaded to avoid heavy import on startup)
|
||
def _register_tts_asr_routes():
|
||
from tts_asr import register_tts_asr_routes
|
||
register_tts_asr_routes(app)
|
||
|
||
_register_tts_asr_routes()
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
|
||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||
|