Files
llm-in-text/backend/llm.py

548 lines
18 KiB
Python

import os
import time
import logging
import asyncio
import json
import base64
from datetime import datetime
from typing import AsyncIterator, Literal
import httpx
from dotenv import load_dotenv
from prompts import get_vlm_ocr_prompt
load_dotenv()
# OpenAI-compatible endpoint config
LLM_BASE_URL = os.getenv('LLM_BASE_URL', 'http://localhost:11434/v1/')
LLM_API_KEY = os.getenv('LLM_API_KEY', 'ollama')
# Model names (backward compat: fall back to OLLAMA_MODEL if LLM_MODEL not set)
_raw_model = os.getenv('LLM_MODEL') or os.getenv('OLLAMA_MODEL', 'gpt-oss:20b')
LLM_MODEL = _raw_model.strip() if _raw_model else 'gpt-oss:20b'
PRO_LLM_MODEL = os.getenv('PRO_LLM_MODEL', LLM_MODEL)
# VLM for OCR (vision models)
VLM_MODEL = os.getenv('VLM_MODEL', 'qwen3-vl:30b')
# Fallback for legacy OLLAMA_HOST env var (auto-convert to /v1/ path)
_legacy_host = os.getenv('OLLAMA_HOST')
if _legacy_host and not os.getenv('LLM_BASE_URL'):
base = _legacy_host.rstrip('/')
if '/v1' not in base:
LLM_BASE_URL = f"{base}/v1/"
# Normalize trailing slash for base URL
LLM_BASE_URL = LLM_BASE_URL.rstrip('/') + '/'
# Timeouts in seconds (10 minutes for large model loading)
COMPLETION_TIMEOUT = int(os.getenv("LLM_COMPLETION_TIMEOUT", "600"))
OCR_TIMEOUT = int(os.getenv("LLM_OCR_TIMEOUT", "600"))
logger = logging.getLogger('llm')
def _extract_message(response: dict) -> tuple[str, str]:
"""Extract content and thinking from an OpenAI-compatible response dict."""
choices = response.get('choices', []) if isinstance(response, dict) else []
msg = (choices[0].get('message', {}) if choices and isinstance(choices, list) else {}).copy()
content = msg.get('content', '') or ''
thinking = (msg.get('reasoning_content') or msg.get('thinking', '') or '').strip()
return content, thinking
def _resolve_system_prompt(system_prompt: str | None) -> str:
if system_prompt and system_prompt.strip():
return system_prompt.strip()
return ''
def _resolve_model_name(model: str | None = None, *, use_pro_model: bool = False) -> str:
candidate = (model or '').strip()
if candidate:
return candidate
return PRO_LLM_MODEL if use_pro_model else LLM_MODEL
def _build_chat_payload(
prompt: str,
*,
system_prompt: str | None = None,
temperature: float = 0.7,
thinking: str | None = None,
model: str | None = None,
use_pro_model: bool = False,
) -> dict:
messages = []
sys_prompt = _resolve_system_prompt(system_prompt)
if sys_prompt:
messages.append({'role': 'system', 'content': sys_prompt})
messages.append({'role': 'user', 'content': prompt})
payload = {
'model': _resolve_model_name(model, use_pro_model=use_pro_model),
'messages': messages,
'stream': False,
}
options = {'temperature': temperature}
if thinking:
payload['options'] = {'temperature': temperature, 'think': thinking}
return payload
def _build_chat_stream_payload(
prompt: str,
*,
system_prompt: str | None = None,
temperature: float = 0.7,
thinking: str | None = None,
model: str | None = None,
use_pro_model: bool = False,
) -> dict:
messages = []
sys_prompt = _resolve_system_prompt(system_prompt)
if sys_prompt:
messages.append({'role': 'system', 'content': sys_prompt})
messages.append({'role': 'user', 'content': prompt})
payload = {
'model': _resolve_model_name(model, use_pro_model=use_pro_model),
'messages': messages,
'stream': True,
}
options = {'temperature': temperature}
if thinking:
payload['options'] = {'temperature': temperature, 'think': thinking}
return payload
def _extract_delta_text(chunk: dict) -> str:
"""Extract text delta from an OpenAI-compatible SSE chunk."""
choices = chunk.get('choices', []) if isinstance(chunk, dict) else []
delta = (choices[0].get('delta', {}) if choices and isinstance(choices, list) else {}).copy()
content = delta.get('content', '') or ''
return content
def _extract_delta_thinking(chunk: dict) -> str:
"""Extract thinking/reasoning delta from an SSE chunk."""
choices = chunk.get('choices', []) if isinstance(chunk, dict) else []
delta = (choices[0].get('delta', {}) if choices and isinstance(choices, list) else {}).copy()
return (delta.get('reasoning_content') or delta.get('thinking', '') or '').strip()
async def call_ollama(
prompt: str,
*,
system_prompt: str | None = None,
tag: str = 'default',
temperature: float = 0.7,
thinking: str | None = None,
model: str | None = None,
use_pro_model: bool = False,
) -> dict:
"""Call OpenAI-compatible chat completions (non-streaming) and return content/thinking."""
start = time.perf_counter()
start_dt = datetime.now()
model_name = _resolve_model_name(model, use_pro_model=use_pro_model)
log_model_name = 'pro' if (model is None and use_pro_model) else model_name
logger.info(
'[LLM][%s] request model=%s base_url=%s prompt_chars=%d system_chars=%d temp=%.2f thinking=%s',
tag, log_model_name, LLM_BASE_URL, len(prompt),
len(system_prompt or ''), temperature, thinking,
)
payload = _build_chat_payload(
prompt=prompt, system_prompt=system_prompt, temperature=temperature,
thinking=thinking, model=model, use_pro_model=use_pro_model,
)
http_timeout = httpx.Timeout(connect=10.0, read=None, write=30.0, pool=30.0)
try:
async with httpx.AsyncClient(base_url=LLM_BASE_URL, timeout=http_timeout) as client:
resp = await asyncio.wait_for(
client.post('/chat/completions', json=payload), timeout=COMPLETION_TIMEOUT,
)
resp.raise_for_status()
response = resp.json()
except asyncio.CancelledError:
elapsed_ms = (time.perf_counter() - start) * 1000
end_dt = datetime.now()
logger.info(
'[LLM][%s] call_time [%s --> %s]', tag,
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
)
logger.warning('[LLM][%s] request cancelled after %.1fms', tag, elapsed_ms)
raise
except Exception:
elapsed_ms = (time.perf_counter() - start) * 1000
end_dt = datetime.now()
logger.info(
'[LLM][%s] call_time [%s --> %s]', tag,
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
)
logger.exception('[LLM][%s] request failed after %.1fms', tag, elapsed_ms)
raise
content, thinking_out = _extract_message(response)
elapsed_ms = (time.perf_counter() - start) * 1000
end_dt = datetime.now()
logger.info(
'[LLM][%s] call_time [%s --> %s]', tag,
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
)
logger.info(
'[LLM][%s] response in %.1fms content_chars=%d thinking_chars=%d',
tag, elapsed_ms, len(content), len(thinking_out or ''),
)
if not content.strip():
logger.warning('[LLM][%s] empty content returned by model', tag)
return {'content': content, 'think': thinking_out or ''}
async def stream_ollama(
prompt: str,
*,
system_prompt: str | None = None,
tag: str = 'default-stream',
temperature: float = 0.7,
thinking: str | None = None,
model: str | None = None,
use_pro_model: bool = False,
) -> AsyncIterator[str]:
"""Stream text deltas from OpenAI-compatible chat completions."""
start = time.perf_counter()
start_dt = datetime.now()
model_name = _resolve_model_name(model, use_pro_model=use_pro_model)
log_model_name = 'pro' if (model is None and use_pro_model) else model_name
yielded_chars = 0
logger.info(
'[LLM][%s] stream request model=%s base_url=%s prompt_chars=%d system_chars=%d temp=%.2f thinking=%s',
tag, log_model_name, LLM_BASE_URL, len(prompt),
len(system_prompt or ''), temperature, thinking,
)
payload = _build_chat_stream_payload(
prompt=prompt, system_prompt=system_prompt, temperature=temperature,
thinking=thinking, model=model, use_pro_model=use_pro_model,
)
http_timeout = httpx.Timeout(connect=10.0, read=None, write=30.0, pool=30.0)
try:
async with httpx.AsyncClient(base_url=LLM_BASE_URL, timeout=http_timeout) as client:
try:
async with client.stream('POST', '/chat/completions', json=payload) as response:
response.raise_for_status()
deadline = time.perf_counter() + COMPLETION_TIMEOUT
line_iterator = response.aiter_lines().__aiter__()
while True:
remaining = deadline - time.perf_counter()
if remaining <= 0:
raise TimeoutError('LLM stream timed out')
try:
line = await asyncio.wait_for(line_iterator.__anext__(), timeout=remaining)
except StopAsyncIteration:
break
if not line or line.startswith(':'):
continue
# SSE data lines: "data: {json}" or "data: [DONE]"
if line.startswith('data: '):
data_str = line[6:] # strip "data: " prefix
else:
data_str = line.strip()
if not data_str or data_str == '[DONE]':
continue
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
logger.warning('[LLM][%s] ignored invalid stream line', tag)
continue
if not isinstance(chunk, dict):
continue
text = _extract_delta_text(chunk)
if not text:
continue
yielded_chars += len(text)
yield text
except asyncio.CancelledError:
if response is not None:
await response.aclose()
raise
except asyncio.CancelledError:
elapsed_ms = (time.perf_counter() - start) * 1000
end_dt = datetime.now()
logger.info(
'[LLM][%s] stream_time [%s --> %s]', tag,
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
)
logger.warning('[LLM][%s] stream cancelled after %.1fms', tag, elapsed_ms)
raise
except Exception:
elapsed_ms = (time.perf_counter() - start) * 1000
end_dt = datetime.now()
logger.info(
'[LLM][%s] stream_time [%s --> %s]', tag,
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
)
logger.exception('[LLM][%s] stream failed after %.1fms', tag, elapsed_ms)
raise
elapsed_ms = (time.perf_counter() - start) * 1000
end_dt = datetime.now()
logger.info(
'[LLM][%s] stream_time [%s --> %s]', tag,
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
)
logger.info(
'[LLM][%s] stream finished in %.1fms yielded_chars=%d',
tag, elapsed_ms, yielded_chars,
)
async def stream_ollama_events(
prompt: str,
*,
system_prompt: str | None = None,
tag: str = 'default-events',
temperature: float = 0.7,
thinking: str | None = None,
model: str | None = None,
use_pro_model: bool = False,
enable_thinking: bool = True,
timeout: float | None = None,
) -> AsyncIterator[tuple[Literal['thinking', 'content'], str]]:
"""Stream (event_type, payload) tuples from OpenAI-compatible chat completions."""
start = time.perf_counter()
start_dt = datetime.now()
model_name = _resolve_model_name(model, use_pro_model=use_pro_model)
log_model_name = 'pro' if (model is None and use_pro_model) else model_name
yielded_chars = 0
logger.info(
'[LLM][%s] event_stream request model=%s base_url=%s prompt_chars=%d system_chars=%d temp=%.2f thinking=%s',
tag, log_model_name, LLM_BASE_URL, len(prompt),
len(system_prompt or ''), temperature, thinking,
)
payload = _build_chat_stream_payload(
prompt=prompt, system_prompt=system_prompt, temperature=temperature,
thinking=thinking if enable_thinking else None, model=model, use_pro_model=use_pro_model,
)
effective_timeout = timeout if timeout is not None else COMPLETION_TIMEOUT
http_timeout = httpx.Timeout(connect=10.0, read=None, write=30.0, pool=30.0)
sent_thinking = False
try:
async with httpx.AsyncClient(base_url=LLM_BASE_URL, timeout=http_timeout) as client:
try:
async with client.stream('POST', '/chat/completions', json=payload) as response:
response.raise_for_status()
deadline = time.perf_counter() + effective_timeout
line_iterator = response.aiter_lines().__aiter__()
while True:
remaining = deadline - time.perf_counter()
if remaining <= 0:
raise TimeoutError('LLM event stream timed out')
try:
line = await asyncio.wait_for(line_iterator.__anext__(), timeout=remaining)
except StopAsyncIteration:
break
if not line or line.startswith(':'):
continue
# SSE data lines: "data: {json}" or "data: [DONE]"
if line.startswith('data: '):
data_str = line[6:] # strip "data: " prefix
else:
data_str = line.strip()
if not data_str or data_str == '[DONE]':
continue
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
logger.warning('[LLM][%s] ignored invalid Ollama stream line', tag)
continue
if not isinstance(chunk, dict):
continue
error = chunk.get('error')
if error:
raise RuntimeError(str(error))
thinking_delta = _extract_delta_thinking(chunk)
if thinking_delta and not sent_thinking:
sent_thinking = True
yield 'thinking', ''
text = _extract_delta_text(chunk)
if not text:
continue
yielded_chars += len(text)
yield 'content', text
except asyncio.CancelledError:
if response is not None:
await response.aclose()
raise
except asyncio.CancelledError:
elapsed_ms = (time.perf_counter() - start) * 1000
end_dt = datetime.now()
logger.info(
'[LLM][%s] event_stream_time [%s --> %s]', tag,
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
)
logger.warning('[LLM][%s] event stream cancelled after %.1fms', tag, elapsed_ms)
raise
except Exception:
elapsed_ms = (time.perf_counter() - start) * 1000
end_dt = datetime.now()
logger.info(
'[LLM][%s] event_stream_time [%s --> %s]', tag,
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
)
logger.exception('[LLM][%s] event stream failed after %.1fms', tag, elapsed_ms)
raise
elapsed_ms = (time.perf_counter() - start) * 1000
end_dt = datetime.now()
logger.info(
'[LLM][%s] event_stream_time [%s --> %s]', tag,
start_dt.strftime('%H:%M:%S'), end_dt.strftime('%H:%M:%S'),
)
logger.info(
'[LLM][%s] event stream finished in %.1fms yielded_chars=%d thinking_seen=%s',
tag, elapsed_ms, yielded_chars, sent_thinking,
)
async def call_vlm_ocr(image_bytes: bytes, language: str = 'auto') -> str:
"""OCR via VLM using OpenAI-compatible vision API (image_url content part)."""
start = time.perf_counter()
start_dt = datetime.now()
logger.info(
'[VLM][ocr] request model=%s base_url=%s image_bytes=%d language=%s',
VLM_MODEL, LLM_BASE_URL, len(image_bytes), language,
)
image_b64 = base64.b64encode(image_bytes).decode('ascii')
payload = {
'model': VLM_MODEL,
'messages': [{
'role': 'user',
'content': [
{'type': 'text', 'text': get_vlm_ocr_prompt()},
{
'type': 'image_url',
'image_url': {'url': f'data:image/png;base64,{image_b64}'},
},
],
}],
'stream': False,
}
http_timeout = httpx.Timeout(connect=10.0, read=None, write=30.0, pool=30.0)
try:
async with httpx.AsyncClient(base_url=LLM_BASE_URL, timeout=http_timeout) as client:
resp = await asyncio.wait_for(
client.post('/chat/completions', json=payload), timeout=OCR_TIMEOUT,
)
resp.raise_for_status()
response = resp.json()
except Exception:
elapsed_ms = (time.perf_counter() - start) * 1000
end_dt = datetime.now()
logger.info(
'[VLM][ocr] call_time [%s --> %s]', start_dt.strftime('%H:%M:%S'),
end_dt.strftime('%H:%M:%S'),
)
logger.exception('[VLM][ocr] request failed after %.1fms', elapsed_ms)
raise
content, _ = _extract_message(response)
elapsed_ms = (time.perf_counter() - start) * 1000
end_dt = datetime.now()
logger.info(
'[VLM][ocr] call_time [%s --> %s]', start_dt.strftime('%H:%M:%S'),
end_dt.strftime('%H:%M:%S'),
)
logger.info(
'[VLM][ocr] response in %.1fms content_chars=%d', elapsed_ms, len(content),
)
if not content.strip():
logger.warning('[VLM][ocr] empty content returned by model')
return content