feat(core): add ModelScope support for TTS and new office load status

Add support to download and load TTS model from ModelScope, with a fallback to the HuggingFace mirror.
Implement a `documentLoadStatus` property and helper functions in `office.js` to track file loading state.
Improve request cancellation logic in `api.js`, ensuring proper cancel URL resolution and request‑id handling.

These changes enhance robustness, reduce external dependencies, and provide better UX for office file handling.
This commit is contained in:
2026-04-11 10:04:34 +08:00
parent d8b7832b14
commit f99acf5d50
5 changed files with 248 additions and 186 deletions

View File

@@ -8,6 +8,7 @@ torchaudio>=0.12.0
transformers>=4.25.0
whisper>=1.0.0
qwen-tts>=0.0.0
modelscope>=1.20.0
# testing
pytest>=7.0.0

View File

@@ -25,6 +25,10 @@ router = APIRouter()
# Global TTS model instance
_tts_model: Optional["Qwen3TTSModel"] = None
# Model paths for loading
MODEL_ID_HF = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign"
MODEL_ID_MS = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign"
def _get_device_map() -> str:
"""设备检测逻辑:优先 CUDA其次 MPS最后 CPU"""
@@ -38,6 +42,24 @@ def _get_device_map() -> str:
return "cpu"
def _download_model_from_modelscope() -> Optional[str]:
"""从 ModelScope 下载模型到本地临时目录"""
try:
from modelscope import snapshot_download
cache_dir = os.path.join(os.path.dirname(__file__), "models")
os.makedirs(cache_dir, exist_ok=True)
model_dir = snapshot_download(
MODEL_ID_MS,
cache_dir=cache_dir,
revision="master"
)
logger.info("ModelScope 模型下载完成: %s", model_dir)
return model_dir
except Exception as e:
logger.warning("ModelScope 下载失败: %s", e)
return None
async def _warmup_tts():
"""预热 TTS 模型"""
await asyncio.to_thread(_load_tts_model_with_retry)
@@ -58,28 +80,42 @@ def _load_tts_model_with_retry(max_retries: int = 3) -> "Qwen3TTSModel":
if Qwen3TTSModel is None:
raise RuntimeError("qwen_tts 库未安装,无法加载 TTS 模型")
candidates = [
"Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign",
"ModelScope/Qwen3-TTS-12Hz-1.7B-VoiceDesign",
]
device_map = _get_device_map()
last_err = None
for i, model_id in enumerate(candidates, start=1):
# 策略1: 尝试从 ModelScope 下载后加载
for attempt in range(max_retries):
try:
logger.info("尝试从 ModelScope 下载模型...")
model_path = _download_model_from_modelscope()
if model_path and os.path.isdir(model_path):
_tts_model = Qwen3TTSModel.from_pretrained(
model_path,
device_map=device_map,
dtype=torch.float16,
)
logger.info("ModelScope 模型加载成功: %s", model_path)
return _tts_model
except Exception as e:
logger.warning("ModelScope 加载失败 (尝试 %d/%d): %s", attempt + 1, max_retries, e)
last_err = e
# 策略2: 尝试从 HuggingFace 镜像加载
for attempt in range(max_retries):
try:
logger.info("尝试从 HuggingFace 镜像加载模型...")
_tts_model = Qwen3TTSModel.from_pretrained(
model_id,
MODEL_ID_HF,
device_map=device_map,
dtype=torch.float16,
attn_implementation="flash_attention_2",
)
logger.info("Loaded TTS model from %s", model_id)
logger.info("HuggingFace 模型加载成功")
return _tts_model
except Exception as e:
logger.warning("Failed to load TTS model from %s: %s", model_id, e)
logger.warning("HuggingFace 加载失败 (尝试 %d/%d): %s", attempt + 1, max_retries, e)
last_err = e
if i >= max_retries:
break
raise RuntimeError(f"Unable to load TTS model from sources: {candidates}") from last_err
raise RuntimeError(f"无法加载 TTS 模型: {last_err}") from last_err
class TTSRequest(BaseModel):
@@ -160,10 +196,10 @@ async def tts_endpoint(req: TTSRequest):
speaker = req.speaker or "Vivian"
try:
wavs_sr = model.generate_custom_voice(
# VoiceDesign 模型使用 generate_voice_design 方法
wavs_sr = model.generate_voice_design(
text=text,
language="Chinese",
speaker=speaker,
instruct=instruct,
)
except Exception as e:

View File

@@ -129,56 +129,33 @@ const loadDocumentIntoUniver = async (instance, blob, fileName, fmt) => {
const lowerFmt = (fmt || '').toLowerCase();
console.log(`[Univer] 开始加载文档: ${fileName}, 格式: ${lowerFmt}`);
// 尝试使用服务端导入API如果可用
// 注意这些API需要后端服务支持纯前端模式下会失败
// 重要提示Univer 纯前端模式不支持直接加载 DOCX/XLSX/PPTX 文件
// 这些格式需要后端服务进行转换
// 当前实现创建空白文档作为预览占位符
try {
if (lowerFmt === 'docx' && typeof univerAPI.importDOCXToSnapshotAsync === 'function') {
console.log('[Univer] 尝试使用 importDOCXToSnapshotAsync...');
const snapshot = await univerAPI.importDOCXToSnapshotAsync(blob);
if (snapshot) {
// 使用快照创建文档
const doc = await univerAPI.createUniverDoc(snapshot);
console.log('[Univer] DOCX 文档加载成功');
return;
}
} else if (lowerFmt === 'xlsx' && typeof univerAPI.importXLSXToSnapshotAsync === 'function') {
console.log('[Univer] 尝试使用 importXLSXToSnapshotAsync...');
const snapshot = await univerAPI.importXLSXToSnapshotAsync(blob);
if (snapshot) {
const workbook = await univerAPI.createWorkbook(snapshot);
console.log('[Univer] XLSX 文档加载成功');
return;
// 尝试创建对应格式的空白文档
if (lowerFmt === 'xlsx') {
await univerAPI.createWorkbook({});
console.log('[Univer] 创建空白 Excel 工作簿');
} else if (lowerFmt === 'pptx') {
// PPTX: 尝试使用 Slides API如果不可用则降级到 Docs
if (typeof univerAPI.createUniverSlide === 'function') {
await univerAPI.createUniverSlide({});
console.log('[Univer] 创建空白 PPT 演示文稿');
} else {
await univerAPI.createUniverDoc({});
console.log('[Univer] Slides API 不可用,创建空白文档');
}
} else {
// DOCX 和默认情况
await univerAPI.createUniverDoc({});
console.log('[Univer] 创建空白 Word 文档');
}
} catch (e) {
console.warn('[Univer] 服务端导入API不可用或失败使用纯前端模式:', e.message);
console.error('[Univer] 创建文档失败:', e);
throw new Error(`无法创建 ${lowerFmt.toUpperCase()} 文档: ${e.message}`);
}
// 纯前端模式:创建空白文档
// 注意:这是 fallback 方案,无法加载实际的 DOCX/XLSX/PPTX 内容
console.log('[Univer] 使用纯前端模式创建空白文档');
if (lowerFmt === 'xlsx') {
await univerAPI.createWorkbook({});
console.log('[Univer] 创建空白 Excel 工作簿');
} else if (lowerFmt === 'pptx') {
// PPTX 需要创建 Slides 文档
// 注意:需要先检查是否有 createUniverSlide 方法
if (typeof univerAPI.createUniverSlide === 'function') {
await univerAPI.createUniverSlide({});
console.log('[Univer] 创建空白 PPT 演示文稿');
} else {
// Fallback 到普通文档
await univerAPI.createUniverDoc({});
console.log('[Univer] Slides API 不可用,创建空白文档');
}
} else {
await univerAPI.createUniverDoc({});
console.log('[Univer] 创建空白 Word 文档');
}
// 提示用户当前是纯前端模式
console.warn('[Univer] 当前为纯前端模式,无法加载实际的 DOCX/XLSX/PPTX 文件内容。如需完整功能,请配置后端服务。');
};
const destroyUniver = () => {

View File

@@ -8,18 +8,22 @@ export const useOfficeStore = defineStore('office', () => {
const currentFormat = ref(null)
const currentFileSize = ref(0)
const currentBytes = ref(null)
// 快照模式
const isSnapshotMode = ref(true) // 默认启用快照模式
const currentSnapshot = ref(null)
// 编辑状态
const isEditing = ref(false)
const hasUnsavedChanges = ref(false)
// 视图状态
const activeView = ref('milkdown') // 'milkdown' | 'univer'
// 文档加载状态(新增)
const documentLoadStatus = ref('idle') // 'idle' | 'loading' | 'success' | 'error'
const documentErrorMessage = ref('')
// 计算属性
const hasDocument = computed(() => {
return currentFileName.value && currentFormat.value
@@ -31,7 +35,8 @@ export const useOfficeStore = defineStore('office', () => {
name: currentFileName.value,
format: currentFormat.value,
size: currentFileSize.value,
isSnapshot: isSnapshotMode.value
isSnapshot: isSnapshotMode.value,
loadStatus: documentLoadStatus.value
}
})
@@ -49,6 +54,8 @@ export const useOfficeStore = defineStore('office', () => {
currentFileSize.value = file.size || 0
currentBytes.value = bytes
hasUnsavedChanges.value = false
documentLoadStatus.value = 'idle'
documentErrorMessage.value = ''
}
/**
@@ -61,6 +68,8 @@ export const useOfficeStore = defineStore('office', () => {
currentBytes.value = null
currentSnapshot.value = null
hasUnsavedChanges.value = false
documentLoadStatus.value = 'idle'
documentErrorMessage.value = ''
}
/**
@@ -92,6 +101,39 @@ export const useOfficeStore = defineStore('office', () => {
isSnapshotMode.value = !isSnapshotMode.value
}
// 新增:文档加载状态管理方法
/**
* 开始加载文档
*/
function startDocumentLoad() {
documentLoadStatus.value = 'loading'
documentErrorMessage.value = ''
}
/**
* 文档加载成功
*/
function completeDocumentLoad() {
documentLoadStatus.value = 'success'
documentErrorMessage.value = ''
}
/**
* 文档加载失败
*/
function failDocumentLoad(message) {
documentLoadStatus.value = 'error'
documentErrorMessage.value = message || '文档加载失败'
}
/**
* 重置文档加载状态
*/
function resetDocumentLoadStatus() {
documentLoadStatus.value = 'idle'
documentErrorMessage.value = ''
}
return {
// 状态
currentFileName,
@@ -103,6 +145,8 @@ export const useOfficeStore = defineStore('office', () => {
isEditing,
hasUnsavedChanges,
activeView,
documentLoadStatus,
documentErrorMessage,
// 计算属性
hasDocument,
@@ -114,7 +158,11 @@ export const useOfficeStore = defineStore('office', () => {
setSnapshot,
markAsChanged,
switchView,
toggleSnapshotMode
toggleSnapshotMode,
startDocumentLoad,
completeDocumentLoad,
failDocumentLoad,
resetDocumentLoadStatus
}
})

View File

@@ -2,159 +2,159 @@ import { API_URL, API_KEY, TTS_URL, TTS_STATUS_URL, TTS_CONFIG_URL } from './con
import { useSettingsStore } from '../stores/settings'
function generateRequestId() {
if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') {
return crypto.randomUUID()
}
return `${Date.now()}-${Math.random().toString(16).slice(2)}`
if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') {
return crypto.randomUUID()
}
return `${Date.now()}-${Math.random().toString(16).slice(2)}`
}
function getCancelUrl(apiUrl) {
const normalized = String(apiUrl || '').replace(/\/+$/, '')
if (!normalized) return '/v1/completions/cancel'
if (normalized.endsWith('/v1/completions')) {
return `${normalized}/cancel`
}
const normalized = String(apiUrl || '').replace(/\/+$/, '')
if (!normalized) return '/v1/completions/cancel'
if (normalized.endsWith('/v1/completions')) {
return `${normalized}/cancel`
}
return `${normalized}/cancel`
}
function normalizeAbortReason(reason) {
if (typeof reason === 'string' && reason.trim()) {
return reason.trim().slice(0, 64)
}
return 'abort'
if (typeof reason === 'string' && reason.trim()) {
return reason.trim().slice(0, 64)
}
return 'abort'
}
async function sendCancelRequest(cancelUrl, requestId, reason) {
try {
await fetch(cancelUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-API-Key': API_KEY,
},
body: JSON.stringify({
request_id: requestId,
reason,
}),
})
} catch {
// Cancel request failed silently
}
try {
await fetch(cancelUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-API-Key': API_KEY,
},
body: JSON.stringify({
request_id: requestId,
reason,
}),
})
} catch {
// Cancel request failed silently
}
}
export async function fetchSuggestion(prefix, suffix, languageId, signal, apiUrl = API_URL) {
let normalizedLanguageId = 'markdown'
if (typeof languageId === 'string' && languageId.trim()) {
normalizedLanguageId = languageId.trim()
} else if (languageId && typeof languageId === 'object' && 'aborted' in languageId) {
signal = languageId
}
if (typeof signal === 'string') {
apiUrl = signal
signal = undefined
}
const requestId = generateRequestId()
const cancelUrl = getCancelUrl(apiUrl)
let normalizedLanguageId = 'markdown'
if (typeof languageId === 'string' && languageId.trim()) {
normalizedLanguageId = languageId.trim()
} else if (languageId && typeof languageId === 'object' && 'aborted' in languageId) {
signal = languageId
}
if (typeof signal === 'string') {
apiUrl = signal
signal = undefined
}
const requestId = generateRequestId()
const cancelUrl = getCancelUrl(apiUrl)
const onAbort = () => {
const reason = normalizeAbortReason(signal?.reason)
void sendCancelRequest(cancelUrl, requestId, reason)
const onAbort = () => {
const reason = normalizeAbortReason(signal?.reason)
void sendCancelRequest(cancelUrl, requestId, reason)
}
if (signal) {
if (signal.aborted) {
onAbort()
} else {
signal.addEventListener('abort', onAbort, { once: true })
}
}
try {
const settings = useSettingsStore()
const headers = {
'Content-Type': 'application/json',
'X-Request-Id': requestId,
'X-API-Key': API_KEY,
}
const body = {
prefix,
suffix,
languageId: normalizedLanguageId,
model_thinking: settings.modelThinking,
privacy_mode: settings.privacyMode,
user_preferences: {
language: settings.language,
currency: settings.currency,
timezone: settings.detectedTimezone,
},
}
const res = await fetch(apiUrl, {
method: 'POST',
headers,
body: JSON.stringify(body),
signal,
})
if (!res.ok) {
const errorText = await res.text()
throw new Error(`HTTP ${res.status}: ${errorText}`)
}
const data = await res.json()
return data.content || ''
} catch (e) {
if (e.name === 'AbortError') {
// ignore abort
} else {
throw e
}
} finally {
if (signal) {
if (signal.aborted) {
onAbort()
} else {
signal.addEventListener('abort', onAbort, { once: true })
}
}
try {
const settings = useSettingsStore()
const headers = {
'Content-Type': 'application/json',
'X-Request-Id': requestId,
'X-API-Key': API_KEY,
}
const body = {
prefix,
suffix,
languageId: normalizedLanguageId,
model_thinking: settings.modelThinking,
privacy_mode: settings.privacyMode,
user_preferences: {
language: settings.language,
currency: settings.currency,
timezone: settings.detectedTimezone,
},
}
const res = await fetch(apiUrl, {
method: 'POST',
headers,
body: JSON.stringify(body),
signal,
})
if (!res.ok) {
const errorText = await res.text()
throw new Error(`HTTP ${res.status}: ${errorText}`)
}
const data = await res.json()
return data.content || ''
} catch (e) {
if (e.name === 'AbortError') {
// ignore abort
} else {
throw e
}
} finally {
if (signal) {
signal.removeEventListener('abort', onAbort)
}
signal.removeEventListener('abort', onAbort)
}
}
}
export async function fetchTTS(text, instruct = '', apiUrl = TTS_URL) {
const res = await fetch(apiUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-API-Key': API_KEY,
},
body: JSON.stringify({ text, instruct, speaker: 'Vivian', format: 'wav' }),
})
const res = await fetch(apiUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-API-Key': API_KEY,
},
body: JSON.stringify({ text, instruct, speaker: 'Vivian', format: 'wav' }),
})
if (!res.ok) {
const errorText = await res.text()
throw new Error(`TTS HTTP ${res.status}: ${errorText}`)
}
if (!res.ok) {
const errorText = await res.text()
throw new Error(`TTS HTTP ${res.status}: ${errorText}`)
}
return res.json()
return res.json()
}
export async function fetchTTSStatus(apiUrl = TTS_STATUS_URL) {
const res = await fetch(apiUrl, {
headers: { 'X-API-Key': API_KEY },
})
const res = await fetch(apiUrl, {
headers: { 'X-API-Key': API_KEY },
})
if (!res.ok) {
throw new Error(`TTS Status HTTP ${res.status}`)
}
if (!res.ok) {
throw new Error(`TTS Status HTTP ${res.status}`)
}
return res.json()
return res.json()
}
export async function fetchTTSConfig(apiUrl = TTS_CONFIG_URL) {
const res = await fetch(apiUrl, {
headers: { 'X-API-Key': API_KEY },
})
const res = await fetch(apiUrl, {
headers: { 'X-API-Key': API_KEY },
})
if (!res.ok) {
throw new Error(`TTS Config HTTP ${res.status}`)
}
if (!res.ok) {
throw new Error(`TTS Config HTTP ${res.status}`)
}
return res.json()
return res.json()
}