119 lines
3.7 KiB
Python
119 lines
3.7 KiB
Python
|
|
import asyncio
|
||
|
|
import importlib
|
||
|
|
import sys
|
||
|
|
import threading
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from fastapi.testclient import TestClient
|
||
|
|
|
||
|
|
|
||
|
|
BACKEND_DIR = Path(__file__).resolve().parents[1]
|
||
|
|
if str(BACKEND_DIR) not in sys.path:
|
||
|
|
sys.path.insert(0, str(BACKEND_DIR))
|
||
|
|
|
||
|
|
try:
|
||
|
|
main = importlib.import_module("main")
|
||
|
|
except ModuleNotFoundError:
|
||
|
|
pytest.skip("main module dependencies are not available", allow_module_level=True)
|
||
|
|
|
||
|
|
|
||
|
|
API_KEY_HEADERS = {"X-API-Key": "your-secret-key-here"}
|
||
|
|
|
||
|
|
|
||
|
|
def _completion_payload():
|
||
|
|
return {
|
||
|
|
"prefix": "hello",
|
||
|
|
"suffix": "",
|
||
|
|
"languageId": "markdown",
|
||
|
|
"model_thinking": "low",
|
||
|
|
"privacy_mode": True,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def test_cancel_endpoint_cancels_running_task(monkeypatch):
|
||
|
|
main.ACTIVE_COMPLETIONS.clear()
|
||
|
|
started = threading.Event()
|
||
|
|
cancelled = threading.Event()
|
||
|
|
|
||
|
|
async def fake_call_ollama(*args, **kwargs):
|
||
|
|
started.set()
|
||
|
|
try:
|
||
|
|
while True:
|
||
|
|
await asyncio.sleep(0.05)
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
cancelled.set()
|
||
|
|
raise
|
||
|
|
|
||
|
|
monkeypatch.setattr(main, "call_ollama", fake_call_ollama)
|
||
|
|
monkeypatch.setattr(main, "build_completion_prompts", lambda *a, **k: ("system", "user"))
|
||
|
|
monkeypatch.setattr(main, "prepare_prompt_context", lambda *a, **k: ("prefix", "suffix"))
|
||
|
|
|
||
|
|
with TestClient(main.app) as client:
|
||
|
|
request_id = "req-cancel-1"
|
||
|
|
completion_headers = {**API_KEY_HEADERS, "X-Request-Id": request_id}
|
||
|
|
response_box = {}
|
||
|
|
|
||
|
|
def send_completion():
|
||
|
|
response_box["response"] = client.post(
|
||
|
|
"/v1/completions",
|
||
|
|
headers=completion_headers,
|
||
|
|
json=_completion_payload(),
|
||
|
|
)
|
||
|
|
|
||
|
|
completion_thread = threading.Thread(target=send_completion, daemon=True)
|
||
|
|
completion_thread.start()
|
||
|
|
|
||
|
|
assert started.wait(timeout=2.0)
|
||
|
|
|
||
|
|
cancel_response = client.post(
|
||
|
|
"/v1/completions/cancel",
|
||
|
|
headers=API_KEY_HEADERS,
|
||
|
|
json={"request_id": request_id, "reason": "superseded"},
|
||
|
|
)
|
||
|
|
assert cancel_response.status_code == 200
|
||
|
|
assert cancel_response.json() == {"cancelled": True, "status": "ok"}
|
||
|
|
|
||
|
|
completion_thread.join(timeout=5.0)
|
||
|
|
assert not completion_thread.is_alive()
|
||
|
|
assert cancelled.wait(timeout=2.0)
|
||
|
|
|
||
|
|
completion_response = response_box["response"]
|
||
|
|
assert completion_response.status_code == 200
|
||
|
|
assert '"cancelled": true' in completion_response.text
|
||
|
|
|
||
|
|
|
||
|
|
def test_cancel_not_found():
|
||
|
|
main.ACTIVE_COMPLETIONS.clear()
|
||
|
|
with TestClient(main.app) as client:
|
||
|
|
response = client.post(
|
||
|
|
"/v1/completions/cancel",
|
||
|
|
headers=API_KEY_HEADERS,
|
||
|
|
json={"request_id": "missing", "reason": "abort"},
|
||
|
|
)
|
||
|
|
assert response.status_code == 200
|
||
|
|
assert response.json() == {"cancelled": False, "status": "not_found"}
|
||
|
|
|
||
|
|
|
||
|
|
def test_completion_normal_flow(monkeypatch):
|
||
|
|
main.ACTIVE_COMPLETIONS.clear()
|
||
|
|
|
||
|
|
async def fake_call_ollama(*args, **kwargs):
|
||
|
|
return {"content": "completion text", "think": ""}
|
||
|
|
|
||
|
|
monkeypatch.setattr(main, "call_ollama", fake_call_ollama)
|
||
|
|
monkeypatch.setattr(main, "build_completion_prompts", lambda *a, **k: ("system", "user"))
|
||
|
|
monkeypatch.setattr(main, "prepare_prompt_context", lambda *a, **k: ("prefix", "suffix"))
|
||
|
|
|
||
|
|
with TestClient(main.app) as client:
|
||
|
|
response = client.post(
|
||
|
|
"/v1/completions",
|
||
|
|
headers=API_KEY_HEADERS,
|
||
|
|
json=_completion_payload(),
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
assert '"content": "completion text"' in response.text
|
||
|
|
assert '"done": true' in response.text
|
||
|
|
assert main.ACTIVE_COMPLETIONS == {}
|