273 lines
9.6 KiB
Python
273 lines
9.6 KiB
Python
"""
|
||
中控:聚合各子账户 /status,转发紧急全平。
|
||
|
||
默认 **HUB_HOST=0.0.0.0** 且 **HUB_TRUST_LAN=开启**,便于局域网内浏览器访问;中间件仍拒绝非公网、非 RFC1918 私网的来源(本机 127.0.0.1 始终允许)。
|
||
若仅需本机访问,请设置:HUB_HOST=127.0.0.1 或 HUB_TRUST_LAN=0(false/off)。
|
||
|
||
与仓库根目录下四个策略/监控项目对应时,中控默认聚合的子代理地址为 127.0.0.1:15200–15203
|
||
(与各 crypto_monitor_* 里 Flask 的 APP_PORT 错开;Flask 仍用各自 .env 的 APP_HOST/APP_PORT)。
|
||
|
||
crypto_monitor_binance → 子代理建议 15200
|
||
crypto_monitor_okx → 子代理建议 15201
|
||
crypto_monitor_gate → 子代理建议 15202
|
||
crypto_monitor_gate_bot→ 子代理建议 15203
|
||
|
||
各目录单独启动 agent.py 时设置 PORT=上述端口(环境变量名是 PORT,不是 APP_PORT),与 Flask 并存。
|
||
|
||
环境变量:
|
||
HUB_PORT 默认 5100
|
||
HUB_HOST 默认 0.0.0.0(局域网可连);改为 127.0.0.1 则仅本机
|
||
HUB_AGENTS 逗号分隔子代理 URL,留空则默认 15200–15203(避免与 Flask APP_PORT 冲突)
|
||
HUB_AGENT_NAMES 可选,逗号分隔显示名,与 URL 顺序对应
|
||
HUB_DISABLED_IDS 可选,逗号分隔不参与监控/全平的账户 id(与 /api/agents 中 id 一致),例:暂不用 OKX 时写 1
|
||
CONTROL_TOKEN 若子代理启用校验,在此填同一令牌(由中控代发请求头)
|
||
HUB_TRUST_LAN 默认开启;设为 0/false/off 则仅允许本机 IP 访问(与 HUB_HOST=0.0.0.0 搭配时仍只放行 127.0.0.1)
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import os
|
||
from pathlib import Path
|
||
|
||
import httpx
|
||
from fastapi import Body, FastAPI, HTTPException, Query, Request
|
||
from fastapi.responses import FileResponse, JSONResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from pydantic import BaseModel, Field
|
||
|
||
HUB_HOST = os.getenv("HUB_HOST", "0.0.0.0")
|
||
HUB_PORT = int(os.getenv("HUB_PORT", "5100"))
|
||
CONTROL_TOKEN = (os.getenv("CONTROL_TOKEN") or "").strip()
|
||
_trust_raw = (os.getenv("HUB_TRUST_LAN", "true") or "").strip().lower()
|
||
HUB_TRUST_LAN = _trust_raw not in ("0", "false", "no", "off")
|
||
DIR = Path(__file__).resolve().parent
|
||
|
||
|
||
def _is_local(host: str | None) -> bool:
|
||
if not host:
|
||
return False
|
||
h = host.lower()
|
||
return h in ("127.0.0.1", "::1", "localhost") or h.startswith("::ffff:127.0.0.1")
|
||
|
||
|
||
def _ipv4_rfc1918_private(host: str) -> bool:
|
||
h = host.lower()
|
||
if h.startswith("::ffff:"):
|
||
h = h[7:]
|
||
parts = h.split(".")
|
||
if len(parts) != 4:
|
||
return False
|
||
try:
|
||
a, b, c, d = (int(x) for x in parts)
|
||
except ValueError:
|
||
return False
|
||
if any(x < 0 or x > 255 for x in (a, b, c, d)):
|
||
return False
|
||
if a == 10:
|
||
return True
|
||
if a == 172 and 16 <= b <= 31:
|
||
return True
|
||
if a == 192 and b == 168:
|
||
return True
|
||
return False
|
||
|
||
|
||
def _client_allowed(host: str | None) -> bool:
|
||
if _is_local(host):
|
||
return True
|
||
if HUB_TRUST_LAN and host and _ipv4_rfc1918_private(host):
|
||
return True
|
||
return False
|
||
|
||
|
||
def _agent_headers() -> dict[str, str]:
|
||
if not CONTROL_TOKEN:
|
||
return {}
|
||
return {"X-Control-Token": CONTROL_TOKEN}
|
||
|
||
|
||
_DEFAULT_FOLDER_LABELS = (
|
||
"币安山寨账户 · crypto_monitor_binance",
|
||
"OKX · crypto_monitor_okx",
|
||
"Gate训练账户 · crypto_monitor_gate",
|
||
"Gate趋势回调 · crypto_monitor_gate_bot",
|
||
)
|
||
|
||
|
||
def _ids_from_csv(raw: str | None) -> set[str]:
|
||
if not raw or not str(raw).strip():
|
||
return set()
|
||
return {x.strip() for x in str(raw).split(",") if x.strip()}
|
||
|
||
|
||
def hub_env_excluded_ids() -> set[str]:
|
||
"""服务端固定关闭的账户(不参与拉取 /status、不参与全局全平)。"""
|
||
return _ids_from_csv(os.getenv("HUB_DISABLED_IDS"))
|
||
|
||
|
||
def merged_excluded_ids(query_exclude: str | None, body_ids: list[str] | None) -> set[str]:
|
||
s = hub_env_excluded_ids()
|
||
s |= _ids_from_csv(query_exclude)
|
||
if body_ids:
|
||
s |= {str(x).strip() for x in body_ids if str(x).strip()}
|
||
return s
|
||
|
||
|
||
def parse_agents() -> list[dict[str, str]]:
|
||
urls_s = (os.getenv("HUB_AGENTS") or "").strip()
|
||
if urls_s:
|
||
urls = [u.strip() for u in urls_s.split(",") if u.strip()]
|
||
else:
|
||
urls = [f"http://127.0.0.1:{p}" for p in range(15200, 15204)]
|
||
# 注意:若环境变量 HUB_AGENT_NAMES 非空,会完全优先于 _DEFAULT_FOLDER_LABELS(改代码不生效时请检查是否设了该变量)
|
||
names_s = (os.getenv("HUB_AGENT_NAMES") or "").strip()
|
||
names = [n.strip() for n in names_s.split(",") if n.strip()] if names_s else []
|
||
out = []
|
||
for i, url in enumerate(urls):
|
||
if i < len(names):
|
||
name = names[i]
|
||
elif i < len(_DEFAULT_FOLDER_LABELS):
|
||
name = _DEFAULT_FOLDER_LABELS[i]
|
||
else:
|
||
name = f"账户{i + 1}"
|
||
out.append({"id": str(i), "name": name, "url": url.rstrip("/")})
|
||
return out
|
||
|
||
|
||
app = FastAPI(title="hub", docs_url=None, redoc_url=None)
|
||
STATIC_DIR = DIR / "static"
|
||
if STATIC_DIR.is_dir():
|
||
app.mount("/assets", StaticFiles(directory=str(STATIC_DIR)), name="assets")
|
||
|
||
|
||
@app.middleware("http")
|
||
async def local_only(request: Request, call_next):
|
||
if request.client and not _client_allowed(request.client.host):
|
||
return JSONResponse({"detail": "forbidden"}, status_code=403)
|
||
return await call_next(request)
|
||
|
||
|
||
@app.get("/")
|
||
def index_page():
|
||
index = STATIC_DIR / "index.html"
|
||
if not index.is_file():
|
||
return JSONResponse({"detail": "missing static/index.html"}, status_code=500)
|
||
return FileResponse(index)
|
||
|
||
|
||
@app.get("/api/agents")
|
||
def api_agents():
|
||
return {"agents": parse_agents()}
|
||
|
||
|
||
class CloseAllBody(BaseModel):
|
||
exclude_ids: list[str] = Field(default_factory=list)
|
||
|
||
|
||
@app.get("/api/snapshot")
|
||
async def api_snapshot(
|
||
exclude_ids: str | None = Query(
|
||
default=None,
|
||
description="逗号分隔,浏览器侧再关闭的账户 id,与服务端 HUB_DISABLED_IDS 合并",
|
||
),
|
||
):
|
||
excl = merged_excluded_ids(exclude_ids, None)
|
||
agents = [a for a in parse_agents() if a["id"] not in excl]
|
||
headers = _agent_headers()
|
||
|
||
async def one(client: httpx.AsyncClient, a: dict[str, str]) -> dict:
|
||
url = f"{a['url']}/status"
|
||
try:
|
||
r = await client.get(url, headers=headers, timeout=10.0)
|
||
body = None
|
||
if r.content:
|
||
try:
|
||
body = r.json()
|
||
except Exception as je:
|
||
preview = (r.text or "")[:400].replace("\n", " ")
|
||
return {
|
||
"id": a["id"],
|
||
"name": a["name"],
|
||
"url": a["url"],
|
||
"http_ok": False,
|
||
"status_code": r.status_code,
|
||
"error": f"子代理返回非 JSON({je})。响应片段: {preview!r}",
|
||
"payload": None,
|
||
}
|
||
return {
|
||
"id": a["id"],
|
||
"name": a["name"],
|
||
"url": a["url"],
|
||
"http_ok": r.status_code == 200,
|
||
"status_code": r.status_code,
|
||
"payload": body,
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
"id": a["id"],
|
||
"name": a["name"],
|
||
"url": a["url"],
|
||
"http_ok": False,
|
||
"status_code": None,
|
||
"error": str(e),
|
||
"payload": None,
|
||
}
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
rows = await asyncio.gather(*[one(client, a) for a in agents])
|
||
env_ex = sorted(hub_env_excluded_ids())
|
||
return {"rows": list(rows), "env_excluded_ids": env_ex}
|
||
|
||
|
||
@app.post("/api/close/{agent_id}")
|
||
async def api_close_one(agent_id: str):
|
||
agents = parse_agents()
|
||
target = next((a for a in agents if a["id"] == agent_id), None)
|
||
if not target:
|
||
raise HTTPException(status_code=404, detail="unknown agent")
|
||
headers = _agent_headers()
|
||
url = f"{target['url']}/emergency/close-all"
|
||
try:
|
||
async with httpx.AsyncClient() as client:
|
||
r = await client.post(url, headers=headers, timeout=120.0)
|
||
try:
|
||
body = r.json()
|
||
except Exception:
|
||
body = {"raw": r.text[:2000]}
|
||
return {"agent": target, "status_code": r.status_code, "payload": body}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=502, detail=str(e)) from e
|
||
|
||
|
||
@app.post("/api/close-all")
|
||
async def api_close_all(body: CloseAllBody | None = Body(default=None)):
|
||
excl = merged_excluded_ids(None, body.exclude_ids if body else None)
|
||
agents = [a for a in parse_agents() if a["id"] not in excl]
|
||
headers = _agent_headers()
|
||
|
||
async def post_close(client: httpx.AsyncClient, a: dict[str, str]) -> dict:
|
||
url = f"{a['url']}/emergency/close-all"
|
||
try:
|
||
r = await client.post(url, headers=headers, timeout=120.0)
|
||
try:
|
||
body = r.json()
|
||
except Exception:
|
||
body = {"raw": r.text[:2000]}
|
||
return {"id": a["id"], "name": a["name"], "status_code": r.status_code, "payload": body}
|
||
except Exception as e:
|
||
return {"id": a["id"], "name": a["name"], "status_code": None, "error": str(e)}
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
results = await asyncio.gather(*[post_close(client, a) for a in agents])
|
||
return {"results": list(results)}
|
||
|
||
|
||
def main():
|
||
import uvicorn
|
||
|
||
uvicorn.run(app, host=HUB_HOST, port=HUB_PORT, log_level="warning", access_log=False)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|