Align margin display with CTP counter rates and position margin.

Read margin ratios from CTP instrument query and margin-rate API instead of vnpy ContractData (which lacks ratios). Keep occupied margin on position UseMargin; use per-lot max rate for recommend table.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-29 10:21:44 +08:00
parent 71c480a587
commit 19676943d0
5 changed files with 332 additions and 50 deletions
+3
View File
@@ -92,6 +92,7 @@ def margin_one_lot(
) -> tuple[float, str, dict]:
"""1 手保证金。CTP 已连接时优先读柜台合约保证金率,否则用本地参考规格估算。
direction 可为 long / short / max(多空费率取较大值,用于可开仓品种表)。
返回 (保证金, 来源 estimate|ctp, 合约规格片段)。
"""
spec = get_contract_spec(ths_code)
@@ -113,6 +114,8 @@ def margin_one_lot(
merged["mult"] = ctp_spec["mult"]
if ctp_spec.get("tick_size"):
merged["tick_size"] = ctp_spec["tick_size"]
if ctp_spec.get("margin_rate"):
merged["margin_rate"] = ctp_spec["margin_rate"]
return float(ctp_margin), "ctp", merged
except Exception:
pass
+32 -17
View File
@@ -287,27 +287,42 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
mode: str,
capital: float,
) -> list[dict]:
"""CTP 已连接时,用「权益−可用」校正占用保证金与仓位占比"""
"""仅在持仓缺少柜台保证金时补全;已有 CTP 持仓保证金的行不覆盖"""
if not ctp_status(mode).get("connected"):
return rows
total_used = ctp_account_margin_used(mode)
if not total_used:
return rows
active = [
r for r in rows
if r.get("order_state") != "pending" and int(r.get("lots") or 0) > 0
]
if not active:
return rows
if len(active) == 1:
row = active[0]
row["margin"] = total_used
row["margin_source"] = "ctp"
if capital > 0:
row["position_pct"] = round(total_used / capital * 100, 2)
return rows
weights: list[float] = []
def _has_ctp_margin(row: dict) -> bool:
return (
float(row.get("margin") or 0) > 0
and row.get("margin_source") == "ctp"
)
without_margin = [r for r in active if not _has_ctp_margin(r)]
for row in active:
if _has_ctp_margin(row) and capital > 0:
m = float(row.get("margin") or 0)
row["position_pct"] = round(m / capital * 100, 2)
if not without_margin:
return rows
total_used = ctp_account_margin_used(mode)
if not total_used:
return rows
known_sum = sum(
float(r.get("margin") or 0) for r in active if _has_ctp_margin(r)
)
pool = max(0.0, float(total_used) - known_sum) if known_sum > 0 else float(total_used)
if pool <= 0:
return rows
weights: list[float] = []
for row in without_margin:
sym = (row.get("symbol_code") or "").strip()
lots = int(row.get("lots") or 0)
entry = float(row.get("entry_price") or 0)
@@ -318,13 +333,13 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
weights.append(0.0)
total_weight = sum(weights)
assigned = 0.0
for i, row in enumerate(active):
for i, row in enumerate(without_margin):
if total_weight <= 0:
margin = round(total_used / len(active), 2)
elif i == len(active) - 1:
margin = round(total_used - assigned, 2)
margin = round(pool / len(without_margin), 2)
elif i == len(without_margin) - 1:
margin = round(pool - assigned, 2)
else:
margin = round(total_used * weights[i] / total_weight, 2)
margin = round(pool * weights[i] / total_weight, 2)
assigned += margin
row["margin"] = margin
row["margin_source"] = "ctp"
+1 -1
View File
@@ -216,7 +216,7 @@ def assess_product_for_capital(
code_for_margin = (main_code or "").strip() or (ths + "8888")
if p > 0 and ctp_connected:
margin_one, margin_source, spec_used = margin_one_lot(
code_for_margin, p, trading_mode=trading_mode,
code_for_margin, p, direction="max", trading_mode=trading_mode,
)
if spec_used.get("mult"):
mult = spec_used["mult"]
+9 -12
View File
@@ -118,23 +118,19 @@ def _ctp_connected_for_mode(trading_mode: str) -> bool:
def recommend_margin_used(trading_mode: str) -> float:
"""当前持仓已占用保证金(CTP 柜台优先)。"""
"""当前持仓已占用保证金(各持仓 CTP 回报之和,与柜台持仓保证金一致)。"""
if not _ctp_connected_for_mode(trading_mode):
return 0.0
try:
from vnpy_bridge import ctp_account_margin_used, ctp_list_positions
from vnpy_bridge import ctp_account_margin_used, ctp_sum_position_margins
used = ctp_account_margin_used(trading_mode)
if used is not None and used > 0:
return float(used)
total = 0.0
for p in ctp_list_positions(
total = ctp_sum_position_margins(
trading_mode, refresh_if_empty=False, refresh_margin=True,
):
m = float(p.get("margin") or 0)
if m > 0:
total += m
return round(total, 2) if total > 0 else 0.0
)
if total > 0:
return total
used = ctp_account_margin_used(trading_mode)
return float(used) if used and used > 0 else 0.0
except Exception as exc:
logger.debug("recommend_margin_used: %s", exc)
return 0.0
@@ -196,6 +192,7 @@ def enrich_recommend_rows(
margin_one, margin_source, spec_used = margin_one_lot(
code_for_margin,
price,
direction="max",
trading_mode=trading_mode if ctp_connected else None,
)
if spec_used.get("mult"):
+286 -19
View File
@@ -8,6 +8,7 @@ from __future__ import annotations
import logging
import os
import re
import threading
import time
from collections import deque
@@ -237,6 +238,12 @@ class CtpBridge:
self._commission_waiters: dict[int, threading.Event] = {}
self._commission_lists: dict[int, list] = {}
self._commission_hooked = False
self._margin_rate_waiters: dict[int, threading.Event] = {}
self._margin_rate_lists: dict[int, list] = {}
self._margin_rate_hooked = False
self._instrument_hooked = False
self._instrument_margin_ratios: dict[str, dict[str, float]] = {}
self._margin_per_lot: dict[str, float] = {}
self._subscribed: set[str] = set()
self._last_position_query_ts: float = 0.0
self._position_margins: dict[str, float] = {}
@@ -305,6 +312,10 @@ class CtpBridge:
raw = float(getattr(pos, attr, 0) or 0)
if raw > 0:
self._position_margins[self._position_margin_key(sym, d)] = raw
if vol > 0:
self._margin_per_lot[self._position_margin_key(sym, d)] = round(
raw / vol, 2,
)
break
except Exception as exc:
logger.debug("position margin cache: %s", exc)
@@ -637,6 +648,7 @@ class CtpBridge:
"请更新 .env 中 SIMNOW_TD_ADDRESS 为官网最新地址,"
"并在服务器执行 nc -zv 验证出网。"
)
self._ensure_instrument_margin_hooks()
self._engine.connect(setting, GATEWAY_NAME)
if self._wait_connected(mode, ctp_logs):
self._connected_mode = mode
@@ -943,6 +955,184 @@ class CtpBridge:
"""批量查询全部合约手续费(InstrumentID 留空)。"""
return self._query_commission(mode=mode, timeout=45)
@staticmethod
def _parse_margin_ratio_row(data: dict) -> dict[str, float]:
long_r = float(
data.get("LongMarginRatioByMoney")
or data.get("LongMarginRatio")
or 0
)
short_r = float(
data.get("ShortMarginRatioByMoney")
or data.get("ShortMarginRatio")
or 0
)
return {"long": long_r, "short": short_r}
def _cache_margin_ratio(self, sym: str, data: dict) -> None:
ratios = self._parse_margin_ratio_row(data)
if ratios["long"] <= 0 and ratios["short"] <= 0:
return
key = (sym or "").strip().lower()
if not key:
return
self._instrument_margin_ratios[key] = ratios
def _ensure_instrument_margin_hooks(self) -> None:
"""登录前挂钩:合约查询回报缓存保证金率;支持按需 reqQryInstrumentMarginRate。"""
if not self._engine:
return
try:
gw = self._engine.get_gateway(GATEWAY_NAME)
td = gw.td_api
except Exception:
return
bridge = self
if not self._instrument_hooked:
orig = td.onRspQryInstrument
def on_instrument(data: dict, error: dict, reqid: int, last: bool) -> None:
try:
if data and data.get("InstrumentID"):
bridge._cache_margin_ratio(str(data["InstrumentID"]), data)
except Exception as exc:
logger.debug("instrument margin cache: %s", exc)
return orig(data, error, reqid, last)
td.onRspQryInstrument = on_instrument # type: ignore[method-assign]
self._instrument_hooked = True
if self._margin_rate_hooked:
return
def on_margin_rate(data: dict, error: dict, reqid: int, last: bool) -> None:
if error and int(error.get("ErrorID") or 0) != 0:
logger.debug(
"CTP margin rate error reqid=%s: %s",
reqid,
error.get("ErrorMsg") or error,
)
if data and data.get("InstrumentID"):
bridge._margin_rate_lists.setdefault(reqid, []).append(dict(data))
ev = bridge._margin_rate_waiters.get(reqid)
if last and ev:
ev.set()
td.onRspQryInstrumentMarginRate = on_margin_rate # type: ignore[method-assign]
self._margin_rate_hooked = True
def _query_instrument_margin_rate(
self,
*,
mode: str,
instrument_id: str,
exchange_id: str,
timeout: float = 6,
) -> Optional[dict[str, float]]:
if self._connected_mode != mode or not self._engine:
return None
sym = (instrument_id or "").strip()
if not sym:
return None
cached = self._instrument_margin_ratios.get(sym.lower())
if cached and (cached.get("long", 0) > 0 or cached.get("short", 0) > 0):
return cached
try:
gw = self._engine.get_gateway(GATEWAY_NAME)
td = gw.td_api
except Exception as exc:
logger.debug("margin rate query init: %s", exc)
return None
if not getattr(td, "login_status", False):
return None
if not hasattr(td, "reqQryInstrumentMarginRate"):
return None
self._ensure_instrument_margin_hooks()
reqid = int(getattr(td, "reqid", 0)) + 1
td.reqid = reqid
ev = threading.Event()
self._margin_rate_waiters[reqid] = ev
req = {
"BrokerID": td.brokerid,
"InvestorID": td.userid,
"InstrumentID": sym,
"ExchangeID": exchange_id or "",
"InvestorRange": "1",
"HedgeFlag": "1",
}
with _ctp_td_lock:
ret = td.reqQryInstrumentMarginRate(req, reqid)
if ret != 0:
self._margin_rate_waiters.pop(reqid, None)
return None
ev.wait(timeout=timeout)
self._margin_rate_waiters.pop(reqid, None)
rows = self._margin_rate_lists.pop(reqid, [])
if not rows:
return None
ratios = self._parse_margin_ratio_row(rows[-1])
if ratios["long"] > 0 or ratios["short"] > 0:
self._cache_margin_ratio(sym, rows[-1])
return ratios
return None
def _lookup_margin_ratios(
self,
sym: str,
ex_name: str,
*,
mode: Optional[str] = None,
) -> Optional[dict[str, float]]:
key = (sym or "").strip().lower()
if not key:
return None
cached = self._instrument_margin_ratios.get(key)
if cached and (cached.get("long", 0) > 0 or cached.get("short", 0) > 0):
return cached
if mode and self._connected_mode == mode:
return self._query_instrument_margin_rate(
mode=mode,
instrument_id=sym,
exchange_id=ex_name,
)
return None
def _lookup_margin_per_lot(self, sym: str, direction: str) -> float:
return float(
self._margin_per_lot.get(self._position_margin_key(sym, direction), 0) or 0
)
def _margin_from_ratios(
self,
price: float,
mult: float,
ratios: dict[str, float],
*,
direction: str,
) -> Optional[float]:
long_r = float(ratios.get("long") or 0)
short_r = float(ratios.get("short") or 0)
d = (direction or "long").strip().lower()
if mult <= 0 or price <= 0:
return None
if d == "max":
candidates = [
round(float(price) * mult * r, 2)
for r in (long_r, short_r)
if r > 0
]
return max(candidates) if candidates else None
if d == "short" and short_r > 0:
ratio = short_r
elif d != "short" and long_r > 0:
ratio = long_r
else:
ratio = max(long_r, short_r)
if ratio <= 0:
return None
return round(float(price) * mult * ratio, 2)
def _tick_key(self, symbol: str, ex_name: str) -> str:
return f"{symbol.lower()}:{ex_name.upper()}"
@@ -1240,6 +1430,50 @@ class CtpBridge:
return letters.upper() + (digits[-3:] if len(digits) >= 4 else digits)
return letters.lower() + digits
def _get_contract_for_ths(self, ths_code: str) -> Any:
"""按同花顺代码查 CTP 合约;精确匹配失败时在同交易所按品种前缀回退。"""
if not self._engine:
return None
try:
sym, ex_name = ths_to_vnpy_symbol(ths_code)
exchange = to_vnpy_exchange(ex_name)
vt_symbol = f"{sym}.{exchange.value}"
contract = self._engine.get_contract(vt_symbol)
if contract:
return contract
m = re.match(r"^([A-Za-z]+)", (ths_code or "").strip())
if not m:
return None
letters = m.group(1)
ex_val = exchange.value
candidates: list[Any] = []
get_all = getattr(self._engine, "get_all_contracts", None)
pool = list(get_all()) if callable(get_all) else []
if not pool:
raw = getattr(self._engine, "contracts", None)
if isinstance(raw, dict):
pool = list(raw.values())
sym_prefix = sym[: len(letters)] if sym else letters.lower()
sym_prefix_up = letters.upper()
for c in pool:
c_ex = getattr(c, "exchange", None)
c_ex_val = str(c_ex.value if hasattr(c_ex, "value") else c_ex or "")
if c_ex_val != ex_val:
continue
c_sym = str(getattr(c, "symbol", "") or "")
if (
c_sym.lower().startswith(sym_prefix.lower())
or c_sym.upper().startswith(sym_prefix_up)
):
candidates.append(c)
if not candidates:
return None
candidates.sort(key=lambda c: str(getattr(c, "symbol", "") or ""))
return candidates[0]
except Exception as exc:
logger.debug("_get_contract_for_ths %s: %s", ths_code, exc)
return None
def estimate_margin_one_lot(
self,
ths_code: str,
@@ -1247,29 +1481,35 @@ class CtpBridge:
*,
direction: str = "long",
) -> Optional[float]:
"""用 CTP 合约信息估算 1 手保证金(需已连接并完成合约查询"""
"""1 手保证金:持仓实收 > CTP 保证金率查询 > 合约查询缓存"""
if not self._engine or not price or price <= 0:
return None
try:
sym, ex_name = ths_to_vnpy_symbol(ths_code)
exchange = to_vnpy_exchange(ex_name)
vt_symbol = f"{sym}.{exchange.value}"
contract = self._engine.get_contract(vt_symbol)
if not contract:
return None
mult = float(getattr(contract, "size", 0) or 0)
long_r = float(getattr(contract, "long_margin_ratio", 0) or 0)
short_r = float(getattr(contract, "short_margin_ratio", 0) or 0)
contract = self._get_contract_for_ths(ths_code)
mult = float(getattr(contract, "size", 0) or 0) if contract else 0.0
if mult <= 0:
mult = float(get_contract_spec(ths_code).get("mult") or 0)
d = (direction or "long").strip().lower()
if d == "short" and short_r > 0:
ratio = short_r
elif d != "short" and long_r > 0:
ratio = long_r
if d == "max":
per_lots = [
self._lookup_margin_per_lot(sym, side)
for side in ("long", "short")
]
per_lots = [x for x in per_lots if x > 0]
if per_lots:
return max(per_lots)
else:
ratio = max(long_r, short_r)
if mult <= 0 or ratio <= 0:
per_lot = self._lookup_margin_per_lot(sym, d)
if per_lot > 0:
return per_lot
mode = self._connected_mode
ratios = self._lookup_margin_ratios(sym, ex_name, mode=mode)
if ratios:
return self._margin_from_ratios(
price, mult, ratios, direction=d,
)
return None
return round(float(price) * mult * ratio, 2)
except Exception as exc:
logger.debug("estimate_margin_one_lot %s: %s", ths_code, exc)
return None
@@ -1308,9 +1548,7 @@ class CtpBridge:
return None
try:
sym, ex_name = ths_to_vnpy_symbol(ths_code)
exchange = to_vnpy_exchange(ex_name)
vt_symbol = f"{sym}.{exchange.value}"
contract = self._engine.get_contract(vt_symbol)
contract = self._get_contract_for_ths(ths_code)
if not contract:
return None
mult = float(getattr(contract, "size", 0) or 0)
@@ -1324,6 +1562,18 @@ class CtpBridge:
out: dict[str, Any] = {"mult": mult}
if tick > 0:
out["tick_size"] = tick
long_r = float(getattr(contract, "long_margin_ratio", 0) or 0)
short_r = float(getattr(contract, "short_margin_ratio", 0) or 0)
c_sym = str(getattr(contract, "symbol", "") or sym or "")
if c_sym and self._connected_mode:
queried = self._lookup_margin_ratios(
c_sym, ex_name, mode=self._connected_mode,
)
if queried:
long_r = float(queried.get("long") or long_r)
short_r = float(queried.get("short") or short_r)
if long_r > 0 or short_r > 0:
out["margin_rate"] = max(long_r, short_r)
return out
except Exception as exc:
logger.debug("lookup_contract_spec %s: %s", ths_code, exc)
@@ -1763,6 +2013,23 @@ def ctp_get_account(mode: str) -> dict[str, Any]:
return b.get_account()
def ctp_sum_position_margins(
mode: str,
*,
refresh_if_empty: bool = True,
refresh_margin: bool = False,
) -> float:
"""各持仓 CTP 回报保证金之和(与柜台「实收保证金」一致)。"""
total = 0.0
for p in ctp_list_positions(
mode, refresh_if_empty=refresh_if_empty, refresh_margin=refresh_margin,
):
m = float(p.get("margin") or 0)
if m > 0:
total += m
return round(total, 2) if total > 0 else 0.0
def ctp_account_margin_used(mode: str) -> Optional[float]:
"""账户实际占用保证金 ≈ 权益 − 可用(与顶栏柜台资金一致)。"""
b = get_bridge()