Align margin display with CTP counter rates and position margin.

Read margin ratios from CTP instrument query and margin-rate API instead of vnpy ContractData (which lacks ratios). Keep occupied margin on position UseMargin; use per-lot max rate for recommend table.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-29 10:21:44 +08:00
parent 71c480a587
commit 19676943d0
5 changed files with 332 additions and 50 deletions
+287 -20
View File
@@ -8,6 +8,7 @@ from __future__ import annotations
import logging
import os
import re
import threading
import time
from collections import deque
@@ -237,6 +238,12 @@ class CtpBridge:
self._commission_waiters: dict[int, threading.Event] = {}
self._commission_lists: dict[int, list] = {}
self._commission_hooked = False
self._margin_rate_waiters: dict[int, threading.Event] = {}
self._margin_rate_lists: dict[int, list] = {}
self._margin_rate_hooked = False
self._instrument_hooked = False
self._instrument_margin_ratios: dict[str, dict[str, float]] = {}
self._margin_per_lot: dict[str, float] = {}
self._subscribed: set[str] = set()
self._last_position_query_ts: float = 0.0
self._position_margins: dict[str, float] = {}
@@ -305,6 +312,10 @@ class CtpBridge:
raw = float(getattr(pos, attr, 0) or 0)
if raw > 0:
self._position_margins[self._position_margin_key(sym, d)] = raw
if vol > 0:
self._margin_per_lot[self._position_margin_key(sym, d)] = round(
raw / vol, 2,
)
break
except Exception as exc:
logger.debug("position margin cache: %s", exc)
@@ -637,6 +648,7 @@ class CtpBridge:
"请更新 .env 中 SIMNOW_TD_ADDRESS 为官网最新地址,"
"并在服务器执行 nc -zv 验证出网。"
)
self._ensure_instrument_margin_hooks()
self._engine.connect(setting, GATEWAY_NAME)
if self._wait_connected(mode, ctp_logs):
self._connected_mode = mode
@@ -943,6 +955,184 @@ class CtpBridge:
"""批量查询全部合约手续费(InstrumentID 留空)。"""
return self._query_commission(mode=mode, timeout=45)
@staticmethod
def _parse_margin_ratio_row(data: dict) -> dict[str, float]:
long_r = float(
data.get("LongMarginRatioByMoney")
or data.get("LongMarginRatio")
or 0
)
short_r = float(
data.get("ShortMarginRatioByMoney")
or data.get("ShortMarginRatio")
or 0
)
return {"long": long_r, "short": short_r}
def _cache_margin_ratio(self, sym: str, data: dict) -> None:
ratios = self._parse_margin_ratio_row(data)
if ratios["long"] <= 0 and ratios["short"] <= 0:
return
key = (sym or "").strip().lower()
if not key:
return
self._instrument_margin_ratios[key] = ratios
def _ensure_instrument_margin_hooks(self) -> None:
"""登录前挂钩:合约查询回报缓存保证金率;支持按需 reqQryInstrumentMarginRate。"""
if not self._engine:
return
try:
gw = self._engine.get_gateway(GATEWAY_NAME)
td = gw.td_api
except Exception:
return
bridge = self
if not self._instrument_hooked:
orig = td.onRspQryInstrument
def on_instrument(data: dict, error: dict, reqid: int, last: bool) -> None:
try:
if data and data.get("InstrumentID"):
bridge._cache_margin_ratio(str(data["InstrumentID"]), data)
except Exception as exc:
logger.debug("instrument margin cache: %s", exc)
return orig(data, error, reqid, last)
td.onRspQryInstrument = on_instrument # type: ignore[method-assign]
self._instrument_hooked = True
if self._margin_rate_hooked:
return
def on_margin_rate(data: dict, error: dict, reqid: int, last: bool) -> None:
if error and int(error.get("ErrorID") or 0) != 0:
logger.debug(
"CTP margin rate error reqid=%s: %s",
reqid,
error.get("ErrorMsg") or error,
)
if data and data.get("InstrumentID"):
bridge._margin_rate_lists.setdefault(reqid, []).append(dict(data))
ev = bridge._margin_rate_waiters.get(reqid)
if last and ev:
ev.set()
td.onRspQryInstrumentMarginRate = on_margin_rate # type: ignore[method-assign]
self._margin_rate_hooked = True
def _query_instrument_margin_rate(
self,
*,
mode: str,
instrument_id: str,
exchange_id: str,
timeout: float = 6,
) -> Optional[dict[str, float]]:
if self._connected_mode != mode or not self._engine:
return None
sym = (instrument_id or "").strip()
if not sym:
return None
cached = self._instrument_margin_ratios.get(sym.lower())
if cached and (cached.get("long", 0) > 0 or cached.get("short", 0) > 0):
return cached
try:
gw = self._engine.get_gateway(GATEWAY_NAME)
td = gw.td_api
except Exception as exc:
logger.debug("margin rate query init: %s", exc)
return None
if not getattr(td, "login_status", False):
return None
if not hasattr(td, "reqQryInstrumentMarginRate"):
return None
self._ensure_instrument_margin_hooks()
reqid = int(getattr(td, "reqid", 0)) + 1
td.reqid = reqid
ev = threading.Event()
self._margin_rate_waiters[reqid] = ev
req = {
"BrokerID": td.brokerid,
"InvestorID": td.userid,
"InstrumentID": sym,
"ExchangeID": exchange_id or "",
"InvestorRange": "1",
"HedgeFlag": "1",
}
with _ctp_td_lock:
ret = td.reqQryInstrumentMarginRate(req, reqid)
if ret != 0:
self._margin_rate_waiters.pop(reqid, None)
return None
ev.wait(timeout=timeout)
self._margin_rate_waiters.pop(reqid, None)
rows = self._margin_rate_lists.pop(reqid, [])
if not rows:
return None
ratios = self._parse_margin_ratio_row(rows[-1])
if ratios["long"] > 0 or ratios["short"] > 0:
self._cache_margin_ratio(sym, rows[-1])
return ratios
return None
def _lookup_margin_ratios(
self,
sym: str,
ex_name: str,
*,
mode: Optional[str] = None,
) -> Optional[dict[str, float]]:
key = (sym or "").strip().lower()
if not key:
return None
cached = self._instrument_margin_ratios.get(key)
if cached and (cached.get("long", 0) > 0 or cached.get("short", 0) > 0):
return cached
if mode and self._connected_mode == mode:
return self._query_instrument_margin_rate(
mode=mode,
instrument_id=sym,
exchange_id=ex_name,
)
return None
def _lookup_margin_per_lot(self, sym: str, direction: str) -> float:
return float(
self._margin_per_lot.get(self._position_margin_key(sym, direction), 0) or 0
)
def _margin_from_ratios(
self,
price: float,
mult: float,
ratios: dict[str, float],
*,
direction: str,
) -> Optional[float]:
long_r = float(ratios.get("long") or 0)
short_r = float(ratios.get("short") or 0)
d = (direction or "long").strip().lower()
if mult <= 0 or price <= 0:
return None
if d == "max":
candidates = [
round(float(price) * mult * r, 2)
for r in (long_r, short_r)
if r > 0
]
return max(candidates) if candidates else None
if d == "short" and short_r > 0:
ratio = short_r
elif d != "short" and long_r > 0:
ratio = long_r
else:
ratio = max(long_r, short_r)
if ratio <= 0:
return None
return round(float(price) * mult * ratio, 2)
def _tick_key(self, symbol: str, ex_name: str) -> str:
return f"{symbol.lower()}:{ex_name.upper()}"
@@ -1240,6 +1430,50 @@ class CtpBridge:
return letters.upper() + (digits[-3:] if len(digits) >= 4 else digits)
return letters.lower() + digits
def _get_contract_for_ths(self, ths_code: str) -> Any:
"""按同花顺代码查 CTP 合约;精确匹配失败时在同交易所按品种前缀回退。"""
if not self._engine:
return None
try:
sym, ex_name = ths_to_vnpy_symbol(ths_code)
exchange = to_vnpy_exchange(ex_name)
vt_symbol = f"{sym}.{exchange.value}"
contract = self._engine.get_contract(vt_symbol)
if contract:
return contract
m = re.match(r"^([A-Za-z]+)", (ths_code or "").strip())
if not m:
return None
letters = m.group(1)
ex_val = exchange.value
candidates: list[Any] = []
get_all = getattr(self._engine, "get_all_contracts", None)
pool = list(get_all()) if callable(get_all) else []
if not pool:
raw = getattr(self._engine, "contracts", None)
if isinstance(raw, dict):
pool = list(raw.values())
sym_prefix = sym[: len(letters)] if sym else letters.lower()
sym_prefix_up = letters.upper()
for c in pool:
c_ex = getattr(c, "exchange", None)
c_ex_val = str(c_ex.value if hasattr(c_ex, "value") else c_ex or "")
if c_ex_val != ex_val:
continue
c_sym = str(getattr(c, "symbol", "") or "")
if (
c_sym.lower().startswith(sym_prefix.lower())
or c_sym.upper().startswith(sym_prefix_up)
):
candidates.append(c)
if not candidates:
return None
candidates.sort(key=lambda c: str(getattr(c, "symbol", "") or ""))
return candidates[0]
except Exception as exc:
logger.debug("_get_contract_for_ths %s: %s", ths_code, exc)
return None
def estimate_margin_one_lot(
self,
ths_code: str,
@@ -1247,29 +1481,35 @@ class CtpBridge:
*,
direction: str = "long",
) -> Optional[float]:
"""用 CTP 合约信息估算 1 手保证金(需已连接并完成合约查询"""
"""1 手保证金:持仓实收 > CTP 保证金率查询 > 合约查询缓存"""
if not self._engine or not price or price <= 0:
return None
try:
sym, ex_name = ths_to_vnpy_symbol(ths_code)
exchange = to_vnpy_exchange(ex_name)
vt_symbol = f"{sym}.{exchange.value}"
contract = self._engine.get_contract(vt_symbol)
if not contract:
return None
mult = float(getattr(contract, "size", 0) or 0)
long_r = float(getattr(contract, "long_margin_ratio", 0) or 0)
short_r = float(getattr(contract, "short_margin_ratio", 0) or 0)
contract = self._get_contract_for_ths(ths_code)
mult = float(getattr(contract, "size", 0) or 0) if contract else 0.0
if mult <= 0:
mult = float(get_contract_spec(ths_code).get("mult") or 0)
d = (direction or "long").strip().lower()
if d == "short" and short_r > 0:
ratio = short_r
elif d != "short" and long_r > 0:
ratio = long_r
if d == "max":
per_lots = [
self._lookup_margin_per_lot(sym, side)
for side in ("long", "short")
]
per_lots = [x for x in per_lots if x > 0]
if per_lots:
return max(per_lots)
else:
ratio = max(long_r, short_r)
if mult <= 0 or ratio <= 0:
return None
return round(float(price) * mult * ratio, 2)
per_lot = self._lookup_margin_per_lot(sym, d)
if per_lot > 0:
return per_lot
mode = self._connected_mode
ratios = self._lookup_margin_ratios(sym, ex_name, mode=mode)
if ratios:
return self._margin_from_ratios(
price, mult, ratios, direction=d,
)
return None
except Exception as exc:
logger.debug("estimate_margin_one_lot %s: %s", ths_code, exc)
return None
@@ -1308,9 +1548,7 @@ class CtpBridge:
return None
try:
sym, ex_name = ths_to_vnpy_symbol(ths_code)
exchange = to_vnpy_exchange(ex_name)
vt_symbol = f"{sym}.{exchange.value}"
contract = self._engine.get_contract(vt_symbol)
contract = self._get_contract_for_ths(ths_code)
if not contract:
return None
mult = float(getattr(contract, "size", 0) or 0)
@@ -1324,6 +1562,18 @@ class CtpBridge:
out: dict[str, Any] = {"mult": mult}
if tick > 0:
out["tick_size"] = tick
long_r = float(getattr(contract, "long_margin_ratio", 0) or 0)
short_r = float(getattr(contract, "short_margin_ratio", 0) or 0)
c_sym = str(getattr(contract, "symbol", "") or sym or "")
if c_sym and self._connected_mode:
queried = self._lookup_margin_ratios(
c_sym, ex_name, mode=self._connected_mode,
)
if queried:
long_r = float(queried.get("long") or long_r)
short_r = float(queried.get("short") or short_r)
if long_r > 0 or short_r > 0:
out["margin_rate"] = max(long_r, short_r)
return out
except Exception as exc:
logger.debug("lookup_contract_spec %s: %s", ths_code, exc)
@@ -1763,6 +2013,23 @@ def ctp_get_account(mode: str) -> dict[str, Any]:
return b.get_account()
def ctp_sum_position_margins(
mode: str,
*,
refresh_if_empty: bool = True,
refresh_margin: bool = False,
) -> float:
"""各持仓 CTP 回报保证金之和(与柜台「实收保证金」一致)。"""
total = 0.0
for p in ctp_list_positions(
mode, refresh_if_empty=refresh_if_empty, refresh_margin=refresh_margin,
):
m = float(p.get("margin") or 0)
if m > 0:
total += m
return round(total, 2) if total > 0 else 0.0
def ctp_account_margin_used(mode: str) -> Optional[float]:
"""账户实际占用保证金 ≈ 权益 − 可用(与顶栏柜台资金一致)。"""
b = get_bridge()