refactor(tts): use numpy and proper temp file cleanup for WAV encoding

Update WAV encoding logic to convert audio to a NumPy array, employ a
temporary file for safe write with soundfile, and ensure cleanup in a
finally block. This resolves the BytesIO limitation and improves the
reliability of the TTS endpoint.
This commit is contained in:
2026-04-11 10:33:46 +08:00
parent ae0d53e295
commit e0054d4cbc

View File

@@ -2,12 +2,13 @@ import asyncio
import base64
import logging
import os
from io import BytesIO
import tempfile
from typing import Optional
# 设置 Hugging Face 镜像源为国内镜像
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
import numpy as np
import torch
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
@@ -193,11 +194,11 @@ async def tts_endpoint(req: TTSRequest):
text = req.text
instruct = req.instruct or ""
speaker = req.speaker or "Vivian"
try:
# VoiceDesign 模型使用 generate_voice_design 方法
wavs_sr = model.generate_voice_design(
# 返回 (wavs, sr),其中 wavs 是列表wavs[0] 是第一个音频数据
wavs, sr = model.generate_voice_design(
text=text,
language="Chinese",
instruct=instruct,
@@ -206,28 +207,36 @@ async def tts_endpoint(req: TTSRequest):
logger.exception("TTS 推理失败")
raise HTTPException(status_code=500, detail=f"TTS 推理失败: {e}")
# Normalize output
if isinstance(wavs_sr, tuple) and len(wavs_sr) == 2:
wav_data, sr = wavs_sr
else:
wav_data, sr = wavs_sr[0], wavs_sr[1] # type: ignore
# 获取第一个音频数据
wav_data = wavs[0] if isinstance(wavs, (list, tuple)) else wavs
# 转换为 numpy 数组
if hasattr(wav_data, 'numpy'):
wav_data = wav_data.cpu().numpy()
wav_data = np.asarray(wav_data, dtype=np.float32)
logger.debug("wav_data shape: %s, dtype: %s, sr: %s", wav_data.shape, wav_data.dtype, sr)
# 编码 WAV 到内存
tmp_path = None
try:
import tempfile
import soundfile as sf
# soundfile 不支持直接写入 BytesIO需要用临时文件
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp_path = tmp.name
sf.write(tmp_path, wav_data, sr, format="WAV")
# 创建临时文件
fd, tmp_path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
sf.write(tmp_path, wav_data, sr)
with open(tmp_path, "rb") as f:
audio_bytes = f.read()
# 清理临时文件
import os as _os
_os.unlink(tmp_path)
except Exception as e:
logger.exception("音频编码失败")
raise HTTPException(status_code=500, detail=f"音频编码失败: {e}")
finally:
# 清理临时文件
if tmp_path and os.path.exists(tmp_path):
try:
os.unlink(tmp_path)
except Exception:
pass
# 计算时长(毫秒)
duration_ms = int(len(wav_data) / sr * 1000) if sr > 0 else 0