Introduce a comprehensive TTS/ASR module that: - Adds /v1/tts-asr/config, /status, /warmup, /tts, /asr endpoints with detailed JSON responses - Implements Apple‑Silicon detection, device selection (MPS/CUDA/CPU), and memory limiting logic - Supports selectable model size, quantization, and offline mode via environment variables - Adds robust audio validation and multi‑path resampling fallback - Provides new README sections for API usage, device detection, and performance benchmarking - Includes a full testing suite: unit tests, integration tests, macOS simulation and performance reports - Updates backend dependencies and CI scripts - Adds new front‑end views and components for Univer editor integration All changes are backward compatible; new features are exposed through environment variables and new API routes.
396 lines
13 KiB
Python
396 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
TTS/ASR模块集成测试
|
||
测试API端点和完整流程(需要运行后端服务)
|
||
|
||
运行方式:
|
||
# 方式1: 使用pytest
|
||
pytest backend/tests/test_tts_asr_integration.py -v -s
|
||
|
||
# 方式2: 直接运行
|
||
python backend/tests/test_tts_asr_integration.py
|
||
|
||
# 方式3: 测试特定端点
|
||
python backend/tests/test_tts_asr_integration.py --test config
|
||
"""
|
||
|
||
import asyncio
|
||
import base64
|
||
import json
|
||
import os
|
||
import sys
|
||
import time
|
||
import unittest
|
||
from typing import Optional
|
||
import httpx
|
||
|
||
# 配置
|
||
API_BASE_URL = os.environ.get('API_BASE_URL', 'http://localhost:8001')
|
||
API_KEY = os.environ.get('API_KEY', 'your-secret-key-here')
|
||
TEST_TIMEOUT = 120.0 # 2分钟超时
|
||
|
||
|
||
class TTSASRIntegrationTest(unittest.TestCase):
|
||
"""TTS/ASR集成测试"""
|
||
|
||
@classmethod
|
||
def setUpClass(cls):
|
||
"""测试类初始化"""
|
||
cls.client = httpx.Client(timeout=TEST_TIMEOUT)
|
||
cls.headers = {'X-API-Key': API_KEY}
|
||
|
||
# 检查服务是否运行
|
||
try:
|
||
response = cls.client.get(f'{API_BASE_URL}/v1/tts-asr/status', headers=cls.headers)
|
||
if response.status_code == 200:
|
||
cls.service_available = True
|
||
print(f"\n✓ 服务可用: {API_BASE_URL}")
|
||
else:
|
||
cls.service_available = False
|
||
print(f"\n✗ 服务返回非200状态码: {response.status_code}")
|
||
except Exception as e:
|
||
cls.service_available = False
|
||
print(f"\n✗ 无法连接到服务: {e}")
|
||
print(f" 请确保后端服务正在运行: python backend/main.py")
|
||
|
||
@classmethod
|
||
def tearDownClass(cls):
|
||
"""测试类清理"""
|
||
cls.client.close()
|
||
|
||
def setUp(self):
|
||
"""每个测试前的检查"""
|
||
if not self.service_available:
|
||
self.skipTest("后端服务不可用")
|
||
|
||
def test_01_config_endpoint(self):
|
||
"""测试配置端点"""
|
||
response = self.client.get(
|
||
f'{API_BASE_URL}/v1/tts-asr/config',
|
||
headers=self.headers
|
||
)
|
||
|
||
self.assertEqual(response.status_code, 200)
|
||
config = response.json()
|
||
|
||
# 验证配置结构
|
||
self.assertIn('environment', config)
|
||
self.assertIn('device', config)
|
||
self.assertIn('model', config)
|
||
self.assertIn('status', config)
|
||
|
||
# 验证环境变量配置
|
||
env = config['environment']
|
||
self.assertIn('TTS_ASR_DEVICE', env)
|
||
self.assertIn('TTS_ASR_MODEL_SIZE', env)
|
||
self.assertIn('TTS_ASR_QUANTIZE', env)
|
||
|
||
# 验证设备信息
|
||
device = config['device']
|
||
self.assertIn('current', device)
|
||
self.assertIn('mps_available', device)
|
||
self.assertIn('cuda_available', device)
|
||
self.assertIn('is_apple_silicon', device)
|
||
|
||
# 验证模型信息
|
||
model = config['model']
|
||
self.assertIn('tts', model)
|
||
self.assertIn('asr_current_size', model)
|
||
self.assertIn('available_sizes', model)
|
||
|
||
print(f"\n配置信息:")
|
||
print(f" 设备: {device['current']}")
|
||
print(f" Apple Silicon: {device['is_apple_silicon']}")
|
||
print(f" MPS可用: {device['mps_available']}")
|
||
print(f" ASR模型大小: {model['asr_current_size']}")
|
||
|
||
def test_02_status_endpoint(self):
|
||
"""测试状态端点"""
|
||
response = self.client.get(
|
||
f'{API_BASE_URL}/v1/tts-asr/status',
|
||
headers=self.headers
|
||
)
|
||
|
||
self.assertEqual(response.status_code, 200)
|
||
status = response.json()
|
||
|
||
# 验证状态结构
|
||
self.assertIn('tts_loaded', status)
|
||
self.assertIn('asr_loaded', status)
|
||
self.assertIn('device', status)
|
||
self.assertIn('offline_mode', status)
|
||
self.assertIn('quantize_enabled', status)
|
||
|
||
print(f"\n状态信息:")
|
||
print(f" TTS已加载: {status['tts_loaded']}")
|
||
print(f" ASR已加载: {status['asr_loaded']}")
|
||
print(f" 设备: {status['device']}")
|
||
print(f" 离线模式: {status['offline_mode']}")
|
||
print(f" 量化启用: {status['quantize_enabled']}")
|
||
|
||
def test_03_warmup_endpoint(self):
|
||
"""测试预热端点"""
|
||
print("\n开始模型预热(可能需要几分钟)...")
|
||
start_time = time.time()
|
||
|
||
response = self.client.post(
|
||
f'{API_BASE_URL}/v1/tts-asr/warmup',
|
||
headers=self.headers
|
||
)
|
||
|
||
elapsed = time.time() - start_time
|
||
|
||
self.assertEqual(response.status_code, 200)
|
||
result = response.json()
|
||
|
||
self.assertIn('tts_warmup', result)
|
||
self.assertIn('asr_warmup', result)
|
||
self.assertIn('device', result)
|
||
|
||
print(f"\n预热完成 (耗时: {elapsed:.2f}秒):")
|
||
print(f" TTS预热: {'成功' if result['tts_warmup'] else '失败'}")
|
||
print(f" ASR预热: {'成功' if result['asr_warmup'] else '失败'}")
|
||
|
||
# 警告:预热失败不一定是错误(可能模型未下载)
|
||
if not result['tts_warmup'] or not result['asr_warmup']:
|
||
print("\n⚠ 警告: 预热失败可能是因为模型未下载")
|
||
print(" 请确保网络连接正常,或使用已下载的模型")
|
||
|
||
def test_04_tts_endpoint_basic(self):
|
||
"""测试TTS基本功能"""
|
||
# 简单的中文文本
|
||
test_text = "这是一个测试"
|
||
|
||
response = self.client.post(
|
||
f'{API_BASE_URL}/v1/tts-asr/tts',
|
||
headers=self.headers,
|
||
json={
|
||
'text': test_text,
|
||
'voice': 'af_bella',
|
||
'rate': 1.0,
|
||
'format': 'wav'
|
||
}
|
||
)
|
||
|
||
# 检查响应
|
||
if response.status_code == 500:
|
||
error = response.json()
|
||
print(f"\n⚠ TTS失败(可能是模型未加载): {error.get('detail', 'Unknown error')}")
|
||
self.skipTest("TTS模型未加载或不可用")
|
||
|
||
self.assertEqual(response.status_code, 200)
|
||
result = response.json()
|
||
|
||
# 验证响应结构
|
||
self.assertIn('audio_base64', result)
|
||
self.assertIn('format', result)
|
||
self.assertIn('duration_ms', result)
|
||
|
||
# 验证音频数据
|
||
audio_data = base64.b64decode(result['audio_base64'])
|
||
self.assertGreater(len(audio_data), 0)
|
||
self.assertGreater(result['duration_ms'], 0)
|
||
|
||
print(f"\nTTS测试成功:")
|
||
print(f" 输入文本: {test_text}")
|
||
print(f" 音频大小: {len(audio_data)} bytes")
|
||
print(f" 时长: {result['duration_ms']} ms")
|
||
|
||
def test_05_asr_endpoint_basic(self):
|
||
"""测试ASR基本功能"""
|
||
# 创建一个简单的静音WAV文件(1秒,16kHz,单声道)
|
||
sample_rate = 16000
|
||
duration = 1.0
|
||
samples = int(sample_rate * duration)
|
||
|
||
# 生成静音数据
|
||
import numpy as np
|
||
silence = np.zeros(samples, dtype=np.int16)
|
||
|
||
# 创建WAV文件字节流
|
||
import io
|
||
import wave
|
||
|
||
wav_buffer = io.BytesIO()
|
||
with wave.open(wav_buffer, 'wb') as wf:
|
||
wf.setnchannels(1)
|
||
wf.setsampwidth(2)
|
||
wf.setframerate(sample_rate)
|
||
wf.writeframes(silence.tobytes())
|
||
|
||
audio_bytes = wav_buffer.getvalue()
|
||
audio_base64 = base64.b64encode(audio_bytes).decode()
|
||
|
||
# 发送ASR请求
|
||
response = self.client.post(
|
||
f'{API_BASE_URL}/v1/tts-asr/asr',
|
||
headers=self.headers,
|
||
json={
|
||
'audio_base64': audio_base64,
|
||
'language': 'zh-CN'
|
||
}
|
||
)
|
||
|
||
# 检查响应
|
||
if response.status_code == 500:
|
||
error = response.json()
|
||
print(f"\n⚠ ASR失败(可能是模型未加载): {error.get('detail', 'Unknown error')}")
|
||
self.skipTest("ASR模型未加载或不可用")
|
||
|
||
self.assertEqual(response.status_code, 200)
|
||
result = response.json()
|
||
|
||
# 验证响应结构
|
||
self.assertIn('text', result)
|
||
self.assertIn('language', result)
|
||
|
||
print(f"\nASR测试成功:")
|
||
print(f" 识别文本: '{result['text']}'")
|
||
print(f" 语言: {result['language']}")
|
||
print(f" 注意: 静音音频应该返回空文本")
|
||
|
||
def test_06_api_key_validation(self):
|
||
"""测试API密钥验证"""
|
||
# 使用错误的API密钥
|
||
wrong_headers = {'X-API-Key': 'wrong-api-key'}
|
||
|
||
response = self.client.get(
|
||
f'{API_BASE_URL}/v1/tts-asr/status',
|
||
headers=wrong_headers
|
||
)
|
||
|
||
# 应该返回403 Forbidden
|
||
self.assertEqual(response.status_code, 403)
|
||
print(f"\n✓ API密钥验证正常:错误密钥被拒绝")
|
||
|
||
def test_07_tts_long_text(self):
|
||
"""测试TTS长文本处理"""
|
||
# 较长的文本
|
||
long_text = "这是一段较长的测试文本,用于测试TTS系统对长文本的处理能力。" * 3
|
||
|
||
response = self.client.post(
|
||
f'{API_BASE_URL}/v1/tts-asr/tts',
|
||
headers=self.headers,
|
||
json={
|
||
'text': long_text,
|
||
'voice': 'af_bella',
|
||
'rate': 1.0,
|
||
'format': 'wav'
|
||
},
|
||
timeout=60.0 # 长文本需要更长超时
|
||
)
|
||
|
||
if response.status_code == 500:
|
||
self.skipTest("TTS模型未加载或不可用")
|
||
|
||
self.assertEqual(response.status_code, 200)
|
||
result = response.json()
|
||
|
||
print(f"\n长文本TTS测试成功:")
|
||
print(f" 输入长度: {len(long_text)} 字符")
|
||
print(f" 音频大小: {len(base64.b64decode(result['audio_base64']))} bytes")
|
||
print(f" 时长: {result['duration_ms']} ms")
|
||
|
||
|
||
class PerformanceTest(unittest.TestCase):
|
||
"""性能测试"""
|
||
|
||
@classmethod
|
||
def setUpClass(cls):
|
||
cls.client = httpx.Client(timeout=TEST_TIMEOUT)
|
||
cls.headers = {'X-API-Key': API_KEY}
|
||
|
||
try:
|
||
response = cls.client.get(f'{API_BASE_URL}/v1/tts-asr/status', headers=cls.headers)
|
||
cls.service_available = response.status_code == 200
|
||
except:
|
||
cls.service_available = False
|
||
|
||
@classmethod
|
||
def tearDownClass(cls):
|
||
cls.client.close()
|
||
|
||
def setUp(self):
|
||
if not self.service_available:
|
||
self.skipTest("后端服务不可用")
|
||
|
||
def test_tts_latency(self):
|
||
"""测试TTS延迟"""
|
||
test_text = "测试延迟"
|
||
|
||
latencies = []
|
||
for i in range(3):
|
||
start = time.time()
|
||
response = self.client.post(
|
||
f'{API_BASE_URL}/v1/tts-asr/tts',
|
||
headers=self.headers,
|
||
json={'text': test_text}
|
||
)
|
||
elapsed = time.time() - start
|
||
|
||
if response.status_code == 200:
|
||
latencies.append(elapsed)
|
||
|
||
if latencies:
|
||
avg_latency = sum(latencies) / len(latencies)
|
||
print(f"\nTTS延迟测试:")
|
||
print(f" 平均延迟: {avg_latency:.3f}秒")
|
||
print(f" 最小延迟: {min(latencies):.3f}秒")
|
||
print(f" 最大延迟: {max(latencies):.3f}秒")
|
||
|
||
|
||
def run_tests(test_type: Optional[str] = None):
|
||
"""运行测试"""
|
||
loader = unittest.TestLoader()
|
||
suite = unittest.TestSuite()
|
||
|
||
if test_type == 'config':
|
||
suite.addTest(TTSASRIntegrationTest('test_01_config_endpoint'))
|
||
elif test_type == 'status':
|
||
suite.addTest(TTSASRIntegrationTest('test_02_status_endpoint'))
|
||
elif test_type == 'warmup':
|
||
suite.addTest(TTSASRIntegrationTest('test_03_warmup_endpoint'))
|
||
elif test_type == 'tts':
|
||
suite.addTest(TTSASRIntegrationTest('test_04_tts_endpoint_basic'))
|
||
elif test_type == 'asr':
|
||
suite.addTest(TTSASRIntegrationTest('test_05_asr_endpoint_basic'))
|
||
elif test_type == 'perf':
|
||
suite.addTests(loader.loadTestsFromTestCase(PerformanceTest))
|
||
else:
|
||
# 运行所有测试
|
||
suite.addTests(loader.loadTestsFromTestCase(TTSASRIntegrationTest))
|
||
suite.addTests(loader.loadTestsFromTestCase(PerformanceTest))
|
||
|
||
runner = unittest.TextTestRunner(verbosity=2)
|
||
result = runner.run(suite)
|
||
|
||
return result.wasSuccessful()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='TTS/ASR集成测试')
|
||
parser.add_argument('--test', choices=[
|
||
'config', 'status', 'warmup', 'tts', 'asr', 'perf'
|
||
], help='运行特定测试')
|
||
parser.add_argument('--url', default=API_BASE_URL, help='API基础URL')
|
||
parser.add_argument('--key', default=API_KEY, help='API密钥')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 更新配置
|
||
API_BASE_URL = args.url
|
||
API_KEY = args.key
|
||
|
||
print("=" * 70)
|
||
print("TTS/ASR 集成测试")
|
||
print("=" * 70)
|
||
print(f"API URL: {API_BASE_URL}")
|
||
print(f"测试类型: {args.test or '全部'}")
|
||
print("=" * 70)
|
||
|
||
success = run_tests(args.test)
|
||
sys.exit(0 if success else 1)
|