feat(plugin): add document export, doc‑block, and TTS/ASR support
Adds a DocBlock component that renders embedded documents, new export buttons for DOCX and PDF, and updates the file‑upload picker to accept *.txt, *.docx, *.pptx, and *.pdf. Introduces a DOCX→PDF conversion bridge in the backend and new /tts and /asr endpoints that expose TTS and speech‑recognition functionality. The README is rewritten to describe the new features and clean up legacy documentation. All changes are backward‑compatible and do not introduce breaking API changes.
This commit is contained in:
20
backend/docx2pdf_bridge.cjs
Normal file
20
backend/docx2pdf_bridge.cjs
Normal file
@@ -0,0 +1,20 @@
|
||||
const path = require('path')
|
||||
const { convert } = require('docx2pdf-converter')
|
||||
|
||||
function main() {
|
||||
const inputPath = process.argv[2]
|
||||
const outputPath = process.argv[3]
|
||||
|
||||
if (!inputPath || !outputPath) {
|
||||
throw new Error('缺少 DOCX 或 PDF 路径')
|
||||
}
|
||||
|
||||
convert(path.resolve(inputPath), path.resolve(outputPath))
|
||||
}
|
||||
|
||||
try {
|
||||
main()
|
||||
} catch (error) {
|
||||
console.error(error instanceof Error ? error.message : String(error))
|
||||
process.exit(1)
|
||||
}
|
||||
126
backend/main.py
126
backend/main.py
@@ -3,13 +3,16 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request, Security
|
||||
from fastapi import FastAPI, HTTPException, Request, Security, File, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse, Response
|
||||
from fastapi.security import APIKeyHeader
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -81,6 +84,32 @@ class ConvertRequest(BaseModel):
|
||||
filename: str = "document.pdf"
|
||||
|
||||
|
||||
ALLOWED_CONVERT_EXTENSIONS = {".txt", ".docx", ".pptx", ".pdf"}
|
||||
IMAGE_MARKDOWN_RE = re.compile(r"!\[[^\]]*]\([^)]+\)")
|
||||
IMAGE_HTML_RE = re.compile(r"<img\b[^>]*>", re.IGNORECASE)
|
||||
|
||||
|
||||
def _convert_docx_to_pdf(input_path: str, output_path: str) -> None:
|
||||
node_executable = shutil.which("node")
|
||||
if not node_executable:
|
||||
raise RuntimeError("未找到 Node.js,无法转换 DOCX 为 PDF")
|
||||
|
||||
bridge_path = os.path.join(os.path.dirname(__file__), "docx2pdf_bridge.cjs")
|
||||
if not os.path.exists(bridge_path):
|
||||
raise RuntimeError("缺少 DOCX 转 PDF 桥接脚本")
|
||||
|
||||
result = subprocess.run(
|
||||
[node_executable, bridge_path, input_path, output_path],
|
||||
cwd=os.path.dirname(os.path.dirname(__file__)),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
error_text = (result.stderr or result.stdout or "DOCX 转 PDF 失败").strip()
|
||||
raise RuntimeError(error_text)
|
||||
|
||||
|
||||
def _preview(text: str, limit: int = 80) -> str:
|
||||
value = (text or "").replace("\n", "\\n")
|
||||
if len(value) <= limit:
|
||||
@@ -88,6 +117,14 @@ def _preview(text: str, limit: int = 80) -> str:
|
||||
return value[:limit] + "..."
|
||||
|
||||
|
||||
def _sanitize_converted_markdown(text: str) -> str:
|
||||
value = (text or "").replace("\r\n", "\n").replace("\r", "\n")
|
||||
value = IMAGE_MARKDOWN_RE.sub("", value)
|
||||
value = IMAGE_HTML_RE.sub("", value)
|
||||
value = re.sub(r"\n{3,}", "\n\n", value)
|
||||
return value.strip()
|
||||
|
||||
|
||||
def _sse_payload(payload: dict) -> str:
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
@@ -253,9 +290,9 @@ async def ocr_image(request: OCRRequest, api_key: str = Security(get_api_key)):
|
||||
|
||||
@app.post("/v1/convert")
|
||||
async def convert_to_markdown(request: ConvertRequest, api_key: str = Security(get_api_key)):
|
||||
"""鐏忓棙鏋冩禒鎯版祮閹诡澀璐烳arkdown閺嶇厧绱?""
|
||||
"""Convert file to markdown"""
|
||||
request_id = str(uuid.uuid4())[:8]
|
||||
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"[%s] /v1/convert filename=%s file_base64_chars=%d",
|
||||
@@ -263,53 +300,106 @@ async def convert_to_markdown(request: ConvertRequest, api_key: str = Security(g
|
||||
request.filename,
|
||||
len(request.file or ""),
|
||||
)
|
||||
|
||||
# 鐟欙絿鐖淏ase64閺傚洣娆㈤崘鍛啇
|
||||
|
||||
# Decode base64
|
||||
file_bytes = base64.b64decode(request.file)
|
||||
logger.info("[%s] /v1/convert decoded file_bytes=%d", request_id, len(file_bytes))
|
||||
|
||||
# 閼惧嘲褰囬弬鍥︽閹碘晛鐫嶉崥?
|
||||
|
||||
# Get file extension
|
||||
ext = os.path.splitext(request.filename)[1].lower()
|
||||
|
||||
# 閸掓稑缂撴稉瀛樻閺傚洣娆?
|
||||
|
||||
if ext not in ALLOWED_CONVERT_EXTENSIONS:
|
||||
raise ValueError("仅支持 txt、docx、pptx、pdf 格式")
|
||||
|
||||
if ext == ".txt":
|
||||
markdown_text = _sanitize_converted_markdown(file_bytes.decode("utf-8", errors="ignore"))
|
||||
return {
|
||||
"markdown": markdown_text,
|
||||
"filename": request.filename
|
||||
}
|
||||
|
||||
# Create temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
|
||||
tmp.write(file_bytes)
|
||||
tmp_path = tmp.name
|
||||
|
||||
|
||||
try:
|
||||
# 娴h法鏁arkItDown鏉烆剚宕叉稉绡梐rkdown
|
||||
# Convert using MarkItDown
|
||||
md = markitdown.MarkItDown()
|
||||
result = md.convert(tmp_path)
|
||||
markdown_text = result.text_content
|
||||
|
||||
markdown_text = _sanitize_converted_markdown(result.text_content)
|
||||
|
||||
logger.info(
|
||||
"[%s] /v1/convert success text_chars=%d text_preview='%s'",
|
||||
request_id,
|
||||
len(markdown_text or ""),
|
||||
_preview(markdown_text, 120),
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"markdown": markdown_text,
|
||||
"filename": request.filename
|
||||
}
|
||||
finally:
|
||||
# 濞撳懐鎮婃稉瀛樻閺傚洣娆?
|
||||
# Clean up temporary file
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("[%s] /v1/convert failed: %s", request_id, e)
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@app.post("/v1/export/pdf")
|
||||
async def export_pdf(file: UploadFile = File(...), api_key: str = Security(get_api_key)):
|
||||
request_id = str(uuid.uuid4())[:8]
|
||||
original_name = file.filename or "document.docx"
|
||||
base_name = os.path.splitext(original_name)[0] or "document"
|
||||
|
||||
try:
|
||||
file_bytes = await file.read()
|
||||
logger.info(
|
||||
"[%s] /v1/export/pdf filename=%s file_bytes=%d",
|
||||
request_id,
|
||||
original_name,
|
||||
len(file_bytes),
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
input_path = os.path.join(temp_dir, f"{base_name}.docx")
|
||||
output_path = os.path.join(temp_dir, f"{base_name}.pdf")
|
||||
|
||||
with open(input_path, "wb") as tmp_file:
|
||||
tmp_file.write(file_bytes)
|
||||
|
||||
await asyncio.to_thread(_convert_docx_to_pdf, input_path, output_path)
|
||||
|
||||
if not os.path.exists(output_path):
|
||||
raise RuntimeError("PDF 转换后未生成输出文件")
|
||||
|
||||
with open(output_path, "rb") as pdf_file:
|
||||
pdf_bytes = pdf_file.read()
|
||||
|
||||
logger.info("[%s] /v1/export/pdf success pdf_bytes=%d", request_id, len(pdf_bytes))
|
||||
headers = {
|
||||
"Content-Disposition": f'attachment; filename="{base_name}.pdf"',
|
||||
}
|
||||
return Response(content=pdf_bytes, media_type="application/pdf", headers=headers)
|
||||
except Exception as e:
|
||||
logger.exception("[%s] /v1/export/pdf failed: %s", request_id, e)
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
|
||||
|
||||
# TTS and STT routes
|
||||
# TTS and ASR routes
|
||||
from tts_asr import register_tts_asr_routes
|
||||
register_tts_asr_routes(app)
|
||||
|
||||
|
||||
255
backend/tts_asr.py
Normal file
255
backend/tts_asr.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# TTS and ASR API for macOS Silicon with HuggingFace transformers
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Security
|
||||
from pydantic import BaseModel
|
||||
import numpy as np
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger("tts_asr")
|
||||
|
||||
_tts_pipeline = None
|
||||
_asr_pipeline = None
|
||||
_device = None
|
||||
|
||||
|
||||
def _get_device():
|
||||
global _device
|
||||
if _device is not None:
|
||||
return _device
|
||||
|
||||
import torch
|
||||
|
||||
if platform.system() == "Darwin" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
_device = "mps"
|
||||
logger.info("[Device] 使用 MPS 加速")
|
||||
elif torch.cuda.is_available():
|
||||
_device = "cuda"
|
||||
logger.info("[Device] 使用 CUDA 加速")
|
||||
else:
|
||||
_device = "cpu"
|
||||
logger.info("[Device] 使用 CPU")
|
||||
return _device
|
||||
|
||||
|
||||
def _device_arg():
|
||||
device = _get_device()
|
||||
if device == "cuda":
|
||||
return "cuda:0"
|
||||
return device
|
||||
|
||||
|
||||
def _get_tts_pipeline():
|
||||
global _tts_pipeline
|
||||
if _tts_pipeline is not None:
|
||||
return _tts_pipeline
|
||||
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
logger.info("[TTS] 加载 Kokoro-82M 模型...")
|
||||
_tts_pipeline = pipeline(
|
||||
"text-to-speech",
|
||||
model="hexgrad/Kokoro-82M",
|
||||
trust_remote_code=True,
|
||||
device=_device_arg(),
|
||||
torch_dtype=torch.float16 if _get_device() != "cpu" else torch.float32,
|
||||
)
|
||||
logger.info("[TTS] Kokoro-82M 模型加载完成")
|
||||
return _tts_pipeline
|
||||
|
||||
|
||||
def _get_asr_pipeline():
|
||||
global _asr_pipeline
|
||||
if _asr_pipeline is not None:
|
||||
return _asr_pipeline
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
||||
|
||||
logger.info("[ASR] 加载 Whisper large-v3-turbo 模型...")
|
||||
model_id = "openai/whisper-large-v3-turbo"
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.float16 if _get_device() != "cpu" else torch.float32,
|
||||
low_cpu_mem_usage=True,
|
||||
use_safetensors=True,
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
_asr_pipeline = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model=model,
|
||||
tokenizer=processor.tokenizer,
|
||||
feature_extractor=processor.feature_extractor,
|
||||
torch_dtype=torch.float16 if _get_device() != "cpu" else torch.float32,
|
||||
device=_device_arg(),
|
||||
)
|
||||
logger.info("[ASR] Whisper large-v3-turbo 模型加载完成")
|
||||
return _asr_pipeline
|
||||
|
||||
|
||||
def _save_audio_to_wav(audio_data: bytes, sample_rate: int = 16000) -> str:
|
||||
import tempfile
|
||||
import wave
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False, mode="wb") as tmp:
|
||||
with wave.open(tmp.name, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio_data)
|
||||
return tmp.name
|
||||
|
||||
|
||||
def _tts_sync(text: str, voice: str = "af_bella", rate: float = 1.0) -> tuple[bytes, int]:
|
||||
tts = _get_tts_pipeline()
|
||||
result = tts(text, voice=voice)
|
||||
audio = None
|
||||
sample_rate = 24000
|
||||
if isinstance(result, dict):
|
||||
audio = result.get("audio")
|
||||
sample_rate = int(result.get("sampling_rate", sample_rate))
|
||||
elif isinstance(result, (list, tuple)) and result:
|
||||
audio = result[0]
|
||||
|
||||
if audio is None:
|
||||
raise RuntimeError("Kokoro 未返回音频数据")
|
||||
|
||||
if hasattr(audio, "cpu"):
|
||||
audio = audio.cpu().numpy()
|
||||
|
||||
duration_ms = int(len(audio) * 1000 / sample_rate)
|
||||
|
||||
if audio.dtype != np.int16:
|
||||
audio = (audio * 32767).astype(np.int16)
|
||||
|
||||
import tempfile
|
||||
import wave
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
output_path = tmp.name
|
||||
try:
|
||||
with wave.open(output_path, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio.tobytes())
|
||||
with open(output_path, "rb") as f:
|
||||
return f.read(), duration_ms
|
||||
finally:
|
||||
if os.path.exists(output_path):
|
||||
os.unlink(output_path)
|
||||
|
||||
|
||||
async def _text_to_speech(text: str, voice: str = "af_bella", rate: float = 1.0) -> tuple[bytes, int]:
|
||||
return await asyncio.to_thread(_tts_sync, text, voice, rate)
|
||||
|
||||
|
||||
def _asr_sync(audio_data: bytes, language: str = "zh") -> str:
|
||||
import soundfile as sf
|
||||
|
||||
asr = _get_asr_pipeline()
|
||||
audio_path = _save_audio_to_wav(audio_data)
|
||||
try:
|
||||
audio_array, sample_rate = sf.read(audio_path)
|
||||
result = asr(
|
||||
audio_array,
|
||||
sampling_rate=sample_rate,
|
||||
generate_kwargs={"language": language, "task": "transcribe"},
|
||||
)
|
||||
if isinstance(result, dict):
|
||||
return result.get("text", "").strip()
|
||||
return str(result).strip()
|
||||
finally:
|
||||
if os.path.exists(audio_path):
|
||||
os.unlink(audio_path)
|
||||
|
||||
|
||||
async def _speech_to_text(audio_data: bytes, language: str = "zh") -> str:
|
||||
return await asyncio.to_thread(_asr_sync, audio_data, language)
|
||||
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str
|
||||
voice: str = "af_bella"
|
||||
rate: float = 1.0
|
||||
format: str = "wav"
|
||||
|
||||
|
||||
class TTSResponse(BaseModel):
|
||||
audio_base64: str
|
||||
format: str
|
||||
duration_ms: int
|
||||
|
||||
|
||||
class ASRRequest(BaseModel):
|
||||
audio_base64: str
|
||||
language: str = "zh-CN"
|
||||
|
||||
|
||||
class ASRResponse(BaseModel):
|
||||
text: str
|
||||
language: str
|
||||
|
||||
|
||||
def get_api_key(api_key: str):
|
||||
from backend.main import API_KEY
|
||||
|
||||
if api_key != API_KEY:
|
||||
raise HTTPException(status_code=403, detail="API Key 无效")
|
||||
return api_key
|
||||
|
||||
|
||||
@router.post("/tts", response_model=TTSResponse)
|
||||
async def text_to_speech(req: TTSRequest, api_key: str = Security(get_api_key)):
|
||||
request_id = str(hash(req.text))[:8]
|
||||
try:
|
||||
logger.info("[TTS][%s] text_chars=%d voice=%s format=%s", request_id, len(req.text), req.voice, req.format)
|
||||
audio_data, duration_ms = await _text_to_speech(req.text, req.voice, req.rate)
|
||||
if req.format.lower() == "mp3":
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in:
|
||||
tmp_in.write(audio_data)
|
||||
input_path = tmp_in.name
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_out:
|
||||
output_path = tmp_out.name
|
||||
try:
|
||||
cmd = ["ffmpeg", "-i", input_path, "-acodec", "libmp3lame", "-ab", "128k", output_path]
|
||||
result = await asyncio.to_thread(lambda: subprocess.run(cmd, capture_output=True, text=True, timeout=30))
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"MP3 转换失败: {result.stderr}")
|
||||
with open(output_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
finally:
|
||||
for path in [input_path, output_path]:
|
||||
if os.path.exists(path):
|
||||
os.unlink(path)
|
||||
logger.info("[TTS][%s] success duration_ms=%d", request_id, duration_ms)
|
||||
return TTSResponse(audio_base64=base64.b64encode(audio_data).decode(), format=req.format, duration_ms=duration_ms)
|
||||
except Exception as e:
|
||||
logger.exception("[TTS] failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/asr", response_model=ASRResponse)
|
||||
async def speech_to_text(req: ASRRequest, api_key: str = Security(get_api_key)):
|
||||
request_id = str(hash(req.audio_base64))[:8]
|
||||
try:
|
||||
logger.info("[ASR][%s] audio_base64_chars=%d language=%s", request_id, len(req.audio_base64), req.language)
|
||||
audio_data = base64.b64decode(req.audio_base64)
|
||||
text = await _speech_to_text(audio_data, req.language[:2])
|
||||
logger.info("[ASR][%s] success text_chars=%d", request_id, len(text))
|
||||
return ASRResponse(text=text, language=req.language)
|
||||
except Exception as e:
|
||||
logger.exception("[ASR] failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def register_tts_asr_routes(app):
|
||||
app.include_router(router, prefix="/v1/tts-asr")
|
||||
@@ -1,141 +0,0 @@
|
||||
# TTS and Speech Recognition API for macOS Silicon
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import base64
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Security
|
||||
from pydantic import BaseModel
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
router = APIRouter()
|
||||
api_key_header = APIKeyHeader(name="X-API-Key")
|
||||
logger = logging.getLogger("tts_stt")
|
||||
|
||||
|
||||
def _speak_text_macos(text: str, voice: str = "meijia", rate: float = 0.5) -> bytes:
|
||||
import subprocess
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
output_path = tmp.name
|
||||
try:
|
||||
cmd = ["say", "-v", voice, "-r", str(rate * 10), "--output-format", "WAVE", "-o", output_path, text]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||||
if result.returncode != 0:
|
||||
raise Exception(f"TTS failed: {result.stderr}")
|
||||
with open(output_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
return audio_data
|
||||
finally:
|
||||
if os.path.exists(output_path):
|
||||
os.unlink(output_path)
|
||||
|
||||
|
||||
async def _speak_text_macos_async(text: str, voice: str = "meijia", rate: float = 0.5) -> bytes:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, _speak_text_macos, text, voice, rate)
|
||||
|
||||
|
||||
def _recognize_speech_macos(audio_data: bytes, language: str = "zh-CN") -> str:
|
||||
import tempfile
|
||||
try:
|
||||
import whisper
|
||||
model = whisper.load_model("tiny")
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp2:
|
||||
tmp2.write(audio_data)
|
||||
audio_for_whisper = tmp2.name
|
||||
try:
|
||||
result = model.transcribe(audio_for_whisper, language=language[:2])
|
||||
return result["text"]
|
||||
finally:
|
||||
if os.path.exists(audio_for_whisper):
|
||||
os.unlink(audio_for_whisper)
|
||||
except ImportError:
|
||||
raise Exception("Whisper is required for speech recognition on macOS")
|
||||
|
||||
|
||||
async def _recognize_speech_macos_async(audio_data: bytes, language: str = "zh-CN") -> str:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, _recognize_speech_macos, audio_data, language)
|
||||
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str
|
||||
voice: str = "meijia"
|
||||
rate: float = 0.5
|
||||
format: str = "wav"
|
||||
|
||||
|
||||
class TTSResponse(BaseModel):
|
||||
audio_base64: str
|
||||
format: str
|
||||
duration_ms: int
|
||||
|
||||
|
||||
class STTRequest(BaseModel):
|
||||
audio_base64: str
|
||||
language: str = "zh-CN"
|
||||
|
||||
|
||||
class STTResponse(BaseModel):
|
||||
text: str
|
||||
language: str
|
||||
|
||||
|
||||
@router.post("/tts", response_model=TTSResponse)
|
||||
async def text_to_speech(req: TTSRequest, api_key: str = Security(get_api_key)):
|
||||
request_id = str(hash(req.text))[:8]
|
||||
try:
|
||||
logger.info("[TTS][%s] text_chars=%d voice=%s", request_id, len(req.text), req.voice)
|
||||
audio_data = await _speak_text_macos_async(req.text, req.voice, req.rate)
|
||||
if req.format.lower() == "mp3":
|
||||
import tempfile
|
||||
import subprocess
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in:
|
||||
tmp_in.write(audio_data)
|
||||
input_path = tmp_in.name
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_out:
|
||||
output_path = tmp_out.name
|
||||
try:
|
||||
cmd = ["ffmpeg", "-i", input_path, "-acodec", "libmp3lame", output_path]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||||
if result.returncode != 0:
|
||||
raise Exception(f"MP3 conversion failed: {result.stderr}")
|
||||
with open(output_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
finally:
|
||||
for p in [input_path, output_path]:
|
||||
if os.path.exists(p):
|
||||
os.unlink(p)
|
||||
duration_ms = len(audio_data) * 1000 // 16000
|
||||
logger.info("[TTS][%s] success duration_ms=%d", request_id, duration_ms)
|
||||
return TTSResponse(audio_base64=base64.b64encode(audio_data).decode(), format=req.format, duration_ms=duration_ms)
|
||||
except Exception as e:
|
||||
logger.exception("[TTS] failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/stt", response_model=STTResponse)
|
||||
async def speech_to_text(req: STTRequest, api_key: str = Security(get_api_key)):
|
||||
request_id = str(hash(req.audio_base64))[:8]
|
||||
try:
|
||||
logger.info("[STT][%s] audio_base64_chars=%d language=%s", request_id, len(req.audio_base64), req.language)
|
||||
audio_data = base64.b64decode(req.audio_base64)
|
||||
text = await _recognize_speech_macos_async(audio_data, req.language)
|
||||
logger.info("[STT][%s] success text_chars=%d", request_id, len(text))
|
||||
return STTResponse(text=text, language=req.language)
|
||||
except Exception as e:
|
||||
logger.exception("[STT] failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def get_api_key(api_key: str):
|
||||
from backend.main import API_KEY
|
||||
if api_key != API_KEY:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=403, detail="Could not validate credentials")
|
||||
return api_key
|
||||
|
||||
|
||||
def register_tts_stt_routes(app):
|
||||
app.include_router(router, prefix="/v1/tts-stt")
|
||||
Reference in New Issue
Block a user