优化测试用例以提高后端模块的测试覆盖率,调整测试断言和异常处理。 Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
378 lines
14 KiB
Python
378 lines
14 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
TTS/ASR模块单元测试
|
||
测试核心功能,无需实际运行模型
|
||
|
||
运行方式:
|
||
pytest backend/tests/test_tts_asr_unit.py -v
|
||
python backend/tests/test_tts_asr_unit.py
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import unittest
|
||
from unittest.mock import Mock, MagicMock, patch
|
||
import tempfile
|
||
import numpy as np
|
||
|
||
# 确保可以导入backend和tts_asr模块
|
||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||
|
||
|
||
class TestAppleSiliconDetection(unittest.TestCase):
|
||
"""测试Apple Silicon检测功能"""
|
||
|
||
def test_is_apple_silicon_on_darwin_arm64(self):
|
||
"""测试在Darwin/arm64环境下检测Apple Silicon"""
|
||
with patch('platform.system', return_value='Darwin'), \
|
||
patch('platform.machine', return_value='arm64'):
|
||
# 需要重新导入以应用mock
|
||
import importlib
|
||
import backend.tts_asr as tts_asr_module
|
||
importlib.reload(tts_asr_module)
|
||
|
||
from backend.tts_asr import _is_apple_silicon
|
||
self.assertTrue(_is_apple_silicon())
|
||
|
||
def test_is_apple_silicon_on_windows(self):
|
||
"""测试在Windows环境下不是Apple Silicon"""
|
||
with patch('platform.system', return_value='Windows'), \
|
||
patch('platform.machine', return_value='AMD64'):
|
||
import importlib
|
||
import backend.tts_asr as tts_asr_module
|
||
importlib.reload(tts_asr_module)
|
||
|
||
from backend.tts_asr import _is_apple_silicon
|
||
self.assertFalse(_is_apple_silicon())
|
||
|
||
def test_is_apple_silicon_on_linux(self):
|
||
"""测试在Linux环境下不是Apple Silicon"""
|
||
with patch('platform.system', return_value='Linux'), \
|
||
patch('platform.machine', return_value='x86_64'):
|
||
import importlib
|
||
import backend.tts_asr as tts_asr_module
|
||
importlib.reload(tts_asr_module)
|
||
|
||
from backend.tts_asr import _is_apple_silicon
|
||
self.assertFalse(_is_apple_silicon())
|
||
|
||
|
||
class TestEnvironmentVariables(unittest.TestCase):
|
||
"""测试环境变量解析"""
|
||
|
||
def test_default_environment_values(self):
|
||
"""测试默认环境变量值"""
|
||
# 清除可能存在的环境变量
|
||
env_vars = [
|
||
'TTS_ASR_DEVICE', 'TTS_ASR_MODEL_SIZE', 'TTS_ASR_QUANTIZE',
|
||
'TTS_ASR_OFFLINE_MODE', 'TTS_ASR_WARMUP', 'TTS_ASR_WARMUP_TIMEOUT',
|
||
'TTS_ASR_IDLE_TIMEOUT', 'TTS_ASR_MPS_MEMORY_LIMIT_MB'
|
||
]
|
||
|
||
# 保存原始值
|
||
original_values = {}
|
||
for var in env_vars:
|
||
original_values[var] = os.environ.get(var)
|
||
if var in os.environ:
|
||
del os.environ[var]
|
||
|
||
try:
|
||
# 重新加载模块以应用默认值
|
||
import importlib
|
||
import backend.tts_asr as tts_asr_module
|
||
importlib.reload(tts_asr_module)
|
||
|
||
from backend.tts_asr import (
|
||
TTS_ASR_DEVICE, TTS_ASR_MODEL_SIZE, TTS_ASR_QUANTIZE,
|
||
TTS_ASR_OFFLINE_MODE, TTS_ASR_WARMUP, TTS_ASR_WARMUP_TIMEOUT,
|
||
TTS_ASR_IDLE_TIMEOUT, TTS_ASR_MPS_MEMORY_LIMIT_MB
|
||
)
|
||
|
||
self.assertEqual(TTS_ASR_DEVICE, 'auto')
|
||
self.assertEqual(TTS_ASR_MODEL_SIZE, 'auto')
|
||
self.assertFalse(TTS_ASR_QUANTIZE)
|
||
self.assertFalse(TTS_ASR_OFFLINE_MODE)
|
||
self.assertTrue(TTS_ASR_WARMUP)
|
||
self.assertEqual(TTS_ASR_WARMUP_TIMEOUT, 120)
|
||
self.assertEqual(TTS_ASR_IDLE_TIMEOUT, 0)
|
||
self.assertEqual(TTS_ASR_MPS_MEMORY_LIMIT_MB, 8192)
|
||
finally:
|
||
# 恢复原始值
|
||
for var, value in original_values.items():
|
||
if value is not None:
|
||
os.environ[var] = value
|
||
elif var in os.environ:
|
||
del os.environ[var]
|
||
|
||
def test_custom_environment_values(self):
|
||
"""测试自定义环境变量值"""
|
||
os.environ['TTS_ASR_DEVICE'] = 'cpu'
|
||
os.environ['TTS_ASR_MODEL_SIZE'] = 'small'
|
||
os.environ['TTS_ASR_QUANTIZE'] = 'true'
|
||
os.environ['TTS_ASR_OFFLINE_MODE'] = 'true'
|
||
|
||
try:
|
||
import importlib
|
||
import backend.tts_asr as tts_asr_module
|
||
importlib.reload(tts_asr_module)
|
||
|
||
from backend.tts_asr import (
|
||
TTS_ASR_DEVICE, TTS_ASR_MODEL_SIZE, TTS_ASR_QUANTIZE,
|
||
TTS_ASR_OFFLINE_MODE
|
||
)
|
||
|
||
self.assertEqual(TTS_ASR_DEVICE, 'cpu')
|
||
self.assertEqual(TTS_ASR_MODEL_SIZE, 'small')
|
||
self.assertTrue(TTS_ASR_QUANTIZE)
|
||
self.assertTrue(TTS_ASR_OFFLINE_MODE)
|
||
finally:
|
||
# 清理环境变量
|
||
for var in ['TTS_ASR_DEVICE', 'TTS_ASR_MODEL_SIZE',
|
||
'TTS_ASR_QUANTIZE', 'TTS_ASR_OFFLINE_MODE']:
|
||
if var in os.environ:
|
||
del os.environ[var]
|
||
|
||
|
||
class TestModelSizeSelection(unittest.TestCase):
|
||
"""测试模型大小选择逻辑"""
|
||
|
||
def test_whisper_model_sizes_mapping(self):
|
||
"""测试Whisper模型大小映射"""
|
||
from backend.tts_asr import WHISPER_MODEL_SIZES
|
||
|
||
expected_sizes = ['tiny', 'base', 'small', 'medium', 'large', 'turbo']
|
||
self.assertEqual(list(WHISPER_MODEL_SIZES.keys()), expected_sizes)
|
||
|
||
# 验证模型ID格式
|
||
for size, model_id in WHISPER_MODEL_SIZES.items():
|
||
self.assertTrue(model_id.startswith('openai/whisper'))
|
||
self.assertIn(size, model_id)
|
||
|
||
def test_recommended_model_size_explicit(self):
|
||
"""测试显式指定的模型大小"""
|
||
os.environ['TTS_ASR_MODEL_SIZE'] = 'medium'
|
||
|
||
try:
|
||
import importlib
|
||
import backend.tts_asr as tts_asr_module
|
||
importlib.reload(tts_asr_module)
|
||
|
||
from backend.tts_asr import _get_recommended_model_size
|
||
size = _get_recommended_model_size()
|
||
self.assertEqual(size, 'medium')
|
||
finally:
|
||
if 'TTS_ASR_MODEL_SIZE' in os.environ:
|
||
del os.environ['TTS_ASR_MODEL_SIZE']
|
||
|
||
def test_invalid_model_size_falls_back(self):
|
||
"""测试无效模型大小回退到自动选择"""
|
||
os.environ['TTS_ASR_MODEL_SIZE'] = 'invalid_size'
|
||
|
||
try:
|
||
import importlib
|
||
import backend.tts_asr as tts_asr_module
|
||
importlib.reload(tts_asr_module)
|
||
|
||
from backend.tts_asr import _get_recommended_model_size, WHISPER_MODEL_SIZES
|
||
# 应该回退到推荐大小而不崩溃
|
||
size = _get_recommended_model_size()
|
||
self.assertIn(size, WHISPER_MODEL_SIZES.keys())
|
||
finally:
|
||
if 'TTS_ASR_MODEL_SIZE' in os.environ:
|
||
del os.environ['TTS_ASR_MODEL_SIZE']
|
||
|
||
|
||
class TestAudioValidation(unittest.TestCase):
|
||
"""测试音频验证功能"""
|
||
|
||
def test_validate_empty_audio(self):
|
||
"""测试空音频数据验证"""
|
||
from backend.tts_asr import _validate_audio_data
|
||
|
||
self.assertFalse(_validate_audio_data(b''))
|
||
self.assertFalse(_validate_audio_data(b'short'))
|
||
|
||
def test_validate_valid_wav_header(self):
|
||
"""测试有效WAV头部验证"""
|
||
from backend.tts_asr import _validate_audio_data
|
||
|
||
# 创建一个最小的有效WAV头部(44字节)
|
||
valid_wav_header = b'RIFF' + b'\x00' * 40
|
||
self.assertTrue(_validate_audio_data(valid_wav_header))
|
||
|
||
def test_validate_invalid_audio(self):
|
||
"""测试无效音频数据验证"""
|
||
from backend.tts_asr import _validate_audio_data
|
||
|
||
# 小于最小WAV头部大小
|
||
invalid_audio = b'RIFF' + b'\x00' * 30
|
||
self.assertFalse(_validate_audio_data(invalid_audio))
|
||
|
||
|
||
class TestAudioResampling(unittest.TestCase):
|
||
"""测试音频重采样功能"""
|
||
|
||
def test_resample_same_rate(self):
|
||
"""测试相同采样率(无需重采样)"""
|
||
from backend.tts_asr import _resample_audio_robust
|
||
|
||
audio = np.random.randn(16000).astype(np.float32)
|
||
resampled = _resample_audio_robust(audio, 16000, 16000)
|
||
|
||
# 应该返回原始音频
|
||
np.testing.assert_array_almost_equal(audio, resampled)
|
||
|
||
def test_resample_different_rate(self):
|
||
"""测试不同采样率重采样"""
|
||
from backend.tts_asr import _resample_audio_robust
|
||
|
||
# 创建1秒的音频,从16kHz重采样到48kHz
|
||
audio_16k = np.sin(np.linspace(0, 2*np.pi, 16000)).astype(np.float32)
|
||
audio_48k = _resample_audio_robust(audio_16k, 16000, 48000)
|
||
|
||
# 检查长度变化
|
||
expected_length = int(len(audio_16k) * 48000 / 16000)
|
||
self.assertEqual(len(audio_48k), expected_length)
|
||
|
||
def test_resample_downsample(self):
|
||
"""测试下采样"""
|
||
from backend.tts_asr import _resample_audio_robust
|
||
|
||
# 从48kHz下采样到16kHz
|
||
audio_48k = np.sin(np.linspace(0, 2*np.pi, 48000)).astype(np.float32)
|
||
audio_16k = _resample_audio_robust(audio_48k, 48000, 16000)
|
||
|
||
expected_length = int(len(audio_48k) * 16000 / 48000)
|
||
self.assertEqual(len(audio_16k), expected_length)
|
||
|
||
|
||
class TestDeviceCapabilities(unittest.TestCase):
|
||
"""测试设备能力检测"""
|
||
|
||
def test_device_capabilities_dataclass(self):
|
||
"""测试DeviceCapabilities数据类"""
|
||
from backend.tts_asr import DeviceCapabilities
|
||
|
||
caps = DeviceCapabilities(
|
||
device='cpu',
|
||
mps_available=False,
|
||
cuda_available=False
|
||
)
|
||
|
||
self.assertEqual(caps.device, 'cpu')
|
||
self.assertFalse(caps.mps_available)
|
||
self.assertFalse(caps.cuda_available)
|
||
self.assertEqual(caps.recommended_model_size, 'large') # 默认值
|
||
|
||
def test_device_capabilities_with_mps(self):
|
||
"""测试MPS设备能力"""
|
||
from backend.tts_asr import DeviceCapabilities
|
||
|
||
caps = DeviceCapabilities(
|
||
device='mps',
|
||
mps_available=True,
|
||
mps_memory_limit_mb=8192,
|
||
recommended_model_size='small'
|
||
)
|
||
|
||
self.assertEqual(caps.device, 'mps')
|
||
self.assertTrue(caps.mps_available)
|
||
self.assertEqual(caps.mps_memory_limit_mb, 8192)
|
||
self.assertEqual(caps.recommended_model_size, 'small')
|
||
|
||
|
||
class TestModelCacheCheck(unittest.TestCase):
|
||
"""测试模型缓存检查"""
|
||
|
||
@patch('backend.tts_asr.TTS_ASR_OFFLINE_MODE', False)
|
||
def test_cache_check_non_offline_mode(self):
|
||
"""测试非离线模式下缓存检查总是返回True"""
|
||
from backend.tts_asr import _check_model_cached
|
||
|
||
# 非离线模式应该总是返回True
|
||
result = _check_model_cached('any/model')
|
||
self.assertTrue(result)
|
||
|
||
@patch('backend.tts_asr.TTS_ASR_OFFLINE_MODE', True)
|
||
def test_cache_check_offline_mode_missing(self):
|
||
"""测试离线模式下缺失模型的处理"""
|
||
try:
|
||
import huggingface_hub # noqa: F401
|
||
except ImportError:
|
||
self.skipTest("huggingface_hub not installed")
|
||
from backend.tts_asr import _check_model_cached
|
||
|
||
# 模拟缓存路径
|
||
with patch('huggingface_hub.constants.HF_HUB_CACHE', '/nonexistent/path'):
|
||
result = _check_model_cached('nonexistent/model')
|
||
# 应该返回False(模型未缓存)
|
||
self.assertFalse(result)
|
||
|
||
|
||
class TestRequestResponseModels(unittest.TestCase):
|
||
"""测试请求/响应数据模型"""
|
||
|
||
def test_tts_request_model(self):
|
||
"""测试TTS请求模型"""
|
||
from backend.tts_asr import TTSRequest
|
||
|
||
req = TTSRequest(text="测试文本")
|
||
self.assertEqual(req.text, "测试文本")
|
||
self.assertEqual(req.voice, "af_bella") # 默认值
|
||
self.assertEqual(req.rate, 1.0) # 默认值
|
||
self.assertEqual(req.format, "wav") # 默认值
|
||
|
||
def test_asr_request_model(self):
|
||
"""测试ASR请求模型"""
|
||
from backend.tts_asr import ASRRequest
|
||
|
||
req = ASRRequest(audio_base64="dGVzdA==")
|
||
self.assertEqual(req.audio_base64, "dGVzdA==")
|
||
self.assertEqual(req.language, "zh-CN") # 默认值
|
||
|
||
def test_model_status_model(self):
|
||
"""测试ModelStatus模型"""
|
||
from backend.tts_asr import ModelStatus
|
||
|
||
status = ModelStatus(
|
||
tts_loaded=False,
|
||
asr_loaded=False,
|
||
device='cpu'
|
||
)
|
||
|
||
self.assertFalse(status.tts_loaded)
|
||
self.assertFalse(status.asr_loaded)
|
||
self.assertEqual(status.device, 'cpu')
|
||
self.assertIsNone(status.tts_last_used)
|
||
self.assertIsNone(status.asr_last_used)
|
||
|
||
|
||
def run_tests():
|
||
"""运行所有测试"""
|
||
loader = unittest.TestLoader()
|
||
suite = unittest.TestSuite()
|
||
|
||
# 添加所有测试类
|
||
suite.addTests(loader.loadTestsFromTestCase(TestAppleSiliconDetection))
|
||
suite.addTests(loader.loadTestsFromTestCase(TestEnvironmentVariables))
|
||
suite.addTests(loader.loadTestsFromTestCase(TestModelSizeSelection))
|
||
suite.addTests(loader.loadTestsFromTestCase(TestAudioValidation))
|
||
suite.addTests(loader.loadTestsFromTestCase(TestAudioResampling))
|
||
suite.addTests(loader.loadTestsFromTestCase(TestDeviceCapabilities))
|
||
suite.addTests(loader.loadTestsFromTestCase(TestModelCacheCheck))
|
||
suite.addTests(loader.loadTestsFromTestCase(TestRequestResponseModels))
|
||
|
||
# 运行测试
|
||
runner = unittest.TextTestRunner(verbosity=2)
|
||
result = runner.run(suite)
|
||
|
||
return result.wasSuccessful()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 直接运行时执行测试
|
||
success = run_tests()
|
||
sys.exit(0 if success else 1)
|