Add optional thinking parameter to the call_ollama function and pass it from the request. Also enhance timezone handling in prompt generation to support configurable timezone preferences.
152 lines
4.7 KiB
Python
152 lines
4.7 KiB
Python
from fastapi import FastAPI, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
from pydantic import BaseModel
|
|
import json
|
|
import base64
|
|
import uuid
|
|
import logging
|
|
|
|
from prompt import build_prompt, prepare_prompt_context
|
|
from llm import call_ollama, call_vlm_ocr
|
|
from geoip import get_ip_location_text
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger("api")
|
|
|
|
app = FastAPI()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
from typing import Optional
|
|
|
|
class UserPreferences(BaseModel):
|
|
language: str = 'auto'
|
|
currency: str = 'auto'
|
|
timezone: str = 'auto'
|
|
|
|
class CompletionRequest(BaseModel):
|
|
prefix: str
|
|
suffix: str
|
|
languageId: str = 'markdown'
|
|
model_thinking: str = 'low'
|
|
privacy_mode: bool = False
|
|
user_preferences: Optional[UserPreferences] = None
|
|
|
|
class OCRRequest(BaseModel):
|
|
image: str
|
|
filename: str = "image.jpg"
|
|
language: str = 'auto'
|
|
|
|
|
|
def _preview(text: str, limit: int = 80) -> str:
|
|
value = (text or "").replace("\n", "\\n")
|
|
if len(value) <= limit:
|
|
return value
|
|
return value[:limit] + "..."
|
|
|
|
def get_client_ip(request: Request) -> str:
|
|
return request.headers.get("X-Client-IP") or request.client.host if request.client else "unknown"
|
|
|
|
@app.post("/v1/completions")
|
|
async def create_completion(request: Request, req: CompletionRequest):
|
|
request_id = str(uuid.uuid4())[:8]
|
|
|
|
client_ip = "hidden"
|
|
location = ""
|
|
|
|
if not req.privacy_mode:
|
|
client_ip = get_client_ip(request)
|
|
# 查询 IP 归属地
|
|
location = get_ip_location_text(client_ip)
|
|
if location:
|
|
logger.info("[%s] client_location=%s", request_id, location)
|
|
|
|
try:
|
|
logger.info(
|
|
"[%s] /v1/completions client_ip=%s prefix_chars=%d suffix_chars=%d lang=%s thinking=%s privacy=%s",
|
|
request_id,
|
|
client_ip,
|
|
len(req.prefix or ""),
|
|
len(req.suffix or ""),
|
|
req.languageId,
|
|
req.model_thinking,
|
|
req.privacy_mode
|
|
)
|
|
llm_prefix, llm_suffix = prepare_prompt_context(req.prefix or "", req.suffix or "")
|
|
logger.info("[%s] llm_input_prefix=%r", request_id, llm_prefix)
|
|
logger.info("[%s] llm_input_suffix=%r", request_id, llm_suffix)
|
|
|
|
prompt = build_prompt(
|
|
req.prefix,
|
|
req.suffix,
|
|
req.languageId,
|
|
location=location,
|
|
thinking_level=req.model_thinking,
|
|
preferences=req.user_preferences
|
|
)
|
|
result = await call_ollama(
|
|
prompt,
|
|
tag=f"{request_id}-primary",
|
|
temperature=0.7,
|
|
thinking=req.model_thinking if req.model_thinking != "none" else None
|
|
)
|
|
|
|
content = result["content"] or ""
|
|
if not content.strip():
|
|
logger.warning("[%s] primary returned empty content, returning empty result", request_id)
|
|
logger.info(
|
|
"[%s] completion resolved source=primary content_chars=%d content_preview='%s'",
|
|
request_id,
|
|
len(content),
|
|
_preview(content, 120),
|
|
)
|
|
|
|
async def generate():
|
|
yield f"data: {json.dumps({'content': content})}\n\n"
|
|
yield f"data: {json.dumps({'done': True})}\n\n"
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream")
|
|
|
|
except Exception as e:
|
|
logger.exception("[%s] /v1/completions failed: %s", request_id, e)
|
|
return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
|
|
@app.post("/v1/ocr")
|
|
async def ocr_image(request: OCRRequest):
|
|
request_id = str(uuid.uuid4())[:8]
|
|
try:
|
|
logger.info(
|
|
"[%s] /v1/ocr filename=%s language=%s image_base64_chars=%d",
|
|
request_id,
|
|
request.filename,
|
|
request.language,
|
|
len(request.image or ""),
|
|
)
|
|
image_bytes = base64.b64decode(request.image)
|
|
logger.info("[%s] /v1/ocr decoded image_bytes=%d", request_id, len(image_bytes))
|
|
result = await call_vlm_ocr(image_bytes, request.language)
|
|
logger.info(
|
|
"[%s] /v1/ocr success text_chars=%d text_preview='%s'",
|
|
request_id,
|
|
len(result or ""),
|
|
_preview(result or "", 120),
|
|
)
|
|
return {"text": result, "filename": request.filename}
|
|
except Exception as e:
|
|
logger.exception("[%s] /v1/ocr failed: %s", request_id, e)
|
|
return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8001)
|