Files
llm-in-text/backend/tts_stt.py

141 lines
5.2 KiB
Python

# TTS and Speech Recognition API for macOS Silicon
import os
import asyncio
import logging
import base64
from typing import Optional
from fastapi import APIRouter, UploadFile, File, HTTPException, Security
from pydantic import BaseModel
from fastapi.security import APIKeyHeader
router = APIRouter()
api_key_header = APIKeyHeader(name="X-API-Key")
logger = logging.getLogger("tts_stt")
def _speak_text_macos(text: str, voice: str = "meijia", rate: float = 0.5) -> bytes:
import subprocess
import tempfile
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
output_path = tmp.name
try:
cmd = ["say", "-v", voice, "-r", str(rate * 10), "--output-format", "WAVE", "-o", output_path, text]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode != 0:
raise Exception(f"TTS failed: {result.stderr}")
with open(output_path, "rb") as f:
audio_data = f.read()
return audio_data
finally:
if os.path.exists(output_path):
os.unlink(output_path)
async def _speak_text_macos_async(text: str, voice: str = "meijia", rate: float = 0.5) -> bytes:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, _speak_text_macos, text, voice, rate)
def _recognize_speech_macos(audio_data: bytes, language: str = "zh-CN") -> str:
import tempfile
try:
import whisper
model = whisper.load_model("tiny")
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp2:
tmp2.write(audio_data)
audio_for_whisper = tmp2.name
try:
result = model.transcribe(audio_for_whisper, language=language[:2])
return result["text"]
finally:
if os.path.exists(audio_for_whisper):
os.unlink(audio_for_whisper)
except ImportError:
raise Exception("Whisper is required for speech recognition on macOS")
async def _recognize_speech_macos_async(audio_data: bytes, language: str = "zh-CN") -> str:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, _recognize_speech_macos, audio_data, language)
class TTSRequest(BaseModel):
text: str
voice: str = "meijia"
rate: float = 0.5
format: str = "wav"
class TTSResponse(BaseModel):
audio_base64: str
format: str
duration_ms: int
class STTRequest(BaseModel):
audio_base64: str
language: str = "zh-CN"
class STTResponse(BaseModel):
text: str
language: str
@router.post("/tts", response_model=TTSResponse)
async def text_to_speech(req: TTSRequest, api_key: str = Security(get_api_key)):
request_id = str(hash(req.text))[:8]
try:
logger.info("[TTS][%s] text_chars=%d voice=%s", request_id, len(req.text), req.voice)
audio_data = await _speak_text_macos_async(req.text, req.voice, req.rate)
if req.format.lower() == "mp3":
import tempfile
import subprocess
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in:
tmp_in.write(audio_data)
input_path = tmp_in.name
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_out:
output_path = tmp_out.name
try:
cmd = ["ffmpeg", "-i", input_path, "-acodec", "libmp3lame", output_path]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode != 0:
raise Exception(f"MP3 conversion failed: {result.stderr}")
with open(output_path, "rb") as f:
audio_data = f.read()
finally:
for p in [input_path, output_path]:
if os.path.exists(p):
os.unlink(p)
duration_ms = len(audio_data) * 1000 // 16000
logger.info("[TTS][%s] success duration_ms=%d", request_id, duration_ms)
return TTSResponse(audio_base64=base64.b64encode(audio_data).decode(), format=req.format, duration_ms=duration_ms)
except Exception as e:
logger.exception("[TTS] failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/stt", response_model=STTResponse)
async def speech_to_text(req: STTRequest, api_key: str = Security(get_api_key)):
request_id = str(hash(req.audio_base64))[:8]
try:
logger.info("[STT][%s] audio_base64_chars=%d language=%s", request_id, len(req.audio_base64), req.language)
audio_data = base64.b64decode(req.audio_base64)
text = await _recognize_speech_macos_async(audio_data, req.language)
logger.info("[STT][%s] success text_chars=%d", request_id, len(text))
return STTResponse(text=text, language=req.language)
except Exception as e:
logger.exception("[STT] failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
def get_api_key(api_key: str):
from backend.main import API_KEY
if api_key != API_KEY:
from fastapi import HTTPException
raise HTTPException(status_code=403, detail="Could not validate credentials")
return api_key
def register_tts_stt_routes(app):
app.include_router(router, prefix="/v1/tts-stt")