diff --git a/modules/ctp/ctp_entry_price.py b/modules/ctp/ctp_entry_price.py index 929569e..dccfbb1 100644 --- a/modules/ctp/ctp_entry_price.py +++ b/modules/ctp/ctp_entry_price.py @@ -4,7 +4,7 @@ """CTP 持仓均价:优先 CTP OpenCost(柜台开仓均价),其次成交加权。""" from __future__ import annotations -from typing import Any, Optional +from typing import Any, Callable, Optional from modules.core.contract_specs import get_contract_spec from modules.ctp.ctp_symbol import ths_to_vnpy_symbol @@ -99,14 +99,20 @@ def resolve_ctp_entry( trades: Optional[list[dict[str, Any]]] = None, *, tick: Optional[float] = None, + open_avg_lookup: Optional[Callable[[str, str], float]] = None, ) -> tuple[float, str]: - """均价:优先 avg_price(OpenCost),否则成交加权。""" - if not ctp: - return 0.0, "none" - pos_avg = float(ctp.get("avg_price") or 0) - if pos_avg > 0: - return round_to_tick(pos_avg, sym), "ctp" - trade_avg = compute_open_avg_from_trades(sym, direction or "long", trades) + """均价:OpenCost 缓存 → 成交加权 → vnpy PositionCost。""" + del tick + want = (direction or "long").strip().lower() + if open_avg_lookup: + cached = float(open_avg_lookup(sym, want) or 0) + if cached > 0: + return round_to_tick(cached, sym), "open_cost" + trade_avg = compute_open_avg_from_trades(sym, want, trades) if trade_avg > 0: return round_to_tick(trade_avg, sym), "trades" + if ctp: + pos_avg = float(ctp.get("avg_price") or 0) + if pos_avg > 0: + return round_to_tick(pos_avg, sym), "position_cost" return 0.0, "none" diff --git a/modules/ctp/vnpy_bridge.py b/modules/ctp/vnpy_bridge.py index 44407ac..0f4bbe7 100644 --- a/modules/ctp/vnpy_bridge.py +++ b/modules/ctp/vnpy_bridge.py @@ -344,6 +344,7 @@ class CtpBridge: self._position_open_times: dict[str, str] = {} self._position_open_avg: dict[str, float] = {} self._position_open_cost_acc: dict[str, dict[str, float]] = {} + self._last_open_cost_query_ts: float = 0.0 self._margin_hooked = False self._trade_hooked = False self._trade_query_results: list[dict[str, Any]] = [] @@ -558,11 +559,33 @@ class CtpBridge: logger.debug("position_row_from_vnpy: %s", exc) return None + def _open_cost_cache_incomplete(self) -> bool: + if not self._engine: + return False + for pos in self._engine.get_all_positions(): + vol = int(getattr(pos, "volume", 0) or 0) + if vol <= 0: + continue + sym = getattr(pos, "symbol", "") or "" + d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short" + if self._lookup_position_open_avg(sym, d) <= 0: + return True + return False + def calibrate_trading_state(self) -> None: """全量校准内存簿(读 vnpy 缓存,不 query 柜台)。""" try: from modules.ctp.ctp_trading_state import trading_state + if self._connected_mode and self._open_cost_cache_incomplete(): + now = time.monotonic() + if now - self._last_open_cost_query_ts >= 45.0: + self._last_open_cost_query_ts = now + try: + self.request_position_snapshot(force=True) + except Exception as exc: + logger.debug("open cost refresh query: %s", exc) + with _ctp_td_lock: orders = self.list_active_orders() positions = self._collect_positions() @@ -1218,6 +1241,8 @@ class CtpBridge: ) -> None: try: if data: + if not bridge._position_open_cost_acc: + bridge._position_open_avg.clear() bridge._ingest_position_open_cost(data) if last: bridge._finalize_position_open_cost_acc() @@ -1627,13 +1652,17 @@ class CtpBridge: } def _position_margin_key(self, sym: str, direction: str) -> str: - return f"{(sym or '').lower()}:{(direction or 'long').strip().lower()}" + base = (sym or "").split(".")[0].strip().lower() + return f"{base}:{(direction or 'long').strip().lower()}" @staticmethod def _direction_from_ctp_posi(posi: Any) -> str: - s = str(posi or "").strip().upper() - if s in ("2", "SHORT", "NET_SHORT"): + """CTP PosiDirection: 2=多头, 3=空头。""" + s = str(posi or "").strip() + if s in ("3",) or s.upper() in ("SHORT", "NET_SHORT"): return "short" + if s in ("2",) or s.upper() in ("LONG", "NET_LONG"): + return "long" return "long" def _contract_mult(self, sym: str, ex_name: str = "") -> float: @@ -1650,7 +1679,9 @@ class CtpBridge: vol = int(data.get("Position") or 0) if vol <= 0: return - open_cost = float(data.get("OpenCost") or 0) + open_cost = float( + data.get("OpenCost") or data.get("open_cost") or data.get("OpenAmount") or 0 + ) if open_cost <= 0: return sym = str(data["InstrumentID"]) @@ -1692,9 +1723,18 @@ class CtpBridge: if cached > 0: return cached try: - from modules.ctp.ctp_entry_price import compute_open_avg_from_trades + from modules.ctp.ctp_entry_price import compute_open_avg_from_trades, resolve_ctp_entry trades = self.list_trades() + entry, src = resolve_ctp_entry( + sym, + direction, + {"avg_price": float(getattr(pos, "price", 0) or 0)}, + trades, + open_avg_lookup=self._lookup_position_open_avg, + ) + if entry > 0 and src != "position_cost": + return entry trade_avg = compute_open_avg_from_trades(sym, direction, trades) if trade_avg > 0: return trade_avg diff --git a/modules/trading/install.py b/modules/trading/install.py index 35107ba..93c2db1 100644 --- a/modules/trading/install.py +++ b/modules/trading/install.py @@ -831,13 +831,23 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se direction: str, ctp: Optional[dict], ) -> tuple[float, str]: - del mode, direction - if not ctp: - return 0.0, "none" - avg = float(ctp.get("avg_price") or 0) - if avg > 0: - return round_to_tick(avg, sym), "ctp" - return 0.0, "none" + from modules.ctp.ctp_entry_price import resolve_ctp_entry + + trades = None + open_avg_lookup = None + if ctp_status(mode).get("connected"): + try: + trades = ctp_list_trades(mode) + open_avg_lookup = get_bridge()._lookup_position_open_avg + except Exception: + pass + return resolve_ctp_entry( + sym, + direction, + ctp, + trades, + open_avg_lookup=open_avg_lookup, + ) def _open_commission_from_ctp_trades( mode: str, sym: str, direction: str, @@ -1492,18 +1502,20 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se if lots <= 0: return None + entry_src = "monitor" if mon else "ctp" if ctp: ctp_lots = int(ctp.get("lots") or 0) if ctp_lots > 0: lots = ctp_lots ths_sym = _ctp_pos_to_ths_code(ctp) or sym - resolved_entry, _entry_src = _resolve_ctp_entry_price( + resolved_entry, entry_src = _resolve_ctp_entry_price( mode, ths_sym, direction, ctp, ) if resolved_entry > 0: entry = resolved_entry elif float(ctp.get("avg_price") or 0) > 0: entry = float(ctp.get("avg_price") or 0) + entry_src = "position_cost" ctp_margin = float(ctp.get("margin") or 0) if (margin is None or float(margin or 0) <= 0) and ctp_margin > 0: margin = ctp_margin @@ -1536,7 +1548,7 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se float_pnl = pos_tmp.get("float_pnl") if ctp and ctp_status(mode).get("connected"): ctp_pnl = float(ctp.get("pnl") or 0) - if ctp_pnl != 0: + if entry_src == "position_cost" and ctp_pnl != 0: float_pnl = round(ctp_pnl, 2) fee_info = calc_fee_breakdown(