import os import sys import base64 import asyncio import pytest from unittest.mock import MagicMock from fastapi.testclient import TestClient CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) BACKEND_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "..")) if BACKEND_DIR not in sys.path: sys.path.insert(0, BACKEND_DIR) import main # type: ignore API_KEY = main.API_KEY HEADERS = {"X-API-Key": API_KEY} @pytest.fixture(autouse=True) def _clear_active_completions(): main.ACTIVE_COMPLETIONS.clear() yield main.ACTIVE_COMPLETIONS.clear() class DummyRequest: def __init__(self, host=None, headers=None): class Client: pass self.client = Client() if host is not None else None if self.client is not None: self.client.host = host self.headers = headers or {} def test_preview_short_text(): assert main._preview("Hello") == "Hello" def test_preview_long_text_truncated(): long_text = "a" * 100 assert main._preview(long_text) == long_text[:80] + "..." def test_preview_none_input(): assert main._preview(None) == "" def test_preview_newlines_replaced(): assert main._preview("line1\nline2") == "line1\\nline2" def test_sanitize_markdown_strips_image_markdown(): assert "![alt](image.png)" not in main._sanitize_converted_markdown( "text with image ![alt](image.png) end" ) def test_sanitize_markdown_strips_img_tag(): assert "") def test_sanitize_markdown_collapse_newlines(): assert main._sanitize_converted_markdown("a\n\n\nb\n\n\n\nc") == "a\n\nb\n\nc" def test_sanitize_markdown_normalize_crlf(): result = main._sanitize_converted_markdown("line1\r\nline2\r\n") assert "line1\nline2" in result assert "\r" not in result def test_get_client_ip_from_host(): req = DummyRequest(host="1.2.3.4", headers={}) assert main.get_client_ip(req) == "1.2.3.4" def test_get_client_ip_header_overrides_host(): req = DummyRequest(host="1.2.3.4", headers={"X-Client-IP": "5.6.7.8"}) assert main.get_client_ip(req) == "5.6.7.8" def test_get_client_ip_when_client_missing(): req = DummyRequest(host=None, headers={"X-Client-IP": "9.9.9.9"}) req.client = None assert main.get_client_ip(req) == "9.9.9.9" def test_post_completions_wrong_api_key_returns_401(): client = TestClient(main.app) resp = client.post("/v1/completions", json={ "prefix": "hello", "suffix": "", "languageId": "markdown", "model_thinking": "low", "privacy_mode": True, }) assert resp.status_code == 401 def test_post_completions_privacy_mode(monkeypatch): async def fake_call(*args, **kwargs): return {"content": "done", "think": ""} monkeypatch.setattr(main, "call_ollama", fake_call) monkeypatch.setattr(main, "build_completion_prompts", lambda *a, **k: ("sys", "user")) monkeypatch.setattr(main, "prepare_prompt_context", lambda *a, **k: ("p", "s")) client = TestClient(main.app) resp = client.post("/v1/completions", headers=HEADERS, json={ "prefix": "hello", "suffix": "", "languageId": "markdown", "model_thinking": "low", "privacy_mode": True, }) assert resp.status_code == 200 data = resp.json() assert data.get("content") == "done" def test_post_ocr_mocked(monkeypatch): async def fake_ocr(*args, **kwargs): return "OCR result text" monkeypatch.setattr(main, "call_vlm_ocr", fake_ocr) client = TestClient(main.app) img_b64 = base64.b64encode(b"pretend image data").decode() resp = client.post("/v1/ocr", headers=HEADERS, json={ "image": img_b64, "filename": "test.jpg", "language": "auto", }) assert resp.status_code == 200 j = resp.json() assert j["text"] == "OCR result text" assert j["filename"] == "test.jpg" def test_post_ocr_invalid_base64_returns_500(): client = TestClient(main.app) resp = client.post("/v1/ocr", headers=HEADERS, json={ "image": "not-base64!!!", "filename": "test.jpg", }) assert resp.status_code == 500 def test_post_convert_txt_returns_markdown(): client = TestClient(main.app) content = base64.b64encode(b"hello world").decode() resp = client.post("/v1/convert", headers=HEADERS, json={ "file": content, "filename": "sample.txt", }) assert resp.status_code == 200 j = resp.json() assert j["markdown"] == "hello world" assert j["filename"] == "sample.txt" def test_post_convert_unsupported_extension_returns_500(): client = TestClient(main.app) content = base64.b64encode(b"data").decode() resp = client.post("/v1/convert", headers=HEADERS, json={ "file": content, "filename": "sample.xlsx", }) assert resp.status_code == 500 assert "仅支持" in resp.json()["error"] def test_post_convert_docx_with_mocked_markitdown(monkeypatch): class FakeResult: text_content = "markdown from docx" class FakeMD: def convert(self, path): return FakeResult() monkeypatch.setattr(main, "_get_markitdown", lambda: FakeMD()) client = TestClient(main.app) content = base64.b64encode(b"docx content").decode() resp = client.post("/v1/convert", headers=HEADERS, json={ "file": content, "filename": "sample.docx", }) assert resp.status_code == 200 j = resp.json() assert j["markdown"] == "markdown from docx" def test_post_cancel_non_existent_returns_not_found(): client = TestClient(main.app) resp = client.post("/v1/completions/cancel", headers=HEADERS, json={ "request_id": "non-existent", "reason": "abort", }) assert resp.status_code == 200 data = resp.json() assert data["cancelled"] is False assert data["status"] == "not_found" def test_post_cancel_wrong_api_key_returns_401(): client = TestClient(main.app) resp = client.post("/v1/completions/cancel", json={ "request_id": "id", "reason": "abort", }) assert resp.status_code == 401 def test_post_cancel_already_done(monkeypatch): main.ACTIVE_COMPLETIONS.clear() # Create a mock task that appears done mock_task = MagicMock() mock_task.done.return_value = True mock_task.cancel = MagicMock() main.ACTIVE_COMPLETIONS["done-id"] = mock_task client = TestClient(main.app) resp = client.post("/v1/completions/cancel", headers=HEADERS, json={ "request_id": "done-id", "reason": "abort", }) assert resp.status_code == 200 data = resp.json() assert data["cancelled"] is False assert data["status"] == "already_done" main.ACTIVE_COMPLETIONS.clear()