1767 lines
71 KiB
Python
1767 lines
71 KiB
Python
"""
|
||
LLM 中转网关:FastAPI + 嵌入式前端(单文件)。
|
||
账号与 API Key 从 JSON 配置文件读取,不提供注册。
|
||
|
||
运行:pip install -r requirements.txt && uvicorn main:app --host 0.0.0.0 --port 8000
|
||
|
||
配置:复制 gateway.json.example 为 gateway.json,填写 username、password。
|
||
api_key 可留空,首次启动会自动生成 sk-... 并写回 gateway.json。
|
||
|
||
环境变量(可选):
|
||
JWT_SECRET JWT 签名密钥(生产环境务必修改)
|
||
UPSTREAM_URL 上游 OpenAI 兼容地址,默认 http://127.0.0.1:10434
|
||
APP_ROOT 反代子路径前缀,例如 /wg
|
||
GATEWAY_CONFIG 配置文件路径,默认当前目录下 gateway.json
|
||
NODES_CONFIG 节点/模型配置,默认 nodes.json
|
||
STATS_DB 访问统计 SQLite 路径,默认 gateway_stats.db
|
||
GATEWAY_PORT 监听端口,默认 8150
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import os
|
||
import secrets
|
||
import sqlite3
|
||
import threading
|
||
from contextlib import asynccontextmanager
|
||
from dataclasses import dataclass
|
||
from datetime import datetime, timedelta, timezone
|
||
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Tuple, TypeVar
|
||
|
||
try:
|
||
from typing import Annotated # Python 3.9+
|
||
except ImportError: # Python 3.8
|
||
from typing_extensions import Annotated
|
||
|
||
import httpx
|
||
|
||
_T = TypeVar("_T")
|
||
|
||
|
||
async def run_in_thread(func: Callable[..., _T], /, *args: Any) -> _T:
|
||
"""在线程池执行阻塞函数(兼容 Python 3.8,无 asyncio.to_thread)。"""
|
||
if hasattr(asyncio, "to_thread"):
|
||
return await run_in_thread(func, *args) # type: ignore[attr-defined]
|
||
loop = asyncio.get_running_loop()
|
||
return await loop.run_in_executor(None, func, *args)
|
||
from fastapi import Depends, FastAPI, HTTPException, Query, Request
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
|
||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||
from jose import JWTError, jwt
|
||
from passlib.context import CryptContext
|
||
from pydantic import BaseModel, Field
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 配置
|
||
# ---------------------------------------------------------------------------
|
||
|
||
CONFIG_PATH = os.environ.get("GATEWAY_CONFIG", "gateway.json")
|
||
NODES_PATH = os.environ.get("NODES_CONFIG", "nodes.json")
|
||
STATS_DB = os.environ.get("STATS_DB", "gateway_stats.db")
|
||
GATEWAY_PORT = int(os.environ.get("GATEWAY_PORT", "8150"))
|
||
JWT_SECRET = os.environ.get("JWT_SECRET", "dev-secret-change-me")
|
||
JWT_ALG = "HS256"
|
||
JWT_EXPIRE_DAYS = 7
|
||
GATE_WEB_UID = 1
|
||
UPSTREAM_BASE = os.environ.get("UPSTREAM_URL", "http://127.0.0.1:10434").rstrip("/")
|
||
APP_ROOT = os.environ.get("APP_ROOT", "").rstrip("/")
|
||
|
||
|
||
def app_url(path: str) -> str:
|
||
if not path.startswith("/"):
|
||
path = "/" + path
|
||
return APP_ROOT + path if APP_ROOT else path
|
||
|
||
|
||
_APP_ROOT_JSON = json.dumps(APP_ROOT)
|
||
|
||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||
security_bearer = HTTPBearer(auto_error=False)
|
||
|
||
# 启动时由 load_gateway_config() 填充
|
||
_GATE: Dict[str, Any] = {}
|
||
_PASSWORD_HASH: str = ""
|
||
_db_lock = threading.Lock()
|
||
_nodes_lock = threading.Lock()
|
||
_NODES: Dict[str, Any] = {"nodes": []}
|
||
_node_runtime: Dict[str, Dict[str, Any]] = {}
|
||
_health_task: Optional[asyncio.Task] = None
|
||
|
||
STATUS_LABELS: Dict[str, str] = {
|
||
"offline": "离线",
|
||
"idle": "空闲",
|
||
"busy": "忙碌",
|
||
"disabled": "未启用",
|
||
"error": "异常",
|
||
}
|
||
|
||
|
||
def hash_password(p: str) -> str:
|
||
return pwd_context.hash(p)
|
||
|
||
|
||
def verify_password(plain: str, hashed: str) -> bool:
|
||
return pwd_context.verify(plain, hashed)
|
||
|
||
|
||
def generate_api_key() -> str:
|
||
return "sk-" + secrets.token_urlsafe(32)
|
||
|
||
|
||
def load_gateway_config() -> None:
|
||
"""读取 gateway.json;若无 api_key 则生成并写回文件。"""
|
||
global _GATE, _PASSWORD_HASH
|
||
if not os.path.isfile(CONFIG_PATH):
|
||
raise RuntimeError(
|
||
f"缺少配置文件: {CONFIG_PATH}(可复制 gateway.json.example 为 gateway.json)"
|
||
)
|
||
with open(CONFIG_PATH, encoding="utf-8") as f:
|
||
raw = json.load(f)
|
||
username = str(raw.get("username", "")).strip()
|
||
password = raw.get("password", "")
|
||
if not isinstance(password, str):
|
||
password = str(password)
|
||
if not username or not password:
|
||
raise ValueError(f"{CONFIG_PATH} 须包含非空的 username 与 password")
|
||
|
||
api_key = str(raw.get("api_key", "")).strip()
|
||
changed = False
|
||
if not api_key:
|
||
api_key = generate_api_key()
|
||
raw["api_key"] = api_key
|
||
changed = True
|
||
if changed:
|
||
with open(CONFIG_PATH, "w", encoding="utf-8") as wf:
|
||
json.dump(raw, wf, indent=2, ensure_ascii=False)
|
||
wf.write("\n")
|
||
|
||
_GATE = {"username": username, "api_key": api_key}
|
||
_PASSWORD_HASH = hash_password(password)
|
||
|
||
|
||
def create_web_token(user_id: int) -> str:
|
||
expire = datetime.now(timezone.utc) + timedelta(days=JWT_EXPIRE_DAYS)
|
||
payload = {"sub": str(user_id), "exp": expire, "typ": "web"}
|
||
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALG)
|
||
|
||
|
||
def decode_web_token(token: str) -> int:
|
||
try:
|
||
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALG])
|
||
if payload.get("typ") != "web":
|
||
raise JWTError("wrong token type")
|
||
sub = payload.get("sub")
|
||
if not sub:
|
||
raise JWTError("missing sub")
|
||
return int(sub)
|
||
except JWTError as e:
|
||
raise HTTPException(status_code=401, detail="无效或过期的登录状态") from e
|
||
|
||
|
||
@dataclass
|
||
class GateSessionUser:
|
||
username: str
|
||
api_key: str
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 节点 / 模型配置(nodes.json)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _default_nodes_payload() -> Dict[str, Any]:
|
||
return {
|
||
"nodes": [
|
||
{
|
||
"id": "node-1",
|
||
"name": "节点 1",
|
||
"host": "127.0.0.1",
|
||
"port": 3313,
|
||
"enabled": True,
|
||
"max_concurrent": 1,
|
||
"models": [],
|
||
},
|
||
{
|
||
"id": "node-2",
|
||
"name": "节点 2",
|
||
"host": "127.0.0.1",
|
||
"port": 3314,
|
||
"enabled": True,
|
||
"max_concurrent": 1,
|
||
"models": [],
|
||
},
|
||
{
|
||
"id": "node-3",
|
||
"name": "节点 3",
|
||
"host": "127.0.0.1",
|
||
"port": 3315,
|
||
"enabled": True,
|
||
"max_concurrent": 1,
|
||
"models": [],
|
||
},
|
||
]
|
||
}
|
||
|
||
|
||
def ensure_nodes_file() -> None:
|
||
if not os.path.isfile(NODES_PATH):
|
||
example = NODES_PATH + ".example"
|
||
if os.path.isfile(example):
|
||
with open(example, encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
else:
|
||
data = _default_nodes_payload()
|
||
save_nodes_config(data)
|
||
|
||
|
||
def load_nodes_config() -> None:
|
||
global _NODES, _node_runtime
|
||
ensure_nodes_file()
|
||
with open(NODES_PATH, encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
nodes = data.get("nodes")
|
||
if not isinstance(nodes, list):
|
||
nodes = []
|
||
_NODES = {"nodes": nodes}
|
||
with _nodes_lock:
|
||
for n in nodes:
|
||
nid = str(n.get("id", ""))
|
||
if nid and nid not in _node_runtime:
|
||
_node_runtime[nid] = {
|
||
"in_flight": 0,
|
||
"healthy": None,
|
||
"last_check": None,
|
||
"last_error": "",
|
||
}
|
||
|
||
|
||
def save_nodes_config(data: Optional[Dict[str, Any]] = None) -> None:
|
||
global _NODES
|
||
payload = data if data is not None else _NODES
|
||
with open(NODES_PATH, "w", encoding="utf-8") as f:
|
||
json.dump(payload, f, indent=2, ensure_ascii=False)
|
||
f.write("\n")
|
||
load_nodes_config()
|
||
|
||
|
||
def get_node_base(node: Dict[str, Any]) -> str:
|
||
host = str(node.get("host", "127.0.0.1")).strip() or "127.0.0.1"
|
||
port = int(node.get("port", 0))
|
||
return f"http://{host}:{port}".rstrip("/")
|
||
|
||
|
||
def _runtime(nid: str) -> Dict[str, Any]:
|
||
with _nodes_lock:
|
||
if nid not in _node_runtime:
|
||
_node_runtime[nid] = {
|
||
"in_flight": 0,
|
||
"healthy": None,
|
||
"last_check": None,
|
||
"last_error": "",
|
||
}
|
||
return _node_runtime[nid]
|
||
|
||
|
||
def node_in_flight_inc(nid: str) -> None:
|
||
with _nodes_lock:
|
||
rt = _runtime(nid)
|
||
rt["in_flight"] = int(rt.get("in_flight", 0)) + 1
|
||
|
||
|
||
def node_in_flight_dec(nid: str) -> None:
|
||
with _nodes_lock:
|
||
rt = _runtime(nid)
|
||
rt["in_flight"] = max(0, int(rt.get("in_flight", 0)) - 1)
|
||
|
||
|
||
def compute_node_status(node: Dict[str, Any]) -> Tuple[str, str]:
|
||
nid = str(node.get("id", ""))
|
||
if not node.get("enabled", True):
|
||
return "disabled", STATUS_LABELS["disabled"]
|
||
rt = _runtime(nid)
|
||
if rt.get("healthy") is False:
|
||
err = str(rt.get("last_error") or "")
|
||
if err:
|
||
return "error", STATUS_LABELS["error"]
|
||
return "offline", STATUS_LABELS["offline"]
|
||
inflight = int(rt.get("in_flight", 0))
|
||
max_c = int(node.get("max_concurrent", 1) or 1)
|
||
if inflight >= max_c:
|
||
return "busy", STATUS_LABELS["busy"]
|
||
if rt.get("healthy") is True:
|
||
return "idle", STATUS_LABELS["idle"]
|
||
return "offline", STATUS_LABELS["offline"]
|
||
|
||
|
||
def find_nodes_for_model(model_id: str) -> List[Dict[str, Any]]:
|
||
if not model_id:
|
||
return []
|
||
out: List[Dict[str, Any]] = []
|
||
for node in _NODES.get("nodes", []):
|
||
if not node.get("enabled", True):
|
||
continue
|
||
for m in node.get("models") or []:
|
||
if str(m.get("id", "")) == model_id:
|
||
out.append(node)
|
||
break
|
||
return out
|
||
|
||
|
||
def select_node_for_model(model_id: str) -> Optional[Dict[str, Any]]:
|
||
candidates = find_nodes_for_model(model_id)
|
||
if not candidates:
|
||
return None
|
||
scored: List[Tuple[int, str, Dict[str, Any]]] = []
|
||
for node in candidates:
|
||
status, _ = compute_node_status(node)
|
||
if status in ("disabled", "error"):
|
||
continue
|
||
if status == "busy":
|
||
continue
|
||
nid = str(node["id"])
|
||
scored.append((int(_runtime(nid).get("in_flight", 0)), nid, node))
|
||
if not scored:
|
||
for node in candidates:
|
||
if compute_node_status(node)[0] != "disabled":
|
||
nid = str(node["id"])
|
||
scored.append((int(_runtime(nid).get("in_flight", 0)), nid, node))
|
||
if not scored:
|
||
return None
|
||
scored.sort(key=lambda x: (x[0], x[1]))
|
||
return scored[0][2]
|
||
|
||
|
||
async def probe_node(
|
||
client: httpx.AsyncClient, node: Dict[str, Any]
|
||
) -> Tuple[bool, str]:
|
||
url = get_node_base(node) + "/v1/models"
|
||
try:
|
||
resp = await client.get(url, timeout=5.0)
|
||
if resp.status_code < 500:
|
||
return True, ""
|
||
return False, f"HTTP {resp.status_code}"
|
||
except httpx.RequestError as e:
|
||
return False, str(e)
|
||
|
||
|
||
async def refresh_all_node_health(client: httpx.AsyncClient) -> None:
|
||
for node in _NODES.get("nodes", []):
|
||
nid = str(node.get("id", ""))
|
||
if not nid:
|
||
continue
|
||
ok, err = await probe_node(client, node)
|
||
with _nodes_lock:
|
||
rt = _runtime(nid)
|
||
rt["healthy"] = ok
|
||
rt["last_error"] = err
|
||
rt["last_check"] = datetime.now(timezone.utc).isoformat()
|
||
|
||
|
||
async def _health_loop(client: httpx.AsyncClient) -> None:
|
||
while True:
|
||
try:
|
||
await refresh_all_node_health(client)
|
||
except Exception:
|
||
pass
|
||
await asyncio.sleep(30)
|
||
|
||
|
||
def _query_model_stats_today() -> Dict[str, Dict[str, int]]:
|
||
today = _today_utc_prefix()
|
||
out: Dict[str, Dict[str, int]] = {}
|
||
with _db_lock:
|
||
conn = sqlite3.connect(STATS_DB)
|
||
conn.row_factory = sqlite3.Row
|
||
try:
|
||
rows = conn.execute(
|
||
"""
|
||
SELECT model, node_id,
|
||
COUNT(*) AS requests,
|
||
COALESCE(SUM(total_tokens), 0) AS total_tokens
|
||
FROM api_logs
|
||
WHERE created_at LIKE ? AND model IS NOT NULL
|
||
GROUP BY model, node_id
|
||
""",
|
||
(today + "%",),
|
||
).fetchall()
|
||
finally:
|
||
conn.close()
|
||
for row in rows:
|
||
key = f"{row['model']}\0{row['node_id'] or ''}"
|
||
out[key] = {
|
||
"requests": int(row["requests"]),
|
||
"total_tokens": int(row["total_tokens"]),
|
||
}
|
||
return out
|
||
|
||
|
||
def build_model_cards() -> List[Dict[str, Any]]:
|
||
stats = _query_model_stats_today()
|
||
cards: List[Dict[str, Any]] = []
|
||
for node in _NODES.get("nodes", []):
|
||
nid = str(node.get("id", ""))
|
||
status, status_label = compute_node_status(node)
|
||
host = str(node.get("host", "127.0.0.1"))
|
||
port = int(node.get("port", 0))
|
||
models = node.get("models") or []
|
||
if not models:
|
||
cards.append(
|
||
{
|
||
"node_id": nid,
|
||
"node_name": node.get("name", nid),
|
||
"model_id": "",
|
||
"model_label": "(未配置模型)",
|
||
"host": host,
|
||
"port": port,
|
||
"endpoint": f"{host}:{port}",
|
||
"status": status if node.get("enabled", True) else "disabled",
|
||
"status_label": status_label
|
||
if node.get("enabled", True)
|
||
else STATUS_LABELS["disabled"],
|
||
"in_flight": int(_runtime(nid).get("in_flight", 0)),
|
||
"max_concurrent": int(node.get("max_concurrent", 1)),
|
||
"today_requests": 0,
|
||
"today_tokens": 0,
|
||
"last_error": str(_runtime(nid).get("last_error") or ""),
|
||
"enabled": bool(node.get("enabled", True)),
|
||
}
|
||
)
|
||
continue
|
||
for m in models:
|
||
mid = str(m.get("id", ""))
|
||
label = str(m.get("label") or mid)
|
||
st = status
|
||
st_label = status_label
|
||
if not node.get("enabled", True):
|
||
st, st_label = "disabled", STATUS_LABELS["disabled"]
|
||
sk = f"{mid}\0{nid}"
|
||
agg = stats.get(sk, {"requests": 0, "total_tokens": 0})
|
||
cards.append(
|
||
{
|
||
"node_id": nid,
|
||
"node_name": node.get("name", nid),
|
||
"model_id": mid,
|
||
"model_label": label,
|
||
"host": host,
|
||
"port": port,
|
||
"endpoint": f"{host}:{port}",
|
||
"status": st,
|
||
"status_label": st_label,
|
||
"in_flight": int(_runtime(nid).get("in_flight", 0)),
|
||
"max_concurrent": int(node.get("max_concurrent", 1)),
|
||
"today_requests": agg["requests"],
|
||
"today_tokens": agg["total_tokens"],
|
||
"last_error": str(_runtime(nid).get("last_error") or ""),
|
||
"enabled": bool(node.get("enabled", True)),
|
||
}
|
||
)
|
||
return cards
|
||
|
||
|
||
def list_nodes_with_status() -> List[Dict[str, Any]]:
|
||
rows: List[Dict[str, Any]] = []
|
||
for node in _NODES.get("nodes", []):
|
||
status, status_label = compute_node_status(node)
|
||
nid = str(node.get("id", ""))
|
||
rt = _runtime(nid)
|
||
item = dict(node)
|
||
item["status"] = status
|
||
item["status_label"] = status_label
|
||
item["in_flight"] = int(rt.get("in_flight", 0))
|
||
item["last_error"] = str(rt.get("last_error") or "")
|
||
item["last_check"] = rt.get("last_check")
|
||
rows.append(item)
|
||
return rows
|
||
|
||
|
||
def _new_node_id() -> str:
|
||
return "node-" + secrets.token_hex(4)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 访问统计(SQLite)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def get_client_ip(request: Request) -> str:
|
||
"""反代后取真实客户端 IP(优先 X-Forwarded-For / X-Real-IP)。"""
|
||
forwarded = request.headers.get("x-forwarded-for")
|
||
if forwarded:
|
||
return forwarded.split(",")[0].strip()
|
||
real_ip = request.headers.get("x-real-ip")
|
||
if real_ip:
|
||
return real_ip.strip()
|
||
if request.client and request.client.host:
|
||
return request.client.host
|
||
return "unknown"
|
||
|
||
|
||
def parse_model_from_body(body: bytes) -> Optional[str]:
|
||
try:
|
||
data = json.loads(body)
|
||
model = data.get("model")
|
||
return str(model).strip() if model else None
|
||
except (json.JSONDecodeError, TypeError, ValueError):
|
||
return None
|
||
|
||
|
||
def parse_usage_from_response(raw: bytes, content_type: str) -> Dict[str, int]:
|
||
out = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||
if not raw:
|
||
return out
|
||
try:
|
||
text = raw.decode("utf-8", errors="replace")
|
||
except Exception:
|
||
return out
|
||
|
||
usage: Optional[Dict[str, Any]] = None
|
||
ct = (content_type or "").lower()
|
||
if "event-stream" in ct or text.lstrip().startswith("data:"):
|
||
for line in text.splitlines():
|
||
line = line.strip()
|
||
if not line.startswith("data:"):
|
||
continue
|
||
payload = line[5:].strip()
|
||
if not payload or payload == "[DONE]":
|
||
continue
|
||
try:
|
||
obj = json.loads(payload)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
u = obj.get("usage")
|
||
if isinstance(u, dict) and (
|
||
u.get("total_tokens") or u.get("prompt_tokens") or u.get("completion_tokens")
|
||
):
|
||
usage = u
|
||
else:
|
||
try:
|
||
obj = json.loads(text)
|
||
u = obj.get("usage")
|
||
if isinstance(u, dict):
|
||
usage = u
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
if not usage:
|
||
return out
|
||
prompt = int(usage.get("prompt_tokens") or 0)
|
||
completion = int(usage.get("completion_tokens") or 0)
|
||
total = int(usage.get("total_tokens") or 0)
|
||
if not total:
|
||
total = prompt + completion
|
||
out["prompt_tokens"] = prompt
|
||
out["completion_tokens"] = completion
|
||
out["total_tokens"] = total
|
||
return out
|
||
|
||
|
||
def init_stats_db() -> None:
|
||
with _db_lock:
|
||
conn = sqlite3.connect(STATS_DB)
|
||
try:
|
||
conn.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS api_logs (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
created_at TEXT NOT NULL,
|
||
client_ip TEXT NOT NULL,
|
||
model TEXT,
|
||
node_id TEXT,
|
||
status_code INTEGER NOT NULL,
|
||
req_bytes INTEGER NOT NULL DEFAULT 0,
|
||
resp_bytes INTEGER NOT NULL DEFAULT 0,
|
||
prompt_tokens INTEGER NOT NULL DEFAULT 0,
|
||
completion_tokens INTEGER NOT NULL DEFAULT 0,
|
||
total_tokens INTEGER NOT NULL DEFAULT 0
|
||
)
|
||
"""
|
||
)
|
||
conn.execute(
|
||
"CREATE INDEX IF NOT EXISTS idx_api_logs_created_at ON api_logs(created_at)"
|
||
)
|
||
conn.execute(
|
||
"CREATE INDEX IF NOT EXISTS idx_api_logs_client_ip ON api_logs(client_ip)"
|
||
)
|
||
try:
|
||
conn.execute("ALTER TABLE api_logs ADD COLUMN node_id TEXT")
|
||
except sqlite3.OperationalError:
|
||
pass
|
||
conn.execute(
|
||
"CREATE INDEX IF NOT EXISTS idx_api_logs_node_id ON api_logs(node_id)"
|
||
)
|
||
conn.commit()
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
def _record_api_log_sync(
|
||
client_ip: str,
|
||
model: Optional[str],
|
||
status_code: int,
|
||
req_bytes: int,
|
||
resp_bytes: int,
|
||
usage: Dict[str, int],
|
||
node_id: Optional[str] = None,
|
||
) -> None:
|
||
created_at = datetime.now(timezone.utc).isoformat()
|
||
with _db_lock:
|
||
conn = sqlite3.connect(STATS_DB)
|
||
try:
|
||
conn.execute(
|
||
"""
|
||
INSERT INTO api_logs (
|
||
created_at, client_ip, model, node_id, status_code,
|
||
req_bytes, resp_bytes,
|
||
prompt_tokens, completion_tokens, total_tokens
|
||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
created_at,
|
||
client_ip,
|
||
model,
|
||
node_id,
|
||
status_code,
|
||
req_bytes,
|
||
resp_bytes,
|
||
usage["prompt_tokens"],
|
||
usage["completion_tokens"],
|
||
usage["total_tokens"],
|
||
),
|
||
)
|
||
conn.commit()
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
async def record_api_log(
|
||
client_ip: str,
|
||
model: Optional[str],
|
||
status_code: int,
|
||
req_bytes: int,
|
||
resp_bytes: int,
|
||
usage: Dict[str, int],
|
||
node_id: Optional[str] = None,
|
||
) -> None:
|
||
await run_in_thread(
|
||
_record_api_log_sync,
|
||
client_ip,
|
||
model,
|
||
status_code,
|
||
req_bytes,
|
||
resp_bytes,
|
||
usage,
|
||
node_id,
|
||
)
|
||
|
||
|
||
def _today_utc_prefix() -> str:
|
||
return datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||
|
||
|
||
def _query_stats_summary() -> Dict[str, Any]:
|
||
today = _today_utc_prefix()
|
||
with _db_lock:
|
||
conn = sqlite3.connect(STATS_DB)
|
||
conn.row_factory = sqlite3.Row
|
||
try:
|
||
row = conn.execute(
|
||
"""
|
||
SELECT
|
||
COUNT(*) AS total_requests,
|
||
COALESCE(SUM(total_tokens), 0) AS total_tokens,
|
||
COUNT(DISTINCT client_ip) AS unique_ips
|
||
FROM api_logs
|
||
"""
|
||
).fetchone()
|
||
today_row = conn.execute(
|
||
"""
|
||
SELECT
|
||
COUNT(*) AS today_requests,
|
||
COALESCE(SUM(total_tokens), 0) AS today_tokens
|
||
FROM api_logs
|
||
WHERE created_at LIKE ?
|
||
""",
|
||
(today + "%",),
|
||
).fetchone()
|
||
finally:
|
||
conn.close()
|
||
return {
|
||
"total_requests": int(row["total_requests"]),
|
||
"total_tokens": int(row["total_tokens"]),
|
||
"unique_ips": int(row["unique_ips"]),
|
||
"today_requests": int(today_row["today_requests"]),
|
||
"today_tokens": int(today_row["today_tokens"]),
|
||
}
|
||
|
||
|
||
def _query_stats_ips(limit: int = 200) -> List[Dict[str, Any]]:
|
||
with _db_lock:
|
||
conn = sqlite3.connect(STATS_DB)
|
||
conn.row_factory = sqlite3.Row
|
||
try:
|
||
rows = conn.execute(
|
||
"""
|
||
SELECT
|
||
client_ip AS ip,
|
||
COUNT(*) AS requests,
|
||
COALESCE(SUM(total_tokens), 0) AS total_tokens,
|
||
COALESCE(SUM(prompt_tokens), 0) AS prompt_tokens,
|
||
COALESCE(SUM(completion_tokens), 0) AS completion_tokens,
|
||
MAX(created_at) AS last_seen
|
||
FROM api_logs
|
||
GROUP BY client_ip
|
||
ORDER BY last_seen DESC
|
||
LIMIT ?
|
||
""",
|
||
(limit,),
|
||
).fetchall()
|
||
finally:
|
||
conn.close()
|
||
return [dict(r) for r in rows]
|
||
|
||
|
||
def _query_stats_billing(days: int) -> Dict[str, Any]:
|
||
since = (datetime.now(timezone.utc) - timedelta(days=days)).date().isoformat()
|
||
with _db_lock:
|
||
conn = sqlite3.connect(STATS_DB)
|
||
conn.row_factory = sqlite3.Row
|
||
try:
|
||
by_day = conn.execute(
|
||
"""
|
||
SELECT
|
||
substr(created_at, 1, 10) AS day,
|
||
COUNT(*) AS requests,
|
||
COALESCE(SUM(prompt_tokens), 0) AS prompt_tokens,
|
||
COALESCE(SUM(completion_tokens), 0) AS completion_tokens,
|
||
COALESCE(SUM(total_tokens), 0) AS total_tokens
|
||
FROM api_logs
|
||
WHERE substr(created_at, 1, 10) >= ?
|
||
GROUP BY day
|
||
ORDER BY day DESC
|
||
""",
|
||
(since,),
|
||
).fetchall()
|
||
by_ip = conn.execute(
|
||
"""
|
||
SELECT
|
||
client_ip AS ip,
|
||
COUNT(*) AS requests,
|
||
COALESCE(SUM(prompt_tokens), 0) AS prompt_tokens,
|
||
COALESCE(SUM(completion_tokens), 0) AS completion_tokens,
|
||
COALESCE(SUM(total_tokens), 0) AS total_tokens
|
||
FROM api_logs
|
||
WHERE substr(created_at, 1, 10) >= ?
|
||
GROUP BY client_ip
|
||
ORDER BY total_tokens DESC
|
||
LIMIT 100
|
||
""",
|
||
(since,),
|
||
).fetchall()
|
||
finally:
|
||
conn.close()
|
||
return {
|
||
"days": days,
|
||
"by_day": [dict(r) for r in by_day],
|
||
"by_ip": [dict(r) for r in by_ip],
|
||
}
|
||
|
||
|
||
def _query_stats_logs(limit: int, offset: int) -> Tuple[List[Dict[str, Any]], int]:
|
||
limit = max(1, min(limit, 500))
|
||
offset = max(0, offset)
|
||
with _db_lock:
|
||
conn = sqlite3.connect(STATS_DB)
|
||
conn.row_factory = sqlite3.Row
|
||
try:
|
||
total = conn.execute("SELECT COUNT(*) FROM api_logs").fetchone()[0]
|
||
rows = conn.execute(
|
||
"""
|
||
SELECT
|
||
id, created_at, client_ip, model, status_code,
|
||
req_bytes, resp_bytes,
|
||
prompt_tokens, completion_tokens, total_tokens
|
||
FROM api_logs
|
||
ORDER BY id DESC
|
||
LIMIT ? OFFSET ?
|
||
""",
|
||
(limit, offset),
|
||
).fetchall()
|
||
finally:
|
||
conn.close()
|
||
return [dict(r) for r in rows], int(total)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Pydantic
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class LoginBody(BaseModel):
|
||
username: str = Field(..., min_length=1)
|
||
password: str = Field(..., min_length=1)
|
||
|
||
|
||
class ModelEntryIn(BaseModel):
|
||
id: str = Field(..., min_length=1)
|
||
label: str = ""
|
||
|
||
|
||
class NodeBody(BaseModel):
|
||
name: str = Field(..., min_length=1)
|
||
host: str = "127.0.0.1"
|
||
port: int = Field(..., ge=1, le=65535)
|
||
enabled: bool = True
|
||
max_concurrent: int = Field(1, ge=1, le=32)
|
||
models: List[ModelEntryIn] = Field(default_factory=list)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 依赖
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
async def get_current_web_user(
|
||
creds: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security_bearer)],
|
||
) -> GateSessionUser:
|
||
if creds is None or creds.scheme.lower() != "bearer":
|
||
raise HTTPException(status_code=401, detail="需要登录(Bearer Token)")
|
||
token = creds.credentials
|
||
if token.startswith("sk-"):
|
||
raise HTTPException(status_code=401, detail="网页登录请使用登录接口返回的令牌,不是 API Key")
|
||
|
||
uid = decode_web_token(token)
|
||
if uid != GATE_WEB_UID:
|
||
raise HTTPException(status_code=401, detail="登录状态无效")
|
||
return GateSessionUser(username=_GATE["username"], api_key=_GATE["api_key"])
|
||
|
||
|
||
async def get_user_from_api_key(
|
||
creds: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security_bearer)],
|
||
) -> GateSessionUser:
|
||
if creds is None or creds.scheme.lower() != "bearer":
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail='请在 Header 中携带 Authorization: Bearer sk-...',
|
||
)
|
||
key = creds.credentials.strip()
|
||
if not key.startswith("sk-"):
|
||
raise HTTPException(status_code=401, detail="API Key 必须以 sk- 开头")
|
||
if key != _GATE.get("api_key"):
|
||
raise HTTPException(status_code=401, detail="无效的 API Key")
|
||
return GateSessionUser(username=_GATE["username"], api_key=key)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# HTML(Tailwind CDN)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
TW_SCRIPT = "https://cdn.tailwindcss.com"
|
||
|
||
SHELL_HEAD = f"""
|
||
<!DOCTYPE html>
|
||
<html lang="zh-CN">
|
||
<head>
|
||
<meta charset="utf-8"/>
|
||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||
<title>{{title}}</title>
|
||
<script src="{TW_SCRIPT}"></script>
|
||
<script>
|
||
tailwind.config = {{
|
||
theme: {{
|
||
extend: {{
|
||
fontFamily: {{
|
||
sans: ['Plus Jakarta Sans', 'system-ui', 'sans-serif'],
|
||
}},
|
||
colors: {{
|
||
ink: {{ 950: '#0b1220', 900: '#101827', 700: '#334155' }},
|
||
}},
|
||
}},
|
||
}},
|
||
}};
|
||
</script>
|
||
<link rel="preconnect" href="https://fonts.googleapis.com"/>
|
||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin/>
|
||
<link href="https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@400;500;600;700&display=swap" rel="stylesheet"/>
|
||
</head>
|
||
<body class="min-h-screen bg-gradient-to-br from-slate-950 via-indigo-950 to-slate-900 text-slate-100 antialiased" data-app-root="{APP_ROOT}">
|
||
<script>
|
||
window.__APP_ROOT__ = {_APP_ROOT_JSON};
|
||
function wgApi(path) {{
|
||
if (!path || path.charAt(0) !== '/') path = '/' + (path || '');
|
||
var root = (window.__APP_ROOT__ || '').replace(/\\/+$/, '');
|
||
return root + path;
|
||
}}
|
||
function wgFmtErr(res, data) {{
|
||
try {{
|
||
if (data && data.detail !== undefined) {{
|
||
var d = data.detail;
|
||
if (typeof d === 'string') return d;
|
||
if (Array.isArray(d))
|
||
return d.map(function (e) {{
|
||
var loc = e.loc ? e.loc.join('.') + ': ' : '';
|
||
return loc + (e.msg || JSON.stringify(e));
|
||
}}).join(';') || '校验失败';
|
||
if (typeof d === 'object') return JSON.stringify(d);
|
||
}}
|
||
}} catch (e) {{}}
|
||
return res && res.status ? ('HTTP ' + res.status + ' ' + (res.statusText || '')) : '网络或服务器错误';
|
||
}}
|
||
</script>
|
||
<div class="pointer-events-none fixed inset-0 overflow-hidden">
|
||
<div class="absolute -left-32 top-20 h-72 w-72 rounded-full bg-indigo-500/20 blur-3xl"></div>
|
||
<div class="absolute -right-20 bottom-10 h-96 w-96 rounded-full bg-cyan-500/15 blur-3xl"></div>
|
||
</div>
|
||
<header class="relative z-10 border-b border-white/10 bg-black/20 backdrop-blur-md">
|
||
<div class="mx-auto flex max-w-5xl items-center justify-between px-6 py-4">
|
||
<a href="{app_url('/')}" class="flex items-center gap-2 text-lg font-semibold tracking-tight text-white hover:text-indigo-200 transition">
|
||
<span class="inline-flex h-9 w-9 items-center justify-center rounded-xl bg-gradient-to-br from-indigo-400 to-cyan-400 text-slate-950 font-bold">AI</span>
|
||
<span>中转网关</span>
|
||
</a>
|
||
<nav class="flex items-center gap-6 text-sm font-medium text-slate-300">
|
||
<a href="{app_url('/')}" class="hover:text-white transition">首页</a>
|
||
<a href="{app_url('/login')}" class="hover:text-white transition">登录</a>
|
||
<a href="{app_url('/stats')}" class="hover:text-white transition">流量统计</a>
|
||
<a href="{app_url('/settings')}" class="hover:text-white transition">系统设置</a>
|
||
<a href="{app_url('/user')}" class="rounded-full bg-white/10 px-4 py-2 text-white ring-1 ring-white/15 hover:bg-white/15 transition">用户中心</a>
|
||
</nav>
|
||
</div>
|
||
</header>
|
||
<main class="relative z-10 mx-auto max-w-5xl px-6 py-14">
|
||
"""
|
||
|
||
|
||
SHELL_FOOT = """
|
||
</main>
|
||
<footer class="relative z-10 border-t border-white/10 py-8 text-center text-xs text-slate-500">
|
||
LLM Gateway · JSON 配置账号 · OpenAI 兼容
|
||
</footer>
|
||
</body>
|
||
</html>
|
||
"""
|
||
|
||
|
||
def page(title: str, inner: str) -> str:
|
||
return (
|
||
SHELL_HEAD.replace("{title}", title)
|
||
+ inner
|
||
+ SHELL_FOOT
|
||
)
|
||
|
||
|
||
HOME_HTML = page(
|
||
"中转网关",
|
||
f"""
|
||
<motion-wrap class="space-y-12">
|
||
<div>
|
||
<motion-wrap class="flex flex-wrap items-end justify-between gap-4">
|
||
<div>
|
||
<h1 class="text-2xl font-bold text-white">模型分布</h1>
|
||
<p class="mt-2 text-sm text-slate-400">各节点模型状态(约每 5 秒刷新;主机默认 127.0.0.1,经 frp 映射端口)</p>
|
||
</div>
|
||
<a href="{app_url('/settings')}" class="text-sm text-indigo-300 hover:text-white transition">管理节点与模型 →</a>
|
||
</div>
|
||
<div id="cards-loading" class="mt-6 text-slate-400">加载中…</div>
|
||
<motion-wrap id="cards-grid" class="mt-6 grid gap-4 sm:grid-cols-2 lg:grid-cols-3 hidden"></div>
|
||
<p id="cards-empty" class="mt-6 hidden text-sm text-slate-500">暂无模型卡片,请登录后在系统设置中添加节点与模型。</p>
|
||
</div>
|
||
<motion-wrap class="rounded-3xl bg-white/5 p-8 ring-1 ring-white/10 backdrop-blur">
|
||
<h2 class="text-lg font-semibold text-white">使用说明</h2>
|
||
<ul class="mt-4 space-y-3 text-sm leading-relaxed text-slate-300">
|
||
<li>· 对外统一访问本网关(宝塔反代端口 <span class="text-cyan-200">8150</span>),请求 <code class="rounded bg-white/10 px-1 text-cyan-200">/v1/chat/completions</code>,Header 携带 <code class="rounded bg-white/10 px-1 text-cyan-200">Authorization: Bearer sk-...</code>。</li>
|
||
<li>· JSON 里的 <code class="rounded bg-white/10 px-1 text-cyan-200">model</code> 须与系统设置中登记的模型 ID 一致,网关会转发到对应 <span class="text-white">127.0.0.1:端口</span>。</li>
|
||
<li>· 节点端口建议 <span class="text-white">3313–3318</span>(frp 映射);未启用的节点显示为「未启用」。</li>
|
||
<li>· 流量与 Token 见 <a href="{app_url('/stats')}" class="text-indigo-300 hover:underline">流量统计</a>;API Key 见 <a href="{app_url('/user')}" class="text-indigo-300 hover:underline">用户中心</a>(需登录)。</li>
|
||
</ul>
|
||
</div>
|
||
</div>
|
||
<script>
|
||
const statusRing = {{ idle: 'ring-emerald-500/40', busy: 'ring-amber-500/40', offline: 'ring-slate-500/30', disabled: 'ring-slate-600/30', error: 'ring-rose-500/40' }};
|
||
const statusBadge = {{ idle: 'bg-emerald-500/20 text-emerald-200', busy: 'bg-amber-500/20 text-amber-200', offline: 'bg-slate-500/20 text-slate-300', disabled: 'bg-slate-600/20 text-slate-400', error: 'bg-rose-500/20 text-rose-200' }};
|
||
function esc(s) {{ return String(s == null ? '' : s).replace(/&/g,'&').replace(/</g,'<').replace(/>/g,'>'); }}
|
||
function fmt(n) {{ return Number(n || 0).toLocaleString(); }}
|
||
function renderCards(cards) {{
|
||
const grid = document.getElementById('cards-grid');
|
||
const empty = document.getElementById('cards-empty');
|
||
grid.innerHTML = '';
|
||
if (!cards.length) {{ grid.classList.add('hidden'); empty.classList.remove('hidden'); return; }}
|
||
empty.classList.add('hidden'); grid.classList.remove('hidden');
|
||
cards.forEach(c => {{
|
||
const st = c.status || 'offline';
|
||
const el = document.createElement('article');
|
||
el.className = 'rounded-2xl bg-white/5 p-5 ring-1 ' + (statusRing[st]||statusRing.offline) + ' backdrop-blur flex flex-col gap-3';
|
||
el.innerHTML = '<motion-wrap class="flex items-start justify-between gap-2"><h3 class="font-semibold text-white leading-snug">'+esc(c.model_label||c.model_id||'未命名')+'</h3><span class="shrink-0 rounded-full px-2.5 py-0.5 text-xs font-medium '+(statusBadge[st]||statusBadge.offline)+'">'+esc(c.status_label)+'</span></div>'
|
||
+ '<p class="text-xs text-slate-400 font-mono">'+esc(c.model_id||'—')+'</p>'
|
||
+ '<p class="text-sm text-slate-300">'+esc(c.node_name)+' · <span class="text-cyan-200/90">'+esc(c.endpoint)+'</span></p>'
|
||
+ '<motion-wrap class="grid grid-cols-2 gap-2 text-xs text-slate-400 mt-auto pt-2 border-t border-white/10"><div>进行中 <span class="text-white font-medium">'+esc(c.in_flight)+'/'+esc(c.max_concurrent)+'</span></div><div>今日 Token <span class="text-cyan-200 font-medium">'+fmt(c.today_tokens)+'</span></div><div>今日请求 <span class="text-white font-medium">'+fmt(c.today_requests)+'</span></div><div></div></div>';
|
||
if (c.last_error && (st==='error'||st==='offline')) {{ const e=document.createElement('p'); e.className='text-xs text-rose-300/90 truncate'; e.title=c.last_error; e.textContent=c.last_error; el.appendChild(e); }}
|
||
grid.appendChild(el);
|
||
}});
|
||
}}
|
||
async function refreshCards() {{
|
||
try {{
|
||
const res = await fetch(wgApi('/api/models/cards'));
|
||
const data = await res.json().catch(() => ({{}}));
|
||
document.getElementById('cards-loading').classList.add('hidden');
|
||
if (res.ok) renderCards(Array.isArray(data) ? data : (data.cards || []));
|
||
}} catch (_) {{}}
|
||
}}
|
||
refreshCards(); setInterval(refreshCards, 5000);
|
||
</script>
|
||
""",
|
||
)
|
||
|
||
|
||
LOGIN_HTML = page(
|
||
"登录",
|
||
f"""
|
||
<div class="mx-auto max-w-md">
|
||
<div class="rounded-3xl bg-white/5 p-8 ring-1 ring-white/10 backdrop-blur">
|
||
<h1 class="text-2xl font-bold text-white">登录</h1>
|
||
<p class="mt-2 text-sm text-slate-400">账号与密码与 <code class="text-cyan-200/90">gateway.json</code> 中一致</p>
|
||
<form id="f" class="mt-8 space-y-5">
|
||
<div>
|
||
<label class="block text-sm font-medium text-slate-300">用户名</label>
|
||
<input name="username" required autocomplete="username"
|
||
class="mt-2 w-full rounded-xl border border-white/10 bg-black/30 px-4 py-3 text-white outline-none focus:ring-2 focus:ring-indigo-500 transition"/>
|
||
</div>
|
||
<div>
|
||
<label class="block text-sm font-medium text-slate-300">密码</label>
|
||
<input name="password" type="password" required autocomplete="current-password"
|
||
class="mt-2 w-full rounded-xl border border-white/10 bg-black/30 px-4 py-3 text-white outline-none focus:ring-2 focus:ring-indigo-500 transition"/>
|
||
</div>
|
||
<p id="err" class="hidden text-sm text-rose-400"></p>
|
||
<button type="submit" class="w-full rounded-xl bg-white/10 py-3 text-sm font-semibold text-white ring-1 ring-white/15 hover:bg-white/15 transition">
|
||
登录
|
||
</button>
|
||
</form>
|
||
</div>
|
||
</div>
|
||
<script>
|
||
document.getElementById('f').addEventListener('submit', async (e) => {{
|
||
e.preventDefault();
|
||
const fd = new FormData(e.target);
|
||
const body = {{ username: String(fd.get('username')||'').trim(), password: String(fd.get('password')||'') }};
|
||
const err = document.getElementById('err');
|
||
err.classList.add('hidden');
|
||
try {{
|
||
const res = await fetch(wgApi('/api/login'), {{ method: 'POST', headers: {{ 'Content-Type': 'application/json' }}, body: JSON.stringify(body) }});
|
||
const raw = await res.text();
|
||
let data = {{}};
|
||
try {{ data = raw ? JSON.parse(raw) : {{}}; }} catch (_) {{}}
|
||
if (!res.ok) {{ err.textContent = wgFmtErr(res, data) || raw.slice(0, 200) || '登录失败'; err.classList.remove('hidden'); return; }}
|
||
if (!data.access_token) {{ err.textContent = '服务端未返回令牌'; err.classList.remove('hidden'); return; }}
|
||
localStorage.setItem('web_token', data.access_token);
|
||
window.location.href = wgApi('/user');
|
||
}} catch (ex) {{
|
||
err.textContent = '无法连接服务器:' + (ex && ex.message ? ex.message : String(ex));
|
||
err.classList.remove('hidden');
|
||
}}
|
||
}});
|
||
</script>
|
||
""",
|
||
)
|
||
|
||
|
||
USER_HTML = page(
|
||
"用户中心",
|
||
"""
|
||
<div class="mx-auto max-w-lg">
|
||
<div class="rounded-3xl bg-white/5 p-8 ring-1 ring-white/10 backdrop-blur">
|
||
<h1 class="text-2xl font-bold text-white">用户中心</h1>
|
||
<p class="mt-2 text-sm text-slate-400">OpenAI 兼容调用密钥(与 gateway.json 中 api_key 一致)</p>
|
||
<div id="loading" class="mt-8 text-slate-400">加载中…</div>
|
||
<div id="panel" class="hidden mt-8 space-y-6">
|
||
<div>
|
||
<span class="text-xs font-semibold uppercase tracking-wider text-indigo-200">用户名</span>
|
||
<p id="uname" class="mt-2 text-lg font-medium text-white"></p>
|
||
</div>
|
||
<div>
|
||
<span class="text-xs font-semibold uppercase tracking-wider text-indigo-200">API Key</span>
|
||
<div class="mt-2 flex flex-col gap-3 sm:flex-row sm:items-center">
|
||
<code id="apikey" class="flex-1 break-all rounded-xl bg-black/40 px-4 py-3 text-sm text-cyan-200 ring-1 ring-white/10"></code>
|
||
<button id="copy" type="button" class="rounded-xl bg-gradient-to-r from-indigo-500 to-cyan-500 px-5 py-3 text-sm font-semibold text-slate-950 hover:brightness-110 transition shrink-0">
|
||
复制
|
||
</button>
|
||
</div>
|
||
</div>
|
||
<button id="logout" type="button" class="text-sm text-slate-400 hover:text-white underline underline-offset-4">退出登录</button>
|
||
</div>
|
||
<p id="err" class="hidden mt-6 text-sm text-rose-400"></p>
|
||
</div>
|
||
</div>
|
||
<script>
|
||
const token = localStorage.getItem('web_token');
|
||
if (!token) { window.location.href = wgApi('/login'); }
|
||
else {
|
||
fetch(wgApi('/api/me'), { headers: { 'Authorization': 'Bearer ' + token } })
|
||
.then(async (res) => {
|
||
const data = await res.json().catch(() => ({}));
|
||
document.getElementById('loading').classList.add('hidden');
|
||
if (!res.ok) {
|
||
document.getElementById('err').textContent = wgFmtErr(res, data) || data.detail || '加载失败';
|
||
document.getElementById('err').classList.remove('hidden');
|
||
return;
|
||
}
|
||
document.getElementById('panel').classList.remove('hidden');
|
||
document.getElementById('uname').textContent = data.username;
|
||
document.getElementById('apikey').textContent = data.api_key;
|
||
});
|
||
}
|
||
document.getElementById('copy').addEventListener('click', () => {
|
||
const k = document.getElementById('apikey').textContent;
|
||
navigator.clipboard.writeText(k).then(() => {
|
||
const b = document.getElementById('copy');
|
||
const t = b.textContent;
|
||
b.textContent = '已复制';
|
||
setTimeout(() => { b.textContent = t; }, 1500);
|
||
});
|
||
});
|
||
document.getElementById('logout').addEventListener('click', () => {
|
||
localStorage.removeItem('web_token');
|
||
window.location.href = wgApi('/login');
|
||
});
|
||
</script>
|
||
""",
|
||
)
|
||
|
||
|
||
|
||
SETTINGS_HTML = page(
|
||
"系统设置",
|
||
f"""
|
||
<div class="mx-auto max-w-3xl space-y-8">
|
||
<div>
|
||
<h1 class="text-2xl font-bold text-white">系统设置</h1>
|
||
<p class="mt-2 text-sm text-slate-400">管理节点主机、端口与模型(默认主机 127.0.0.1,对应 frp 映射)</p>
|
||
</div>
|
||
<p id="msg" class="hidden text-sm"></p>
|
||
<motion-wrap id="nodes-list" class="space-y-4"></motion-wrap>
|
||
<motion-wrap class="rounded-3xl bg-white/5 p-6 ring-1 ring-white/10">
|
||
<h2 class="text-lg font-semibold text-white">添加节点</h2>
|
||
<form id="add-form" class="mt-4 space-y-4">
|
||
<motion-wrap class="grid gap-4 sm:grid-cols-2">
|
||
<div><label class="text-sm text-slate-300">节点名称</label>
|
||
<input name="name" required class="mt-1 w-full rounded-xl border border-white/10 bg-black/30 px-3 py-2 text-white"/></motion-wrap>
|
||
<div><label class="text-sm text-slate-300">主机</label>
|
||
<input name="host" value="127.0.0.1" class="mt-1 w-full rounded-xl border border-white/10 bg-black/30 px-3 py-2 text-white"/></motion-wrap>
|
||
<div><label class="text-sm text-slate-300">端口</label>
|
||
<input name="port" type="number" required min="1" max="65535" placeholder="3313" class="mt-1 w-full rounded-xl border border-white/10 bg-black/30 px-3 py-2 text-white"/></motion-wrap>
|
||
<div><label class="text-sm text-slate-300">最大并发</label>
|
||
<input name="max_concurrent" type="number" value="1" min="1" max="32" class="mt-1 w-full rounded-xl border border-white/10 bg-black/30 px-3 py-2 text-white"/></motion-wrap>
|
||
</motion-wrap>
|
||
<div><label class="text-sm text-slate-300">模型列表 <span class="text-slate-500">(每行:模型ID|显示名,显示名可省略)</span></label>
|
||
<textarea name="models" rows="4" placeholder="qwen2.5:14b|千问14B llama3|Llama3" class="mt-1 w-full rounded-xl border border-white/10 bg-black/30 px-3 py-2 text-sm text-white font-mono"></textarea>
|
||
</motion-wrap>
|
||
<label class="flex items-center gap-2 text-sm text-slate-300"><input name="enabled" type="checkbox" checked class="rounded"/> 启用节点</label>
|
||
<button type="submit" class="rounded-xl bg-gradient-to-r from-indigo-500 to-cyan-500 px-5 py-2.5 text-sm font-semibold text-slate-950">添加节点</button>
|
||
</form>
|
||
</motion-wrap>
|
||
</motion-wrap>
|
||
<script>
|
||
const token = localStorage.getItem('web_token');
|
||
if (!token) location.href = wgApi('/login');
|
||
const hdrs = {{ 'Authorization': 'Bearer ' + token, 'Content-Type': 'application/json' }};
|
||
function esc(s) {{ return String(s==null?'':s).replace(/&/g,'&').replace(/</g,'<'); }}
|
||
function showMsg(text, ok) {{
|
||
const m = document.getElementById('msg');
|
||
m.textContent = text;
|
||
m.className = 'text-sm ' + (ok ? 'text-emerald-400' : 'text-rose-400');
|
||
m.classList.remove('hidden');
|
||
setTimeout(() => m.classList.add('hidden'), 4000);
|
||
}}
|
||
function parseModels(text) {{
|
||
return String(text||'').split(/\\n/).map(l => l.trim()).filter(Boolean).map(line => {{
|
||
const p = line.split('|');
|
||
return {{ id: p[0].trim(), label: (p[1]||p[0]).trim() }};
|
||
}}).filter(m => m.id);
|
||
}}
|
||
function modelsToText(models) {{
|
||
return (models||[]).map(m => m.id + (m.label && m.label !== m.id ? '|' + m.label : '')).join('\\n');
|
||
}}
|
||
async function api(method, path, body) {{
|
||
const res = await fetch(wgApi(path), {{ method, headers: hdrs, body: body ? JSON.stringify(body) : undefined }});
|
||
const data = await res.json().catch(() => ({{}}));
|
||
if (!res.ok) throw new Error(wgFmtErr(res, data) || '请求失败');
|
||
return data;
|
||
}}
|
||
function renderNodes(nodes) {{
|
||
const box = document.getElementById('nodes-list');
|
||
box.innerHTML = '';
|
||
nodes.forEach(n => {{
|
||
const card = document.createElement('div');
|
||
card.className = 'rounded-2xl bg-white/5 p-5 ring-1 ring-white/10 space-y-3';
|
||
const modelsTxt = modelsToText(n.models);
|
||
card.innerHTML = '<motion-wrap class="flex flex-wrap items-center justify-between gap-2"><div><span class="font-semibold text-white">'+esc(n.name)+'</span> <span class="text-xs rounded-full px-2 py-0.5 bg-white/10 text-slate-300">'+esc(n.status_label)+'</span></motion-wrap><motion-wrap class="flex gap-2"><button data-test="'+esc(n.id)+'" class="text-xs text-indigo-300 hover:underline">测试</button><button data-del="'+esc(n.id)+'" class="text-xs text-rose-400 hover:underline">删除</button></motion-wrap></motion-wrap>'
|
||
+ '<p class="text-sm text-cyan-200/90 font-mono">'+esc(n.host)+':'+esc(n.port)+' · 进行中 '+esc(n.in_flight)+'/'+esc(n.max_concurrent)+'</p>'
|
||
+ '<motion-wrap class="grid gap-2 sm:grid-cols-2 text-sm"><input data-f="name" value="'+esc(n.name)+'" class="rounded-lg bg-black/30 border border-white/10 px-2 py-1 text-white"/>'
|
||
+ '<input data-f="host" value="'+esc(n.host)+'" class="rounded-lg bg-black/30 border border-white/10 px-2 py-1 text-white"/>'
|
||
+ '<input data-f="port" type="number" value="'+esc(n.port)+'" class="rounded-lg bg-black/30 border border-white/10 px-2 py-1 text-white"/>'
|
||
+ '<input data-f="max_concurrent" type="number" value="'+esc(n.max_concurrent)+'" class="rounded-lg bg-black/30 border border-white/10 px-2 py-1 text-white"/></motion-wrap>'
|
||
+ '<textarea data-f="models" rows="3" class="w-full rounded-lg bg-black/30 border border-white/10 px-2 py-1 text-xs text-white font-mono">'+esc(modelsTxt)+'</textarea>'
|
||
+ '<label class="flex items-center gap-2 text-sm"><input data-f="enabled" type="checkbox" '+(n.enabled?'checked':'')+'> 启用</label>'
|
||
+ '<button data-save="'+esc(n.id)+'" class="rounded-lg bg-white/10 px-3 py-1.5 text-sm text-white ring-1 ring-white/15">保存修改</button>';
|
||
card.querySelector('[data-test]').onclick = async () => {{
|
||
try {{ const r = await api('POST', '/api/nodes/'+n.id+'/test'); showMsg(r.ok ? '连接成功' : ('失败: '+(r.error||'')), r.ok); load(); }} catch(e) {{ showMsg(e.message, false); }}
|
||
}};
|
||
card.querySelector('[data-del]').onclick = async () => {{
|
||
if (!confirm('确定删除该节点?')) return;
|
||
try {{ await api('DELETE', '/api/nodes/'+n.id); showMsg('已删除', true); load(); }} catch(e) {{ showMsg(e.message, false); }}
|
||
}};
|
||
card.querySelector('[data-save]').onclick = async () => {{
|
||
const body = {{
|
||
name: card.querySelector('[data-f=name]').value.trim(),
|
||
host: card.querySelector('[data-f=host]').value.trim() || '127.0.0.1',
|
||
port: parseInt(card.querySelector('[data-f=port]').value, 10),
|
||
max_concurrent: parseInt(card.querySelector('[data-f=max_concurrent]').value, 10) || 1,
|
||
enabled: card.querySelector('[data-f=enabled]').checked,
|
||
models: parseModels(card.querySelector('[data-f=models]').value),
|
||
}};
|
||
try {{ await api('PUT', '/api/nodes/'+n.id, body); showMsg('已保存', true); load(); }} catch(e) {{ showMsg(e.message, false); }}
|
||
}};
|
||
box.appendChild(card);
|
||
}});
|
||
}}
|
||
async function load() {{ renderNodes(await api('GET', '/api/nodes')); }}
|
||
document.getElementById('add-form').onsubmit = async (e) => {{
|
||
e.preventDefault();
|
||
const fd = new FormData(e.target);
|
||
const body = {{
|
||
name: String(fd.get('name')).trim(),
|
||
host: String(fd.get('host')||'127.0.0.1').trim() || '127.0.0.1',
|
||
port: parseInt(String(fd.get('port')), 10),
|
||
max_concurrent: parseInt(String(fd.get('max_concurrent')), 10) || 1,
|
||
enabled: !!fd.get('enabled'),
|
||
models: parseModels(String(fd.get('models')||'')),
|
||
}};
|
||
try {{ await api('POST', '/api/nodes', body); showMsg('节点已添加', true); e.target.reset(); e.target.host.value='127.0.0.1'; load(); }} catch(err) {{ showMsg(err.message, false); }}
|
||
}};
|
||
load().catch(e => showMsg(e.message, false));
|
||
</script>
|
||
""",
|
||
).replace("<div>", "<div>").replace("</motion-wrap>", "</motion-wrap>")
|
||
|
||
|
||
STATS_HTML = page(
|
||
"流量统计",
|
||
f"""
|
||
<div>
|
||
<div class="mx-auto max-w-4xl space-y-8">
|
||
<div>
|
||
<h1 class="text-2xl font-bold text-white">流量统计</h1>
|
||
<p class="mt-2 text-sm text-slate-400">客户端 IP 与 Token 用量(UTC 日期汇总;反代需传递 X-Forwarded-For)</p>
|
||
</div>
|
||
|
||
<div id="loading" class="text-slate-400">加载中…</div>
|
||
<p id="err" class="hidden text-sm text-rose-400"></p>
|
||
|
||
<div id="panel" class="hidden space-y-8">
|
||
<div class="grid gap-4 sm:grid-cols-2 lg:grid-cols-4">
|
||
<div class="rounded-2xl bg-white/5 p-5 ring-1 ring-white/10">
|
||
<p class="text-xs font-semibold uppercase tracking-wider text-indigo-200">今日请求</p>
|
||
<p id="s-today-req" class="mt-2 text-2xl font-bold text-white">—</p>
|
||
</div>
|
||
<div class="rounded-2xl bg-white/5 p-5 ring-1 ring-white/10">
|
||
<p class="text-xs font-semibold uppercase tracking-wider text-indigo-200">今日 Token</p>
|
||
<p id="s-today-tok" class="mt-2 text-2xl font-bold text-cyan-300">—</p>
|
||
</div>
|
||
<div class="rounded-2xl bg-white/5 p-5 ring-1 ring-white/10">
|
||
<p class="text-xs font-semibold uppercase tracking-wider text-indigo-200">累计 Token</p>
|
||
<p id="s-total-tok" class="mt-2 text-2xl font-bold text-white">—</p>
|
||
</div>
|
||
<div class="rounded-2xl bg-white/5 p-5 ring-1 ring-white/10">
|
||
<p class="text-xs font-semibold uppercase tracking-wider text-indigo-200">独立 IP</p>
|
||
<p id="s-unique-ip" class="mt-2 text-2xl font-bold text-white">—</p>
|
||
</div>
|
||
</div>
|
||
|
||
<div class="rounded-3xl bg-white/5 p-6 ring-1 ring-white/10 backdrop-blur">
|
||
<h2 class="text-lg font-semibold text-white">IP 列表</h2>
|
||
<p class="mt-1 text-sm text-slate-400">按最近访问排序</p>
|
||
<div class="mt-4 overflow-x-auto rounded-2xl ring-1 ring-white/10">
|
||
<table class="min-w-full text-left text-sm">
|
||
<thead class="bg-black/30 text-xs uppercase tracking-wider text-slate-400">
|
||
<tr>
|
||
<th class="px-4 py-3">IP</th>
|
||
<th class="px-4 py-3">请求数</th>
|
||
<th class="px-4 py-3">Token 合计</th>
|
||
<th class="px-4 py-3">最近访问 (UTC)</th>
|
||
</tr>
|
||
</thead>
|
||
<tbody id="ip-tbody" class="divide-y divide-white/5 text-slate-200"></tbody>
|
||
</table>
|
||
</div>
|
||
<p id="ip-empty" class="hidden mt-4 text-sm text-slate-500">暂无记录</p>
|
||
</div>
|
||
|
||
<div class="rounded-3xl bg-white/5 p-6 ring-1 ring-white/10 backdrop-blur">
|
||
<div class="flex flex-wrap items-end justify-between gap-4">
|
||
<div>
|
||
<h2 class="text-lg font-semibold text-white">Token 账单</h2>
|
||
<p class="mt-1 text-sm text-slate-400">按日汇总(近 <span id="bill-days-label">30</span> 天)</p>
|
||
</div>
|
||
<select id="bill-days" class="rounded-xl border border-white/10 bg-black/30 px-3 py-2 text-sm text-white outline-none focus:ring-2 focus:ring-indigo-500">
|
||
<option value="7">近 7 天</option>
|
||
<option value="30" selected>近 30 天</option>
|
||
<option value="90">近 90 天</option>
|
||
</select>
|
||
</div>
|
||
<div class="mt-4 overflow-x-auto rounded-2xl ring-1 ring-white/10">
|
||
<table class="min-w-full text-left text-sm">
|
||
<thead class="bg-black/30 text-xs uppercase tracking-wider text-slate-400">
|
||
<tr>
|
||
<th class="px-4 py-3">日期 (UTC)</th>
|
||
<th class="px-4 py-3">请求</th>
|
||
<th class="px-4 py-3">Prompt</th>
|
||
<th class="px-4 py-3">Completion</th>
|
||
<th class="px-4 py-3">合计</th>
|
||
</tr>
|
||
</thead>
|
||
<tbody id="bill-tbody" class="divide-y divide-white/5 text-slate-200"></tbody>
|
||
</table>
|
||
</div>
|
||
<p id="bill-empty" class="hidden mt-4 text-sm text-slate-500">该时段暂无账单</p>
|
||
|
||
<h3 class="mt-8 text-sm font-semibold uppercase tracking-wider text-indigo-200">同期按 IP 分摊</h3>
|
||
<div class="mt-3 overflow-x-auto rounded-2xl ring-1 ring-white/10">
|
||
<table class="min-w-full text-left text-sm">
|
||
<thead class="bg-black/30 text-xs uppercase tracking-wider text-slate-400">
|
||
<tr>
|
||
<th class="px-4 py-3">IP</th>
|
||
<th class="px-4 py-3">请求</th>
|
||
<th class="px-4 py-3">Prompt</th>
|
||
<th class="px-4 py-3">Completion</th>
|
||
<th class="px-4 py-3">合计</th>
|
||
</tr>
|
||
</thead>
|
||
<tbody id="bill-ip-tbody" class="divide-y divide-white/5 text-slate-200"></tbody>
|
||
</table>
|
||
</div>
|
||
</div>
|
||
|
||
<div class="rounded-3xl bg-white/5 p-6 ring-1 ring-white/10 backdrop-blur">
|
||
<h2 class="text-lg font-semibold text-white">最近请求</h2>
|
||
<div class="mt-4 overflow-x-auto rounded-2xl ring-1 ring-white/10">
|
||
<table class="min-w-full text-left text-sm">
|
||
<thead class="bg-black/30 text-xs uppercase tracking-wider text-slate-400">
|
||
<tr>
|
||
<th class="px-4 py-3">时间 (UTC)</th>
|
||
<th class="px-4 py-3">IP</th>
|
||
<th class="px-4 py-3">模型</th>
|
||
<th class="px-4 py-3">状态</th>
|
||
<th class="px-4 py-3">Token</th>
|
||
</tr>
|
||
</thead>
|
||
<tbody id="log-tbody" class="divide-y divide-white/5 text-slate-200"></tbody>
|
||
</table>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
<script>
|
||
const token = localStorage.getItem('web_token');
|
||
if (!token) {{ window.location.href = wgApi('/login'); }}
|
||
else {{
|
||
const hdrs = {{ 'Authorization': 'Bearer ' + token }};
|
||
function fmt(n) {{ return Number(n || 0).toLocaleString(); }}
|
||
function esc(s) {{
|
||
return String(s == null ? '' : s)
|
||
.replace(/&/g,'&').replace(/</g,'<').replace(/>/g,'>');
|
||
}}
|
||
function showErr(msg) {{
|
||
document.getElementById('loading').classList.add('hidden');
|
||
const el = document.getElementById('err');
|
||
el.textContent = msg;
|
||
el.classList.remove('hidden');
|
||
}}
|
||
async function apiGet(path) {{
|
||
const res = await fetch(wgApi(path), {{ headers: hdrs }});
|
||
const data = await res.json().catch(() => ({{}}));
|
||
if (!res.ok) throw new Error(wgFmtErr(res, data) || '请求失败');
|
||
return data;
|
||
}}
|
||
function fillIpTable(rows) {{
|
||
const tb = document.getElementById('ip-tbody');
|
||
tb.innerHTML = '';
|
||
if (!rows.length) {{
|
||
document.getElementById('ip-empty').classList.remove('hidden');
|
||
return;
|
||
}}
|
||
document.getElementById('ip-empty').classList.add('hidden');
|
||
rows.forEach(r => {{
|
||
const tr = document.createElement('tr');
|
||
tr.className = 'hover:bg-white/5';
|
||
tr.innerHTML = '<td class="px-4 py-3 font-mono text-cyan-200">' + esc(r.ip) + '</td>'
|
||
+ '<td class="px-4 py-3">' + fmt(r.requests) + '</td>'
|
||
+ '<td class="px-4 py-3">' + fmt(r.total_tokens) + '</td>'
|
||
+ '<td class="px-4 py-3 text-slate-400">' + esc((r.last_seen || '').replace('T',' ').slice(0,19)) + '</td>';
|
||
tb.appendChild(tr);
|
||
}});
|
||
}}
|
||
function fillBillTables(data) {{
|
||
const dayTb = document.getElementById('bill-tbody');
|
||
const ipTb = document.getElementById('bill-ip-tbody');
|
||
dayTb.innerHTML = '';
|
||
ipTb.innerHTML = '';
|
||
const days = data.by_day || [];
|
||
const ips = data.by_ip || [];
|
||
document.getElementById('bill-empty').classList.toggle('hidden', days.length > 0);
|
||
days.forEach(r => {{
|
||
const tr = document.createElement('tr');
|
||
tr.className = 'hover:bg-white/5';
|
||
tr.innerHTML = '<td class="px-4 py-3">' + esc(r.day) + '</td>'
|
||
+ '<td class="px-4 py-3">' + fmt(r.requests) + '</td>'
|
||
+ '<td class="px-4 py-3">' + fmt(r.prompt_tokens) + '</td>'
|
||
+ '<td class="px-4 py-3">' + fmt(r.completion_tokens) + '</td>'
|
||
+ '<td class="px-4 py-3 font-medium text-cyan-200">' + fmt(r.total_tokens) + '</td>';
|
||
dayTb.appendChild(tr);
|
||
}});
|
||
ips.forEach(r => {{
|
||
const tr = document.createElement('tr');
|
||
tr.className = 'hover:bg-white/5';
|
||
tr.innerHTML = '<td class="px-4 py-3 font-mono text-cyan-200">' + esc(r.ip) + '</td>'
|
||
+ '<td class="px-4 py-3">' + fmt(r.requests) + '</td>'
|
||
+ '<td class="px-4 py-3">' + fmt(r.prompt_tokens) + '</td>'
|
||
+ '<td class="px-4 py-3">' + fmt(r.completion_tokens) + '</td>'
|
||
+ '<td class="px-4 py-3 font-medium text-cyan-200">' + fmt(r.total_tokens) + '</td>';
|
||
ipTb.appendChild(tr);
|
||
}});
|
||
}}
|
||
function fillLogs(rows) {{
|
||
const tb = document.getElementById('log-tbody');
|
||
tb.innerHTML = '';
|
||
rows.forEach(r => {{
|
||
const tr = document.createElement('tr');
|
||
tr.className = 'hover:bg-white/5';
|
||
const st = r.status_code >= 400 ? 'text-rose-400' : 'text-emerald-400';
|
||
tr.innerHTML = '<td class="px-4 py-3 text-slate-400">' + esc((r.created_at || '').replace('T',' ').slice(0,19)) + '</td>'
|
||
+ '<td class="px-4 py-3 font-mono text-cyan-200">' + esc(r.client_ip) + '</td>'
|
||
+ '<td class="px-4 py-3">' + esc(r.model || '—') + '</td>'
|
||
+ '<td class="px-4 py-3 ' + st + '">' + esc(r.status_code) + '</td>'
|
||
+ '<td class="px-4 py-3">' + fmt(r.total_tokens) + '</td>';
|
||
tb.appendChild(tr);
|
||
}});
|
||
}}
|
||
async function loadBilling(days) {{
|
||
document.getElementById('bill-days-label').textContent = days;
|
||
const bill = await apiGet('/api/stats/billing?days=' + days);
|
||
fillBillTables(bill);
|
||
}}
|
||
async function boot() {{
|
||
try {{
|
||
const [summary, ips, logs] = await Promise.all([
|
||
apiGet('/api/stats/summary'),
|
||
apiGet('/api/stats/ips'),
|
||
apiGet('/api/stats/logs?limit=50'),
|
||
]);
|
||
document.getElementById('s-today-req').textContent = fmt(summary.today_requests);
|
||
document.getElementById('s-today-tok').textContent = fmt(summary.today_tokens);
|
||
document.getElementById('s-total-tok').textContent = fmt(summary.total_tokens);
|
||
document.getElementById('s-unique-ip').textContent = fmt(summary.unique_ips);
|
||
fillIpTable(ips);
|
||
fillLogs(logs.items || []);
|
||
const daysSel = document.getElementById('bill-days');
|
||
await loadBilling(daysSel.value);
|
||
daysSel.addEventListener('change', () => loadBilling(daysSel.value).catch(e => showErr(e.message)));
|
||
document.getElementById('loading').classList.add('hidden');
|
||
document.getElementById('panel').classList.remove('hidden');
|
||
}} catch (e) {{
|
||
showErr(e.message || String(e));
|
||
}}
|
||
}}
|
||
boot();
|
||
}}
|
||
</script>
|
||
""",
|
||
)
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||
load_gateway_config()
|
||
load_nodes_config()
|
||
init_stats_db()
|
||
timeout = httpx.Timeout(600.0, connect=30.0)
|
||
app.state.http = httpx.AsyncClient(timeout=timeout)
|
||
global _health_task
|
||
await refresh_all_node_health(app.state.http)
|
||
_health_task = asyncio.create_task(_health_loop(app.state.http))
|
||
try:
|
||
yield
|
||
finally:
|
||
if _health_task:
|
||
_health_task.cancel()
|
||
await app.state.http.aclose()
|
||
|
||
|
||
app = FastAPI(title="LLM Gateway", lifespan=lifespan)
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
@app.get("/", response_class=HTMLResponse)
|
||
async def home() -> str:
|
||
return HOME_HTML
|
||
|
||
|
||
@app.get("/login", response_class=HTMLResponse)
|
||
async def login_page() -> str:
|
||
return LOGIN_HTML
|
||
|
||
|
||
@app.get("/user", response_class=HTMLResponse)
|
||
async def user_page() -> str:
|
||
return USER_HTML
|
||
|
||
|
||
@app.get("/settings", response_class=HTMLResponse)
|
||
async def settings_page() -> str:
|
||
return SETTINGS_HTML
|
||
|
||
|
||
@app.get("/stats", response_class=HTMLResponse)
|
||
async def stats_page() -> str:
|
||
return STATS_HTML
|
||
|
||
|
||
@app.post("/api/login")
|
||
async def api_login(body: LoginBody) -> JSONResponse:
|
||
name = body.username.strip()
|
||
if name != _GATE["username"]:
|
||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||
if not verify_password(body.password, _PASSWORD_HASH):
|
||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||
token = create_web_token(GATE_WEB_UID)
|
||
return JSONResponse({"access_token": token, "token_type": "bearer"})
|
||
|
||
|
||
@app.get("/api/me")
|
||
async def api_me(
|
||
user: Annotated[GateSessionUser, Depends(get_current_web_user)],
|
||
) -> Dict[str, str]:
|
||
return {"username": user.username, "api_key": user.api_key}
|
||
|
||
|
||
|
||
@app.get("/api/models/cards")
|
||
async def api_model_cards() -> List[Dict[str, Any]]:
|
||
return await run_in_thread(build_model_cards)
|
||
|
||
|
||
@app.get("/api/nodes")
|
||
async def api_nodes(
|
||
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
|
||
) -> List[Dict[str, Any]]:
|
||
return await run_in_thread(list_nodes_with_status)
|
||
|
||
|
||
@app.post("/api/nodes")
|
||
async def api_nodes_create(
|
||
body: NodeBody,
|
||
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
|
||
) -> Dict[str, Any]:
|
||
node = {
|
||
"id": _new_node_id(),
|
||
"name": body.name.strip(),
|
||
"host": (body.host or "127.0.0.1").strip() or "127.0.0.1",
|
||
"port": body.port,
|
||
"enabled": body.enabled,
|
||
"max_concurrent": body.max_concurrent,
|
||
"models": [{"id": m.id.strip(), "label": (m.label or m.id).strip()} for m in body.models],
|
||
}
|
||
nodes = list(_NODES.get("nodes", [])) + [node]
|
||
await run_in_thread(save_nodes_config, {"nodes": nodes})
|
||
return node
|
||
|
||
|
||
@app.put("/api/nodes/{node_id}")
|
||
async def api_nodes_update(
|
||
node_id: str,
|
||
body: NodeBody,
|
||
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
|
||
) -> Dict[str, Any]:
|
||
nodes = list(_NODES.get("nodes", []))
|
||
idx = next((i for i, n in enumerate(nodes) if str(n.get("id")) == node_id), None)
|
||
if idx is None:
|
||
raise HTTPException(status_code=404, detail="节点不存在")
|
||
nodes[idx] = {
|
||
"id": node_id,
|
||
"name": body.name.strip(),
|
||
"host": (body.host or "127.0.0.1").strip() or "127.0.0.1",
|
||
"port": body.port,
|
||
"enabled": body.enabled,
|
||
"max_concurrent": body.max_concurrent,
|
||
"models": [{"id": m.id.strip(), "label": (m.label or m.id).strip()} for m in body.models],
|
||
}
|
||
await run_in_thread(save_nodes_config, {"nodes": nodes})
|
||
return nodes[idx]
|
||
|
||
|
||
@app.delete("/api/nodes/{node_id}")
|
||
async def api_nodes_delete(
|
||
node_id: str,
|
||
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
|
||
) -> JSONResponse:
|
||
nodes = [n for n in _NODES.get("nodes", []) if str(n.get("id")) != node_id]
|
||
if len(nodes) == len(_NODES.get("nodes", [])):
|
||
raise HTTPException(status_code=404, detail="节点不存在")
|
||
await run_in_thread(save_nodes_config, {"nodes": nodes})
|
||
return JSONResponse({"ok": True})
|
||
|
||
|
||
@app.post("/api/nodes/{node_id}/test")
|
||
async def api_nodes_test(
|
||
node_id: str,
|
||
request: Request,
|
||
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
|
||
) -> Dict[str, Any]:
|
||
node = next((n for n in _NODES.get("nodes", []) if str(n.get("id")) == node_id), None)
|
||
if node is None:
|
||
raise HTTPException(status_code=404, detail="节点不存在")
|
||
client: httpx.AsyncClient = request.app.state.http
|
||
ok, err = await probe_node(client, node)
|
||
with _nodes_lock:
|
||
rt = _runtime(node_id)
|
||
rt["healthy"] = ok
|
||
rt["last_error"] = err
|
||
rt["last_check"] = datetime.now(timezone.utc).isoformat()
|
||
return {"ok": ok, "error": err}
|
||
|
||
|
||
@app.get("/api/stats/summary")
|
||
async def api_stats_summary(
|
||
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
|
||
) -> Dict[str, Any]:
|
||
return await run_in_thread(_query_stats_summary)
|
||
|
||
|
||
@app.get("/api/stats/ips")
|
||
async def api_stats_ips(
|
||
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
|
||
) -> List[Dict[str, Any]]:
|
||
return await run_in_thread(_query_stats_ips)
|
||
|
||
|
||
@app.get("/api/stats/billing")
|
||
async def api_stats_billing(
|
||
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
|
||
days: int = Query(30, ge=1, le=365),
|
||
) -> Dict[str, Any]:
|
||
return await run_in_thread(_query_stats_billing, days)
|
||
|
||
|
||
@app.get("/api/stats/logs")
|
||
async def api_stats_logs(
|
||
_: Annotated[GateSessionUser, Depends(get_current_web_user)],
|
||
limit: int = Query(50, ge=1, le=500),
|
||
offset: int = Query(0, ge=0),
|
||
) -> Dict[str, Any]:
|
||
items, total = await run_in_thread(_query_stats_logs, limit, offset)
|
||
return {"items": items, "total": total, "limit": limit, "offset": offset}
|
||
|
||
|
||
@app.post("/v1/chat/completions")
|
||
async def chat_completions(
|
||
request: Request,
|
||
_: Annotated[GateSessionUser, Depends(get_user_from_api_key)],
|
||
) -> StreamingResponse:
|
||
"""
|
||
OpenAI 兼容聊天接口:校验 Bearer sk-xxx 后,将原始请求体转发到上游。
|
||
"""
|
||
body = await request.body()
|
||
client_ip = get_client_ip(request)
|
||
model = parse_model_from_body(body)
|
||
req_bytes = len(body)
|
||
selected_node: Optional[Dict[str, Any]] = None
|
||
if _NODES.get("nodes"):
|
||
if not model:
|
||
raise HTTPException(status_code=400, detail="请求体须包含 model 字段")
|
||
selected_node = select_node_for_model(model)
|
||
if selected_node is None:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail=f"未找到模型「{model}」的可用节点,请在系统设置中配置",
|
||
)
|
||
base = get_node_base(selected_node)
|
||
node_in_flight_inc(str(selected_node["id"]))
|
||
else:
|
||
base = UPSTREAM_BASE
|
||
url = f"{base}/v1/chat/completions"
|
||
fwd_headers = {
|
||
"Content-Type": request.headers.get("content-type") or "application/json",
|
||
"Accept": request.headers.get("accept") or "*/*",
|
||
}
|
||
|
||
client: httpx.AsyncClient = request.app.state.http
|
||
req = client.build_request("POST", url, headers=fwd_headers, content=body)
|
||
|
||
try:
|
||
resp = await client.send(req, stream=True)
|
||
except httpx.RequestError as e:
|
||
nid = str(selected_node["id"]) if selected_node else None
|
||
if selected_node:
|
||
node_in_flight_dec(nid)
|
||
await record_api_log(
|
||
client_ip,
|
||
model,
|
||
502,
|
||
req_bytes,
|
||
0,
|
||
parse_usage_from_response(b"", ""),
|
||
nid,
|
||
)
|
||
raise HTTPException(status_code=502, detail=f"上游连接失败: {e}") from e
|
||
|
||
hop_headers = (
|
||
"content-type",
|
||
"cache-control",
|
||
"openai-processing-ms",
|
||
"x-request-id",
|
||
)
|
||
out_headers = {
|
||
k: v
|
||
for k, v in resp.headers.items()
|
||
if k.lower() in hop_headers
|
||
}
|
||
content_type = resp.headers.get("content-type", "")
|
||
|
||
async def stream_body() -> AsyncIterator[bytes]:
|
||
buf = bytearray()
|
||
try:
|
||
async for chunk in resp.aiter_bytes():
|
||
buf.extend(chunk)
|
||
yield chunk
|
||
finally:
|
||
await resp.aclose()
|
||
usage = parse_usage_from_response(bytes(buf), content_type)
|
||
nid = str(selected_node["id"]) if selected_node else None
|
||
if selected_node:
|
||
node_in_flight_dec(nid)
|
||
await record_api_log(
|
||
client_ip,
|
||
model,
|
||
resp.status_code,
|
||
req_bytes,
|
||
len(buf),
|
||
usage,
|
||
nid,
|
||
)
|
||
|
||
return StreamingResponse(
|
||
stream_body(),
|
||
status_code=resp.status_code,
|
||
headers=out_headers,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
|
||
uvicorn.run("main:app", host="0.0.0.0", port=GATEWAY_PORT, reload=False)
|