From 19676943d027ea953ce7018e2227a6b2a48b04e1 Mon Sep 17 00:00:00 2001 From: dekun Date: Mon, 29 Jun 2026 10:21:44 +0800 Subject: [PATCH] Align margin display with CTP counter rates and position margin. Read margin ratios from CTP instrument query and margin-rate API instead of vnpy ContractData (which lacks ratios). Keep occupied margin on position UseMargin; use per-lot max rate for recommend table. Co-authored-by: Cursor --- contract_specs.py | 3 + install_trading.py | 49 ++++--- product_recommend.py | 2 +- recommend_store.py | 21 ++- vnpy_bridge.py | 307 ++++++++++++++++++++++++++++++++++++++++--- 5 files changed, 332 insertions(+), 50 deletions(-) diff --git a/contract_specs.py b/contract_specs.py index d30d4ff..ba085ef 100644 --- a/contract_specs.py +++ b/contract_specs.py @@ -92,6 +92,7 @@ def margin_one_lot( ) -> tuple[float, str, dict]: """1 手保证金。CTP 已连接时优先读柜台合约保证金率,否则用本地参考规格估算。 + direction 可为 long / short / max(多空费率取较大值,用于可开仓品种表)。 返回 (保证金, 来源 estimate|ctp, 合约规格片段)。 """ spec = get_contract_spec(ths_code) @@ -113,6 +114,8 @@ def margin_one_lot( merged["mult"] = ctp_spec["mult"] if ctp_spec.get("tick_size"): merged["tick_size"] = ctp_spec["tick_size"] + if ctp_spec.get("margin_rate"): + merged["margin_rate"] = ctp_spec["margin_rate"] return float(ctp_margin), "ctp", merged except Exception: pass diff --git a/install_trading.py b/install_trading.py index 2569b5d..d4dc572 100644 --- a/install_trading.py +++ b/install_trading.py @@ -287,27 +287,42 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se mode: str, capital: float, ) -> list[dict]: - """CTP 已连接时,用「权益−可用」校正占用保证金与仓位占比。""" + """仅在持仓缺少柜台保证金时补全;已有 CTP 持仓保证金的行不覆盖。""" if not ctp_status(mode).get("connected"): return rows - total_used = ctp_account_margin_used(mode) - if not total_used: - return rows active = [ r for r in rows if r.get("order_state") != "pending" and int(r.get("lots") or 0) > 0 ] if not active: return rows - if len(active) == 1: - row = active[0] - row["margin"] = total_used - row["margin_source"] = "ctp" - if capital > 0: - row["position_pct"] = round(total_used / capital * 100, 2) - return rows - weights: list[float] = [] + + def _has_ctp_margin(row: dict) -> bool: + return ( + float(row.get("margin") or 0) > 0 + and row.get("margin_source") == "ctp" + ) + + without_margin = [r for r in active if not _has_ctp_margin(r)] for row in active: + if _has_ctp_margin(row) and capital > 0: + m = float(row.get("margin") or 0) + row["position_pct"] = round(m / capital * 100, 2) + if not without_margin: + return rows + + total_used = ctp_account_margin_used(mode) + if not total_used: + return rows + known_sum = sum( + float(r.get("margin") or 0) for r in active if _has_ctp_margin(r) + ) + pool = max(0.0, float(total_used) - known_sum) if known_sum > 0 else float(total_used) + if pool <= 0: + return rows + + weights: list[float] = [] + for row in without_margin: sym = (row.get("symbol_code") or "").strip() lots = int(row.get("lots") or 0) entry = float(row.get("entry_price") or 0) @@ -318,13 +333,13 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se weights.append(0.0) total_weight = sum(weights) assigned = 0.0 - for i, row in enumerate(active): + for i, row in enumerate(without_margin): if total_weight <= 0: - margin = round(total_used / len(active), 2) - elif i == len(active) - 1: - margin = round(total_used - assigned, 2) + margin = round(pool / len(without_margin), 2) + elif i == len(without_margin) - 1: + margin = round(pool - assigned, 2) else: - margin = round(total_used * weights[i] / total_weight, 2) + margin = round(pool * weights[i] / total_weight, 2) assigned += margin row["margin"] = margin row["margin_source"] = "ctp" diff --git a/product_recommend.py b/product_recommend.py index 35769ab..3889409 100644 --- a/product_recommend.py +++ b/product_recommend.py @@ -216,7 +216,7 @@ def assess_product_for_capital( code_for_margin = (main_code or "").strip() or (ths + "8888") if p > 0 and ctp_connected: margin_one, margin_source, spec_used = margin_one_lot( - code_for_margin, p, trading_mode=trading_mode, + code_for_margin, p, direction="max", trading_mode=trading_mode, ) if spec_used.get("mult"): mult = spec_used["mult"] diff --git a/recommend_store.py b/recommend_store.py index edce71c..c5287b0 100644 --- a/recommend_store.py +++ b/recommend_store.py @@ -118,23 +118,19 @@ def _ctp_connected_for_mode(trading_mode: str) -> bool: def recommend_margin_used(trading_mode: str) -> float: - """当前持仓已占用保证金(CTP 柜台优先)。""" + """当前持仓已占用保证金(各持仓 CTP 回报之和,与柜台持仓保证金一致)。""" if not _ctp_connected_for_mode(trading_mode): return 0.0 try: - from vnpy_bridge import ctp_account_margin_used, ctp_list_positions + from vnpy_bridge import ctp_account_margin_used, ctp_sum_position_margins - used = ctp_account_margin_used(trading_mode) - if used is not None and used > 0: - return float(used) - total = 0.0 - for p in ctp_list_positions( + total = ctp_sum_position_margins( trading_mode, refresh_if_empty=False, refresh_margin=True, - ): - m = float(p.get("margin") or 0) - if m > 0: - total += m - return round(total, 2) if total > 0 else 0.0 + ) + if total > 0: + return total + used = ctp_account_margin_used(trading_mode) + return float(used) if used and used > 0 else 0.0 except Exception as exc: logger.debug("recommend_margin_used: %s", exc) return 0.0 @@ -196,6 +192,7 @@ def enrich_recommend_rows( margin_one, margin_source, spec_used = margin_one_lot( code_for_margin, price, + direction="max", trading_mode=trading_mode if ctp_connected else None, ) if spec_used.get("mult"): diff --git a/vnpy_bridge.py b/vnpy_bridge.py index d7a58e4..95667eb 100644 --- a/vnpy_bridge.py +++ b/vnpy_bridge.py @@ -8,6 +8,7 @@ from __future__ import annotations import logging import os +import re import threading import time from collections import deque @@ -237,6 +238,12 @@ class CtpBridge: self._commission_waiters: dict[int, threading.Event] = {} self._commission_lists: dict[int, list] = {} self._commission_hooked = False + self._margin_rate_waiters: dict[int, threading.Event] = {} + self._margin_rate_lists: dict[int, list] = {} + self._margin_rate_hooked = False + self._instrument_hooked = False + self._instrument_margin_ratios: dict[str, dict[str, float]] = {} + self._margin_per_lot: dict[str, float] = {} self._subscribed: set[str] = set() self._last_position_query_ts: float = 0.0 self._position_margins: dict[str, float] = {} @@ -305,6 +312,10 @@ class CtpBridge: raw = float(getattr(pos, attr, 0) or 0) if raw > 0: self._position_margins[self._position_margin_key(sym, d)] = raw + if vol > 0: + self._margin_per_lot[self._position_margin_key(sym, d)] = round( + raw / vol, 2, + ) break except Exception as exc: logger.debug("position margin cache: %s", exc) @@ -637,6 +648,7 @@ class CtpBridge: "请更新 .env 中 SIMNOW_TD_ADDRESS 为官网最新地址," "并在服务器执行 nc -zv 验证出网。" ) + self._ensure_instrument_margin_hooks() self._engine.connect(setting, GATEWAY_NAME) if self._wait_connected(mode, ctp_logs): self._connected_mode = mode @@ -943,6 +955,184 @@ class CtpBridge: """批量查询全部合约手续费(InstrumentID 留空)。""" return self._query_commission(mode=mode, timeout=45) + @staticmethod + def _parse_margin_ratio_row(data: dict) -> dict[str, float]: + long_r = float( + data.get("LongMarginRatioByMoney") + or data.get("LongMarginRatio") + or 0 + ) + short_r = float( + data.get("ShortMarginRatioByMoney") + or data.get("ShortMarginRatio") + or 0 + ) + return {"long": long_r, "short": short_r} + + def _cache_margin_ratio(self, sym: str, data: dict) -> None: + ratios = self._parse_margin_ratio_row(data) + if ratios["long"] <= 0 and ratios["short"] <= 0: + return + key = (sym or "").strip().lower() + if not key: + return + self._instrument_margin_ratios[key] = ratios + + def _ensure_instrument_margin_hooks(self) -> None: + """登录前挂钩:合约查询回报缓存保证金率;支持按需 reqQryInstrumentMarginRate。""" + if not self._engine: + return + try: + gw = self._engine.get_gateway(GATEWAY_NAME) + td = gw.td_api + except Exception: + return + bridge = self + + if not self._instrument_hooked: + orig = td.onRspQryInstrument + + def on_instrument(data: dict, error: dict, reqid: int, last: bool) -> None: + try: + if data and data.get("InstrumentID"): + bridge._cache_margin_ratio(str(data["InstrumentID"]), data) + except Exception as exc: + logger.debug("instrument margin cache: %s", exc) + return orig(data, error, reqid, last) + + td.onRspQryInstrument = on_instrument # type: ignore[method-assign] + self._instrument_hooked = True + + if self._margin_rate_hooked: + return + + def on_margin_rate(data: dict, error: dict, reqid: int, last: bool) -> None: + if error and int(error.get("ErrorID") or 0) != 0: + logger.debug( + "CTP margin rate error reqid=%s: %s", + reqid, + error.get("ErrorMsg") or error, + ) + if data and data.get("InstrumentID"): + bridge._margin_rate_lists.setdefault(reqid, []).append(dict(data)) + ev = bridge._margin_rate_waiters.get(reqid) + if last and ev: + ev.set() + + td.onRspQryInstrumentMarginRate = on_margin_rate # type: ignore[method-assign] + self._margin_rate_hooked = True + + def _query_instrument_margin_rate( + self, + *, + mode: str, + instrument_id: str, + exchange_id: str, + timeout: float = 6, + ) -> Optional[dict[str, float]]: + if self._connected_mode != mode or not self._engine: + return None + sym = (instrument_id or "").strip() + if not sym: + return None + cached = self._instrument_margin_ratios.get(sym.lower()) + if cached and (cached.get("long", 0) > 0 or cached.get("short", 0) > 0): + return cached + try: + gw = self._engine.get_gateway(GATEWAY_NAME) + td = gw.td_api + except Exception as exc: + logger.debug("margin rate query init: %s", exc) + return None + if not getattr(td, "login_status", False): + return None + if not hasattr(td, "reqQryInstrumentMarginRate"): + return None + self._ensure_instrument_margin_hooks() + reqid = int(getattr(td, "reqid", 0)) + 1 + td.reqid = reqid + ev = threading.Event() + self._margin_rate_waiters[reqid] = ev + req = { + "BrokerID": td.brokerid, + "InvestorID": td.userid, + "InstrumentID": sym, + "ExchangeID": exchange_id or "", + "InvestorRange": "1", + "HedgeFlag": "1", + } + with _ctp_td_lock: + ret = td.reqQryInstrumentMarginRate(req, reqid) + if ret != 0: + self._margin_rate_waiters.pop(reqid, None) + return None + ev.wait(timeout=timeout) + self._margin_rate_waiters.pop(reqid, None) + rows = self._margin_rate_lists.pop(reqid, []) + if not rows: + return None + ratios = self._parse_margin_ratio_row(rows[-1]) + if ratios["long"] > 0 or ratios["short"] > 0: + self._cache_margin_ratio(sym, rows[-1]) + return ratios + return None + + def _lookup_margin_ratios( + self, + sym: str, + ex_name: str, + *, + mode: Optional[str] = None, + ) -> Optional[dict[str, float]]: + key = (sym or "").strip().lower() + if not key: + return None + cached = self._instrument_margin_ratios.get(key) + if cached and (cached.get("long", 0) > 0 or cached.get("short", 0) > 0): + return cached + if mode and self._connected_mode == mode: + return self._query_instrument_margin_rate( + mode=mode, + instrument_id=sym, + exchange_id=ex_name, + ) + return None + + def _lookup_margin_per_lot(self, sym: str, direction: str) -> float: + return float( + self._margin_per_lot.get(self._position_margin_key(sym, direction), 0) or 0 + ) + + def _margin_from_ratios( + self, + price: float, + mult: float, + ratios: dict[str, float], + *, + direction: str, + ) -> Optional[float]: + long_r = float(ratios.get("long") or 0) + short_r = float(ratios.get("short") or 0) + d = (direction or "long").strip().lower() + if mult <= 0 or price <= 0: + return None + if d == "max": + candidates = [ + round(float(price) * mult * r, 2) + for r in (long_r, short_r) + if r > 0 + ] + return max(candidates) if candidates else None + if d == "short" and short_r > 0: + ratio = short_r + elif d != "short" and long_r > 0: + ratio = long_r + else: + ratio = max(long_r, short_r) + if ratio <= 0: + return None + return round(float(price) * mult * ratio, 2) + def _tick_key(self, symbol: str, ex_name: str) -> str: return f"{symbol.lower()}:{ex_name.upper()}" @@ -1240,6 +1430,50 @@ class CtpBridge: return letters.upper() + (digits[-3:] if len(digits) >= 4 else digits) return letters.lower() + digits + def _get_contract_for_ths(self, ths_code: str) -> Any: + """按同花顺代码查 CTP 合约;精确匹配失败时在同交易所按品种前缀回退。""" + if not self._engine: + return None + try: + sym, ex_name = ths_to_vnpy_symbol(ths_code) + exchange = to_vnpy_exchange(ex_name) + vt_symbol = f"{sym}.{exchange.value}" + contract = self._engine.get_contract(vt_symbol) + if contract: + return contract + m = re.match(r"^([A-Za-z]+)", (ths_code or "").strip()) + if not m: + return None + letters = m.group(1) + ex_val = exchange.value + candidates: list[Any] = [] + get_all = getattr(self._engine, "get_all_contracts", None) + pool = list(get_all()) if callable(get_all) else [] + if not pool: + raw = getattr(self._engine, "contracts", None) + if isinstance(raw, dict): + pool = list(raw.values()) + sym_prefix = sym[: len(letters)] if sym else letters.lower() + sym_prefix_up = letters.upper() + for c in pool: + c_ex = getattr(c, "exchange", None) + c_ex_val = str(c_ex.value if hasattr(c_ex, "value") else c_ex or "") + if c_ex_val != ex_val: + continue + c_sym = str(getattr(c, "symbol", "") or "") + if ( + c_sym.lower().startswith(sym_prefix.lower()) + or c_sym.upper().startswith(sym_prefix_up) + ): + candidates.append(c) + if not candidates: + return None + candidates.sort(key=lambda c: str(getattr(c, "symbol", "") or "")) + return candidates[0] + except Exception as exc: + logger.debug("_get_contract_for_ths %s: %s", ths_code, exc) + return None + def estimate_margin_one_lot( self, ths_code: str, @@ -1247,29 +1481,35 @@ class CtpBridge: *, direction: str = "long", ) -> Optional[float]: - """用 CTP 合约信息估算 1 手保证金(需已连接并完成合约查询)。""" + """1 手保证金:持仓实收 > CTP 保证金率查询 > 合约查询缓存。""" if not self._engine or not price or price <= 0: return None try: sym, ex_name = ths_to_vnpy_symbol(ths_code) - exchange = to_vnpy_exchange(ex_name) - vt_symbol = f"{sym}.{exchange.value}" - contract = self._engine.get_contract(vt_symbol) - if not contract: - return None - mult = float(getattr(contract, "size", 0) or 0) - long_r = float(getattr(contract, "long_margin_ratio", 0) or 0) - short_r = float(getattr(contract, "short_margin_ratio", 0) or 0) + contract = self._get_contract_for_ths(ths_code) + mult = float(getattr(contract, "size", 0) or 0) if contract else 0.0 + if mult <= 0: + mult = float(get_contract_spec(ths_code).get("mult") or 0) d = (direction or "long").strip().lower() - if d == "short" and short_r > 0: - ratio = short_r - elif d != "short" and long_r > 0: - ratio = long_r + if d == "max": + per_lots = [ + self._lookup_margin_per_lot(sym, side) + for side in ("long", "short") + ] + per_lots = [x for x in per_lots if x > 0] + if per_lots: + return max(per_lots) else: - ratio = max(long_r, short_r) - if mult <= 0 or ratio <= 0: - return None - return round(float(price) * mult * ratio, 2) + per_lot = self._lookup_margin_per_lot(sym, d) + if per_lot > 0: + return per_lot + mode = self._connected_mode + ratios = self._lookup_margin_ratios(sym, ex_name, mode=mode) + if ratios: + return self._margin_from_ratios( + price, mult, ratios, direction=d, + ) + return None except Exception as exc: logger.debug("estimate_margin_one_lot %s: %s", ths_code, exc) return None @@ -1308,9 +1548,7 @@ class CtpBridge: return None try: sym, ex_name = ths_to_vnpy_symbol(ths_code) - exchange = to_vnpy_exchange(ex_name) - vt_symbol = f"{sym}.{exchange.value}" - contract = self._engine.get_contract(vt_symbol) + contract = self._get_contract_for_ths(ths_code) if not contract: return None mult = float(getattr(contract, "size", 0) or 0) @@ -1324,6 +1562,18 @@ class CtpBridge: out: dict[str, Any] = {"mult": mult} if tick > 0: out["tick_size"] = tick + long_r = float(getattr(contract, "long_margin_ratio", 0) or 0) + short_r = float(getattr(contract, "short_margin_ratio", 0) or 0) + c_sym = str(getattr(contract, "symbol", "") or sym or "") + if c_sym and self._connected_mode: + queried = self._lookup_margin_ratios( + c_sym, ex_name, mode=self._connected_mode, + ) + if queried: + long_r = float(queried.get("long") or long_r) + short_r = float(queried.get("short") or short_r) + if long_r > 0 or short_r > 0: + out["margin_rate"] = max(long_r, short_r) return out except Exception as exc: logger.debug("lookup_contract_spec %s: %s", ths_code, exc) @@ -1763,6 +2013,23 @@ def ctp_get_account(mode: str) -> dict[str, Any]: return b.get_account() +def ctp_sum_position_margins( + mode: str, + *, + refresh_if_empty: bool = True, + refresh_margin: bool = False, +) -> float: + """各持仓 CTP 回报保证金之和(与柜台「实收保证金」一致)。""" + total = 0.0 + for p in ctp_list_positions( + mode, refresh_if_empty=refresh_if_empty, refresh_margin=refresh_margin, + ): + m = float(p.get("margin") or 0) + if m > 0: + total += m + return round(total, 2) if total > 0 else 0.0 + + def ctp_account_margin_used(mode: str) -> Optional[float]: """账户实际占用保证金 ≈ 权益 − 可用(与顶栏柜台资金一致)。""" b = get_bridge()