194 lines
6.1 KiB
Python
194 lines
6.1 KiB
Python
import os
|
|
import sys
|
|
import asyncio
|
|
import types
|
|
import pytest
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
BACKEND_DIR = Path(__file__).resolve().parents[1]
|
|
if str(BACKEND_DIR) not in sys.path:
|
|
sys.path.insert(0, str(BACKEND_DIR))
|
|
|
|
|
|
def _make_mlx_stub():
|
|
"""Create minimal MLX stub for testing without Apple Silicon"""
|
|
mlx = types.SimpleNamespace()
|
|
mlx.core = types.SimpleNamespace()
|
|
mx_array = type('mx.array', (), {'item': lambda self: 1})
|
|
mlx.core.array = mx_array
|
|
mlx.nn = types.SimpleNamespace()
|
|
return mlx
|
|
|
|
|
|
def _make_mlx_audio_stub():
|
|
"""Create minimal mlx-audio stub"""
|
|
stt = types.SimpleNamespace()
|
|
stt.utils = types.SimpleNamespace()
|
|
|
|
def mock_load(path, **kwargs):
|
|
model = MagicMock()
|
|
return model
|
|
|
|
stt.utils.load = mock_load # type: ignore
|
|
|
|
qwen3_asr_mod = types.SimpleNamespace()
|
|
qwen3_asr_mod.Qwen3ASRModel = type('Qwen3ASRModel', (), {})
|
|
qwen3_asr_mod.ForcedAlignerModel = type('ForcedAlignerModel', (), {})
|
|
stt.models = types.SimpleNamespace() # type: ignore
|
|
stt.models.qwen3_asr = qwen3_asr_mod # type: ignore
|
|
|
|
audio = types.SimpleNamespace()
|
|
audio.stt = stt # type: ignore
|
|
return audio
|
|
|
|
|
|
def _reload_tts_asr_with_mocks():
|
|
"""Reload tts_asr with mocked MLX dependencies"""
|
|
for mod_name in list(sys.modules.keys()):
|
|
if 'tts_asr' in mod_name or 'mlx' in mod_name:
|
|
del sys.modules[mod_name]
|
|
|
|
mlx_stub = _make_mlx_stub()
|
|
sys.modules['mlx'] = mlx_stub # type: ignore
|
|
sys.modules['mlx.core'] = mlx_stub.core # type: ignore
|
|
sys.modules['mlx.nn'] = mlx_stub.nn # type: ignore
|
|
|
|
audio_stub = _make_mlx_audio_stub()
|
|
sys.modules['mlx-audio'] = audio_stub # type: ignore
|
|
sys.modules['mlx_audio'] = audio_stub # type: ignore
|
|
sys.modules['mlx_audio.stt'] = audio_stub.stt # type: ignore
|
|
sys.modules['mlx_audio.stt.utils'] = audio_stub.stt.utils # type: ignore
|
|
sys.modules['mlx_audio.stt.models'] = audio_stub.stt.models # type: ignore
|
|
sys.modules['mlx_audio.stt.models.qwen3_asr'] = audio_stub.stt.models.qwen3_asr # type: ignore
|
|
|
|
import tts_asr
|
|
return tts_asr
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _clean_env():
|
|
"""Clean ASR-related env vars before/after each test"""
|
|
saved = {}
|
|
for k in ['HF_ENDPOINT']:
|
|
saved[k] = os.environ.get(k)
|
|
if k in os.environ:
|
|
del os.environ[k]
|
|
yield
|
|
for k, v in saved.items():
|
|
if v is not None:
|
|
os.environ[k] = v # type: ignore (unused var)
|
|
|
|
|
|
class TestRequestResponseModels:
|
|
"""Pydantic 数据模型测试"""
|
|
|
|
def test_tts_request_defaults(self):
|
|
tts = _reload_tts_asr_with_mocks()
|
|
req = tts.TTSRequest(text="hello")
|
|
assert req.text == "hello"
|
|
assert req.speaker == "Vivian"
|
|
|
|
def test_asr_request_defaults(self):
|
|
tts = _reload_tts_asr_with_mocks()
|
|
req = tts.ASRRequest(audio_base64="dGVzdA==")
|
|
assert req.audio_base64 == "dGVzdA=="
|
|
assert req.language == "zh-CN"
|
|
|
|
def test_asr_request_custom_language(self):
|
|
tts = _reload_tts_asr_with_mocks()
|
|
req = tts.ASRRequest(audio_base64="dGVzdA==", language="en")
|
|
assert req.language == "en"
|
|
|
|
def test_model_status_defaults(self):
|
|
tts = _reload_tts_asr_with_mocks()
|
|
status = tts.ModelStatus(tts_loaded=False, asr_loaded=True, device="cpu")
|
|
assert not status.tts_loaded
|
|
assert status.asr_loaded
|
|
|
|
|
|
class TestDeviceDetection:
|
|
"""设备检测测试"""
|
|
|
|
def test_device_map_returns_string(self):
|
|
tts = _reload_tts_asr_with_mocks()
|
|
device = tts._get_device_map()
|
|
assert isinstance(device, str)
|
|
|
|
|
|
class TestModelLoading:
|
|
"""模型加载测试"""
|
|
|
|
def test_load_asr_skips_when_mlx_unavailable(self):
|
|
"""mlx_audio 未安装时应跳过 ASR"""
|
|
for mod_name in list(sys.modules.keys()):
|
|
if 'tts_asr' in mod_name or 'mlx' in mod_name:
|
|
del sys.modules[mod_name]
|
|
|
|
# Don't inject mlx stubs — simulate missing MLX
|
|
import tts_asr # noqa: F811
|
|
|
|
assert tts_asr.Qwen3ASRModel is None
|
|
tts_asr._load_asr_models() # should not crash
|
|
assert tts_asr._asr_model is None
|
|
|
|
def test_load_asr_from_path_success(self):
|
|
tts = _reload_tts_asr_with_mocks()
|
|
# Mock snapshot_download to return a path, mock stt_load to succeed
|
|
with patch('backend.tts_asr.snapshot_download', return_value='/fake/path'): # type: ignore
|
|
tts._load_asr_from_path('/fake/path')
|
|
|
|
assert tts._asr_model is not None # type: ignore (MagicMock)
|
|
|
|
|
|
class TestWarmupFunctions:
|
|
"""预热函数测试"""
|
|
|
|
def test_warmup_functions_callable(self):
|
|
tts = _reload_tts_asr_with_mocks()
|
|
assert callable(tts._warmup_tts) # type: ignore (unused var)
|
|
assert callable(tts._warmup_all)
|
|
|
|
def test_warmup_asr_skips_when_mlx_unavailable(self):
|
|
for mod_name in list(sys.modules.keys()):
|
|
if 'tts_asr' in mod_name or 'mlx' in mod_name:
|
|
del sys.modules[mod_name]
|
|
|
|
import tts_asr # noqa: F811
|
|
assert tts_asr.Qwen3ASRModel is None
|
|
|
|
def test_warmup_all_runs_without_error(self):
|
|
tts = _reload_tts_asr_with_mocks()
|
|
|
|
# Set global models so warmup returns immediately without actual loading
|
|
tts._tts_model = MagicMock()
|
|
|
|
async def run(): # type: ignore (unused var)
|
|
await tts._warmup_all()
|
|
|
|
asyncio.get_event_loop().run_until_complete(run()) # type: ignore
|
|
|
|
|
|
class TestRouteRegistration:
|
|
"""路由注册测试"""
|
|
|
|
def test_register_function_exists(self):
|
|
tts = _reload_tts_asr_with_mocks()
|
|
assert callable(tts.register_tts_asr_routes)
|
|
|
|
def test_router_prefix(self):
|
|
tts = _reload_tts_asr_with_mocks()
|
|
assert hasattr(tts.router, 'routes')
|
|
|
|
|
|
class TestModelConstants:
|
|
"""模型常量测试"""
|
|
|
|
def test_asr_model_id(self):
|
|
tts = _reload_tts_asr_with_mocks()
|
|
assert 'Qwen3-ASR' in tts.ASR_MODEL_ID_MS
|
|
|
|
def test_align_model_id(self):
|
|
tts = _reload_tts_asr_with_mocks()
|
|
assert 'ForcedAligner' in tts.ALIGN_MODEL_ID_MS
|