diff --git a/backend/requirements.txt b/backend/requirements.txt index 41147ee..94dfd75 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -8,6 +8,7 @@ torchaudio>=0.12.0 transformers>=4.25.0 whisper>=1.0.0 qwen-tts>=0.0.0 +modelscope>=1.20.0 # testing pytest>=7.0.0 diff --git a/backend/tts_asr.py b/backend/tts_asr.py index fdb20dc..f0ffe8d 100644 --- a/backend/tts_asr.py +++ b/backend/tts_asr.py @@ -25,6 +25,10 @@ router = APIRouter() # Global TTS model instance _tts_model: Optional["Qwen3TTSModel"] = None +# Model paths for loading +MODEL_ID_HF = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign" +MODEL_ID_MS = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign" + def _get_device_map() -> str: """设备检测逻辑:优先 CUDA,其次 MPS,最后 CPU""" @@ -38,6 +42,24 @@ def _get_device_map() -> str: return "cpu" +def _download_model_from_modelscope() -> Optional[str]: + """从 ModelScope 下载模型到本地临时目录""" + try: + from modelscope import snapshot_download + cache_dir = os.path.join(os.path.dirname(__file__), "models") + os.makedirs(cache_dir, exist_ok=True) + model_dir = snapshot_download( + MODEL_ID_MS, + cache_dir=cache_dir, + revision="master" + ) + logger.info("ModelScope 模型下载完成: %s", model_dir) + return model_dir + except Exception as e: + logger.warning("ModelScope 下载失败: %s", e) + return None + + async def _warmup_tts(): """预热 TTS 模型""" await asyncio.to_thread(_load_tts_model_with_retry) @@ -58,28 +80,42 @@ def _load_tts_model_with_retry(max_retries: int = 3) -> "Qwen3TTSModel": if Qwen3TTSModel is None: raise RuntimeError("qwen_tts 库未安装,无法加载 TTS 模型") - candidates = [ - "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign", - "ModelScope/Qwen3-TTS-12Hz-1.7B-VoiceDesign", - ] device_map = _get_device_map() last_err = None - for i, model_id in enumerate(candidates, start=1): + + # 策略1: 尝试从 ModelScope 下载后加载 + for attempt in range(max_retries): try: + logger.info("尝试从 ModelScope 下载模型...") + model_path = _download_model_from_modelscope() + if model_path and os.path.isdir(model_path): + _tts_model = Qwen3TTSModel.from_pretrained( + model_path, + device_map=device_map, + dtype=torch.float16, + ) + logger.info("ModelScope 模型加载成功: %s", model_path) + return _tts_model + except Exception as e: + logger.warning("ModelScope 加载失败 (尝试 %d/%d): %s", attempt + 1, max_retries, e) + last_err = e + + # 策略2: 尝试从 HuggingFace 镜像加载 + for attempt in range(max_retries): + try: + logger.info("尝试从 HuggingFace 镜像加载模型...") _tts_model = Qwen3TTSModel.from_pretrained( - model_id, + MODEL_ID_HF, device_map=device_map, dtype=torch.float16, - attn_implementation="flash_attention_2", ) - logger.info("Loaded TTS model from %s", model_id) + logger.info("HuggingFace 模型加载成功") return _tts_model except Exception as e: - logger.warning("Failed to load TTS model from %s: %s", model_id, e) + logger.warning("HuggingFace 加载失败 (尝试 %d/%d): %s", attempt + 1, max_retries, e) last_err = e - if i >= max_retries: - break - raise RuntimeError(f"Unable to load TTS model from sources: {candidates}") from last_err + + raise RuntimeError(f"无法加载 TTS 模型: {last_err}") from last_err class TTSRequest(BaseModel): @@ -160,10 +196,10 @@ async def tts_endpoint(req: TTSRequest): speaker = req.speaker or "Vivian" try: - wavs_sr = model.generate_custom_voice( + # VoiceDesign 模型使用 generate_voice_design 方法 + wavs_sr = model.generate_voice_design( text=text, language="Chinese", - speaker=speaker, instruct=instruct, ) except Exception as e: diff --git a/src/components/UniverPreview.vue b/src/components/UniverPreview.vue index d6049f6..e3fcdab 100644 --- a/src/components/UniverPreview.vue +++ b/src/components/UniverPreview.vue @@ -129,56 +129,33 @@ const loadDocumentIntoUniver = async (instance, blob, fileName, fmt) => { const lowerFmt = (fmt || '').toLowerCase(); console.log(`[Univer] 开始加载文档: ${fileName}, 格式: ${lowerFmt}`); - // 尝试使用服务端导入API(如果可用) - // 注意:这些API需要后端服务支持,纯前端模式下会失败 + // 重要提示:Univer 纯前端模式不支持直接加载 DOCX/XLSX/PPTX 文件 + // 这些格式需要后端服务进行转换 + // 当前实现创建空白文档作为预览占位符 + try { - if (lowerFmt === 'docx' && typeof univerAPI.importDOCXToSnapshotAsync === 'function') { - console.log('[Univer] 尝试使用 importDOCXToSnapshotAsync...'); - const snapshot = await univerAPI.importDOCXToSnapshotAsync(blob); - if (snapshot) { - // 使用快照创建文档 - const doc = await univerAPI.createUniverDoc(snapshot); - console.log('[Univer] DOCX 文档加载成功'); - return; - } - } else if (lowerFmt === 'xlsx' && typeof univerAPI.importXLSXToSnapshotAsync === 'function') { - console.log('[Univer] 尝试使用 importXLSXToSnapshotAsync...'); - const snapshot = await univerAPI.importXLSXToSnapshotAsync(blob); - if (snapshot) { - const workbook = await univerAPI.createWorkbook(snapshot); - console.log('[Univer] XLSX 文档加载成功'); - return; + // 尝试创建对应格式的空白文档 + if (lowerFmt === 'xlsx') { + await univerAPI.createWorkbook({}); + console.log('[Univer] 创建空白 Excel 工作簿'); + } else if (lowerFmt === 'pptx') { + // PPTX: 尝试使用 Slides API,如果不可用则降级到 Docs + if (typeof univerAPI.createUniverSlide === 'function') { + await univerAPI.createUniverSlide({}); + console.log('[Univer] 创建空白 PPT 演示文稿'); + } else { + await univerAPI.createUniverDoc({}); + console.log('[Univer] Slides API 不可用,创建空白文档'); } + } else { + // DOCX 和默认情况 + await univerAPI.createUniverDoc({}); + console.log('[Univer] 创建空白 Word 文档'); } } catch (e) { - console.warn('[Univer] 服务端导入API不可用或失败,使用纯前端模式:', e.message); + console.error('[Univer] 创建文档失败:', e); + throw new Error(`无法创建 ${lowerFmt.toUpperCase()} 文档: ${e.message}`); } - - // 纯前端模式:创建空白文档 - // 注意:这是 fallback 方案,无法加载实际的 DOCX/XLSX/PPTX 内容 - console.log('[Univer] 使用纯前端模式创建空白文档'); - - if (lowerFmt === 'xlsx') { - await univerAPI.createWorkbook({}); - console.log('[Univer] 创建空白 Excel 工作簿'); - } else if (lowerFmt === 'pptx') { - // PPTX 需要创建 Slides 文档 - // 注意:需要先检查是否有 createUniverSlide 方法 - if (typeof univerAPI.createUniverSlide === 'function') { - await univerAPI.createUniverSlide({}); - console.log('[Univer] 创建空白 PPT 演示文稿'); - } else { - // Fallback 到普通文档 - await univerAPI.createUniverDoc({}); - console.log('[Univer] Slides API 不可用,创建空白文档'); - } - } else { - await univerAPI.createUniverDoc({}); - console.log('[Univer] 创建空白 Word 文档'); - } - - // 提示用户当前是纯前端模式 - console.warn('[Univer] 当前为纯前端模式,无法加载实际的 DOCX/XLSX/PPTX 文件内容。如需完整功能,请配置后端服务。'); }; const destroyUniver = () => { diff --git a/src/stores/office.js b/src/stores/office.js index 1d23481..230721e 100644 --- a/src/stores/office.js +++ b/src/stores/office.js @@ -8,18 +8,22 @@ export const useOfficeStore = defineStore('office', () => { const currentFormat = ref(null) const currentFileSize = ref(0) const currentBytes = ref(null) - + // 快照模式 const isSnapshotMode = ref(true) // 默认启用快照模式 const currentSnapshot = ref(null) - + // 编辑状态 const isEditing = ref(false) const hasUnsavedChanges = ref(false) - + // 视图状态 const activeView = ref('milkdown') // 'milkdown' | 'univer' + // 文档加载状态(新增) + const documentLoadStatus = ref('idle') // 'idle' | 'loading' | 'success' | 'error' + const documentErrorMessage = ref('') + // 计算属性 const hasDocument = computed(() => { return currentFileName.value && currentFormat.value @@ -31,7 +35,8 @@ export const useOfficeStore = defineStore('office', () => { name: currentFileName.value, format: currentFormat.value, size: currentFileSize.value, - isSnapshot: isSnapshotMode.value + isSnapshot: isSnapshotMode.value, + loadStatus: documentLoadStatus.value } }) @@ -49,6 +54,8 @@ export const useOfficeStore = defineStore('office', () => { currentFileSize.value = file.size || 0 currentBytes.value = bytes hasUnsavedChanges.value = false + documentLoadStatus.value = 'idle' + documentErrorMessage.value = '' } /** @@ -61,6 +68,8 @@ export const useOfficeStore = defineStore('office', () => { currentBytes.value = null currentSnapshot.value = null hasUnsavedChanges.value = false + documentLoadStatus.value = 'idle' + documentErrorMessage.value = '' } /** @@ -92,6 +101,39 @@ export const useOfficeStore = defineStore('office', () => { isSnapshotMode.value = !isSnapshotMode.value } + // 新增:文档加载状态管理方法 + /** + * 开始加载文档 + */ + function startDocumentLoad() { + documentLoadStatus.value = 'loading' + documentErrorMessage.value = '' + } + + /** + * 文档加载成功 + */ + function completeDocumentLoad() { + documentLoadStatus.value = 'success' + documentErrorMessage.value = '' + } + + /** + * 文档加载失败 + */ + function failDocumentLoad(message) { + documentLoadStatus.value = 'error' + documentErrorMessage.value = message || '文档加载失败' + } + + /** + * 重置文档加载状态 + */ + function resetDocumentLoadStatus() { + documentLoadStatus.value = 'idle' + documentErrorMessage.value = '' + } + return { // 状态 currentFileName, @@ -103,6 +145,8 @@ export const useOfficeStore = defineStore('office', () => { isEditing, hasUnsavedChanges, activeView, + documentLoadStatus, + documentErrorMessage, // 计算属性 hasDocument, @@ -114,7 +158,11 @@ export const useOfficeStore = defineStore('office', () => { setSnapshot, markAsChanged, switchView, - toggleSnapshotMode + toggleSnapshotMode, + startDocumentLoad, + completeDocumentLoad, + failDocumentLoad, + resetDocumentLoadStatus } }) diff --git a/src/utils/api.js b/src/utils/api.js index 4b24f12..519d326 100644 --- a/src/utils/api.js +++ b/src/utils/api.js @@ -2,159 +2,159 @@ import { API_URL, API_KEY, TTS_URL, TTS_STATUS_URL, TTS_CONFIG_URL } from './con import { useSettingsStore } from '../stores/settings' function generateRequestId() { - if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') { - return crypto.randomUUID() - } - return `${Date.now()}-${Math.random().toString(16).slice(2)}` + if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') { + return crypto.randomUUID() + } + return `${Date.now()}-${Math.random().toString(16).slice(2)}` } function getCancelUrl(apiUrl) { - const normalized = String(apiUrl || '').replace(/\/+$/, '') - if (!normalized) return '/v1/completions/cancel' - if (normalized.endsWith('/v1/completions')) { - return `${normalized}/cancel` - } + const normalized = String(apiUrl || '').replace(/\/+$/, '') + if (!normalized) return '/v1/completions/cancel' + if (normalized.endsWith('/v1/completions')) { return `${normalized}/cancel` + } + return `${normalized}/cancel` } function normalizeAbortReason(reason) { - if (typeof reason === 'string' && reason.trim()) { - return reason.trim().slice(0, 64) - } - return 'abort' + if (typeof reason === 'string' && reason.trim()) { + return reason.trim().slice(0, 64) + } + return 'abort' } async function sendCancelRequest(cancelUrl, requestId, reason) { - try { - await fetch(cancelUrl, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'X-API-Key': API_KEY, - }, - body: JSON.stringify({ - request_id: requestId, - reason, - }), - }) - } catch { - // Cancel request failed silently - } + try { + await fetch(cancelUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-API-Key': API_KEY, + }, + body: JSON.stringify({ + request_id: requestId, + reason, + }), + }) + } catch { + // Cancel request failed silently + } } export async function fetchSuggestion(prefix, suffix, languageId, signal, apiUrl = API_URL) { - let normalizedLanguageId = 'markdown' - if (typeof languageId === 'string' && languageId.trim()) { - normalizedLanguageId = languageId.trim() - } else if (languageId && typeof languageId === 'object' && 'aborted' in languageId) { - signal = languageId - } - if (typeof signal === 'string') { - apiUrl = signal - signal = undefined - } - const requestId = generateRequestId() - const cancelUrl = getCancelUrl(apiUrl) + let normalizedLanguageId = 'markdown' + if (typeof languageId === 'string' && languageId.trim()) { + normalizedLanguageId = languageId.trim() + } else if (languageId && typeof languageId === 'object' && 'aborted' in languageId) { + signal = languageId + } + if (typeof signal === 'string') { + apiUrl = signal + signal = undefined + } + const requestId = generateRequestId() + const cancelUrl = getCancelUrl(apiUrl) - const onAbort = () => { - const reason = normalizeAbortReason(signal?.reason) - void sendCancelRequest(cancelUrl, requestId, reason) + const onAbort = () => { + const reason = normalizeAbortReason(signal?.reason) + void sendCancelRequest(cancelUrl, requestId, reason) + } + + if (signal) { + if (signal.aborted) { + onAbort() + } else { + signal.addEventListener('abort', onAbort, { once: true }) + } + } + + try { + const settings = useSettingsStore() + const headers = { + 'Content-Type': 'application/json', + 'X-Request-Id': requestId, + 'X-API-Key': API_KEY, } + const body = { + prefix, + suffix, + languageId: normalizedLanguageId, + model_thinking: settings.modelThinking, + privacy_mode: settings.privacyMode, + user_preferences: { + language: settings.language, + currency: settings.currency, + timezone: settings.detectedTimezone, + }, + } + + const res = await fetch(apiUrl, { + method: 'POST', + headers, + body: JSON.stringify(body), + signal, + }) + + if (!res.ok) { + const errorText = await res.text() + throw new Error(`HTTP ${res.status}: ${errorText}`) + } + + const data = await res.json() + return data.content || '' + } catch (e) { + if (e.name === 'AbortError') { + // ignore abort + } else { + throw e + } + } finally { if (signal) { - if (signal.aborted) { - onAbort() - } else { - signal.addEventListener('abort', onAbort, { once: true }) - } - } - - try { - const settings = useSettingsStore() - const headers = { - 'Content-Type': 'application/json', - 'X-Request-Id': requestId, - 'X-API-Key': API_KEY, - } - - const body = { - prefix, - suffix, - languageId: normalizedLanguageId, - model_thinking: settings.modelThinking, - privacy_mode: settings.privacyMode, - user_preferences: { - language: settings.language, - currency: settings.currency, - timezone: settings.detectedTimezone, - }, - } - - const res = await fetch(apiUrl, { - method: 'POST', - headers, - body: JSON.stringify(body), - signal, - }) - - if (!res.ok) { - const errorText = await res.text() - throw new Error(`HTTP ${res.status}: ${errorText}`) - } - - const data = await res.json() - return data.content || '' - } catch (e) { - if (e.name === 'AbortError') { - // ignore abort - } else { - throw e - } - } finally { - if (signal) { - signal.removeEventListener('abort', onAbort) - } + signal.removeEventListener('abort', onAbort) } + } } export async function fetchTTS(text, instruct = '', apiUrl = TTS_URL) { - const res = await fetch(apiUrl, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'X-API-Key': API_KEY, - }, - body: JSON.stringify({ text, instruct, speaker: 'Vivian', format: 'wav' }), - }) + const res = await fetch(apiUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-API-Key': API_KEY, + }, + body: JSON.stringify({ text, instruct, speaker: 'Vivian', format: 'wav' }), + }) - if (!res.ok) { - const errorText = await res.text() - throw new Error(`TTS HTTP ${res.status}: ${errorText}`) - } + if (!res.ok) { + const errorText = await res.text() + throw new Error(`TTS HTTP ${res.status}: ${errorText}`) + } - return res.json() + return res.json() } export async function fetchTTSStatus(apiUrl = TTS_STATUS_URL) { - const res = await fetch(apiUrl, { - headers: { 'X-API-Key': API_KEY }, - }) + const res = await fetch(apiUrl, { + headers: { 'X-API-Key': API_KEY }, + }) - if (!res.ok) { - throw new Error(`TTS Status HTTP ${res.status}`) - } + if (!res.ok) { + throw new Error(`TTS Status HTTP ${res.status}`) + } - return res.json() + return res.json() } export async function fetchTTSConfig(apiUrl = TTS_CONFIG_URL) { - const res = await fetch(apiUrl, { - headers: { 'X-API-Key': API_KEY }, - }) + const res = await fetch(apiUrl, { + headers: { 'X-API-Key': API_KEY }, + }) - if (!res.ok) { - throw new Error(`TTS Config HTTP ${res.status}`) - } + if (!res.ok) { + throw new Error(`TTS Config HTTP ${res.status}`) + } - return res.json() + return res.json() }