Files
llm-in-text/backend/tts_asr.py
ydy0615 f99acf5d50 feat(core): add ModelScope support for TTS and new office load status
Add support to download and load TTS model from ModelScope, with a fallback to the HuggingFace mirror.
Implement a `documentLoadStatus` property and helper functions in `office.js` to track file loading state.
Improve request cancellation logic in `api.js`, ensuring proper cancel URL resolution and request‑id handling.

These changes enhance robustness, reduce external dependencies, and provide better UX for office file handling.
2026-04-11 10:04:34 +08:00

240 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import base64
import logging
import os
from io import BytesIO
from typing import Optional
# 设置 Hugging Face 镜像源为国内镜像
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
import torch
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
logger = logging.getLogger(__name__)
# New TTS model import
try:
from qwen_tts import Qwen3TTSModel # type: ignore
except Exception: # pragma: no cover
Qwen3TTSModel = None # type: ignore
router = APIRouter()
# Global TTS model instance
_tts_model: Optional["Qwen3TTSModel"] = None
# Model paths for loading
MODEL_ID_HF = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign"
MODEL_ID_MS = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign"
def _get_device_map() -> str:
"""设备检测逻辑:优先 CUDA其次 MPS最后 CPU"""
if torch.cuda.is_available():
return "cuda:0"
try:
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
except Exception as e:
logger.debug("MPS check failed: %s", e)
return "cpu"
def _download_model_from_modelscope() -> Optional[str]:
"""从 ModelScope 下载模型到本地临时目录"""
try:
from modelscope import snapshot_download
cache_dir = os.path.join(os.path.dirname(__file__), "models")
os.makedirs(cache_dir, exist_ok=True)
model_dir = snapshot_download(
MODEL_ID_MS,
cache_dir=cache_dir,
revision="master"
)
logger.info("ModelScope 模型下载完成: %s", model_dir)
return model_dir
except Exception as e:
logger.warning("ModelScope 下载失败: %s", e)
return None
async def _warmup_tts():
"""预热 TTS 模型"""
await asyncio.to_thread(_load_tts_model_with_retry)
async def _warmup_all():
"""预热所有模型TTS 和 ASR"""
logger.info("[Warmup] 开始预热 TTS 模型...")
await _warmup_tts()
logger.info("[Warmup] TTS 模型预热完成")
def _load_tts_model_with_retry(max_retries: int = 3) -> "Qwen3TTSModel":
"""加载 TTS 模型,支持多个镜像源"""
global _tts_model
if _tts_model is not None:
return _tts_model
if Qwen3TTSModel is None:
raise RuntimeError("qwen_tts 库未安装,无法加载 TTS 模型")
device_map = _get_device_map()
last_err = None
# 策略1: 尝试从 ModelScope 下载后加载
for attempt in range(max_retries):
try:
logger.info("尝试从 ModelScope 下载模型...")
model_path = _download_model_from_modelscope()
if model_path and os.path.isdir(model_path):
_tts_model = Qwen3TTSModel.from_pretrained(
model_path,
device_map=device_map,
dtype=torch.float16,
)
logger.info("ModelScope 模型加载成功: %s", model_path)
return _tts_model
except Exception as e:
logger.warning("ModelScope 加载失败 (尝试 %d/%d): %s", attempt + 1, max_retries, e)
last_err = e
# 策略2: 尝试从 HuggingFace 镜像加载
for attempt in range(max_retries):
try:
logger.info("尝试从 HuggingFace 镜像加载模型...")
_tts_model = Qwen3TTSModel.from_pretrained(
MODEL_ID_HF,
device_map=device_map,
dtype=torch.float16,
)
logger.info("HuggingFace 模型加载成功")
return _tts_model
except Exception as e:
logger.warning("HuggingFace 加载失败 (尝试 %d/%d): %s", attempt + 1, max_retries, e)
last_err = e
raise RuntimeError(f"无法加载 TTS 模型: {last_err}") from last_err
class TTSRequest(BaseModel):
text: str
instruct: str = ""
speaker: str = "Vivian"
format: str = "wav"
class TTSResponse(BaseModel):
audio_base64: str
format: str
duration_ms: int
class ModelStatus(BaseModel):
tts_loaded: bool
asr_loaded: bool = False
device: str
tts_last_used: Optional[float] = None
asr_last_used: Optional[float] = None
def _ensure_model() -> "Qwen3TTSModel":
"""确保模型已加载"""
global _tts_model
if _tts_model is None:
_tts_model = _load_tts_model_with_retry()
return _tts_model
@router.get("/status", response_model=ModelStatus)
async def get_status():
"""获取模型状态"""
return ModelStatus(
tts_loaded=_tts_model is not None,
asr_loaded=False,
device=_get_device_map(),
)
@router.get("/config")
async def get_config():
"""获取配置信息"""
return {
"model": {
"tts": "Qwen3-TTS-12Hz-1.7B-VoiceDesign",
"asr": None,
},
"device": _get_device_map(),
"status": {
"tts_loaded": _tts_model is not None,
"asr_loaded": False,
}
}
@router.post("/warmup")
async def warmup_models():
"""手动触发模型预热"""
await _warmup_tts()
return {
"tts_warmup": _tts_model is not None,
"device": _get_device_map(),
}
@router.post("/tts", response_model=TTSResponse)
async def tts_endpoint(req: TTSRequest):
"""TTS 文字转语音端点"""
try:
model = _ensure_model()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
text = req.text
instruct = req.instruct or ""
speaker = req.speaker or "Vivian"
try:
# VoiceDesign 模型使用 generate_voice_design 方法
wavs_sr = model.generate_voice_design(
text=text,
language="Chinese",
instruct=instruct,
)
except Exception as e:
logger.exception("TTS 推理失败")
raise HTTPException(status_code=500, detail=f"TTS 推理失败: {e}")
# Normalize output
if isinstance(wavs_sr, tuple) and len(wavs_sr) == 2:
wav_data, sr = wavs_sr
else:
wav_data, sr = wavs_sr[0], wavs_sr[1] # type: ignore
# 编码 WAV 到内存
try:
import soundfile as sf
bio = BytesIO()
sf.write(bio, wav_data, sr, format="WAV")
audio_bytes = bio.getvalue()
except Exception as e:
logger.exception("音频编码失败")
raise HTTPException(status_code=500, detail=f"音频编码失败: {e}")
# 计算时长(毫秒)
duration_ms = int(len(wav_data) / sr * 1000) if sr > 0 else 0
# 返回 JSON 格式,包含 base64 编码的音频
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
return TTSResponse(
audio_base64=audio_base64,
format="wav",
duration_ms=duration_ms,
)
def register_tts_asr_routes(app):
"""注册 TTS/ASR 路由到 FastAPI 应用"""
app.include_router(router, prefix="/v1/tts-asr")