# TTS and ASR API for macOS Silicon with HuggingFace transformers import asyncio import base64 import logging import os import platform from fastapi import APIRouter, HTTPException, Security from pydantic import BaseModel import numpy as np router = APIRouter() logger = logging.getLogger("tts_asr") _tts_pipeline = None _asr_pipeline = None _device = None def _get_device(): global _device if _device is not None: return _device import torch if platform.system() == "Darwin" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): _device = "mps" logger.info("[Device] 使用 MPS 加速") elif torch.cuda.is_available(): _device = "cuda" logger.info("[Device] 使用 CUDA 加速") else: _device = "cpu" logger.info("[Device] 使用 CPU") return _device def _device_arg(): device = _get_device() if device == "cuda": return "cuda:0" return device def _get_tts_pipeline(): global _tts_pipeline if _tts_pipeline is not None: return _tts_pipeline import torch from transformers import pipeline logger.info("[TTS] 加载 Kokoro-82M 模型...") _tts_pipeline = pipeline( "text-to-speech", model="hexgrad/Kokoro-82M", trust_remote_code=True, device=_device_arg(), torch_dtype=torch.float16 if _get_device() != "cpu" else torch.float32, ) logger.info("[TTS] Kokoro-82M 模型加载完成") return _tts_pipeline def _get_asr_pipeline(): global _asr_pipeline if _asr_pipeline is not None: return _asr_pipeline import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline logger.info("[ASR] 加载 Whisper large-v3-turbo 模型...") model_id = "openai/whisper-large-v3-turbo" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch.float16 if _get_device() != "cpu" else torch.float32, low_cpu_mem_usage=True, use_safetensors=True, ) processor = AutoProcessor.from_pretrained(model_id) _asr_pipeline = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch.float16 if _get_device() != "cpu" else torch.float32, device=_device_arg(), ) logger.info("[ASR] Whisper large-v3-turbo 模型加载完成") return _asr_pipeline def _save_audio_to_wav(audio_data: bytes, sample_rate: int = 16000) -> str: import tempfile import wave with tempfile.NamedTemporaryFile(suffix=".wav", delete=False, mode="wb") as tmp: with wave.open(tmp.name, "wb") as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(sample_rate) wf.writeframes(audio_data) return tmp.name def _tts_sync(text: str, voice: str = "af_bella", rate: float = 1.0) -> tuple[bytes, int]: tts = _get_tts_pipeline() result = tts(text, voice=voice) audio = None sample_rate = 24000 if isinstance(result, dict): audio = result.get("audio") sample_rate = int(result.get("sampling_rate", sample_rate)) elif isinstance(result, (list, tuple)) and result: audio = result[0] if audio is None: raise RuntimeError("Kokoro 未返回音频数据") if hasattr(audio, "cpu"): audio = audio.cpu().numpy() duration_ms = int(len(audio) * 1000 / sample_rate) if audio.dtype != np.int16: audio = (audio * 32767).astype(np.int16) import tempfile import wave with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: output_path = tmp.name try: with wave.open(output_path, "wb") as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(sample_rate) wf.writeframes(audio.tobytes()) with open(output_path, "rb") as f: return f.read(), duration_ms finally: if os.path.exists(output_path): os.unlink(output_path) async def _text_to_speech(text: str, voice: str = "af_bella", rate: float = 1.0) -> tuple[bytes, int]: return await asyncio.to_thread(_tts_sync, text, voice, rate) def _asr_sync(audio_data: bytes, language: str = "zh") -> str: import soundfile as sf asr = _get_asr_pipeline() audio_path = _save_audio_to_wav(audio_data) try: audio_array, sample_rate = sf.read(audio_path) result = asr( audio_array, sampling_rate=sample_rate, generate_kwargs={"language": language, "task": "transcribe"}, ) if isinstance(result, dict): return result.get("text", "").strip() return str(result).strip() finally: if os.path.exists(audio_path): os.unlink(audio_path) async def _speech_to_text(audio_data: bytes, language: str = "zh") -> str: return await asyncio.to_thread(_asr_sync, audio_data, language) class TTSRequest(BaseModel): text: str voice: str = "af_bella" rate: float = 1.0 format: str = "wav" class TTSResponse(BaseModel): audio_base64: str format: str duration_ms: int class ASRRequest(BaseModel): audio_base64: str language: str = "zh-CN" class ASRResponse(BaseModel): text: str language: str def get_api_key(api_key: str): from backend.main import API_KEY if api_key != API_KEY: raise HTTPException(status_code=403, detail="API Key 无效") return api_key @router.post("/tts", response_model=TTSResponse) async def text_to_speech(req: TTSRequest, api_key: str = Security(get_api_key)): request_id = str(hash(req.text))[:8] try: logger.info("[TTS][%s] text_chars=%d voice=%s format=%s", request_id, len(req.text), req.voice, req.format) audio_data, duration_ms = await _text_to_speech(req.text, req.voice, req.rate) if req.format.lower() == "mp3": import subprocess import tempfile with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in: tmp_in.write(audio_data) input_path = tmp_in.name with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_out: output_path = tmp_out.name try: cmd = ["ffmpeg", "-i", input_path, "-acodec", "libmp3lame", "-ab", "128k", output_path] result = await asyncio.to_thread(lambda: subprocess.run(cmd, capture_output=True, text=True, timeout=30)) if result.returncode != 0: raise RuntimeError(f"MP3 转换失败: {result.stderr}") with open(output_path, "rb") as f: audio_data = f.read() finally: for path in [input_path, output_path]: if os.path.exists(path): os.unlink(path) logger.info("[TTS][%s] success duration_ms=%d", request_id, duration_ms) return TTSResponse(audio_base64=base64.b64encode(audio_data).decode(), format=req.format, duration_ms=duration_ms) except Exception as e: logger.exception("[TTS] failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) @router.post("/asr", response_model=ASRResponse) async def speech_to_text(req: ASRRequest, api_key: str = Security(get_api_key)): request_id = str(hash(req.audio_base64))[:8] try: logger.info("[ASR][%s] audio_base64_chars=%d language=%s", request_id, len(req.audio_base64), req.language) audio_data = base64.b64decode(req.audio_base64) text = await _speech_to_text(audio_data, req.language[:2]) logger.info("[ASR][%s] success text_chars=%d", request_id, len(text)) return ASRResponse(text=text, language=req.language) except Exception as e: logger.exception("[ASR] failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) def register_tts_asr_routes(app): app.include_router(router, prefix="/v1/tts-asr")