Files
llm-in-text/backend/tests/test_main_endpoints.py
ydy0615 2fdc996af9 test(backend): add comprehensive test coverage for backend modules
Added a new `.coveragerc` file configuring coverage thresholds and exclusions.
Included `pytest.ini` to enable coverage reporting for multiple backend modules (`main`, `llm`, `prompt`, `geoip`, `tts_asr`) with a 90 % fail‑under requirement and detailed HTML output.
Implemented a suite of unit tests:

* `test_geoip.py` – validates geo‑location lookup logic.
* `test_llm_extended.py` – tests LLm response extraction and Ollama interactions.
* `test_main_endpoints.py` – covers API endpoints for completions, OCR, and TTS.
* `test_prompt_extended.py` – verifies language sanitization, timestamp generation, and prompt building.
* `test_tts_asr_coverage.py` – checks device detection, cache clearing, and model loading under various environment configurations.
* `test_tts_asr_extended.py` – further tests TTS/ASR device selection and time‑outs.

Updated `backend/requirements.txt` to use newer, compatible packages, removed obsolete testing dependencies, and added `qwen-tts`.
Modified `backend/tts_asr.py` to work with the new `Qwen3TTSModel`, simplified imports, and adjusted device mapping logic.

Additionally, frontend changes added a new `TreeNodeItem` component, updated Markdown rendering, added TTS instruction fields, and reworked context menu handling.

No breaking changes were introduced.
2026-04-07 23:38:23 +08:00

216 lines
6.5 KiB
Python

