refactor: 将共用代码迁入 lib/ 模块化目录

统一 strategy、key_monitor、trade、hub 等共用库到 lib/ 子包,并补充 lib-structure 文档,便于四所与中控维护。

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-07-02 16:23:09 +08:00
parent 4742a0bb9d
commit 5797d49d8a
190 changed files with 27946 additions and 27499 deletions
+1
View File
@@ -0,0 +1 @@
"""Shared library package."""
+36
View File
@@ -0,0 +1,36 @@
"""中控调用实例 API 时的鉴权(Flask request 头 X-Hub-Token)。SSO 见 hub_sso.py。"""
from __future__ import annotations
import os
from lib.hub.hub_sso import (
HUB_SSO_TTL_SEC,
hub_bridge_token,
mint_hub_sso_token,
safe_next_path,
verify_hub_sso_token,
)
__all__ = [
"HUB_SSO_TTL_SEC",
"hub_bridge_token",
"mint_hub_sso_token",
"safe_next_path",
"verify_hub_sso_token",
"request_allowed",
]
def request_allowed(session_logged_in: bool, auth_disabled: bool) -> bool:
if auth_disabled or session_logged_in:
return True
tok = hub_bridge_token()
if not tok:
return False
try:
from flask import request
except ImportError:
return False
if request.headers.get("X-Hub-Token") == tok:
return True
return False
File diff suppressed because it is too large Load Diff
+498
View File
@@ -0,0 +1,498 @@
"""中控历史测算:趋势回调 / 滚仓,以损定仓(按交易所精度与张数规则)。"""
from __future__ import annotations
from typing import Any, Callable, Optional, Tuple
from lib.strategy.strategy_roll_lib import max_roll_legs
from lib.strategy.strategy_trend_lib import (
build_trend_preview_level_rows,
calc_risk_fraction,
compute_trend_plan_core,
validate_trend_bounds,
)
DEFAULT_DCA_LEGS = 5
MARGIN_BUFFER = 0.95
def _resolve_market(
exchange_id: str,
base: str,
) -> Tuple[Optional[dict[str, Any]], Optional[Callable[[float], Optional[float]]], Optional[str]]:
from lib.hub.hub_calculator_market_lib import get_calculator_market, make_amount_precise_fn_from_market
market, err = get_calculator_market(exchange_id, base)
if err or not market:
return None, None, err or "无法解析合约"
amount_precise = make_amount_precise_fn_from_market(market)
return market, amount_precise, None
def calc_trend_calculator(
*,
direction: str,
capital_usdt: float,
risk_percent: float,
leverage: int,
entry_price: float,
stop_loss: float,
add_upper: float,
take_profit: float,
dca_legs: int = DEFAULT_DCA_LEGS,
exchange_id: str = "0",
base: str = "ETH",
) -> Tuple[Optional[dict[str, Any]], Optional[str]]:
market, amount_precise, merr = _resolve_market(exchange_id, base)
if merr or not market or not amount_precise:
return None, merr or "无法解析合约"
contract_size = float(market.get("contract_size") or 1.0)
exchange_symbol = market["exchange_symbol"]
direction = (direction or "long").strip().lower()
if direction not in ("long", "short"):
return None, "方向须为 long 或 short"
try:
capital = float(capital_usdt)
rp = float(risk_percent)
lev = int(leverage)
entry = float(entry_price)
sl = float(stop_loss)
upper = float(add_upper)
tp = float(take_profit)
legs = max(1, int(dca_legs))
cs = float(contract_size) if contract_size else 1.0
except (TypeError, ValueError):
return None, "参数格式错误"
if capital <= 0 or rp <= 0 or lev <= 0 or entry <= 0 or sl <= 0 or upper <= 0 or tp <= 0:
return None, "资金、风险、杠杆与价格须大于 0"
bound_err = validate_trend_bounds(direction, sl, upper)
if bound_err:
return None, bound_err
rf = calc_risk_fraction(direction, upper, sl)
if rf is None or rf <= 0:
return None, "止损与补仓区间边界组合无法计算风险比例"
risk_budget = capital * (rp / 100.0)
notional = risk_budget / rf
margin_plan = min(notional / float(lev), capital * MARGIN_BUFFER)
if margin_plan <= 0:
return None, "计划保证金过小"
target_amt = _amount_from_margin(margin_plan, lev, entry, cs)
if target_amt is None or target_amt <= 0:
return None, "无法计算计划张数,请检查入场价与杠杆"
target_amt = amount_precise(target_amt)
if target_amt is None or target_amt <= 0:
return None, "计划张数低于交易所最小精度"
def _amount_precise(_symbol: str, amount: float) -> Optional[float]:
return amount_precise(amount)
payload, err = compute_trend_plan_core(
direction=direction,
stop_loss=sl,
add_upper=upper,
risk_percent=rp,
snapshot_usdt=capital,
leverage=lev,
live_price=entry,
target_order_amount=target_amt,
exchange_symbol=exchange_symbol,
dca_legs=legs,
amount_precise=_amount_precise,
min_amount=float(market.get("min_amount") or 0.0),
full_margin_buffer_ratio=MARGIN_BUFFER,
)
if err:
return None, err
payload["take_profit"] = tp
payload["leverage"] = lev
payload["contract_size"] = cs
preview, rows = build_trend_preview_level_rows(payload)
px_dec = int(market.get("price_decimals") or 4)
amt_dec = int(market.get("amount_decimals") or 4)
def _f(v: Any, nd: int | None = None) -> Any:
if v is None:
return None
try:
return round(float(v), nd if nd is not None else 8)
except (TypeError, ValueError):
return v
table = []
for row in rows:
table.append(
{
"label": row.get("label"),
"price": _f(row.get("price"), px_dec),
"contracts": _f(row.get("contracts"), amt_dec),
"avg_entry": _f(row.get("avg_entry"), px_dec),
"profit_u": _f(row.get("profit_u")),
"risk_u": _f(row.get("risk_u")),
"rr": _f(row.get("rr"), 4),
}
)
return {
"direction": direction,
"capital_usdt": _f(capital),
"risk_percent": _f(rp, 2),
"risk_budget_u": _f(preview.get("preview_risk_amount_u")),
"leverage": lev,
"entry_price": _f(entry, px_dec),
"stop_loss": _f(sl, px_dec),
"add_upper": _f(upper, px_dec),
"take_profit": _f(tp, px_dec),
"plan_margin_u": _f(preview.get("plan_margin_capital")),
"target_contracts": _f(preview.get("target_order_amount"), amt_dec),
"first_contracts": _f(preview.get("first_order_amount"), amt_dec),
"dca_legs": int(preview.get("dca_legs") or legs),
"first_profit_u": _f(preview.get("preview_first_profit_u")),
"first_rr": _f(preview.get("preview_target_rr"), 4),
"market": market,
"rows": table,
}, None
def _amount_from_margin(
margin_capital: float,
leverage: int,
price: float,
contract_size: float,
) -> Optional[float]:
try:
margin = float(margin_capital)
lev = int(leverage)
px = float(price)
cs = float(contract_size) if contract_size else 1.0
except (TypeError, ValueError):
return None
if margin <= 0 or lev <= 0 or px <= 0 or cs <= 0:
return None
notional = margin * lev
return notional / (px * cs)
def _round(v: Any, nd: int = 4) -> Any:
if v is None:
return None
try:
return round(float(v), nd)
except (TypeError, ValueError):
return v
def _money_rr(profit_u: Optional[float], risk_u: Optional[float]) -> Optional[float]:
try:
if risk_u is None or float(risk_u) <= 0 or profit_u is None:
return None
return round(float(profit_u) / float(risk_u), 4)
except (TypeError, ValueError):
return None
def calc_initial_roll_qty(
direction: str,
entry_price: float,
stop_loss: float,
risk_budget_usdt: float,
contract_size: float = 1.0,
) -> Tuple[Optional[float], Optional[str]]:
"""首仓以损定仓:打到初始止损亏损 = 风险预算。"""
try:
entry = float(entry_price)
sl = float(stop_loss)
budget = float(risk_budget_usdt)
cs = float(contract_size) if contract_size else 1.0
except (TypeError, ValueError):
return None, "参数格式错误"
if entry <= 0 or sl <= 0 or budget <= 0 or cs <= 0:
return None, "入场价、止损与风险预算须大于 0"
direction = (direction or "long").strip().lower()
if direction == "short":
per_unit = (sl - entry) * cs
if per_unit <= 0:
return None, "做空:止损价须高于首仓入场价"
else:
per_unit = (entry - sl) * cs
if per_unit <= 0:
return None, "做多:止损价须低于首仓入场价"
return budget / per_unit, None
def solve_add_amount_for_total_risk(
direction: str,
qty_existing: float,
entry_existing: float,
add_price: float,
new_stop: float,
risk_budget_usdt: float,
contract_size: float = 1.0,
) -> Tuple[Optional[float], Optional[str]]:
"""合并持仓打到新止损总亏损 = 风险预算,反推本次加仓张数。"""
try:
q1 = float(qty_existing)
e1 = float(entry_existing)
e2 = float(add_price)
sl = float(new_stop)
b = float(risk_budget_usdt)
cs = float(contract_size) if contract_size else 1.0
except (TypeError, ValueError):
return None, "参数格式错误"
if q1 <= 0 or e1 <= 0 or e2 <= 0 or b <= 0 or cs <= 0:
return None, "持仓或风险预算无效"
direction = (direction or "long").strip().lower()
if direction == "short":
denom = sl - e2
numer = b / cs - q1 * (sl - e1)
if denom <= 0:
return None, "做空:新止损须高于限价加仓价"
else:
denom = e2 - sl
numer = b / cs - q1 * (e1 - sl)
if denom <= 0:
return None, "做多:新止损须低于限价/市价加仓价"
q2 = numer / denom
if q2 <= 0:
return None, "按当前新止损与总风险%,无需加仓或无法再加(已满足风险上限)"
return q2, None
def _roll_leg_preview(
*,
direction: str,
qty_existing: float,
entry_existing: float,
take_profit: float,
add_price: float,
new_stop_loss: float,
risk_budget: float,
contract_size: float,
amount_precise: Callable[[float], Optional[float]],
) -> Tuple[Optional[dict[str, Any]], Optional[str]]:
direction = (direction or "long").strip().lower()
try:
tp = float(take_profit)
sl = float(new_stop_loss)
entry_add = float(add_price)
e1 = float(entry_existing)
except (TypeError, ValueError):
return None, "止损/止盈格式错误"
if sl <= 0 or tp <= 0 or entry_add <= 0:
return None, "止损与首仓止盈须大于0"
if direction == "long":
if sl >= entry_add:
return None, "做多:新止损须低于加仓价"
if tp <= e1:
return None, "做多:首仓止盈须高于当前持仓均价参考"
else:
if sl <= entry_add:
return None, "做空:新止损须高于加仓价"
if tp >= e1:
return None, "做空:首仓止盈须低于当前持仓均价参考"
q2_raw, err = solve_add_amount_for_total_risk(
direction,
qty_existing,
entry_existing,
entry_add,
sl,
risk_budget,
contract_size,
)
if err:
return None, err
q2 = amount_precise(float(q2_raw))
if q2 is None or q2 <= 0:
return None, "加仓张数低于交易所最小精度"
new_qty = float(qty_existing) + float(q2)
new_avg = (float(qty_existing) * float(entry_existing) + float(q2) * entry_add) / new_qty
cs = float(contract_size) if contract_size else 1.0
if direction == "long":
loss_at_sl = (new_avg - sl) * new_qty * cs
reward_at_tp = (tp - new_avg) * new_qty * cs
else:
loss_at_sl = (sl - new_avg) * new_qty * cs
reward_at_tp = (new_avg - tp) * new_qty * cs
return {
"add_amount_raw": q2,
"qty_after": new_qty,
"avg_entry_after": new_avg,
"add_price": entry_add,
"new_stop_loss": sl,
"loss_at_sl_usdt": loss_at_sl,
"reward_at_tp_usdt": reward_at_tp,
}, None
def calc_roll_calculator(
*,
direction: str,
capital_usdt: float,
risk_percent: float,
entry_price: float,
stop_loss: float,
take_profit: float,
add_legs: list[dict[str, float]] | None = None,
legs_done: int = 0,
exchange_id: str = "0",
base: str = "ETH",
) -> Tuple[Optional[dict[str, Any]], Optional[str]]:
"""
滚仓历史测算:首仓自动以损定仓;止盈锁定首仓价;最多 3 次滚仓加仓。
add_legs: [{add_price, new_stop_loss}, ...],按顺序链式计算。
legs_done: 已完成滚仓次数(仅标记,仍参与链式状态推进)。
"""
market, amount_precise, merr = _resolve_market(exchange_id, base)
if merr or not market or not amount_precise:
return None, merr or "无法解析合约"
contract_size = float(market.get("contract_size") or 1.0)
px_dec = int(market.get("price_decimals") or 4)
amt_dec = int(market.get("amount_decimals") or 4)
direction = (direction or "long").strip().lower()
if direction not in ("long", "short"):
return None, "方向须为 long 或 short"
try:
capital = float(capital_usdt)
rp = float(risk_percent)
entry = float(entry_price)
initial_sl = float(stop_loss)
tp = float(take_profit)
done = max(0, int(legs_done))
except (TypeError, ValueError):
return None, "参数格式错误"
if capital <= 0 or rp <= 0 or entry <= 0 or initial_sl <= 0 or tp <= 0:
return None, "资金、风险与价格须大于 0"
if done > max_roll_legs(direction):
return None, f"已完成滚仓次数不能超过 {max_roll_legs(direction)}"
legs_in: list[dict[str, float]] = []
for raw in add_legs or []:
if not isinstance(raw, dict):
continue
try:
ap = float(raw.get("add_price"))
nsl = float(raw.get("new_stop_loss"))
except (TypeError, ValueError):
return None, "加仓价与新止损须为有效数字"
if ap <= 0 or nsl <= 0:
return None, "加仓价与新止损须大于 0"
legs_in.append({"add_price": ap, "new_stop_loss": nsl})
if done + len(legs_in) > max_roll_legs(direction):
return None, f"已完成 {done} 次 + 待测算 {len(legs_in)} 次,合计不能超过 {max_roll_legs(direction)} 次滚仓"
if direction == "long":
if tp <= entry:
return None, "做多:止盈价须高于首仓入场价"
else:
if tp >= entry:
return None, "做空:止盈价须低于首仓入场价"
risk_budget = capital * (rp / 100.0)
qty, err = calc_initial_roll_qty(direction, entry, initial_sl, risk_budget, contract_size)
if err:
return None, err
if qty is None or qty <= 0:
return None, "无法计算首仓张数"
qty_p = amount_precise(float(qty))
if qty_p is None or qty_p <= 0:
return None, "首仓张数低于交易所最小精度"
qty_f = float(qty_p)
avg = entry
rows: list[dict[str, Any]] = []
cs = contract_size
if direction == "long":
first_loss = (avg - initial_sl) * qty_f * cs
first_profit = (tp - avg) * qty_f * cs
else:
first_loss = (initial_sl - avg) * qty_f * cs
first_profit = (avg - tp) * qty_f * cs
rows.append(
{
"label": "首仓",
"leg_index": 0,
"already_done": False,
"entry_or_add_price": _round(entry, px_dec),
"stop_loss": _round(initial_sl, px_dec),
"add_contracts": _round(qty_f, amt_dec),
"total_contracts": _round(qty_f, amt_dec),
"avg_entry": _round(avg, px_dec),
"take_profit": _round(tp, px_dec),
"loss_at_sl_u": _round(first_loss),
"profit_at_tp_u": _round(first_profit),
"rr": _money_rr(first_profit, first_loss),
}
)
current_qty = qty_f
current_avg = avg
for i, leg in enumerate(legs_in):
leg_no = i + 1
preview, err = _roll_leg_preview(
direction=direction,
qty_existing=current_qty,
entry_existing=current_avg,
take_profit=tp,
add_price=leg["add_price"],
new_stop_loss=leg["new_stop_loss"],
risk_budget=risk_budget,
contract_size=cs,
amount_precise=amount_precise,
)
if err:
return None, f"滚仓第 {leg_no} 次:{err}"
if not preview:
return None, f"滚仓第 {leg_no} 次计算失败"
current_qty = float(preview["qty_after"])
current_avg = float(preview["avg_entry_after"])
loss = preview.get("loss_at_sl_usdt")
reward = preview.get("reward_at_tp_usdt")
rows.append(
{
"label": f"滚仓{leg_no}",
"leg_index": leg_no,
"already_done": leg_no <= done,
"entry_or_add_price": _round(preview.get("add_price"), px_dec),
"stop_loss": _round(preview.get("new_stop_loss"), px_dec),
"add_contracts": _round(preview.get("add_amount_raw"), amt_dec),
"total_contracts": _round(current_qty, amt_dec),
"avg_entry": _round(current_avg, px_dec),
"take_profit": _round(tp, px_dec),
"loss_at_sl_u": _round(loss),
"profit_at_tp_u": _round(reward),
"rr": _money_rr(reward, loss),
}
)
last = rows[-1]
return {
"direction": direction,
"capital_usdt": _round(capital),
"risk_percent": _round(rp, 2),
"risk_budget_u": _round(risk_budget),
"entry_price": _round(entry, px_dec),
"stop_loss": _round(initial_sl, px_dec),
"take_profit": _round(tp, px_dec),
"legs_done": done,
"roll_legs_planned": len(legs_in),
"first_contracts": _round(qty_f, amt_dec),
"final_contracts": last.get("total_contracts"),
"final_avg_entry": last.get("avg_entry"),
"final_loss_at_sl_u": last.get("loss_at_sl_u"),
"final_profit_at_tp_u": last.get("profit_at_tp_u"),
"final_rr": last.get("rr"),
"market": market,
"rows": rows,
}, None
+257
View File
@@ -0,0 +1,257 @@
"""计算器:从已配置交易实例读取 USDT 永续合约精度与张数规则。"""
from __future__ import annotations
import json
import threading
import time
import urllib.error
import urllib.request
from typing import Any, Callable, Optional, Tuple
from urllib.parse import urlencode
try:
from settings_store import enabled_exchanges, load_settings
except ImportError:
from manual_trading_hub.settings_store import enabled_exchanges, load_settings
MARKET_CACHE: dict[str, tuple[float, dict[str, Any]]] = {}
MARKET_LOCK = threading.Lock()
MARKET_TTL_SEC = 300.0
HUB_FLASK_TIMEOUT = float(__import__("os").getenv("HUB_FLASK_TIMEOUT", "20"))
def normalize_base_symbol(text: str) -> str:
s = str(text or "").upper().strip()
for suf in ("USDT:USDT", "/USDT:USDT", "/USDT", "USDT", "-USDT-SWAP"):
if s.endswith(suf) and len(s) > len(suf):
s = s[: -len(suf)].strip("-/")
break
if "/" in s:
s = s.split("/", 1)[0].strip()
if ":" in s:
s = s.split(":", 1)[0].strip()
return s
def resolve_usdt_perp_symbol(exchange: Any, base: str) -> Tuple[Optional[str], Optional[str]]:
base_u = normalize_base_symbol(base)
if not base_u:
return None, "请输入币种,如 ETH"
candidates = [f"{base_u}/USDT:USDT", f"{base_u}/USDT"]
markets = getattr(exchange, "markets", None) or {}
for sym in candidates:
m = markets.get(sym)
if not m:
continue
if m.get("active") is False:
continue
if m.get("swap") or m.get("linear") or m.get("contract"):
return sym, None
for sym, m in markets.items():
if m.get("active") is False:
continue
if not (m.get("swap") or m.get("linear")):
continue
if (m.get("quote") or "").upper() != "USDT":
continue
if (m.get("base") or "").upper() == base_u:
return sym, None
return None, f"未找到 {base_u}/USDT 永续合约"
def _decimals_from_precision_value(value: Any) -> Optional[int]:
if value in (None, ""):
return None
try:
p = float(value)
except (TypeError, ValueError):
return None
if p >= 1 and abs(p - round(p)) < 1e-9 and p <= 12:
return int(round(p))
if 0 < p < 1:
s = f"{p:.12f}".rstrip("0")
if "." in s:
return min(12, len(s.split(".", 1)[1]))
return None
def _decimals_from_ccxt_str(text: str) -> int:
s = str(text or "").strip()
if not s or "." not in s:
return 0
frac = s.split(".", 1)[1]
if not frac:
return 0
return min(12, len(frac.rstrip("0") or frac))
def amount_decimals_from_exchange(exchange: Any, exchange_symbol: str) -> int:
try:
return _decimals_from_ccxt_str(exchange.amount_to_precision(exchange_symbol, 1.23456789))
except Exception:
market = exchange.market(exchange_symbol)
prec = (market.get("precision") or {}).get("amount")
d = _decimals_from_precision_value(prec)
return d if d is not None else 4
def price_decimals_from_exchange(
exchange: Any, exchange_symbol: str, price_tick: Optional[float]
) -> int:
from lib.hub.hub_ohlcv_lib import normalize_price_tick
tick = normalize_price_tick(price_tick)
if tick and tick > 0:
if tick >= 1:
return 0
s = f"{tick:.12f}".rstrip("0")
if "." in s:
return min(12, len(s.split(".", 1)[1]))
try:
return _decimals_from_ccxt_str(exchange.price_to_precision(exchange_symbol, 12345.678901234))
except Exception:
market = exchange.market(exchange_symbol)
prec = (market.get("precision") or {}).get("price")
d = _decimals_from_precision_value(prec)
return d if d is not None else 4
def make_amount_precise_fn_from_market(market: dict[str, Any]) -> Callable[[float], Optional[float]]:
dec = max(0, int(market.get("amount_decimals") or 4))
min_amt = market.get("min_amount")
def _fn(amount: float) -> Optional[float]:
try:
v = float(amount)
except (TypeError, ValueError):
return None
if v <= 0:
return None
factor = 10**dec
v = int(v * factor + 1e-12) / factor
if min_amt is not None:
try:
if v < float(min_amt):
return None
except (TypeError, ValueError):
pass
if v <= 0:
return None
return v
return _fn
def find_exchange(exchange_id: str) -> dict | None:
needle = str(exchange_id or "").strip()
if not needle:
return None
for ex in load_settings().get("exchanges") or []:
if str(ex.get("id") or "").strip() == needle:
return ex
if str(ex.get("key") or "").strip().lower() == needle.lower():
return ex
return None
def list_calculator_exchanges() -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
for ex in enabled_exchanges():
rows.append(
{
"id": str(ex.get("id") or ""),
"key": str(ex.get("key") or ""),
"name": str(ex.get("name") or ex.get("key") or ""),
"enabled": bool(ex.get("enabled")),
}
)
return rows
def _hub_headers() -> dict[str, str]:
import os
token = (os.getenv("HUB_BRIDGE_TOKEN") or os.getenv("CONTROL_TOKEN") or "").strip()
if token:
return {"X-Hub-Token": token}
return {}
def fetch_instance_market_sync(ex: dict, *, base: str) -> dict[str, Any]:
base_url = (ex.get("flask_url") or "").rstrip("/")
if not base_url:
return {"ok": False, "msg": "未配置 flask_url"}
params = urlencode({"base": normalize_base_symbol(base) or base})
url = f"{base_url}/api/hub/market?{params}"
req = urllib.request.Request(url, headers=_hub_headers(), method="GET")
try:
with urllib.request.urlopen(req, timeout=HUB_FLASK_TIMEOUT) as resp:
status = int(getattr(resp, "status", 200) or 200)
raw = resp.read().decode("utf-8", errors="replace")
data = json.loads(raw) if raw else {}
if not isinstance(data, dict):
return {"ok": False, "msg": "无效 JSON"}
if status >= 400:
data.setdefault("ok", False)
return data
except urllib.error.HTTPError as exc:
try:
raw = exc.read().decode("utf-8", errors="replace")
body = json.loads(raw) if raw else {}
except Exception:
body = {"ok": False, "msg": raw if "raw" in locals() else str(exc)}
if isinstance(body, dict):
body.setdefault("ok", False)
return body
return {"ok": False, "msg": f"HTTP {exc.code}"}
except Exception as exc:
return {"ok": False, "msg": str(exc)}
def _enrich_market_from_settings(ex: dict, payload: dict[str, Any]) -> dict[str, Any]:
out = dict(payload)
out["exchange_id"] = str(ex.get("id") or "")
out["exchange_key"] = str(ex.get("key") or "")
out["exchange_name"] = str(ex.get("name") or ex.get("key") or "")
out["exchange_label"] = out["exchange_name"]
return out
def get_calculator_market(
exchange_id: str,
base: str,
*,
ex: dict | None = None,
) -> Tuple[Optional[dict[str, Any]], Optional[str]]:
"""从系统设置中的交易实例拉取合约精度(与实盘一致)。"""
row = ex or find_exchange(exchange_id)
if not row:
return None, "未找到该交易所配置"
if not row.get("enabled"):
return None, f"{row.get('name') or exchange_id} 未启用"
base_u = normalize_base_symbol(base)
if not base_u:
return None, "请输入币种,如 ETH"
cache_key = f"{row.get('id')}:{base_u}"
now = time.time()
with MARKET_LOCK:
cached = MARKET_CACHE.get(cache_key)
if cached and now - cached[0] < MARKET_TTL_SEC:
return dict(cached[1]), None
remote = fetch_instance_market_sync(row, base=base_u)
if not remote.get("ok"):
return None, str(remote.get("msg") or "实例返回失败")
data = _enrich_market_from_settings(row, remote)
with MARKET_LOCK:
MARKET_CACHE[cache_key] = (now, data)
return data, None
def clear_market_cache() -> None:
with MARKET_LOCK:
MARKET_CACHE.clear()
+453
View File
@@ -0,0 +1,453 @@
"""中控开仓计划:进行中 / 历史归档 / 胜率统计。"""
from __future__ import annotations
import os
import sqlite3
import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any
from zoneinfo import ZoneInfo
PLAN_TYPES = {
"trend": "趋势单",
"swing": "波段单",
"intraday": "日内短线",
}
TREND_TIMEFRAMES = ("5m", "15m", "30m", "1h", "4h", "1d")
ENTRY_TIMEFRAMES = ("1m", "5m", "15m", "30m", "1h")
DIRECTIONS = {"long": "", "short": ""}
ENTRY_SCHEMES = {
"breakout": "突破方案",
"false_breakout": "假突破突破方案",
"box_inflection": "箱体拐点方案",
}
RESULTS = {"win": "", "loss": ""}
STAT_DIMENSIONS = ("symbol", "trend_tf", "entry_scheme")
DISPLAY_TZ = ZoneInfo(
(os.getenv("HUB_ENTRY_PLAN_TZ") or os.getenv("HUB_VOLUME_RANK_TZ") or "Asia/Shanghai").strip()
or "Asia/Shanghai"
)
def default_db_path() -> Path:
raw = (os.getenv("HUB_ENTRY_PLAN_DB_PATH") or "").strip()
if raw:
return Path(raw)
hub_dir = Path(__file__).resolve().parent / "manual_trading_hub" / "data"
hub_dir.mkdir(parents=True, exist_ok=True)
return hub_dir / "hub_entry_plans.db"
def _now_ms() -> int:
return int(time.time() * 1000)
def _connect(db_path: Path | None = None) -> sqlite3.Connection:
path = db_path or default_db_path()
path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(path), timeout=30, isolation_level=None)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
return conn
def init_db(db_path: Path | None = None) -> None:
conn = _connect(db_path)
try:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS entry_plans (
id INTEGER PRIMARY KEY AUTOINCREMENT,
plan_date TEXT NOT NULL,
exchange_key TEXT NOT NULL,
symbol TEXT NOT NULL,
plan_type TEXT NOT NULL,
trend_timeframe TEXT NOT NULL,
entry_timeframe TEXT NOT NULL,
direction TEXT NOT NULL,
target_level TEXT NOT NULL DEFAULT '',
current_range TEXT NOT NULL DEFAULT '',
entry_scheme TEXT NOT NULL,
result TEXT,
pnl_amount REAL,
note TEXT NOT NULL DEFAULT '',
status TEXT NOT NULL DEFAULT 'active',
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
archived_at INTEGER
)
"""
)
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_entry_plans_status_date
ON entry_plans (status, plan_date DESC, id DESC)
"""
)
finally:
conn.close()
def normalize_plan_symbol(raw: str) -> str:
s = str(raw or "").strip().upper()
if not s:
raise ValueError("缺少币种")
if ":" in s:
s = s.split(":", 1)[0]
if "/" in s:
base, quote = s.split("/", 1)
base = base.strip()
quote = (quote or "USDT").strip() or "USDT"
if not base:
raise ValueError("币种无效")
return f"{base}/{quote}"
if s.endswith("USDT") and len(s) > 4:
return f"{s[:-4]}/{s[-4:]}"
return f"{s}/USDT"
def _validate_choice(value: str, allowed: dict[str, str] | tuple[str, ...], field: str) -> str:
key = str(value or "").strip().lower()
if isinstance(allowed, dict):
if key not in allowed:
raise ValueError(f"{field} 无效")
return key
if key not in allowed:
raise ValueError(f"{field} 无效")
return key
def _row_to_dict(row: sqlite3.Row | None) -> dict[str, Any] | None:
if row is None:
return None
d = dict(row)
d["plan_type_label"] = PLAN_TYPES.get(d.get("plan_type") or "", d.get("plan_type") or "")
d["direction_label"] = DIRECTIONS.get(d.get("direction") or "", d.get("direction") or "")
d["entry_scheme_label"] = ENTRY_SCHEMES.get(
d.get("entry_scheme") or "", d.get("entry_scheme") or ""
) or "待填写"
res = d.get("result")
d["result_label"] = RESULTS.get(res, "") if res else ""
return d
def _parse_optional_pnl(raw: Any) -> float | None:
if raw is None or raw == "":
return None
try:
return round(float(raw), 4)
except (TypeError, ValueError) as e:
raise ValueError("盈亏金额无效") from e
def create_entry_plan(payload: dict[str, Any], *, db_path: Path | None = None) -> dict[str, Any]:
init_db(db_path)
plan_date = str(payload.get("plan_date") or "").strip()[:10]
if not plan_date:
raise ValueError("缺少 plan_date")
exchange_key = str(payload.get("exchange_key") or "").strip().lower()
if not exchange_key:
raise ValueError("缺少 exchange_key")
symbol = normalize_plan_symbol(payload.get("symbol") or "")
plan_type = _validate_choice(payload.get("plan_type"), PLAN_TYPES, "类型")
trend_tf = _validate_choice(payload.get("trend_timeframe"), TREND_TIMEFRAMES, "趋势周期")
entry_tf = _validate_choice(payload.get("entry_timeframe"), ENTRY_TIMEFRAMES, "入场周期")
direction = _validate_choice(payload.get("direction"), DIRECTIONS, "方向")
entry_scheme = ""
if payload.get("entry_scheme"):
entry_scheme = _validate_choice(payload.get("entry_scheme"), ENTRY_SCHEMES, "入场方案")
target_level = str(payload.get("target_level") or "").strip()
current_range = str(payload.get("current_range") or "").strip()
note = str(payload.get("note") or "").strip()
now = _now_ms()
conn = _connect(db_path)
try:
cur = conn.execute(
"""
INSERT INTO entry_plans (
plan_date, exchange_key, symbol, plan_type, trend_timeframe, entry_timeframe,
direction, target_level, current_range, entry_scheme, note, status,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'active', ?, ?)
""",
(
plan_date,
exchange_key,
symbol,
plan_type,
trend_tf,
entry_tf,
direction,
target_level,
current_range,
entry_scheme,
note,
now,
now,
),
)
row = conn.execute(
"SELECT * FROM entry_plans WHERE id=?",
(int(cur.lastrowid),),
).fetchone()
return _row_to_dict(row) or {}
finally:
conn.close()
def list_entry_plans(
*,
status: str = "active",
db_path: Path | None = None,
) -> list[dict[str, Any]]:
init_db(db_path)
st = (status or "active").strip().lower()
if st not in ("active", "archived"):
raise ValueError("status 无效")
conn = _connect(db_path)
try:
rows = conn.execute(
"""
SELECT * FROM entry_plans
WHERE status=?
ORDER BY plan_date DESC, id DESC
""",
(st,),
).fetchall()
return [_row_to_dict(r) for r in rows if r]
finally:
conn.close()
def get_entry_plan(plan_id: int, *, db_path: Path | None = None) -> dict[str, Any] | None:
init_db(db_path)
conn = _connect(db_path)
try:
row = conn.execute("SELECT * FROM entry_plans WHERE id=?", (int(plan_id),)).fetchone()
return _row_to_dict(row)
finally:
conn.close()
def update_entry_plan(
plan_id: int,
payload: dict[str, Any],
*,
db_path: Path | None = None,
) -> dict[str, Any] | None:
init_db(db_path)
conn = _connect(db_path)
try:
row = conn.execute("SELECT * FROM entry_plans WHERE id=?", (int(plan_id),)).fetchone()
if not row:
return None
if row["status"] == "archived":
raise ValueError("已归档计划不可修改")
fields: dict[str, Any] = {}
if "plan_date" in payload:
qd = str(payload.get("plan_date") or "").strip()[:10]
if not qd:
raise ValueError("缺少 plan_date")
fields["plan_date"] = qd
if "exchange_key" in payload:
ex = str(payload.get("exchange_key") or "").strip().lower()
if not ex:
raise ValueError("缺少 exchange_key")
fields["exchange_key"] = ex
if "symbol" in payload:
fields["symbol"] = normalize_plan_symbol(payload.get("symbol") or "")
if "plan_type" in payload:
fields["plan_type"] = _validate_choice(payload.get("plan_type"), PLAN_TYPES, "类型")
if "trend_timeframe" in payload:
fields["trend_timeframe"] = _validate_choice(
payload.get("trend_timeframe"), TREND_TIMEFRAMES, "趋势周期"
)
if "entry_timeframe" in payload:
fields["entry_timeframe"] = _validate_choice(
payload.get("entry_timeframe"), ENTRY_TIMEFRAMES, "入场周期"
)
if "direction" in payload:
fields["direction"] = _validate_choice(payload.get("direction"), DIRECTIONS, "方向")
if "entry_scheme" in payload:
fields["entry_scheme"] = _validate_choice(
payload.get("entry_scheme"), ENTRY_SCHEMES, "入场方案"
)
if "target_level" in payload:
fields["target_level"] = str(payload.get("target_level") or "").strip()
if "current_range" in payload:
fields["current_range"] = str(payload.get("current_range") or "").strip()
if "note" in payload:
fields["note"] = str(payload.get("note") or "").strip()
if "pnl_amount" in payload:
fields["pnl_amount"] = _parse_optional_pnl(payload.get("pnl_amount"))
archive_now = False
if "result" in payload:
res_raw = payload.get("result")
if res_raw is None or str(res_raw).strip() == "":
fields["result"] = None
else:
fields["result"] = _validate_choice(res_raw, RESULTS, "结果")
archive_now = True
if not fields:
return _row_to_dict(row)
now = _now_ms()
fields["updated_at"] = now
if archive_now:
scheme_val = fields.get("entry_scheme", row["entry_scheme"])
if not str(scheme_val or "").strip():
raise ValueError("归档前请在进行中计划里选择入场方案")
fields["status"] = "archived"
fields["archived_at"] = now
sets = ", ".join(f"{k}=?" for k in fields)
conn.execute(
f"UPDATE entry_plans SET {sets} WHERE id=?",
(*fields.values(), int(plan_id)),
)
updated = conn.execute("SELECT * FROM entry_plans WHERE id=?", (int(plan_id),)).fetchone()
return _row_to_dict(updated)
finally:
conn.close()
def delete_entry_plan(plan_id: int, *, db_path: Path | None = None) -> bool:
init_db(db_path)
conn = _connect(db_path)
try:
row = conn.execute("SELECT status FROM entry_plans WHERE id=?", (int(plan_id),)).fetchone()
if not row:
return False
if row["status"] != "active":
raise ValueError("仅进行中的计划可删除")
cur = conn.execute("DELETE FROM entry_plans WHERE id=? AND status='active'", (int(plan_id),))
return int(cur.rowcount or 0) > 0
finally:
conn.close()
def _today_iso() -> str:
return datetime.now(DISPLAY_TZ).strftime("%Y-%m-%d")
def resolve_stats_date_bounds(
*,
period: str = "all",
date_from: str = "",
date_to: str = "",
) -> tuple[str | None, str | None, str]:
"""返回 (date_from, date_to, label)all 时 bounds 为 None。"""
p = (period or "all").strip().lower() or "all"
today = _today_iso()
if p == "all":
return None, None, "全部历史"
if p == "week":
day_dt = datetime.strptime(today, "%Y-%m-%d")
monday = (day_dt - timedelta(days=day_dt.weekday())).strftime("%Y-%m-%d")
return monday, today, f"本周 {monday}{today}"
if p == "month":
day_dt = datetime.strptime(today, "%Y-%m-%d")
first = day_dt.replace(day=1).strftime("%Y-%m-%d")
return first, today, f"本月 {first}{today}"
if p == "range":
df = (date_from or "").strip()[:10] or today
dt = (date_to or "").strip()[:10] or df
if df > dt:
df, dt = dt, df
label = f"区间 {df}{dt}" if df != dt else f"区间 {df}"
return df, dt, label
return None, None, "全部历史"
def compute_entry_plan_stats(
*,
dimension: str = "symbol",
period: str = "all",
date_from: str = "",
date_to: str = "",
db_path: Path | None = None,
) -> dict[str, Any]:
init_db(db_path)
dim = (dimension or "symbol").strip().lower()
if dim not in STAT_DIMENSIONS:
raise ValueError("dimension 无效")
df_bound, dt_bound, period_label = resolve_stats_date_bounds(
period=period, date_from=date_from, date_to=date_to
)
col_map = {
"symbol": "symbol",
"trend_tf": "trend_timeframe",
"entry_scheme": "entry_scheme",
}
col = col_map[dim]
conn = _connect(db_path)
try:
where = "status='archived' AND result IN ('win','loss')"
params: list[Any] = []
if df_bound:
where += " AND plan_date >= ? AND plan_date <= ?"
params.extend([df_bound, dt_bound])
rows = conn.execute(
f"""
SELECT {col} AS dim_key,
COUNT(*) AS total,
SUM(CASE WHEN result='win' THEN 1 ELSE 0 END) AS win_count,
SUM(CASE WHEN result='loss' THEN 1 ELSE 0 END) AS loss_count
FROM entry_plans
WHERE {where}
GROUP BY {col}
ORDER BY total DESC, dim_key ASC
""",
params,
).fetchall()
items = []
for r in rows:
total = int(r["total"] or 0)
wins = int(r["win_count"] or 0)
losses = int(r["loss_count"] or 0)
key = str(r["dim_key"] or "")
label = key
if dim == "entry_scheme":
label = ENTRY_SCHEMES.get(key, key)
elif dim == "trend_tf":
label = key
win_rate = round(wins / total * 100, 1) if total else None
items.append(
{
"key": key,
"label": label,
"total": total,
"win_count": wins,
"loss_count": losses,
"win_rate": win_rate,
}
)
return {
"dimension": dim,
"period": period,
"period_label": period_label,
"date_from": df_bound,
"date_to": dt_bound,
"items": items,
}
finally:
conn.close()
def meta_payload(exchanges: list[dict[str, Any]] | None = None) -> dict[str, Any]:
return {
"plan_types": [{"value": k, "label": v} for k, v in PLAN_TYPES.items()],
"trend_timeframes": list(TREND_TIMEFRAMES),
"entry_timeframes": list(ENTRY_TIMEFRAMES),
"directions": [{"value": k, "label": v} for k, v in DIRECTIONS.items()],
"entry_schemes": [{"value": k, "label": v} for k, v in ENTRY_SCHEMES.items()],
"results": [{"value": k, "label": v} for k, v in RESULTS.items()],
"stat_dimensions": [
{"value": "symbol", "label": "币种"},
{"value": "trend_tf", "label": "趋势周期"},
{"value": "entry_scheme", "label": "入场方案"},
],
"exchanges": exchanges or [],
}
+407
View File
@@ -0,0 +1,407 @@
"""中控资金概况:分户日快照(180 交易日)、总资金曲线与回撤。"""
from __future__ import annotations
import json
import os
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Optional
from lib.hub.hub_trades_lib import current_trading_day
HUB_DIR = Path(__file__).resolve().parent / "manual_trading_hub"
FUND_HISTORY_PATH = HUB_DIR / "hub_fund_history.json"
LEGACY_FUND_HISTORY_PATH = HUB_DIR / "hub_ai_fund_history.json"
try:
FUND_HISTORY_DAYS = max(30, int(os.getenv("HUB_FUND_HISTORY_DAYS", "180") or "180"))
except ValueError:
FUND_HISTORY_DAYS = 180
FUND_HISTORY_START_DAY = (os.getenv("HUB_FUND_HISTORY_START_DAY") or "2026-06-09").strip()[:10]
def fund_history_start_day() -> str:
return FUND_HISTORY_START_DAY or "2026-06-09"
def _now_str() -> str:
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def _safe_float(value: Any) -> Optional[float]:
try:
v = float(value)
return v if v >= 0 else None
except (TypeError, ValueError):
return None
def account_total_usdt(funding: Any, trading: Any) -> Optional[float]:
"""资金户 + 交易户;任一侧缺失则不计入(返回 None)。"""
fu = _safe_float(funding)
tu = _safe_float(trading)
if fu is None or tu is None:
return None
return round(fu + tu, 4)
def compute_drawdown(values: list[float]) -> dict[str, Any]:
"""基于资金权益序列计算峰值回撤(U 与 %)。"""
peak = 0.0
max_dd_u = 0.0
peak_at_end = 0.0
for v in values:
if not isinstance(v, (int, float)):
continue
fv = float(v)
if fv > peak:
peak = fv
dd = peak - fv
if dd > max_dd_u:
max_dd_u = dd
peak_at_end = peak
max_dd_u = round(max_dd_u, 4)
peak_at_end = round(peak_at_end, 4)
max_dd_pct = round((max_dd_u / peak_at_end) * 100, 2) if peak_at_end > 0 else None
return {
"peak_usdt": peak_at_end,
"max_drawdown_u": max_dd_u,
"max_drawdown_pct": max_dd_pct,
}
def _atomic_write(path: Path, data: dict) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(path.suffix + ".tmp")
tmp.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
os.replace(tmp, path)
def _prune_days(
days: dict,
*,
keep_days: int,
anchor_day: str,
start_day: Optional[str] = None,
) -> dict:
try:
anchor = datetime.strptime(anchor_day[:10], "%Y-%m-%d")
except ValueError:
anchor = datetime.now()
rolling_cutoff = (anchor - timedelta(days=max(1, keep_days) - 1)).strftime("%Y-%m-%d")
start = (start_day or fund_history_start_day()).strip()[:10]
cutoff = max(rolling_cutoff, start) if start else rolling_cutoff
return {k: v for k, v in (days or {}).items() if str(k) >= cutoff}
def _migrate_legacy_store(days: dict) -> dict:
if not LEGACY_FUND_HISTORY_PATH.is_file():
return days
try:
loaded = json.loads(LEGACY_FUND_HISTORY_PATH.read_text(encoding="utf-8"))
legacy_days = loaded.get("days") if isinstance(loaded, dict) else {}
if not isinstance(legacy_days, dict):
return days
merged = dict(days)
for day, block in legacy_days.items():
if day in merged:
continue
if isinstance(block, dict) and block.get("accounts"):
merged[day] = block
return merged
except Exception:
return days
def _load_store() -> dict:
if not FUND_HISTORY_PATH.is_file():
store = {"version": 1, "days": _migrate_legacy_store({})}
if store["days"]:
_atomic_write(FUND_HISTORY_PATH, store)
return store
try:
loaded = json.loads(FUND_HISTORY_PATH.read_text(encoding="utf-8"))
if isinstance(loaded, dict):
loaded.setdefault("version", 1)
days = dict(loaded.get("days") or {})
loaded["days"] = _migrate_legacy_store(days)
return loaded
except Exception:
pass
return {"version": 1, "days": {}}
def record_fund_snapshot(
trading_day: str,
accounts: list[dict],
*,
keep_days: int = FUND_HISTORY_DAYS,
reset_hour: int = 8,
) -> dict[str, Any]:
"""写入当日各户资金账户/交易账户余额,并裁剪历史。"""
day = (trading_day or "").strip()[:10] or current_trading_day(reset_hour=reset_hour)
start = fund_history_start_day()
if start and day < start:
return _load_store().get("days") or {}
store = _load_store()
days = dict(store.get("days") or {})
row_accounts: dict[str, dict] = {}
for ac in accounts or []:
key = str(ac.get("key") or ac.get("id") or "").strip()
if not key:
continue
if not ac.get("monitored"):
continue
fu = _safe_float(ac.get("funding_usdt"))
tu = _safe_float(ac.get("trading_usdt"))
total = account_total_usdt(fu, tu)
if total is None:
continue
row_accounts[key] = {
"name": ac.get("name"),
"funding_usdt": fu,
"trading_usdt": tu,
"total_usdt": total,
"recorded_at": _now_str(),
}
if row_accounts:
days[day] = {"accounts": row_accounts, "updated_at": _now_str()}
days = _prune_days(
days, keep_days=keep_days, anchor_day=day, start_day=fund_history_start_day()
)
_atomic_write(FUND_HISTORY_PATH, {"version": 1, "days": days})
return days
def record_fund_snapshot_from_board(
rows: list[dict],
*,
keep_days: int = FUND_HISTORY_DAYS,
reset_hour: int = 8,
) -> dict[str, Any]:
"""监控板行写入当日快照(仅 account_ok 且资金/交易户齐全)。"""
day = current_trading_day(reset_hour=reset_hour)
accounts = []
for row in rows or []:
if not isinstance(row, dict):
continue
if not row.get("account_ok"):
continue
accounts.append(
{
"key": row.get("key") or row.get("id"),
"name": row.get("name"),
"funding_usdt": row.get("funding_usdt"),
"trading_usdt": row.get("trading_usdt"),
"monitored": True,
}
)
return record_fund_snapshot(day, accounts, keep_days=keep_days, reset_hour=reset_hour)
def get_fund_history(*, anchor_day: str, keep_days: int = FUND_HISTORY_DAYS) -> dict[str, dict]:
store = _load_store()
return _prune_days(
dict(store.get("days") or {}),
keep_days=keep_days,
anchor_day=anchor_day,
start_day=fund_history_start_day(),
)
def _exchange_monitored(ex: dict) -> bool:
return bool(ex.get("enabled")) and not bool(ex.get("env_disabled"))
def _live_row_for_exchange(ex: dict, rows_by_key: dict[str, dict]) -> Optional[dict]:
key = str(ex.get("key") or "").strip()
if not key:
return None
return rows_by_key.get(key)
def _series_from_history(
history: dict[str, dict],
account_keys: list[str],
) -> list[dict[str, Any]]:
out: list[dict[str, Any]] = []
for day in sorted(history.keys()):
block = history.get(day) or {}
ac_map = block.get("accounts") or {}
total = 0.0
n = 0
for key in account_keys:
ac = ac_map.get(key) or {}
t = account_total_usdt(ac.get("funding_usdt"), ac.get("trading_usdt"))
if t is None:
t = _safe_float(ac.get("total_usdt"))
if t is None:
continue
total += t
n += 1
if n > 0:
out.append({"day": day, "total_usdt": round(total, 4)})
return out
def _account_series(history: dict[str, dict], key: str) -> list[dict[str, Any]]:
out: list[dict[str, Any]] = []
for day in sorted(history.keys()):
ac = (history.get(day) or {}).get("accounts", {}).get(key) or {}
t = account_total_usdt(ac.get("funding_usdt"), ac.get("trading_usdt"))
if t is None:
t = _safe_float(ac.get("total_usdt"))
if t is None:
continue
out.append(
{
"day": day,
"total_usdt": t,
"funding_usdt": _safe_float(ac.get("funding_usdt")),
"trading_usdt": _safe_float(ac.get("trading_usdt")),
}
)
return out
def build_fund_overview(
exchanges: list[dict],
*,
board_rows: Optional[list[dict]] = None,
trading_day: Optional[str] = None,
keep_days: int = FUND_HISTORY_DAYS,
reset_hour: int = 8,
updated_at: Optional[str] = None,
) -> dict[str, Any]:
day = (trading_day or "").strip()[:10] or current_trading_day(reset_hour=reset_hour)
history = get_fund_history(anchor_day=day, keep_days=keep_days)
rows_by_key: dict[str, dict] = {}
for row in board_rows or []:
if isinstance(row, dict):
k = str(row.get("key") or "").strip()
if k:
rows_by_key[k] = row
monitored_keys: list[str] = []
accounts_out: list[dict[str, Any]] = []
live_total = 0.0
live_known = 0
for ex in exchanges or []:
if not _exchange_monitored(ex):
continue
key = str(ex.get("key") or "").strip()
monitored = True
row = _live_row_for_exchange(ex, rows_by_key)
fu = tu = total = None
data_ok = False
if row and row.get("account_ok"):
fu = _safe_float(row.get("funding_usdt"))
tu = _safe_float(row.get("trading_usdt"))
total = account_total_usdt(fu, tu)
data_ok = total is not None
if data_ok:
live_total += total
live_known += 1
series = _account_series(history, key) if key else []
dd = compute_drawdown([p["total_usdt"] for p in series]) if series else {
"peak_usdt": None,
"max_drawdown_u": None,
"max_drawdown_pct": None,
}
day_delta = None
if series:
if len(series) >= 2:
day_delta = round(series[-1]["total_usdt"] - series[-2]["total_usdt"], 4)
elif data_ok and total is not None:
day_delta = round(total - series[-1]["total_usdt"], 4)
accounts_out.append(
{
"id": ex.get("id"),
"key": key,
"name": ex.get("name") or key,
"monitored": monitored,
"data_ok": data_ok,
"funding_usdt": fu,
"trading_usdt": tu,
"total_usdt": total,
"series": series,
"drawdown": dd,
"day_delta_usdt": day_delta,
}
)
if key:
monitored_keys.append(key)
total_series = _series_from_history(history, monitored_keys)
if live_known > 0:
last_day = total_series[-1]["day"] if total_series else None
live_point = round(live_total, 4)
if last_day == day and total_series:
total_series[-1]["total_usdt"] = live_point
total_series[-1]["live"] = True
else:
total_series.append({"day": day, "total_usdt": live_point, "live": True})
total_dd = compute_drawdown([p["total_usdt"] for p in total_series]) if total_series else {
"peak_usdt": None,
"max_drawdown_u": None,
"max_drawdown_pct": None,
}
total_day_delta = None
if total_series:
if len(total_series) >= 2:
total_day_delta = round(
total_series[-1]["total_usdt"] - total_series[-2]["total_usdt"], 4
)
return {
"ok": True,
"trading_day": day,
"reset_hour": reset_hour,
"keep_days": keep_days,
"history_start_day": fund_history_start_day(),
"updated_at": updated_at,
"totals": {
"monitored_count": len(monitored_keys),
"live_known_count": live_known,
"total_usdt": round(live_total, 4) if live_known > 0 else None,
"day_delta_usdt": total_day_delta,
"series": total_series,
"drawdown": total_dd,
},
"accounts": accounts_out,
}
def format_fund_history_text(
history: dict[str, dict],
*,
account_names: Optional[dict[str, str]] = None,
) -> str:
if not history:
return "(暂无资金历史快照)"
names = account_names or {}
lines = ["【资金快照(资金账户 + 交易账户 USDT)】"]
for day in sorted(history.keys()):
block = history.get(day) or {}
ac_map = block.get("accounts") or {}
if not ac_map:
continue
parts = []
for key, ac in ac_map.items():
label = names.get(key) or ac.get("name") or key
fu = ac.get("funding_usdt")
tu = ac.get("trading_usdt")
tot = ac.get("total_usdt")
if tot is None:
tot = account_total_usdt(fu, tu)
fu_txt = f"{fu}U" if fu is not None else "未知"
tu_txt = f"{tu}U" if tu is not None else "未知"
tot_txt = f"{tot}U" if tot is not None else "未知"
parts.append(f"{label}: 合计{tot_txt}(资金{fu_txt}/交易{tu_txt}")
lines.append(f"- {day}: " + "".join(parts))
return "\n".join(lines) if len(lines) > 1 else "(暂无资金历史快照)"
+98
View File
@@ -0,0 +1,98 @@
"""中控:本机 CPU / 内存 / 磁盘 / 网络快照(监控区服务器状态条)。"""
from __future__ import annotations
import os
import socket
import time
from typing import Any
_state: dict[str, Any] = {
"primed": False,
"net_ts": 0.0,
"net_sent": 0,
"net_recv": 0,
}
def _disk_path() -> str:
raw = (os.getenv("HUB_HOST_DISK_PATH") or "").strip()
if raw:
return raw
if os.name == "nt":
drive = (os.environ.get("SystemDrive") or "C:").strip()
return drive if drive.endswith(("\\", "/")) else drive + "\\"
return "/"
def _safe_int(value: Any) -> int:
try:
return int(value)
except (TypeError, ValueError):
return 0
def get_host_status() -> dict[str, Any]:
try:
import psutil
except ImportError:
return {
"ok": False,
"msg": "未安装 psutil,请在 manual-trading-hub 环境执行 pip install psutil",
}
now = time.time()
if not _state["primed"]:
psutil.cpu_percent(interval=None)
_state["primed"] = True
cpu_pct = float(psutil.cpu_percent(interval=None))
cpu_count = int(psutil.cpu_count(logical=True) or 0)
vm = psutil.virtual_memory()
disk_path = _disk_path()
du = psutil.disk_usage(disk_path)
net = psutil.net_io_counters()
sent_rate = 0.0
recv_rate = 0.0
if net is not None and _state["net_ts"] > 0:
dt = max(0.001, now - float(_state["net_ts"]))
sent_rate = max(0.0, (net.bytes_sent - int(_state["net_sent"])) / dt)
recv_rate = max(0.0, (net.bytes_recv - int(_state["net_recv"])) / dt)
if net is not None:
_state["net_ts"] = now
_state["net_sent"] = int(net.bytes_sent)
_state["net_recv"] = int(net.bytes_recv)
disk_total = _safe_int(du.total)
disk_used = _safe_int(du.used)
disk_pct = round(disk_used / disk_total * 100, 1) if disk_total > 0 else 0.0
boot = float(psutil.boot_time())
return {
"ok": True,
"hostname": socket.gethostname(),
"uptime_sec": max(0, int(now - boot)),
"cpu": {
"percent": round(cpu_pct, 1),
"count": cpu_count,
},
"memory": {
"total_bytes": _safe_int(vm.total),
"used_bytes": _safe_int(vm.used),
"percent": round(float(vm.percent), 1),
},
"disk": {
"path": disk_path,
"total_bytes": disk_total,
"used_bytes": disk_used,
"percent": disk_pct,
},
"network": {
"bytes_sent": _safe_int(net.bytes_sent if net else 0),
"bytes_recv": _safe_int(net.bytes_recv if net else 0),
"sent_rate_bps": round(sent_rate, 1),
"recv_rate_bps": round(recv_rate, 1),
},
"updated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
}
+881
View File
@@ -0,0 +1,881 @@
"""中控 K 线 SQLite:分周期保留、交易所直拉、分页读取。"""
from __future__ import annotations
import os
import sqlite3
import time
from pathlib import Path
from typing import Any, Callable, Optional
from lib.hub.hub_ohlcv_lib import (
HUB_KLINE_1M_MAX_BARS,
HUB_KLINE_5M_1H_RETENTION_DAYS,
TIMEFRAME_MS,
YEAR_ROLLING_STORED,
chart_chunk_limit,
chart_initial_limit,
chart_memory_cap,
history_cutoff_ms_for_storage,
normalize_chart_timeframe,
normalize_price_tick,
format_price_by_tick,
last_closed_bar_open_ms,
retention_policy_meta,
round_ohlcv_bars_to_tick,
seed_bar_target,
)
HUB_KLINE_MIN_BARS_BEFORE_TAIL = 200
HUB_KLINE_REMOTE_FETCH_CAP = 1500
_DEFAULT_RETENTION_DAYS = 15
def retention_days() -> int:
"""兼容旧配置;新策略见 retention_policy_meta。"""
try:
return max(1, int(os.getenv("HUB_KLINE_RETENTION_DAYS", str(_DEFAULT_RETENTION_DAYS))))
except ValueError:
return _DEFAULT_RETENTION_DAYS
def default_db_path() -> Path:
raw = (os.getenv("HUB_KLINE_DB_PATH") or "").strip()
if raw:
return Path(raw)
hub_dir = Path(__file__).resolve().parent / "manual_trading_hub" / "data"
hub_dir.mkdir(parents=True, exist_ok=True)
return hub_dir / "hub_kline.db"
def _connect(db_path: Path | None = None) -> sqlite3.Connection:
path = db_path or default_db_path()
path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(path), timeout=30, isolation_level=None)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
return conn
def init_db(db_path: Path | None = None) -> None:
conn = _connect(db_path)
try:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS ohlcv_bars (
exchange_key TEXT NOT NULL,
symbol TEXT NOT NULL,
timeframe TEXT NOT NULL,
open_time_ms INTEGER NOT NULL,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL DEFAULT 0,
updated_at INTEGER NOT NULL,
PRIMARY KEY (exchange_key, symbol, timeframe, open_time_ms)
)
"""
)
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_ohlcv_series
ON ohlcv_bars (exchange_key, symbol, timeframe, open_time_ms)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS ohlcv_symbol_meta (
exchange_key TEXT NOT NULL,
symbol TEXT NOT NULL,
price_tick REAL,
updated_at INTEGER NOT NULL,
PRIMARY KEY (exchange_key, symbol)
)
"""
)
finally:
conn.close()
def save_symbol_price_tick(
exchange_key: str,
symbol: str,
price_tick: float | None,
db_path: Path | None = None,
) -> None:
tick = price_tick
if tick is None:
return
try:
t = float(tick)
except (TypeError, ValueError):
return
if t <= 0:
return
ex_k = (exchange_key or "").strip().lower()
sym = (symbol or "").strip().upper()
conn = _connect(db_path)
try:
conn.execute(
"""
INSERT INTO ohlcv_symbol_meta (exchange_key, symbol, price_tick, updated_at)
VALUES (?,?,?,?)
ON CONFLICT(exchange_key, symbol) DO UPDATE SET
price_tick=excluded.price_tick,
updated_at=excluded.updated_at
""",
(ex_k, sym, t, int(time.time())),
)
finally:
conn.close()
def load_symbol_price_tick(
exchange_key: str,
symbol: str,
db_path: Path | None = None,
) -> float | None:
ex_k = (exchange_key or "").strip().lower()
sym = (symbol or "").strip().upper()
conn = _connect(db_path)
try:
row = conn.execute(
"SELECT price_tick FROM ohlcv_symbol_meta WHERE exchange_key=? AND symbol=?",
(ex_k, sym),
).fetchone()
if not row or row["price_tick"] is None:
return None
return float(row["price_tick"])
except (TypeError, ValueError):
return None
finally:
conn.close()
def purge_timeframe_by_days(
timeframe: str,
days: int,
db_path: Path | None = None,
) -> int:
cutoff = int(time.time() * 1000) - max(1, int(days)) * 86400000
tf = normalize_chart_timeframe(timeframe)
conn = _connect(db_path)
try:
cur = conn.execute(
"DELETE FROM ohlcv_bars WHERE timeframe=? AND open_time_ms < ?",
(tf, cutoff),
)
return int(cur.rowcount or 0)
finally:
conn.close()
def purge_1m_bar_cap(db_path: Path | None = None, *, max_bars: int | None = None) -> int:
cap = max(100, int(max_bars or HUB_KLINE_1M_MAX_BARS))
conn = _connect(db_path)
try:
cur = conn.execute(
"""
DELETE FROM ohlcv_bars
WHERE timeframe='1m' AND rowid IN (
SELECT rowid FROM (
SELECT rowid,
ROW_NUMBER() OVER (
PARTITION BY exchange_key, symbol
ORDER BY open_time_ms DESC
) AS rn
FROM ohlcv_bars
WHERE timeframe='1m'
) WHERE rn > ?
)
""",
(cap,),
)
return int(cur.rowcount or 0)
finally:
conn.close()
def clear_series_bars(
exchange_key: str,
symbol: str,
timeframe: str | None = None,
db_path: Path | None = None,
) -> int:
"""删除某交易所+币种 K 线(可指定周期);用于清库后全量重拉。"""
init_db(db_path)
ex_k = (exchange_key or "").strip().lower()
sym = (symbol or "").strip().upper()
if not ex_k or not sym:
return 0
conn = _connect(db_path)
try:
if timeframe:
tf = normalize_chart_timeframe(timeframe)
cur = conn.execute(
"DELETE FROM ohlcv_bars WHERE exchange_key=? AND symbol=? AND timeframe=?",
(ex_k, sym, tf),
)
else:
cur = conn.execute(
"DELETE FROM ohlcv_bars WHERE exchange_key=? AND symbol=?",
(ex_k, sym),
)
return int(cur.rowcount or 0)
finally:
conn.close()
def clear_all_bars(db_path: Path | None = None) -> int:
"""清空 hub K 线库全部 OHLCV 行。"""
init_db(db_path)
conn = _connect(db_path)
try:
cur = conn.execute("DELETE FROM ohlcv_bars")
return int(cur.rowcount or 0)
finally:
conn.close()
def purge_retention(db_path: Path | None = None) -> int:
"""按周期策略清理:5m/15m/1h/2h/4h 一年;1m 保留最近 N 根;1d/1w 不删。"""
n = 0
for tf in sorted(YEAR_ROLLING_STORED):
n += purge_timeframe_by_days(tf, HUB_KLINE_5M_1H_RETENTION_DAYS, db_path)
n += purge_1m_bar_cap(db_path)
return n
def upsert_bars(
exchange_key: str,
symbol: str,
timeframe: str,
bars: list[dict[str, Any]],
db_path: Path | None = None,
) -> int:
if not bars:
return 0
ex_k = (exchange_key or "").strip().lower()
sym = (symbol or "").strip().upper()
tf = normalize_chart_timeframe(timeframe)
now = int(time.time())
conn = _connect(db_path)
n = 0
try:
for b in bars:
try:
oms = int(b["open_time_ms"])
conn.execute(
"""
INSERT INTO ohlcv_bars
(exchange_key, symbol, timeframe, open_time_ms, open, high, low, close, volume, updated_at)
VALUES (?,?,?,?,?,?,?,?,?,?)
ON CONFLICT(exchange_key, symbol, timeframe, open_time_ms) DO UPDATE SET
open=excluded.open,
high=excluded.high,
low=excluded.low,
close=excluded.close,
volume=excluded.volume,
updated_at=excluded.updated_at
""",
(
ex_k,
sym,
tf,
oms,
float(b["open"]),
float(b["high"]),
float(b["low"]),
float(b["close"]),
float(b.get("volume") or 0),
now,
),
)
n += 1
except (KeyError, TypeError, ValueError):
continue
finally:
conn.close()
return n
def load_bars_range(
exchange_key: str,
symbol: str,
timeframe: str,
start_ms: int,
end_ms: int,
db_path: Path | None = None,
) -> list[dict[str, Any]]:
ex_k = (exchange_key or "").strip().lower()
sym = (symbol or "").strip().upper()
tf = normalize_chart_timeframe(timeframe)
conn = _connect(db_path)
try:
rows = conn.execute(
"""
SELECT open_time_ms, open, high, low, close, volume
FROM ohlcv_bars
WHERE exchange_key=? AND symbol=? AND timeframe=?
AND open_time_ms >= ? AND open_time_ms <= ?
ORDER BY open_time_ms ASC
""",
(ex_k, sym, tf, int(start_ms), int(end_ms)),
).fetchall()
return _rows_to_bars(rows)
finally:
conn.close()
def count_series_bars(
exchange_key: str,
symbol: str,
timeframe: str,
db_path: Path | None = None,
) -> int:
ex_k = (exchange_key or "").strip().lower()
sym = (symbol or "").strip().upper()
tf = normalize_chart_timeframe(timeframe)
conn = _connect(db_path)
try:
row = conn.execute(
"""
SELECT COUNT(*) AS c FROM ohlcv_bars
WHERE exchange_key=? AND symbol=? AND timeframe=?
""",
(ex_k, sym, tf),
).fetchone()
return int(row["c"] or 0) if row else 0
finally:
conn.close()
def _remote_fetch_limit(
*,
need: int,
force_refresh: bool,
storage_tf: str,
tail_only: bool,
) -> int:
if tail_only:
return min(need + 20, 300)
cap = HUB_KLINE_REMOTE_FETCH_CAP
if force_refresh:
return min(seed_bar_target(storage_tf), cap)
return min(max(need + 20, 1), cap)
def _since_ms_for_span(
*,
now_ms: int,
period_ms: int,
span_bars: int,
cutoff_ms: int,
) -> int:
"""拉取窗口起点:跨度必须与 fetch_limit 一致,保证数据能铺到最近。"""
span = max(1, int(span_bars))
return max(int(cutoff_ms), int(now_ms) - int(period_ms) * span)
def load_bars_latest(
exchange_key: str,
symbol: str,
timeframe: str,
limit: int,
db_path: Path | None = None,
) -> list[dict[str, Any]]:
ex_k = (exchange_key or "").strip().lower()
sym = (symbol or "").strip().upper()
tf = normalize_chart_timeframe(timeframe)
lim = max(1, int(limit))
conn = _connect(db_path)
try:
rows = conn.execute(
"""
SELECT open_time_ms, open, high, low, close, volume
FROM ohlcv_bars
WHERE exchange_key=? AND symbol=? AND timeframe=?
ORDER BY open_time_ms DESC
LIMIT ?
""",
(ex_k, sym, tf, lim),
).fetchall()
return list(reversed(_rows_to_bars(rows)))
finally:
conn.close()
def load_bars_before(
exchange_key: str,
symbol: str,
timeframe: str,
before_ms: int,
limit: int,
db_path: Path | None = None,
) -> list[dict[str, Any]]:
ex_k = (exchange_key or "").strip().lower()
sym = (symbol or "").strip().upper()
tf = normalize_chart_timeframe(timeframe)
lim = max(1, int(limit))
bms = int(before_ms)
conn = _connect(db_path)
try:
rows = conn.execute(
"""
SELECT open_time_ms, open, high, low, close, volume
FROM ohlcv_bars
WHERE exchange_key=? AND symbol=? AND timeframe=?
AND open_time_ms < ?
ORDER BY open_time_ms DESC
LIMIT ?
""",
(ex_k, sym, tf, bms, lim),
).fetchall()
return list(reversed(_rows_to_bars(rows)))
finally:
conn.close()
def trim_contiguous_tail(
bars: list[dict[str, Any]],
period_ms: int,
*,
max_gap_factor: float = 3.0,
) -> tuple[list[dict[str, Any]], int]:
"""只保留最近一段连续 K 线,丢弃左侧与主段断开的孤立数据。"""
if len(bars) <= 1:
return list(bars), 0
try:
period = max(1, int(period_ms))
except (TypeError, ValueError):
period = 60_000
max_gap = int(period * max_gap_factor)
split = 0
for i in range(len(bars) - 1, 0, -1):
gap = int(bars[i]["open_time_ms"]) - int(bars[i - 1]["open_time_ms"])
if gap > max_gap:
split = i
break
return bars[split:], split
def normalize_contiguous_db_rows(
bars: list[dict[str, Any]],
*,
period_ms: int,
exchange_key: str,
symbol: str,
timeframe: str,
db_path: Path | None = None,
purge_orphans: bool = True,
) -> list[dict[str, Any]]:
"""去掉与主段断开的孤立前缀;可选同步清理库内孤立数据。"""
if len(bars) <= 1:
return list(bars)
trimmed, split_at = trim_contiguous_tail(bars, period_ms)
if split_at > 0 and purge_orphans:
purge_bars_open_before(
exchange_key,
symbol,
timeframe,
int(trimmed[0]["open_time_ms"]),
db_path,
)
return trimmed
def purge_bars_open_before(
exchange_key: str,
symbol: str,
timeframe: str,
open_time_ms: int,
db_path: Path | None = None,
) -> int:
"""删除某品种周期下早于 open_time_ms 的 K 线(清理与主段断开的孤立历史)。"""
ex_k = (exchange_key or "").strip().lower()
sym = (symbol or "").strip().upper()
tf = normalize_chart_timeframe(timeframe)
conn = _connect(db_path)
try:
cur = conn.execute(
"""
DELETE FROM ohlcv_bars
WHERE exchange_key=? AND symbol=? AND timeframe=? AND open_time_ms < ?
""",
(ex_k, sym, tf, int(open_time_ms)),
)
return int(cur.rowcount or 0)
finally:
conn.close()
def _rows_to_bars(rows) -> list[dict[str, Any]]:
return [
{
"open_time_ms": int(r["open_time_ms"]),
"open": float(r["open"]),
"high": float(r["high"]),
"low": float(r["low"]),
"close": float(r["close"]),
"volume": float(r["volume"] or 0),
}
for r in rows
]
def _to_chart_candles(bars: list[dict[str, Any]]) -> list[dict[str, Any]]:
out = []
for b in bars:
try:
out.append(
{
"time": int(b["open_time_ms"] // 1000),
"open": float(b["open"]),
"high": float(b["high"]),
"low": float(b["low"]),
"close": float(b["close"]),
"volume": float(b.get("volume") or 0),
}
)
except (KeyError, TypeError, ValueError):
continue
return out
def _trim_display_bars(
bars: list[dict[str, Any]],
*,
need: int,
before_ms: int | None,
) -> list[dict[str, Any]]:
if not bars:
return []
if before_ms is not None and int(before_ms) > 0:
bms = int(before_ms)
bars = [b for b in bars if int(b["open_time_ms"]) < bms]
if len(bars) > need:
bars = bars[-need:]
return bars
if len(bars) > need:
bars = bars[-need:]
return bars
def resolve_chart_bars(
exchange_key: str,
symbol: str,
timeframe: str,
remote_fetch: Callable[..., dict[str, Any]],
*,
db_path: Path | None = None,
force_refresh: bool = False,
tail_refresh: bool = False,
clear_db: bool = False,
limit: int | None = None,
before_ms: int | None = None,
) -> dict[str, Any]:
"""
分页读库:首屏 / 左拖 before_ms / 尾部 tail_refresh。
各展示周期均直读交易所同步入库的同名 K 线。
"""
init_db(db_path)
purged = purge_retention(db_path)
cleared = 0
sym = (symbol or "").strip().upper()
ex_k = (exchange_key or "").strip().lower()
display_tf = normalize_chart_timeframe(timeframe)
if not sym or not ex_k:
return {"ok": False, "msg": "缺少 exchange 或 symbol"}
storage_tf = display_tf
is_history = before_ms is not None and int(before_ms) > 0
need = int(
limit
or (chart_chunk_limit(display_tf) if is_history else chart_initial_limit(display_tf))
)
need = max(1, min(need, chart_memory_cap(display_tf)))
now_ms = int(time.time() * 1000)
period_display = TIMEFRAME_MS[display_tf]
period_storage = TIMEFRAME_MS[storage_tf]
series_bar_count = (
count_series_bars(ex_k, sym, storage_tf, db_path) if not is_history else 0
)
if tail_refresh and not is_history:
min_seed = min(chart_initial_limit(display_tf) // 5, HUB_KLINE_MIN_BARS_BEFORE_TAIL)
if series_bar_count < max(1, min_seed):
tail_refresh = False
else:
need = min(need, 30)
cutoff = history_cutoff_ms_for_storage(storage_tf, now_ms)
if clear_db and not is_history and not tail_refresh:
cleared = clear_series_bars(ex_k, sym, storage_tf, db_path)
def load_display_rows() -> list[dict[str, Any]]:
if is_history:
rows = load_bars_before(ex_k, sym, storage_tf, int(before_ms), need, db_path)
return _trim_display_bars(rows, need=need, before_ms=int(before_ms))
return load_bars_latest(ex_k, sym, storage_tf, need, db_path)
db_rows: list[dict[str, Any]] = []
if not force_refresh:
db_rows = load_display_rows()
if not is_history and db_rows:
db_rows = normalize_contiguous_db_rows(
db_rows,
period_ms=period_display,
exchange_key=ex_k,
symbol=sym,
timeframe=storage_tf,
db_path=db_path,
)
last_closed = last_closed_bar_open_ms(display_tf, now_ms)
newest_db = db_rows[-1]["open_time_ms"] if db_rows else None
if is_history:
newest_ok = True
else:
newest_ok = newest_db is not None and int(newest_db) >= int(last_closed) - period_display
need_fetch = force_refresh or (
not is_history and (len(db_rows) < need or not newest_ok)
)
if is_history and len(db_rows) < need:
need_fetch = True
tail_only = False
if tail_refresh and not is_history and db_rows and not force_refresh and not need_fetch:
need_fetch = True
tail_only = True
fetched = 0
price_tick: Optional[float] = None
remote_err: Optional[str] = None
if need_fetch:
if is_history:
bms = int(before_ms)
anchor = bms - period_display
since = max(cutoff, anchor - period_storage * need)
fetch_limit = min(need + 20, 1500)
elif tail_only:
anchor_ms = int(newest_db) if newest_db is not None else now_ms
fetch_limit = _remote_fetch_limit(
need=need, force_refresh=False, storage_tf=storage_tf, tail_only=True
)
since = _since_ms_for_span(
now_ms=anchor_ms,
period_ms=period_storage,
span_bars=5,
cutoff_ms=cutoff,
)
else:
fetch_limit = _remote_fetch_limit(
need=need,
force_refresh=force_refresh,
storage_tf=storage_tf,
tail_only=False,
)
since = _since_ms_for_span(
now_ms=now_ms,
period_ms=period_storage,
span_bars=fetch_limit,
cutoff_ms=cutoff,
)
remote = remote_fetch(
symbol=sym,
timeframe=storage_tf,
since_ms=since,
limit=fetch_limit,
)
if remote.get("ok") and remote.get("bars"):
fetched = upsert_bars(ex_k, sym, storage_tf, remote["bars"], db_path)
price_tick = remote.get("price_tick")
if price_tick is not None:
save_symbol_price_tick(ex_k, sym, price_tick, db_path)
db_rows = load_display_rows()
if not is_history and db_rows:
db_rows = normalize_contiguous_db_rows(
db_rows,
period_ms=period_display,
exchange_key=ex_k,
symbol=sym,
timeframe=storage_tf,
db_path=db_path,
)
if not is_history and not tail_only and db_rows:
newest_ms = int(db_rows[-1]["open_time_ms"])
if newest_ms < int(last_closed) - period_display:
gap_limit = min(
500,
int((now_ms - newest_ms) // period_storage) + 10,
)
if gap_limit > 1:
gap_remote = remote_fetch(
symbol=sym,
timeframe=storage_tf,
since_ms=newest_ms,
limit=gap_limit,
)
if gap_remote.get("ok") and gap_remote.get("bars"):
fetched += upsert_bars(
ex_k, sym, storage_tf, gap_remote["bars"], db_path
)
db_rows = load_display_rows()
db_rows = normalize_contiguous_db_rows(
db_rows,
period_ms=period_display,
exchange_key=ex_k,
symbol=sym,
timeframe=storage_tf,
db_path=db_path,
)
else:
remote_err = remote.get("msg") or remote.get("error") or "实例拉取 K 线失败"
if not db_rows:
if is_history:
exhausted = True
else:
return {"ok": False, "msg": remote_err, "purged": purged}
exhausted = False
if is_history:
if not db_rows:
exhausted = True
elif len(db_rows) < need:
oldest = int(db_rows[0]["open_time_ms"])
if cutoff > 0 and oldest <= cutoff + period_storage:
exhausted = True
elif fetched == 0:
exhausted = True
if price_tick is None:
price_tick = load_symbol_price_tick(ex_k, sym, db_path)
if price_tick is None and not is_history:
try:
tick_probe = remote_fetch(
symbol=sym,
timeframe=storage_tf,
since_ms=None,
limit=3,
)
if tick_probe.get("ok"):
price_tick = tick_probe.get("price_tick")
if price_tick is not None:
save_symbol_price_tick(ex_k, sym, price_tick, db_path)
except Exception:
pass
if not is_history and db_rows:
db_rows = normalize_contiguous_db_rows(
db_rows,
period_ms=period_display,
exchange_key=ex_k,
symbol=sym,
timeframe=storage_tf,
db_path=db_path,
)
if not is_history and len(db_rows) < need:
missing = need - len(db_rows)
backfill_limit = min(missing + 60, HUB_KLINE_REMOTE_FETCH_CAP)
if db_rows:
oldest = int(db_rows[0]["open_time_ms"])
backfill_since = _since_ms_for_span(
now_ms=oldest,
period_ms=period_storage,
span_bars=backfill_limit,
cutoff_ms=cutoff,
)
else:
backfill_since = _since_ms_for_span(
now_ms=now_ms,
period_ms=period_storage,
span_bars=backfill_limit,
cutoff_ms=cutoff,
)
try:
remote_back = remote_fetch(
symbol=sym,
timeframe=storage_tf,
since_ms=backfill_since,
limit=backfill_limit,
)
if remote_back.get("ok") and remote_back.get("bars"):
fetched += upsert_bars(ex_k, sym, storage_tf, remote_back["bars"], db_path)
if remote_back.get("price_tick") is not None:
price_tick = remote_back.get("price_tick")
save_symbol_price_tick(ex_k, sym, price_tick, db_path)
db_rows = load_display_rows()
db_rows = normalize_contiguous_db_rows(
db_rows,
period_ms=period_display,
exchange_key=ex_k,
symbol=sym,
timeframe=storage_tf,
db_path=db_path,
)
elif not remote_err:
remote_err = (
remote_back.get("msg")
or remote_back.get("error")
or "实例补拉 K 线失败"
)
except Exception as e:
if not remote_err:
remote_err = str(e)
price_tick = normalize_price_tick(price_tick)
if db_rows and price_tick is not None:
round_ohlcv_bars_to_tick(db_rows, price_tick)
candles = _to_chart_candles(db_rows)
if not is_history and not candles and not exhausted:
return {"ok": False, "msg": remote_err or "无 K 线数据", "purged": purged}
oldest_ms = int(db_rows[0]["open_time_ms"]) if db_rows else None
newest_ms = int(db_rows[-1]["open_time_ms"]) if db_rows else None
from_cache = max(0, len(candles) - min(fetched, len(candles))) if fetched else len(candles)
return {
"ok": True,
"symbol": sym,
"exchange_key": ex_k,
"timeframe": display_tf,
"storage_timeframe": storage_tf,
"limit": need,
"before_ms": int(before_ms) if is_history else None,
"oldest_ms": oldest_ms,
"newest_ms": newest_ms,
"exhausted": exhausted,
"source": "remote" if fetched else "db",
"retention_policy": retention_policy_meta(),
"candles": candles,
"from_cache": from_cache,
"fetched": fetched,
"cleared": cleared,
"purged": purged,
"price_tick": price_tick,
"stale": bool(remote_err),
"stale_message": remote_err if remote_err else None,
"updated_at": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
}
def format_ohlcv_detail(bar: dict[str, Any] | None, tick: Optional[float]) -> dict[str, str]:
if not bar:
return {"open": "-", "high": "-", "low": "-", "close": "-", "volume": "-"}
return {
"open": format_price_by_tick(bar.get("open"), tick),
"high": format_price_by_tick(bar.get("high"), tick),
"low": format_price_by_tick(bar.get("low"), tick),
"close": format_price_by_tick(bar.get("close"), tick),
"volume": format_price_by_tick(bar.get("volume"), tick),
}
+311
View File
@@ -0,0 +1,311 @@
"""中控宏观关键数据日历:手动录入 FOMC / CPI / 非农档发布时间,±1h 风控前置窗口。"""
from __future__ import annotations
import os
import sqlite3
import time
from datetime import datetime
from pathlib import Path
from typing import Any
from zoneinfo import ZoneInfo
from lib.hub.hub_symbol_archive_lib import parse_wall_clock_ms
DISPLAY_TZ = ZoneInfo(os.getenv("APP_TIMEZONE", "Asia/Shanghai"))
MACRO_EVENT_TYPES = ("fomc", "cpi", "employment")
MACRO_EVENT_LABELS: dict[str, str] = {
"fomc": "FOMC 联邦基金利率",
"cpi": "美国 CPI 通胀",
"employment": "就业与劳工数据",
}
WINDOW_BEFORE_MS = int(os.getenv("HUB_MACRO_WINDOW_BEFORE_SEC", str(3600))) * 1000
WINDOW_AFTER_MS = int(os.getenv("HUB_MACRO_WINDOW_AFTER_SEC", str(3600))) * 1000
IMMINENT_BEFORE_MS = int(os.getenv("HUB_MACRO_IMMINENT_BEFORE_SEC", str(1800))) * 1000
LIST_FUTURE_DAYS = int(os.getenv("HUB_MACRO_LIST_FUTURE_DAYS", "60"))
def default_db_path() -> Path:
raw = (os.getenv("HUB_MACRO_CALENDAR_DB_PATH") or "").strip()
if raw:
return Path(raw)
hub_dir = Path(__file__).resolve().parent / "manual_trading_hub" / "data"
hub_dir.mkdir(parents=True, exist_ok=True)
return hub_dir / "hub_macro_calendar.db"
def _connect(db_path: Path | None = None) -> sqlite3.Connection:
path = db_path or default_db_path()
path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(path), timeout=30, isolation_level=None)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
return conn
def init_db(db_path: Path | None = None) -> None:
conn = _connect(db_path)
try:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS macro_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
event_type TEXT NOT NULL,
event_at_ms INTEGER NOT NULL,
note TEXT NOT NULL DEFAULT '',
created_at_ms INTEGER NOT NULL,
updated_at_ms INTEGER NOT NULL
)
"""
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_macro_events_at ON macro_events(event_at_ms)"
)
finally:
conn.close()
def normalize_event_type(raw: str) -> str:
key = (raw or "").strip().lower()
if key not in MACRO_EVENT_TYPES:
raise ValueError(f"事件类型须为: {', '.join(MACRO_EVENT_LABELS.values())}")
return key
def parse_event_at_ms(raw: Any) -> int:
ms = parse_wall_clock_ms(raw, tz=DISPLAY_TZ)
if ms is None:
raise ValueError("发布时间格式错误,请使用 YYYY-MM-DD HH:MM 或 YYYY-MM-DDTHH:MM")
return int(ms)
def format_event_at(ms: int) -> str:
dt = datetime.fromtimestamp(ms / 1000, tz=DISPLAY_TZ)
return dt.strftime("%Y-%m-%d %H:%M")
def _row_to_dict(row: sqlite3.Row) -> dict[str, Any]:
ms = int(row["event_at_ms"])
et = str(row["event_type"])
return {
"id": int(row["id"]),
"event_type": et,
"event_type_label": MACRO_EVENT_LABELS.get(et, et),
"event_at_ms": ms,
"event_at": format_event_at(ms),
"note": str(row["note"] or ""),
"created_at_ms": int(row["created_at_ms"]),
"updated_at_ms": int(row["updated_at_ms"]),
}
def _window_bounds(event_at_ms: int) -> tuple[int, int]:
start = int(event_at_ms) - WINDOW_BEFORE_MS
end = int(event_at_ms) + WINDOW_AFTER_MS
return start, end
def enrich_alert(row: dict[str, Any], now_ms: int | None = None) -> dict[str, Any] | None:
now = int(now_ms if now_ms is not None else time.time() * 1000)
event_at_ms = int(row["event_at_ms"])
window_start, window_end = _window_bounds(event_at_ms)
if now < window_start or now > window_end:
return None
imminent = now >= (event_at_ms - IMMINENT_BEFORE_MS) and now <= window_end
mins_to_event = max(0, int((event_at_ms - now) / 60000))
mins_from_event = max(0, int((now - event_at_ms) / 60000))
return {
**row,
"window_start_ms": window_start,
"window_end_ms": window_end,
"window_start": format_event_at(window_start),
"window_end": format_event_at(window_end),
"phase": "imminent" if imminent else "window",
"phase_label": "即将发布" if imminent and now < event_at_ms else "高波动窗口",
"minutes_to_event": mins_to_event if now < event_at_ms else 0,
"minutes_from_event": mins_from_event if now >= event_at_ms else 0,
}
def list_events(
*,
now_ms: int | None = None,
include_expired_hours: int = 24,
db_path: Path | None = None,
) -> list[dict[str, Any]]:
init_db(db_path)
now = int(now_ms if now_ms is not None else time.time() * 1000)
horizon = now + LIST_FUTURE_DAYS * 86400 * 1000
expired_cutoff = now - max(0, int(include_expired_hours)) * 3600 * 1000 - WINDOW_AFTER_MS
conn = _connect(db_path)
try:
rows = conn.execute(
"""
SELECT * FROM macro_events
WHERE event_at_ms >= ? AND event_at_ms <= ?
ORDER BY event_at_ms ASC, id ASC
""",
(expired_cutoff, horizon),
).fetchall()
return [_row_to_dict(r) for r in rows]
finally:
conn.close()
def get_event(event_id: int, db_path: Path | None = None) -> dict[str, Any] | None:
init_db(db_path)
conn = _connect(db_path)
try:
row = conn.execute("SELECT * FROM macro_events WHERE id=?", (int(event_id),)).fetchone()
return _row_to_dict(row) if row else None
finally:
conn.close()
def _assert_no_duplicate(
conn: sqlite3.Connection,
event_type: str,
event_at_ms: int,
*,
exclude_id: int | None = None,
) -> None:
if exclude_id is None:
row = conn.execute(
"SELECT id FROM macro_events WHERE event_type=? AND event_at_ms=? LIMIT 1",
(event_type, int(event_at_ms)),
).fetchone()
else:
row = conn.execute(
"""
SELECT id FROM macro_events
WHERE event_type=? AND event_at_ms=? AND id<>?
LIMIT 1
""",
(event_type, int(event_at_ms), int(exclude_id)),
).fetchone()
if row:
raise ValueError("同类型、同发布时间的记录已存在")
def create_event(
event_type: str,
event_at: Any,
*,
note: str = "",
db_path: Path | None = None,
) -> dict[str, Any]:
init_db(db_path)
et = normalize_event_type(event_type)
event_at_ms = parse_event_at_ms(event_at)
note_s = str(note or "").strip()[:500]
now_ms = int(time.time() * 1000)
conn = _connect(db_path)
try:
_assert_no_duplicate(conn, et, event_at_ms)
cur = conn.execute(
"""
INSERT INTO macro_events (event_type, event_at_ms, note, created_at_ms, updated_at_ms)
VALUES (?, ?, ?, ?, ?)
""",
(et, event_at_ms, note_s, now_ms, now_ms),
)
eid = int(cur.lastrowid)
finally:
conn.close()
row = get_event(eid, db_path=db_path)
assert row is not None
return row
def update_event(
event_id: int,
*,
event_type: str | None = None,
event_at: Any | None = None,
note: str | None = None,
db_path: Path | None = None,
) -> dict[str, Any] | None:
init_db(db_path)
existing = get_event(event_id, db_path=db_path)
if not existing:
return None
et = normalize_event_type(event_type if event_type is not None else existing["event_type"])
event_at_ms = (
parse_event_at_ms(event_at) if event_at is not None else int(existing["event_at_ms"])
)
note_s = existing["note"] if note is None else str(note or "").strip()[:500]
now_ms = int(time.time() * 1000)
conn = _connect(db_path)
try:
_assert_no_duplicate(conn, et, event_at_ms, exclude_id=int(event_id))
conn.execute(
"""
UPDATE macro_events
SET event_type=?, event_at_ms=?, note=?, updated_at_ms=?
WHERE id=?
""",
(et, event_at_ms, note_s, now_ms, int(event_id)),
)
finally:
conn.close()
return get_event(event_id, db_path=db_path)
def delete_event(event_id: int, db_path: Path | None = None) -> bool:
init_db(db_path)
conn = _connect(db_path)
try:
cur = conn.execute("DELETE FROM macro_events WHERE id=?", (int(event_id),))
return cur.rowcount > 0
finally:
conn.close()
def list_active_alerts(
now_ms: int | None = None,
db_path: Path | None = None,
) -> list[dict[str, Any]]:
now = int(now_ms if now_ms is not None else time.time() * 1000)
lookback = now - WINDOW_BEFORE_MS - IMMINENT_BEFORE_MS
lookahead = now + WINDOW_AFTER_MS
init_db(db_path)
conn = _connect(db_path)
try:
rows = conn.execute(
"""
SELECT * FROM macro_events
WHERE event_at_ms >= ? AND event_at_ms <= ?
ORDER BY event_at_ms ASC, id ASC
""",
(lookback, lookahead),
).fetchall()
finally:
conn.close()
alerts: list[dict[str, Any]] = []
for row in rows:
item = enrich_alert(_row_to_dict(row), now_ms=now)
if item:
alerts.append(item)
return alerts
def build_banner_message(alert: dict[str, Any], *, has_positions: bool) -> str:
label = alert.get("event_type_label") or alert.get("event_type") or "宏观数据"
phase = alert.get("phase") or "window"
if has_positions:
if phase == "imminent" and int(alert.get("minutes_to_event") or 0) > 0:
return (
f"{label}」即将发布(约 {alert['minutes_to_event']} 分钟),"
"注意仓位风险:勿加仓,检查止损/减仓"
)
return f"{label}」高波动窗口(±1h),注意仓位风险:勿加仓,检查止损/减仓"
if phase == "imminent" and int(alert.get("minutes_to_event") or 0) > 0:
return (
f"{label}」即将发布(约 {alert['minutes_to_event']} 分钟),"
"建议等待,避免新开仓"
)
return f"{label}」高波动窗口(±1h),建议等待,避免新开仓"
+81
View File
@@ -0,0 +1,81 @@
"""实例 USDT 永续合约信息(与实盘 ccxt 精度一致)。"""
from __future__ import annotations
from typing import Any, Callable, Optional, Tuple
from lib.hub.hub_calculator_market_lib import (
amount_decimals_from_exchange,
normalize_base_symbol,
price_decimals_from_exchange,
resolve_usdt_perp_symbol,
)
from lib.hub.hub_ohlcv_lib import normalize_price_tick, price_tick_from_market
def fetch_usdt_swap_market_info(
*,
base_or_symbol: str,
normalize_symbol_input: Callable[[str], str],
normalize_exchange_symbol: Callable[[str], str],
ensure_markets_loaded: Callable[[], None],
exchange: Any,
exchange_id: str = "",
) -> dict[str, Any]:
"""供各实例 /api/hub/market 调用。"""
raw = str(base_or_symbol or "").strip()
if not raw:
return {"ok": False, "msg": "请输入币种,如 ETH"}
try:
ensure_markets_loaded()
except Exception as exc:
return {"ok": False, "msg": f"加载市场失败: {exc}"}
base_u = normalize_base_symbol(raw)
hub_sym = normalize_symbol_input(raw if base_u else raw)
try:
ex_sym = normalize_exchange_symbol(hub_sym)
except Exception:
ex_sym = hub_sym
sym, err = resolve_usdt_perp_symbol(exchange, base_u or hub_sym)
if err and ex_sym:
markets = getattr(exchange, "markets", None) or {}
if ex_sym in markets:
sym = ex_sym
err = None
if err or not sym:
return {"ok": False, "msg": err or f"未找到 {base_u or raw}/USDT 永续合约"}
market = exchange.market(sym)
try:
contract_size = float(market.get("contractSize") or 1.0)
except (TypeError, ValueError):
contract_size = 1.0
if contract_size <= 0:
contract_size = 1.0
price_tick = normalize_price_tick(price_tick_from_market(exchange, sym))
amt_dec = amount_decimals_from_exchange(exchange, sym)
px_dec = price_decimals_from_exchange(exchange, sym, price_tick)
min_amount = None
try:
min_amount = float((market.get("limits") or {}).get("amount", {}).get("min"))
except (TypeError, ValueError):
min_amount = None
base_out = (market.get("base") or base_u or "").upper() or base_u
return {
"ok": True,
"exchange": (exchange_id or "").strip().lower(),
"base": base_out,
"exchange_symbol": sym,
"display_symbol": f"{base_out}/USDT" if base_out else sym,
"contract_size": contract_size,
"price_tick": price_tick,
"price_decimals": px_dec,
"amount_decimals": amt_dec,
"min_amount": min_amount,
}
+692
View File
@@ -0,0 +1,692 @@
"""中控行情区:各实例 ccxt OHLCV 拉取(hub_bridge /api/hub/ohlcv 共用)。"""
from __future__ import annotations
import math
import os
import time
from typing import Any, Callable, Optional
CHART_TIMEFRAMES = frozenset(
{
"1m",
"5m",
"15m",
"1h",
"2h",
"4h",
"1d",
"1w",
}
)
CHART_TIMEFRAME_ORDER = (
"1m",
"5m",
"15m",
"1h",
"2h",
"4h",
"1d",
"1w",
)
DAILY_PLUS_TIMEFRAMES = frozenset({"1d", "1w"})
# 入库 / 同步真源(各周期直拉交易所,不做本地聚合)
STORED_TIMEFRAMES = frozenset(CHART_TIMEFRAMES)
PERMANENT_STORED_TIMEFRAMES = frozenset({"1d", "1w"})
YEAR_ROLLING_STORED = frozenset({"5m", "15m", "1h", "2h", "4h"})
# 行情区不做展示周期聚合;保留空映射供兼容读取
CHART_DISPLAY_AGGREGATE_FROM: dict[str, str] = {}
SMALL_DISPLAY_TFS = frozenset({"1m", "5m", "15m"})
MID_DISPLAY_TFS = frozenset({"1h", "2h", "4h"})
HUB_KLINE_1M_MAX_BARS = max(1000, int(os.getenv("HUB_KLINE_1M_MAX_BARS", "10000")))
HUB_KLINE_5M_1H_RETENTION_DAYS = max(30, int(os.getenv("HUB_KLINE_5M_1H_RETENTION_DAYS", "365")))
HUB_KLINE_SEED_BARS = max(100, int(os.getenv("HUB_KLINE_SEED_BARS", "500")))
# 交易所无原生周期时的远程拉取 fallback(行情区当前无映射)
OHLCV_AGGREGATE_FROM: dict[str, str] = {}
TIMEFRAME_MS: dict[str, int] = {
"1m": 60_000,
"5m": 5 * 60_000,
"15m": 15 * 60_000,
"1h": 60 * 60_000,
"2h": 2 * 60 * 60_000,
"4h": 4 * 60 * 60_000,
"12h": 12 * 60 * 60_000,
"1d": 24 * 60 * 60_000,
"1w": 7 * 24 * 60 * 60_000,
}
def normalize_chart_timeframe(raw: str | None, default: str = "5m") -> str:
tf = (raw or default).strip().lower()
return tf if tf in CHART_TIMEFRAMES else default
def normalize_perpetual_symbol(symbol: str) -> str:
"""BTC/USDT → BTC/USDT:USDT(与四所 ccxt swap 行情一致)。"""
sym = (symbol or "").strip().upper()
if not sym:
return ""
if ":" in sym:
return sym
if "/" in sym:
base, quote = sym.split("/", 1)
quote_clean = quote.split(":")[0]
return f"{base}/{quote_clean}:{quote_clean}"
return sym
def sync_timeframe_for_display(timeframe: str) -> str:
"""展示周期对应的入库 / 同步周期。"""
tf = normalize_chart_timeframe(timeframe)
return CHART_DISPLAY_AGGREGATE_FROM.get(tf, tf)
def aggregation_source_for_display(timeframe: str) -> str | None:
tf = normalize_chart_timeframe(timeframe)
return CHART_DISPLAY_AGGREGATE_FROM.get(tf)
def aggregate_ratio(display_tf: str, source_tf: str) -> int:
d = normalize_chart_timeframe(display_tf)
s = normalize_chart_timeframe(source_tf)
return max(1, int(TIMEFRAME_MS[d] // TIMEFRAME_MS[s]))
def chart_initial_limit(timeframe: str) -> int:
tf = normalize_chart_timeframe(timeframe)
if tf in SMALL_DISPLAY_TFS:
return 2000
if tf in MID_DISPLAY_TFS:
return 1000
if tf in DAILY_PLUS_TIMEFRAMES:
return 500
return 500
def chart_chunk_limit(timeframe: str) -> int:
tf = normalize_chart_timeframe(timeframe)
if tf in SMALL_DISPLAY_TFS:
return 500
if tf == "1w":
return 150
if tf in MID_DISPLAY_TFS:
return 300
return 200
def chart_memory_cap(timeframe: str) -> int:
tf = normalize_chart_timeframe(timeframe)
if tf in SMALL_DISPLAY_TFS:
return 5000
if tf == "1w":
return 500
return 1000
def bar_limit_for_timeframe(timeframe: str) -> int:
return chart_memory_cap(timeframe)
def storage_retention_days(storage_tf: str) -> int | None:
"""None 表示不按天截断(1m 按根数;1d/1w 永久)。"""
tf = normalize_chart_timeframe(storage_tf)
if tf in YEAR_ROLLING_STORED:
return HUB_KLINE_5M_1H_RETENTION_DAYS
return None
def history_cutoff_ms_for_storage(storage_tf: str, now_ms: int | None = None) -> int:
days = storage_retention_days(storage_tf)
if days is None:
return 0
now = int(now_ms if now_ms is not None else time.time() * 1000)
return max(0, now - int(days) * 86400000)
def seed_bar_target(storage_tf: str) -> int:
tf = normalize_chart_timeframe(storage_tf)
if tf == "1m":
return HUB_KLINE_1M_MAX_BARS
if tf in YEAR_ROLLING_STORED:
period = TIMEFRAME_MS[tf]
return min(
int(86400000 * HUB_KLINE_5M_1H_RETENTION_DAYS / period) + 20,
150000,
)
return HUB_KLINE_SEED_BARS
def retention_policy_meta() -> dict[str, Any]:
year = {"mode": "days", "days": HUB_KLINE_5M_1H_RETENTION_DAYS}
return {
"1m": {"mode": "bars", "max_bars": HUB_KLINE_1M_MAX_BARS},
"5m": dict(year),
"15m": dict(year),
"1h": dict(year),
"2h": dict(year),
"4h": dict(year),
"1d": {"mode": "permanent"},
"1w": {"mode": "permanent"},
"aggregate_from": {},
}
def last_closed_bar_open_ms(timeframe: str, now_ms: int | None = None) -> int:
"""上一根已收盘 K 的 open_time(毫秒 UTC)。"""
tf = normalize_chart_timeframe(timeframe)
period = TIMEFRAME_MS[tf]
now = int(now_ms if now_ms is not None else time.time() * 1000)
current_open = (now // period) * period
return int(current_open - period)
def window_start_ms(timeframe: str, need: int, retention_days: int, now_ms: int | None = None) -> int:
"""本地库清理/读库窗口:不超过 retention_days。"""
now = int(now_ms if now_ms is not None else time.time() * 1000)
period = TIMEFRAME_MS[normalize_chart_timeframe(timeframe)]
retention_cutoff = now - max(1, int(retention_days)) * 86400000
want = now - max(1, int(need)) * period
return max(retention_cutoff, want)
def chart_fetch_start_ms(timeframe: str, need: int, now_ms: int | None = None) -> int:
"""行情展示拉取起点:按 need 根回看(日线 500 / 日内 1000),不受 DB 保留天数限制。"""
now = int(now_ms if now_ms is not None else time.time() * 1000)
period = TIMEFRAME_MS[normalize_chart_timeframe(timeframe)]
return max(0, now - max(1, int(need)) * period)
def _positive_float(value: Any) -> Optional[float]:
if value in (None, ""):
return None
try:
v = float(value)
except (TypeError, ValueError):
return None
return v if v > 0 else None
def _price_tick_from_market_info(info: dict) -> Optional[float]:
"""从 market.info 解析 tick(含币安 PRICE_FILTER.filters)。"""
for key in ("tickSize", "tickSz", "price_increment", "order_price_round", "quote_increment"):
v = _positive_float(info.get(key))
if v is not None:
return v
for key in ("pricePrecision", "price_precision"):
raw = info.get(key)
if raw in (None, ""):
continue
try:
p = float(raw)
except (TypeError, ValueError):
continue
if p >= 1 and abs(p - round(p)) < 1e-9 and p <= 12:
return 10 ** (-int(p))
if 0 < p < 1:
return p
filters = info.get("filters")
if isinstance(filters, list):
for f in filters:
if not isinstance(f, dict):
continue
if str(f.get("filterType") or "").upper() != "PRICE_FILTER":
continue
v = _positive_float(f.get("tickSize"))
if v is not None:
return v
return None
def round_price_to_tick(value: Any, tick: Optional[float]) -> Optional[float]:
"""按交易所 tick 对齐价格(K 线/标记线与坐标轴一致)。"""
t = normalize_price_tick(tick)
if t is None:
return None
try:
v = float(value)
except (TypeError, ValueError):
return None
n = round(v / t) * t
d = _decimals_from_tick(t)
return float(f"{n:.{d}f}")
def round_ohlcv_bars_to_tick(bars: list[dict[str, Any]], tick: Optional[float]) -> None:
t = normalize_price_tick(tick)
if t is None:
return
for b in bars:
for key in ("open", "high", "low", "close"):
if key in b:
rounded = round_price_to_tick(b.get(key), t)
if rounded is not None:
b[key] = rounded
def price_tick_from_market(exchange, exchange_symbol: str) -> Optional[float]:
"""最小价格变动单位(与交易所 tick / price_to_precision 一致)。"""
try:
if not getattr(exchange, "markets", None):
exchange.load_markets()
market = exchange.market(exchange_symbol)
except Exception:
return None
info = market.get("info") or {}
if isinstance(info, dict):
tick = _price_tick_from_market_info(info)
if tick is not None:
return tick
limits = market.get("limits") or {}
price_limits = limits.get("price") or {}
if price_limits.get("min") not in (None, ""):
try:
v = float(price_limits["min"])
if v > 0:
return v
except (TypeError, ValueError):
pass
try:
sample = exchange.price_to_precision(exchange_symbol, 12345.678901234)
s = str(sample).strip()
if "." in s:
frac = s.split(".", 1)[1]
if frac:
return 10 ** (-len(frac))
return 1.0
except Exception:
pass
prec = (market.get("precision") or {}).get("price")
if prec is not None:
try:
p = float(prec)
if p >= 1 and abs(p - round(p)) < 1e-9 and p <= 12:
return 10 ** (-int(p))
if 0 < p < 1:
return p
except (TypeError, ValueError):
pass
return None
def normalize_price_tick(tick: Optional[float]) -> Optional[float]:
"""将 tick 对齐为 10^-n,避免浮点噪声导致前端 lightweight-charts unexpected base。"""
if tick is None:
return None
try:
t = float(tick)
except (TypeError, ValueError):
return None
if t <= 0:
return None
if t >= 1:
return t
try:
exp = int(round(-math.log10(t)))
except (ValueError, OverflowError):
return None
exp = max(0, min(12, exp))
return 10 ** (-exp)
def _decimals_from_tick(tick: float) -> int:
if tick >= 1:
return 0
s = f"{tick:.12f}".rstrip("0")
if "." in s:
frac = s.split(".", 1)[1]
if frac:
return min(12, len(frac))
return max(0, min(12, int(round(-math.log10(tick)))))
def format_price_by_tick(value: Any, tick: Optional[float]) -> str:
if value in (None, ""):
return "-"
try:
v = float(value)
except (TypeError, ValueError):
return str(value)
if v == 0:
return "0"
if tick and tick > 0:
return f"{v:.{_decimals_from_tick(float(tick))}f}"
av = abs(v)
if av >= 10000:
d = 2
elif av >= 100:
d = 3
elif av >= 1:
d = 4
elif av >= 0.01:
d = 6
else:
d = 8
text = f"{v:.{d}f}"
return text.rstrip("0").rstrip(".") if "." in text else text
def exchange_supports_timeframe(exchange, timeframe: str) -> bool:
tf = normalize_chart_timeframe(timeframe)
tfs = getattr(exchange, "timeframes", None) or {}
if not tfs:
return True
return tf in tfs
def _median_bar_step_ms(bars: list[dict[str, Any]]) -> Optional[int]:
if len(bars) < 2:
return None
steps: list[int] = []
for i in range(1, min(len(bars), 64)):
step = int(bars[i]["open_time_ms"]) - int(bars[i - 1]["open_time_ms"])
if step > 0:
steps.append(step)
if not steps:
return None
steps.sort()
return steps[len(steps) // 2]
def bars_spacing_matches_timeframe(
bars: list[dict[str, Any]], timeframe: str, *, tolerance: float = 0.08
) -> bool:
if len(bars) < 2:
return True
period = TIMEFRAME_MS[normalize_chart_timeframe(timeframe)]
step = _median_bar_step_ms(bars)
if step is None:
return False
return abs(step - period) <= period * tolerance
def align_bar_open_ms(open_time_ms: int, period_ms: int) -> int:
return (int(open_time_ms) // period_ms) * period_ms
def snap_to_bar_grid(ts_ms: int, origin_ms: int, step_ms: int) -> int:
step = max(1, int(step_ms))
origin = int(origin_ms)
if ts_ms <= origin:
return origin
idx = (int(ts_ms) - origin + step - 1) // step
return origin + idx * step
def fill_missing_ohlcv_bars(
bars: list[dict[str, Any]],
period_ms: int,
start_ms: int | None = None,
end_ms: int | None = None,
) -> list[dict[str, Any]]:
"""细周期缺口用上一根收盘价填平,保证聚合后 K 线时间轴连续。"""
by_ts: dict[int, dict[str, Any]] = {}
for b in bars or []:
try:
by_ts[int(b["open_time_ms"])] = b
except (KeyError, TypeError, ValueError):
continue
if not by_ts:
return []
keys = sorted(by_ts.keys())
step_ms = max(1, int(period_ms))
origin = keys[0]
aligned_start = snap_to_bar_grid(
int(start_ms if start_ms is not None else keys[0]), origin, step_ms
)
aligned_end = max(
int(end_ms if end_ms is not None else keys[-1]),
keys[-1],
)
out: list[dict[str, Any]] = []
last: dict[str, Any] | None = None
for ts_key in keys:
if ts_key <= aligned_start:
last = by_ts[ts_key]
ts = aligned_start
while ts <= aligned_end:
cur = by_ts.get(ts)
if cur is not None:
last = cur
out.append(cur)
elif last is not None:
c = float(last["close"])
out.append(
{
"open_time_ms": ts,
"open": c,
"high": c,
"low": c,
"close": c,
"volume": 0.0,
"filled": True,
}
)
ts += step_ms
return out
def aggregate_ohlcv_bars(
bars: list[dict[str, Any]], target_timeframe: str
) -> list[dict[str, Any]]:
"""将细周期 OHLCV 聚合为目标周期(UTC 对齐 bucket)。"""
tf = normalize_chart_timeframe(target_timeframe)
period = TIMEFRAME_MS[tf]
buckets: dict[int, dict[str, Any]] = {}
for b in bars or []:
try:
key = align_bar_open_ms(int(b["open_time_ms"]), period)
o = float(b["open"])
h = float(b["high"])
l = float(b["low"])
c = float(b["close"])
v = float(b.get("volume") or 0)
except (KeyError, TypeError, ValueError):
continue
cur = buckets.get(key)
if cur is None:
buckets[key] = {
"open_time_ms": key,
"open": o,
"high": h,
"low": l,
"close": c,
"volume": v,
}
continue
cur["high"] = max(float(cur["high"]), h)
cur["low"] = min(float(cur["low"]), l)
cur["close"] = c
cur["volume"] = float(cur.get("volume") or 0) + v
return [buckets[k] for k in sorted(buckets.keys())]
def _next_since_from_batch(batch: list, period_ms: int) -> int:
last_ts = int(batch[-1][0])
if len(batch) >= 2:
step = int(batch[-1][0]) - int(batch[-2][0])
if step > 0:
return last_ts + step
return last_ts + period_ms
def _paginate_fetch_ohlcv(
exchange,
ex_sym: str,
timeframe: str,
*,
want: int,
since_ms: int | None,
period_ms: int,
chunk_max: int = 300,
) -> list[dict[str, Any]]:
tf = normalize_chart_timeframe(timeframe)
collected: list = []
if since_ms is not None and int(since_ms) > 0:
since = int(since_ms)
else:
since = max(0, int(time.time() * 1000) - want * period_ms)
now_ms = int(time.time() * 1000)
guard = 0
prev_since = None
while len(collected) < want and guard < 80:
guard += 1
if since >= now_ms:
break
req_limit = min(chunk_max, want - len(collected))
try:
batch = exchange.fetch_ohlcv(
ex_sym, timeframe=tf, since=since, limit=req_limit
)
except Exception as e:
err = str(e).lower()
if collected and (
"from" in err
and "to" in err
or "invalid request parameter" in err
):
break
raise
if not batch:
break
collected.extend(batch)
next_since = _next_since_from_batch(batch, period_ms)
if next_since >= now_ms:
break
if prev_since is not None and next_since <= prev_since:
break
prev_since = since
since = next_since
bars = _bars_to_dicts(collected)
uniq: dict[int, dict[str, Any]] = {}
for b in bars:
uniq[int(b["open_time_ms"])] = b
merged = [uniq[k] for k in sorted(uniq.keys())]
if len(merged) > want:
merged = merged[-want:]
return merged
def _bars_to_dicts(ohlcv: list) -> list[dict[str, Any]]:
out: list[dict[str, Any]] = []
for bar in ohlcv or []:
if not bar or len(bar) < 6:
continue
try:
out.append(
{
"open_time_ms": int(bar[0]),
"open": float(bar[1]),
"high": float(bar[2]),
"low": float(bar[3]),
"close": float(bar[4]),
"volume": float(bar[5]),
}
)
except (TypeError, ValueError):
continue
return out
def fetch_ohlcv_for_hub(
*,
symbol: str,
timeframe: str,
since_ms: int | None = None,
limit: int = 500,
normalize_symbol_input: Callable[[Any], str],
normalize_exchange_symbol: Callable[[str], str],
ensure_markets_loaded: Callable[[], None],
exchange,
friendly_error: Callable[[Exception], str] | None = None,
) -> dict[str, Any]:
"""从 ccxt 拉 OHLCV,供 hub_bridge /api/hub/ohlcv 返回。"""
tf = normalize_chart_timeframe(timeframe)
sym = normalize_symbol_input(symbol)
if not sym:
return {"ok": False, "msg": "symbol 不能为空"}
try:
ensure_markets_loaded()
ex_sym = normalize_exchange_symbol(sym)
want = max(1, min(int(limit or bar_limit_for_timeframe(tf)), 1500))
period = TIMEFRAME_MS[tf]
merged: list[dict[str, Any]] = []
src_tf = OHLCV_AGGREGATE_FROM.get(tf)
if exchange_supports_timeframe(exchange, tf):
candidate = _paginate_fetch_ohlcv(
exchange,
ex_sym,
tf,
want=want,
since_ms=since_ms,
period_ms=period,
)
if candidate and bars_spacing_matches_timeframe(candidate, tf):
merged = candidate
if (
not merged
and src_tf
and exchange_supports_timeframe(exchange, src_tf)
):
src_period = TIMEFRAME_MS[normalize_chart_timeframe(src_tf)]
ratio = max(1, int(math.ceil(period / src_period)))
src_want = min(1500, want * ratio + ratio * 4)
src_bars = _paginate_fetch_ohlcv(
exchange,
ex_sym,
src_tf,
want=src_want,
since_ms=since_ms,
period_ms=src_period,
)
if not src_bars or not bars_spacing_matches_timeframe(src_bars, src_tf):
return {
"ok": False,
"msg": f"无法获取 {tf} K 线(细周期 {src_tf} 数据异常)",
}
merged = aggregate_ohlcv_bars(src_bars, tf)
if len(merged) > want:
merged = merged[-want:]
if not merged:
try:
tail = exchange.fetch_ohlcv(
ex_sym, timeframe=tf, limit=min(want, 300)
)
merged = _bars_to_dicts(tail or [])
if len(merged) > want:
merged = merged[-want:]
except Exception:
pass
if not merged:
return {"ok": False, "msg": "交易所未返回 K 线"}
tick = normalize_price_tick(price_tick_from_market(exchange, ex_sym))
round_ohlcv_bars_to_tick(merged, tick)
return {
"ok": True,
"symbol": sym,
"exchange_symbol": ex_sym,
"timeframe": tf,
"price_tick": tick,
"bars": merged,
}
except Exception as e:
msg = friendly_error(e) if friendly_error else str(e)
return {"ok": False, "msg": f"K线加载失败:{msg}"}
+249
View File
@@ -0,0 +1,249 @@
"""ccxt 持仓标记价解析(实例 price_snapshot 与中控子代理共用)。"""
from __future__ import annotations
import math
from typing import Any, Callable
def _finite_or_none(x: Any) -> float | None:
try:
f = float(x)
return f if math.isfinite(f) else None
except (TypeError, ValueError):
return None
def _coerce_float(*values: Any) -> float | None:
for v in values:
if v is None or v == "":
continue
px = _finite_or_none(v)
if px is not None and px > 0:
return px
return None
def position_contracts(p: dict[str, Any]) -> float:
raw = p.get("contracts")
if raw is not None:
try:
return float(raw)
except (TypeError, ValueError):
pass
info = p.get("info") or {}
if not isinstance(info, dict):
info = {}
for k in ("positionAmt", "positionamt", "pos", "size"):
if k in info:
try:
v = float(info[k])
if v != 0:
return v
except (TypeError, ValueError):
pass
return 0.0
def position_side_from_ccxt(p: dict[str, Any], contracts: float | None = None) -> str:
s = (p.get("side") or "").lower()
if s in ("long", "short"):
return s
c = contracts if contracts is not None else position_contracts(p)
if c > 0:
return "long"
if c < 0:
return "short"
return "long"
def parse_position_entry_price(p: dict[str, Any]) -> float | None:
"""四所 ccxt 持仓开仓均价。"""
if not isinstance(p, dict):
return None
info = p.get("info") or {}
if not isinstance(info, dict):
info = {}
return _coerce_float(
p.get("entryPrice"),
p.get("entry_price"),
p.get("average"),
info.get("entryPrice"),
info.get("entry_price"),
info.get("avgPx"),
info.get("avgEntryPrice"),
info.get("avg_entry_price"),
info.get("avgPrice"),
info.get("openAvgPx"),
)
def estimate_linear_swap_upnl_usdt(
side: str,
entry: float | None,
mark: float | None,
contracts: float | None,
contract_size: float | None = None,
) -> float | None:
"""U 本位线性永续:浮盈 = (标记价 - 开仓价) × 张数 × contractSize(空头取反)。"""
e = _finite_or_none(entry)
m = _finite_or_none(mark)
c = _finite_or_none(contracts)
if e is None or m is None or c is None or c <= 0:
return None
mult = _finite_or_none(contract_size)
if mult is None or mult <= 0:
mult = 1.0
diff = (m - e) if (side or "long").strip().lower() == "long" else (e - m)
return round(diff * abs(c) * mult, 2)
def resolve_position_display_upnl(
side: str,
entry: float | None,
mark: float | None,
contracts: float | None,
contract_size: float | None,
exchange_upnl: float | None,
) -> float | None:
"""展示用浮盈:优先与标记价/张数一致的推算;与交易所值偏差过大时用推算值。"""
computed = estimate_linear_swap_upnl_usdt(
side, entry, mark, contracts, contract_size
)
if computed is None:
return exchange_upnl
if exchange_upnl is None:
return computed
ref = max(abs(computed), 1.0)
if abs(exchange_upnl - computed) / ref > 0.2:
return computed
return exchange_upnl
def _coerce_signed(*values: Any) -> float | None:
"""解析可正可负的数值(未实现盈亏等)。"""
for v in values:
if v is None or v == "":
continue
f = _finite_or_none(v)
if f is not None:
return f
return None
def parse_position_unrealized_pnl(p: dict[str, Any]) -> float | None:
"""四所 ccxt 持仓统一解析未实现盈亏(Gate/OKX/Binance 字段名不一致)。"""
if not isinstance(p, dict):
return None
info = p.get("info") or {}
if not isinstance(info, dict):
info = {}
return _coerce_signed(
p.get("unrealizedPnl"),
p.get("unrealisedPnl"),
p.get("unrealized_pnl"),
p.get("unrealised_pnl"),
info.get("unrealised_pnl"),
info.get("unrealized_pnl"),
info.get("unrealisedPnl"),
info.get("unrealizedPnl"),
info.get("upl"),
info.get("uplLast"),
)
def enrich_ccxt_position_metrics_out(
position: dict[str, Any],
out: dict[str, Any],
*,
contract_size: float = 1.0,
funds_decimals: int = 2,
) -> dict[str, Any]:
"""
四所 parse_ccxt_position_metrics 产出后统一:
- 标记价用 hub 兜底
- 未实现盈亏 = resolve(交易所值, entry/mark/张数/contractSize 推算)
"""
if not isinstance(position, dict) or not isinstance(out, dict):
return out
mark = _finite_or_none(out.get("mark_price"))
if mark is None or mark <= 0:
mp = parse_position_mark_price(position)
if mp is not None and mp > 0:
out["mark_price"] = round(mp, 8)
mark = mp
exchange_upnl = parse_position_unrealized_pnl(position)
if exchange_upnl is None:
exchange_upnl = _coerce_signed(out.get("unrealized_pnl"))
c = position_contracts(position)
if abs(c) < 1e-12:
return out
side = position_side_from_ccxt(position, c)
entry = parse_position_entry_price(position)
if entry is not None and entry > 0:
out["entry_price"] = round(entry, 8)
cs = contract_size if contract_size and contract_size > 0 else 1.0
upnl = resolve_position_display_upnl(
side, entry, mark, abs(c), cs, exchange_upnl
)
if upnl is not None:
out["unrealized_pnl"] = round(upnl, funds_decimals)
return out
def parse_position_mark_price(p: dict[str, Any]) -> float | None:
"""四所 ccxt 持仓统一解析标记价(与 crypto_monitor_* parse_ccxt_position_metrics 口径一致)。"""
if not isinstance(p, dict):
return None
info = p.get("info") or {}
if not isinstance(info, dict):
info = {}
mark = _coerce_float(
p.get("markPrice"),
p.get("mark_price"),
p.get("mark"),
info.get("markPx"),
info.get("mark_price"),
info.get("markPrice"),
)
if mark is not None:
return mark
contracts = position_contracts(p)
if abs(contracts) >= 1e-12:
notional = _finite_or_none(p.get("notional"))
if notional is not None and abs(notional) > 0:
return abs(notional) / abs(contracts)
return None
def build_position_marks_list(
positions: list,
*,
format_mark_display: Callable[[str, float], str] | None = None,
) -> list[dict[str, Any]]:
"""从 fetch_positions 结果生成 position_marks,供 price_snapshot / 中控合并。"""
out: list[dict[str, Any]] = []
for p in positions or []:
if not isinstance(p, dict):
continue
c = position_contracts(p)
if abs(c) < 1e-12:
continue
mark = parse_position_mark_price(p)
if mark is None or mark <= 0:
continue
sym = (p.get("symbol") or "").strip()
side = position_side_from_ccxt(p, c)
row: dict[str, Any] = {
"symbol": sym,
"side": side,
"mark_price": mark,
}
if format_mark_display and sym:
try:
row["mark_price_display"] = format_mark_display(sym, mark)
except Exception:
row["mark_price_display"] = f"{mark:g}"
else:
row["mark_price_display"] = f"{mark:g}"
out.append(row)
return out
+166
View File
@@ -0,0 +1,166 @@
"""
实例浏览器 SSO(复用 HUB_BRIDGE_TOKEN)。无 Flask 依赖,供中控 FastAPI 与各实例共用。
"""
from __future__ import annotations
import base64
import hashlib
import hmac
import json
import os
import secrets
import threading
import time
HUB_SSO_TTL_SEC = int(os.getenv("HUB_SSO_TTL_SEC", "7200"))
HUB_EMBED_BOOTSTRAP_TTL_SEC = int(os.getenv("HUB_EMBED_BOOTSTRAP_TTL_SEC", "120"))
_used_nonces: dict[str, float] = {}
_nonce_lock = threading.Lock()
def hub_bridge_token() -> str:
return (os.getenv("HUB_BRIDGE_TOKEN") or "").strip()
def safe_next_path(raw: str | None) -> str:
p = (raw or "/").strip()
if not p.startswith("/") or p.startswith("//"):
return "/"
if "://" in p:
return "/"
return p
def _sso_secret() -> str:
return hub_bridge_token()
def _b64url_encode(data: bytes) -> str:
return base64.urlsafe_b64encode(data).decode().rstrip("=")
def _b64url_decode(data: str) -> bytes:
pad = "=" * (-len(data) % 4)
return base64.urlsafe_b64decode(data + pad)
def _prune_used_nonces() -> None:
now = time.time()
with _nonce_lock:
dead = [k for k, exp in _used_nonces.items() if exp <= now]
for k in dead:
del _used_nonces[k]
def mint_hub_sso_token(exchange_key: str, next_path: str = "/") -> str | None:
secret = _sso_secret()
ex = (exchange_key or "").strip().lower()
if not secret or not ex:
return None
payload = {
"ex": ex,
"exp": int(time.time()) + max(60, HUB_SSO_TTL_SEC),
"nonce": secrets.token_urlsafe(16),
"next": safe_next_path(next_path),
}
body = _b64url_encode(json.dumps(payload, separators=(",", ":")).encode())
sig = hmac.new(secret.encode(), body.encode(), hashlib.sha256).hexdigest()
return f"{body}.{sig}"
def verify_hub_sso_token(
token: str | None, expected_exchange: str
) -> tuple[bool, str, str | None]:
secret = _sso_secret()
expected = (expected_exchange or "").strip().lower()
if not secret or not expected:
return False, "/", "未配置 HUB_BRIDGE_TOKEN"
raw = (token or "").strip()
if "." not in raw:
return False, "/", "token 无效"
body, sig = raw.rsplit(".", 1)
try:
expect_sig = hmac.new(secret.encode(), body.encode(), hashlib.sha256).hexdigest()
if not hmac.compare_digest(expect_sig, sig):
return False, "/", "签名校验失败"
payload = json.loads(_b64url_decode(body).decode())
except Exception:
return False, "/", "token 解析失败"
if not isinstance(payload, dict):
return False, "/", "payload 无效"
if str(payload.get("ex") or "").lower() != expected:
return False, "/", "实例不匹配"
try:
exp = int(payload.get("exp") or 0)
except (TypeError, ValueError):
return False, "/", "exp 无效"
if exp < int(time.time()):
return False, "/", "链接已过期"
nonce = str(payload.get("nonce") or "")
if not nonce:
return False, "/", "nonce 缺失"
_prune_used_nonces()
with _nonce_lock:
if nonce in _used_nonces:
return False, "/", "链接已使用"
_used_nonces[nonce] = float(exp)
return True, safe_next_path(str(payload.get("next") or "/")), None
def mint_hub_embed_bootstrap(exchange_key: str, next_path: str = "/") -> str | None:
"""iframe 内嵌登录引导 token(短效、单次),供 /hub-embed-auth 写入 SameSite=None Cookie。"""
secret = _sso_secret()
ex = (exchange_key or "").strip().lower()
if not secret or not ex:
return None
payload = {
"kind": "embed",
"ex": ex,
"exp": int(time.time()) + max(30, HUB_EMBED_BOOTSTRAP_TTL_SEC),
"nonce": secrets.token_urlsafe(16),
"next": safe_next_path(next_path),
}
body = _b64url_encode(json.dumps(payload, separators=(",", ":")).encode())
sig = hmac.new(secret.encode(), body.encode(), hashlib.sha256).hexdigest()
return f"{body}.{sig}"
def verify_hub_embed_bootstrap(
token: str | None, expected_exchange: str
) -> tuple[bool, str, str | None]:
secret = _sso_secret()
expected = (expected_exchange or "").strip().lower()
if not secret or not expected:
return False, "/", "未配置 HUB_BRIDGE_TOKEN"
raw = (token or "").strip()
if "." not in raw:
return False, "/", "token 无效"
body, sig = raw.rsplit(".", 1)
try:
expect_sig = hmac.new(secret.encode(), body.encode(), hashlib.sha256).hexdigest()
if not hmac.compare_digest(expect_sig, sig):
return False, "/", "签名校验失败"
payload = json.loads(_b64url_decode(body).decode())
except Exception:
return False, "/", "token 解析失败"
if not isinstance(payload, dict) or payload.get("kind") != "embed":
return False, "/", "token 类型无效"
if str(payload.get("ex") or "").lower() != expected:
return False, "/", "实例不匹配"
try:
exp = int(payload.get("exp") or 0)
except (TypeError, ValueError):
return False, "/", "exp 无效"
if exp < int(time.time()):
return False, "/", "链接已过期"
nonce = str(payload.get("nonce") or "")
if not nonce:
return False, "/", "nonce 缺失"
key = f"embed:{nonce}"
_prune_used_nonces()
with _nonce_lock:
if key in _used_nonces:
return False, "/", "链接已使用"
_used_nonces[key] = float(exp)
return True, safe_next_path(str(payload.get("next") or "/")), None
File diff suppressed because it is too large Load Diff
+638
View File
@@ -0,0 +1,638 @@
"""各实例当日平仓记录查询(供 hub_bridge /api/hub/trades/today 与中控 AI 聚合)。"""
from __future__ import annotations
from datetime import datetime, timedelta
from typing import Any, Callable, Optional
from lib.strategy.strategy_trade_labels import (
MONITOR_TYPE_ROLL,
MONITOR_TYPE_TREND_PULLBACK,
entry_reason_for_monitor_type,
)
from lib.trade.time_close_lib import TIME_CLOSE_RESULT
TRADE_COMPLETED_RESULTS = (
"止盈",
"止损",
"保本止盈",
"移动止盈",
"手动平仓",
"强制清仓",
"外部平仓",
TIME_CLOSE_RESULT,
)
def trading_day_from_dt(dt: datetime, reset_hour: int = 8) -> str:
"""与实例 get_trading_day 一致:小时 < reset_hour 归属上一日历日。"""
if dt.hour < reset_hour:
dt = dt - timedelta(days=1)
return dt.strftime("%Y-%m-%d")
def current_trading_day(*, now: datetime | None = None, reset_hour: int = 8) -> str:
return trading_day_from_dt(now or datetime.now(), reset_hour)
def parse_dt_for_trading_day(raw: Any) -> datetime | None:
if raw is None:
return None
s = str(raw).strip().replace("Z", "").replace("T", " ")
if not s:
return None
for fmt, ln in (("%Y-%m-%d %H:%M:%S", 19), ("%Y-%m-%d %H:%M", 16), ("%Y-%m-%d", 10)):
try:
return datetime.strptime(s[:ln], fmt)
except ValueError:
continue
return None
def trading_day_window_bounds(trading_day: str, reset_hour: int = 8) -> tuple[str, str]:
"""交易日 [reset_hour, 次日 reset_hour) 对应的北京时间字符串区间(闭区间)。"""
day = datetime.strptime((trading_day or "").strip()[:10], "%Y-%m-%d")
start = day.replace(hour=reset_hour, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1) - timedelta(seconds=1)
return start.strftime("%Y-%m-%d %H:%M:%S"), end.strftime("%Y-%m-%d %H:%M:%S")
def _row_dict(row, row_to_dict: Optional[Callable] = None) -> dict:
if row is None:
return {}
if row_to_dict:
try:
return dict(row_to_dict(row))
except Exception:
pass
try:
keys = row.keys() if hasattr(row, "keys") else ()
if keys:
return {k: row[k] for k in keys}
except Exception:
pass
try:
return dict(row)
except Exception:
return {}
def _effective_field(d: dict, reviewed_key: str, base_key: str, default: Any = None) -> Any:
rv = d.get(reviewed_key)
if rv is not None and str(rv).strip() != "":
return rv
bv = d.get(base_key)
if bv is not None and str(bv).strip() != "":
return bv
return default
def format_hold_minutes(minutes: Any) -> str:
try:
total = int(minutes or 0)
except (TypeError, ValueError):
return "0分钟"
if total <= 0:
return "0分钟"
hours = total // 60
mins = total % 60
if hours:
return f"{hours}小时{mins}分钟"
return f"{mins}分钟"
def _normalize_monitor_type_label(raw: Any) -> str:
mt = str(raw or "").strip()
if mt in ("trend_pullback", "trend"):
return MONITOR_TYPE_TREND_PULLBACK
if mt in ("roll",):
return MONITOR_TYPE_ROLL
return mt
def effective_entry_type(d: dict) -> str:
"""复盘开仓类型优先,与实例交易记录 effective_entry_reason 一致。"""
er = _effective_field(d, "reviewed_entry_reason", "entry_reason")
if er is not None and str(er).strip():
return str(er).strip()
mt = _normalize_monitor_type_label(d.get("monitor_type"))
er2 = entry_reason_for_monitor_type(mt)
if er2:
return er2
kst = str(d.get("key_signal_type") or "").strip()
if kst:
return kst
legacy = str(d.get("entry_type") or "").strip()
if legacy and legacy not in ("trend_pullback", "roll", "trend"):
return _normalize_monitor_type_label(legacy) or legacy
return mt
def display_entry_type_label(d: dict) -> str:
"""档案/列表展示用开仓类型(不回落为「下单监控」若已有复盘或建档类型)。"""
label = effective_entry_type(d).strip()
if not label:
return ""
return _normalize_monitor_type_label(label) or label
def effective_hold_minutes(
d: dict,
*,
opened_ms: int | None = None,
closed_ms: int | None = None,
) -> int:
hm = _effective_field(d, "reviewed_hold_minutes", "hold_minutes")
if hm is not None and str(hm).strip() != "":
try:
return max(0, int(hm))
except (TypeError, ValueError):
pass
hs = _effective_field(d, "reviewed_hold_seconds", "hold_seconds")
if hs is not None and str(hs).strip() != "":
try:
return max(0, int(int(hs) // 60))
except (TypeError, ValueError):
pass
oms = opened_ms if opened_ms is not None else d.get("opened_at_ms")
cms = closed_ms if closed_ms is not None else d.get("closed_at_ms")
try:
oms_i = int(oms) if oms not in (None, "") else None
cms_i = int(cms) if cms not in (None, "") else None
except (TypeError, ValueError):
oms_i = cms_i = None
if oms_i and cms_i and cms_i > oms_i:
return max(0, int((cms_i - oms_i) // 60_000))
return 0
def _effective_pnl(d: dict) -> float:
reviewed = d.get("reviewed_pnl_amount")
if reviewed is not None and str(reviewed).strip() != "":
try:
return float(reviewed)
except (TypeError, ValueError):
pass
ex = d.get("exchange_realized_pnl")
if ex is not None and str(ex).strip() != "":
try:
return float(ex)
except (TypeError, ValueError):
pass
try:
return float(d.get("pnl_amount") or 0)
except (TypeError, ValueError):
return 0.0
def _trade_close_dt(d: dict) -> datetime | None:
raw = _effective_field(d, "reviewed_closed_at", "closed_at")
if raw is None or str(raw).strip() == "":
raw = d.get("created_at") or d.get("opened_at")
return parse_dt_for_trading_day(raw)
def _normalize_trade_row(
d: dict,
*,
trading_day: str,
reset_hour: int,
) -> dict[str, Any] | None:
effective_result = str(_effective_field(d, "reviewed_result", "result") or "").strip()
if effective_result not in TRADE_COMPLETED_RESULTS:
return None
close_dt = _trade_close_dt(d)
if not close_dt:
return None
if trading_day_from_dt(close_dt, reset_hour) != trading_day:
return None
pnl = _effective_pnl(d)
closed_at = _effective_field(d, "reviewed_closed_at", "closed_at")
opened_at = _effective_field(d, "reviewed_opened_at", "opened_at")
return {
"symbol": d.get("symbol"),
"direction": d.get("direction"),
"result": effective_result,
"pnl_amount": round(pnl, 4),
"closed_at": closed_at,
"opened_at": opened_at,
"monitor_type": d.get("monitor_type"),
"actual_rr": d.get("actual_rr"),
"planned_rr": d.get("planned_rr"),
"trade_style": d.get("trade_style"),
"entry_reason": d.get("entry_reason"),
"reviewed": bool(d.get("reviewed_at") or d.get("reviewed_result")),
}
def fetch_trades_for_trading_day(
conn,
trading_day: str,
*,
row_to_dict_fn: Optional[Callable] = None,
reset_hour: int = 8,
limit: int = 200,
) -> list[dict[str, Any]]:
"""返回指定交易日的已平仓记录(与 /records 交易记录一致,复盘字段优先)。"""
day = (trading_day or "").strip()[:10]
if not day:
return []
lim = max(1, min(int(limit or 200), 500))
start_bj, end_bj = trading_day_window_bounds(day, reset_hour)
ts_expr = "REPLACE(COALESCE(reviewed_closed_at, closed_at, created_at, opened_at), 'T', ' ')"
rows = conn.execute(
f"""
SELECT symbol, direction, result, reviewed_result, pnl_amount, reviewed_pnl_amount,
exchange_realized_pnl, closed_at, reviewed_closed_at, opened_at, reviewed_opened_at,
created_at, monitor_type, actual_rr, planned_rr, trade_style, entry_reason,
reviewed_at
FROM trade_records
WHERE {ts_expr} >= ? AND {ts_expr} <= ?
ORDER BY {ts_expr} ASC
LIMIT ?
""",
(start_bj, end_bj, lim * 3),
).fetchall()
out: list[dict[str, Any]] = []
for row in rows:
d = _row_dict(row, row_to_dict_fn)
norm = _normalize_trade_row(d, trading_day=day, reset_hour=reset_hour)
if norm:
out.append(norm)
if len(out) >= lim:
break
return out
def _normalize_archive_trade_row(
d: dict,
*,
exchange_key: str = "",
reset_hour: int = 8,
) -> dict[str, Any] | None:
"""全历史档案用:已平仓记录(不按交易日截断)。"""
effective_result = str(_effective_field(d, "reviewed_result", "result") or "").strip()
if effective_result not in TRADE_COMPLETED_RESULTS:
return None
close_dt = _trade_close_dt(d)
if not close_dt:
return None
pnl = _effective_pnl(d)
closed_at = _effective_field(d, "reviewed_closed_at", "closed_at")
opened_at = _effective_field(d, "reviewed_opened_at", "opened_at")
opened_ms = d.get("opened_at_ms")
closed_ms = d.get("closed_at_ms")
if opened_ms in (None, ""):
odt = parse_dt_for_trading_day(opened_at)
opened_ms = int(odt.timestamp() * 1000) if odt else None
if closed_ms in (None, ""):
cdt = close_dt
closed_ms = int(cdt.timestamp() * 1000) if cdt else None
try:
trade_id = int(d.get("id"))
except (TypeError, ValueError):
return None
opened_ms_i = int(opened_ms) if opened_ms else None
closed_ms_i = int(closed_ms) if closed_ms else None
hold_m = effective_hold_minutes(d, opened_ms=opened_ms_i, closed_ms=closed_ms_i)
entry_type = display_entry_type_label(d)
reviewed = bool(
d.get("reviewed_at")
or d.get("reviewed_result")
or d.get("reviewed_opened_at")
or d.get("reviewed_closed_at")
or d.get("reviewed_entry_reason")
or d.get("reviewed_hold_minutes")
)
return {
"id": trade_id,
"exchange_key": (exchange_key or "").strip().lower(),
"symbol": (d.get("symbol") or "").strip().upper(),
"direction": d.get("direction"),
"result": effective_result,
"pnl_amount": round(pnl, 4),
"closed_at": closed_at,
"opened_at": opened_at,
"opened_at_ms": opened_ms_i,
"closed_at_ms": closed_ms_i,
"monitor_type": _normalize_monitor_type_label(d.get("monitor_type")),
"entry_type": entry_type,
"entry_reason": entry_type,
"hold_minutes": hold_m,
"hold_minutes_text": format_hold_minutes(hold_m),
"actual_rr": d.get("actual_rr"),
"planned_rr": d.get("planned_rr"),
"trade_style": d.get("trade_style"),
"trigger_price": d.get("trigger_price"),
"stop_loss": _effective_field(d, "reviewed_stop_loss", "stop_loss"),
"take_profit": _effective_field(d, "reviewed_take_profit", "take_profit"),
"reviewed": reviewed,
"trading_day": trading_day_from_dt(close_dt, reset_hour),
"exchange_turnover_usdt": d.get("exchange_turnover_usdt"),
"exchange_commission_usdt": d.get("exchange_commission_usdt"),
}
_SNAPSHOT_STATUS_TO_RESULT = {
"stopped_sl": "止损",
"stopped_tp": "止盈",
"stopped_manual": "手动平仓",
"stopped_external": "外部平仓",
}
def _table_columns(conn, table: str) -> set[str]:
try:
rows = conn.execute(f"PRAGMA table_info({table})").fetchall()
except Exception:
return set()
out: set[str] = set()
for r in rows:
try:
out.add(str(r[1]))
except (IndexError, TypeError):
try:
out.add(str(r["name"]))
except Exception:
continue
return out
def _archive_ts_expr(cols: set[str]) -> str:
parts = [c for c in ("reviewed_closed_at", "closed_at", "created_at", "opened_at") if c in cols]
if not parts:
return "''"
return f"REPLACE(COALESCE({', '.join(parts)}), 'T', ' ')"
def _archive_trade_select_sql(cols: set[str]) -> str:
wanted = [
"id",
"symbol",
"direction",
"result",
"reviewed_result",
"pnl_amount",
"reviewed_pnl_amount",
"exchange_realized_pnl",
"closed_at",
"reviewed_closed_at",
"opened_at",
"reviewed_opened_at",
"opened_at_ms",
"closed_at_ms",
"created_at",
"monitor_type",
"key_signal_type",
"actual_rr",
"planned_rr",
"trade_style",
"entry_reason",
"reviewed_entry_reason",
"hold_minutes",
"reviewed_hold_minutes",
"hold_seconds",
"reviewed_hold_seconds",
"trigger_price",
"stop_loss",
"take_profit",
"reviewed_stop_loss",
"reviewed_take_profit",
"reviewed_at",
"trend_plan_id",
"exchange_turnover_usdt",
"exchange_commission_usdt",
]
select_cols = [c for c in wanted if c in cols]
if "id" not in select_cols:
select_cols = ["id"] + select_cols
return ", ".join(select_cols)
def _existing_trend_plan_ids(conn) -> set[int]:
cols = _table_columns(conn, "trade_records")
if "trend_plan_id" not in cols:
return set()
rows = conn.execute(
"SELECT DISTINCT trend_plan_id FROM trade_records WHERE trend_plan_id IS NOT NULL"
).fetchall()
out: set[int] = set()
for row in rows:
d = _row_dict(row)
try:
out.add(int(d.get("trend_plan_id")))
except (TypeError, ValueError):
continue
return out
def _normalize_snapshot_archive_row(
snap: dict,
*,
exchange_key: str = "",
reset_hour: int = 8,
) -> dict[str, Any] | None:
result = str(snap.get("result_label") or "").strip()
if not result:
result = _SNAPSHOT_STATUS_TO_RESULT.get(
str(snap.get("status_at_close") or "").strip(), ""
)
if result not in TRADE_COMPLETED_RESULTS:
return None
closed_at = snap.get("closed_at")
close_dt = parse_dt_for_trading_day(closed_at)
if not close_dt:
return None
opened_at = snap.get("opened_at")
opened_ms = _parse_ms_from_row(snap.get("opened_at"))
closed_ms = _parse_ms_from_row(closed_at)
try:
snap_id = int(snap.get("id"))
except (TypeError, ValueError):
return None
try:
pnl = float(snap.get("pnl_amount") or 0)
except (TypeError, ValueError):
pnl = 0.0
st = str(snap.get("strategy_type") or "").strip()
monitor_type = _normalize_monitor_type_label(
"trend_pullback" if st == "trend_pullback" else ("roll" if st == "roll" else st)
)
hold_m = effective_hold_minutes(
{},
opened_ms=opened_ms,
closed_ms=closed_ms,
)
entry_type = entry_reason_for_monitor_type(monitor_type) or monitor_type
return {
"id": -snap_id,
"exchange_key": (exchange_key or "").strip().lower(),
"symbol": (snap.get("symbol") or "").strip().upper(),
"direction": snap.get("direction"),
"result": result,
"pnl_amount": round(pnl, 4),
"closed_at": closed_at,
"opened_at": opened_at,
"opened_at_ms": opened_ms,
"closed_at_ms": closed_ms,
"monitor_type": monitor_type,
"entry_type": entry_type,
"entry_reason": entry_type,
"hold_minutes": hold_m,
"hold_minutes_text": format_hold_minutes(hold_m),
"from_snapshot": True,
"snapshot_id": snap_id,
"trend_plan_id": snap.get("source_id"),
"reviewed": False,
"trading_day": trading_day_from_dt(close_dt, reset_hour),
}
def _parse_ms_from_row(raw: Any) -> int | None:
if raw in (None, ""):
return None
try:
if isinstance(raw, (int, float)):
v = int(raw)
return v if v > 1_000_000_000_000 else v * 1000
except (TypeError, ValueError):
pass
dt = parse_dt_for_trading_day(raw)
return int(dt.timestamp() * 1000) if dt else None
def _fetch_strategy_snapshots_for_archive(
conn,
*,
exchange_key: str = "",
days: int = 365,
reset_hour: int = 8,
limit: int = 2000,
skip_plan_ids: set[int] | None = None,
) -> list[dict[str, Any]]:
cols = _table_columns(conn, "strategy_trade_snapshots")
if not cols:
return []
lim = max(1, min(int(limit or 2000), 5000))
day_span = max(1, min(int(days or 365), 3650))
cutoff = datetime.now() - timedelta(days=day_span)
cutoff_s = cutoff.strftime("%Y-%m-%d %H:%M:%S")
ts_expr = "REPLACE(COALESCE(closed_at, opened_at, created_at), 'T', ' ')"
rows = conn.execute(
f"""
SELECT * FROM strategy_trade_snapshots
WHERE {ts_expr} >= ?
ORDER BY {ts_expr} DESC
LIMIT ?
""",
(cutoff_s, lim * 2),
).fetchall()
skip = skip_plan_ids or set()
out: list[dict[str, Any]] = []
for row in rows:
d = _row_dict(row)
try:
source_id = int(d.get("source_id") or 0)
except (TypeError, ValueError):
source_id = 0
if source_id > 0 and source_id in skip:
continue
norm = _normalize_snapshot_archive_row(
d, exchange_key=exchange_key, reset_hour=reset_hour
)
if norm:
out.append(norm)
if len(out) >= lim:
break
return out
def fetch_trades_for_archive(
conn,
*,
exchange_key: str = "",
days: int = 365,
row_to_dict_fn: Optional[Callable] = None,
reset_hour: int = 8,
limit: int = 2000,
include_strategy_snapshots: bool = True,
) -> list[dict[str, Any]]:
"""返回近 N 天已平仓记录(trade_records + 未落库的 strategy 快照)。"""
lim = max(1, min(int(limit or 2000), 5000))
day_span = max(1, min(int(days or 365), 3650))
cutoff = datetime.now() - timedelta(days=day_span)
cutoff_s = cutoff.strftime("%Y-%m-%d %H:%M:%S")
cols = _table_columns(conn, "trade_records")
if not cols:
records: list[dict[str, Any]] = []
else:
ts_expr = _archive_ts_expr(cols)
sql = f"""
SELECT {_archive_trade_select_sql(cols)}
FROM trade_records
WHERE {ts_expr} >= ?
ORDER BY {ts_expr} DESC
LIMIT ?
"""
rows = conn.execute(sql, (cutoff_s, lim * 2)).fetchall()
records = []
for row in rows:
d = _row_dict(row, row_to_dict_fn)
norm = _normalize_archive_trade_row(
d, exchange_key=exchange_key, reset_hour=reset_hour
)
if norm:
records.append(norm)
if len(records) >= lim:
break
if not include_strategy_snapshots:
return records
skip_ids = _existing_trend_plan_ids(conn)
for rec in records:
try:
pid = int(rec.get("trend_plan_id") or 0)
except (TypeError, ValueError):
pid = 0
if pid > 0:
skip_ids.add(pid)
snaps = _fetch_strategy_snapshots_for_archive(
conn,
days=days,
exchange_key=exchange_key,
reset_hour=reset_hour,
limit=max(0, lim - len(records)),
skip_plan_ids=skip_ids,
)
merged = records + snaps
merged.sort(
key=lambda x: int(x.get("closed_at_ms") or 0),
reverse=True,
)
return merged[:lim]
def summarize_trades(trades: list[dict]) -> dict[str, Any]:
"""单笔列表 → 笔数 / 盈亏 / 胜败统计。"""
total_pnl = 0.0
win = loss = flat = 0
for t in trades or []:
try:
pnl = float(t.get("pnl_amount") or 0)
except (TypeError, ValueError):
pnl = 0.0
total_pnl += pnl
if pnl > 1e-9:
win += 1
elif pnl < -1e-9:
loss += 1
else:
flat += 1
return {
"closed_count": len(trades or []),
"win_count": win,
"loss_count": loss,
"flat_count": flat,
"total_pnl_u": round(total_pnl, 4),
}
+595
View File
@@ -0,0 +1,595 @@
"""行情区:各交易所 USDT 永续昨日成交额 Top N(每日 8:00 快照)。"""
from __future__ import annotations
import json
import os
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Callable
from zoneinfo import ZoneInfo
from lib.hub.hub_trades_lib import trading_day_from_dt
TOP_N_DEFAULT = 20
CACHE_VERSION = 3
LIQUIDITY_RANK_CACHE_VERSION = 1
def volume_rank_reset_hour() -> int:
try:
return max(0, min(23, int(os.getenv("HUB_VOLUME_RANK_RESET_HOUR", "8"))))
except ValueError:
return 8
def volume_rank_timezone() -> ZoneInfo:
name = (os.getenv("HUB_VOLUME_RANK_TZ") or "Asia/Shanghai").strip() or "Asia/Shanghai"
try:
return ZoneInfo(name)
except Exception:
return ZoneInfo("Asia/Shanghai")
def rank_date_label(*, now: datetime | None = None, reset_hour: int | None = None) -> str:
"""8 点更新后展示的「昨日」交易日(与 TRADING_DAY_RESET_HOUR 口径一致)。"""
rh = volume_rank_reset_hour() if reset_hour is None else reset_hour
tz = volume_rank_timezone()
dt = now.astimezone(tz) if now else datetime.now(tz)
cur_td = trading_day_from_dt(dt.replace(tzinfo=None), rh)
cur = datetime.strptime(cur_td, "%Y-%m-%d").date()
return (cur - timedelta(days=1)).isoformat()
def seconds_until_next_reset(
*,
now: datetime | None = None,
reset_hour: int | None = None,
) -> float:
rh = volume_rank_reset_hour() if reset_hour is None else reset_hour
tz = volume_rank_timezone()
dt = now.astimezone(tz) if now else datetime.now(tz)
nxt = dt.replace(hour=rh, minute=0, second=0, microsecond=0)
if dt >= nxt:
nxt += timedelta(days=1)
return max(1.0, (nxt - dt).total_seconds())
def default_cache_path() -> Path:
raw = (os.getenv("HUB_VOLUME_RANK_CACHE_PATH") or "").strip()
if raw:
return Path(raw)
hub_dir = Path(__file__).resolve().parent / "manual_trading_hub" / "data"
hub_dir.mkdir(parents=True, exist_ok=True)
return hub_dir / "hub_volume_rank.json"
def _safe_float(v: Any) -> float | None:
try:
n = float(v)
return n if n == n else None
except (TypeError, ValueError):
return None
def _ticker_base(sym_text: str) -> str:
s = str(sym_text or "").upper().strip()
if ":" in s:
s = s.split(":", 1)[0]
if "/" in s:
return s.split("/", 1)[0].strip()
if "-" in s:
return s.split("-", 1)[0].strip()
if s.endswith("USDT"):
return s[:-4].strip()
return s
def _hub_symbol_from_base(base: str, quote: str = "USDT") -> str:
b = str(base or "").strip().upper()
q = str(quote or "USDT").strip().upper()
return f"{b}/{q}" if b else ""
def _hub_symbol_from_market(market: dict | None, fallback_symbol: str) -> str:
if market:
base = str(market.get("base") or "").strip().upper()
quote = str(market.get("quote") or "USDT").strip().upper()
if base:
return f"{base}/{quote}"
fb = str(fallback_symbol or "").upper().strip()
if ":" in fb:
fb = fb.split(":", 1)[0]
if "/" in fb:
return fb
base = _ticker_base(fb)
return f"{base}/USDT" if base else fb
def _okx_turnover_usdt(row: dict | None) -> float | None:
"""OKX SWAP:成交额(USDT) ≈ volCcy24h(基础币) × last。"""
if not isinstance(row, dict):
return None
base_vol = _safe_float(row.get("volCcy24h"))
if base_vol is None or base_vol <= 0:
return None
last = _safe_float(row.get("last") or row.get("lastPx"))
if last is None or last <= 0:
return None
return float(base_vol * last)
def _quote_volume_from_ticker(
ticker: dict | None,
market: dict | None,
*,
exchange_id: str = "",
) -> float | None:
ex_id = str(exchange_id or "").lower()
t = ticker or {}
info = t.get("info") if isinstance(t.get("info"), dict) else {}
if ex_id == "okx":
row = dict(info)
if row.get("last") is None:
row["last"] = t.get("last")
qv = _okx_turnover_usdt(row)
if qv is not None and qv > 0:
return qv
qv = _safe_float(t.get("quoteVolume"))
if qv is not None and qv > 0:
return qv
if ex_id in ("gateio", "gate"):
for key in (
"volume_24h_quote",
"volume_24h_settle",
"quote_volume",
"vol_24h",
"turnover",
):
qv = _safe_float(info.get(key))
if qv is not None and qv > 0:
return qv
for key in ("quoteVolume", "volCcy24h", "vol24h", "turnover24h", "amount24", "turnover"):
qv = _safe_float(info.get(key))
if qv is not None and qv > 0:
if key == "volCcy24h" and ex_id == "okx":
last = _safe_float(info.get("last") or info.get("lastPx") or t.get("last"))
if last:
return qv * last
return qv
bv = _safe_float(t.get("baseVolume"))
lp = _safe_float(t.get("last")) or _safe_float(t.get("close"))
if bv is not None and lp is not None and bv > 0 and lp > 0:
return bv * lp
if info:
bv = _safe_float(info.get("volCcy24h") or info.get("vol24h") or info.get("volume"))
lp = _safe_float(info.get("last") or info.get("lastPx") or info.get("markPrice"))
if bv is not None and lp is not None and bv > 0 and lp > 0:
return bv * lp
return None
def _is_usdt_linear_swap(market: dict | None, symbol: str) -> bool:
if not market:
su = str(symbol or "").upper()
return "USDT" in su and (":USDT" in su or "/USDT" in su or su.endswith("USDT"))
if not market.get("swap") and market.get("type") not in ("swap", "future"):
return False
if str(market.get("quote") or "").upper() != "USDT":
return False
if market.get("linear") is False:
return False
if market.get("active") is False:
return False
settle = str(market.get("settle") or "").upper()
if settle and settle != "USDT":
return False
return True
def _lookup_ticker(tickers: dict, sym: str, market: dict | None) -> dict | None:
if not tickers:
return None
t = tickers.get(sym)
if t:
return t
if not market:
return None
base = market.get("base")
quote = market.get("quote") or "USDT"
settle = market.get("settle") or quote
candidates = [
sym,
f"{base}/{quote}:{settle}",
f"{base}/{quote}",
f"{base}{quote}",
market.get("id"),
]
for key in candidates:
if not key:
continue
t = tickers.get(key)
if t:
return t
return None
def _merge_scores(scored: dict[str, tuple[str, float]]) -> list[tuple[str, str, float]]:
rows = [(sym, base, vol) for base, (sym, vol) in scored.items() if sym and base and vol > 0]
rows.sort(key=lambda x: x[2], reverse=True)
return rows
def _scores_from_okx(exchange) -> list[tuple[str, str, float]]:
by_base: dict[str, tuple[str, float]] = {}
if hasattr(exchange, "publicGetMarketTickers"):
try:
resp = exchange.publicGetMarketTickers({"instType": "SWAP"})
for row in (resp or {}).get("data") or []:
if not isinstance(row, dict):
continue
inst = str(row.get("instId") or "").upper()
parts = inst.split("-")
if len(parts) < 3 or parts[-1] != "SWAP" or parts[1] != "USDT":
continue
base = parts[0].strip()
if not base:
continue
qv = _okx_turnover_usdt(row)
if qv is None or qv <= 0:
continue
sym = _hub_symbol_from_base(base)
prev = by_base.get(base)
if prev is None or qv > prev[1]:
by_base[base] = (sym, float(qv))
if by_base:
return _merge_scores(by_base)
except Exception:
pass
try:
tickers = exchange.fetch_tickers(params={"instType": "SWAP"})
except Exception:
tickers = exchange.fetch_tickers()
return _scores_from_markets(exchange, tickers or {}, "okx")
def _scores_from_binance(exchange) -> list[tuple[str, str, float]]:
by_base: dict[str, tuple[str, float]] = {}
if hasattr(exchange, "fapiPublicGetTicker24hr"):
try:
rows = exchange.fapiPublicGetTicker24hr()
if isinstance(rows, list):
for row in rows:
if not isinstance(row, dict):
continue
raw = str(row.get("symbol") or "").upper()
if not raw.endswith("USDT"):
continue
base = raw[:-4]
if not base:
continue
qv = _safe_float(row.get("quoteVolume"))
if qv is None or qv <= 0:
bv = _safe_float(row.get("volume"))
lp = _safe_float(row.get("lastPrice") or row.get("weightedAvgPrice"))
if bv and lp:
qv = bv * lp
if qv is None or qv <= 0:
continue
sym = _hub_symbol_from_base(base)
prev = by_base.get(base)
if prev is None or qv > prev[1]:
by_base[base] = (sym, float(qv))
if by_base:
return _merge_scores(by_base)
except Exception:
pass
return []
def _scores_from_gate(exchange) -> list[tuple[str, str, float]]:
by_base: dict[str, tuple[str, float]] = {}
for method_name in ("publicFuturesGetSettleTickers", "publicFuturesGetUsdtTickers"):
fn = getattr(exchange, method_name, None)
if not callable(fn):
continue
try:
rows = fn({"settle": "usdt"})
if isinstance(rows, list):
for row in rows:
if not isinstance(row, dict):
continue
contract = str(row.get("contract") or row.get("name") or "").upper()
if not contract:
continue
base = contract.replace("_USDT", "").replace("USDT", "").strip("_")
if not base:
continue
qv = _safe_float(row.get("volume_24h_quote") or row.get("volume_24h_settle"))
if qv is None or qv <= 0:
bv = _safe_float(row.get("volume_24h_base"))
lp = _safe_float(row.get("last") or row.get("mark_price"))
if bv and lp:
qv = bv * lp
if qv is None or qv <= 0:
continue
sym = _hub_symbol_from_base(base)
prev = by_base.get(base)
if prev is None or qv > prev[1]:
by_base[base] = (sym, float(qv))
if by_base:
return _merge_scores(by_base)
except Exception:
continue
return []
def _scores_from_markets(
exchange,
tickers: dict,
exchange_id: str,
) -> list[tuple[str, str, float]]:
by_base: dict[str, tuple[str, float]] = {}
markets = getattr(exchange, "markets", None) or {}
for sym, mk in markets.items():
try:
if not _is_usdt_linear_swap(mk, sym):
continue
ticker = _lookup_ticker(tickers, sym, mk)
qv = _quote_volume_from_ticker(ticker, mk, exchange_id=exchange_id)
if qv is None or qv <= 0:
continue
hub_sym = _hub_symbol_from_market(mk, sym)
base = _ticker_base(hub_sym)
if not base:
continue
prev = by_base.get(base)
if prev is None or qv > prev[1]:
by_base[base] = (hub_sym, float(qv))
except Exception:
continue
return _merge_scores(by_base)
def _collect_scores(exchange, exchange_id: str) -> list[tuple[str, str, float]]:
ex_id = str(exchange_id or "").lower()
if ex_id == "okx":
return _scores_from_okx(exchange)
if ex_id == "binance":
return _scores_from_binance(exchange)
if ex_id in ("gateio", "gate", "gate_bot"):
return _scores_from_gate(exchange)
tickers = exchange.fetch_tickers()
return _scores_from_markets(exchange, tickers or {}, ex_id)
def _uses_lightweight_volume_scores(exchange_id: str) -> bool:
ex_id = str(exchange_id or "").lower()
return ex_id in ("okx", "binance", "gateio", "gate", "gate_bot")
def build_usdt_swap_volume_ranks(
exchange,
ensure_markets_loaded: Callable[[], None],
*,
exchange_id: str | None = None,
) -> tuple[dict[str, int], int]:
"""
全市场 USDT 永续 24h 成交额排名(base -> rank)。
优先各所轻量 ticker API,避免 fetch_tickers() 拉全市场(Gate/Binance 内存优化)。
"""
ex_id = str(exchange_id or getattr(exchange, "id", "") or "").lower()
if not _uses_lightweight_volume_scores(ex_id):
ensure_markets_loaded()
scored = _collect_scores(exchange, ex_id)
ranks: dict[str, int] = {}
for idx, (_sym, base, _qv) in enumerate(scored, 1):
if base and base not in ranks:
ranks[base] = idx
return ranks, len(scored)
def resolve_daily_volume_rank(
target_base: str,
cache: dict[str, Any],
*,
now_ts: float,
ttl_sec: float,
exchange,
ensure_markets_loaded: Callable[[], None],
exchange_id: str | None = None,
cache_version: int = LIQUIDITY_RANK_CACHE_VERSION,
) -> tuple[int | None, int]:
"""关键位门控:按 base 查 24h 成交额全市场排名;cache 带 TTL。"""
cached_ok = (
cache.get("version") == cache_version
and cache.get("updated_at")
and now_ts - float(cache["updated_at"]) < ttl_sec
)
if not cached_ok:
try:
ranks, total = build_usdt_swap_volume_ranks(
exchange,
ensure_markets_loaded,
exchange_id=exchange_id,
)
if total > 0 and ranks:
cache["ranks"] = ranks
cache["total"] = total
cache["version"] = cache_version
cache["updated_at"] = now_ts
except Exception:
pass
ranks = cache.get("ranks") or {}
total = int(cache.get("total") or 0)
base = str(target_base or "").strip().upper()
return ranks.get(base), total
def fetch_usdt_swap_volume_rank(
exchange,
ensure_markets_loaded: Callable[[], None],
*,
top_n: int = TOP_N_DEFAULT,
rank_date: str | None = None,
exchange_id: str | None = None,
) -> dict[str, Any]:
"""从 ccxt 拉全市场 USDT 永续 ticker,按 24h 成交额(USDT) 取 Top N。"""
top_n = max(1, min(int(top_n or TOP_N_DEFAULT), 100))
ensure_markets_loaded()
ex_id = str(exchange_id or getattr(exchange, "id", "") or "").lower()
try:
scored = _collect_scores(exchange, ex_id)
except Exception as e:
return {"ok": False, "msg": str(e)}
items = []
for idx, (hub_sym, base, qv) in enumerate(scored[:top_n], 1):
items.append(
{
"rank": idx,
"symbol": hub_sym,
"base": base,
"volume_quote": round(qv, 4),
}
)
return {
"ok": True,
"rank_date": rank_date or rank_date_label(),
"items": items,
"total_symbols": len(scored),
"exchange_id": ex_id,
"fetched_at": datetime.now(volume_rank_timezone()).isoformat(timespec="seconds"),
}
def format_volume_quote(value: float | None) -> str:
n = _safe_float(value)
if n is None or n <= 0:
return ""
if n >= 1e9:
return f"{n / 1e9:.2f}B"
if n >= 1e6:
return f"{n / 1e6:.2f}M"
if n >= 1e3:
return f"{n / 1e3:.2f}K"
return f"{n:.0f}"
def load_volume_rank_cache(path: Path | None = None) -> dict[str, Any]:
p = path or default_cache_path()
if not p.is_file():
return {"version": CACHE_VERSION, "exchanges": {}}
try:
data = json.loads(p.read_text(encoding="utf-8"))
if not isinstance(data, dict):
return {"version": CACHE_VERSION, "exchanges": {}}
if int(data.get("version") or 0) < CACHE_VERSION:
return {"version": CACHE_VERSION, "exchanges": {}}
data.setdefault("version", CACHE_VERSION)
data.setdefault("exchanges", {})
return data
except Exception:
return {"version": CACHE_VERSION, "exchanges": {}}
def save_volume_rank_cache(data: dict[str, Any], path: Path | None = None) -> None:
p = path or default_cache_path()
p.parent.mkdir(parents=True, exist_ok=True)
payload = dict(data)
payload["version"] = CACHE_VERSION
payload["updated_at"] = datetime.now(volume_rank_timezone()).isoformat(timespec="seconds")
p.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
def merge_exchange_rank(
cache: dict[str, Any],
exchange_key: str,
payload: dict[str, Any],
) -> dict[str, Any]:
ex_k = str(exchange_key or "").strip().lower()
if not ex_k or not payload.get("ok"):
return cache
exchanges = dict(cache.get("exchanges") or {})
exchanges[ex_k] = {
"rank_date": payload.get("rank_date"),
"items": payload.get("items") or [],
"total_symbols": int(payload.get("total_symbols") or 0),
"fetched_at": payload.get("fetched_at"),
"error": None,
}
out = dict(cache)
out["exchanges"] = exchanges
out["rank_date"] = payload.get("rank_date") or cache.get("rank_date")
return out
def _exchange_rank_row_stale(row: dict[str, Any] | None) -> bool:
if not row:
return True
items = row.get("items") or []
if len(items) < TOP_N_DEFAULT:
return True
total = int(row.get("total_symbols") or 0)
if total > 0 and total < TOP_N_DEFAULT:
return True
return False
def cache_needs_refresh(
cache: dict[str, Any],
*,
expected_rank_date: str | None = None,
required_keys: list[str] | None = None,
) -> bool:
expected = expected_rank_date or rank_date_label()
if int(cache.get("version") or 0) < CACHE_VERSION:
return True
exchanges = cache.get("exchanges") or {}
if not exchanges:
return True
if str(cache.get("rank_date") or "") != expected:
return True
keys = required_keys or list(exchanges.keys())
if not keys:
return True
for key in keys:
ex_k = str(key or "").strip().lower()
if not ex_k:
continue
if _exchange_rank_row_stale(exchanges.get(ex_k)):
return True
return False
def get_cached_rank(
cache: dict[str, Any],
exchange_key: str,
*,
top_n: int = TOP_N_DEFAULT,
) -> dict[str, Any]:
ex_k = str(exchange_key or "").strip().lower()
ex_data = (cache.get("exchanges") or {}).get(ex_k) or {}
items = list(ex_data.get("items") or [])[: max(1, int(top_n))]
stale = _exchange_rank_row_stale(ex_data)
return {
"ok": True,
"exchange_key": ex_k,
"rank_date": ex_data.get("rank_date") or cache.get("rank_date"),
"updated_at": cache.get("updated_at"),
"items": items,
"item_count": len(items),
"expected_count": int(top_n),
"total_symbols": int(ex_data.get("total_symbols") or 0),
"stale": stale,
"error": ex_data.get("error"),
}