Files
2026-05-19 02:42:59 +08:00

1866 lines
74 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
LLM 中转网关:FastAPI + 嵌入式前端(单文件)。
账号与 API Key 从 JSON 配置文件读取,不提供注册。
运行:pip install -r requirements.txt && uvicorn main:app --host 0.0.0.0 --port 8000
配置:复制 gateway.json.example 为 gateway.json,填写 username、password。
api_key 可留空,首次启动会自动生成 sk-... 并写回 gateway.json。
环境变量(可选):
JWT_SECRET JWT 签名密钥(生产环境务必修改)
UPSTREAM_URL 上游 OpenAI 兼容地址,默认 http://127.0.0.1:10434
APP_ROOT 反代子路径前缀,例如 /wg
GATEWAY_CONFIG 配置文件路径,默认当前目录下 gateway.json
NODES_CONFIG 节点/模型配置,默认 nodes.json
STATS_DB 访问统计 SQLite 路径,默认 gateway_stats.db
GATEWAY_PORT 监听端口,默认 8150
"""
from __future__ import annotations
import asyncio
import json
import os
import secrets
import sqlite3
import threading
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Tuple, TypeVar
try:
from typing import Annotated # Python 3.9+
except ImportError: # Python 3.8
from typing_extensions import Annotated
import httpx
_T = TypeVar("_T")
async def run_in_thread(func: Callable[..., _T], /, *args: Any) -> _T:
"""在线程池执行阻塞函数(兼容 Python 3.8,无 asyncio.to_thread)。"""
if hasattr(asyncio, "to_thread"):
return await run_in_thread(func, *args) # type: ignore[attr-defined]
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, func, *args)
from fastapi import Depends, FastAPI, HTTPException, Query, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
import bcrypt
from jose import JWTError, jwt
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# 配置
# ---------------------------------------------------------------------------
CONFIG_PATH = os.environ.get("GATEWAY_CONFIG", "gateway.json")
NODES_PATH = os.environ.get("NODES_CONFIG", "nodes.json")
STATS_DB = os.environ.get("STATS_DB", "gateway_stats.db")
GATEWAY_PORT = int(os.environ.get("GATEWAY_PORT", "8150"))
JWT_SECRET = os.environ.get("JWT_SECRET", "dev-secret-change-me")
JWT_ALG = "HS256"
JWT_EXPIRE_DAYS = 7
GATE_WEB_UID = 1
UPSTREAM_BASE = os.environ.get("UPSTREAM_URL", "http://127.0.0.1:10434").rstrip("/")
APP_ROOT = os.environ.get("APP_ROOT", "").rstrip("/")
def app_url(path: str) -> str:
if not path.startswith("/"):
path = "/" + path
return APP_ROOT + path if APP_ROOT else path
_APP_ROOT_JSON = json.dumps(APP_ROOT)
security_bearer = HTTPBearer(auto_error=False)
# 启动时由 load_gateway_config() 填充
_GATE: Dict[str, Any] = {}
_PASSWORD_HASH: str = ""
_db_lock = threading.Lock()
_nodes_lock = threading.Lock()
_NODES: Dict[str, Any] = {"nodes": []}
_node_runtime: Dict[str, Dict[str, Any]] = {}
_health_task: Optional[asyncio.Task] = None
STATUS_LABELS: Dict[str, str] = {
"offline": "离线",
"idle": "空闲",
"busy": "忙碌",
"disabled": "未启用",
"error": "异常",
}
def _is_bcrypt_hash(value: str) -> bool:
return value.startswith(("$2a$", "$2b$", "$2y$"))
def hash_password(p: str) -> str:
return bcrypt.hashpw(p.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
def verify_password(plain: str, hashed: str) -> bool:
try:
return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8"))
except (ValueError, TypeError):
return False
def generate_api_key() -> str:
return "sk-" + secrets.token_urlsafe(32)
def load_gateway_config() -> None:
"""读取 gateway.json;若无 api_key 则生成并写回文件。"""
global _GATE, _PASSWORD_HASH
if not os.path.isfile(CONFIG_PATH):
raise RuntimeError(
f"缺少配置文件: {CONFIG_PATH}(可复制 gateway.json.example 为 gateway.json"
)
with open(CONFIG_PATH, encoding="utf-8") as f:
raw = json.load(f)
username = str(raw.get("username", "")).strip()
password = raw.get("password", "")
if not isinstance(password, str):
password = str(password)
if not username or not password:
raise ValueError(f"{CONFIG_PATH} 须包含非空的 username 与 password")
api_key = str(raw.get("api_key", "")).strip()
changed = False
if not api_key:
api_key = generate_api_key()
raw["api_key"] = api_key
changed = True
if changed:
with open(CONFIG_PATH, "w", encoding="utf-8") as wf:
json.dump(raw, wf, indent=2, ensure_ascii=False)
wf.write("\n")
_GATE = {"username": username, "api_key": api_key}
if _is_bcrypt_hash(password):
_PASSWORD_HASH = password
else:
_PASSWORD_HASH = hash_password(password)
def create_web_token(user_id: int) -> str:
expire = datetime.now(timezone.utc) + timedelta(days=JWT_EXPIRE_DAYS)
payload = {"sub": str(user_id), "exp": expire, "typ": "web"}
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALG)
def decode_web_token(token: str) -> int:
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALG])
if payload.get("typ") != "web":
raise JWTError("wrong token type")
sub = payload.get("sub")
if not sub:
raise JWTError("missing sub")
return int(sub)
except JWTError as e:
raise HTTPException(status_code=401, detail="无效或过期的登录状态") from e
@dataclass
class GateSessionUser:
username: str
api_key: str
# ---------------------------------------------------------------------------
# 节点 / 模型配置(nodes.json
# ---------------------------------------------------------------------------
def _default_nodes_payload() -> Dict[str, Any]:
return {
"nodes": [
{
"id": "node-1",
"name": "节点 1",
"host": "127.0.0.1",
"port": 3313,
"enabled": True,
"max_concurrent": 1,
"models": [],
},
{
"id": "node-2",
"name": "节点 2",
"host": "127.0.0.1",
"port": 3314,
"enabled": True,
"max_concurrent": 1,
"models": [],
},
{
"id": "node-3",
"name": "节点 3",
"host": "127.0.0.1",
"port": 3315,
"enabled": True,
"max_concurrent": 1,
"models": [],
},
]
}
def ensure_nodes_file() -> None:
if not os.path.isfile(NODES_PATH):
example = NODES_PATH + ".example"
if os.path.isfile(example):
with open(example, encoding="utf-8") as f:
data = json.load(f)
else:
data = _default_nodes_payload()
save_nodes_config(data)
def load_nodes_config() -> None:
global _NODES, _node_runtime
ensure_nodes_file()
with open(NODES_PATH, encoding="utf-8") as f:
data = json.load(f)
nodes = data.get("nodes")
if not isinstance(nodes, list):
nodes = []
_NODES = {"nodes": nodes}
with _nodes_lock:
for n in nodes:
nid = str(n.get("id", ""))
if nid and nid not in _node_runtime:
_node_runtime[nid] = {
"in_flight": 0,
"healthy": None,
"last_check": None,
"last_error": "",
}
def save_nodes_config(data: Optional[Dict[str, Any]] = None) -> None:
global _NODES
payload = data if data is not None else _NODES
with open(NODES_PATH, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2, ensure_ascii=False)
f.write("\n")
load_nodes_config()
def get_node_base(node: Dict[str, Any]) -> str:
host = str(node.get("host", "127.0.0.1")).strip() or "127.0.0.1"
port = int(node.get("port", 0))
return f"http://{host}:{port}".rstrip("/")
def _runtime_unlocked(nid: str) -> Dict[str, Any]:
"""获取节点运行时状态(调用方已持有 _nodes_lock 时使用)。"""
if nid not in _node_runtime:
_node_runtime[nid] = {
"in_flight": 0,
"healthy": None,
"last_check": None,
"last_error": "",
}
return _node_runtime[nid]
def _runtime(nid: str) -> Dict[str, Any]:
with _nodes_lock:
return _runtime_unlocked(nid)
def node_in_flight_inc(nid: str) -> None:
with _nodes_lock:
rt = _runtime_unlocked(nid)
rt["in_flight"] = int(rt.get("in_flight", 0)) + 1
def node_in_flight_dec(nid: str) -> None:
with _nodes_lock:
rt = _runtime_unlocked(nid)
rt["in_flight"] = max(0, int(rt.get("in_flight", 0)) - 1)
def compute_node_status(node: Dict[str, Any]) -> Tuple[str, str]:
nid = str(node.get("id", ""))
if not node.get("enabled", True):
return "disabled", STATUS_LABELS["disabled"]
rt = _runtime(nid)
if rt.get("healthy") is False:
err = str(rt.get("last_error") or "")
if err:
return "error", STATUS_LABELS["error"]
return "offline", STATUS_LABELS["offline"]
inflight = int(rt.get("in_flight", 0))
max_c = int(node.get("max_concurrent", 1) or 1)
if inflight >= max_c:
return "busy", STATUS_LABELS["busy"]
if rt.get("healthy") is True:
return "idle", STATUS_LABELS["idle"]
return "offline", STATUS_LABELS["offline"]
def find_nodes_for_model(model_id: str) -> List[Dict[str, Any]]:
if not model_id:
return []
out: List[Dict[str, Any]] = []
for node in _NODES.get("nodes", []):
if not node.get("enabled", True):
continue
for m in node.get("models") or []:
if str(m.get("id", "")) == model_id:
out.append(node)
break
return out
def select_node_for_model(model_id: str) -> Optional[Dict[str, Any]]:
candidates = find_nodes_for_model(model_id)
if not candidates:
return None
scored: List[Tuple[int, str, Dict[str, Any]]] = []
for node in candidates:
status, _ = compute_node_status(node)
if status in ("disabled", "error"):
continue
if status == "busy":
continue
nid = str(node["id"])
scored.append((int(_runtime(nid).get("in_flight", 0)), nid, node))
if not scored:
for node in candidates:
if compute_node_status(node)[0] != "disabled":
nid = str(node["id"])
scored.append((int(_runtime(nid).get("in_flight", 0)), nid, node))
if not scored:
return None
scored.sort(key=lambda x: (x[0], x[1]))
return scored[0][2]
async def probe_node(
client: httpx.AsyncClient, node: Dict[str, Any]
) -> Tuple[bool, str]:
url = get_node_base(node) + "/v1/models"
try:
resp = await client.get(url, timeout=5.0)
if resp.status_code < 500:
return True, ""
return False, f"HTTP {resp.status_code}"
except httpx.RequestError as e:
return False, str(e)
async def refresh_all_node_health(client: httpx.AsyncClient) -> None:
for node in _NODES.get("nodes", []):
nid = str(node.get("id", ""))
if not nid:
continue
ok, err = await probe_node(client, node)
with _nodes_lock:
rt = _runtime_unlocked(nid)
rt["healthy"] = ok
rt["last_error"] = err
rt["last_check"] = datetime.now(timezone.utc).isoformat()
async def _health_loop(client: httpx.AsyncClient) -> None:
while True:
try:
await refresh_all_node_health(client)
except Exception:
pass
await asyncio.sleep(30)
def _query_model_stats_today() -> Dict[str, Dict[str, int]]:
today = _today_utc_prefix()
out: Dict[str, Dict[str, int]] = {}
with _db_lock:
conn = sqlite3.connect(STATS_DB)
conn.row_factory = sqlite3.Row
try:
rows = conn.execute(
"""
SELECT model, node_id,
COUNT(*) AS requests,
COALESCE(SUM(total_tokens), 0) AS total_tokens
FROM api_logs
WHERE created_at LIKE ? AND model IS NOT NULL
GROUP BY model, node_id
""",
(today + "%",),
).fetchall()
finally:
conn.close()
for row in rows:
key = f"{row['model']}\0{row['node_id'] or ''}"
out[key] = {
"requests": int(row["requests"]),
"total_tokens": int(row["total_tokens"]),
}
return out
def build_model_cards() -> List[Dict[str, Any]]:
stats = _query_model_stats_today()
cards: List[Dict[str, Any]] = []
for node in _NODES.get("nodes", []):
nid = str(node.get("id", ""))
status, status_label = compute_node_status(node)
host = str(node.get("host", "127.0.0.1"))
port = int(node.get("port", 0))
models = node.get("models") or []
if not models:
cards.append(
{
"node_id": nid,
"node_name": node.get("name", nid),
"model_id": "",
"model_label": "(未配置模型)",
"host": host,
"port": port,
"endpoint": f"{host}:{port}",
"status": status if node.get("enabled", True) else "disabled",
"status_label": status_label
if node.get("enabled", True)
else STATUS_LABELS["disabled"],
"in_flight": int(_runtime(nid).get("in_flight", 0)),
"max_concurrent": int(node.get("max_concurrent", 1)),
"today_requests": 0,
"today_tokens": 0,
"last_error": str(_runtime(nid).get("last_error") or ""),
"enabled": bool(node.get("enabled", True)),
}
)
continue
for m in models:
mid = str(m.get("id", ""))
label = str(m.get("label") or mid)
st = status
st_label = status_label
if not node.get("enabled", True):
st, st_label = "disabled", STATUS_LABELS["disabled"]
sk = f"{mid}\0{nid}"
agg = stats.get(sk, {"requests": 0, "total_tokens": 0})
cards.append(
{
"node_id": nid,
"node_name": node.get("name", nid),
"model_id": mid,
"model_label": label,
"host": host,
"port": port,
"endpoint": f"{host}:{port}",
"status": st,
"status_label": st_label,
"in_flight": int(_runtime(nid).get("in_flight", 0)),
"max_concurrent": int(node.get("max_concurrent", 1)),
"today_requests": agg["requests"],
"today_tokens": agg["total_tokens"],
"last_error": str(_runtime(nid).get("last_error") or ""),
"enabled": bool(node.get("enabled", True)),
}
)
return cards
def list_nodes_with_status() -> List[Dict[str, Any]]:
rows: List[Dict[str, Any]] = []
for node in _NODES.get("nodes", []):
status, status_label = compute_node_status(node)
nid = str(node.get("id", ""))
rt = _runtime(nid)
item = dict(node)
item["status"] = status
item["status_label"] = status_label
item["in_flight"] = int(rt.get("in_flight", 0))
item["last_error"] = str(rt.get("last_error") or "")
item["last_check"] = rt.get("last_check")
rows.append(item)
return rows
def _new_node_id() -> str:
return "node-" + secrets.token_hex(4)
# ---------------------------------------------------------------------------
# 访问统计(SQLite
# ---------------------------------------------------------------------------
def get_client_ip(request: Request) -> str:
"""反代后取真实客户端 IP(优先 X-Forwarded-For / X-Real-IP)。"""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip.strip()
if request.client and request.client.host:
return request.client.host
return "unknown"
def parse_model_from_body(body: bytes) -> Optional[str]:
try:
data = json.loads(body)
model = data.get("model")
return str(model).strip() if model else None
except (json.JSONDecodeError, TypeError, ValueError):
return None
def parse_usage_from_response(raw: bytes, content_type: str) -> Dict[str, int]:
out = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
if not raw:
return out
try:
text = raw.decode("utf-8", errors="replace")
except Exception:
return out
usage: Optional[Dict[str, Any]] = None
ct = (content_type or "").lower()
if "event-stream" in ct or text.lstrip().startswith("data:"):
for line in text.splitlines():
line = line.strip()
if not line.startswith("data:"):
continue
payload = line[5:].strip()
if not payload or payload == "[DONE]":
continue
try:
obj = json.loads(payload)
except json.JSONDecodeError:
continue
u = obj.get("usage")
if isinstance(u, dict) and (
u.get("total_tokens") or u.get("prompt_tokens") or u.get("completion_tokens")
):
usage = u
else:
try:
obj = json.loads(text)
u = obj.get("usage")
if isinstance(u, dict):
usage = u
except json.JSONDecodeError:
pass
if not usage:
return out
prompt = int(usage.get("prompt_tokens") or 0)
completion = int(usage.get("completion_tokens") or 0)
total = int(usage.get("total_tokens") or 0)
if not total:
total = prompt + completion
out["prompt_tokens"] = prompt
out["completion_tokens"] = completion
out["total_tokens"] = total
return out
def init_stats_db() -> None:
with _db_lock:
conn = sqlite3.connect(STATS_DB)
try:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS api_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TEXT NOT NULL,
client_ip TEXT NOT NULL,
model TEXT,
node_id TEXT,
status_code INTEGER NOT NULL,
req_bytes INTEGER NOT NULL DEFAULT 0,
resp_bytes INTEGER NOT NULL DEFAULT 0,
prompt_tokens INTEGER NOT NULL DEFAULT 0,
completion_tokens INTEGER NOT NULL DEFAULT 0,
total_tokens INTEGER NOT NULL DEFAULT 0
)
"""
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_api_logs_created_at ON api_logs(created_at)"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_api_logs_client_ip ON api_logs(client_ip)"
)
try:
conn.execute("ALTER TABLE api_logs ADD COLUMN node_id TEXT")
except sqlite3.OperationalError:
pass
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_api_logs_node_id ON api_logs(node_id)"
)
conn.commit()
finally:
conn.close()
def _record_api_log_sync(
client_ip: str,
model: Optional[str],
status_code: int,
req_bytes: int,
resp_bytes: int,
usage: Dict[str, int],
node_id: Optional[str] = None,
) -> None:
created_at = datetime.now(timezone.utc).isoformat()
with _db_lock:
conn = sqlite3.connect(STATS_DB)
try:
conn.execute(
"""
INSERT INTO api_logs (
created_at, client_ip, model, node_id, status_code,
req_bytes, resp_bytes,
prompt_tokens, completion_tokens, total_tokens
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
created_at,
client_ip,
model,
node_id,
status_code,
req_bytes,
resp_bytes,
usage["prompt_tokens"],
usage["completion_tokens"],
usage["total_tokens"],
),
)
conn.commit()
finally:
conn.close()
async def record_api_log(
client_ip: str,
model: Optional[str],
status_code: int,
req_bytes: int,
resp_bytes: int,
usage: Dict[str, int],
node_id: Optional[str] = None,
) -> None:
await run_in_thread(
_record_api_log_sync,
client_ip,
model,
status_code,
req_bytes,
resp_bytes,
usage,
node_id,
)
def _today_utc_prefix() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m-%d")
def _query_stats_summary() -> Dict[str, Any]:
today = _today_utc_prefix()
with _db_lock:
conn = sqlite3.connect(STATS_DB)
conn.row_factory = sqlite3.Row
try:
row = conn.execute(
"""
SELECT
COUNT(*) AS total_requests,
COALESCE(SUM(total_tokens), 0) AS total_tokens,
COUNT(DISTINCT client_ip) AS unique_ips
FROM api_logs
"""
).fetchone()
today_row = conn.execute(
"""
SELECT
COUNT(*) AS today_requests,
COALESCE(SUM(total_tokens), 0) AS today_tokens
FROM api_logs
WHERE created_at LIKE ?
""",
(today + "%",),
).fetchone()
finally:
conn.close()
return {
"total_requests": int(row["total_requests"]),
"total_tokens": int(row["total_tokens"]),
"unique_ips": int(row["unique_ips"]),
"today_requests": int(today_row["today_requests"]),
"today_tokens": int(today_row["today_tokens"]),
}
def _query_stats_ips(limit: int = 200) -> List[Dict[str, Any]]:
with _db_lock:
conn = sqlite3.connect(STATS_DB)
conn.row_factory = sqlite3.Row
try:
rows = conn.execute(
"""
SELECT
client_ip AS ip,
COUNT(*) AS requests,
COALESCE(SUM(total_tokens), 0) AS total_tokens,
COALESCE(SUM(prompt_tokens), 0) AS prompt_tokens,
COALESCE(SUM(completion_tokens), 0) AS completion_tokens,
MAX(created_at) AS last_seen
FROM api_logs
GROUP BY client_ip
ORDER BY last_seen DESC
LIMIT ?
""",
(limit,),
).fetchall()
finally:
conn.close()
return [dict(r) for r in rows]
def _query_stats_billing(days: int) -> Dict[str, Any]:
since = (datetime.now(timezone.utc) - timedelta(days=days)).date().isoformat()
with _db_lock:
conn = sqlite3.connect(STATS_DB)
conn.row_factory = sqlite3.Row
try:
by_day = conn.execute(
"""
SELECT
substr(created_at, 1, 10) AS day,
COUNT(*) AS requests,
COALESCE(SUM(prompt_tokens), 0) AS prompt_tokens,
COALESCE(SUM(completion_tokens), 0) AS completion_tokens,
COALESCE(SUM(total_tokens), 0) AS total_tokens
FROM api_logs
WHERE substr(created_at, 1, 10) >= ?
GROUP BY day
ORDER BY day DESC
""",
(since,),
).fetchall()
by_ip = conn.execute(
"""
SELECT
client_ip AS ip,
COUNT(*) AS requests,
COALESCE(SUM(prompt_tokens), 0) AS prompt_tokens,
COALESCE(SUM(completion_tokens), 0) AS completion_tokens,
COALESCE(SUM(total_tokens), 0) AS total_tokens
FROM api_logs
WHERE substr(created_at, 1, 10) >= ?
GROUP BY client_ip
ORDER BY total_tokens DESC
LIMIT 100
""",
(since,),
).fetchall()
finally:
conn.close()
return {
"days": days,
"by_day": [dict(r) for r in by_day],
"by_ip": [dict(r) for r in by_ip],
}
def _query_stats_logs(limit: int, offset: int) -> Tuple[List[Dict[str, Any]], int]:
limit = max(1, min(limit, 500))
offset = max(0, offset)
with _db_lock:
conn = sqlite3.connect(STATS_DB)
conn.row_factory = sqlite3.Row
try:
total = conn.execute("SELECT COUNT(*) FROM api_logs").fetchone()[0]
rows = conn.execute(
"""
SELECT
id, created_at, client_ip, model, status_code,
req_bytes, resp_bytes,
prompt_tokens, completion_tokens, total_tokens
FROM api_logs
ORDER BY id DESC
LIMIT ? OFFSET ?
""",
(limit, offset),
).fetchall()
finally:
conn.close()
return [dict(r) for r in rows], int(total)
# ---------------------------------------------------------------------------
# Pydantic
# ---------------------------------------------------------------------------
class LoginBody(BaseModel):
username: str = Field(..., min_length=1)
password: str = Field(..., min_length=1)
class ModelEntryIn(BaseModel):
id: str = Field(..., min_length=1)
label: str = ""
class NodeBody(BaseModel):
name: str = Field(..., min_length=1)
host: str = "127.0.0.1"
port: int = Field(..., ge=1, le=65535)
enabled: bool = True
max_concurrent: int = Field(1, ge=1, le=32)
models: List[ModelEntryIn] = Field(default_factory=list)
# ---------------------------------------------------------------------------
# 依赖
# ---------------------------------------------------------------------------
async def get_current_web_user(
creds: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security_bearer)],
) -> GateSessionUser:
if creds is None or creds.scheme.lower() != "bearer":
raise HTTPException(status_code=401, detail="需要登录(Bearer Token")
token = creds.credentials
if token.startswith("sk-"):
raise HTTPException(status_code=401, detail="网页登录请使用登录接口返回的令牌,不是 API Key")
uid = decode_web_token(token)
if uid != GATE_WEB_UID:
raise HTTPException(status_code=401, detail="登录状态无效")
return GateSessionUser(username=_GATE["username"], api_key=_GATE["api_key"])
async def get_user_from_api_key(
creds: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security_bearer)],
) -> GateSessionUser:
if creds is None or creds.scheme.lower() != "bearer":
raise HTTPException(
status_code=401,
detail='请在 Header 中携带 Authorization: Bearer sk-...',
)
key = creds.credentials.strip()
if not key.startswith("sk-"):
raise HTTPException(status_code=401, detail="API Key 必须以 sk- 开头")
if key != _GATE.get("api_key"):
raise HTTPException(status_code=401, detail="无效的 API Key")
return GateSessionUser(username=_GATE["username"], api_key=key)
# ---------------------------------------------------------------------------
# HTMLTailwind CDN
# ---------------------------------------------------------------------------
TW_SCRIPT = "https://cdn.tailwindcss.com"
SHELL_HEAD = f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<title>{{title}}</title>
<script src="{TW_SCRIPT}"></script>
<script>
tailwind.config = {{
theme: {{
extend: {{
fontFamily: {{
sans: ['Plus Jakarta Sans', 'system-ui', 'sans-serif'],
}},
colors: {{
ink: {{ 950: '#0b1220', 900: '#101827', 700: '#334155' }},
}},
}},
}},
}};
</script>
<link rel="preconnect" href="https://fonts.googleapis.com"/>
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin/>
<link href="https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@400;500;600;700&display=swap" rel="stylesheet"/>
</head>
<body class="min-h-screen bg-gradient-to-br from-slate-950 via-indigo-950 to-slate-900 text-slate-100 antialiased" data-app-root="{APP_ROOT}">
<script>
window.__APP_ROOT__ = {_APP_ROOT_JSON};
function wgApi(path) {{
if (!path || path.charAt(0) !== '/') path = '/' + (path || '');
var root = (window.__APP_ROOT__ || '').replace(/\\/+$/, '');
return root + path;
}}
function wgFmtErr(res, data) {{
try {{
if (data && data.detail !== undefined) {{
var d = data.detail;
if (typeof d === 'string') return d;
if (Array.isArray(d))
return d.map(function (e) {{
var loc = e.loc ? e.loc.join('.') + ': ' : '';
return loc + (e.msg || JSON.stringify(e));
}}).join('') || '校验失败';
if (typeof d === 'object') return JSON.stringify(d);
}}
}} catch (e) {{}}
return res && res.status ? ('HTTP ' + res.status + ' ' + (res.statusText || '')) : '网络或服务器错误';
}}
(function() {{
if (!localStorage.getItem('web_token')) {{ location.replace(wgApi('/login')); return; }}
document.addEventListener('DOMContentLoaded', function() {{
var btn = document.getElementById('wg-logout');
if (btn) btn.addEventListener('click', function() {{
localStorage.removeItem('web_token');
location.href = wgApi('/login');
}});
}});
}})();
</script>
<div class="pointer-events-none fixed inset-0 overflow-hidden">
<div class="absolute -left-32 top-20 h-72 w-72 rounded-full bg-indigo-500/20 blur-3xl"></div>
<div class="absolute -right-20 bottom-10 h-96 w-96 rounded-full bg-cyan-500/15 blur-3xl"></div>
</div>
<header class="relative z-10 border-b border-white/10 bg-black/20 backdrop-blur-md">
<div class="mx-auto flex max-w-5xl items-center justify-between px-6 py-4">
<a href="{app_url('/home')}" class="flex items-center gap-2 text-lg font-semibold tracking-tight text-white hover:text-indigo-200 transition">
<span class="inline-flex h-9 w-9 items-center justify-center rounded-xl bg-gradient-to-br from-indigo-400 to-cyan-400 text-slate-950 font-bold">AI</span>
<span>中转网关</span>
</a>
<nav class="flex items-center gap-6 text-sm font-medium text-slate-300">
<a href="{app_url('/home')}" class="hover:text-white transition">首页</a>
<a href="{app_url('/stats')}" class="hover:text-white transition">流量统计</a>
<a href="{app_url('/settings')}" class="hover:text-white transition">系统设置</a>
<a href="{app_url('/user')}" class="rounded-full bg-white/10 px-4 py-2 text-white ring-1 ring-white/15 hover:bg-white/15 transition">用户中心</a>
<button type="button" id="wg-logout" class="text-slate-400 hover:text-white transition">退出</button>
</nav>
</div>
</header>
<main class="relative z-10 mx-auto max-w-5xl px-6 py-14">
"""
SHELL_FOOT = """
</main>
<footer class="relative z-10 border-t border-white/10 py-8 text-center text-xs text-slate-500">
LLM Gateway · JSON 配置账号 · OpenAI 兼容
</footer>
</body>
</html>
"""
LOGIN_SHELL_HEAD = f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<title>{{title}}</title>
<script src="{TW_SCRIPT}"></script>
<link rel="preconnect" href="https://fonts.googleapis.com"/>
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin/>
<link href="https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@400;500;600;700&display=swap" rel="stylesheet"/>
</head>
<body class="min-h-screen bg-gradient-to-br from-slate-950 via-indigo-950 to-slate-900 text-slate-100 antialiased">
<script>
window.__APP_ROOT__ = {_APP_ROOT_JSON};
function wgApi(path) {{
if (!path || path.charAt(0) !== '/') path = '/' + (path || '');
var root = (window.__APP_ROOT__ || '').replace(/\\/+$/, '');
return root + path;
}}
function wgFmtErr(res, data) {{
try {{
if (data && data.detail !== undefined) {{
var d = data.detail;
if (typeof d === 'string') return d;
if (Array.isArray(d)) return d.map(function(e) {{ return (e.loc?e.loc.join('.'):'') + (e.msg||''); }}).join('');
if (typeof d === 'object') return JSON.stringify(d);
}}
}} catch (e) {{}}
return res && res.status ? ('HTTP ' + res.status) : '网络或服务器错误';
}}
if (localStorage.getItem('web_token')) location.replace(wgApi('/home'));
</script>
<div class="pointer-events-none fixed inset-0 overflow-hidden">
<div class="absolute -left-32 top-20 h-72 w-72 rounded-full bg-indigo-500/20 blur-3xl"></div>
<div class="absolute -right-20 bottom-10 h-96 w-96 rounded-full bg-cyan-500/15 blur-3xl"></div>
</div>
<main class="relative z-10 mx-auto flex min-h-screen max-w-5xl items-center justify-center px-6 py-12">
"""
def page(title: str, inner: str) -> str:
return (
SHELL_HEAD.replace("{title}", title)
+ inner
+ SHELL_FOOT
)
def page_login(title: str, inner: str) -> str:
return (
LOGIN_SHELL_HEAD.replace("{title}", title)
+ inner
+ SHELL_FOOT
)
HOME_HTML = page(
"中转网关",
f"""
<div class="space-y-12">
<div>
<div class="flex flex-wrap items-end justify-between gap-4">
<div>
<h1 class="text-2xl font-bold text-white">模型分布</h1>
<p class="mt-2 text-sm text-slate-400">各节点模型状态(约每 5 秒刷新;主机默认 127.0.0.1,经 frp 映射端口)</p>
</div>
<a href="{app_url('/settings')}" class="text-sm text-indigo-300 hover:text-white transition">管理节点与模型 →</a>
</div>
<div id="cards-loading" class="mt-6 text-slate-400">加载中…</div>
<div id="cards-grid" class="mt-6 grid gap-4 sm:grid-cols-2 xl:grid-cols-3"></div>
<p id="cards-empty" class="mt-6 hidden text-sm text-slate-500">暂无模型卡片,请登录后在系统设置中添加节点与模型。</p>
</div>
<div class="rounded-3xl bg-white/5 p-8 ring-1 ring-white/10 backdrop-blur">
<h2 class="text-lg font-semibold text-white">使用说明</h2>
<ul class="mt-4 space-y-3 text-sm leading-relaxed text-slate-300">
<li>· 对外统一访问本网关(宝塔反代端口 <span class="text-cyan-200">8150</span>),请求 <code class="rounded bg-white/10 px-1 text-cyan-200">/v1/chat/completions</code>Header 携带 <code class="rounded bg-white/10 px-1 text-cyan-200">Authorization: Bearer sk-...</code>。</li>
<li>· JSON 里的 <code class="rounded bg-white/10 px-1 text-cyan-200">model</code> 须与系统设置中登记的模型 ID 一致,网关会转发到对应 <span class="text-white">127.0.0.1:端口</span>。</li>
<li>· 节点端口建议 <span class="text-white">33133318</span>frp 映射);未启用的节点显示为「未启用」。</li>
<li>· 流量与 Token 见 <a href="{app_url('/stats')}" class="text-indigo-300 hover:underline">流量统计</a>API Key 见 <a href="{app_url('/user')}" class="text-indigo-300 hover:underline">用户中心</a>(需登录)。</li>
</ul>
</div>
</div>
<script>
const statusRing = {{ idle: 'ring-emerald-500/40', busy: 'ring-amber-500/40', offline: 'ring-slate-500/30', disabled: 'ring-slate-600/30', error: 'ring-rose-500/40' }};
const statusBadge = {{ idle: 'bg-emerald-500/20 text-emerald-200', busy: 'bg-amber-500/20 text-amber-200', offline: 'bg-slate-500/20 text-slate-300', disabled: 'bg-slate-600/20 text-slate-400', error: 'bg-rose-500/20 text-rose-200' }};
function esc(s) {{ return String(s == null ? '' : s).replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;'); }}
function fmt(n) {{ return Number(n || 0).toLocaleString(); }}
function wgAuthHeaders() {{
var t = localStorage.getItem('web_token');
if (!t) {{ location.replace(wgApi('/login')); return null; }}
return {{ 'Authorization': 'Bearer ' + t }};
}}
let lastCards = [];
function renderCards(cards) {{
const grid = document.getElementById('cards-grid');
const empty = document.getElementById('cards-empty');
if (!cards.length) {{ grid.innerHTML = ''; empty.classList.remove('hidden'); return; }}
empty.classList.add('hidden');
grid.innerHTML = cards.map(function(c) {{
const st = c.status || 'offline';
const err = (c.last_error && (st === 'error' || st === 'offline'))
? '<p class="text-xs text-rose-300/90 truncate mt-2" title="'+esc(c.last_error)+'">'+esc(c.last_error)+'</p>' : '';
return '<article class="rounded-2xl bg-white/5 p-5 ring-1 '+(statusRing[st]||statusRing.offline)+' flex flex-col min-h-[190px]">'
+ '<div class="flex items-start justify-between gap-2 mb-2"><h3 class="font-semibold text-white text-base break-words flex-1">'+esc(c.model_label||c.model_id||'未命名')+'</h3>'
+ '<span class="shrink-0 rounded-full px-2.5 py-0.5 text-xs font-medium '+(statusBadge[st]||statusBadge.offline)+'">'+esc(c.status_label)+'</span></div>'
+ '<p class="text-xs text-slate-500 font-mono truncate mb-1">'+esc(c.model_id||'')+'</p>'
+ '<p class="text-sm text-slate-300 truncate mb-3">'+esc(c.node_name)+' · <span class="text-cyan-200">'+esc(c.endpoint)+'</span></p>'
+ '<dl class="grid grid-cols-2 gap-x-4 gap-y-2 text-xs mt-auto pt-3 border-t border-white/10">'
+ '<div><dt class="text-slate-500">进行中</dt><dd class="text-white font-medium">'+esc(c.in_flight)+' / '+esc(c.max_concurrent)+'</dd></div>'
+ '<div><dt class="text-slate-500">今日 Token</dt><dd class="text-cyan-200 font-medium">'+fmt(c.today_tokens)+'</dd></div>'
+ '<div><dt class="text-slate-500">今日请求</dt><dd class="text-white font-medium">'+fmt(c.today_requests)+'</dd></div>'
+ '<div></div></dl>' + err + '</article>';
}}).join('');
}}
async function refreshCards() {{
const hdrs = wgAuthHeaders();
if (!hdrs) return;
try {{
const res = await fetch(wgApi('/api/models/cards'), {{ headers: hdrs }});
if (res.status === 401) {{ localStorage.removeItem('web_token'); location.replace(wgApi('/login')); return; }}
const data = await res.json().catch(function() {{ return []; }});
document.getElementById('cards-loading').classList.add('hidden');
if (res.ok) {{
lastCards = Array.isArray(data) ? data : (data.cards || []);
renderCards(lastCards);
}} else if (lastCards.length) {{
renderCards(lastCards);
}}
}} catch (e) {{
if (lastCards.length) renderCards(lastCards);
}}
}}
refreshCards();
setInterval(refreshCards, 5000);
</script>
""",
)
LOGIN_HTML = page_login(
"登录",
f"""
<div class="mx-auto max-w-md">
<div class="rounded-3xl bg-white/5 p-8 ring-1 ring-white/10 backdrop-blur">
<h1 class="text-2xl font-bold text-white">登录</h1>
<p class="mt-2 text-sm text-slate-400">账号与密码与 <code class="text-cyan-200/90">gateway.json</code> 中一致</p>
<form id="f" class="mt-8 space-y-5">
<div>
<label class="block text-sm font-medium text-slate-300">用户名</label>
<input name="username" required autocomplete="username"
class="mt-2 w-full rounded-xl border border-white/10 bg-black/30 px-4 py-3 text-white outline-none focus:ring-2 focus:ring-indigo-500 transition"/>
</div>
<div>
<label class="block text-sm font-medium text-slate-300">密码</label>
<input name="password" type="password" required autocomplete="current-password"
class="mt-2 w-full rounded-xl border border-white/10 bg-black/30 px-4 py-3 text-white outline-none focus:ring-2 focus:ring-indigo-500 transition"/>
</div>
<p id="err" class="hidden text-sm text-rose-400"></p>
<button type="submit" class="w-full rounded-xl bg-white/10 py-3 text-sm font-semibold text-white ring-1 ring-white/15 hover:bg-white/15 transition">
登录
</button>
</form>
</div>
</div>
<script>
document.getElementById('f').addEventListener('submit', async (e) => {{
e.preventDefault();
const fd = new FormData(e.target);
const body = {{ username: String(fd.get('username')||'').trim(), password: String(fd.get('password')||'') }};
const err = document.getElementById('err');
err.classList.add('hidden');
try {{
const res = await fetch(wgApi('/api/login'), {{ method: 'POST', headers: {{ 'Content-Type': 'application/json' }}, body: JSON.stringify(body) }});
const raw = await res.text();
let data = {{}};
try {{ data = raw ? JSON.parse(raw) : {{}}; }} catch (_) {{}}
if (!res.ok) {{ err.textContent = wgFmtErr(res, data) || raw.slice(0, 200) || '登录失败'; err.classList.remove('hidden'); return; }}
if (!data.access_token) {{ err.textContent = '服务端未返回令牌'; err.classList.remove('hidden'); return; }}
localStorage.setItem('web_token', data.access_token);
window.location.href = wgApi('/home');
}} catch (ex) {{
err.textContent = '无法连接服务器:' + (ex && ex.message ? ex.message : String(ex));
err.classList.remove('hidden');
}}
}});
</script>
""",
)
USER_HTML = page(
"用户中心",
"""
<div class="mx-auto max-w-lg">
<div class="rounded-3xl bg-white/5 p-8 ring-1 ring-white/10 backdrop-blur">
<h1 class="text-2xl font-bold text-white">用户中心</h1>
<p class="mt-2 text-sm text-slate-400">OpenAI 兼容调用密钥(与 gateway.json 中 api_key 一致)</p>
<div id="loading" class="mt-8 text-slate-400">加载中…</div>
<div id="panel" class="hidden mt-8 space-y-6">
<div>
<span class="text-xs font-semibold uppercase tracking-wider text-indigo-200">用户名</span>
<p id="uname" class="mt-2 text-lg font-medium text-white"></p>
</div>
<div>
<span class="text-xs font-semibold uppercase tracking-wider text-indigo-200">API Key</span>
<div class="mt-2 flex flex-col gap-3 sm:flex-row sm:items-center">
<code id="apikey" class="flex-1 break-all rounded-xl bg-black/40 px-4 py-3 text-sm text-cyan-200 ring-1 ring-white/10"></code>
<button id="copy" type="button" class="rounded-xl bg-gradient-to-r from-indigo-500 to-cyan-500 px-5 py-3 text-sm font-semibold text-slate-950 hover:brightness-110 transition shrink-0">
复制
</button>
</div>
</div>
<button id="logout" type="button" class="text-sm text-slate-400 hover:text-white underline underline-offset-4">退出登录</button>
</div>
<p id="err" class="hidden mt-6 text-sm text-rose-400"></p>
</div>
</div>
<script>
const token = localStorage.getItem('web_token');
if (!token) { window.location.href = wgApi('/login'); }
else {
fetch(wgApi('/api/me'), { headers: { 'Authorization': 'Bearer ' + token } })
.then(async (res) => {
const data = await res.json().catch(() => ({}));
document.getElementById('loading').classList.add('hidden');
if (!res.ok) {
document.getElementById('err').textContent = wgFmtErr(res, data) || data.detail || '加载失败';
document.getElementById('err').classList.remove('hidden');
return;
}
document.getElementById('panel').classList.remove('hidden');
document.getElementById('uname').textContent = data.username;
document.getElementById('apikey').textContent = data.api_key;
});
}
document.getElementById('copy').addEventListener('click', () => {
const k = document.getElementById('apikey').textContent;
navigator.clipboard.writeText(k).then(() => {
const b = document.getElementById('copy');
const t = b.textContent;
b.textContent = '已复制';
setTimeout(() => { b.textContent = t; }, 1500);
});
});
document.getElementById('logout').addEventListener('click', () => {
localStorage.removeItem('web_token');
window.location.href = wgApi('/login');
});
</script>
""",
)
SETTINGS_HTML = page(
"系统设置",
f"""
<div class="mx-auto max-w-3xl space-y-8">
<div>
<h1 class="text-2xl font-bold text-white">系统设置</h1>
<p class="mt-2 text-sm text-slate-400">管理节点主机、端口与模型(默认主机 127.0.0.1,对应 frp 映射)</p>
</div>
<p id="msg" class="hidden text-sm"></p>
<div id="nodes-list" class="space-y-4"></div>
<div class="rounded-3xl bg-white/5 p-6 ring-1 ring-white/10">
<h2 class="text-lg font-semibold text-white">添加节点</h2>
<form id="add-form" class="mt-4 space-y-4">
<div class="grid gap-4 sm:grid-cols-2">
<div><label class="text-sm text-slate-300">节点名称</label>
<input name="name" required class="mt-1 w-full rounded-xl border border-white/10 bg-black/30 px-3 py-2 text-white"/></div>
<div><label class="text-sm text-slate-300">主机</label>
<input name="host" value="127.0.0.1" class="mt-1 w-full rounded-xl border border-white/10 bg-black/30 px-3 py-2 text-white"/></div>
<div><label class="text-sm text-slate-300">端口</label>
<input name="port" type="number" required min="1" max="65535" placeholder="3313" class="mt-1 w-full rounded-xl border border-white/10 bg-black/30 px-3 py-2 text-white"/></div>
<div><label class="text-sm text-slate-300">最大并发</label>
<input name="max_concurrent" type="number" value="1" min="1" max="32" class="mt-1 w-full rounded-xl border border-white/10 bg-black/30 px-3 py-2 text-white"/></div>
</div>
<div><label class="text-sm text-slate-300">模型列表 <span class="text-slate-500">(每行:模型ID|显示名,显示名可省略)</span></label>
<textarea name="models" rows="4" placeholder="qwen2.5:14b|千问14B&#10;llama3|Llama3" class="mt-1 w-full rounded-xl border border-white/10 bg-black/30 px-3 py-2 text-sm text-white font-mono"></textarea>
</div>
<label class="flex items-center gap-2 text-sm text-slate-300"><input name="enabled" type="checkbox" checked class="rounded"/> 启用节点</label>
<button type="submit" class="rounded-xl bg-gradient-to-r from-indigo-500 to-cyan-500 px-5 py-2.5 text-sm font-semibold text-slate-950">添加节点</button>
</form>
</div>
</div>
<script>
const token = localStorage.getItem('web_token');
if (!token) {{ location.replace(wgApi('/login')); }}
const hdrs = {{ 'Authorization': 'Bearer ' + token, 'Content-Type': 'application/json' }};
function esc(s) {{ return String(s==null?'':s).replace(/&/g,'&amp;').replace(/</g,'&lt;'); }}
function showMsg(text, ok) {{
const m = document.getElementById('msg');
m.textContent = text;
m.className = 'text-sm ' + (ok ? 'text-emerald-400' : 'text-rose-400');
m.classList.remove('hidden');
setTimeout(() => m.classList.add('hidden'), 4000);
}}
function parseModels(text) {{
return String(text||'').split(/\\n/).map(l => l.trim()).filter(Boolean).map(line => {{
const p = line.split('|');
return {{ id: p[0].trim(), label: (p[1]||p[0]).trim() }};
}}).filter(m => m.id);
}}
function modelsToText(models) {{
return (models||[]).map(m => m.id + (m.label && m.label !== m.id ? '|' + m.label : '')).join('\\n');
}}
async function api(method, path, body) {{
const res = await fetch(wgApi(path), {{ method, headers: hdrs, body: body ? JSON.stringify(body) : undefined }});
const data = await res.json().catch(() => ({{}}));
if (!res.ok) throw new Error(wgFmtErr(res, data) || '请求失败');
return data;
}}
function renderNodes(nodes) {{
const box = document.getElementById('nodes-list');
box.innerHTML = '';
nodes.forEach(n => {{
const card = document.createElement('div');
card.className = 'rounded-2xl bg-white/5 p-5 ring-1 ring-white/10 space-y-3';
const modelsTxt = modelsToText(n.models);
card.innerHTML = '<div class="flex flex-wrap items-center justify-between gap-2"><div><span class="font-semibold text-white">'+esc(n.name)+'</span> <span class="text-xs rounded-full px-2 py-0.5 bg-white/10 text-slate-300">'+esc(n.status_label)+'</span></div><div class="flex gap-2"><button data-test="'+esc(n.id)+'" class="text-xs text-indigo-300 hover:underline">测试</button><button data-del="'+esc(n.id)+'" class="text-xs text-rose-400 hover:underline">删除</button></div></div>'
+ '<p class="text-sm text-cyan-200/90 font-mono">'+esc(n.host)+':'+esc(n.port)+' · 进行中 '+esc(n.in_flight)+'/'+esc(n.max_concurrent)+'</p>'
+ '<div class="grid gap-2 sm:grid-cols-2 text-sm"><input data-f="name" value="'+esc(n.name)+'" class="rounded-lg bg-black/30 border border-white/10 px-2 py-1 text-white"/>'
+ '<input data-f="host" value="'+esc(n.host)+'" class="rounded-lg bg-black/30 border border-white/10 px-2 py-1 text-white"/>'
+ '<input data-f="port" type="number" value="'+esc(n.port)+'" class="rounded-lg bg-black/30 border border-white/10 px-2 py-1 text-white"/>'
+ '<input data-f="max_concurrent" type="number" value="'+esc(n.max_concurrent)+'" class="rounded-lg bg-black/30 border border-white/10 px-2 py-1 text-white"/></div>'
+ '<textarea data-f="models" rows="3" class="w-full rounded-lg bg-black/30 border border-white/10 px-2 py-1 text-xs text-white font-mono">'+esc(modelsTxt)+'</textarea>'
+ '<label class="flex items-center gap-2 text-sm"><input data-f="enabled" type="checkbox" '+(n.enabled?'checked':'')+'> 启用</label>'
+ '<button data-save="'+esc(n.id)+'" class="rounded-lg bg-white/10 px-3 py-1.5 text-sm text-white ring-1 ring-white/15">保存修改</button>';
card.querySelector('[data-test]').onclick = async () => {{
try {{ const r = await api('POST', '/api/nodes/'+n.id+'/test'); showMsg(r.ok ? '连接成功' : ('失败: '+(r.error||'')), r.ok); load(); }} catch(e) {{ showMsg(e.message, false); }}
}};
card.querySelector('[data-del]').onclick = async () => {{
if (!confirm('确定删除该节点?')) return;
try {{ await api('DELETE', '/api/nodes/'+n.id); showMsg('已删除', true); load(); }} catch(e) {{ showMsg(e.message, false); }}
}};
card.querySelector('[data-save]').onclick = async () => {{
const body = {{
name: card.querySelector('[data-f=name]').value.trim(),
host: card.querySelector('[data-f=host]').value.trim() || '127.0.0.1',
port: parseInt(card.querySelector('[data-f=port]').value, 10),
max_concurrent: parseInt(card.querySelector('[data-f=max_concurrent]').value, 10) || 1,
enabled: card.querySelector('[data-f=enabled]').checked,
models: parseModels(card.querySelector('[data-f=models]').value),
}};
try {{ await api('PUT', '/api/nodes/'+n.id, body); showMsg('已保存', true); load(); }} catch(e) {{ showMsg(e.message, false); }}
}};
box.appendChild(card);
}});
}}
async function load() {{ renderNodes(await api('GET', '/api/nodes')); }}
document.getElementById('add-form').onsubmit = async (e) => {{
e.preventDefault();
const fd = new FormData(e.target);
const body = {{
name: String(fd.get('name')).trim(),
host: String(fd.get('host')||'127.0.0.1').trim() || '127.0.0.1',
port: parseInt(String(fd.get('port')), 10),
max_concurrent: parseInt(String(fd.get('max_concurrent')), 10) || 1,
enabled: !!fd.get('enabled'),
models: parseModels(String(fd.get('models')||'')),
}};
try {{ await api('POST', '/api/nodes', body); showMsg('节点已添加', true); e.target.reset(); e.target.host.value='127.0.0.1'; load(); }} catch(err) {{ showMsg(err.message, false); }}
}};
load().catch(e => showMsg(e.message, false));
</script>
""",
).replace("<div>", "<div>").replace("</div>", "</div>")
STATS_HTML = page(
"流量统计",
f"""
<div>
<div class="mx-auto max-w-4xl space-y-8">
<div>
<h1 class="text-2xl font-bold text-white">流量统计</h1>
<p class="mt-2 text-sm text-slate-400">客户端 IP 与 Token 用量(UTC 日期汇总;反代需传递 X-Forwarded-For</p>
</div>
<div id="loading" class="text-slate-400">加载中…</div>
<p id="err" class="hidden text-sm text-rose-400"></p>
<div id="panel" class="hidden space-y-8">
<div class="grid gap-4 sm:grid-cols-2 lg:grid-cols-4">
<div class="rounded-2xl bg-white/5 p-5 ring-1 ring-white/10">
<p class="text-xs font-semibold uppercase tracking-wider text-indigo-200">今日请求</p>
<p id="s-today-req" class="mt-2 text-2xl font-bold text-white">—</p>
</div>
<div class="rounded-2xl bg-white/5 p-5 ring-1 ring-white/10">
<p class="text-xs font-semibold uppercase tracking-wider text-indigo-200">今日 Token</p>
<p id="s-today-tok" class="mt-2 text-2xl font-bold text-cyan-300">—</p>
</div>
<div class="rounded-2xl bg-white/5 p-5 ring-1 ring-white/10">
<p class="text-xs font-semibold uppercase tracking-wider text-indigo-200">累计 Token</p>
<p id="s-total-tok" class="mt-2 text-2xl font-bold text-white">—</p>
</div>
<div class="rounded-2xl bg-white/5 p-5 ring-1 ring-white/10">
<p class="text-xs font-semibold uppercase tracking-wider text-indigo-200">独立 IP</p>
<p id="s-unique-ip" class="mt-2 text-2xl font-bold text-white">—</p>
</div>
</div>
<div class="rounded-3xl bg-white/5 p-6 ring-1 ring-white/10 backdrop-blur">
<h2 class="text-lg font-semibold text-white">IP 列表</h2>
<p class="mt-1 text-sm text-slate-400">按最近访问排序</p>
<div class="mt-4 overflow-x-auto rounded-2xl ring-1 ring-white/10">
<table class="min-w-full text-left text-sm">
<thead class="bg-black/30 text-xs uppercase tracking-wider text-slate-400">
<tr>
<th class="px-4 py-3">IP</th>
<th class="px-4 py-3">请求数</th>
<th class="px-4 py-3">Token 合计</th>
<th class="px-4 py-3">最近访问 (UTC)</th>
</tr>
</thead>
<tbody id="ip-tbody" class="divide-y divide-white/5 text-slate-200"></tbody>
</table>
</div>
<p id="ip-empty" class="hidden mt-4 text-sm text-slate-500">暂无记录</p>
</div>
<div class="rounded-3xl bg-white/5 p-6 ring-1 ring-white/10 backdrop-blur">
<div class="flex flex-wrap items-end justify-between gap-4">
<div>
<h2 class="text-lg font-semibold text-white">Token 账单</h2>
<p class="mt-1 text-sm text-slate-400">按日汇总(近 <span id="bill-days-label">30</span> 天)</p>
</div>
<select id="bill-days" class="rounded-xl border border-white/10 bg-black/30 px-3 py-2 text-sm text-white outline-none focus:ring-2 focus:ring-indigo-500">
<option value="7">近 7 天</option>
<option value="30" selected>近 30 天</option>
<option value="90">近 90 天</option>
</select>
</div>
<div class="mt-4 overflow-x-auto rounded-2xl ring-1 ring-white/10">
<table class="min-w-full text-left text-sm">
<thead class="bg-black/30 text-xs uppercase tracking-wider text-slate-400">
<tr>
<th class="px-4 py-3">日期 (UTC)</th>
<th class="px-4 py-3">请求</th>
<th class="px-4 py-3">Prompt</th>
<th class="px-4 py-3">Completion</th>
<th class="px-4 py-3">合计</th>
</tr>
</thead>
<tbody id="bill-tbody" class="divide-y divide-white/5 text-slate-200"></tbody>
</table>
</div>
<p id="bill-empty" class="hidden mt-4 text-sm text-slate-500">该时段暂无账单</p>
<h3 class="mt-8 text-sm font-semibold uppercase tracking-wider text-indigo-200">同期按 IP 分摊</h3>
<div class="mt-3 overflow-x-auto rounded-2xl ring-1 ring-white/10">
<table class="min-w-full text-left text-sm">
<thead class="bg-black/30 text-xs uppercase tracking-wider text-slate-400">
<tr>
<th class="px-4 py-3">IP</th>
<th class="px-4 py-3">请求</th>
<th class="px-4 py-3">Prompt</th>
<th class="px-4 py-3">Completion</th>
<th class="px-4 py-3">合计</th>
</tr>
</thead>
<tbody id="bill-ip-tbody" class="divide-y divide-white/5 text-slate-200"></tbody>
</table>
</div>
</div>
<div class="rounded-3xl bg-white/5 p-6 ring-1 ring-white/10 backdrop-blur">
<h2 class="text-lg font-semibold text-white">最近请求</h2>
<div class="mt-4 overflow-x-auto rounded-2xl ring-1 ring-white/10">
<table class="min-w-full text-left text-sm">
<thead class="bg-black/30 text-xs uppercase tracking-wider text-slate-400">
<tr>
<th class="px-4 py-3">时间 (UTC)</th>
<th class="px-4 py-3">IP</th>
<th class="px-4 py-3">模型</th>
<th class="px-4 py-3">状态</th>
<th class="px-4 py-3">Token</th>
</tr>
</thead>
<tbody id="log-tbody" class="divide-y divide-white/5 text-slate-200"></tbody>
</table>
</div>
</div>
</div>
</div>
<script>
const token = localStorage.getItem('web_token');
if (!token) {{ window.location.href = wgApi('/login'); }}
else {{
const hdrs = {{ 'Authorization': 'Bearer ' + token }};
function fmt(n) {{ return Number(n || 0).toLocaleString(); }}
function esc(s) {{
return String(s == null ? '' : s)
.replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;');
}}
function showErr(msg) {{
document.getElementById('loading').classList.add('hidden');
const el = document.getElementById('err');
el.textContent = msg;
el.classList.remove('hidden');
}}
async function apiGet(path) {{
const res = await fetch(wgApi(path), {{ headers: hdrs }});
const data = await res.json().catch(() => ({{}}));
if (!res.ok) throw new Error(wgFmtErr(res, data) || '请求失败');
return data;
}}
function fillIpTable(rows) {{
const tb = document.getElementById('ip-tbody');
tb.innerHTML = '';
if (!rows.length) {{
document.getElementById('ip-empty').classList.remove('hidden');
return;
}}
document.getElementById('ip-empty').classList.add('hidden');
rows.forEach(r => {{
const tr = document.createElement('tr');
tr.className = 'hover:bg-white/5';
tr.innerHTML = '<td class="px-4 py-3 font-mono text-cyan-200">' + esc(r.ip) + '</td>'
+ '<td class="px-4 py-3">' + fmt(r.requests) + '</td>'
+ '<td class="px-4 py-3">' + fmt(r.total_tokens) + '</td>'
+ '<td class="px-4 py-3 text-slate-400">' + esc((r.last_seen || '').replace('T',' ').slice(0,19)) + '</td>';
tb.appendChild(tr);
}});
}}
function fillBillTables(data) {{
const dayTb = document.getElementById('bill-tbody');
const ipTb = document.getElementById('bill-ip-tbody');
dayTb.innerHTML = '';
ipTb.innerHTML = '';
const days = data.by_day || [];
const ips = data.by_ip || [];
document.getElementById('bill-empty').classList.toggle('hidden', days.length > 0);
days.forEach(r => {{
const tr = document.createElement('tr');
tr.className = 'hover:bg-white/5';
tr.innerHTML = '<td class="px-4 py-3">' + esc(r.day) + '</td>'
+ '<td class="px-4 py-3">' + fmt(r.requests) + '</td>'
+ '<td class="px-4 py-3">' + fmt(r.prompt_tokens) + '</td>'
+ '<td class="px-4 py-3">' + fmt(r.completion_tokens) + '</td>'
+ '<td class="px-4 py-3 font-medium text-cyan-200">' + fmt(r.total_tokens) + '</td>';
dayTb.appendChild(tr);
}});
ips.forEach(r => {{
const tr = document.createElement('tr');
tr.className = 'hover:bg-white/5';
tr.innerHTML = '<td class="px-4 py-3 font-mono text-cyan-200">' + esc(r.ip) + '</td>'
+ '<td class="px-4 py-3">' + fmt(r.requests) + '</td>'
+ '<td class="px-4 py-3">' + fmt(r.prompt_tokens) + '</td>'
+ '<td class="px-4 py-3">' + fmt(r.completion_tokens) + '</td>'
+ '<td class="px-4 py-3 font-medium text-cyan-200">' + fmt(r.total_tokens) + '</td>';
ipTb.appendChild(tr);
}});
}}
function fillLogs(rows) {{
const tb = document.getElementById('log-tbody');
tb.innerHTML = '';
rows.forEach(r => {{
const tr = document.createElement('tr');
tr.className = 'hover:bg-white/5';
const st = r.status_code >= 400 ? 'text-rose-400' : 'text-emerald-400';
tr.innerHTML = '<td class="px-4 py-3 text-slate-400">' + esc((r.created_at || '').replace('T',' ').slice(0,19)) + '</td>'
+ '<td class="px-4 py-3 font-mono text-cyan-200">' + esc(r.client_ip) + '</td>'
+ '<td class="px-4 py-3">' + esc(r.model || '') + '</td>'
+ '<td class="px-4 py-3 ' + st + '">' + esc(r.status_code) + '</td>'
+ '<td class="px-4 py-3">' + fmt(r.total_tokens) + '</td>';
tb.appendChild(tr);
}});
}}
async function loadBilling(days) {{
document.getElementById('bill-days-label').textContent = days;
const bill = await apiGet('/api/stats/billing?days=' + days);
fillBillTables(bill);
}}
async function boot() {{
try {{
const [summary, ips, logs] = await Promise.all([
apiGet('/api/stats/summary'),
apiGet('/api/stats/ips'),
apiGet('/api/stats/logs?limit=50'),
]);
document.getElementById('s-today-req').textContent = fmt(summary.today_requests);
document.getElementById('s-today-tok').textContent = fmt(summary.today_tokens);
document.getElementById('s-total-tok').textContent = fmt(summary.total_tokens);
document.getElementById('s-unique-ip').textContent = fmt(summary.unique_ips);
fillIpTable(ips);
fillLogs(logs.items || []);
const daysSel = document.getElementById('bill-days');
await loadBilling(daysSel.value);
daysSel.addEventListener('change', () => loadBilling(daysSel.value).catch(e => showErr(e.message)));
document.getElementById('loading').classList.add('hidden');
document.getElementById('panel').classList.remove('hidden');
}} catch (e) {{
showErr(e.message || String(e));
}}
}}
boot();
}}
</script>
""",
)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
load_gateway_config()
load_nodes_config()
init_stats_db()
timeout = httpx.Timeout(600.0, connect=30.0)
app.state.http = httpx.AsyncClient(timeout=timeout)
global _health_task
_health_task = asyncio.create_task(_health_loop(app.state.http))
try:
yield
finally:
if _health_task:
_health_task.cancel()
await app.state.http.aclose()
app = FastAPI(title="LLM Gateway", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", response_class=HTMLResponse)
async def root() -> str:
return LOGIN_HTML
@app.get("/login", response_class=HTMLResponse)
async def login_page() -> str:
return LOGIN_HTML
@app.get("/home", response_class=HTMLResponse)
async def home() -> str:
return HOME_HTML
@app.get("/user", response_class=HTMLResponse)
async def user_page() -> str:
return USER_HTML
@app.get("/settings", response_class=HTMLResponse)
async def settings_page() -> str:
return SETTINGS_HTML
@app.get("/stats", response_class=HTMLResponse)
async def stats_page() -> str:
return STATS_HTML
@app.post("/api/login")
async def api_login(body: LoginBody) -> JSONResponse:
name = body.username.strip()
if name != _GATE["username"]:
raise HTTPException(status_code=401, detail="用户名或密码错误")
if not verify_password(body.password, _PASSWORD_HASH):
raise HTTPException(status_code=401, detail="用户名或密码错误")
token = create_web_token(GATE_WEB_UID)
return JSONResponse({"access_token": token, "token_type": "bearer"})
@app.get("/api/me")
async def api_me(
user: Annotated[GateSessionUser, Depends(get_current_web_user)],
) -> Dict[str, str]:
return {"username": user.username, "api_key": user.api_key}
@app.get("/api/models/cards")
async def api_model_cards(
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
) -> List[Dict[str, Any]]:
return await run_in_thread(build_model_cards)
@app.get("/api/nodes")
async def api_nodes(
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
) -> List[Dict[str, Any]]:
return await run_in_thread(list_nodes_with_status)
@app.post("/api/nodes")
async def api_nodes_create(
body: NodeBody,
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
) -> Dict[str, Any]:
node = {
"id": _new_node_id(),
"name": body.name.strip(),
"host": (body.host or "127.0.0.1").strip() or "127.0.0.1",
"port": body.port,
"enabled": body.enabled,
"max_concurrent": body.max_concurrent,
"models": [{"id": m.id.strip(), "label": (m.label or m.id).strip()} for m in body.models],
}
nodes = list(_NODES.get("nodes", [])) + [node]
await run_in_thread(save_nodes_config, {"nodes": nodes})
return node
@app.put("/api/nodes/{node_id}")
async def api_nodes_update(
node_id: str,
body: NodeBody,
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
) -> Dict[str, Any]:
nodes = list(_NODES.get("nodes", []))
idx = next((i for i, n in enumerate(nodes) if str(n.get("id")) == node_id), None)
if idx is None:
raise HTTPException(status_code=404, detail="节点不存在")
nodes[idx] = {
"id": node_id,
"name": body.name.strip(),
"host": (body.host or "127.0.0.1").strip() or "127.0.0.1",
"port": body.port,
"enabled": body.enabled,
"max_concurrent": body.max_concurrent,
"models": [{"id": m.id.strip(), "label": (m.label or m.id).strip()} for m in body.models],
}
await run_in_thread(save_nodes_config, {"nodes": nodes})
return nodes[idx]
@app.delete("/api/nodes/{node_id}")
async def api_nodes_delete(
node_id: str,
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
) -> JSONResponse:
nodes = [n for n in _NODES.get("nodes", []) if str(n.get("id")) != node_id]
if len(nodes) == len(_NODES.get("nodes", [])):
raise HTTPException(status_code=404, detail="节点不存在")
await run_in_thread(save_nodes_config, {"nodes": nodes})
return JSONResponse({"ok": True})
@app.post("/api/nodes/{node_id}/test")
async def api_nodes_test(
node_id: str,
request: Request,
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
) -> Dict[str, Any]:
node = next((n for n in _NODES.get("nodes", []) if str(n.get("id")) == node_id), None)
if node is None:
raise HTTPException(status_code=404, detail="节点不存在")
client: httpx.AsyncClient = request.app.state.http
ok, err = await probe_node(client, node)
with _nodes_lock:
rt = _runtime_unlocked(node_id)
rt["healthy"] = ok
rt["last_error"] = err
rt["last_check"] = datetime.now(timezone.utc).isoformat()
return {"ok": ok, "error": err}
@app.get("/api/stats/summary")
async def api_stats_summary(
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
) -> Dict[str, Any]:
return await run_in_thread(_query_stats_summary)
@app.get("/api/stats/ips")
async def api_stats_ips(
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
) -> List[Dict[str, Any]]:
return await run_in_thread(_query_stats_ips)
@app.get("/api/stats/billing")
async def api_stats_billing(
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
days: int = Query(30, ge=1, le=365),
) -> Dict[str, Any]:
return await run_in_thread(_query_stats_billing, days)
@app.get("/api/stats/logs")
async def api_stats_logs(
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
limit: int = Query(50, ge=1, le=500),
offset: int = Query(0, ge=0),
) -> Dict[str, Any]:
items, total = await run_in_thread(_query_stats_logs, limit, offset)
return {"items": items, "total": total, "limit": limit, "offset": offset}
@app.post("/v1/chat/completions")
async def chat_completions(
request: Request,
_: Annotated[GateSessionUser, Depends(get_user_from_api_key)],
) -> StreamingResponse:
"""
OpenAI 兼容聊天接口:校验 Bearer sk-xxx 后,将原始请求体转发到上游。
"""
body = await request.body()
client_ip = get_client_ip(request)
model = parse_model_from_body(body)
req_bytes = len(body)
selected_node: Optional[Dict[str, Any]] = None
if _NODES.get("nodes"):
if not model:
raise HTTPException(status_code=400, detail="请求体须包含 model 字段")
selected_node = select_node_for_model(model)
if selected_node is None:
raise HTTPException(
status_code=404,
detail=f"未找到模型「{model}」的可用节点,请在系统设置中配置",
)
base = get_node_base(selected_node)
node_in_flight_inc(str(selected_node["id"]))
else:
base = UPSTREAM_BASE
url = f"{base}/v1/chat/completions"
fwd_headers = {
"Content-Type": request.headers.get("content-type") or "application/json",
"Accept": request.headers.get("accept") or "*/*",
}
client: httpx.AsyncClient = request.app.state.http
req = client.build_request("POST", url, headers=fwd_headers, content=body)
try:
resp = await client.send(req, stream=True)
except httpx.RequestError as e:
nid = str(selected_node["id"]) if selected_node else None
if selected_node:
node_in_flight_dec(nid)
await record_api_log(
client_ip,
model,
502,
req_bytes,
0,
parse_usage_from_response(b"", ""),
nid,
)
raise HTTPException(status_code=502, detail=f"上游连接失败: {e}") from e
hop_headers = (
"content-type",
"cache-control",
"openai-processing-ms",
"x-request-id",
)
out_headers = {
k: v
for k, v in resp.headers.items()
if k.lower() in hop_headers
}
content_type = resp.headers.get("content-type", "")
async def stream_body() -> AsyncIterator[bytes]:
buf = bytearray()
try:
async for chunk in resp.aiter_bytes():
buf.extend(chunk)
yield chunk
finally:
await resp.aclose()
usage = parse_usage_from_response(bytes(buf), content_type)
nid = str(selected_node["id"]) if selected_node else None
if selected_node:
node_in_flight_dec(nid)
await record_api_log(
client_ip,
model,
resp.status_code,
req_bytes,
len(buf),
usage,
nid,
)
return StreamingResponse(
stream_body(),
status_code=resp.status_code,
headers=out_headers,
)
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=GATEWAY_PORT, reload=False)