Align margin display with CTP counter rates and position margin.
Read margin ratios from CTP instrument query and margin-rate API instead of vnpy ContractData (which lacks ratios). Keep occupied margin on position UseMargin; use per-lot max rate for recommend table. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -92,6 +92,7 @@ def margin_one_lot(
|
||||
) -> tuple[float, str, dict]:
|
||||
"""1 手保证金。CTP 已连接时优先读柜台合约保证金率,否则用本地参考规格估算。
|
||||
|
||||
direction 可为 long / short / max(多空费率取较大值,用于可开仓品种表)。
|
||||
返回 (保证金, 来源 estimate|ctp, 合约规格片段)。
|
||||
"""
|
||||
spec = get_contract_spec(ths_code)
|
||||
@@ -113,6 +114,8 @@ def margin_one_lot(
|
||||
merged["mult"] = ctp_spec["mult"]
|
||||
if ctp_spec.get("tick_size"):
|
||||
merged["tick_size"] = ctp_spec["tick_size"]
|
||||
if ctp_spec.get("margin_rate"):
|
||||
merged["margin_rate"] = ctp_spec["margin_rate"]
|
||||
return float(ctp_margin), "ctp", merged
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
+32
-17
@@ -287,27 +287,42 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
|
||||
mode: str,
|
||||
capital: float,
|
||||
) -> list[dict]:
|
||||
"""CTP 已连接时,用「权益−可用」校正占用保证金与仓位占比。"""
|
||||
"""仅在持仓缺少柜台保证金时补全;已有 CTP 持仓保证金的行不覆盖。"""
|
||||
if not ctp_status(mode).get("connected"):
|
||||
return rows
|
||||
total_used = ctp_account_margin_used(mode)
|
||||
if not total_used:
|
||||
return rows
|
||||
active = [
|
||||
r for r in rows
|
||||
if r.get("order_state") != "pending" and int(r.get("lots") or 0) > 0
|
||||
]
|
||||
if not active:
|
||||
return rows
|
||||
if len(active) == 1:
|
||||
row = active[0]
|
||||
row["margin"] = total_used
|
||||
row["margin_source"] = "ctp"
|
||||
if capital > 0:
|
||||
row["position_pct"] = round(total_used / capital * 100, 2)
|
||||
return rows
|
||||
weights: list[float] = []
|
||||
|
||||
def _has_ctp_margin(row: dict) -> bool:
|
||||
return (
|
||||
float(row.get("margin") or 0) > 0
|
||||
and row.get("margin_source") == "ctp"
|
||||
)
|
||||
|
||||
without_margin = [r for r in active if not _has_ctp_margin(r)]
|
||||
for row in active:
|
||||
if _has_ctp_margin(row) and capital > 0:
|
||||
m = float(row.get("margin") or 0)
|
||||
row["position_pct"] = round(m / capital * 100, 2)
|
||||
if not without_margin:
|
||||
return rows
|
||||
|
||||
total_used = ctp_account_margin_used(mode)
|
||||
if not total_used:
|
||||
return rows
|
||||
known_sum = sum(
|
||||
float(r.get("margin") or 0) for r in active if _has_ctp_margin(r)
|
||||
)
|
||||
pool = max(0.0, float(total_used) - known_sum) if known_sum > 0 else float(total_used)
|
||||
if pool <= 0:
|
||||
return rows
|
||||
|
||||
weights: list[float] = []
|
||||
for row in without_margin:
|
||||
sym = (row.get("symbol_code") or "").strip()
|
||||
lots = int(row.get("lots") or 0)
|
||||
entry = float(row.get("entry_price") or 0)
|
||||
@@ -318,13 +333,13 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
|
||||
weights.append(0.0)
|
||||
total_weight = sum(weights)
|
||||
assigned = 0.0
|
||||
for i, row in enumerate(active):
|
||||
for i, row in enumerate(without_margin):
|
||||
if total_weight <= 0:
|
||||
margin = round(total_used / len(active), 2)
|
||||
elif i == len(active) - 1:
|
||||
margin = round(total_used - assigned, 2)
|
||||
margin = round(pool / len(without_margin), 2)
|
||||
elif i == len(without_margin) - 1:
|
||||
margin = round(pool - assigned, 2)
|
||||
else:
|
||||
margin = round(total_used * weights[i] / total_weight, 2)
|
||||
margin = round(pool * weights[i] / total_weight, 2)
|
||||
assigned += margin
|
||||
row["margin"] = margin
|
||||
row["margin_source"] = "ctp"
|
||||
|
||||
@@ -216,7 +216,7 @@ def assess_product_for_capital(
|
||||
code_for_margin = (main_code or "").strip() or (ths + "8888")
|
||||
if p > 0 and ctp_connected:
|
||||
margin_one, margin_source, spec_used = margin_one_lot(
|
||||
code_for_margin, p, trading_mode=trading_mode,
|
||||
code_for_margin, p, direction="max", trading_mode=trading_mode,
|
||||
)
|
||||
if spec_used.get("mult"):
|
||||
mult = spec_used["mult"]
|
||||
|
||||
+9
-12
@@ -118,23 +118,19 @@ def _ctp_connected_for_mode(trading_mode: str) -> bool:
|
||||
|
||||
|
||||
def recommend_margin_used(trading_mode: str) -> float:
|
||||
"""当前持仓已占用保证金(CTP 柜台优先)。"""
|
||||
"""当前持仓已占用保证金(各持仓 CTP 回报之和,与柜台持仓保证金一致)。"""
|
||||
if not _ctp_connected_for_mode(trading_mode):
|
||||
return 0.0
|
||||
try:
|
||||
from vnpy_bridge import ctp_account_margin_used, ctp_list_positions
|
||||
from vnpy_bridge import ctp_account_margin_used, ctp_sum_position_margins
|
||||
|
||||
used = ctp_account_margin_used(trading_mode)
|
||||
if used is not None and used > 0:
|
||||
return float(used)
|
||||
total = 0.0
|
||||
for p in ctp_list_positions(
|
||||
total = ctp_sum_position_margins(
|
||||
trading_mode, refresh_if_empty=False, refresh_margin=True,
|
||||
):
|
||||
m = float(p.get("margin") or 0)
|
||||
if m > 0:
|
||||
total += m
|
||||
return round(total, 2) if total > 0 else 0.0
|
||||
)
|
||||
if total > 0:
|
||||
return total
|
||||
used = ctp_account_margin_used(trading_mode)
|
||||
return float(used) if used and used > 0 else 0.0
|
||||
except Exception as exc:
|
||||
logger.debug("recommend_margin_used: %s", exc)
|
||||
return 0.0
|
||||
@@ -196,6 +192,7 @@ def enrich_recommend_rows(
|
||||
margin_one, margin_source, spec_used = margin_one_lot(
|
||||
code_for_margin,
|
||||
price,
|
||||
direction="max",
|
||||
trading_mode=trading_mode if ctp_connected else None,
|
||||
)
|
||||
if spec_used.get("mult"):
|
||||
|
||||
+287
-20
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
@@ -237,6 +238,12 @@ class CtpBridge:
|
||||
self._commission_waiters: dict[int, threading.Event] = {}
|
||||
self._commission_lists: dict[int, list] = {}
|
||||
self._commission_hooked = False
|
||||
self._margin_rate_waiters: dict[int, threading.Event] = {}
|
||||
self._margin_rate_lists: dict[int, list] = {}
|
||||
self._margin_rate_hooked = False
|
||||
self._instrument_hooked = False
|
||||
self._instrument_margin_ratios: dict[str, dict[str, float]] = {}
|
||||
self._margin_per_lot: dict[str, float] = {}
|
||||
self._subscribed: set[str] = set()
|
||||
self._last_position_query_ts: float = 0.0
|
||||
self._position_margins: dict[str, float] = {}
|
||||
@@ -305,6 +312,10 @@ class CtpBridge:
|
||||
raw = float(getattr(pos, attr, 0) or 0)
|
||||
if raw > 0:
|
||||
self._position_margins[self._position_margin_key(sym, d)] = raw
|
||||
if vol > 0:
|
||||
self._margin_per_lot[self._position_margin_key(sym, d)] = round(
|
||||
raw / vol, 2,
|
||||
)
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.debug("position margin cache: %s", exc)
|
||||
@@ -637,6 +648,7 @@ class CtpBridge:
|
||||
"请更新 .env 中 SIMNOW_TD_ADDRESS 为官网最新地址,"
|
||||
"并在服务器执行 nc -zv 验证出网。"
|
||||
)
|
||||
self._ensure_instrument_margin_hooks()
|
||||
self._engine.connect(setting, GATEWAY_NAME)
|
||||
if self._wait_connected(mode, ctp_logs):
|
||||
self._connected_mode = mode
|
||||
@@ -943,6 +955,184 @@ class CtpBridge:
|
||||
"""批量查询全部合约手续费(InstrumentID 留空)。"""
|
||||
return self._query_commission(mode=mode, timeout=45)
|
||||
|
||||
@staticmethod
|
||||
def _parse_margin_ratio_row(data: dict) -> dict[str, float]:
|
||||
long_r = float(
|
||||
data.get("LongMarginRatioByMoney")
|
||||
or data.get("LongMarginRatio")
|
||||
or 0
|
||||
)
|
||||
short_r = float(
|
||||
data.get("ShortMarginRatioByMoney")
|
||||
or data.get("ShortMarginRatio")
|
||||
or 0
|
||||
)
|
||||
return {"long": long_r, "short": short_r}
|
||||
|
||||
def _cache_margin_ratio(self, sym: str, data: dict) -> None:
|
||||
ratios = self._parse_margin_ratio_row(data)
|
||||
if ratios["long"] <= 0 and ratios["short"] <= 0:
|
||||
return
|
||||
key = (sym or "").strip().lower()
|
||||
if not key:
|
||||
return
|
||||
self._instrument_margin_ratios[key] = ratios
|
||||
|
||||
def _ensure_instrument_margin_hooks(self) -> None:
|
||||
"""登录前挂钩:合约查询回报缓存保证金率;支持按需 reqQryInstrumentMarginRate。"""
|
||||
if not self._engine:
|
||||
return
|
||||
try:
|
||||
gw = self._engine.get_gateway(GATEWAY_NAME)
|
||||
td = gw.td_api
|
||||
except Exception:
|
||||
return
|
||||
bridge = self
|
||||
|
||||
if not self._instrument_hooked:
|
||||
orig = td.onRspQryInstrument
|
||||
|
||||
def on_instrument(data: dict, error: dict, reqid: int, last: bool) -> None:
|
||||
try:
|
||||
if data and data.get("InstrumentID"):
|
||||
bridge._cache_margin_ratio(str(data["InstrumentID"]), data)
|
||||
except Exception as exc:
|
||||
logger.debug("instrument margin cache: %s", exc)
|
||||
return orig(data, error, reqid, last)
|
||||
|
||||
td.onRspQryInstrument = on_instrument # type: ignore[method-assign]
|
||||
self._instrument_hooked = True
|
||||
|
||||
if self._margin_rate_hooked:
|
||||
return
|
||||
|
||||
def on_margin_rate(data: dict, error: dict, reqid: int, last: bool) -> None:
|
||||
if error and int(error.get("ErrorID") or 0) != 0:
|
||||
logger.debug(
|
||||
"CTP margin rate error reqid=%s: %s",
|
||||
reqid,
|
||||
error.get("ErrorMsg") or error,
|
||||
)
|
||||
if data and data.get("InstrumentID"):
|
||||
bridge._margin_rate_lists.setdefault(reqid, []).append(dict(data))
|
||||
ev = bridge._margin_rate_waiters.get(reqid)
|
||||
if last and ev:
|
||||
ev.set()
|
||||
|
||||
td.onRspQryInstrumentMarginRate = on_margin_rate # type: ignore[method-assign]
|
||||
self._margin_rate_hooked = True
|
||||
|
||||
def _query_instrument_margin_rate(
|
||||
self,
|
||||
*,
|
||||
mode: str,
|
||||
instrument_id: str,
|
||||
exchange_id: str,
|
||||
timeout: float = 6,
|
||||
) -> Optional[dict[str, float]]:
|
||||
if self._connected_mode != mode or not self._engine:
|
||||
return None
|
||||
sym = (instrument_id or "").strip()
|
||||
if not sym:
|
||||
return None
|
||||
cached = self._instrument_margin_ratios.get(sym.lower())
|
||||
if cached and (cached.get("long", 0) > 0 or cached.get("short", 0) > 0):
|
||||
return cached
|
||||
try:
|
||||
gw = self._engine.get_gateway(GATEWAY_NAME)
|
||||
td = gw.td_api
|
||||
except Exception as exc:
|
||||
logger.debug("margin rate query init: %s", exc)
|
||||
return None
|
||||
if not getattr(td, "login_status", False):
|
||||
return None
|
||||
if not hasattr(td, "reqQryInstrumentMarginRate"):
|
||||
return None
|
||||
self._ensure_instrument_margin_hooks()
|
||||
reqid = int(getattr(td, "reqid", 0)) + 1
|
||||
td.reqid = reqid
|
||||
ev = threading.Event()
|
||||
self._margin_rate_waiters[reqid] = ev
|
||||
req = {
|
||||
"BrokerID": td.brokerid,
|
||||
"InvestorID": td.userid,
|
||||
"InstrumentID": sym,
|
||||
"ExchangeID": exchange_id or "",
|
||||
"InvestorRange": "1",
|
||||
"HedgeFlag": "1",
|
||||
}
|
||||
with _ctp_td_lock:
|
||||
ret = td.reqQryInstrumentMarginRate(req, reqid)
|
||||
if ret != 0:
|
||||
self._margin_rate_waiters.pop(reqid, None)
|
||||
return None
|
||||
ev.wait(timeout=timeout)
|
||||
self._margin_rate_waiters.pop(reqid, None)
|
||||
rows = self._margin_rate_lists.pop(reqid, [])
|
||||
if not rows:
|
||||
return None
|
||||
ratios = self._parse_margin_ratio_row(rows[-1])
|
||||
if ratios["long"] > 0 or ratios["short"] > 0:
|
||||
self._cache_margin_ratio(sym, rows[-1])
|
||||
return ratios
|
||||
return None
|
||||
|
||||
def _lookup_margin_ratios(
|
||||
self,
|
||||
sym: str,
|
||||
ex_name: str,
|
||||
*,
|
||||
mode: Optional[str] = None,
|
||||
) -> Optional[dict[str, float]]:
|
||||
key = (sym or "").strip().lower()
|
||||
if not key:
|
||||
return None
|
||||
cached = self._instrument_margin_ratios.get(key)
|
||||
if cached and (cached.get("long", 0) > 0 or cached.get("short", 0) > 0):
|
||||
return cached
|
||||
if mode and self._connected_mode == mode:
|
||||
return self._query_instrument_margin_rate(
|
||||
mode=mode,
|
||||
instrument_id=sym,
|
||||
exchange_id=ex_name,
|
||||
)
|
||||
return None
|
||||
|
||||
def _lookup_margin_per_lot(self, sym: str, direction: str) -> float:
|
||||
return float(
|
||||
self._margin_per_lot.get(self._position_margin_key(sym, direction), 0) or 0
|
||||
)
|
||||
|
||||
def _margin_from_ratios(
|
||||
self,
|
||||
price: float,
|
||||
mult: float,
|
||||
ratios: dict[str, float],
|
||||
*,
|
||||
direction: str,
|
||||
) -> Optional[float]:
|
||||
long_r = float(ratios.get("long") or 0)
|
||||
short_r = float(ratios.get("short") or 0)
|
||||
d = (direction or "long").strip().lower()
|
||||
if mult <= 0 or price <= 0:
|
||||
return None
|
||||
if d == "max":
|
||||
candidates = [
|
||||
round(float(price) * mult * r, 2)
|
||||
for r in (long_r, short_r)
|
||||
if r > 0
|
||||
]
|
||||
return max(candidates) if candidates else None
|
||||
if d == "short" and short_r > 0:
|
||||
ratio = short_r
|
||||
elif d != "short" and long_r > 0:
|
||||
ratio = long_r
|
||||
else:
|
||||
ratio = max(long_r, short_r)
|
||||
if ratio <= 0:
|
||||
return None
|
||||
return round(float(price) * mult * ratio, 2)
|
||||
|
||||
def _tick_key(self, symbol: str, ex_name: str) -> str:
|
||||
return f"{symbol.lower()}:{ex_name.upper()}"
|
||||
|
||||
@@ -1240,6 +1430,50 @@ class CtpBridge:
|
||||
return letters.upper() + (digits[-3:] if len(digits) >= 4 else digits)
|
||||
return letters.lower() + digits
|
||||
|
||||
def _get_contract_for_ths(self, ths_code: str) -> Any:
|
||||
"""按同花顺代码查 CTP 合约;精确匹配失败时在同交易所按品种前缀回退。"""
|
||||
if not self._engine:
|
||||
return None
|
||||
try:
|
||||
sym, ex_name = ths_to_vnpy_symbol(ths_code)
|
||||
exchange = to_vnpy_exchange(ex_name)
|
||||
vt_symbol = f"{sym}.{exchange.value}"
|
||||
contract = self._engine.get_contract(vt_symbol)
|
||||
if contract:
|
||||
return contract
|
||||
m = re.match(r"^([A-Za-z]+)", (ths_code or "").strip())
|
||||
if not m:
|
||||
return None
|
||||
letters = m.group(1)
|
||||
ex_val = exchange.value
|
||||
candidates: list[Any] = []
|
||||
get_all = getattr(self._engine, "get_all_contracts", None)
|
||||
pool = list(get_all()) if callable(get_all) else []
|
||||
if not pool:
|
||||
raw = getattr(self._engine, "contracts", None)
|
||||
if isinstance(raw, dict):
|
||||
pool = list(raw.values())
|
||||
sym_prefix = sym[: len(letters)] if sym else letters.lower()
|
||||
sym_prefix_up = letters.upper()
|
||||
for c in pool:
|
||||
c_ex = getattr(c, "exchange", None)
|
||||
c_ex_val = str(c_ex.value if hasattr(c_ex, "value") else c_ex or "")
|
||||
if c_ex_val != ex_val:
|
||||
continue
|
||||
c_sym = str(getattr(c, "symbol", "") or "")
|
||||
if (
|
||||
c_sym.lower().startswith(sym_prefix.lower())
|
||||
or c_sym.upper().startswith(sym_prefix_up)
|
||||
):
|
||||
candidates.append(c)
|
||||
if not candidates:
|
||||
return None
|
||||
candidates.sort(key=lambda c: str(getattr(c, "symbol", "") or ""))
|
||||
return candidates[0]
|
||||
except Exception as exc:
|
||||
logger.debug("_get_contract_for_ths %s: %s", ths_code, exc)
|
||||
return None
|
||||
|
||||
def estimate_margin_one_lot(
|
||||
self,
|
||||
ths_code: str,
|
||||
@@ -1247,29 +1481,35 @@ class CtpBridge:
|
||||
*,
|
||||
direction: str = "long",
|
||||
) -> Optional[float]:
|
||||
"""用 CTP 合约信息估算 1 手保证金(需已连接并完成合约查询)。"""
|
||||
"""1 手保证金:持仓实收 > CTP 保证金率查询 > 合约查询缓存。"""
|
||||
if not self._engine or not price or price <= 0:
|
||||
return None
|
||||
try:
|
||||
sym, ex_name = ths_to_vnpy_symbol(ths_code)
|
||||
exchange = to_vnpy_exchange(ex_name)
|
||||
vt_symbol = f"{sym}.{exchange.value}"
|
||||
contract = self._engine.get_contract(vt_symbol)
|
||||
if not contract:
|
||||
return None
|
||||
mult = float(getattr(contract, "size", 0) or 0)
|
||||
long_r = float(getattr(contract, "long_margin_ratio", 0) or 0)
|
||||
short_r = float(getattr(contract, "short_margin_ratio", 0) or 0)
|
||||
contract = self._get_contract_for_ths(ths_code)
|
||||
mult = float(getattr(contract, "size", 0) or 0) if contract else 0.0
|
||||
if mult <= 0:
|
||||
mult = float(get_contract_spec(ths_code).get("mult") or 0)
|
||||
d = (direction or "long").strip().lower()
|
||||
if d == "short" and short_r > 0:
|
||||
ratio = short_r
|
||||
elif d != "short" and long_r > 0:
|
||||
ratio = long_r
|
||||
if d == "max":
|
||||
per_lots = [
|
||||
self._lookup_margin_per_lot(sym, side)
|
||||
for side in ("long", "short")
|
||||
]
|
||||
per_lots = [x for x in per_lots if x > 0]
|
||||
if per_lots:
|
||||
return max(per_lots)
|
||||
else:
|
||||
ratio = max(long_r, short_r)
|
||||
if mult <= 0 or ratio <= 0:
|
||||
return None
|
||||
return round(float(price) * mult * ratio, 2)
|
||||
per_lot = self._lookup_margin_per_lot(sym, d)
|
||||
if per_lot > 0:
|
||||
return per_lot
|
||||
mode = self._connected_mode
|
||||
ratios = self._lookup_margin_ratios(sym, ex_name, mode=mode)
|
||||
if ratios:
|
||||
return self._margin_from_ratios(
|
||||
price, mult, ratios, direction=d,
|
||||
)
|
||||
return None
|
||||
except Exception as exc:
|
||||
logger.debug("estimate_margin_one_lot %s: %s", ths_code, exc)
|
||||
return None
|
||||
@@ -1308,9 +1548,7 @@ class CtpBridge:
|
||||
return None
|
||||
try:
|
||||
sym, ex_name = ths_to_vnpy_symbol(ths_code)
|
||||
exchange = to_vnpy_exchange(ex_name)
|
||||
vt_symbol = f"{sym}.{exchange.value}"
|
||||
contract = self._engine.get_contract(vt_symbol)
|
||||
contract = self._get_contract_for_ths(ths_code)
|
||||
if not contract:
|
||||
return None
|
||||
mult = float(getattr(contract, "size", 0) or 0)
|
||||
@@ -1324,6 +1562,18 @@ class CtpBridge:
|
||||
out: dict[str, Any] = {"mult": mult}
|
||||
if tick > 0:
|
||||
out["tick_size"] = tick
|
||||
long_r = float(getattr(contract, "long_margin_ratio", 0) or 0)
|
||||
short_r = float(getattr(contract, "short_margin_ratio", 0) or 0)
|
||||
c_sym = str(getattr(contract, "symbol", "") or sym or "")
|
||||
if c_sym and self._connected_mode:
|
||||
queried = self._lookup_margin_ratios(
|
||||
c_sym, ex_name, mode=self._connected_mode,
|
||||
)
|
||||
if queried:
|
||||
long_r = float(queried.get("long") or long_r)
|
||||
short_r = float(queried.get("short") or short_r)
|
||||
if long_r > 0 or short_r > 0:
|
||||
out["margin_rate"] = max(long_r, short_r)
|
||||
return out
|
||||
except Exception as exc:
|
||||
logger.debug("lookup_contract_spec %s: %s", ths_code, exc)
|
||||
@@ -1763,6 +2013,23 @@ def ctp_get_account(mode: str) -> dict[str, Any]:
|
||||
return b.get_account()
|
||||
|
||||
|
||||
def ctp_sum_position_margins(
|
||||
mode: str,
|
||||
*,
|
||||
refresh_if_empty: bool = True,
|
||||
refresh_margin: bool = False,
|
||||
) -> float:
|
||||
"""各持仓 CTP 回报保证金之和(与柜台「实收保证金」一致)。"""
|
||||
total = 0.0
|
||||
for p in ctp_list_positions(
|
||||
mode, refresh_if_empty=refresh_if_empty, refresh_margin=refresh_margin,
|
||||
):
|
||||
m = float(p.get("margin") or 0)
|
||||
if m > 0:
|
||||
total += m
|
||||
return round(total, 2) if total > 0 else 0.0
|
||||
|
||||
|
||||
def ctp_account_margin_used(mode: str) -> Optional[float]:
|
||||
"""账户实际占用保证金 ≈ 权益 − 可用(与顶栏柜台资金一致)。"""
|
||||
b = get_bridge()
|
||||
|
||||
Reference in New Issue
Block a user