"""从 sing-box Clash API / V2Ray gRPC 采集节点连接与流量。""" from __future__ import annotations import json import os import sqlite3 import time import urllib.error import urllib.request from pathlib import Path import grpc from db import connect, list_nodes ROOT = Path(os.environ.get("JIEDIAN_ROOT", Path(__file__).resolve().parents[1])) ENV_FILE = ROOT / ".env" CLASH_ADDR = "127.0.0.1:9090" V2RAY_ADDR = "127.0.0.1:9091" GRPC_METHOD = "/v2ray.core.app.stats.command.StatsService/QueryStats" _grpc_channel: grpc.Channel | None = None _speed_cache: dict[int, tuple[float, int, int]] = {} def _load_env() -> dict[str, str]: env: dict[str, str] = {} if not ENV_FILE.exists(): return env for line in ENV_FILE.read_text(encoding="utf-8").splitlines(): line = line.strip() if not line or line.startswith("#") or "=" not in line: continue key, _, value = line.partition("=") env[key.strip()] = value.strip() return env def format_bytes(num: int | float) -> str: n = float(num) for unit in ("B", "KB", "MB", "GB", "TB"): if n < 1024 or unit == "TB": if unit == "B": return f"{int(n)} B" return f"{n:.1f} {unit}" n /= 1024 return f"{n:.1f} PB" def format_speed(num: float) -> str: return f"{format_bytes(num)}/s" def _varint_encode(n: int) -> bytes: out = bytearray() while n > 0x7F: out.append((n & 0x7F) | 0x80) n >>= 7 out.append(n) return bytes(out) def _varint_decode(data: bytes, i: int) -> tuple[int, int]: shift = 0 result = 0 while i < len(data): b = data[i] i += 1 result |= (b & 0x7F) << shift if not (b & 0x80): return result, i shift += 7 raise ValueError("truncated varint") def _skip_field(data: bytes, i: int, wire_type: int) -> int: if wire_type == 0: _, i = _varint_decode(data, i) elif wire_type == 1: i += 8 elif wire_type == 2: length, i = _varint_decode(data, i) i += length elif wire_type == 5: i += 4 else: raise ValueError(f"unsupported wire type {wire_type}") return i def _encode_query_stats(name: str) -> bytes: if not name: return b"" payload = name.encode("utf-8") return bytes([0x0A]) + _varint_encode(len(payload)) + payload def _decode_stat_message(data: bytes) -> tuple[str, int | None]: name: str | None = None value: int | None = None i = 0 while i < len(data): tag = data[i] i += 1 field = tag >> 3 wire = tag & 0x07 if field == 1 and wire == 2: length, i = _varint_decode(data, i) name = data[i : i + length].decode("utf-8") i += length elif field == 2 and wire == 0: value, i = _varint_decode(data, i) else: i = _skip_field(data, i, wire) return name or "", value def _decode_query_stats_response(data: bytes) -> dict[str, int]: stats: dict[str, int] = {} i = 0 while i < len(data): tag = data[i] i += 1 field = tag >> 3 wire = tag & 0x07 if field == 1 and wire == 2: length, i = _varint_decode(data, i) name, value = _decode_stat_message(data[i : i + length]) i += length if name and value is not None: stats[name] = value else: i = _skip_field(data, i, wire) return stats def _grpc_channel_get() -> grpc.Channel: global _grpc_channel if _grpc_channel is None: _grpc_channel = grpc.insecure_channel(V2RAY_ADDR) return _grpc_channel def fetch_v2ray_user_stats() -> tuple[dict[str, tuple[int, int]], bool]: """返回 ({uuid: (upload_bytes, download_bytes)}, ok)。""" channel = _grpc_channel_get() method = channel.unary_unary( GRPC_METHOD, request_serializer=_encode_query_stats, response_deserializer=_decode_query_stats_response, ) try: raw = method(b"user>>>") except grpc.RpcError: return {}, False users: dict[str, tuple[int, int]] = {} for name, value in raw.items(): parts = name.split(">>>") if len(parts) != 4 or parts[0] != "user" or parts[2] != "traffic": continue uid, direction = parts[1], parts[3] up, down = users.get(uid, (0, 0)) if direction == "uplink": users[uid] = (value, down) elif direction == "downlink": users[uid] = (up, value) return users, True def fetch_clash_connections() -> tuple[list[dict], bool]: env = _load_env() secret = env.get("CLASH_API_SECRET", "") url = f"http://{CLASH_ADDR}/connections" req = urllib.request.Request(url) if secret: req.add_header("Authorization", f"Bearer {secret}") try: with urllib.request.urlopen(req, timeout=3) as resp: payload = json.loads(resp.read().decode("utf-8")) except (urllib.error.URLError, TimeoutError, json.JSONDecodeError, OSError): return [], False return payload.get("connections") or [], True def _match_connection(conn: dict, uuid: str) -> bool: meta = conn.get("metadata") or {} user = str(meta.get("user") or meta.get("uid") or "") return user == uuid def _ensure_traffic_schema(conn: sqlite3.Connection) -> None: conn.executescript( """ CREATE TABLE IF NOT EXISTS traffic_counters ( node_id INTEGER PRIMARY KEY, upload_total INTEGER NOT NULL DEFAULT 0, download_total INTEGER NOT NULL DEFAULT 0, snapshot_upload INTEGER NOT NULL DEFAULT 0, snapshot_download INTEGER NOT NULL DEFAULT 0, updated_at TEXT, FOREIGN KEY (node_id) REFERENCES nodes(id) ON DELETE CASCADE ); """ ) for row in conn.execute("SELECT id FROM nodes").fetchall(): conn.execute( "INSERT OR IGNORE INTO traffic_counters (node_id) VALUES (?)", (row["id"],), ) def _update_traffic_totals(node_id: int, raw_up: int, raw_down: int) -> tuple[int, int]: conn = connect() _ensure_traffic_schema(conn) row = conn.execute( "SELECT upload_total, download_total, snapshot_upload, snapshot_download " "FROM traffic_counters WHERE node_id = ?", (node_id,), ).fetchone() if row is None: conn.execute("INSERT INTO traffic_counters (node_id) VALUES (?)", (node_id,)) conn.commit() total_up, total_down, snap_up, snap_down = 0, 0, 0, 0 else: total_up = int(row["upload_total"]) total_down = int(row["download_total"]) snap_up = int(row["snapshot_upload"]) snap_down = int(row["snapshot_download"]) if raw_up < snap_up or raw_down < snap_down: total_up += snap_up total_down += snap_down snap_up = 0 snap_down = 0 total_up += max(0, raw_up - snap_up) total_down += max(0, raw_down - snap_down) conn.execute( """ UPDATE traffic_counters SET upload_total = ?, download_total = ?, snapshot_upload = ?, snapshot_download = ?, updated_at = datetime('now') WHERE node_id = ? """, (total_up, total_down, raw_up, raw_down, node_id), ) conn.commit() conn.close() return total_up, total_down def _calc_speed(node_id: int, up: int, down: int) -> tuple[float, float]: now = time.time() prev = _speed_cache.get(node_id) _speed_cache[node_id] = (now, up, down) if not prev: return 0.0, 0.0 t0, u0, d0 = prev dt = now - t0 if dt <= 0: return 0.0, 0.0 return max(0.0, (up - u0) / dt), max(0.0, (down - d0) / dt) def collect_node_stats() -> dict: nodes = list_nodes() v2ray, v2ray_ok = fetch_v2ray_user_stats() connections, clash_ok = fetch_clash_connections() singbox_ok = v2ray_ok or clash_ok result_nodes: dict[str, dict] = {} summary_online = 0 summary_up_speed = 0.0 summary_down_speed = 0.0 for node in nodes: uid = node["uuid"] node_id = int(node["id"]) raw_up, raw_down = v2ray.get(uid, (0, 0)) total_up, total_down = _update_traffic_totals(node_id, raw_up, raw_down) up_speed, down_speed = _calc_speed(node_id, raw_up, raw_down) matched = [c for c in connections if _match_connection(c, uid)] online = len(matched) > 0 or (up_speed + down_speed) > 512 if online: summary_online += 1 summary_up_speed += up_speed summary_down_speed += down_speed result_nodes[str(node_id)] = { "online": online, "connections": len(matched), "upload_speed": round(up_speed), "download_speed": round(down_speed), "upload_total": total_up, "download_total": total_down, "upload_speed_human": format_speed(up_speed), "download_speed_human": format_speed(down_speed), "upload_total_human": format_bytes(total_up), "download_total_human": format_bytes(total_down), } return { "ok": True, "singbox": singbox_ok, "nodes": result_nodes, "summary": { "online": summary_online, "total_nodes": len(nodes), "upload_speed": round(summary_up_speed), "download_speed": round(summary_down_speed), "upload_speed_human": format_speed(summary_up_speed), "download_speed_human": format_speed(summary_down_speed), }, }