Add support to download and load TTS model from ModelScope, with a fallback to the HuggingFace mirror. Implement a `documentLoadStatus` property and helper functions in `office.js` to track file loading state. Improve request cancellation logic in `api.js`, ensuring proper cancel URL resolution and request‑id handling. These changes enhance robustness, reduce external dependencies, and provide better UX for office file handling.
240 lines
6.8 KiB
Python
240 lines
6.8 KiB
Python
import asyncio
|
||
import base64
|
||
import logging
|
||
import os
|
||
from io import BytesIO
|
||
from typing import Optional
|
||
|
||
# 设置 Hugging Face 镜像源为国内镜像
|
||
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
|
||
|
||
import torch
|
||
from fastapi import APIRouter, HTTPException
|
||
from pydantic import BaseModel
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# New TTS model import
|
||
try:
|
||
from qwen_tts import Qwen3TTSModel # type: ignore
|
||
except Exception: # pragma: no cover
|
||
Qwen3TTSModel = None # type: ignore
|
||
|
||
router = APIRouter()
|
||
|
||
# Global TTS model instance
|
||
_tts_model: Optional["Qwen3TTSModel"] = None
|
||
|
||
# Model paths for loading
|
||
MODEL_ID_HF = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign"
|
||
MODEL_ID_MS = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign"
|
||
|
||
|
||
def _get_device_map() -> str:
|
||
"""设备检测逻辑:优先 CUDA,其次 MPS,最后 CPU"""
|
||
if torch.cuda.is_available():
|
||
return "cuda:0"
|
||
try:
|
||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||
return "mps"
|
||
except Exception as e:
|
||
logger.debug("MPS check failed: %s", e)
|
||
return "cpu"
|
||
|
||
|
||
def _download_model_from_modelscope() -> Optional[str]:
|
||
"""从 ModelScope 下载模型到本地临时目录"""
|
||
try:
|
||
from modelscope import snapshot_download
|
||
cache_dir = os.path.join(os.path.dirname(__file__), "models")
|
||
os.makedirs(cache_dir, exist_ok=True)
|
||
model_dir = snapshot_download(
|
||
MODEL_ID_MS,
|
||
cache_dir=cache_dir,
|
||
revision="master"
|
||
)
|
||
logger.info("ModelScope 模型下载完成: %s", model_dir)
|
||
return model_dir
|
||
except Exception as e:
|
||
logger.warning("ModelScope 下载失败: %s", e)
|
||
return None
|
||
|
||
|
||
async def _warmup_tts():
|
||
"""预热 TTS 模型"""
|
||
await asyncio.to_thread(_load_tts_model_with_retry)
|
||
|
||
|
||
async def _warmup_all():
|
||
"""预热所有模型(TTS 和 ASR)"""
|
||
logger.info("[Warmup] 开始预热 TTS 模型...")
|
||
await _warmup_tts()
|
||
logger.info("[Warmup] TTS 模型预热完成")
|
||
|
||
|
||
def _load_tts_model_with_retry(max_retries: int = 3) -> "Qwen3TTSModel":
|
||
"""加载 TTS 模型,支持多个镜像源"""
|
||
global _tts_model
|
||
if _tts_model is not None:
|
||
return _tts_model
|
||
if Qwen3TTSModel is None:
|
||
raise RuntimeError("qwen_tts 库未安装,无法加载 TTS 模型")
|
||
|
||
device_map = _get_device_map()
|
||
last_err = None
|
||
|
||
# 策略1: 尝试从 ModelScope 下载后加载
|
||
for attempt in range(max_retries):
|
||
try:
|
||
logger.info("尝试从 ModelScope 下载模型...")
|
||
model_path = _download_model_from_modelscope()
|
||
if model_path and os.path.isdir(model_path):
|
||
_tts_model = Qwen3TTSModel.from_pretrained(
|
||
model_path,
|
||
device_map=device_map,
|
||
dtype=torch.float16,
|
||
)
|
||
logger.info("ModelScope 模型加载成功: %s", model_path)
|
||
return _tts_model
|
||
except Exception as e:
|
||
logger.warning("ModelScope 加载失败 (尝试 %d/%d): %s", attempt + 1, max_retries, e)
|
||
last_err = e
|
||
|
||
# 策略2: 尝试从 HuggingFace 镜像加载
|
||
for attempt in range(max_retries):
|
||
try:
|
||
logger.info("尝试从 HuggingFace 镜像加载模型...")
|
||
_tts_model = Qwen3TTSModel.from_pretrained(
|
||
MODEL_ID_HF,
|
||
device_map=device_map,
|
||
dtype=torch.float16,
|
||
)
|
||
logger.info("HuggingFace 模型加载成功")
|
||
return _tts_model
|
||
except Exception as e:
|
||
logger.warning("HuggingFace 加载失败 (尝试 %d/%d): %s", attempt + 1, max_retries, e)
|
||
last_err = e
|
||
|
||
raise RuntimeError(f"无法加载 TTS 模型: {last_err}") from last_err
|
||
|
||
|
||
class TTSRequest(BaseModel):
|
||
text: str
|
||
instruct: str = ""
|
||
speaker: str = "Vivian"
|
||
format: str = "wav"
|
||
|
||
|
||
class TTSResponse(BaseModel):
|
||
audio_base64: str
|
||
format: str
|
||
duration_ms: int
|
||
|
||
|
||
class ModelStatus(BaseModel):
|
||
tts_loaded: bool
|
||
asr_loaded: bool = False
|
||
device: str
|
||
tts_last_used: Optional[float] = None
|
||
asr_last_used: Optional[float] = None
|
||
|
||
|
||
def _ensure_model() -> "Qwen3TTSModel":
|
||
"""确保模型已加载"""
|
||
global _tts_model
|
||
if _tts_model is None:
|
||
_tts_model = _load_tts_model_with_retry()
|
||
return _tts_model
|
||
|
||
|
||
@router.get("/status", response_model=ModelStatus)
|
||
async def get_status():
|
||
"""获取模型状态"""
|
||
return ModelStatus(
|
||
tts_loaded=_tts_model is not None,
|
||
asr_loaded=False,
|
||
device=_get_device_map(),
|
||
)
|
||
|
||
|
||
@router.get("/config")
|
||
async def get_config():
|
||
"""获取配置信息"""
|
||
return {
|
||
"model": {
|
||
"tts": "Qwen3-TTS-12Hz-1.7B-VoiceDesign",
|
||
"asr": None,
|
||
},
|
||
"device": _get_device_map(),
|
||
"status": {
|
||
"tts_loaded": _tts_model is not None,
|
||
"asr_loaded": False,
|
||
}
|
||
}
|
||
|
||
|
||
@router.post("/warmup")
|
||
async def warmup_models():
|
||
"""手动触发模型预热"""
|
||
await _warmup_tts()
|
||
return {
|
||
"tts_warmup": _tts_model is not None,
|
||
"device": _get_device_map(),
|
||
}
|
||
|
||
|
||
@router.post("/tts", response_model=TTSResponse)
|
||
async def tts_endpoint(req: TTSRequest):
|
||
"""TTS 文字转语音端点"""
|
||
try:
|
||
model = _ensure_model()
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
text = req.text
|
||
instruct = req.instruct or ""
|
||
speaker = req.speaker or "Vivian"
|
||
|
||
try:
|
||
# VoiceDesign 模型使用 generate_voice_design 方法
|
||
wavs_sr = model.generate_voice_design(
|
||
text=text,
|
||
language="Chinese",
|
||
instruct=instruct,
|
||
)
|
||
except Exception as e:
|
||
logger.exception("TTS 推理失败")
|
||
raise HTTPException(status_code=500, detail=f"TTS 推理失败: {e}")
|
||
|
||
# Normalize output
|
||
if isinstance(wavs_sr, tuple) and len(wavs_sr) == 2:
|
||
wav_data, sr = wavs_sr
|
||
else:
|
||
wav_data, sr = wavs_sr[0], wavs_sr[1] # type: ignore
|
||
|
||
# 编码 WAV 到内存
|
||
try:
|
||
import soundfile as sf
|
||
bio = BytesIO()
|
||
sf.write(bio, wav_data, sr, format="WAV")
|
||
audio_bytes = bio.getvalue()
|
||
except Exception as e:
|
||
logger.exception("音频编码失败")
|
||
raise HTTPException(status_code=500, detail=f"音频编码失败: {e}")
|
||
|
||
# 计算时长(毫秒)
|
||
duration_ms = int(len(wav_data) / sr * 1000) if sr > 0 else 0
|
||
|
||
# 返回 JSON 格式,包含 base64 编码的音频
|
||
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
|
||
return TTSResponse(
|
||
audio_base64=audio_base64,
|
||
format="wav",
|
||
duration_ms=duration_ms,
|
||
)
|
||
|
||
|
||
def register_tts_asr_routes(app):
|
||
"""注册 TTS/ASR 路由到 FastAPI 应用"""
|
||
app.include_router(router, prefix="/v1/tts-asr")
|