548 lines
18 KiB
Python
548 lines
18 KiB
Python
import os
|
|
import time
|
|
import logging
|
|
import asyncio
|
|
import json
|
|
import base64
|
|
from datetime import datetime
|
|
from typing import AsyncIterator, Literal
|
|
|
|
import httpx
|
|
from dotenv import load_dotenv
|
|
|
|
from prompts import get_vlm_ocr_prompt
|
|
|
|
load_dotenv()
|
|
|
|
# OpenAI-compatible endpoint config
|
|
LLM_BASE_URL = os.getenv('LLM_BASE_URL', 'http://localhost:11434/v1/')
|
|
LLM_API_KEY = os.getenv('LLM_API_KEY', 'ollama')
|
|
|
|
# Model names (backward compat: fall back to OLLAMA_MODEL if LLM_MODEL not set)
|
|
_raw_model = os.getenv('LLM_MODEL') or os.getenv('OLLAMA_MODEL', 'gpt-oss:20b')
|
|
LLM_MODEL = _raw_model.strip() if _raw_model else 'gpt-oss:20b'
|
|
PRO_LLM_MODEL = os.getenv('PRO_LLM_MODEL', LLM_MODEL)
|
|
|
|
# VLM for OCR (vision models)
|
|
VLM_MODEL = os.getenv('VLM_MODEL', 'qwen3-vl:30b')
|
|
|
|
# Fallback for legacy OLLAMA_HOST env var (auto-convert to /v1/ path)
|
|
_legacy_host = os.getenv('OLLAMA_HOST')
|
|
if _legacy_host and not os.getenv('LLM_BASE_URL'):
|
|
base = _legacy_host.rstrip('/')
|
|
if '/v1' not in base:
|
|
LLM_BASE_URL = f"{base}/v1/"
|
|
|
|
# Normalize trailing slash for base URL
|
|
LLM_BASE_URL = LLM_BASE_URL.rstrip('/') + '/'
|
|
|
|
# Timeouts in seconds (10 minutes for large model loading)
|
|
COMPLETION_TIMEOUT = int(os.getenv("LLM_COMPLETION_TIMEOUT", "600"))
|
|
OCR_TIMEOUT = int(os.getenv("LLM_OCR_TIMEOUT", "600"))
|
|
|
|
logger = logging.getLogger('llm')
|
|
|
|
|
|
def _extract_message(response: dict) -> tuple[str, str]:
|
|
"""Extract content and thinking from an OpenAI-compatible response dict."""
|
|
choices = response.get('choices', []) if isinstance(response, dict) else []
|
|
msg = (choices[0].get('message', {}) if choices and isinstance(choices, list) else {}).copy()
|
|
content = msg.get('content', '') or ''
|
|
thinking = (msg.get('reasoning_content') or msg.get('thinking', '') or '').strip()
|
|
return content, thinking
|
|
|
|
|
|
def _resolve_system_prompt(system_prompt: str | None) -> str:
|
|
if system_prompt and system_prompt.strip():
|
|
return system_prompt.strip()
|
|
return ''
|
|
|
|
|
|
def _resolve_model_name(model: str | None = None, *, use_pro_model: bool = False) -> str:
|
|
candidate = (model or '').strip()
|
|
if candidate:
|
|
return candidate
|
|
return PRO_LLM_MODEL if use_pro_model else LLM_MODEL
|
|
|
|
|
|
def _build_chat_payload(
|
|
prompt: str,
|
|
*,
|
|
system_prompt: str | None = None,
|
|
temperature: float = 0.7,
|
|
thinking: str | None = None,
|
|
model: str | None = None,
|
|
use_pro_model: bool = False,
|
|
) -> dict:
|
|
messages = []
|
|
sys_prompt = _resolve_system_prompt(system_prompt)
|
|
if sys_prompt:
|
|
messages.append({'role': 'system', 'content': sys_prompt})
|
|
messages.append({'role': 'user', 'content': prompt})
|
|
|
|
payload = {
|
|
'model': _resolve_model_name(model, use_pro_model=use_pro_model),
|
|
'messages': messages,
|
|
'stream': False,
|
|
}
|
|
|
|
options = {'temperature': temperature}
|
|
if thinking:
|
|
payload['options'] = {'temperature': temperature, 'think': thinking}
|
|
|
|
return payload
|
|
|
|
|
|
def _build_chat_stream_payload(
|
|
prompt: str,
|
|
*,
|
|
system_prompt: str | None = None,
|
|
temperature: float = 0.7,
|
|
thinking: str | None = None,
|
|
model: str | None = None,
|
|
use_pro_model: bool = False,
|
|
) -> dict:
|
|
messages = []
|
|
sys_prompt = _resolve_system_prompt(system_prompt)
|
|
if sys_prompt:
|
|
messages.append({'role': 'system', 'content': sys_prompt})
|
|
messages.append({'role': 'user', 'content': prompt})
|
|
|
|
payload = {
|
|
'model': _resolve_model_name(model, use_pro_model=use_pro_model),
|
|
'messages': messages,
|
|
'stream': True,
|
|
}
|
|
|
|
options = {'temperature': temperature}
|
|
if thinking:
|
|
payload['options'] = {'temperature': temperature, 'think': thinking}
|
|
|
|
return payload
|
|
|
|
|
|
def _extract_delta_text(chunk: dict) -> str:
|
|
"""Extract text delta from an OpenAI-compatible SSE chunk."""
|
|
choices = chunk.get('choices', []) if isinstance(chunk, dict) else []
|
|
delta = (choices[0].get('delta', {}) if choices and isinstance(choices, list) else {}).copy()
|
|
content = delta.get('content', '') or ''
|
|
return content
|
|
|
|
|
|
def _extract_delta_thinking(chunk: dict) -> str:
|
|
"""Extract thinking/reasoning delta from an SSE chunk."""
|
|
choices = chunk.get('choices', []) if isinstance(chunk, dict) else []
|
|
delta = (choices[0].get('delta', {}) if choices and isinstance(choices, list) else {}).copy()
|
|
return (delta.get('reasoning_content') or delta.get('thinking', '') or '').strip()
|
|
|
|
|
|
async def call_ollama(
|
|
prompt: str,
|
|
*,
|
|
system_prompt: str | None = None,
|
|
tag: str = 'default',
|
|
temperature: float = 0.7,
|
|
thinking: str | None = None,
|
|
model: str | None = None,
|
|
use_pro_model: bool = False,
|
|
) -> dict:
|
|
"""Call OpenAI-compatible chat completions (non-streaming) and return content/thinking."""
|
|
start = time.perf_counter()
|
|
start_dt = datetime.now()
|
|
|
|
model_name = _resolve_model_name(model, use_pro_model=use_pro_model)
|
|
log_model_name = 'pro' if (model is None and use_pro_model) else model_name
|
|
|
|
logger.info(
|
|
'[LLM][%s] request model=%s base_url=%s prompt_chars=%d system_chars=%d temp=%.2f thinking=%s',
|
|
tag, log_model_name, LLM_BASE_URL, len(prompt),
|
|
len(system_prompt or ''), temperature, thinking,
|
|
)
|
|
|
|
payload = _build_chat_payload(
|
|
prompt=prompt, system_prompt=system_prompt, temperature=temperature,
|
|
thinking=thinking, model=model, use_pro_model=use_pro_model,
|
|
)
|
|
|
|
http_timeout = httpx.Timeout(connect=10.0, read=None, write=30.0, pool=30.0)
|
|
|
|
try:
|
|
async with httpx.AsyncClient(base_url=LLM_BASE_URL, timeout=http_timeout) as client:
|
|
resp = await asyncio.wait_for(
|
|
client.post('/chat/completions', json=payload), timeout=COMPLETION_TIMEOUT,
|
|
)
|
|
|
|
resp.raise_for_status()
|
|
response = resp.json()
|
|
|
|
except asyncio.CancelledError:
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
end_dt = datetime.now()
|
|
|
|
logger.info(
|
|
'[LLM][%s] call_time [%s --> %s]', tag,
|
|
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
|
|
)
|
|
|
|
logger.warning('[LLM][%s] request cancelled after %.1fms', tag, elapsed_ms)
|
|
raise
|
|
|
|
except Exception:
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
end_dt = datetime.now()
|
|
|
|
logger.info(
|
|
'[LLM][%s] call_time [%s --> %s]', tag,
|
|
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
|
|
)
|
|
|
|
logger.exception('[LLM][%s] request failed after %.1fms', tag, elapsed_ms)
|
|
raise
|
|
|
|
content, thinking_out = _extract_message(response)
|
|
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
end_dt = datetime.now()
|
|
|
|
logger.info(
|
|
'[LLM][%s] call_time [%s --> %s]', tag,
|
|
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
|
|
)
|
|
|
|
logger.info(
|
|
'[LLM][%s] response in %.1fms content_chars=%d thinking_chars=%d',
|
|
tag, elapsed_ms, len(content), len(thinking_out or ''),
|
|
)
|
|
|
|
if not content.strip():
|
|
logger.warning('[LLM][%s] empty content returned by model', tag)
|
|
|
|
return {'content': content, 'think': thinking_out or ''}
|
|
|
|
|
|
async def stream_ollama(
|
|
prompt: str,
|
|
*,
|
|
system_prompt: str | None = None,
|
|
tag: str = 'default-stream',
|
|
temperature: float = 0.7,
|
|
thinking: str | None = None,
|
|
model: str | None = None,
|
|
use_pro_model: bool = False,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream text deltas from OpenAI-compatible chat completions."""
|
|
start = time.perf_counter()
|
|
start_dt = datetime.now()
|
|
|
|
model_name = _resolve_model_name(model, use_pro_model=use_pro_model)
|
|
log_model_name = 'pro' if (model is None and use_pro_model) else model_name
|
|
yielded_chars = 0
|
|
|
|
logger.info(
|
|
'[LLM][%s] stream request model=%s base_url=%s prompt_chars=%d system_chars=%d temp=%.2f thinking=%s',
|
|
tag, log_model_name, LLM_BASE_URL, len(prompt),
|
|
len(system_prompt or ''), temperature, thinking,
|
|
)
|
|
|
|
payload = _build_chat_stream_payload(
|
|
prompt=prompt, system_prompt=system_prompt, temperature=temperature,
|
|
thinking=thinking, model=model, use_pro_model=use_pro_model,
|
|
)
|
|
|
|
http_timeout = httpx.Timeout(connect=10.0, read=None, write=30.0, pool=30.0)
|
|
|
|
try:
|
|
async with httpx.AsyncClient(base_url=LLM_BASE_URL, timeout=http_timeout) as client:
|
|
try:
|
|
async with client.stream('POST', '/chat/completions', json=payload) as response:
|
|
response.raise_for_status()
|
|
|
|
deadline = time.perf_counter() + COMPLETION_TIMEOUT
|
|
line_iterator = response.aiter_lines().__aiter__()
|
|
|
|
while True:
|
|
remaining = deadline - time.perf_counter()
|
|
if remaining <= 0:
|
|
raise TimeoutError('LLM stream timed out')
|
|
|
|
try:
|
|
line = await asyncio.wait_for(line_iterator.__anext__(), timeout=remaining)
|
|
except StopAsyncIteration:
|
|
break
|
|
|
|
if not line or line.startswith(':'):
|
|
continue
|
|
|
|
# SSE data lines: "data: {json}" or "data: [DONE]"
|
|
if line.startswith('data: '):
|
|
data_str = line[6:] # strip "data: " prefix
|
|
|
|
else:
|
|
data_str = line.strip()
|
|
|
|
if not data_str or data_str == '[DONE]':
|
|
continue
|
|
|
|
try:
|
|
chunk = json.loads(data_str)
|
|
except json.JSONDecodeError:
|
|
logger.warning('[LLM][%s] ignored invalid stream line', tag)
|
|
continue
|
|
|
|
if not isinstance(chunk, dict):
|
|
continue
|
|
|
|
text = _extract_delta_text(chunk)
|
|
if not text:
|
|
continue
|
|
|
|
yielded_chars += len(text)
|
|
yield text
|
|
|
|
except asyncio.CancelledError:
|
|
if response is not None:
|
|
await response.aclose()
|
|
raise
|
|
|
|
except asyncio.CancelledError:
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
end_dt = datetime.now()
|
|
|
|
logger.info(
|
|
'[LLM][%s] stream_time [%s --> %s]', tag,
|
|
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
|
|
)
|
|
|
|
logger.warning('[LLM][%s] stream cancelled after %.1fms', tag, elapsed_ms)
|
|
raise
|
|
|
|
except Exception:
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
end_dt = datetime.now()
|
|
|
|
logger.info(
|
|
'[LLM][%s] stream_time [%s --> %s]', tag,
|
|
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
|
|
)
|
|
|
|
logger.exception('[LLM][%s] stream failed after %.1fms', tag, elapsed_ms)
|
|
raise
|
|
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
end_dt = datetime.now()
|
|
|
|
logger.info(
|
|
'[LLM][%s] stream_time [%s --> %s]', tag,
|
|
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
|
|
)
|
|
|
|
logger.info(
|
|
'[LLM][%s] stream finished in %.1fms yielded_chars=%d',
|
|
tag, elapsed_ms, yielded_chars,
|
|
)
|
|
|
|
|
|
async def stream_ollama_events(
|
|
prompt: str,
|
|
*,
|
|
system_prompt: str | None = None,
|
|
tag: str = 'default-events',
|
|
temperature: float = 0.7,
|
|
thinking: str | None = None,
|
|
model: str | None = None,
|
|
use_pro_model: bool = False,
|
|
enable_thinking: bool = True,
|
|
timeout: float | None = None,
|
|
) -> AsyncIterator[tuple[Literal['thinking', 'content'], str]]:
|
|
"""Stream (event_type, payload) tuples from OpenAI-compatible chat completions."""
|
|
start = time.perf_counter()
|
|
start_dt = datetime.now()
|
|
|
|
model_name = _resolve_model_name(model, use_pro_model=use_pro_model)
|
|
log_model_name = 'pro' if (model is None and use_pro_model) else model_name
|
|
yielded_chars = 0
|
|
|
|
logger.info(
|
|
'[LLM][%s] event_stream request model=%s base_url=%s prompt_chars=%d system_chars=%d temp=%.2f thinking=%s',
|
|
tag, log_model_name, LLM_BASE_URL, len(prompt),
|
|
len(system_prompt or ''), temperature, thinking,
|
|
)
|
|
|
|
payload = _build_chat_stream_payload(
|
|
prompt=prompt, system_prompt=system_prompt, temperature=temperature,
|
|
thinking=thinking if enable_thinking else None, model=model, use_pro_model=use_pro_model,
|
|
)
|
|
|
|
effective_timeout = timeout if timeout is not None else COMPLETION_TIMEOUT
|
|
http_timeout = httpx.Timeout(connect=10.0, read=None, write=30.0, pool=30.0)
|
|
sent_thinking = False
|
|
|
|
try:
|
|
async with httpx.AsyncClient(base_url=LLM_BASE_URL, timeout=http_timeout) as client:
|
|
try:
|
|
async with client.stream('POST', '/chat/completions', json=payload) as response:
|
|
response.raise_for_status()
|
|
|
|
deadline = time.perf_counter() + effective_timeout
|
|
line_iterator = response.aiter_lines().__aiter__()
|
|
|
|
while True:
|
|
remaining = deadline - time.perf_counter()
|
|
if remaining <= 0:
|
|
raise TimeoutError('LLM event stream timed out')
|
|
|
|
try:
|
|
line = await asyncio.wait_for(line_iterator.__anext__(), timeout=remaining)
|
|
except StopAsyncIteration:
|
|
break
|
|
|
|
if not line or line.startswith(':'):
|
|
continue
|
|
|
|
# SSE data lines: "data: {json}" or "data: [DONE]"
|
|
if line.startswith('data: '):
|
|
data_str = line[6:] # strip "data: " prefix
|
|
|
|
else:
|
|
data_str = line.strip()
|
|
|
|
if not data_str or data_str == '[DONE]':
|
|
continue
|
|
|
|
try:
|
|
chunk = json.loads(data_str)
|
|
except json.JSONDecodeError:
|
|
logger.warning('[LLM][%s] ignored invalid Ollama stream line', tag)
|
|
continue
|
|
|
|
if not isinstance(chunk, dict):
|
|
continue
|
|
|
|
error = chunk.get('error')
|
|
if error:
|
|
raise RuntimeError(str(error))
|
|
|
|
thinking_delta = _extract_delta_thinking(chunk)
|
|
if thinking_delta and not sent_thinking:
|
|
sent_thinking = True
|
|
yield 'thinking', ''
|
|
|
|
text = _extract_delta_text(chunk)
|
|
if not text:
|
|
continue
|
|
|
|
yielded_chars += len(text)
|
|
yield 'content', text
|
|
|
|
except asyncio.CancelledError:
|
|
if response is not None:
|
|
await response.aclose()
|
|
raise
|
|
|
|
except asyncio.CancelledError:
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
end_dt = datetime.now()
|
|
|
|
logger.info(
|
|
'[LLM][%s] event_stream_time [%s --> %s]', tag,
|
|
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
|
|
)
|
|
|
|
logger.warning('[LLM][%s] event stream cancelled after %.1fms', tag, elapsed_ms)
|
|
raise
|
|
|
|
except Exception:
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
end_dt = datetime.now()
|
|
|
|
logger.info(
|
|
'[LLM][%s] event_stream_time [%s --> %s]', tag,
|
|
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
|
|
)
|
|
|
|
logger.exception('[LLM][%s] event stream failed after %.1fms', tag, elapsed_ms)
|
|
raise
|
|
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
end_dt = datetime.now()
|
|
|
|
logger.info(
|
|
'[LLM][%s] event_stream_time [%s --> %s]', tag,
|
|
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
|
|
)
|
|
|
|
logger.info(
|
|
'[LLM][%s] event stream finished in %.1fms yielded_chars=%d thinking_seen=%s',
|
|
tag, elapsed_ms, yielded_chars, sent_thinking,
|
|
)
|
|
|
|
|
|
async def call_vlm_ocr(image_bytes: bytes, language: str = 'auto') -> str:
|
|
"""OCR via VLM using OpenAI-compatible vision API (image_url content part)."""
|
|
start = time.perf_counter()
|
|
start_dt = datetime.now()
|
|
|
|
logger.info(
|
|
'[VLM][ocr] request model=%s base_url=%s image_bytes=%d language=%s',
|
|
VLM_MODEL, LLM_BASE_URL, len(image_bytes), language,
|
|
)
|
|
|
|
image_b64 = base64.b64encode(image_bytes).decode('ascii')
|
|
|
|
payload = {
|
|
'model': VLM_MODEL,
|
|
'messages': [{
|
|
'role': 'user',
|
|
'content': [
|
|
{'type': 'text', 'text': get_vlm_ocr_prompt()},
|
|
{
|
|
'type': 'image_url',
|
|
'image_url': {'url': f'data:image/png;base64,{image_b64}'},
|
|
},
|
|
],
|
|
}],
|
|
'stream': False,
|
|
}
|
|
|
|
http_timeout = httpx.Timeout(connect=10.0, read=None, write=30.0, pool=30.0)
|
|
|
|
try:
|
|
async with httpx.AsyncClient(base_url=LLM_BASE_URL, timeout=http_timeout) as client:
|
|
resp = await asyncio.wait_for(
|
|
client.post('/chat/completions', json=payload), timeout=OCR_TIMEOUT,
|
|
)
|
|
|
|
resp.raise_for_status()
|
|
response = resp.json()
|
|
|
|
except Exception:
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
end_dt = datetime.now()
|
|
|
|
logger.info(
|
|
'[VLM][ocr] call_time [%s --> %s]', start_dt.strftime('%H:%M:%S'),
|
|
end_dt.strftime('%H:%M:%S'),
|
|
)
|
|
|
|
logger.exception('[VLM][ocr] request failed after %.1fms', elapsed_ms)
|
|
raise
|
|
|
|
content, _ = _extract_message(response)
|
|
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
end_dt = datetime.now()
|
|
|
|
logger.info(
|
|
'[VLM][ocr] call_time [%s --> %s]', start_dt.strftime('%H:%M:%S'),
|
|
end_dt.strftime('%H:%M:%S'),
|
|
)
|
|
|
|
logger.info(
|
|
'[VLM][ocr] response in %.1fms content_chars=%d', elapsed_ms, len(content),
|
|
)
|
|
|
|
if not content.strip():
|
|
logger.warning('[VLM][ocr] empty content returned by model')
|
|
|
|
return content
|