refactor(tts): use numpy and proper temp file cleanup for WAV encoding
Update WAV encoding logic to convert audio to a NumPy array, employ a temporary file for safe write with soundfile, and ensure cleanup in a finally block. This resolves the BytesIO limitation and improves the reliability of the TTS endpoint.
This commit is contained in:
@@ -2,12 +2,13 @@ import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from io import BytesIO
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
# 设置 Hugging Face 镜像源为国内镜像
|
||||
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
@@ -193,11 +194,11 @@ async def tts_endpoint(req: TTSRequest):
|
||||
|
||||
text = req.text
|
||||
instruct = req.instruct or ""
|
||||
speaker = req.speaker or "Vivian"
|
||||
|
||||
try:
|
||||
# VoiceDesign 模型使用 generate_voice_design 方法
|
||||
wavs_sr = model.generate_voice_design(
|
||||
# 返回 (wavs, sr),其中 wavs 是列表,wavs[0] 是第一个音频数据
|
||||
wavs, sr = model.generate_voice_design(
|
||||
text=text,
|
||||
language="Chinese",
|
||||
instruct=instruct,
|
||||
@@ -206,28 +207,36 @@ async def tts_endpoint(req: TTSRequest):
|
||||
logger.exception("TTS 推理失败")
|
||||
raise HTTPException(status_code=500, detail=f"TTS 推理失败: {e}")
|
||||
|
||||
# Normalize output
|
||||
if isinstance(wavs_sr, tuple) and len(wavs_sr) == 2:
|
||||
wav_data, sr = wavs_sr
|
||||
else:
|
||||
wav_data, sr = wavs_sr[0], wavs_sr[1] # type: ignore
|
||||
# 获取第一个音频数据
|
||||
wav_data = wavs[0] if isinstance(wavs, (list, tuple)) else wavs
|
||||
|
||||
# 转换为 numpy 数组
|
||||
if hasattr(wav_data, 'numpy'):
|
||||
wav_data = wav_data.cpu().numpy()
|
||||
wav_data = np.asarray(wav_data, dtype=np.float32)
|
||||
|
||||
logger.debug("wav_data shape: %s, dtype: %s, sr: %s", wav_data.shape, wav_data.dtype, sr)
|
||||
|
||||
# 编码 WAV 到内存
|
||||
tmp_path = None
|
||||
try:
|
||||
import tempfile
|
||||
import soundfile as sf
|
||||
# soundfile 不支持直接写入 BytesIO,需要用临时文件
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
sf.write(tmp_path, wav_data, sr, format="WAV")
|
||||
# 创建临时文件
|
||||
fd, tmp_path = tempfile.mkstemp(suffix=".wav")
|
||||
os.close(fd)
|
||||
sf.write(tmp_path, wav_data, sr)
|
||||
with open(tmp_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
# 清理临时文件
|
||||
import os as _os
|
||||
_os.unlink(tmp_path)
|
||||
except Exception as e:
|
||||
logger.exception("音频编码失败")
|
||||
raise HTTPException(status_code=500, detail=f"音频编码失败: {e}")
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 计算时长(毫秒)
|
||||
duration_ms = int(len(wav_data) / sr * 1000) if sr > 0 else 0
|
||||
|
||||
Reference in New Issue
Block a user