Files
llm-in-text/backend/tests/test_main_endpoints.py
ydy0615 59334e4057 Stabilize pro editing without heavy office runtime
The workspace now carries the pro editing flow, streaming completion path, and lighter Office preview state as one checkpoint so the remote has the current runnable project shape.

Constraint: Preserve the current workspace as a single reviewable project commit while excluding local agent state and verification artifacts. Removed stale Univer runtime dependencies from the lockfile so installs match package.json.

Rejected: Commit runtime screenshots, .omx state, and coverage files | they are local artifacts rather than source state.

Confidence: medium

Scope-risk: broad

Directive: Keep package.json and package-lock.json synchronized when changing frontend dependencies.

Tested: npm run build; C:\Users\ydy\.conda\envs\llmwebsite\python.exe -m pytest backend/tests/test_main_endpoints.py backend/tests/test_main_cancel.py backend/tests/test_llm.py backend/tests/test_llm_extended.py -v -o addopts= (44 passed).

Not-tested: Full pytest with repository coverage addopts currently reports 0% coverage because pytest-cov watches backend.* module names while tests import top-level backend modules.

Co-authored-by: OmX <omx@oh-my-codex.dev>
2026-05-24 23:30:32 +08:00

255 lines
7.8 KiB
Python

import os
import sys
import base64
import types
import pytest
from unittest.mock import MagicMock
from fastapi.testclient import TestClient
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
BACKEND_DIR = os.path.abspath(os.path.join(CURRENT_DIR, ".."))
if BACKEND_DIR not in sys.path:
sys.path.insert(0, BACKEND_DIR)
if "tts_asr" not in sys.modules:
fake_tts_asr = types.ModuleType("tts_asr")
fake_tts_asr.register_tts_asr_routes = lambda app: None
sys.modules["tts_asr"] = fake_tts_asr
import main # type: ignore
API_KEY = main.API_KEY
HEADERS = {"X-API-Key": API_KEY}
@pytest.fixture(autouse=True)
def _clear_active_completions():
main.ACTIVE_COMPLETIONS.clear()
yield
main.ACTIVE_COMPLETIONS.clear()
class DummyRequest:
def __init__(self, host=None, headers=None):
class Client:
pass
self.client = Client() if host is not None else None
if self.client is not None:
self.client.host = host
self.headers = headers or {}
def test_preview_short_text():
assert main._preview("Hello") == "Hello"
def test_preview_long_text_truncated():
long_text = "a" * 100
assert main._preview(long_text) == long_text[:80] + "..."
def test_preview_none_input():
assert main._preview(None) == ""
def test_preview_newlines_replaced():
assert main._preview("line1\nline2") == "line1\\nline2"
def test_sanitize_markdown_strips_image_markdown():
assert "![alt](image.png)" not in main._sanitize_converted_markdown(
"text with image ![alt](image.png) end"
)
def test_sanitize_markdown_strips_img_tag():
assert "<img" not in main._sanitize_converted_markdown("<img src='x.png'/>")
def test_sanitize_markdown_collapse_newlines():
assert main._sanitize_converted_markdown("a\n\n\nb\n\n\n\nc") == "a\n\nb\n\nc"
def test_sanitize_markdown_normalize_crlf():
result = main._sanitize_converted_markdown("line1\r\nline2\r\n")
assert "line1\nline2" in result
assert "\r" not in result
def test_get_client_ip_from_host():
req = DummyRequest(host="1.2.3.4", headers={})
assert main.get_client_ip(req) == "1.2.3.4"
def test_get_client_ip_header_overrides_host():
req = DummyRequest(host="1.2.3.4", headers={"X-Client-IP": "5.6.7.8"})
assert main.get_client_ip(req) == "5.6.7.8"
def test_get_client_ip_when_client_missing():
req = DummyRequest(host=None, headers={"X-Client-IP": "9.9.9.9"})
req.client = None
assert main.get_client_ip(req) == "9.9.9.9"
def test_post_completions_wrong_api_key_returns_401():
client = TestClient(main.app)
resp = client.post("/v1/completions", json={
"prefix": "hello", "suffix": "", "languageId": "markdown",
"model_thinking": "low", "privacy_mode": True,
})
assert resp.status_code == 401
def test_post_completions_privacy_mode(monkeypatch):
async def fake_call(*args, **kwargs):
return {"content": "done", "think": ""}
monkeypatch.setattr(main, "call_ollama", fake_call)
monkeypatch.setattr(main, "build_completion_prompts", lambda *a, **k: ("sys", "user"))
monkeypatch.setattr(main, "prepare_prompt_context", lambda *a, **k: ("p", "s"))
client = TestClient(main.app)
resp = client.post("/v1/completions", headers=HEADERS, json={
"prefix": "hello", "suffix": "", "languageId": "markdown",
"model_thinking": "low", "privacy_mode": True,
})
assert resp.status_code == 200
data = resp.json()
assert data.get("content") == "done"
def test_post_pro_stream_returns_sse(monkeypatch):
captured = {}
async def fake_stream(*args, **kwargs):
captured["kwargs"] = kwargs
yield "深度"
yield "回答"
monkeypatch.setattr(main, "stream_ollama", fake_stream)
monkeypatch.setattr(main, "build_completion_prompts", lambda *a, **k: ("sys", "user"))
monkeypatch.setattr(main, "prepare_prompt_context", lambda *a, **k: ("p", "s"))
client = TestClient(main.app)
with client.stream("POST", "/v1/pro/completions/stream", headers=HEADERS, json={
"prefix": "hello",
"suffix": "",
"languageId": "markdown",
"model_thinking": "high",
"privacy_mode": True,
"model": "pro-model",
"temperature": 0.95,
}) as resp:
assert resp.status_code == 200
body = "".join(resp.iter_text())
assert "event: chunk" in body
assert "event: done" in body
assert "深度" in body
assert "回答" in body
assert captured["kwargs"]["model"] == "pro-model"
assert captured["kwargs"]["use_pro_model"] is True
assert main.ACTIVE_COMPLETIONS == {}
def test_post_ocr_mocked(monkeypatch):
async def fake_ocr(*args, **kwargs):
return "OCR result text"
monkeypatch.setattr(main, "call_vlm_ocr", fake_ocr)
client = TestClient(main.app)
img_b64 = base64.b64encode(b"pretend image data").decode()
resp = client.post("/v1/ocr", headers=HEADERS, json={
"image": img_b64, "filename": "test.jpg", "language": "auto",
})
assert resp.status_code == 200
j = resp.json()
assert j["text"] == "OCR result text"
assert j["filename"] == "test.jpg"
def test_post_ocr_invalid_base64_returns_500():
client = TestClient(main.app)
resp = client.post("/v1/ocr", headers=HEADERS, json={
"image": "not-base64!!!", "filename": "test.jpg",
})
assert resp.status_code == 500
def test_post_convert_txt_returns_markdown():
client = TestClient(main.app)
content = base64.b64encode(b"hello world").decode()
resp = client.post("/v1/convert", headers=HEADERS, json={
"file": content, "filename": "sample.txt",
})
assert resp.status_code == 200
j = resp.json()
assert j["markdown"] == "hello world"
assert j["filename"] == "sample.txt"
def test_post_convert_unsupported_extension_returns_500():
client = TestClient(main.app)
content = base64.b64encode(b"data").decode()
resp = client.post("/v1/convert", headers=HEADERS, json={
"file": content, "filename": "sample.xlsx",
})
assert resp.status_code == 500
assert "仅支持" in resp.json()["error"]
def test_post_convert_docx_with_mocked_markitdown(monkeypatch):
class FakeResult:
text_content = "markdown from docx"
class FakeMD:
def convert(self, path):
return FakeResult()
monkeypatch.setattr(main, "_get_markitdown", lambda: FakeMD())
client = TestClient(main.app)
content = base64.b64encode(b"docx content").decode()
resp = client.post("/v1/convert", headers=HEADERS, json={
"file": content, "filename": "sample.docx",
})
assert resp.status_code == 200
j = resp.json()
assert j["markdown"] == "markdown from docx"
def test_post_cancel_non_existent_returns_not_found():
client = TestClient(main.app)
resp = client.post("/v1/completions/cancel", headers=HEADERS, json={
"request_id": "non-existent", "reason": "abort",
})
assert resp.status_code == 200
data = resp.json()
assert data["cancelled"] is False
assert data["status"] == "not_found"
def test_post_cancel_wrong_api_key_returns_401():
client = TestClient(main.app)
resp = client.post("/v1/completions/cancel", json={
"request_id": "id", "reason": "abort",
})
assert resp.status_code == 401
def test_post_cancel_already_done(monkeypatch):
main.ACTIVE_COMPLETIONS.clear()
# Create a mock task that appears done
mock_task = MagicMock()
mock_task.done.return_value = True
mock_task.cancel = MagicMock()
main.ACTIVE_COMPLETIONS["done-id"] = mock_task
client = TestClient(main.app)
resp = client.post("/v1/completions/cancel", headers=HEADERS, json={
"request_id": "done-id", "reason": "abort",
})
assert resp.status_code == 200
data = resp.json()
assert data["cancelled"] is False
assert data["status"] == "already_done"
main.ACTIVE_COMPLETIONS.clear()