Files
llm-in-text/backend/tests/test_tts_asr_coverage.py

194 lines
6.1 KiB
Python

import os
import sys
import asyncio
import types
import pytest
from pathlib import Path
from unittest.mock import MagicMock, patch
BACKEND_DIR = Path(__file__).resolve().parents[1]
if str(BACKEND_DIR) not in sys.path:
sys.path.insert(0, str(BACKEND_DIR))
def _make_mlx_stub():
"""Create minimal MLX stub for testing without Apple Silicon"""
mlx = types.SimpleNamespace()
mlx.core = types.SimpleNamespace()
mx_array = type('mx.array', (), {'item': lambda self: 1})
mlx.core.array = mx_array
mlx.nn = types.SimpleNamespace()
return mlx
def _make_mlx_audio_stub():
"""Create minimal mlx-audio stub"""
stt = types.SimpleNamespace()
stt.utils = types.SimpleNamespace()
def mock_load(path, **kwargs):
model = MagicMock()
return model
stt.utils.load = mock_load # type: ignore
qwen3_asr_mod = types.SimpleNamespace()
qwen3_asr_mod.Qwen3ASRModel = type('Qwen3ASRModel', (), {})
qwen3_asr_mod.ForcedAlignerModel = type('ForcedAlignerModel', (), {})
stt.models = types.SimpleNamespace() # type: ignore
stt.models.qwen3_asr = qwen3_asr_mod # type: ignore
audio = types.SimpleNamespace()
audio.stt = stt # type: ignore
return audio
def _reload_tts_asr_with_mocks():
"""Reload tts_asr with mocked MLX dependencies"""
for mod_name in list(sys.modules.keys()):
if 'tts_asr' in mod_name or 'mlx' in mod_name:
del sys.modules[mod_name]
mlx_stub = _make_mlx_stub()
sys.modules['mlx'] = mlx_stub # type: ignore
sys.modules['mlx.core'] = mlx_stub.core # type: ignore
sys.modules['mlx.nn'] = mlx_stub.nn # type: ignore
audio_stub = _make_mlx_audio_stub()
sys.modules['mlx-audio'] = audio_stub # type: ignore
sys.modules['mlx_audio'] = audio_stub # type: ignore
sys.modules['mlx_audio.stt'] = audio_stub.stt # type: ignore
sys.modules['mlx_audio.stt.utils'] = audio_stub.stt.utils # type: ignore
sys.modules['mlx_audio.stt.models'] = audio_stub.stt.models # type: ignore
sys.modules['mlx_audio.stt.models.qwen3_asr'] = audio_stub.stt.models.qwen3_asr # type: ignore
import tts_asr
return tts_asr
@pytest.fixture(autouse=True)
def _clean_env():
"""Clean ASR-related env vars before/after each test"""
saved = {}
for k in ['HF_ENDPOINT']:
saved[k] = os.environ.get(k)
if k in os.environ:
del os.environ[k]
yield
for k, v in saved.items():
if v is not None:
os.environ[k] = v # type: ignore (unused var)
class TestRequestResponseModels:
"""Pydantic 数据模型测试"""
def test_tts_request_defaults(self):
tts = _reload_tts_asr_with_mocks()
req = tts.TTSRequest(text="hello")
assert req.text == "hello"
assert req.speaker == "Vivian"
def test_asr_request_defaults(self):
tts = _reload_tts_asr_with_mocks()
req = tts.ASRRequest(audio_base64="dGVzdA==")
assert req.audio_base64 == "dGVzdA=="
assert req.language == "zh-CN"
def test_asr_request_custom_language(self):
tts = _reload_tts_asr_with_mocks()
req = tts.ASRRequest(audio_base64="dGVzdA==", language="en")
assert req.language == "en"
def test_model_status_defaults(self):
tts = _reload_tts_asr_with_mocks()
status = tts.ModelStatus(tts_loaded=False, asr_loaded=True, device="cpu")
assert not status.tts_loaded
assert status.asr_loaded
class TestDeviceDetection:
"""设备检测测试"""
def test_device_map_returns_string(self):
tts = _reload_tts_asr_with_mocks()
device = tts._get_device_map()
assert isinstance(device, str)
class TestModelLoading:
"""模型加载测试"""
def test_load_asr_skips_when_mlx_unavailable(self):
"""mlx_audio 未安装时应跳过 ASR"""
for mod_name in list(sys.modules.keys()):
if 'tts_asr' in mod_name or 'mlx' in mod_name:
del sys.modules[mod_name]
# Don't inject mlx stubs — simulate missing MLX
import tts_asr # noqa: F811
assert tts_asr.Qwen3ASRModel is None
tts_asr._load_asr_models() # should not crash
assert tts_asr._asr_model is None
def test_load_asr_from_path_success(self):
tts = _reload_tts_asr_with_mocks()
# Mock snapshot_download to return a path, mock stt_load to succeed
with patch('backend.tts_asr.snapshot_download', return_value='/fake/path'): # type: ignore
tts._load_asr_from_path('/fake/path')
assert tts._asr_model is not None # type: ignore (MagicMock)
class TestWarmupFunctions:
"""预热函数测试"""
def test_warmup_functions_callable(self):
tts = _reload_tts_asr_with_mocks()
assert callable(tts._warmup_tts) # type: ignore (unused var)
assert callable(tts._warmup_all)
def test_warmup_asr_skips_when_mlx_unavailable(self):
for mod_name in list(sys.modules.keys()):
if 'tts_asr' in mod_name or 'mlx' in mod_name:
del sys.modules[mod_name]
import tts_asr # noqa: F811
assert tts_asr.Qwen3ASRModel is None
def test_warmup_all_runs_without_error(self):
tts = _reload_tts_asr_with_mocks()
# Set global models so warmup returns immediately without actual loading
tts._tts_model = MagicMock()
async def run(): # type: ignore (unused var)
await tts._warmup_all()
asyncio.get_event_loop().run_until_complete(run()) # type: ignore
class TestRouteRegistration:
"""路由注册测试"""
def test_register_function_exists(self):
tts = _reload_tts_asr_with_mocks()
assert callable(tts.register_tts_asr_routes)
def test_router_prefix(self):
tts = _reload_tts_asr_with_mocks()
assert hasattr(tts.router, 'routes')
class TestModelConstants:
"""模型常量测试"""
def test_asr_model_id(self):
tts = _reload_tts_asr_with_mocks()
assert 'Qwen3-ASR' in tts.ASR_MODEL_ID_MS
def test_align_model_id(self):
tts = _reload_tts_asr_with_mocks()
assert 'ForcedAligner' in tts.ALIGN_MODEL_ID_MS