import os
import sys
import base64
import asyncio
import pytest
from unittest.mock import MagicMock
from fastapi.testclient import TestClient
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
BACKEND_DIR = os.path.abspath(os.path.join(CURRENT_DIR, ".."))
if BACKEND_DIR not in sys.path:
sys.path.insert(0, BACKEND_DIR)
import main # type: ignore
API_KEY = main.API_KEY
HEADERS = {"X-API-Key": API_KEY}
@pytest.fixture(autouse=True)
def _clear_active_completions():
main.ACTIVE_COMPLETIONS.clear()
yield
main.ACTIVE_COMPLETIONS.clear()
class DummyRequest:
def __init__(self, host=None, headers=None):
class Client:
pass
self.client = Client() if host is not None else None
if self.client is not None:
self.client.host = host
self.headers = headers or {}
def test_preview_short_text():
assert main._preview("Hello") == "Hello"
def test_preview_long_text_truncated():
long_text = "a" * 100
assert main._preview(long_text) == long_text[:80] + "..."
def test_preview_none_input():
assert main._preview(None) == ""
def test_preview_newlines_replaced():
assert main._preview("line1\nline2") == "line1\\nline2"
def test_sanitize_markdown_strips_image_markdown():
assert "![alt](image.png)" not in main._sanitize_converted_markdown(
"text with image ![alt](image.png) end"
)
def test_sanitize_markdown_strips_img_tag():
assert "<img" not in main._sanitize_converted_markdown("<img src='x.png'/>")
def test_sanitize_markdown_collapse_newlines():
assert main._sanitize_converted_markdown("a\n\n\nb\n\n\n\nc") == "a\n\nb\n\nc"
def test_sanitize_markdown_normalize_crlf():
result = main._sanitize_converted_markdown("line1\r\nline2\r\n")
assert "line1\nline2" in result
assert "\r" not in result
def test_get_client_ip_from_host():
req = DummyRequest(host="1.2.3.4", headers={})
assert main.get_client_ip(req) == "1.2.3.4"
def test_get_client_ip_header_overrides_host():
req = DummyRequest(host="1.2.3.4", headers={"X-Client-IP": "5.6.7.8"})
assert main.get_client_ip(req) == "5.6.7.8"
def test_get_client_ip_when_client_missing():
req = DummyRequest(host=None, headers={"X-Client-IP": "9.9.9.9"})
req.client = None
assert main.get_client_ip(req) == "9.9.9.9"
def test_post_completions_wrong_api_key_returns_401():
client = TestClient(main.app)
resp = client.post("/v1/completions", json={
"prefix": "hello", "suffix": "", "languageId": "markdown",
"model_thinking": "low", "privacy_mode": True,
})
assert resp.status_code == 401
def test_post_completions_privacy_mode(monkeypatch):
async def fake_call(*args, **kwargs):
return {"content": "done", "think": ""}
monkeypatch.setattr(main, "call_ollama", fake_call)
monkeypatch.setattr(main, "build_completion_prompts", lambda *a, **k: ("sys", "user"))
monkeypatch.setattr(main, "prepare_prompt_context", lambda *a, **k: ("p", "s"))
client = TestClient(main.app)
resp = client.post("/v1/completions", headers=HEADERS, json={
"prefix": "hello", "suffix": "", "languageId": "markdown",
"model_thinking": "low", "privacy_mode": True,
})
assert resp.status_code == 200
data = resp.json()
assert data.get("content") == "done"
def test_post_ocr_mocked(monkeypatch):
async def fake_ocr(*args, **kwargs):
return "OCR result text"
monkeypatch.setattr(main, "call_vlm_ocr", fake_ocr)
client = TestClient(main.app)
img_b64 = base64.b64encode(b"pretend image data").decode()
resp = client.post("/v1/ocr", headers=HEADERS, json={
"image": img_b64, "filename": "test.jpg", "language": "auto",
})
assert resp.status_code == 200
j = resp.json()
assert j["text"] == "OCR result text"
assert j["filename"] == "test.jpg"
def test_post_ocr_invalid_base64_returns_500():
client = TestClient(main.app)
resp = client.post("/v1/ocr", headers=HEADERS, json={
"image": "not-base64!!!", "filename": "test.jpg",
})
assert resp.status_code == 500
def test_post_convert_txt_returns_markdown():
client = TestClient(main.app)
content = base64.b64encode(b"hello world").decode()
resp = client.post("/v1/convert", headers=HEADERS, json={
"file": content, "filename": "sample.txt",
})
assert resp.status_code == 200
j = resp.json()
assert j["markdown"] == "hello world"
assert j["filename"] == "sample.txt"
def test_post_convert_unsupported_extension_returns_500():
client = TestClient(main.app)
content = base64.b64encode(b"data").decode()
resp = client.post("/v1/convert", headers=HEADERS, json={
"file": content, "filename": "sample.xlsx",
})
assert resp.status_code == 500
assert "仅支持" in resp.json()["error"]
def test_post_convert_docx_with_mocked_markitdown(monkeypatch):
class FakeResult:
text_content = "markdown from docx"
class FakeMD:
def convert(self, path):
return FakeResult()
monkeypatch.setattr(main, "_get_markitdown", lambda: FakeMD())
client = TestClient(main.app)
content = base64.b64encode(b"docx content").decode()
resp = client.post("/v1/convert", headers=HEADERS, json={
"file": content, "filename": "sample.docx",
})
assert resp.status_code == 200
j = resp.json()
assert j["markdown"] == "markdown from docx"
def test_post_cancel_non_existent_returns_not_found():
client = TestClient(main.app)
resp = client.post("/v1/completions/cancel", headers=HEADERS, json={
"request_id": "non-existent", "reason": "abort",
})
assert resp.status_code == 200
data = resp.json()
assert data["cancelled"] is False
assert data["status"] == "not_found"
def test_post_cancel_wrong_api_key_returns_401():
client = TestClient(main.app)
resp = client.post("/v1/completions/cancel", json={
"request_id": "id", "reason": "abort",
})
assert resp.status_code == 401
def test_post_cancel_already_done(monkeypatch):
main.ACTIVE_COMPLETIONS.clear()
# Create a mock task that appears done
mock_task = MagicMock()
mock_task.done.return_value = True
mock_task.cancel = MagicMock()
main.ACTIVE_COMPLETIONS["done-id"] = mock_task
client = TestClient(main.app)
resp = client.post("/v1/completions/cancel", headers=HEADERS, json={
"request_id": "done-id", "reason": "abort",
})
assert resp.status_code == 200
data = resp.json()
assert data["cancelled"] is False
assert data["status"] == "already_done"
main.ACTIVE_COMPLETIONS.clear()