feat: add IP geolocation tracking and include location in prompts
- Add GeoLite2-City.mmdb database for IP lookup - Create geoip.py module for IP location services - Extract client IP from requests and log location info - Pass location context to LLM prompts for enhanced responses
This commit is contained in:
BIN
backend/GeoLite2-City.mmdb
Normal file
BIN
backend/GeoLite2-City.mmdb
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 54 MiB |
56
backend/geoip.py
Normal file
56
backend/geoip.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger("api")
|
||||
|
||||
_geoip_reader = None
|
||||
|
||||
|
||||
def _get_reader():
|
||||
global _geoip_reader
|
||||
if _geoip_reader is not None:
|
||||
return _geoip_reader
|
||||
try:
|
||||
import geoip2.database
|
||||
db_path = os.path.join(os.path.dirname(__file__), "GeoLite2-City.mmdb")
|
||||
if os.path.exists(db_path):
|
||||
_geoip_reader = geoip2.database.Reader(db_path)
|
||||
logger.info("GeoIP database loaded: %s", db_path)
|
||||
return _geoip_reader
|
||||
else:
|
||||
logger.warning("GeoIP database not found: %s", db_path)
|
||||
except ImportError:
|
||||
logger.warning("geoip2 not installed, IP location disabled")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load GeoIP database: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def get_ip_location(ip: str) -> Optional[dict]:
|
||||
if not ip or ip in ("127.0.0.1", "localhost", "::1"):
|
||||
return None
|
||||
reader = _get_reader()
|
||||
if not reader:
|
||||
return None
|
||||
try:
|
||||
response = reader.city(ip)
|
||||
country = response.country.name
|
||||
region = response.subdivisions.most_specific.name if response.subdivisions else None
|
||||
city = response.city.name
|
||||
parts = [p for p in [country, region, city] if p]
|
||||
if not parts:
|
||||
return None
|
||||
return {
|
||||
"country": country,
|
||||
"region": region,
|
||||
"city": city,
|
||||
"display": " ".join(parts)
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def get_ip_location_text(ip: str) -> str:
|
||||
loc = get_ip_location(ip)
|
||||
return loc["display"] if loc else ""
|
||||
@@ -1,4 +1,4 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from pydantic import BaseModel
|
||||
@@ -9,6 +9,7 @@ import logging
|
||||
|
||||
from prompt import build_prompt, prepare_prompt_context
|
||||
from llm import call_ollama, call_vlm_ocr
|
||||
from geoip import get_ip_location_text
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -43,23 +44,32 @@ def _preview(text: str, limit: int = 80) -> str:
|
||||
return value
|
||||
return value[:limit] + "..."
|
||||
|
||||
def get_client_ip(request: Request) -> str:
|
||||
return request.headers.get("X-Client-IP") or request.client.host if request.client else "unknown"
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest):
|
||||
async def create_completion(request: Request, req: CompletionRequest):
|
||||
request_id = str(uuid.uuid4())[:8]
|
||||
client_ip = get_client_ip(request)
|
||||
# 查询 IP 归属地
|
||||
location = get_ip_location_text(client_ip)
|
||||
if location:
|
||||
logger.info("[%s] client_location=%s", request_id, location)
|
||||
try:
|
||||
logger.info(
|
||||
"[%s] /v1/completions prefix_chars=%d suffix_chars=%d lang=%s prefix_tail='%s' suffix_head='%s'",
|
||||
"[%s] /v1/completions client_ip=%s prefix_chars=%d suffix_chars=%d lang=%s prefix_tail='%s' suffix_head='%s'",
|
||||
request_id,
|
||||
len(request.prefix or ""),
|
||||
len(request.suffix or ""),
|
||||
request.languageId,
|
||||
_preview((request.prefix or "")[-120:]),
|
||||
_preview((request.suffix or "")[:120]),
|
||||
client_ip,
|
||||
len(req.prefix or ""),
|
||||
len(req.suffix or ""),
|
||||
req.languageId,
|
||||
_preview((req.prefix or "")[-120:]),
|
||||
_preview((req.suffix or "")[:120]),
|
||||
)
|
||||
llm_prefix, llm_suffix = prepare_prompt_context(request.prefix or "", request.suffix or "")
|
||||
llm_prefix, llm_suffix = prepare_prompt_context(req.prefix or "", req.suffix or "")
|
||||
logger.info("[%s] llm_input_prefix=%r", request_id, llm_prefix)
|
||||
logger.info("[%s] llm_input_suffix=%r", request_id, llm_suffix)
|
||||
prompt = build_prompt(request.prefix, request.suffix, request.languageId)
|
||||
prompt = build_prompt(req.prefix, req.suffix, req.languageId, location=location)
|
||||
result = await call_ollama(prompt, tag=f"{request_id}-primary", temperature=0.7)
|
||||
|
||||
content = result["content"] or ""
|
||||
|
||||
@@ -29,12 +29,13 @@ def prepare_prompt_context(prefix: str, suffix: str) -> Tuple[str, str]:
|
||||
return _prepare_context(prefix, suffix)
|
||||
|
||||
|
||||
def build_prompt(prefix: str, suffix: str, language_id: str = "markdown") -> str:
|
||||
def build_prompt(prefix: str, suffix: str, language_id: str = "markdown", location: str = "") -> str:
|
||||
safe_language_id = _sanitize_language_id(language_id)
|
||||
recent_prefix, recent_suffix = _prepare_context(prefix, suffix)
|
||||
current_time = _get_current_datetime()
|
||||
location_info = f"\nUser location: {location}" if location else ""
|
||||
|
||||
prompt = f"""Current time: {current_time}
|
||||
prompt = f"""Current time: {current_time}{location_info}
|
||||
|
||||
You are an inline completion engine for a {safe_language_id} editor with ghost-text suggestions.
|
||||
|
||||
|
||||
@@ -4,3 +4,4 @@ ollama
|
||||
pydantic
|
||||
python-dotenv
|
||||
httpx
|
||||
geoip2
|
||||
|
||||
80
backend/test_geoip.py
Normal file
80
backend/test_geoip.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
GeoIP2 IP归属地查询测试脚本
|
||||
|
||||
使用方法:
|
||||
1. 安装依赖:pip install geoip2
|
||||
2. 下载数据库:https://dev.maxmind.com/geoip/geoip2/geolite2/
|
||||
3. 运行测试:python test_geoip.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
try:
|
||||
import geoip2.database
|
||||
except ImportError:
|
||||
print("请先安装 geoip2: pip install geoip2")
|
||||
sys.exit(1)
|
||||
|
||||
DB_PATH = os.path.join(os.path.dirname(__file__), "GeoLite2-City.mmdb")
|
||||
|
||||
TEST_IPS = [
|
||||
"8.8.8.8", # Google DNS (美国)
|
||||
"114.114.114.114", # 114 DNS (中国南京)
|
||||
"223.5.5.5", # 阿里DNS (中国杭州)
|
||||
"1.1.1.1", # Cloudflare DNS (澳大利亚)
|
||||
"119.29.29.29", # 腾讯DNS (中国)
|
||||
]
|
||||
|
||||
|
||||
def get_location(reader, ip: str) -> dict:
|
||||
try:
|
||||
response = reader.city(ip)
|
||||
return {
|
||||
"ip": ip,
|
||||
"country": response.country.name,
|
||||
"country_code": response.country.iso_code,
|
||||
"region": response.subdivisions.most_specific.name if response.subdivisions else None,
|
||||
"city": response.city.name,
|
||||
"latitude": response.location.latitude,
|
||||
"longitude": response.location.longitude,
|
||||
"timezone": response.location.time_zone,
|
||||
}
|
||||
except geoip2.errors.AddressNotFoundError:
|
||||
return {"ip": ip, "error": "IP未在数据库中找到"}
|
||||
except Exception as e:
|
||||
return {"ip": ip, "error": str(e)}
|
||||
|
||||
|
||||
def main():
|
||||
if not os.path.exists(DB_PATH):
|
||||
print(f"数据库文件不存在: {DB_PATH}")
|
||||
print("请从 https://dev.maxmind.com/geoip/geoip2/geolite2/ 下载 GeoLite2-City.mmdb")
|
||||
return
|
||||
|
||||
print(f"加载数据库: {DB_PATH}")
|
||||
reader = geoip2.database.Reader(DB_PATH)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("IP归属地查询测试")
|
||||
print("=" * 60)
|
||||
|
||||
for ip in TEST_IPS:
|
||||
result = get_location(reader, ip)
|
||||
if "error" in result:
|
||||
print(f"\n{ip}: {result['error']}")
|
||||
else:
|
||||
print(f"\n{ip}:")
|
||||
print(f" 国家: {result['country']} ({result['country_code']})")
|
||||
print(f" 地区: {result['region'] or '未知'}")
|
||||
print(f" 城市: {result['city'] or '未知'}")
|
||||
print(f" 坐标: {result['latitude']}, {result['longitude']}")
|
||||
print(f" 时区: {result['timezone']}")
|
||||
|
||||
reader.close()
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,10 +1,30 @@
|
||||
import { API_URL } from './config.js'
|
||||
|
||||
let cachedIP = null
|
||||
|
||||
async function getClientIP() {
|
||||
if (cachedIP) return cachedIP
|
||||
try {
|
||||
const controller = new AbortController()
|
||||
setTimeout(() => controller.abort(), 3000)
|
||||
const res = await fetch('https://api.ipify.org?format=json', { signal: controller.signal })
|
||||
const data = await res.json()
|
||||
cachedIP = data.ip
|
||||
return cachedIP
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchSuggestion(prefix, suffix, signal, apiUrl = API_URL) {
|
||||
try {
|
||||
const clientIP = await getClientIP()
|
||||
const headers = { 'Content-Type': 'application/json' }
|
||||
if (clientIP) headers['X-Client-IP'] = clientIP
|
||||
|
||||
const res = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
headers,
|
||||
body: JSON.stringify({ prefix, suffix, languageId: 'markdown' }),
|
||||
signal
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user