Files
llm-in-text/backend/tests/test_tts_asr_unit.py
ydy0615 538f3e227a test: improve test coverage for backend modules
优化测试用例以提高后端模块的测试覆盖率,调整测试断言和异常处理。

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-07 12:46:56 +08:00

378 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
TTS/ASR模块单元测试
测试核心功能,无需实际运行模型
运行方式:
pytest backend/tests/test_tts_asr_unit.py -v
python backend/tests/test_tts_asr_unit.py
"""
import os
import sys
import unittest
from unittest.mock import Mock, MagicMock, patch
import tempfile
import numpy as np
# 确保可以导入backend和tts_asr模块
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
class TestAppleSiliconDetection(unittest.TestCase):
"""测试Apple Silicon检测功能"""
def test_is_apple_silicon_on_darwin_arm64(self):
"""测试在Darwin/arm64环境下检测Apple Silicon"""
with patch('platform.system', return_value='Darwin'), \
patch('platform.machine', return_value='arm64'):
# 需要重新导入以应用mock
import importlib
import backend.tts_asr as tts_asr_module
importlib.reload(tts_asr_module)
from backend.tts_asr import _is_apple_silicon
self.assertTrue(_is_apple_silicon())
def test_is_apple_silicon_on_windows(self):
"""测试在Windows环境下不是Apple Silicon"""
with patch('platform.system', return_value='Windows'), \
patch('platform.machine', return_value='AMD64'):
import importlib
import backend.tts_asr as tts_asr_module
importlib.reload(tts_asr_module)
from backend.tts_asr import _is_apple_silicon
self.assertFalse(_is_apple_silicon())
def test_is_apple_silicon_on_linux(self):
"""测试在Linux环境下不是Apple Silicon"""
with patch('platform.system', return_value='Linux'), \
patch('platform.machine', return_value='x86_64'):
import importlib
import backend.tts_asr as tts_asr_module
importlib.reload(tts_asr_module)
from backend.tts_asr import _is_apple_silicon
self.assertFalse(_is_apple_silicon())
class TestEnvironmentVariables(unittest.TestCase):
"""测试环境变量解析"""
def test_default_environment_values(self):
"""测试默认环境变量值"""
# 清除可能存在的环境变量
env_vars = [
'TTS_ASR_DEVICE', 'TTS_ASR_MODEL_SIZE', 'TTS_ASR_QUANTIZE',
'TTS_ASR_OFFLINE_MODE', 'TTS_ASR_WARMUP', 'TTS_ASR_WARMUP_TIMEOUT',
'TTS_ASR_IDLE_TIMEOUT', 'TTS_ASR_MPS_MEMORY_LIMIT_MB'
]
# 保存原始值
original_values = {}
for var in env_vars:
original_values[var] = os.environ.get(var)
if var in os.environ:
del os.environ[var]
try:
# 重新加载模块以应用默认值
import importlib
import backend.tts_asr as tts_asr_module
importlib.reload(tts_asr_module)
from backend.tts_asr import (
TTS_ASR_DEVICE, TTS_ASR_MODEL_SIZE, TTS_ASR_QUANTIZE,
TTS_ASR_OFFLINE_MODE, TTS_ASR_WARMUP, TTS_ASR_WARMUP_TIMEOUT,
TTS_ASR_IDLE_TIMEOUT, TTS_ASR_MPS_MEMORY_LIMIT_MB
)
self.assertEqual(TTS_ASR_DEVICE, 'auto')
self.assertEqual(TTS_ASR_MODEL_SIZE, 'auto')
self.assertFalse(TTS_ASR_QUANTIZE)
self.assertFalse(TTS_ASR_OFFLINE_MODE)
self.assertTrue(TTS_ASR_WARMUP)
self.assertEqual(TTS_ASR_WARMUP_TIMEOUT, 120)
self.assertEqual(TTS_ASR_IDLE_TIMEOUT, 0)
self.assertEqual(TTS_ASR_MPS_MEMORY_LIMIT_MB, 8192)
finally:
# 恢复原始值
for var, value in original_values.items():
if value is not None:
os.environ[var] = value
elif var in os.environ:
del os.environ[var]
def test_custom_environment_values(self):
"""测试自定义环境变量值"""
os.environ['TTS_ASR_DEVICE'] = 'cpu'
os.environ['TTS_ASR_MODEL_SIZE'] = 'small'
os.environ['TTS_ASR_QUANTIZE'] = 'true'
os.environ['TTS_ASR_OFFLINE_MODE'] = 'true'
try:
import importlib
import backend.tts_asr as tts_asr_module
importlib.reload(tts_asr_module)
from backend.tts_asr import (
TTS_ASR_DEVICE, TTS_ASR_MODEL_SIZE, TTS_ASR_QUANTIZE,
TTS_ASR_OFFLINE_MODE
)
self.assertEqual(TTS_ASR_DEVICE, 'cpu')
self.assertEqual(TTS_ASR_MODEL_SIZE, 'small')
self.assertTrue(TTS_ASR_QUANTIZE)
self.assertTrue(TTS_ASR_OFFLINE_MODE)
finally:
# 清理环境变量
for var in ['TTS_ASR_DEVICE', 'TTS_ASR_MODEL_SIZE',
'TTS_ASR_QUANTIZE', 'TTS_ASR_OFFLINE_MODE']:
if var in os.environ:
del os.environ[var]
class TestModelSizeSelection(unittest.TestCase):
"""测试模型大小选择逻辑"""
def test_whisper_model_sizes_mapping(self):
"""测试Whisper模型大小映射"""
from backend.tts_asr import WHISPER_MODEL_SIZES
expected_sizes = ['tiny', 'base', 'small', 'medium', 'large', 'turbo']
self.assertEqual(list(WHISPER_MODEL_SIZES.keys()), expected_sizes)
# 验证模型ID格式
for size, model_id in WHISPER_MODEL_SIZES.items():
self.assertTrue(model_id.startswith('openai/whisper'))
self.assertIn(size, model_id)
def test_recommended_model_size_explicit(self):
"""测试显式指定的模型大小"""
os.environ['TTS_ASR_MODEL_SIZE'] = 'medium'
try:
import importlib
import backend.tts_asr as tts_asr_module
importlib.reload(tts_asr_module)
from backend.tts_asr import _get_recommended_model_size
size = _get_recommended_model_size()
self.assertEqual(size, 'medium')
finally:
if 'TTS_ASR_MODEL_SIZE' in os.environ:
del os.environ['TTS_ASR_MODEL_SIZE']
def test_invalid_model_size_falls_back(self):
"""测试无效模型大小回退到自动选择"""
os.environ['TTS_ASR_MODEL_SIZE'] = 'invalid_size'
try:
import importlib
import backend.tts_asr as tts_asr_module
importlib.reload(tts_asr_module)
from backend.tts_asr import _get_recommended_model_size, WHISPER_MODEL_SIZES
# 应该回退到推荐大小而不崩溃
size = _get_recommended_model_size()
self.assertIn(size, WHISPER_MODEL_SIZES.keys())
finally:
if 'TTS_ASR_MODEL_SIZE' in os.environ:
del os.environ['TTS_ASR_MODEL_SIZE']
class TestAudioValidation(unittest.TestCase):
"""测试音频验证功能"""
def test_validate_empty_audio(self):
"""测试空音频数据验证"""
from backend.tts_asr import _validate_audio_data
self.assertFalse(_validate_audio_data(b''))
self.assertFalse(_validate_audio_data(b'short'))
def test_validate_valid_wav_header(self):
"""测试有效WAV头部验证"""
from backend.tts_asr import _validate_audio_data
# 创建一个最小的有效WAV头部44字节
valid_wav_header = b'RIFF' + b'\x00' * 40
self.assertTrue(_validate_audio_data(valid_wav_header))
def test_validate_invalid_audio(self):
"""测试无效音频数据验证"""
from backend.tts_asr import _validate_audio_data
# 小于最小WAV头部大小
invalid_audio = b'RIFF' + b'\x00' * 30
self.assertFalse(_validate_audio_data(invalid_audio))
class TestAudioResampling(unittest.TestCase):
"""测试音频重采样功能"""
def test_resample_same_rate(self):
"""测试相同采样率(无需重采样)"""
from backend.tts_asr import _resample_audio_robust
audio = np.random.randn(16000).astype(np.float32)
resampled = _resample_audio_robust(audio, 16000, 16000)
# 应该返回原始音频
np.testing.assert_array_almost_equal(audio, resampled)
def test_resample_different_rate(self):
"""测试不同采样率重采样"""
from backend.tts_asr import _resample_audio_robust
# 创建1秒的音频从16kHz重采样到48kHz
audio_16k = np.sin(np.linspace(0, 2*np.pi, 16000)).astype(np.float32)
audio_48k = _resample_audio_robust(audio_16k, 16000, 48000)
# 检查长度变化
expected_length = int(len(audio_16k) * 48000 / 16000)
self.assertEqual(len(audio_48k), expected_length)
def test_resample_downsample(self):
"""测试下采样"""
from backend.tts_asr import _resample_audio_robust
# 从48kHz下采样到16kHz
audio_48k = np.sin(np.linspace(0, 2*np.pi, 48000)).astype(np.float32)
audio_16k = _resample_audio_robust(audio_48k, 48000, 16000)
expected_length = int(len(audio_48k) * 16000 / 48000)
self.assertEqual(len(audio_16k), expected_length)
class TestDeviceCapabilities(unittest.TestCase):
"""测试设备能力检测"""
def test_device_capabilities_dataclass(self):
"""测试DeviceCapabilities数据类"""
from backend.tts_asr import DeviceCapabilities
caps = DeviceCapabilities(
device='cpu',
mps_available=False,
cuda_available=False
)
self.assertEqual(caps.device, 'cpu')
self.assertFalse(caps.mps_available)
self.assertFalse(caps.cuda_available)
self.assertEqual(caps.recommended_model_size, 'large') # 默认值
def test_device_capabilities_with_mps(self):
"""测试MPS设备能力"""
from backend.tts_asr import DeviceCapabilities
caps = DeviceCapabilities(
device='mps',
mps_available=True,
mps_memory_limit_mb=8192,
recommended_model_size='small'
)
self.assertEqual(caps.device, 'mps')
self.assertTrue(caps.mps_available)
self.assertEqual(caps.mps_memory_limit_mb, 8192)
self.assertEqual(caps.recommended_model_size, 'small')
class TestModelCacheCheck(unittest.TestCase):
"""测试模型缓存检查"""
@patch('backend.tts_asr.TTS_ASR_OFFLINE_MODE', False)
def test_cache_check_non_offline_mode(self):
"""测试非离线模式下缓存检查总是返回True"""
from backend.tts_asr import _check_model_cached
# 非离线模式应该总是返回True
result = _check_model_cached('any/model')
self.assertTrue(result)
@patch('backend.tts_asr.TTS_ASR_OFFLINE_MODE', True)
def test_cache_check_offline_mode_missing(self):
"""测试离线模式下缺失模型的处理"""
try:
import huggingface_hub # noqa: F401
except ImportError:
self.skipTest("huggingface_hub not installed")
from backend.tts_asr import _check_model_cached
# 模拟缓存路径
with patch('huggingface_hub.constants.HF_HUB_CACHE', '/nonexistent/path'):
result = _check_model_cached('nonexistent/model')
# 应该返回False模型未缓存
self.assertFalse(result)
class TestRequestResponseModels(unittest.TestCase):
"""测试请求/响应数据模型"""
def test_tts_request_model(self):
"""测试TTS请求模型"""
from backend.tts_asr import TTSRequest
req = TTSRequest(text="测试文本")
self.assertEqual(req.text, "测试文本")
self.assertEqual(req.voice, "af_bella") # 默认值
self.assertEqual(req.rate, 1.0) # 默认值
self.assertEqual(req.format, "wav") # 默认值
def test_asr_request_model(self):
"""测试ASR请求模型"""
from backend.tts_asr import ASRRequest
req = ASRRequest(audio_base64="dGVzdA==")
self.assertEqual(req.audio_base64, "dGVzdA==")
self.assertEqual(req.language, "zh-CN") # 默认值
def test_model_status_model(self):
"""测试ModelStatus模型"""
from backend.tts_asr import ModelStatus
status = ModelStatus(
tts_loaded=False,
asr_loaded=False,
device='cpu'
)
self.assertFalse(status.tts_loaded)
self.assertFalse(status.asr_loaded)
self.assertEqual(status.device, 'cpu')
self.assertIsNone(status.tts_last_used)
self.assertIsNone(status.asr_last_used)
def run_tests():
"""运行所有测试"""
loader = unittest.TestLoader()
suite = unittest.TestSuite()
# 添加所有测试类
suite.addTests(loader.loadTestsFromTestCase(TestAppleSiliconDetection))
suite.addTests(loader.loadTestsFromTestCase(TestEnvironmentVariables))
suite.addTests(loader.loadTestsFromTestCase(TestModelSizeSelection))
suite.addTests(loader.loadTestsFromTestCase(TestAudioValidation))
suite.addTests(loader.loadTestsFromTestCase(TestAudioResampling))
suite.addTests(loader.loadTestsFromTestCase(TestDeviceCapabilities))
suite.addTests(loader.loadTestsFromTestCase(TestModelCacheCheck))
suite.addTests(loader.loadTestsFromTestCase(TestRequestResponseModels))
# 运行测试
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
return result.wasSuccessful()
if __name__ == '__main__':
# 直接运行时执行测试
success = run_tests()
sys.exit(0 if success else 1)