Added a new `.coveragerc` file configuring coverage thresholds and exclusions. Included `pytest.ini` to enable coverage reporting for multiple backend modules (`main`, `llm`, `prompt`, `geoip`, `tts_asr`) with a 90 % fail‑under requirement and detailed HTML output. Implemented a suite of unit tests: * `test_geoip.py` – validates geo‑location lookup logic. * `test_llm_extended.py` – tests LLm response extraction and Ollama interactions. * `test_main_endpoints.py` – covers API endpoints for completions, OCR, and TTS. * `test_prompt_extended.py` – verifies language sanitization, timestamp generation, and prompt building. * `test_tts_asr_coverage.py` – checks device detection, cache clearing, and model loading under various environment configurations. * `test_tts_asr_extended.py` – further tests TTS/ASR device selection and time‑outs. Updated `backend/requirements.txt` to use newer, compatible packages, removed obsolete testing dependencies, and added `qwen-tts`. Modified `backend/tts_asr.py` to work with the new `Qwen3TTSModel`, simplified imports, and adjusted device mapping logic. Additionally, frontend changes added a new `TreeNodeItem` component, updated Markdown rendering, added TTS instruction fields, and reworked context menu handling. No breaking changes were introduced.
216 lines
6.5 KiB
Python
216 lines
6.5 KiB
Python
import os
|
|
import sys
|
|
import base64
|
|
import asyncio
|
|
import pytest
|
|
from unittest.mock import MagicMock
|
|
from fastapi.testclient import TestClient
|
|
|
|
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
BACKEND_DIR = os.path.abspath(os.path.join(CURRENT_DIR, ".."))
|
|
if BACKEND_DIR not in sys.path:
|
|
sys.path.insert(0, BACKEND_DIR)
|
|
|
|
import main # type: ignore
|
|
|
|
API_KEY = main.API_KEY
|
|
HEADERS = {"X-API-Key": API_KEY}
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _clear_active_completions():
|
|
main.ACTIVE_COMPLETIONS.clear()
|
|
yield
|
|
main.ACTIVE_COMPLETIONS.clear()
|
|
|
|
|
|
class DummyRequest:
|
|
def __init__(self, host=None, headers=None):
|
|
class Client:
|
|
pass
|
|
self.client = Client() if host is not None else None
|
|
if self.client is not None:
|
|
self.client.host = host
|
|
self.headers = headers or {}
|
|
|
|
|
|
def test_preview_short_text():
|
|
assert main._preview("Hello") == "Hello"
|
|
|
|
|
|
def test_preview_long_text_truncated():
|
|
long_text = "a" * 100
|
|
assert main._preview(long_text) == long_text[:80] + "..."
|
|
|
|
|
|
def test_preview_none_input():
|
|
assert main._preview(None) == ""
|
|
|
|
|
|
def test_preview_newlines_replaced():
|
|
assert main._preview("line1\nline2") == "line1\\nline2"
|
|
|
|
|
|
def test_sanitize_markdown_strips_image_markdown():
|
|
assert "" not in main._sanitize_converted_markdown(
|
|
"text with image  end"
|
|
)
|
|
|
|
|
|
def test_sanitize_markdown_strips_img_tag():
|
|
assert "<img" not in main._sanitize_converted_markdown("<img src='x.png'/>")
|
|
|
|
|
|
def test_sanitize_markdown_collapse_newlines():
|
|
assert main._sanitize_converted_markdown("a\n\n\nb\n\n\n\nc") == "a\n\nb\n\nc"
|
|
|
|
|
|
def test_sanitize_markdown_normalize_crlf():
|
|
result = main._sanitize_converted_markdown("line1\r\nline2\r\n")
|
|
assert "line1\nline2" in result
|
|
assert "\r" not in result
|
|
|
|
|
|
def test_get_client_ip_from_host():
|
|
req = DummyRequest(host="1.2.3.4", headers={})
|
|
assert main.get_client_ip(req) == "1.2.3.4"
|
|
|
|
|
|
def test_get_client_ip_header_overrides_host():
|
|
req = DummyRequest(host="1.2.3.4", headers={"X-Client-IP": "5.6.7.8"})
|
|
assert main.get_client_ip(req) == "5.6.7.8"
|
|
|
|
|
|
def test_get_client_ip_when_client_missing():
|
|
req = DummyRequest(host=None, headers={"X-Client-IP": "9.9.9.9"})
|
|
req.client = None
|
|
assert main.get_client_ip(req) == "9.9.9.9"
|
|
|
|
|
|
def test_post_completions_wrong_api_key_returns_401():
|
|
client = TestClient(main.app)
|
|
resp = client.post("/v1/completions", json={
|
|
"prefix": "hello", "suffix": "", "languageId": "markdown",
|
|
"model_thinking": "low", "privacy_mode": True,
|
|
})
|
|
assert resp.status_code == 401
|
|
|
|
|
|
def test_post_completions_privacy_mode(monkeypatch):
|
|
async def fake_call(*args, **kwargs):
|
|
return {"content": "done", "think": ""}
|
|
monkeypatch.setattr(main, "call_ollama", fake_call)
|
|
monkeypatch.setattr(main, "build_completion_prompts", lambda *a, **k: ("sys", "user"))
|
|
monkeypatch.setattr(main, "prepare_prompt_context", lambda *a, **k: ("p", "s"))
|
|
|
|
client = TestClient(main.app)
|
|
resp = client.post("/v1/completions", headers=HEADERS, json={
|
|
"prefix": "hello", "suffix": "", "languageId": "markdown",
|
|
"model_thinking": "low", "privacy_mode": True,
|
|
})
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data.get("content") == "done"
|
|
|
|
|
|
def test_post_ocr_mocked(monkeypatch):
|
|
async def fake_ocr(*args, **kwargs):
|
|
return "OCR result text"
|
|
monkeypatch.setattr(main, "call_vlm_ocr", fake_ocr)
|
|
|
|
client = TestClient(main.app)
|
|
img_b64 = base64.b64encode(b"pretend image data").decode()
|
|
resp = client.post("/v1/ocr", headers=HEADERS, json={
|
|
"image": img_b64, "filename": "test.jpg", "language": "auto",
|
|
})
|
|
assert resp.status_code == 200
|
|
j = resp.json()
|
|
assert j["text"] == "OCR result text"
|
|
assert j["filename"] == "test.jpg"
|
|
|
|
|
|
def test_post_ocr_invalid_base64_returns_500():
|
|
client = TestClient(main.app)
|
|
resp = client.post("/v1/ocr", headers=HEADERS, json={
|
|
"image": "not-base64!!!", "filename": "test.jpg",
|
|
})
|
|
assert resp.status_code == 500
|
|
|
|
|
|
def test_post_convert_txt_returns_markdown():
|
|
client = TestClient(main.app)
|
|
content = base64.b64encode(b"hello world").decode()
|
|
resp = client.post("/v1/convert", headers=HEADERS, json={
|
|
"file": content, "filename": "sample.txt",
|
|
})
|
|
assert resp.status_code == 200
|
|
j = resp.json()
|
|
assert j["markdown"] == "hello world"
|
|
assert j["filename"] == "sample.txt"
|
|
|
|
|
|
def test_post_convert_unsupported_extension_returns_500():
|
|
client = TestClient(main.app)
|
|
content = base64.b64encode(b"data").decode()
|
|
resp = client.post("/v1/convert", headers=HEADERS, json={
|
|
"file": content, "filename": "sample.xlsx",
|
|
})
|
|
assert resp.status_code == 500
|
|
assert "仅支持" in resp.json()["error"]
|
|
|
|
|
|
def test_post_convert_docx_with_mocked_markitdown(monkeypatch):
|
|
class FakeResult:
|
|
text_content = "markdown from docx"
|
|
class FakeMD:
|
|
def convert(self, path):
|
|
return FakeResult()
|
|
monkeypatch.setattr(main, "_get_markitdown", lambda: FakeMD())
|
|
|
|
client = TestClient(main.app)
|
|
content = base64.b64encode(b"docx content").decode()
|
|
resp = client.post("/v1/convert", headers=HEADERS, json={
|
|
"file": content, "filename": "sample.docx",
|
|
})
|
|
assert resp.status_code == 200
|
|
j = resp.json()
|
|
assert j["markdown"] == "markdown from docx"
|
|
|
|
|
|
def test_post_cancel_non_existent_returns_not_found():
|
|
client = TestClient(main.app)
|
|
resp = client.post("/v1/completions/cancel", headers=HEADERS, json={
|
|
"request_id": "non-existent", "reason": "abort",
|
|
})
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["cancelled"] is False
|
|
assert data["status"] == "not_found"
|
|
|
|
|
|
def test_post_cancel_wrong_api_key_returns_401():
|
|
client = TestClient(main.app)
|
|
resp = client.post("/v1/completions/cancel", json={
|
|
"request_id": "id", "reason": "abort",
|
|
})
|
|
assert resp.status_code == 401
|
|
|
|
|
|
def test_post_cancel_already_done(monkeypatch):
|
|
main.ACTIVE_COMPLETIONS.clear()
|
|
# Create a mock task that appears done
|
|
mock_task = MagicMock()
|
|
mock_task.done.return_value = True
|
|
mock_task.cancel = MagicMock()
|
|
main.ACTIVE_COMPLETIONS["done-id"] = mock_task
|
|
|
|
client = TestClient(main.app)
|
|
resp = client.post("/v1/completions/cancel", headers=HEADERS, json={
|
|
"request_id": "done-id", "reason": "abort",
|
|
})
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["cancelled"] is False
|
|
assert data["status"] == "already_done"
|
|
main.ACTIVE_COMPLETIONS.clear()
|