The workspace now carries the pro editing flow, streaming completion path, and lighter Office preview state as one checkpoint so the remote has the current runnable project shape. Constraint: Preserve the current workspace as a single reviewable project commit while excluding local agent state and verification artifacts. Removed stale Univer runtime dependencies from the lockfile so installs match package.json. Rejected: Commit runtime screenshots, .omx state, and coverage files | they are local artifacts rather than source state. Confidence: medium Scope-risk: broad Directive: Keep package.json and package-lock.json synchronized when changing frontend dependencies. Tested: npm run build; C:\Users\ydy\.conda\envs\llmwebsite\python.exe -m pytest backend/tests/test_main_endpoints.py backend/tests/test_main_cancel.py backend/tests/test_llm.py backend/tests/test_llm_extended.py -v -o addopts= (44 passed). Not-tested: Full pytest with repository coverage addopts currently reports 0% coverage because pytest-cov watches backend.* module names while tests import top-level backend modules. Co-authored-by: OmX <omx@oh-my-codex.dev>
255 lines
7.8 KiB
Python
255 lines
7.8 KiB
Python
import os
|
|
import sys
|
|
import base64
|
|
import types
|
|
import pytest
|
|
from unittest.mock import MagicMock
|
|
from fastapi.testclient import TestClient
|
|
|
|
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
BACKEND_DIR = os.path.abspath(os.path.join(CURRENT_DIR, ".."))
|
|
if BACKEND_DIR not in sys.path:
|
|
sys.path.insert(0, BACKEND_DIR)
|
|
|
|
if "tts_asr" not in sys.modules:
|
|
fake_tts_asr = types.ModuleType("tts_asr")
|
|
fake_tts_asr.register_tts_asr_routes = lambda app: None
|
|
sys.modules["tts_asr"] = fake_tts_asr
|
|
|
|
import main # type: ignore
|
|
|
|
API_KEY = main.API_KEY
|
|
HEADERS = {"X-API-Key": API_KEY}
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _clear_active_completions():
|
|
main.ACTIVE_COMPLETIONS.clear()
|
|
yield
|
|
main.ACTIVE_COMPLETIONS.clear()
|
|
|
|
|
|
class DummyRequest:
|
|
def __init__(self, host=None, headers=None):
|
|
class Client:
|
|
pass
|
|
self.client = Client() if host is not None else None
|
|
if self.client is not None:
|
|
self.client.host = host
|
|
self.headers = headers or {}
|
|
|
|
|
|
def test_preview_short_text():
|
|
assert main._preview("Hello") == "Hello"
|
|
|
|
|
|
def test_preview_long_text_truncated():
|
|
long_text = "a" * 100
|
|
assert main._preview(long_text) == long_text[:80] + "..."
|
|
|
|
|
|
def test_preview_none_input():
|
|
assert main._preview(None) == ""
|
|
|
|
|
|
def test_preview_newlines_replaced():
|
|
assert main._preview("line1\nline2") == "line1\\nline2"
|
|
|
|
|
|
def test_sanitize_markdown_strips_image_markdown():
|
|
assert "" not in main._sanitize_converted_markdown(
|
|
"text with image  end"
|
|
)
|
|
|
|
|
|
def test_sanitize_markdown_strips_img_tag():
|
|
assert "<img" not in main._sanitize_converted_markdown("<img src='x.png'/>")
|
|
|
|
|
|
def test_sanitize_markdown_collapse_newlines():
|
|
assert main._sanitize_converted_markdown("a\n\n\nb\n\n\n\nc") == "a\n\nb\n\nc"
|
|
|
|
|
|
def test_sanitize_markdown_normalize_crlf():
|
|
result = main._sanitize_converted_markdown("line1\r\nline2\r\n")
|
|
assert "line1\nline2" in result
|
|
assert "\r" not in result
|
|
|
|
|
|
def test_get_client_ip_from_host():
|
|
req = DummyRequest(host="1.2.3.4", headers={})
|
|
assert main.get_client_ip(req) == "1.2.3.4"
|
|
|
|
|
|
def test_get_client_ip_header_overrides_host():
|
|
req = DummyRequest(host="1.2.3.4", headers={"X-Client-IP": "5.6.7.8"})
|
|
assert main.get_client_ip(req) == "5.6.7.8"
|
|
|
|
|
|
def test_get_client_ip_when_client_missing():
|
|
req = DummyRequest(host=None, headers={"X-Client-IP": "9.9.9.9"})
|
|
req.client = None
|
|
assert main.get_client_ip(req) == "9.9.9.9"
|
|
|
|
|
|
def test_post_completions_wrong_api_key_returns_401():
|
|
client = TestClient(main.app)
|
|
resp = client.post("/v1/completions", json={
|
|
"prefix": "hello", "suffix": "", "languageId": "markdown",
|
|
"model_thinking": "low", "privacy_mode": True,
|
|
})
|
|
assert resp.status_code == 401
|
|
|
|
|
|
def test_post_completions_privacy_mode(monkeypatch):
|
|
async def fake_call(*args, **kwargs):
|
|
return {"content": "done", "think": ""}
|
|
monkeypatch.setattr(main, "call_ollama", fake_call)
|
|
monkeypatch.setattr(main, "build_completion_prompts", lambda *a, **k: ("sys", "user"))
|
|
monkeypatch.setattr(main, "prepare_prompt_context", lambda *a, **k: ("p", "s"))
|
|
|
|
client = TestClient(main.app)
|
|
resp = client.post("/v1/completions", headers=HEADERS, json={
|
|
"prefix": "hello", "suffix": "", "languageId": "markdown",
|
|
"model_thinking": "low", "privacy_mode": True,
|
|
})
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data.get("content") == "done"
|
|
|
|
|
|
def test_post_pro_stream_returns_sse(monkeypatch):
|
|
captured = {}
|
|
|
|
async def fake_stream(*args, **kwargs):
|
|
captured["kwargs"] = kwargs
|
|
yield "深度"
|
|
yield "回答"
|
|
|
|
monkeypatch.setattr(main, "stream_ollama", fake_stream)
|
|
monkeypatch.setattr(main, "build_completion_prompts", lambda *a, **k: ("sys", "user"))
|
|
monkeypatch.setattr(main, "prepare_prompt_context", lambda *a, **k: ("p", "s"))
|
|
|
|
client = TestClient(main.app)
|
|
with client.stream("POST", "/v1/pro/completions/stream", headers=HEADERS, json={
|
|
"prefix": "hello",
|
|
"suffix": "",
|
|
"languageId": "markdown",
|
|
"model_thinking": "high",
|
|
"privacy_mode": True,
|
|
"model": "pro-model",
|
|
"temperature": 0.95,
|
|
}) as resp:
|
|
assert resp.status_code == 200
|
|
body = "".join(resp.iter_text())
|
|
|
|
assert "event: chunk" in body
|
|
assert "event: done" in body
|
|
assert "深度" in body
|
|
assert "回答" in body
|
|
assert captured["kwargs"]["model"] == "pro-model"
|
|
assert captured["kwargs"]["use_pro_model"] is True
|
|
assert main.ACTIVE_COMPLETIONS == {}
|
|
|
|
|
|
def test_post_ocr_mocked(monkeypatch):
|
|
async def fake_ocr(*args, **kwargs):
|
|
return "OCR result text"
|
|
monkeypatch.setattr(main, "call_vlm_ocr", fake_ocr)
|
|
|
|
client = TestClient(main.app)
|
|
img_b64 = base64.b64encode(b"pretend image data").decode()
|
|
resp = client.post("/v1/ocr", headers=HEADERS, json={
|
|
"image": img_b64, "filename": "test.jpg", "language": "auto",
|
|
})
|
|
assert resp.status_code == 200
|
|
j = resp.json()
|
|
assert j["text"] == "OCR result text"
|
|
assert j["filename"] == "test.jpg"
|
|
|
|
|
|
def test_post_ocr_invalid_base64_returns_500():
|
|
client = TestClient(main.app)
|
|
resp = client.post("/v1/ocr", headers=HEADERS, json={
|
|
"image": "not-base64!!!", "filename": "test.jpg",
|
|
})
|
|
assert resp.status_code == 500
|
|
|
|
|
|
def test_post_convert_txt_returns_markdown():
|
|
client = TestClient(main.app)
|
|
content = base64.b64encode(b"hello world").decode()
|
|
resp = client.post("/v1/convert", headers=HEADERS, json={
|
|
"file": content, "filename": "sample.txt",
|
|
})
|
|
assert resp.status_code == 200
|
|
j = resp.json()
|
|
assert j["markdown"] == "hello world"
|
|
assert j["filename"] == "sample.txt"
|
|
|
|
|
|
def test_post_convert_unsupported_extension_returns_500():
|
|
client = TestClient(main.app)
|
|
content = base64.b64encode(b"data").decode()
|
|
resp = client.post("/v1/convert", headers=HEADERS, json={
|
|
"file": content, "filename": "sample.xlsx",
|
|
})
|
|
assert resp.status_code == 500
|
|
assert "仅支持" in resp.json()["error"]
|
|
|
|
|
|
def test_post_convert_docx_with_mocked_markitdown(monkeypatch):
|
|
class FakeResult:
|
|
text_content = "markdown from docx"
|
|
class FakeMD:
|
|
def convert(self, path):
|
|
return FakeResult()
|
|
monkeypatch.setattr(main, "_get_markitdown", lambda: FakeMD())
|
|
|
|
client = TestClient(main.app)
|
|
content = base64.b64encode(b"docx content").decode()
|
|
resp = client.post("/v1/convert", headers=HEADERS, json={
|
|
"file": content, "filename": "sample.docx",
|
|
})
|
|
assert resp.status_code == 200
|
|
j = resp.json()
|
|
assert j["markdown"] == "markdown from docx"
|
|
|
|
|
|
def test_post_cancel_non_existent_returns_not_found():
|
|
client = TestClient(main.app)
|
|
resp = client.post("/v1/completions/cancel", headers=HEADERS, json={
|
|
"request_id": "non-existent", "reason": "abort",
|
|
})
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["cancelled"] is False
|
|
assert data["status"] == "not_found"
|
|
|
|
|
|
def test_post_cancel_wrong_api_key_returns_401():
|
|
client = TestClient(main.app)
|
|
resp = client.post("/v1/completions/cancel", json={
|
|
"request_id": "id", "reason": "abort",
|
|
})
|
|
assert resp.status_code == 401
|
|
|
|
|
|
def test_post_cancel_already_done(monkeypatch):
|
|
main.ACTIVE_COMPLETIONS.clear()
|
|
# Create a mock task that appears done
|
|
mock_task = MagicMock()
|
|
mock_task.done.return_value = True
|
|
mock_task.cancel = MagicMock()
|
|
main.ACTIVE_COMPLETIONS["done-id"] = mock_task
|
|
|
|
client = TestClient(main.app)
|
|
resp = client.post("/v1/completions/cancel", headers=HEADERS, json={
|
|
"request_id": "done-id", "reason": "abort",
|
|
})
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["cancelled"] is False
|
|
assert data["status"] == "already_done"
|
|
main.ACTIVE_COMPLETIONS.clear()
|