Files
llm-in-text/backend/tests/test_tts_asr_integration.py
ydy0615 7985fe9641 feat(tts): add api endpoints and optimization for apple silicon
Introduce a comprehensive TTS/ASR module that:
- Adds /v1/tts-asr/config, /status, /warmup, /tts, /asr endpoints with detailed JSON responses
- Implements Apple‑Silicon detection, device selection (MPS/CUDA/CPU), and memory limiting logic
- Supports selectable model size, quantization, and offline mode via environment variables
- Adds robust audio validation and multi‑path resampling fallback
- Provides new README sections for API usage, device detection, and performance benchmarking
- Includes a full testing suite: unit tests, integration tests, macOS simulation and performance reports
- Updates backend dependencies and CI scripts
- Adds new front‑end views and components for Univer editor integration

All changes are backward compatible; new features are exposed through environment variables and new API routes.
2026-04-06 11:14:09 +08:00

396 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
TTS/ASR模块集成测试
测试API端点和完整流程需要运行后端服务
运行方式:
# 方式1: 使用pytest
pytest backend/tests/test_tts_asr_integration.py -v -s
# 方式2: 直接运行
python backend/tests/test_tts_asr_integration.py
# 方式3: 测试特定端点
python backend/tests/test_tts_asr_integration.py --test config
"""
import asyncio
import base64
import json
import os
import sys
import time
import unittest
from typing import Optional
import httpx
# 配置
API_BASE_URL = os.environ.get('API_BASE_URL', 'http://localhost:8001')
API_KEY = os.environ.get('API_KEY', 'your-secret-key-here')
TEST_TIMEOUT = 120.0 # 2分钟超时
class TTSASRIntegrationTest(unittest.TestCase):
"""TTS/ASR集成测试"""
@classmethod
def setUpClass(cls):
"""测试类初始化"""
cls.client = httpx.Client(timeout=TEST_TIMEOUT)
cls.headers = {'X-API-Key': API_KEY}
# 检查服务是否运行
try:
response = cls.client.get(f'{API_BASE_URL}/v1/tts-asr/status', headers=cls.headers)
if response.status_code == 200:
cls.service_available = True
print(f"\n✓ 服务可用: {API_BASE_URL}")
else:
cls.service_available = False
print(f"\n✗ 服务返回非200状态码: {response.status_code}")
except Exception as e:
cls.service_available = False
print(f"\n✗ 无法连接到服务: {e}")
print(f" 请确保后端服务正在运行: python backend/main.py")
@classmethod
def tearDownClass(cls):
"""测试类清理"""
cls.client.close()
def setUp(self):
"""每个测试前的检查"""
if not self.service_available:
self.skipTest("后端服务不可用")
def test_01_config_endpoint(self):
"""测试配置端点"""
response = self.client.get(
f'{API_BASE_URL}/v1/tts-asr/config',
headers=self.headers
)
self.assertEqual(response.status_code, 200)
config = response.json()
# 验证配置结构
self.assertIn('environment', config)
self.assertIn('device', config)
self.assertIn('model', config)
self.assertIn('status', config)
# 验证环境变量配置
env = config['environment']
self.assertIn('TTS_ASR_DEVICE', env)
self.assertIn('TTS_ASR_MODEL_SIZE', env)
self.assertIn('TTS_ASR_QUANTIZE', env)
# 验证设备信息
device = config['device']
self.assertIn('current', device)
self.assertIn('mps_available', device)
self.assertIn('cuda_available', device)
self.assertIn('is_apple_silicon', device)
# 验证模型信息
model = config['model']
self.assertIn('tts', model)
self.assertIn('asr_current_size', model)
self.assertIn('available_sizes', model)
print(f"\n配置信息:")
print(f" 设备: {device['current']}")
print(f" Apple Silicon: {device['is_apple_silicon']}")
print(f" MPS可用: {device['mps_available']}")
print(f" ASR模型大小: {model['asr_current_size']}")
def test_02_status_endpoint(self):
"""测试状态端点"""
response = self.client.get(
f'{API_BASE_URL}/v1/tts-asr/status',
headers=self.headers
)
self.assertEqual(response.status_code, 200)
status = response.json()
# 验证状态结构
self.assertIn('tts_loaded', status)
self.assertIn('asr_loaded', status)
self.assertIn('device', status)
self.assertIn('offline_mode', status)
self.assertIn('quantize_enabled', status)
print(f"\n状态信息:")
print(f" TTS已加载: {status['tts_loaded']}")
print(f" ASR已加载: {status['asr_loaded']}")
print(f" 设备: {status['device']}")
print(f" 离线模式: {status['offline_mode']}")
print(f" 量化启用: {status['quantize_enabled']}")
def test_03_warmup_endpoint(self):
"""测试预热端点"""
print("\n开始模型预热(可能需要几分钟)...")
start_time = time.time()
response = self.client.post(
f'{API_BASE_URL}/v1/tts-asr/warmup',
headers=self.headers
)
elapsed = time.time() - start_time
self.assertEqual(response.status_code, 200)
result = response.json()
self.assertIn('tts_warmup', result)
self.assertIn('asr_warmup', result)
self.assertIn('device', result)
print(f"\n预热完成 (耗时: {elapsed:.2f}秒):")
print(f" TTS预热: {'成功' if result['tts_warmup'] else '失败'}")
print(f" ASR预热: {'成功' if result['asr_warmup'] else '失败'}")
# 警告:预热失败不一定是错误(可能模型未下载)
if not result['tts_warmup'] or not result['asr_warmup']:
print("\n⚠ 警告: 预热失败可能是因为模型未下载")
print(" 请确保网络连接正常,或使用已下载的模型")
def test_04_tts_endpoint_basic(self):
"""测试TTS基本功能"""
# 简单的中文文本
test_text = "这是一个测试"
response = self.client.post(
f'{API_BASE_URL}/v1/tts-asr/tts',
headers=self.headers,
json={
'text': test_text,
'voice': 'af_bella',
'rate': 1.0,
'format': 'wav'
}
)
# 检查响应
if response.status_code == 500:
error = response.json()
print(f"\n⚠ TTS失败可能是模型未加载: {error.get('detail', 'Unknown error')}")
self.skipTest("TTS模型未加载或不可用")
self.assertEqual(response.status_code, 200)
result = response.json()
# 验证响应结构
self.assertIn('audio_base64', result)
self.assertIn('format', result)
self.assertIn('duration_ms', result)
# 验证音频数据
audio_data = base64.b64decode(result['audio_base64'])
self.assertGreater(len(audio_data), 0)
self.assertGreater(result['duration_ms'], 0)
print(f"\nTTS测试成功:")
print(f" 输入文本: {test_text}")
print(f" 音频大小: {len(audio_data)} bytes")
print(f" 时长: {result['duration_ms']} ms")
def test_05_asr_endpoint_basic(self):
"""测试ASR基本功能"""
# 创建一个简单的静音WAV文件1秒16kHz单声道
sample_rate = 16000
duration = 1.0
samples = int(sample_rate * duration)
# 生成静音数据
import numpy as np
silence = np.zeros(samples, dtype=np.int16)
# 创建WAV文件字节流
import io
import wave
wav_buffer = io.BytesIO()
with wave.open(wav_buffer, 'wb') as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(silence.tobytes())
audio_bytes = wav_buffer.getvalue()
audio_base64 = base64.b64encode(audio_bytes).decode()
# 发送ASR请求
response = self.client.post(
f'{API_BASE_URL}/v1/tts-asr/asr',
headers=self.headers,
json={
'audio_base64': audio_base64,
'language': 'zh-CN'
}
)
# 检查响应
if response.status_code == 500:
error = response.json()
print(f"\n⚠ ASR失败可能是模型未加载: {error.get('detail', 'Unknown error')}")
self.skipTest("ASR模型未加载或不可用")
self.assertEqual(response.status_code, 200)
result = response.json()
# 验证响应结构
self.assertIn('text', result)
self.assertIn('language', result)
print(f"\nASR测试成功:")
print(f" 识别文本: '{result['text']}'")
print(f" 语言: {result['language']}")
print(f" 注意: 静音音频应该返回空文本")
def test_06_api_key_validation(self):
"""测试API密钥验证"""
# 使用错误的API密钥
wrong_headers = {'X-API-Key': 'wrong-api-key'}
response = self.client.get(
f'{API_BASE_URL}/v1/tts-asr/status',
headers=wrong_headers
)
# 应该返回403 Forbidden
self.assertEqual(response.status_code, 403)
print(f"\n✓ API密钥验证正常错误密钥被拒绝")
def test_07_tts_long_text(self):
"""测试TTS长文本处理"""
# 较长的文本
long_text = "这是一段较长的测试文本用于测试TTS系统对长文本的处理能力。" * 3
response = self.client.post(
f'{API_BASE_URL}/v1/tts-asr/tts',
headers=self.headers,
json={
'text': long_text,
'voice': 'af_bella',
'rate': 1.0,
'format': 'wav'
},
timeout=60.0 # 长文本需要更长超时
)
if response.status_code == 500:
self.skipTest("TTS模型未加载或不可用")
self.assertEqual(response.status_code, 200)
result = response.json()
print(f"\n长文本TTS测试成功:")
print(f" 输入长度: {len(long_text)} 字符")
print(f" 音频大小: {len(base64.b64decode(result['audio_base64']))} bytes")
print(f" 时长: {result['duration_ms']} ms")
class PerformanceTest(unittest.TestCase):
"""性能测试"""
@classmethod
def setUpClass(cls):
cls.client = httpx.Client(timeout=TEST_TIMEOUT)
cls.headers = {'X-API-Key': API_KEY}
try:
response = cls.client.get(f'{API_BASE_URL}/v1/tts-asr/status', headers=cls.headers)
cls.service_available = response.status_code == 200
except:
cls.service_available = False
@classmethod
def tearDownClass(cls):
cls.client.close()
def setUp(self):
if not self.service_available:
self.skipTest("后端服务不可用")
def test_tts_latency(self):
"""测试TTS延迟"""
test_text = "测试延迟"
latencies = []
for i in range(3):
start = time.time()
response = self.client.post(
f'{API_BASE_URL}/v1/tts-asr/tts',
headers=self.headers,
json={'text': test_text}
)
elapsed = time.time() - start
if response.status_code == 200:
latencies.append(elapsed)
if latencies:
avg_latency = sum(latencies) / len(latencies)
print(f"\nTTS延迟测试:")
print(f" 平均延迟: {avg_latency:.3f}")
print(f" 最小延迟: {min(latencies):.3f}")
print(f" 最大延迟: {max(latencies):.3f}")
def run_tests(test_type: Optional[str] = None):
"""运行测试"""
loader = unittest.TestLoader()
suite = unittest.TestSuite()
if test_type == 'config':
suite.addTest(TTSASRIntegrationTest('test_01_config_endpoint'))
elif test_type == 'status':
suite.addTest(TTSASRIntegrationTest('test_02_status_endpoint'))
elif test_type == 'warmup':
suite.addTest(TTSASRIntegrationTest('test_03_warmup_endpoint'))
elif test_type == 'tts':
suite.addTest(TTSASRIntegrationTest('test_04_tts_endpoint_basic'))
elif test_type == 'asr':
suite.addTest(TTSASRIntegrationTest('test_05_asr_endpoint_basic'))
elif test_type == 'perf':
suite.addTests(loader.loadTestsFromTestCase(PerformanceTest))
else:
# 运行所有测试
suite.addTests(loader.loadTestsFromTestCase(TTSASRIntegrationTest))
suite.addTests(loader.loadTestsFromTestCase(PerformanceTest))
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
return result.wasSuccessful()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='TTS/ASR集成测试')
parser.add_argument('--test', choices=[
'config', 'status', 'warmup', 'tts', 'asr', 'perf'
], help='运行特定测试')
parser.add_argument('--url', default=API_BASE_URL, help='API基础URL')
parser.add_argument('--key', default=API_KEY, help='API密钥')
args = parser.parse_args()
# 更新配置
API_BASE_URL = args.url
API_KEY = args.key
print("=" * 70)
print("TTS/ASR 集成测试")
print("=" * 70)
print(f"API URL: {API_BASE_URL}")
print(f"测试类型: {args.test or '全部'}")
print("=" * 70)
success = run_tests(args.test)
sys.exit(0 if success else 1)