diff --git a/backend/llm.py b/backend/llm.py index 2e08db2..0385aa9 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -1,52 +1,38 @@ import os -import json import ollama -from typing import AsyncGenerator +from dotenv import load_dotenv + +load_dotenv() OLLAMA_MODEL = os.getenv('OLLAMA_MODEL', 'gpt-oss:20b') OLLAMA_HOST = os.getenv('OLLAMA_HOST', 'http://192.168.0.120:11434') -if OLLAMA_HOST.endswith('/v1/'): - OLLAMA_HOST = OLLAMA_HOST[:-4] -elif OLLAMA_HOST.endswith('/v1'): - OLLAMA_HOST = OLLAMA_HOST[:-3] - -os.environ['OLLAMA_HOST'] = OLLAMA_HOST - -print(f"[LLM] Ollama host: {OLLAMA_HOST}") -print(f"[LLM] Model: {OLLAMA_MODEL}") - client = ollama.AsyncClient(host=OLLAMA_HOST) -async def stream_openai(prompt: str) -> AsyncGenerator[str, None]: - print(f"[LLM] Calling Ollama API with prompt length: {len(prompt)}") +async def call_ollama(prompt: str) -> dict: + """ + 调用 Ollama API 并返回 content 和 thinking。 + """ + response = await client.chat( + model=OLLAMA_MODEL, + messages=[{'role': 'user', 'content': prompt}], + stream=False, + options={ + 'temperature': 0.7, + 'repeat_penalty': 1.1, + }, + think='high' + ) - try: - print(f"[LLM] Awaiting client.chat...") - stream = await client.chat( - model=OLLAMA_MODEL, - messages=[{'role': 'user', 'content': prompt}], - stream=True, - options={ - 'temperature': 0.7, - 'repeat_penalty': 1.1, - }, - think='high' - ) - print(f"[LLM] Got stream object, starting iteration...") - - chunk_count = 0 - async for chunk in stream: - if chunk['message'] and chunk['message']['content']: - content = chunk['message']['content'] - chunk_count += 1 - print(f"[LLM] Chunk {chunk_count}: {content}") - yield json.dumps({"content": content}) - - print(f"[LLM] Stream complete, total chunks: {chunk_count}") - except Exception as e: - error_msg = f"Error: {str(e)}" - print(f"[LLM] Error: {error_msg}") - import traceback - traceback.print_exc() - yield json.dumps({"error": str(e)}) + content = "" + thinking = "" + + if hasattr(response, 'message') and response.message: + content = response.message.content or "" + thinking = getattr(response.message, 'thinking', '') or "" + elif isinstance(response, dict): + msg = response.get('message', {}) + content = msg.get('content', '') or "" + thinking = msg.get('thinking', '') or "" + + return {"content": content, "thinking": thinking} diff --git a/backend/main.py b/backend/main.py index 77dd5a7..018c4d4 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,10 +1,11 @@ -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel -import os import json -import re + +from prompt import build_prompt +from llm import call_ollama app = FastAPI() @@ -21,121 +22,26 @@ class CompletionRequest(BaseModel): suffix: str languageId: str = 'markdown' -def extract_completion_from_thinking(thinking: str) -> str: - """ - 从模型的 thinking 输出中提取实际的续写内容。 - 移除推理过程,保留实际的续写。 - """ - if not thinking: - return "" - - # 尝试找到实际的续写内容 - # 模型通常会在 thinking 中描述上下文,然后输出实际续写 - # 常见的模式是:推理过程以描述开始,然后直接输出续写 - - # 查找 "Continuation:" 或类似标记之后的内容 - continuation_match = re.search(r'Continuation[:\s]*([\s\S]*)', thinking, re.IGNORECASE) - if continuation_match: - result = continuation_match.group(1).strip() - # 移除可能的后续推理说明 - result = re.sub(r'\s*It seems like.*$', '', result, flags=re.IGNORECASE) - return result.strip() - - # 如果没有明确标记,尝试移除描述性内容 - # 查找 "We need to continue" 或类似开头 - continue_match = re.search(r'(?:We need to|Then we should|So we|I will|The|Thus)[,\s]+([A-Z][^.!?]*(?:[.!?]|$))', thinking) - if continue_match: - # 取找到的句子及其后续内容 - start_idx = continue_match.start(1) - result = thinking[start_idx:].strip() - # 移除 "Probably " 开头及其后续内容 - result = re.sub(r'^Probably\s+', '', result) - # 如果有 "It seems like" 或类似短语,截断 - result = re.split(r'\s*It seems like\s', result, flags=re.IGNORECASE)[0] - return result.strip() - - # 最后的策略:直接返回 thinking,移除末尾的推理说明 - result = thinking.strip() - # 移除 "Probably" 及其后续内容 - result = re.split(r'\s+Probably\s', result, flags=re.IGNORECASE, maxsplit=1)[0] - # 移除 "The instruction:" 及其后续内容 - result = re.split(r'\s+The instruction:', result, flags=re.IGNORECASE, maxsplit=1)[0] - - return result.strip() - @app.post("/v1/completions") async def create_completion(request: CompletionRequest): - from prompt import build_prompt - import ollama - - print(f"[Backend] POST /v1/completions called") - print(f"[Backend] Received request - prefix length: {len(request.prefix)}, suffix length: {len(request.suffix)}") - - OLLAMA_MODEL = os.getenv('OLLAMA_MODEL', 'gpt-oss:20b') - OLLAMA_HOST = os.getenv('OLLAMA_HOST', 'http://192.168.0.120:11434') - - print(f"[LLM] Using host: {OLLAMA_HOST}, model: {OLLAMA_MODEL}") - try: prompt = build_prompt(request.prefix, request.suffix) - print(f"[Backend] Built prompt (first 100 chars): {prompt[:100]}...") - print(f"[LLM] Full prompt:\n{prompt}\n") + result = await call_ollama(prompt) - # 使用非流式 API 获取完整响应 - print(f"[LLM] Calling Ollama API (non-streaming)...") - client = ollama.AsyncClient(host=OLLAMA_HOST) - response = await client.chat( - model=OLLAMA_MODEL, - messages=[{'role': 'user', 'content': prompt}], - stream=False, - options={ - 'temperature': 0.2, - } - ) + content = result["content"] - print(f"[LLM] Response type: {type(response)}") - - # 提取 content 和 thinking - content = "" - thinking = "" - - if hasattr(response, 'message') and response.message: - content = response.message.content or "" - thinking = getattr(response.message, 'thinking', '') or "" - elif isinstance(response, dict): - msg = response.get('message', {}) - content = msg.get('content', '') or "" - thinking = msg.get('thinking', '') or "" - - print(f"[LLM] Original content: {repr(content[:100] if content else '')}...") - print(f"[LLM] Thinking length: {len(thinking)}") - print(f"[LLM] Thinking (first 200): {thinking[:200]}...") - - # 如果 content 为空,尝试从 thinking 中提取 - if not content and thinking: - print(f"[LLM] Content is empty, extracting from thinking...") - content = extract_completion_from_thinking(thinking) - print(f"[LLM] Extracted completion: {repr(content[:100])}...") - - print(f"[LLM] Final content length: {len(content)}") - - # 返回完整内容 async def generate(): if content: - print(f"[LLM] Yielding full content: {repr(content)}") yield f"data: {json.dumps({'content': content})}\n\n" yield f"data: {json.dumps({'done': True})}\n\n" return StreamingResponse(generate(), media_type="text/event-stream") except Exception as e: - error_msg = f"{{\"error\": \"{str(e)}\"}}" - print(f"[Backend] Error: {e}") import traceback traceback.print_exc() return JSONResponse(content={"error": str(e)}, status_code=500) if __name__ == "__main__": import uvicorn - print("[Backend] Starting server on http://0.0.0.0:8000") uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/plans/refactor-backend.md b/plans/refactor-backend.md new file mode 100644 index 0000000..5c1613a --- /dev/null +++ b/plans/refactor-backend.md @@ -0,0 +1,77 @@ +# 重构计划:统一 backend/llm.py 和 backend/main.py + +## 目标 + +消除 `llm.py` 和 `main.py` 之间的代码冗余,建立清晰的职责分离。 + +## 当前问题 + +```mermaid +graph LR + A[llm.py] -->|流式调用| B[Ollama API] + C[main.py] -->|非流式调用| B + A -.->|未被使用| D[❌ 冗余] +``` + +## 重构后架构 + +```mermaid +graph LR + A[main.py] -->|导入调用| B[llm.py] + B -->|非流式调用| C[Ollama API] + A --> D[FastAPI 路由处理] +``` + +## 具体步骤 + +### 步骤 1:重构 llm.py + +将 `stream_openai` 函数改为非流式调用,参考 main.py 的实现: + +```python +# 新的 llm.py 结构 +async def call_ollama(prompt: str) -> dict: + # 非流式调用 + # 返回 {"content": str, "thinking": str} +``` + +关键改动: +- 移除 `AsyncGenerator` 类型,改为返回 `dict` +- 设置 `stream=False` +- 使用 `temperature=0.2`(与 main.py 一致) +- 返回 content 和 thinking 字段 + +### 步骤 2:重构 main.py + +导入并使用 llm.py: + +```python +# main.py 改动 +from llm import call_ollama + +@app.post("/v1/completions") +async def create_completion(request: CompletionRequest): + prompt = build_prompt(request.prefix, request.suffix) + result = await call_ollama(prompt) + # 使用 result["content"] 和 result["thinking"] +``` + +删除的代码: +- 直接导入 `ollama` 的代码 +- 重复创建 `AsyncClient` 的代码 +- 重复的 API 调用逻辑 +- 重复的环境变量读取 + +### 步骤 3:清理冗余 + +- 移除 llm.py 中不再需要的 `AsyncGenerator` 导入 +- 移除 main.py 中重复的环境变量定义 +- 确保调试日志保留但不过度 + +## 文件职责划分 + +| 文件 | 职责 | +|------|------| +| `llm.py` | Ollama API 调用封装、模型配置 | +| `main.py` | FastAPI 路由、请求解析、响应格式化 | +| `prompt.py` | Prompt 构建逻辑 |