Files
llm-in-text/backend/tts_asr.py
ydy0615 9ff51ac2f3 feat(plugin): add document export, doc‑block, and TTS/ASR support
Adds a DocBlock component that renders embedded documents, new export buttons for DOCX
and PDF, and updates the file‑upload picker to accept *.txt, *.docx, *.pptx, and *.pdf.
Introduces a DOCX→PDF conversion bridge in the backend and new /tts and /asr
endpoints that expose TTS and speech‑recognition functionality.  The README is
rewritten to describe the new features and clean up legacy documentation.  All
changes are backward‑compatible and do not introduce breaking API changes.
2026-04-04 23:56:18 +08:00

256 lines
8.0 KiB
Python

# TTS and ASR API for macOS Silicon with HuggingFace transformers
import asyncio
import base64
import logging
import os
import platform
from fastapi import APIRouter, HTTPException, Security
from pydantic import BaseModel
import numpy as np
router = APIRouter()
logger = logging.getLogger("tts_asr")
_tts_pipeline = None
_asr_pipeline = None
_device = None
def _get_device():
global _device
if _device is not None:
return _device
import torch
if platform.system() == "Darwin" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
_device = "mps"
logger.info("[Device] 使用 MPS 加速")
elif torch.cuda.is_available():
_device = "cuda"
logger.info("[Device] 使用 CUDA 加速")
else:
_device = "cpu"
logger.info("[Device] 使用 CPU")
return _device
def _device_arg():
device = _get_device()
if device == "cuda":
return "cuda:0"
return device
def _get_tts_pipeline():
global _tts_pipeline
if _tts_pipeline is not None:
return _tts_pipeline
import torch
from transformers import pipeline
logger.info("[TTS] 加载 Kokoro-82M 模型...")
_tts_pipeline = pipeline(
"text-to-speech",
model="hexgrad/Kokoro-82M",
trust_remote_code=True,
device=_device_arg(),
torch_dtype=torch.float16 if _get_device() != "cpu" else torch.float32,
)
logger.info("[TTS] Kokoro-82M 模型加载完成")
return _tts_pipeline
def _get_asr_pipeline():
global _asr_pipeline
if _asr_pipeline is not None:
return _asr_pipeline
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
logger.info("[ASR] 加载 Whisper large-v3-turbo 模型...")
model_id = "openai/whisper-large-v3-turbo"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch.float16 if _get_device() != "cpu" else torch.float32,
low_cpu_mem_usage=True,
use_safetensors=True,
)
processor = AutoProcessor.from_pretrained(model_id)
_asr_pipeline = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch.float16 if _get_device() != "cpu" else torch.float32,
device=_device_arg(),
)
logger.info("[ASR] Whisper large-v3-turbo 模型加载完成")
return _asr_pipeline
def _save_audio_to_wav(audio_data: bytes, sample_rate: int = 16000) -> str:
import tempfile
import wave
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False, mode="wb") as tmp:
with wave.open(tmp.name, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio_data)
return tmp.name
def _tts_sync(text: str, voice: str = "af_bella", rate: float = 1.0) -> tuple[bytes, int]:
tts = _get_tts_pipeline()
result = tts(text, voice=voice)
audio = None
sample_rate = 24000
if isinstance(result, dict):
audio = result.get("audio")
sample_rate = int(result.get("sampling_rate", sample_rate))
elif isinstance(result, (list, tuple)) and result:
audio = result[0]
if audio is None:
raise RuntimeError("Kokoro 未返回音频数据")
if hasattr(audio, "cpu"):
audio = audio.cpu().numpy()
duration_ms = int(len(audio) * 1000 / sample_rate)
if audio.dtype != np.int16:
audio = (audio * 32767).astype(np.int16)
import tempfile
import wave
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
output_path = tmp.name
try:
with wave.open(output_path, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio.tobytes())
with open(output_path, "rb") as f:
return f.read(), duration_ms
finally:
if os.path.exists(output_path):
os.unlink(output_path)
async def _text_to_speech(text: str, voice: str = "af_bella", rate: float = 1.0) -> tuple[bytes, int]:
return await asyncio.to_thread(_tts_sync, text, voice, rate)
def _asr_sync(audio_data: bytes, language: str = "zh") -> str:
import soundfile as sf
asr = _get_asr_pipeline()
audio_path = _save_audio_to_wav(audio_data)
try:
audio_array, sample_rate = sf.read(audio_path)
result = asr(
audio_array,
sampling_rate=sample_rate,
generate_kwargs={"language": language, "task": "transcribe"},
)
if isinstance(result, dict):
return result.get("text", "").strip()
return str(result).strip()
finally:
if os.path.exists(audio_path):
os.unlink(audio_path)
async def _speech_to_text(audio_data: bytes, language: str = "zh") -> str:
return await asyncio.to_thread(_asr_sync, audio_data, language)
class TTSRequest(BaseModel):
text: str
voice: str = "af_bella"
rate: float = 1.0
format: str = "wav"
class TTSResponse(BaseModel):
audio_base64: str
format: str
duration_ms: int
class ASRRequest(BaseModel):
audio_base64: str
language: str = "zh-CN"
class ASRResponse(BaseModel):
text: str
language: str
def get_api_key(api_key: str):
from backend.main import API_KEY
if api_key != API_KEY:
raise HTTPException(status_code=403, detail="API Key 无效")
return api_key
@router.post("/tts", response_model=TTSResponse)
async def text_to_speech(req: TTSRequest, api_key: str = Security(get_api_key)):
request_id = str(hash(req.text))[:8]
try:
logger.info("[TTS][%s] text_chars=%d voice=%s format=%s", request_id, len(req.text), req.voice, req.format)
audio_data, duration_ms = await _text_to_speech(req.text, req.voice, req.rate)
if req.format.lower() == "mp3":
import subprocess
import tempfile
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in:
tmp_in.write(audio_data)
input_path = tmp_in.name
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_out:
output_path = tmp_out.name
try:
cmd = ["ffmpeg", "-i", input_path, "-acodec", "libmp3lame", "-ab", "128k", output_path]
result = await asyncio.to_thread(lambda: subprocess.run(cmd, capture_output=True, text=True, timeout=30))
if result.returncode != 0:
raise RuntimeError(f"MP3 转换失败: {result.stderr}")
with open(output_path, "rb") as f:
audio_data = f.read()
finally:
for path in [input_path, output_path]:
if os.path.exists(path):
os.unlink(path)
logger.info("[TTS][%s] success duration_ms=%d", request_id, duration_ms)
return TTSResponse(audio_base64=base64.b64encode(audio_data).decode(), format=req.format, duration_ms=duration_ms)
except Exception as e:
logger.exception("[TTS] failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/asr", response_model=ASRResponse)
async def speech_to_text(req: ASRRequest, api_key: str = Security(get_api_key)):
request_id = str(hash(req.audio_base64))[:8]
try:
logger.info("[ASR][%s] audio_base64_chars=%d language=%s", request_id, len(req.audio_base64), req.language)
audio_data = base64.b64decode(req.audio_base64)
text = await _speech_to_text(audio_data, req.language[:2])
logger.info("[ASR][%s] success text_chars=%d", request_id, len(text))
return ASRResponse(text=text, language=req.language)
except Exception as e:
logger.exception("[ASR] failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
def register_tts_asr_routes(app):
app.include_router(router, prefix="/v1/tts-asr")