commit a98b24fa32ccab53e44dd45d12927193929ed120 Author: dekun Date: Tue May 19 02:05:00 2026 +0800 first commit diff --git a/DEPLOY.md b/DEPLOY.md new file mode 100644 index 0000000..751eed7 --- /dev/null +++ b/DEPLOY.md @@ -0,0 +1,181 @@ +# LLM 中转网关 · PM2 部署说明 + +本文说明如何在服务器上使用 **PM2** 守护 **FastAPI(uvicorn)** 进程,实现崩溃自动拉起、开机自启与日志管理。 + +--- + +## 1. 前置条件 + +| 组件 | 说明 | +|------|------| +| **Python** | 建议 3.10+,已加入 `PATH` | +| **Node.js** | 用于安装 PM2(LTS 即可) | +| **PM2** | `npm i -g pm2` | +| **本项目** | 含 `main.py`、`requirements.txt`,可选 `ecosystem.config.cjs` | + +上游本地大模型(默认 `http://127.0.0.1:10434`)需在同一台机器或内网可达;若在大模型同一主机部署网关,确保端口未被占用(下文示例为 **8000**)。 + +--- + +## 2. 准备运行环境 + +在 **项目根目录**(与 `main.py` 同级)执行: + +### Linux / macOS + +```bash +cd /path/to/中转网关 +python3 -m venv venv +source venv/bin/activate +pip install -r requirements.txt -U +``` + +### Windows(PowerShell) + +```powershell +cd C:\path\to\中转网关 +python -m venv venv +.\venv\Scripts\Activate.ps1 +pip install -r requirements.txt -U +``` + +在项目目录创建 **`gateway.json`**(可复制 `gateway.json.example`),填写 **`username`**、**`password`**;**`api_key`** 可留空,首次启动会自动写入 `sk-...`。也可用环境变量 **`GATEWAY_CONFIG`** 指定配置文件绝对路径。 + +确认可手动启动: + +```bash +python -m uvicorn main:app --host 0.0.0.0 --port 8000 +``` + +浏览器访问 `http://服务器IP:8000` 正常后再交给 PM2。 + +--- + +## 3. 环境变量(生产必改) + +| 变量 | 含义 | 建议 | +|------|------|------| +| `JWT_SECRET` | 签发网页登录 JWT 的密钥 | 长随机字符串,勿泄露 | +| `UPSTREAM_URL` | 上游 OpenAI 兼容服务根地址(无末尾 `/`) | 默认 `http://127.0.0.1:10434`,按实际修改 | + +可在 shell 中导出,或在 PM2 的 `env` / `ecosystem.config.cjs` 中配置(见下文)。 + +--- + +## 4. 使用 PM2 启动(推荐 ecosystem) + +项目根目录已提供 **`ecosystem.config.cjs`**:自动根据 **Windows / 非 Windows** 选择虚拟环境中的 Python 解释器,并启动: + +`python -m uvicorn main:app --host 0.0.0.0 --port 8000` + +### 4.1 修改配置 + +1. 用编辑器打开 `ecosystem.config.cjs`。 +2. 在 `env` 里填写 **`JWT_SECRET`**(必填),按需修改 **`UPSTREAM_URL`**。 +3. 若端口冲突,可在 `args` 中把 `8000` 改成其它端口。 + +### 4.2 启动与常用命令 + +在项目根目录执行: + +```bash +pm2 start ecosystem.config.cjs +pm2 save +``` + +常用运维: + +| 命令 | 作用 | +|------|------| +| `pm2 status` | 查看进程状态 | +| `pm2 logs llm-gateway` | 实时日志 | +| `pm2 restart llm-gateway` | 重启 | +| `pm2 stop llm-gateway` | 停止 | +| `pm2 delete llm-gateway` | 从 PM2 列表移除 | + +修改代码或依赖后:`pip install …` 完成再执行 `pm2 restart llm-gateway`。 + +--- + +## 5. 不用配置文件时的等价命令 + +若暂时不用 `ecosystem.config.cjs`,需自行写出虚拟环境里 **Python 可执行文件绝对路径**。 + +**Linux 示例:** + +```bash +pm2 start /path/to/中转网关/venv/bin/python \ + --name llm-gateway \ + --cwd /path/to/中转网关 \ + --interpreter none \ + -- -m uvicorn main:app --host 0.0.0.0 --port 8000 +``` + +**Windows 示例(PowerShell,路径按实际修改):** + +```powershell +pm2 start "C:\path\to\中转网关\venv\Scripts\python.exe" ` + --name llm-gateway ` + --cwd "C:\path\to\中转网关" ` + --interpreter none ` + -- -m uvicorn main:app --host 0.0.0.0 --port 8000 +``` + +然后通过 PM2 设置环境变量(任选其一): + +```bash +pm2 restart llm-gateway --update-env +``` + +或在启动前于终端导出 `JWT_SECRET`、`UPSTREAM_URL` 后再 `pm2 start …`(注意:`--update-env` 与持久化策略以当前 PM2 版本文档为准)。 + +--- + +## 6. 开机自启 + +PM2 保存当前进程列表后,生成并启用开机脚本: + +```bash +pm2 save +pm2 startup +``` + +终端会打印一行以 `sudo` 开头的命令,**复制执行**即可完成 systemd(Linux)或对应平台的自启配置。Windows 可使用 **pm2-windows-service** 或任务计划程序配合 `pm2 resurrect`,按环境选用。 + +--- + +## 7. 日志与配置文件 + +| 项 | 说明 | +|------|------| +| PM2 日志 | `pm2 logs llm-gateway`,默认在 `~/.pm2/logs/` | +| 账号与 API Key | **`gateway.json`**(或 `GATEWAY_CONFIG` 指向的文件),含明文密码,请限制权限(如 `chmod 600`),勿提交到版本库 | + +--- + +## 8. 反向代理(可选) + +对外只暴露 443/80 时,可用 Nginx/Caddy 把域名反向代理到 `http://127.0.0.1:8000`,并配置 TLS。此时上游 **`UPSTREAM_URL`** 仍指向本机大模型地址,与网关对外端口无关。 + +--- + +## 9. 故障排查 + +1. **`pm2 logs` 中报 ModuleNotFoundError** + 确认 PM2 使用的解释器是 **`venv` 内的 python**,且已在同一 venv 中执行 `pip install -r requirements.txt`。 + +2. **502 / 上游连接失败** + 检查大模型是否监听 `UPSTREAM_URL`,防火墙与内网穿透是否正常。 + +3. **修改 `JWT_SECRET` 后旧登录失效** + 属预期行为,用户需重新登录网页端。 + +--- + +## 10. 小结 + +1. 创建 venv 并安装 `requirements.txt`。 +2. 配置 **`JWT_SECRET`**(及可选 **`UPSTREAM_URL`**)。 +3. `pm2 start ecosystem.config.cjs` → `pm2 save` → `pm2 startup`(按需)。 + +进程名 **`llm-gateway`** 与 `ecosystem.config.cjs` 中 `name` 字段一致,便于 `pm2 restart llm-gateway` 等操作。 diff --git a/ecosystem.config.cjs b/ecosystem.config.cjs new file mode 100644 index 0000000..a69678e --- /dev/null +++ b/ecosystem.config.cjs @@ -0,0 +1,35 @@ +/** + * PM2 进程配置:在项目根目录执行 pm2 start ecosystem.config.cjs + * 请先创建 venv 并 pip install -r requirements.txt + */ +const path = require("path"); + +const isWin = process.platform === "win32"; +const py = isWin + ? path.join(__dirname, "venv", "Scripts", "python.exe") + : path.join(__dirname, "venv", "bin", "python"); + +module.exports = { + apps: [ + { + name: "llm-gateway", + cwd: __dirname, + interpreter: "none", + script: py, + args: "-m uvicorn main:app --host 0.0.0.0 --port 8150", + instances: 1, + autorestart: true, + watch: false, + max_memory_restart: "500M", + env: { + NODE_ENV: "production", + // 非 TTY 下避免 stdout/stderr 被缓冲,否则 pm2 logs 长时间看不到输出 + PYTHONUNBUFFERED: "1", + // 生产环境务必改为足够长的随机字符串 + JWT_SECRET: "change-me-to-a-long-random-secret", + GATEWAY_PORT: "8150", + UPSTREAM_URL: "http://127.0.0.1:10434", + }, + }, + ], +}; diff --git a/gateway.json.example b/gateway.json.example new file mode 100644 index 0000000..bb4159a --- /dev/null +++ b/gateway.json.example @@ -0,0 +1,5 @@ +{ + "username": "admin", + "password": "请修改为强密码", + "api_key": "" +} diff --git a/main.py b/main.py new file mode 100644 index 0000000..a998ed1 --- /dev/null +++ b/main.py @@ -0,0 +1,1756 @@ +""" +LLM 中转网关:FastAPI + 嵌入式前端(单文件)。 +账号与 API Key 从 JSON 配置文件读取,不提供注册。 + +运行:pip install -r requirements.txt && uvicorn main:app --host 0.0.0.0 --port 8000 + +配置:复制 gateway.json.example 为 gateway.json,填写 username、password。 + api_key 可留空,首次启动会自动生成 sk-... 并写回 gateway.json。 + +环境变量(可选): + JWT_SECRET JWT 签名密钥(生产环境务必修改) + UPSTREAM_URL 上游 OpenAI 兼容地址,默认 http://127.0.0.1:10434 + APP_ROOT 反代子路径前缀,例如 /wg + GATEWAY_CONFIG 配置文件路径,默认当前目录下 gateway.json + NODES_CONFIG 节点/模型配置,默认 nodes.json + STATS_DB 访问统计 SQLite 路径,默认 gateway_stats.db + GATEWAY_PORT 监听端口,默认 8150 +""" +from __future__ import annotations + +import asyncio +import json +import os +import secrets +import sqlite3 +import threading +from contextlib import asynccontextmanager +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Any, AsyncIterator, Dict, List, Optional, Tuple + +try: + from typing import Annotated # Python 3.9+ +except ImportError: # Python 3.8 + from typing_extensions import Annotated + +import httpx +from fastapi import Depends, FastAPI, HTTPException, Query, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from jose import JWTError, jwt +from passlib.context import CryptContext +from pydantic import BaseModel, Field + +# --------------------------------------------------------------------------- +# 配置 +# --------------------------------------------------------------------------- + +CONFIG_PATH = os.environ.get("GATEWAY_CONFIG", "gateway.json") +NODES_PATH = os.environ.get("NODES_CONFIG", "nodes.json") +STATS_DB = os.environ.get("STATS_DB", "gateway_stats.db") +GATEWAY_PORT = int(os.environ.get("GATEWAY_PORT", "8150")) +JWT_SECRET = os.environ.get("JWT_SECRET", "dev-secret-change-me") +JWT_ALG = "HS256" +JWT_EXPIRE_DAYS = 7 +GATE_WEB_UID = 1 +UPSTREAM_BASE = os.environ.get("UPSTREAM_URL", "http://127.0.0.1:10434").rstrip("/") +APP_ROOT = os.environ.get("APP_ROOT", "").rstrip("/") + + +def app_url(path: str) -> str: + if not path.startswith("/"): + path = "/" + path + return APP_ROOT + path if APP_ROOT else path + + +_APP_ROOT_JSON = json.dumps(APP_ROOT) + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +security_bearer = HTTPBearer(auto_error=False) + +# 启动时由 load_gateway_config() 填充 +_GATE: Dict[str, Any] = {} +_PASSWORD_HASH: str = "" +_db_lock = threading.Lock() +_nodes_lock = threading.Lock() +_NODES: Dict[str, Any] = {"nodes": []} +_node_runtime: Dict[str, Dict[str, Any]] = {} +_health_task: Optional[asyncio.Task] = None + +STATUS_LABELS: Dict[str, str] = { + "offline": "离线", + "idle": "空闲", + "busy": "忙碌", + "disabled": "未启用", + "error": "异常", +} + + +def hash_password(p: str) -> str: + return pwd_context.hash(p) + + +def verify_password(plain: str, hashed: str) -> bool: + return pwd_context.verify(plain, hashed) + + +def generate_api_key() -> str: + return "sk-" + secrets.token_urlsafe(32) + + +def load_gateway_config() -> None: + """读取 gateway.json;若无 api_key 则生成并写回文件。""" + global _GATE, _PASSWORD_HASH + if not os.path.isfile(CONFIG_PATH): + raise RuntimeError( + f"缺少配置文件: {CONFIG_PATH}(可复制 gateway.json.example 为 gateway.json)" + ) + with open(CONFIG_PATH, encoding="utf-8") as f: + raw = json.load(f) + username = str(raw.get("username", "")).strip() + password = raw.get("password", "") + if not isinstance(password, str): + password = str(password) + if not username or not password: + raise ValueError(f"{CONFIG_PATH} 须包含非空的 username 与 password") + + api_key = str(raw.get("api_key", "")).strip() + changed = False + if not api_key: + api_key = generate_api_key() + raw["api_key"] = api_key + changed = True + if changed: + with open(CONFIG_PATH, "w", encoding="utf-8") as wf: + json.dump(raw, wf, indent=2, ensure_ascii=False) + wf.write("\n") + + _GATE = {"username": username, "api_key": api_key} + _PASSWORD_HASH = hash_password(password) + + +def create_web_token(user_id: int) -> str: + expire = datetime.now(timezone.utc) + timedelta(days=JWT_EXPIRE_DAYS) + payload = {"sub": str(user_id), "exp": expire, "typ": "web"} + return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALG) + + +def decode_web_token(token: str) -> int: + try: + payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALG]) + if payload.get("typ") != "web": + raise JWTError("wrong token type") + sub = payload.get("sub") + if not sub: + raise JWTError("missing sub") + return int(sub) + except JWTError as e: + raise HTTPException(status_code=401, detail="无效或过期的登录状态") from e + + +@dataclass +class GateSessionUser: + username: str + api_key: str + + +# --------------------------------------------------------------------------- +# 节点 / 模型配置(nodes.json) +# --------------------------------------------------------------------------- + + +def _default_nodes_payload() -> Dict[str, Any]: + return { + "nodes": [ + { + "id": "node-1", + "name": "节点 1", + "host": "127.0.0.1", + "port": 3313, + "enabled": True, + "max_concurrent": 1, + "models": [], + }, + { + "id": "node-2", + "name": "节点 2", + "host": "127.0.0.1", + "port": 3314, + "enabled": True, + "max_concurrent": 1, + "models": [], + }, + { + "id": "node-3", + "name": "节点 3", + "host": "127.0.0.1", + "port": 3315, + "enabled": True, + "max_concurrent": 1, + "models": [], + }, + ] + } + + +def ensure_nodes_file() -> None: + if not os.path.isfile(NODES_PATH): + example = NODES_PATH + ".example" + if os.path.isfile(example): + with open(example, encoding="utf-8") as f: + data = json.load(f) + else: + data = _default_nodes_payload() + save_nodes_config(data) + + +def load_nodes_config() -> None: + global _NODES, _node_runtime + ensure_nodes_file() + with open(NODES_PATH, encoding="utf-8") as f: + data = json.load(f) + nodes = data.get("nodes") + if not isinstance(nodes, list): + nodes = [] + _NODES = {"nodes": nodes} + with _nodes_lock: + for n in nodes: + nid = str(n.get("id", "")) + if nid and nid not in _node_runtime: + _node_runtime[nid] = { + "in_flight": 0, + "healthy": None, + "last_check": None, + "last_error": "", + } + + +def save_nodes_config(data: Optional[Dict[str, Any]] = None) -> None: + global _NODES + payload = data if data is not None else _NODES + with open(NODES_PATH, "w", encoding="utf-8") as f: + json.dump(payload, f, indent=2, ensure_ascii=False) + f.write("\n") + load_nodes_config() + + +def get_node_base(node: Dict[str, Any]) -> str: + host = str(node.get("host", "127.0.0.1")).strip() or "127.0.0.1" + port = int(node.get("port", 0)) + return f"http://{host}:{port}".rstrip("/") + + +def _runtime(nid: str) -> Dict[str, Any]: + with _nodes_lock: + if nid not in _node_runtime: + _node_runtime[nid] = { + "in_flight": 0, + "healthy": None, + "last_check": None, + "last_error": "", + } + return _node_runtime[nid] + + +def node_in_flight_inc(nid: str) -> None: + with _nodes_lock: + rt = _runtime(nid) + rt["in_flight"] = int(rt.get("in_flight", 0)) + 1 + + +def node_in_flight_dec(nid: str) -> None: + with _nodes_lock: + rt = _runtime(nid) + rt["in_flight"] = max(0, int(rt.get("in_flight", 0)) - 1) + + +def compute_node_status(node: Dict[str, Any]) -> Tuple[str, str]: + nid = str(node.get("id", "")) + if not node.get("enabled", True): + return "disabled", STATUS_LABELS["disabled"] + rt = _runtime(nid) + if rt.get("healthy") is False: + err = str(rt.get("last_error") or "") + if err: + return "error", STATUS_LABELS["error"] + return "offline", STATUS_LABELS["offline"] + inflight = int(rt.get("in_flight", 0)) + max_c = int(node.get("max_concurrent", 1) or 1) + if inflight >= max_c: + return "busy", STATUS_LABELS["busy"] + if rt.get("healthy") is True: + return "idle", STATUS_LABELS["idle"] + return "offline", STATUS_LABELS["offline"] + + +def find_nodes_for_model(model_id: str) -> List[Dict[str, Any]]: + if not model_id: + return [] + out: List[Dict[str, Any]] = [] + for node in _NODES.get("nodes", []): + if not node.get("enabled", True): + continue + for m in node.get("models") or []: + if str(m.get("id", "")) == model_id: + out.append(node) + break + return out + + +def select_node_for_model(model_id: str) -> Optional[Dict[str, Any]]: + candidates = find_nodes_for_model(model_id) + if not candidates: + return None + scored: List[Tuple[int, str, Dict[str, Any]]] = [] + for node in candidates: + status, _ = compute_node_status(node) + if status in ("disabled", "error"): + continue + if status == "busy": + continue + nid = str(node["id"]) + scored.append((int(_runtime(nid).get("in_flight", 0)), nid, node)) + if not scored: + for node in candidates: + if compute_node_status(node)[0] != "disabled": + nid = str(node["id"]) + scored.append((int(_runtime(nid).get("in_flight", 0)), nid, node)) + if not scored: + return None + scored.sort(key=lambda x: (x[0], x[1])) + return scored[0][2] + + +async def probe_node( + client: httpx.AsyncClient, node: Dict[str, Any] +) -> Tuple[bool, str]: + url = get_node_base(node) + "/v1/models" + try: + resp = await client.get(url, timeout=5.0) + if resp.status_code < 500: + return True, "" + return False, f"HTTP {resp.status_code}" + except httpx.RequestError as e: + return False, str(e) + + +async def refresh_all_node_health(client: httpx.AsyncClient) -> None: + for node in _NODES.get("nodes", []): + nid = str(node.get("id", "")) + if not nid: + continue + ok, err = await probe_node(client, node) + with _nodes_lock: + rt = _runtime(nid) + rt["healthy"] = ok + rt["last_error"] = err + rt["last_check"] = datetime.now(timezone.utc).isoformat() + + +async def _health_loop(client: httpx.AsyncClient) -> None: + while True: + try: + await refresh_all_node_health(client) + except Exception: + pass + await asyncio.sleep(30) + + +def _query_model_stats_today() -> Dict[str, Dict[str, int]]: + today = _today_utc_prefix() + out: Dict[str, Dict[str, int]] = {} + with _db_lock: + conn = sqlite3.connect(STATS_DB) + conn.row_factory = sqlite3.Row + try: + rows = conn.execute( + """ + SELECT model, node_id, + COUNT(*) AS requests, + COALESCE(SUM(total_tokens), 0) AS total_tokens + FROM api_logs + WHERE created_at LIKE ? AND model IS NOT NULL + GROUP BY model, node_id + """, + (today + "%",), + ).fetchall() + finally: + conn.close() + for row in rows: + key = f"{row['model']}\0{row['node_id'] or ''}" + out[key] = { + "requests": int(row["requests"]), + "total_tokens": int(row["total_tokens"]), + } + return out + + +def build_model_cards() -> List[Dict[str, Any]]: + stats = _query_model_stats_today() + cards: List[Dict[str, Any]] = [] + for node in _NODES.get("nodes", []): + nid = str(node.get("id", "")) + status, status_label = compute_node_status(node) + host = str(node.get("host", "127.0.0.1")) + port = int(node.get("port", 0)) + models = node.get("models") or [] + if not models: + cards.append( + { + "node_id": nid, + "node_name": node.get("name", nid), + "model_id": "", + "model_label": "(未配置模型)", + "host": host, + "port": port, + "endpoint": f"{host}:{port}", + "status": status if node.get("enabled", True) else "disabled", + "status_label": status_label + if node.get("enabled", True) + else STATUS_LABELS["disabled"], + "in_flight": int(_runtime(nid).get("in_flight", 0)), + "max_concurrent": int(node.get("max_concurrent", 1)), + "today_requests": 0, + "today_tokens": 0, + "last_error": str(_runtime(nid).get("last_error") or ""), + "enabled": bool(node.get("enabled", True)), + } + ) + continue + for m in models: + mid = str(m.get("id", "")) + label = str(m.get("label") or mid) + st = status + st_label = status_label + if not node.get("enabled", True): + st, st_label = "disabled", STATUS_LABELS["disabled"] + sk = f"{mid}\0{nid}" + agg = stats.get(sk, {"requests": 0, "total_tokens": 0}) + cards.append( + { + "node_id": nid, + "node_name": node.get("name", nid), + "model_id": mid, + "model_label": label, + "host": host, + "port": port, + "endpoint": f"{host}:{port}", + "status": st, + "status_label": st_label, + "in_flight": int(_runtime(nid).get("in_flight", 0)), + "max_concurrent": int(node.get("max_concurrent", 1)), + "today_requests": agg["requests"], + "today_tokens": agg["total_tokens"], + "last_error": str(_runtime(nid).get("last_error") or ""), + "enabled": bool(node.get("enabled", True)), + } + ) + return cards + + +def list_nodes_with_status() -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + for node in _NODES.get("nodes", []): + status, status_label = compute_node_status(node) + nid = str(node.get("id", "")) + rt = _runtime(nid) + item = dict(node) + item["status"] = status + item["status_label"] = status_label + item["in_flight"] = int(rt.get("in_flight", 0)) + item["last_error"] = str(rt.get("last_error") or "") + item["last_check"] = rt.get("last_check") + rows.append(item) + return rows + + +def _new_node_id() -> str: + return "node-" + secrets.token_hex(4) + + +# --------------------------------------------------------------------------- +# 访问统计(SQLite) +# --------------------------------------------------------------------------- + + +def get_client_ip(request: Request) -> str: + """反代后取真实客户端 IP(优先 X-Forwarded-For / X-Real-IP)。""" + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + return forwarded.split(",")[0].strip() + real_ip = request.headers.get("x-real-ip") + if real_ip: + return real_ip.strip() + if request.client and request.client.host: + return request.client.host + return "unknown" + + +def parse_model_from_body(body: bytes) -> Optional[str]: + try: + data = json.loads(body) + model = data.get("model") + return str(model).strip() if model else None + except (json.JSONDecodeError, TypeError, ValueError): + return None + + +def parse_usage_from_response(raw: bytes, content_type: str) -> Dict[str, int]: + out = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + if not raw: + return out + try: + text = raw.decode("utf-8", errors="replace") + except Exception: + return out + + usage: Optional[Dict[str, Any]] = None + ct = (content_type or "").lower() + if "event-stream" in ct or text.lstrip().startswith("data:"): + for line in text.splitlines(): + line = line.strip() + if not line.startswith("data:"): + continue + payload = line[5:].strip() + if not payload or payload == "[DONE]": + continue + try: + obj = json.loads(payload) + except json.JSONDecodeError: + continue + u = obj.get("usage") + if isinstance(u, dict) and ( + u.get("total_tokens") or u.get("prompt_tokens") or u.get("completion_tokens") + ): + usage = u + else: + try: + obj = json.loads(text) + u = obj.get("usage") + if isinstance(u, dict): + usage = u + except json.JSONDecodeError: + pass + + if not usage: + return out + prompt = int(usage.get("prompt_tokens") or 0) + completion = int(usage.get("completion_tokens") or 0) + total = int(usage.get("total_tokens") or 0) + if not total: + total = prompt + completion + out["prompt_tokens"] = prompt + out["completion_tokens"] = completion + out["total_tokens"] = total + return out + + +def init_stats_db() -> None: + with _db_lock: + conn = sqlite3.connect(STATS_DB) + try: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS api_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created_at TEXT NOT NULL, + client_ip TEXT NOT NULL, + model TEXT, + node_id TEXT, + status_code INTEGER NOT NULL, + req_bytes INTEGER NOT NULL DEFAULT 0, + resp_bytes INTEGER NOT NULL DEFAULT 0, + prompt_tokens INTEGER NOT NULL DEFAULT 0, + completion_tokens INTEGER NOT NULL DEFAULT 0, + total_tokens INTEGER NOT NULL DEFAULT 0 + ) + """ + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_api_logs_created_at ON api_logs(created_at)" + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_api_logs_client_ip ON api_logs(client_ip)" + ) + try: + conn.execute("ALTER TABLE api_logs ADD COLUMN node_id TEXT") + except sqlite3.OperationalError: + pass + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_api_logs_node_id ON api_logs(node_id)" + ) + conn.commit() + finally: + conn.close() + + +def _record_api_log_sync( + client_ip: str, + model: Optional[str], + status_code: int, + req_bytes: int, + resp_bytes: int, + usage: Dict[str, int], + node_id: Optional[str] = None, +) -> None: + created_at = datetime.now(timezone.utc).isoformat() + with _db_lock: + conn = sqlite3.connect(STATS_DB) + try: + conn.execute( + """ + INSERT INTO api_logs ( + created_at, client_ip, model, node_id, status_code, + req_bytes, resp_bytes, + prompt_tokens, completion_tokens, total_tokens + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + created_at, + client_ip, + model, + node_id, + status_code, + req_bytes, + resp_bytes, + usage["prompt_tokens"], + usage["completion_tokens"], + usage["total_tokens"], + ), + ) + conn.commit() + finally: + conn.close() + + +async def record_api_log( + client_ip: str, + model: Optional[str], + status_code: int, + req_bytes: int, + resp_bytes: int, + usage: Dict[str, int], + node_id: Optional[str] = None, +) -> None: + await asyncio.to_thread( + _record_api_log_sync, + client_ip, + model, + status_code, + req_bytes, + resp_bytes, + usage, + node_id, + ) + + +def _today_utc_prefix() -> str: + return datetime.now(timezone.utc).strftime("%Y-%m-%d") + + +def _query_stats_summary() -> Dict[str, Any]: + today = _today_utc_prefix() + with _db_lock: + conn = sqlite3.connect(STATS_DB) + conn.row_factory = sqlite3.Row + try: + row = conn.execute( + """ + SELECT + COUNT(*) AS total_requests, + COALESCE(SUM(total_tokens), 0) AS total_tokens, + COUNT(DISTINCT client_ip) AS unique_ips + FROM api_logs + """ + ).fetchone() + today_row = conn.execute( + """ + SELECT + COUNT(*) AS today_requests, + COALESCE(SUM(total_tokens), 0) AS today_tokens + FROM api_logs + WHERE created_at LIKE ? + """, + (today + "%",), + ).fetchone() + finally: + conn.close() + return { + "total_requests": int(row["total_requests"]), + "total_tokens": int(row["total_tokens"]), + "unique_ips": int(row["unique_ips"]), + "today_requests": int(today_row["today_requests"]), + "today_tokens": int(today_row["today_tokens"]), + } + + +def _query_stats_ips(limit: int = 200) -> List[Dict[str, Any]]: + with _db_lock: + conn = sqlite3.connect(STATS_DB) + conn.row_factory = sqlite3.Row + try: + rows = conn.execute( + """ + SELECT + client_ip AS ip, + COUNT(*) AS requests, + COALESCE(SUM(total_tokens), 0) AS total_tokens, + COALESCE(SUM(prompt_tokens), 0) AS prompt_tokens, + COALESCE(SUM(completion_tokens), 0) AS completion_tokens, + MAX(created_at) AS last_seen + FROM api_logs + GROUP BY client_ip + ORDER BY last_seen DESC + LIMIT ? + """, + (limit,), + ).fetchall() + finally: + conn.close() + return [dict(r) for r in rows] + + +def _query_stats_billing(days: int) -> Dict[str, Any]: + since = (datetime.now(timezone.utc) - timedelta(days=days)).date().isoformat() + with _db_lock: + conn = sqlite3.connect(STATS_DB) + conn.row_factory = sqlite3.Row + try: + by_day = conn.execute( + """ + SELECT + substr(created_at, 1, 10) AS day, + COUNT(*) AS requests, + COALESCE(SUM(prompt_tokens), 0) AS prompt_tokens, + COALESCE(SUM(completion_tokens), 0) AS completion_tokens, + COALESCE(SUM(total_tokens), 0) AS total_tokens + FROM api_logs + WHERE substr(created_at, 1, 10) >= ? + GROUP BY day + ORDER BY day DESC + """, + (since,), + ).fetchall() + by_ip = conn.execute( + """ + SELECT + client_ip AS ip, + COUNT(*) AS requests, + COALESCE(SUM(prompt_tokens), 0) AS prompt_tokens, + COALESCE(SUM(completion_tokens), 0) AS completion_tokens, + COALESCE(SUM(total_tokens), 0) AS total_tokens + FROM api_logs + WHERE substr(created_at, 1, 10) >= ? + GROUP BY client_ip + ORDER BY total_tokens DESC + LIMIT 100 + """, + (since,), + ).fetchall() + finally: + conn.close() + return { + "days": days, + "by_day": [dict(r) for r in by_day], + "by_ip": [dict(r) for r in by_ip], + } + + +def _query_stats_logs(limit: int, offset: int) -> Tuple[List[Dict[str, Any]], int]: + limit = max(1, min(limit, 500)) + offset = max(0, offset) + with _db_lock: + conn = sqlite3.connect(STATS_DB) + conn.row_factory = sqlite3.Row + try: + total = conn.execute("SELECT COUNT(*) FROM api_logs").fetchone()[0] + rows = conn.execute( + """ + SELECT + id, created_at, client_ip, model, status_code, + req_bytes, resp_bytes, + prompt_tokens, completion_tokens, total_tokens + FROM api_logs + ORDER BY id DESC + LIMIT ? OFFSET ? + """, + (limit, offset), + ).fetchall() + finally: + conn.close() + return [dict(r) for r in rows], int(total) + + +# --------------------------------------------------------------------------- +# Pydantic +# --------------------------------------------------------------------------- + + +class LoginBody(BaseModel): + username: str = Field(..., min_length=1) + password: str = Field(..., min_length=1) + + +class ModelEntryIn(BaseModel): + id: str = Field(..., min_length=1) + label: str = "" + + +class NodeBody(BaseModel): + name: str = Field(..., min_length=1) + host: str = "127.0.0.1" + port: int = Field(..., ge=1, le=65535) + enabled: bool = True + max_concurrent: int = Field(1, ge=1, le=32) + models: List[ModelEntryIn] = Field(default_factory=list) + + +# --------------------------------------------------------------------------- +# 依赖 +# --------------------------------------------------------------------------- + + +async def get_current_web_user( + creds: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security_bearer)], +) -> GateSessionUser: + if creds is None or creds.scheme.lower() != "bearer": + raise HTTPException(status_code=401, detail="需要登录(Bearer Token)") + token = creds.credentials + if token.startswith("sk-"): + raise HTTPException(status_code=401, detail="网页登录请使用登录接口返回的令牌,不是 API Key") + + uid = decode_web_token(token) + if uid != GATE_WEB_UID: + raise HTTPException(status_code=401, detail="登录状态无效") + return GateSessionUser(username=_GATE["username"], api_key=_GATE["api_key"]) + + +async def get_user_from_api_key( + creds: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security_bearer)], +) -> GateSessionUser: + if creds is None or creds.scheme.lower() != "bearer": + raise HTTPException( + status_code=401, + detail='请在 Header 中携带 Authorization: Bearer sk-...', + ) + key = creds.credentials.strip() + if not key.startswith("sk-"): + raise HTTPException(status_code=401, detail="API Key 必须以 sk- 开头") + if key != _GATE.get("api_key"): + raise HTTPException(status_code=401, detail="无效的 API Key") + return GateSessionUser(username=_GATE["username"], api_key=key) + + +# --------------------------------------------------------------------------- +# HTML(Tailwind CDN) +# --------------------------------------------------------------------------- + +TW_SCRIPT = "https://cdn.tailwindcss.com" + +SHELL_HEAD = f""" + + + + + + {{title}} + + + + + + + + +
+
+
+
+
+ +
+
+""" + + +SHELL_FOOT = """ +
+ + + +""" + + +def page(title: str, inner: str) -> str: + return ( + SHELL_HEAD.replace("{title}", title) + + inner + + SHELL_FOOT + ) + + +HOME_HTML = page( + "中转网关", + f""" + +
+ +
+

模型分布

+

各节点模型状态(约每 5 秒刷新;主机默认 127.0.0.1,经 frp 映射端口)

+
+ 管理节点与模型 → +
+
加载中…
+