296 lines
8.9 KiB
Python
296 lines
8.9 KiB
Python
import asyncio
|
|
import importlib
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
|
|
BACKEND_DIR = Path(__file__).resolve().parents[1]
|
|
if str(BACKEND_DIR) not in sys.path:
|
|
sys.path.insert(0, str(BACKEND_DIR))
|
|
|
|
try:
|
|
llm = importlib.import_module("llm")
|
|
except ModuleNotFoundError:
|
|
pytest.skip("llm module dependencies are not available", allow_module_level=True)
|
|
|
|
|
|
def test_extract_message_openai_format():
|
|
resp = {"choices": [{"message": {"content": "hello world", "thinking": "reasoning"}}]}
|
|
content, thinking = llm._extract_message(resp)
|
|
assert content == "hello world"
|
|
assert thinking == "reasoning"
|
|
|
|
|
|
def test_extract_message_openai_reasoning_content():
|
|
resp = {"choices": [{"message": {"content": "answer", "reasoning_content": "deep thought"}}]}
|
|
content, thinking = llm._extract_message(resp)
|
|
assert content == "answer"
|
|
assert thinking == "deep thought"
|
|
|
|
|
|
def test_extract_message_empty_choices():
|
|
resp = {"choices": []}
|
|
content, thinking = llm._extract_message(resp)
|
|
assert content == ""
|
|
assert thinking == ""
|
|
|
|
|
|
def test_extract_message_no_choices_key():
|
|
resp = {}
|
|
content, thinking = llm._extract_message(resp)
|
|
assert content == ""
|
|
assert thinking == ""
|
|
|
|
|
|
def test_extract_message_none_content():
|
|
resp = {"choices": [{"message": {"content": None, "thinking": None}}]}
|
|
content, thinking = llm._extract_message(resp)
|
|
assert content == ""
|
|
assert thinking == ""
|
|
|
|
|
|
def test_extract_delta_text():
|
|
chunk = {"choices": [{"delta": {"content": "hello"}}]}
|
|
assert llm._extract_delta_text(chunk) == "hello"
|
|
|
|
|
|
def test_extract_delta_text_empty():
|
|
chunk = {"choices": [{"delta": {}}]}
|
|
assert llm._extract_delta_text(chunk) == ""
|
|
|
|
|
|
def test_extract_delta_thinking():
|
|
chunk = {"choices": [{"delta": {"thinking": "reasoning step"}}]}
|
|
assert llm._extract_delta_thinking(chunk) == "reasoning step"
|
|
|
|
|
|
def test_extract_delta_reasoning_content():
|
|
chunk = {"choices": [{"delta": {"reasoning_content": "deep thought"}}]}
|
|
assert llm._extract_delta_thinking(chunk) == "deep thought"
|
|
|
|
|
|
def test_resolve_model_name_explicit():
|
|
assert llm._resolve_model_name("custom-model") == "custom-model"
|
|
|
|
|
|
def test_resolve_model_name_default():
|
|
assert llm._resolve_model_name() == llm.LLM_MODEL
|
|
|
|
|
|
def test_resolve_model_name_pro():
|
|
assert llm._resolve_model_name(use_pro_model=True) == llm.PRO_LLM_MODEL
|
|
|
|
|
|
def test_resolve_system_prompt():
|
|
assert llm._resolve_system_prompt(" system prompt ") == "system prompt"
|
|
assert llm._resolve_system_prompt("") == ""
|
|
assert llm._resolve_system_prompt(None) == ""
|
|
|
|
|
|
def test_build_chat_payload_with_system():
|
|
payload = llm._build_chat_payload(
|
|
"user prompt", system_prompt="sys prompt", temperature=0.5, model="test-model"
|
|
)
|
|
assert payload["model"] == "test-model"
|
|
assert len(payload["messages"]) == 2
|
|
assert payload["messages"][0]["role"] == "system"
|
|
assert payload["messages"][1]["role"] == "user"
|
|
assert payload["stream"] is False
|
|
|
|
|
|
def test_build_chat_payload_no_system():
|
|
payload = llm._build_chat_payload("user prompt", system_prompt=None)
|
|
assert len(payload["messages"]) == 1
|
|
assert payload["stream"] is False
|
|
|
|
|
|
def test_build_chat_payload_with_thinking():
|
|
payload = llm._build_chat_payload("prompt", thinking="low")
|
|
assert "options" in payload
|
|
assert payload["options"]["think"] == "low"
|
|
|
|
|
|
def test_build_chat_stream_payload():
|
|
payload = llm._build_chat_stream_payload("prompt", system_prompt="sys")
|
|
assert payload["stream"] is True
|
|
assert len(payload["messages"]) == 2
|
|
|
|
|
|
def test_build_chat_stream_payload_with_thinking():
|
|
payload = llm._build_chat_stream_payload("prompt", thinking="high")
|
|
assert "options" in payload
|
|
assert payload["options"]["think"] == "high"
|
|
|
|
|
|
def test_call_ollama_non_streaming(monkeypatch):
|
|
captured = {}
|
|
|
|
async def fake_post(url, json=None):
|
|
captured["url"] = url
|
|
captured["json"] = json
|
|
|
|
class FakeResp:
|
|
def raise_for_status(self): pass
|
|
def json(self): return {"choices": [{"message": {"content": "done"}}]}
|
|
|
|
return FakeResp()
|
|
|
|
async def fake_client(*args, **kwargs):
|
|
class Ctx:
|
|
async def __aenter__(self2): return self2
|
|
async def __aexit__(*a): pass
|
|
post = fake_post
|
|
return Ctx()
|
|
|
|
monkeypatch.setattr(llm.httpx, "AsyncClient", fake_client)
|
|
monkeypatch.setattr(llm.asyncio, "wait_for", lambda coro, **kw: coro)
|
|
|
|
result = asyncio.run(
|
|
llm.call_ollama("test prompt", system_prompt="sys", tag="t1")
|
|
)
|
|
|
|
assert result["content"] == "done"
|
|
assert captured["url"] == "/chat/completions"
|
|
assert captured["json"]["stream"] is False
|
|
|
|
|
|
def test_stream_ollama_text_deltas(monkeypatch):
|
|
captured = {}
|
|
|
|
def make_lines():
|
|
lines_iter = iter([
|
|
'data: {"choices": [{"delta": {"content": "hel"}}]}',
|
|
'data: {"choices": [{"delta": {"content": "lo"}}]}',
|
|
"data: [DONE]",
|
|
])
|
|
|
|
class LineIterator:
|
|
async def __anext__(self):
|
|
try:
|
|
return next(lines_iter)
|
|
except StopIteration:
|
|
raise StopAsyncIteration()
|
|
|
|
class Response:
|
|
def __init__(self2): self2._lines = LineIterator()
|
|
|
|
async def raise_for_status(self2): pass
|
|
async def aiter_lines(self2): return self2._lines
|
|
|
|
class StreamCtx:
|
|
async def __aenter__(self2): return Response()
|
|
async def __aexit__(*a): pass
|
|
|
|
class Client:
|
|
stream = lambda self2, *args, **kw: StreamCtx()
|
|
|
|
return Client()
|
|
|
|
async def fake_client(*args, **kwargs):
|
|
captured["called"] = True
|
|
return make_lines()
|
|
|
|
monkeypatch.setattr(llm.httpx, "AsyncClient", fake_client)
|
|
monkeypatch.setattr(llm.asyncio, "wait_for", lambda coro, **kw: coro)
|
|
|
|
results = []
|
|
async def collect():
|
|
async for delta in llm.stream_ollama("prompt", tag="t1"):
|
|
results.append(delta)
|
|
|
|
asyncio.run(collect())
|
|
assert captured.get("called") is True
|
|
assert results == ["hel", "lo"]
|
|
|
|
|
|
def test_stream_ollama_events_thinking_and_content(monkeypatch):
|
|
captured = {}
|
|
|
|
def make_lines():
|
|
lines_iter = iter([
|
|
'data: {"choices": [{"delta": {"thinking": "reasoning"}}]}',
|
|
'data: {"choices": [{"delta": {"content": "answer"}}]}',
|
|
"data: [DONE]",
|
|
])
|
|
|
|
class LineIterator:
|
|
async def __anext__(self):
|
|
try:
|
|
return next(lines_iter)
|
|
except StopIteration:
|
|
raise StopAsyncIteration()
|
|
|
|
class Response:
|
|
def __init__(self2): self2._lines = LineIterator()
|
|
|
|
async def raise_for_status(self2): pass
|
|
async def aiter_lines(self2): return self2._lines
|
|
|
|
class StreamCtx:
|
|
async def __aenter__(self2): return Response()
|
|
async def __aexit__(*a): pass
|
|
|
|
class Client:
|
|
stream = lambda self2, *args, **kw: StreamCtx()
|
|
|
|
return Client()
|
|
|
|
async def fake_client(*args, **kwargs):
|
|
captured["called"] = True
|
|
return make_lines()
|
|
|
|
monkeypatch.setattr(llm.httpx, "AsyncClient", fake_client)
|
|
monkeypatch.setattr(llm.asyncio, "wait_for", lambda coro, **kw: coro)
|
|
|
|
results = []
|
|
async def collect():
|
|
async for event_type, payload in llm.stream_ollama_events("prompt", tag="t1"):
|
|
results.append((event_type, payload))
|
|
|
|
asyncio.run(collect())
|
|
assert captured.get("called") is True
|
|
# First event should be thinking, then content
|
|
assert results[0] == ("thinking", "")
|
|
assert results[1][0] == "content"
|
|
|
|
|
|
def test_call_vlm_ocr(monkeypatch):
|
|
captured = {}
|
|
|
|
async def fake_post(url, json=None):
|
|
captured["url"] = url
|
|
captured["json"] = json
|
|
|
|
class FakeResp:
|
|
def raise_for_status(self): pass
|
|
def json(self): return {"choices": [{"message": {"content": "ocr text"}}]}
|
|
|
|
return FakeResp()
|
|
|
|
async def fake_client(*args, **kwargs):
|
|
class Ctx:
|
|
async def __aenter__(self2): return self2
|
|
async def __aexit__(*a): pass
|
|
post = fake_post
|
|
return Ctx()
|
|
|
|
monkeypatch.setattr(llm.httpx, "AsyncClient", fake_client)
|
|
monkeypatch.setattr(llm.asyncio, "wait_for", lambda coro, **kw: coro)
|
|
|
|
result = asyncio.run(llm.call_vlm_ocr(b"fake image bytes"))
|
|
assert result == "ocr text"
|
|
|
|
# Verify the payload uses OpenAI vision format (image_url)
|
|
assert captured["url"] == "/chat/completions"
|
|
messages = captured["json"]["messages"]
|
|
assert len(messages) == 1
|
|
content_parts = messages[0]["content"]
|
|
# Should have text part and image_url part
|
|
assert any(p.get("type") == "text" for p in content_parts)
|
|
image_part = [p for p in content_parts if p.get("type") == "image_url"]
|
|
assert len(image_part) == 1
|
|
assert image_part[0]["image_url"]["url"].startswith("data:image/png;base64,")
|