diff --git a/backend/GeoLite2-City.mmdb b/backend/GeoLite2-City.mmdb new file mode 100644 index 0000000..a2f6a1f Binary files /dev/null and b/backend/GeoLite2-City.mmdb differ diff --git a/backend/geoip.py b/backend/geoip.py new file mode 100644 index 0000000..5948ff2 --- /dev/null +++ b/backend/geoip.py @@ -0,0 +1,56 @@ +import os +import logging +from typing import Optional + +logger = logging.getLogger("api") + +_geoip_reader = None + + +def _get_reader(): + global _geoip_reader + if _geoip_reader is not None: + return _geoip_reader + try: + import geoip2.database + db_path = os.path.join(os.path.dirname(__file__), "GeoLite2-City.mmdb") + if os.path.exists(db_path): + _geoip_reader = geoip2.database.Reader(db_path) + logger.info("GeoIP database loaded: %s", db_path) + return _geoip_reader + else: + logger.warning("GeoIP database not found: %s", db_path) + except ImportError: + logger.warning("geoip2 not installed, IP location disabled") + except Exception as e: + logger.warning("Failed to load GeoIP database: %s", e) + return None + + +def get_ip_location(ip: str) -> Optional[dict]: + if not ip or ip in ("127.0.0.1", "localhost", "::1"): + return None + reader = _get_reader() + if not reader: + return None + try: + response = reader.city(ip) + country = response.country.name + region = response.subdivisions.most_specific.name if response.subdivisions else None + city = response.city.name + parts = [p for p in [country, region, city] if p] + if not parts: + return None + return { + "country": country, + "region": region, + "city": city, + "display": " ".join(parts) + } + except Exception: + return None + + +def get_ip_location_text(ip: str) -> str: + loc = get_ip_location(ip) + return loc["display"] if loc else "" diff --git a/backend/main.py b/backend/main.py index c8415c8..67a8433 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,4 +1,4 @@ -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel @@ -9,6 +9,7 @@ import logging from prompt import build_prompt, prepare_prompt_context from llm import call_ollama, call_vlm_ocr +from geoip import get_ip_location_text logging.basicConfig( level=logging.INFO, @@ -43,23 +44,32 @@ def _preview(text: str, limit: int = 80) -> str: return value return value[:limit] + "..." +def get_client_ip(request: Request) -> str: + return request.headers.get("X-Client-IP") or request.client.host if request.client else "unknown" + @app.post("/v1/completions") -async def create_completion(request: CompletionRequest): +async def create_completion(request: Request, req: CompletionRequest): request_id = str(uuid.uuid4())[:8] + client_ip = get_client_ip(request) + # 查询 IP 归属地 + location = get_ip_location_text(client_ip) + if location: + logger.info("[%s] client_location=%s", request_id, location) try: logger.info( - "[%s] /v1/completions prefix_chars=%d suffix_chars=%d lang=%s prefix_tail='%s' suffix_head='%s'", + "[%s] /v1/completions client_ip=%s prefix_chars=%d suffix_chars=%d lang=%s prefix_tail='%s' suffix_head='%s'", request_id, - len(request.prefix or ""), - len(request.suffix or ""), - request.languageId, - _preview((request.prefix or "")[-120:]), - _preview((request.suffix or "")[:120]), + client_ip, + len(req.prefix or ""), + len(req.suffix or ""), + req.languageId, + _preview((req.prefix or "")[-120:]), + _preview((req.suffix or "")[:120]), ) - llm_prefix, llm_suffix = prepare_prompt_context(request.prefix or "", request.suffix or "") + llm_prefix, llm_suffix = prepare_prompt_context(req.prefix or "", req.suffix or "") logger.info("[%s] llm_input_prefix=%r", request_id, llm_prefix) logger.info("[%s] llm_input_suffix=%r", request_id, llm_suffix) - prompt = build_prompt(request.prefix, request.suffix, request.languageId) + prompt = build_prompt(req.prefix, req.suffix, req.languageId, location=location) result = await call_ollama(prompt, tag=f"{request_id}-primary", temperature=0.7) content = result["content"] or "" diff --git a/backend/prompt.py b/backend/prompt.py index 4bb340e..12ec50a 100644 --- a/backend/prompt.py +++ b/backend/prompt.py @@ -29,12 +29,13 @@ def prepare_prompt_context(prefix: str, suffix: str) -> Tuple[str, str]: return _prepare_context(prefix, suffix) -def build_prompt(prefix: str, suffix: str, language_id: str = "markdown") -> str: +def build_prompt(prefix: str, suffix: str, language_id: str = "markdown", location: str = "") -> str: safe_language_id = _sanitize_language_id(language_id) recent_prefix, recent_suffix = _prepare_context(prefix, suffix) current_time = _get_current_datetime() + location_info = f"\nUser location: {location}" if location else "" - prompt = f"""Current time: {current_time} + prompt = f"""Current time: {current_time}{location_info} You are an inline completion engine for a {safe_language_id} editor with ghost-text suggestions. diff --git a/backend/requirements.txt b/backend/requirements.txt index 588c686..2a2f806 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -4,3 +4,4 @@ ollama pydantic python-dotenv httpx +geoip2 diff --git a/backend/test_geoip.py b/backend/test_geoip.py new file mode 100644 index 0000000..52ba305 --- /dev/null +++ b/backend/test_geoip.py @@ -0,0 +1,80 @@ +""" +GeoIP2 IP归属地查询测试脚本 + +使用方法: +1. 安装依赖:pip install geoip2 +2. 下载数据库:https://dev.maxmind.com/geoip/geoip2/geolite2/ +3. 运行测试:python test_geoip.py +""" + +import os +import sys + +try: + import geoip2.database +except ImportError: + print("请先安装 geoip2: pip install geoip2") + sys.exit(1) + +DB_PATH = os.path.join(os.path.dirname(__file__), "GeoLite2-City.mmdb") + +TEST_IPS = [ + "8.8.8.8", # Google DNS (美国) + "114.114.114.114", # 114 DNS (中国南京) + "223.5.5.5", # 阿里DNS (中国杭州) + "1.1.1.1", # Cloudflare DNS (澳大利亚) + "119.29.29.29", # 腾讯DNS (中国) +] + + +def get_location(reader, ip: str) -> dict: + try: + response = reader.city(ip) + return { + "ip": ip, + "country": response.country.name, + "country_code": response.country.iso_code, + "region": response.subdivisions.most_specific.name if response.subdivisions else None, + "city": response.city.name, + "latitude": response.location.latitude, + "longitude": response.location.longitude, + "timezone": response.location.time_zone, + } + except geoip2.errors.AddressNotFoundError: + return {"ip": ip, "error": "IP未在数据库中找到"} + except Exception as e: + return {"ip": ip, "error": str(e)} + + +def main(): + if not os.path.exists(DB_PATH): + print(f"数据库文件不存在: {DB_PATH}") + print("请从 https://dev.maxmind.com/geoip/geoip2/geolite2/ 下载 GeoLite2-City.mmdb") + return + + print(f"加载数据库: {DB_PATH}") + reader = geoip2.database.Reader(DB_PATH) + + print("\n" + "=" * 60) + print("IP归属地查询测试") + print("=" * 60) + + for ip in TEST_IPS: + result = get_location(reader, ip) + if "error" in result: + print(f"\n{ip}: {result['error']}") + else: + print(f"\n{ip}:") + print(f" 国家: {result['country']} ({result['country_code']})") + print(f" 地区: {result['region'] or '未知'}") + print(f" 城市: {result['city'] or '未知'}") + print(f" 坐标: {result['latitude']}, {result['longitude']}") + print(f" 时区: {result['timezone']}") + + reader.close() + print("\n" + "=" * 60) + print("测试完成") + + +if __name__ == "__main__": + main() diff --git a/src/utils/api.js b/src/utils/api.js index 6e153eb..29712c6 100644 --- a/src/utils/api.js +++ b/src/utils/api.js @@ -1,10 +1,30 @@ import { API_URL } from './config.js' +let cachedIP = null + +async function getClientIP() { + if (cachedIP) return cachedIP + try { + const controller = new AbortController() + setTimeout(() => controller.abort(), 3000) + const res = await fetch('https://api.ipify.org?format=json', { signal: controller.signal }) + const data = await res.json() + cachedIP = data.ip + return cachedIP + } catch { + return null + } +} + export async function fetchSuggestion(prefix, suffix, signal, apiUrl = API_URL) { try { + const clientIP = await getClientIP() + const headers = { 'Content-Type': 'application/json' } + if (clientIP) headers['X-Client-IP'] = clientIP + const res = await fetch(apiUrl, { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers, body: JSON.stringify({ prefix, suffix, languageId: 'markdown' }), signal })