""" LLM 中转网关:FastAPI + 嵌入式前端(单文件)。 账号与 API Key 从 JSON 配置文件读取,不提供注册。 运行:pip install -r requirements.txt && uvicorn main:app --host 0.0.0.0 --port 8000 配置:复制 gateway.json.example 为 gateway.json,填写 username、password。 api_key 可留空,首次启动会自动生成 sk-... 并写回 gateway.json。 环境变量(可选): JWT_SECRET JWT 签名密钥(生产环境务必修改) UPSTREAM_URL 上游 OpenAI 兼容地址,默认 http://127.0.0.1:10434 APP_ROOT 反代子路径前缀,例如 /wg GATEWAY_CONFIG 配置文件路径,默认当前目录下 gateway.json NODES_CONFIG 节点/模型配置,默认 nodes.json STATS_DB 访问统计 SQLite 路径,默认 gateway_stats.db GATEWAY_PORT 监听端口,默认 8150 """ from __future__ import annotations import asyncio import json import os import secrets import sqlite3 import threading from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import datetime, timedelta, timezone from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Tuple, TypeVar try: from typing import Annotated # Python 3.9+ except ImportError: # Python 3.8 from typing_extensions import Annotated import httpx _T = TypeVar("_T") async def run_in_thread(func: Callable[..., _T], /, *args: Any) -> _T: """在线程池执行阻塞函数(兼容 Python 3.8,无 asyncio.to_thread)。""" if hasattr(asyncio, "to_thread"): return await run_in_thread(func, *args) # type: ignore[attr-defined] loop = asyncio.get_running_loop() return await loop.run_in_executor(None, func, *args) from fastapi import Depends, FastAPI, HTTPException, Query, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer import bcrypt from jose import JWTError, jwt from pydantic import BaseModel, Field # --------------------------------------------------------------------------- # 配置 # --------------------------------------------------------------------------- CONFIG_PATH = os.environ.get("GATEWAY_CONFIG", "gateway.json") NODES_PATH = os.environ.get("NODES_CONFIG", "nodes.json") STATS_DB = os.environ.get("STATS_DB", "gateway_stats.db") GATEWAY_PORT = int(os.environ.get("GATEWAY_PORT", "8150")) JWT_SECRET = os.environ.get("JWT_SECRET", "dev-secret-change-me") JWT_ALG = "HS256" JWT_EXPIRE_DAYS = 7 GATE_WEB_UID = 1 UPSTREAM_BASE = os.environ.get("UPSTREAM_URL", "http://127.0.0.1:10434").rstrip("/") APP_ROOT = os.environ.get("APP_ROOT", "").rstrip("/") def app_url(path: str) -> str: if not path.startswith("/"): path = "/" + path return APP_ROOT + path if APP_ROOT else path _APP_ROOT_JSON = json.dumps(APP_ROOT) security_bearer = HTTPBearer(auto_error=False) # 启动时由 load_gateway_config() 填充 _GATE: Dict[str, Any] = {} _PASSWORD_HASH: str = "" _db_lock = threading.Lock() _nodes_lock = threading.Lock() _NODES: Dict[str, Any] = {"nodes": []} _node_runtime: Dict[str, Dict[str, Any]] = {} _health_task: Optional[asyncio.Task] = None STATUS_LABELS: Dict[str, str] = { "offline": "离线", "idle": "空闲", "busy": "忙碌", "disabled": "未启用", "error": "异常", } def _is_bcrypt_hash(value: str) -> bool: return value.startswith(("$2a$", "$2b$", "$2y$")) def hash_password(p: str) -> str: return bcrypt.hashpw(p.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") def verify_password(plain: str, hashed: str) -> bool: try: return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8")) except (ValueError, TypeError): return False def generate_api_key() -> str: return "sk-" + secrets.token_urlsafe(32) def load_gateway_config() -> None: """读取 gateway.json;若无 api_key 则生成并写回文件。""" global _GATE, _PASSWORD_HASH if not os.path.isfile(CONFIG_PATH): raise RuntimeError( f"缺少配置文件: {CONFIG_PATH}(可复制 gateway.json.example 为 gateway.json)" ) with open(CONFIG_PATH, encoding="utf-8") as f: raw = json.load(f) username = str(raw.get("username", "")).strip() password = raw.get("password", "") if not isinstance(password, str): password = str(password) if not username or not password: raise ValueError(f"{CONFIG_PATH} 须包含非空的 username 与 password") api_key = str(raw.get("api_key", "")).strip() changed = False if not api_key: api_key = generate_api_key() raw["api_key"] = api_key changed = True if changed: with open(CONFIG_PATH, "w", encoding="utf-8") as wf: json.dump(raw, wf, indent=2, ensure_ascii=False) wf.write("\n") _GATE = {"username": username, "api_key": api_key} if _is_bcrypt_hash(password): _PASSWORD_HASH = password else: _PASSWORD_HASH = hash_password(password) def create_web_token(user_id: int) -> str: expire = datetime.now(timezone.utc) + timedelta(days=JWT_EXPIRE_DAYS) payload = {"sub": str(user_id), "exp": expire, "typ": "web"} return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALG) def decode_web_token(token: str) -> int: try: payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALG]) if payload.get("typ") != "web": raise JWTError("wrong token type") sub = payload.get("sub") if not sub: raise JWTError("missing sub") return int(sub) except JWTError as e: raise HTTPException(status_code=401, detail="无效或过期的登录状态") from e @dataclass class GateSessionUser: username: str api_key: str # --------------------------------------------------------------------------- # 节点 / 模型配置(nodes.json) # --------------------------------------------------------------------------- def _default_nodes_payload() -> Dict[str, Any]: return { "nodes": [ { "id": "node-1", "name": "节点 1", "host": "127.0.0.1", "port": 3313, "enabled": True, "max_concurrent": 1, "models": [], }, { "id": "node-2", "name": "节点 2", "host": "127.0.0.1", "port": 3314, "enabled": True, "max_concurrent": 1, "models": [], }, { "id": "node-3", "name": "节点 3", "host": "127.0.0.1", "port": 3315, "enabled": True, "max_concurrent": 1, "models": [], }, ] } def ensure_nodes_file() -> None: if not os.path.isfile(NODES_PATH): example = NODES_PATH + ".example" if os.path.isfile(example): with open(example, encoding="utf-8") as f: data = json.load(f) else: data = _default_nodes_payload() save_nodes_config(data) def load_nodes_config() -> None: global _NODES, _node_runtime ensure_nodes_file() with open(NODES_PATH, encoding="utf-8") as f: data = json.load(f) nodes = data.get("nodes") if not isinstance(nodes, list): nodes = [] _NODES = {"nodes": nodes} with _nodes_lock: for n in nodes: nid = str(n.get("id", "")) if nid and nid not in _node_runtime: _node_runtime[nid] = { "in_flight": 0, "healthy": None, "last_check": None, "last_error": "", } def save_nodes_config(data: Optional[Dict[str, Any]] = None) -> None: global _NODES payload = data if data is not None else _NODES with open(NODES_PATH, "w", encoding="utf-8") as f: json.dump(payload, f, indent=2, ensure_ascii=False) f.write("\n") load_nodes_config() def get_node_base(node: Dict[str, Any]) -> str: host = str(node.get("host", "127.0.0.1")).strip() or "127.0.0.1" port = int(node.get("port", 0)) return f"http://{host}:{port}".rstrip("/") def _runtime(nid: str) -> Dict[str, Any]: with _nodes_lock: if nid not in _node_runtime: _node_runtime[nid] = { "in_flight": 0, "healthy": None, "last_check": None, "last_error": "", } return _node_runtime[nid] def node_in_flight_inc(nid: str) -> None: with _nodes_lock: rt = _runtime(nid) rt["in_flight"] = int(rt.get("in_flight", 0)) + 1 def node_in_flight_dec(nid: str) -> None: with _nodes_lock: rt = _runtime(nid) rt["in_flight"] = max(0, int(rt.get("in_flight", 0)) - 1) def compute_node_status(node: Dict[str, Any]) -> Tuple[str, str]: nid = str(node.get("id", "")) if not node.get("enabled", True): return "disabled", STATUS_LABELS["disabled"] rt = _runtime(nid) if rt.get("healthy") is False: err = str(rt.get("last_error") or "") if err: return "error", STATUS_LABELS["error"] return "offline", STATUS_LABELS["offline"] inflight = int(rt.get("in_flight", 0)) max_c = int(node.get("max_concurrent", 1) or 1) if inflight >= max_c: return "busy", STATUS_LABELS["busy"] if rt.get("healthy") is True: return "idle", STATUS_LABELS["idle"] return "offline", STATUS_LABELS["offline"] def find_nodes_for_model(model_id: str) -> List[Dict[str, Any]]: if not model_id: return [] out: List[Dict[str, Any]] = [] for node in _NODES.get("nodes", []): if not node.get("enabled", True): continue for m in node.get("models") or []: if str(m.get("id", "")) == model_id: out.append(node) break return out def select_node_for_model(model_id: str) -> Optional[Dict[str, Any]]: candidates = find_nodes_for_model(model_id) if not candidates: return None scored: List[Tuple[int, str, Dict[str, Any]]] = [] for node in candidates: status, _ = compute_node_status(node) if status in ("disabled", "error"): continue if status == "busy": continue nid = str(node["id"]) scored.append((int(_runtime(nid).get("in_flight", 0)), nid, node)) if not scored: for node in candidates: if compute_node_status(node)[0] != "disabled": nid = str(node["id"]) scored.append((int(_runtime(nid).get("in_flight", 0)), nid, node)) if not scored: return None scored.sort(key=lambda x: (x[0], x[1])) return scored[0][2] async def probe_node( client: httpx.AsyncClient, node: Dict[str, Any] ) -> Tuple[bool, str]: url = get_node_base(node) + "/v1/models" try: resp = await client.get(url, timeout=5.0) if resp.status_code < 500: return True, "" return False, f"HTTP {resp.status_code}" except httpx.RequestError as e: return False, str(e) async def refresh_all_node_health(client: httpx.AsyncClient) -> None: for node in _NODES.get("nodes", []): nid = str(node.get("id", "")) if not nid: continue ok, err = await probe_node(client, node) with _nodes_lock: rt = _runtime(nid) rt["healthy"] = ok rt["last_error"] = err rt["last_check"] = datetime.now(timezone.utc).isoformat() async def _health_loop(client: httpx.AsyncClient) -> None: while True: try: await refresh_all_node_health(client) except Exception: pass await asyncio.sleep(30) def _query_model_stats_today() -> Dict[str, Dict[str, int]]: today = _today_utc_prefix() out: Dict[str, Dict[str, int]] = {} with _db_lock: conn = sqlite3.connect(STATS_DB) conn.row_factory = sqlite3.Row try: rows = conn.execute( """ SELECT model, node_id, COUNT(*) AS requests, COALESCE(SUM(total_tokens), 0) AS total_tokens FROM api_logs WHERE created_at LIKE ? AND model IS NOT NULL GROUP BY model, node_id """, (today + "%",), ).fetchall() finally: conn.close() for row in rows: key = f"{row['model']}\0{row['node_id'] or ''}" out[key] = { "requests": int(row["requests"]), "total_tokens": int(row["total_tokens"]), } return out def build_model_cards() -> List[Dict[str, Any]]: stats = _query_model_stats_today() cards: List[Dict[str, Any]] = [] for node in _NODES.get("nodes", []): nid = str(node.get("id", "")) status, status_label = compute_node_status(node) host = str(node.get("host", "127.0.0.1")) port = int(node.get("port", 0)) models = node.get("models") or [] if not models: cards.append( { "node_id": nid, "node_name": node.get("name", nid), "model_id": "", "model_label": "(未配置模型)", "host": host, "port": port, "endpoint": f"{host}:{port}", "status": status if node.get("enabled", True) else "disabled", "status_label": status_label if node.get("enabled", True) else STATUS_LABELS["disabled"], "in_flight": int(_runtime(nid).get("in_flight", 0)), "max_concurrent": int(node.get("max_concurrent", 1)), "today_requests": 0, "today_tokens": 0, "last_error": str(_runtime(nid).get("last_error") or ""), "enabled": bool(node.get("enabled", True)), } ) continue for m in models: mid = str(m.get("id", "")) label = str(m.get("label") or mid) st = status st_label = status_label if not node.get("enabled", True): st, st_label = "disabled", STATUS_LABELS["disabled"] sk = f"{mid}\0{nid}" agg = stats.get(sk, {"requests": 0, "total_tokens": 0}) cards.append( { "node_id": nid, "node_name": node.get("name", nid), "model_id": mid, "model_label": label, "host": host, "port": port, "endpoint": f"{host}:{port}", "status": st, "status_label": st_label, "in_flight": int(_runtime(nid).get("in_flight", 0)), "max_concurrent": int(node.get("max_concurrent", 1)), "today_requests": agg["requests"], "today_tokens": agg["total_tokens"], "last_error": str(_runtime(nid).get("last_error") or ""), "enabled": bool(node.get("enabled", True)), } ) return cards def list_nodes_with_status() -> List[Dict[str, Any]]: rows: List[Dict[str, Any]] = [] for node in _NODES.get("nodes", []): status, status_label = compute_node_status(node) nid = str(node.get("id", "")) rt = _runtime(nid) item = dict(node) item["status"] = status item["status_label"] = status_label item["in_flight"] = int(rt.get("in_flight", 0)) item["last_error"] = str(rt.get("last_error") or "") item["last_check"] = rt.get("last_check") rows.append(item) return rows def _new_node_id() -> str: return "node-" + secrets.token_hex(4) # --------------------------------------------------------------------------- # 访问统计(SQLite) # --------------------------------------------------------------------------- def get_client_ip(request: Request) -> str: """反代后取真实客户端 IP(优先 X-Forwarded-For / X-Real-IP)。""" forwarded = request.headers.get("x-forwarded-for") if forwarded: return forwarded.split(",")[0].strip() real_ip = request.headers.get("x-real-ip") if real_ip: return real_ip.strip() if request.client and request.client.host: return request.client.host return "unknown" def parse_model_from_body(body: bytes) -> Optional[str]: try: data = json.loads(body) model = data.get("model") return str(model).strip() if model else None except (json.JSONDecodeError, TypeError, ValueError): return None def parse_usage_from_response(raw: bytes, content_type: str) -> Dict[str, int]: out = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} if not raw: return out try: text = raw.decode("utf-8", errors="replace") except Exception: return out usage: Optional[Dict[str, Any]] = None ct = (content_type or "").lower() if "event-stream" in ct or text.lstrip().startswith("data:"): for line in text.splitlines(): line = line.strip() if not line.startswith("data:"): continue payload = line[5:].strip() if not payload or payload == "[DONE]": continue try: obj = json.loads(payload) except json.JSONDecodeError: continue u = obj.get("usage") if isinstance(u, dict) and ( u.get("total_tokens") or u.get("prompt_tokens") or u.get("completion_tokens") ): usage = u else: try: obj = json.loads(text) u = obj.get("usage") if isinstance(u, dict): usage = u except json.JSONDecodeError: pass if not usage: return out prompt = int(usage.get("prompt_tokens") or 0) completion = int(usage.get("completion_tokens") or 0) total = int(usage.get("total_tokens") or 0) if not total: total = prompt + completion out["prompt_tokens"] = prompt out["completion_tokens"] = completion out["total_tokens"] = total return out def init_stats_db() -> None: with _db_lock: conn = sqlite3.connect(STATS_DB) try: conn.execute( """ CREATE TABLE IF NOT EXISTS api_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, created_at TEXT NOT NULL, client_ip TEXT NOT NULL, model TEXT, node_id TEXT, status_code INTEGER NOT NULL, req_bytes INTEGER NOT NULL DEFAULT 0, resp_bytes INTEGER NOT NULL DEFAULT 0, prompt_tokens INTEGER NOT NULL DEFAULT 0, completion_tokens INTEGER NOT NULL DEFAULT 0, total_tokens INTEGER NOT NULL DEFAULT 0 ) """ ) conn.execute( "CREATE INDEX IF NOT EXISTS idx_api_logs_created_at ON api_logs(created_at)" ) conn.execute( "CREATE INDEX IF NOT EXISTS idx_api_logs_client_ip ON api_logs(client_ip)" ) try: conn.execute("ALTER TABLE api_logs ADD COLUMN node_id TEXT") except sqlite3.OperationalError: pass conn.execute( "CREATE INDEX IF NOT EXISTS idx_api_logs_node_id ON api_logs(node_id)" ) conn.commit() finally: conn.close() def _record_api_log_sync( client_ip: str, model: Optional[str], status_code: int, req_bytes: int, resp_bytes: int, usage: Dict[str, int], node_id: Optional[str] = None, ) -> None: created_at = datetime.now(timezone.utc).isoformat() with _db_lock: conn = sqlite3.connect(STATS_DB) try: conn.execute( """ INSERT INTO api_logs ( created_at, client_ip, model, node_id, status_code, req_bytes, resp_bytes, prompt_tokens, completion_tokens, total_tokens ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( created_at, client_ip, model, node_id, status_code, req_bytes, resp_bytes, usage["prompt_tokens"], usage["completion_tokens"], usage["total_tokens"], ), ) conn.commit() finally: conn.close() async def record_api_log( client_ip: str, model: Optional[str], status_code: int, req_bytes: int, resp_bytes: int, usage: Dict[str, int], node_id: Optional[str] = None, ) -> None: await run_in_thread( _record_api_log_sync, client_ip, model, status_code, req_bytes, resp_bytes, usage, node_id, ) def _today_utc_prefix() -> str: return datetime.now(timezone.utc).strftime("%Y-%m-%d") def _query_stats_summary() -> Dict[str, Any]: today = _today_utc_prefix() with _db_lock: conn = sqlite3.connect(STATS_DB) conn.row_factory = sqlite3.Row try: row = conn.execute( """ SELECT COUNT(*) AS total_requests, COALESCE(SUM(total_tokens), 0) AS total_tokens, COUNT(DISTINCT client_ip) AS unique_ips FROM api_logs """ ).fetchone() today_row = conn.execute( """ SELECT COUNT(*) AS today_requests, COALESCE(SUM(total_tokens), 0) AS today_tokens FROM api_logs WHERE created_at LIKE ? """, (today + "%",), ).fetchone() finally: conn.close() return { "total_requests": int(row["total_requests"]), "total_tokens": int(row["total_tokens"]), "unique_ips": int(row["unique_ips"]), "today_requests": int(today_row["today_requests"]), "today_tokens": int(today_row["today_tokens"]), } def _query_stats_ips(limit: int = 200) -> List[Dict[str, Any]]: with _db_lock: conn = sqlite3.connect(STATS_DB) conn.row_factory = sqlite3.Row try: rows = conn.execute( """ SELECT client_ip AS ip, COUNT(*) AS requests, COALESCE(SUM(total_tokens), 0) AS total_tokens, COALESCE(SUM(prompt_tokens), 0) AS prompt_tokens, COALESCE(SUM(completion_tokens), 0) AS completion_tokens, MAX(created_at) AS last_seen FROM api_logs GROUP BY client_ip ORDER BY last_seen DESC LIMIT ? """, (limit,), ).fetchall() finally: conn.close() return [dict(r) for r in rows] def _query_stats_billing(days: int) -> Dict[str, Any]: since = (datetime.now(timezone.utc) - timedelta(days=days)).date().isoformat() with _db_lock: conn = sqlite3.connect(STATS_DB) conn.row_factory = sqlite3.Row try: by_day = conn.execute( """ SELECT substr(created_at, 1, 10) AS day, COUNT(*) AS requests, COALESCE(SUM(prompt_tokens), 0) AS prompt_tokens, COALESCE(SUM(completion_tokens), 0) AS completion_tokens, COALESCE(SUM(total_tokens), 0) AS total_tokens FROM api_logs WHERE substr(created_at, 1, 10) >= ? GROUP BY day ORDER BY day DESC """, (since,), ).fetchall() by_ip = conn.execute( """ SELECT client_ip AS ip, COUNT(*) AS requests, COALESCE(SUM(prompt_tokens), 0) AS prompt_tokens, COALESCE(SUM(completion_tokens), 0) AS completion_tokens, COALESCE(SUM(total_tokens), 0) AS total_tokens FROM api_logs WHERE substr(created_at, 1, 10) >= ? GROUP BY client_ip ORDER BY total_tokens DESC LIMIT 100 """, (since,), ).fetchall() finally: conn.close() return { "days": days, "by_day": [dict(r) for r in by_day], "by_ip": [dict(r) for r in by_ip], } def _query_stats_logs(limit: int, offset: int) -> Tuple[List[Dict[str, Any]], int]: limit = max(1, min(limit, 500)) offset = max(0, offset) with _db_lock: conn = sqlite3.connect(STATS_DB) conn.row_factory = sqlite3.Row try: total = conn.execute("SELECT COUNT(*) FROM api_logs").fetchone()[0] rows = conn.execute( """ SELECT id, created_at, client_ip, model, status_code, req_bytes, resp_bytes, prompt_tokens, completion_tokens, total_tokens FROM api_logs ORDER BY id DESC LIMIT ? OFFSET ? """, (limit, offset), ).fetchall() finally: conn.close() return [dict(r) for r in rows], int(total) # --------------------------------------------------------------------------- # Pydantic # --------------------------------------------------------------------------- class LoginBody(BaseModel): username: str = Field(..., min_length=1) password: str = Field(..., min_length=1) class ModelEntryIn(BaseModel): id: str = Field(..., min_length=1) label: str = "" class NodeBody(BaseModel): name: str = Field(..., min_length=1) host: str = "127.0.0.1" port: int = Field(..., ge=1, le=65535) enabled: bool = True max_concurrent: int = Field(1, ge=1, le=32) models: List[ModelEntryIn] = Field(default_factory=list) # --------------------------------------------------------------------------- # 依赖 # --------------------------------------------------------------------------- async def get_current_web_user( creds: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security_bearer)], ) -> GateSessionUser: if creds is None or creds.scheme.lower() != "bearer": raise HTTPException(status_code=401, detail="需要登录(Bearer Token)") token = creds.credentials if token.startswith("sk-"): raise HTTPException(status_code=401, detail="网页登录请使用登录接口返回的令牌,不是 API Key") uid = decode_web_token(token) if uid != GATE_WEB_UID: raise HTTPException(status_code=401, detail="登录状态无效") return GateSessionUser(username=_GATE["username"], api_key=_GATE["api_key"]) async def get_user_from_api_key( creds: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security_bearer)], ) -> GateSessionUser: if creds is None or creds.scheme.lower() != "bearer": raise HTTPException( status_code=401, detail='请在 Header 中携带 Authorization: Bearer sk-...', ) key = creds.credentials.strip() if not key.startswith("sk-"): raise HTTPException(status_code=401, detail="API Key 必须以 sk- 开头") if key != _GATE.get("api_key"): raise HTTPException(status_code=401, detail="无效的 API Key") return GateSessionUser(username=_GATE["username"], api_key=key) # --------------------------------------------------------------------------- # HTML(Tailwind CDN) # --------------------------------------------------------------------------- TW_SCRIPT = "https://cdn.tailwindcss.com" SHELL_HEAD = f""" {{title}}
AI 中转网关
""" SHELL_FOOT = """
""" def page(title: str, inner: str) -> str: return ( SHELL_HEAD.replace("{title}", title) + inner + SHELL_FOOT ) HOME_HTML = page( "中转网关", f"""

模型分布

各节点模型状态(约每 5 秒刷新;主机默认 127.0.0.1,经 frp 映射端口)

管理节点与模型 →
加载中…