diff --git a/modules/ctp/ctp_entry_price.py b/modules/ctp/ctp_entry_price.py index 45aa5c5..929569e 100644 --- a/modules/ctp/ctp_entry_price.py +++ b/modules/ctp/ctp_entry_price.py @@ -1,7 +1,7 @@ # Copyright (c) 2025-2026 马建军. All rights reserved. # 详见 LICENSE.zh-CN.txt -"""CTP 持仓均价:仅使用柜台持仓回报(vnpy pos.price = PositionCost 加权)。""" +"""CTP 持仓均价:优先 CTP OpenCost(柜台开仓均价),其次成交加权。""" from __future__ import annotations from typing import Any, Optional @@ -45,6 +45,53 @@ def round_to_tick(price: float, sym: str) -> float: return round(round(price / tick) * tick, 4) +def compute_open_avg_from_trades( + sym: str, + direction: str, + trades: Optional[list[dict[str, Any]]], +) -> float: + """按开仓成交 FIFO 还原剩余持仓的开仓均价。""" + if not trades: + return 0.0 + want = (direction or "long").strip().lower() + open_vol = 0.0 + open_cost = 0.0 + for t in sorted(trades, key=lambda x: x.get("datetime") or ""): + if (t.get("offset") or "").strip().lower() != "open": + continue + pos_dir = (t.get("position_direction") or t.get("direction") or "long").strip().lower() + if pos_dir != want: + continue + if not symbols_match(t.get("symbol") or "", sym): + continue + lots = float(int(t.get("lots") or 0)) + px = float(t.get("price") or 0) + if lots <= 0 or px <= 0: + continue + open_vol += lots + open_cost += px * lots + if open_vol <= 0: + return 0.0 + for t in sorted(trades, key=lambda x: x.get("datetime") or ""): + if (t.get("offset") or "").strip().lower() != "close": + continue + pos_dir = (t.get("position_direction") or t.get("direction") or "long").strip().lower() + if pos_dir != want: + continue + if not symbols_match(t.get("symbol") or "", sym): + continue + lots = float(int(t.get("lots") or 0)) + if lots <= 0 or open_vol <= 0: + continue + avg = open_cost / open_vol + dec = min(lots, open_vol) + open_cost -= avg * dec + open_vol -= dec + if open_vol <= 0: + return 0.0 + return round(open_cost / open_vol, 4) + + def resolve_ctp_entry( sym: str, direction: str, @@ -53,11 +100,13 @@ def resolve_ctp_entry( *, tick: Optional[float] = None, ) -> tuple[float, str]: - """均价:仅柜台持仓价(trades/tick 参数保留兼容,不参与计算)。""" - del direction, trades, tick + """均价:优先 avg_price(OpenCost),否则成交加权。""" if not ctp: return 0.0, "none" pos_avg = float(ctp.get("avg_price") or 0) if pos_avg > 0: return round_to_tick(pos_avg, sym), "ctp" + trade_avg = compute_open_avg_from_trades(sym, direction or "long", trades) + if trade_avg > 0: + return round_to_tick(trade_avg, sym), "trades" return 0.0, "none" diff --git a/modules/ctp/vnpy_bridge.py b/modules/ctp/vnpy_bridge.py index b037d68..44407ac 100644 --- a/modules/ctp/vnpy_bridge.py +++ b/modules/ctp/vnpy_bridge.py @@ -342,6 +342,8 @@ class CtpBridge: self._last_position_query_ts: float = 0.0 self._position_margins: dict[str, float] = {} self._position_open_times: dict[str, str] = {} + self._position_open_avg: dict[str, float] = {} + self._position_open_cost_acc: dict[str, dict[str, float]] = {} self._margin_hooked = False self._trade_hooked = False self._trade_query_results: list[dict[str, Any]] = [] @@ -524,7 +526,7 @@ class CtpBridge: sym = getattr(pos, "symbol", "") or "" exchange = getattr(pos, "exchange", None) ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") - price = float(getattr(pos, "price", 0) or 0) + price = self._position_avg_from_vnpy(pos, sym=sym, ex_name=ex_name, direction=d) yd = int(getattr(pos, "yd_volume", 0) or 0) td = max(0, vol - yd) margin = self.estimate_position_margin(sym, ex_name, d, vol, price, pos=pos) @@ -1214,6 +1216,13 @@ class CtpBridge: def on_rsp_position( data: dict, error: dict, reqid: int, last: bool, ) -> None: + try: + if data: + bridge._ingest_position_open_cost(data) + if last: + bridge._finalize_position_open_cost_acc() + except Exception as exc: + logger.debug("position open avg cache: %s", exc) ret = orig_pos(data, error, reqid, last) if last: now = time.monotonic() @@ -1620,6 +1629,79 @@ class CtpBridge: def _position_margin_key(self, sym: str, direction: str) -> str: return f"{(sym or '').lower()}:{(direction or 'long').strip().lower()}" + @staticmethod + def _direction_from_ctp_posi(posi: Any) -> str: + s = str(posi or "").strip().upper() + if s in ("2", "SHORT", "NET_SHORT"): + return "short" + return "long" + + def _contract_mult(self, sym: str, ex_name: str = "") -> float: + ths = self._vnpy_sym_to_ths(sym, ex_name) or sym + try: + return float(get_contract_spec(ths).get("mult") or 0) + except Exception: + return 0.0 + + def _ingest_position_open_cost(self, data: dict) -> None: + """累积 CTP 持仓回报 OpenCost,用于计算开仓均价(柜台 开仓均价)。""" + if not data or not data.get("InstrumentID"): + return + vol = int(data.get("Position") or 0) + if vol <= 0: + return + open_cost = float(data.get("OpenCost") or 0) + if open_cost <= 0: + return + sym = str(data["InstrumentID"]) + direction = self._direction_from_ctp_posi(data.get("PosiDirection")) + mult = self._contract_mult(sym) + if mult <= 0: + return + key = self._position_margin_key(sym, direction) + acc = self._position_open_cost_acc.setdefault( + key, {"open_cost": 0.0, "vol": 0.0, "mult": mult}, + ) + acc["open_cost"] += open_cost + acc["vol"] += float(vol) + acc["mult"] = mult + + def _finalize_position_open_cost_acc(self) -> None: + for key, acc in list(self._position_open_cost_acc.items()): + vol = float(acc.get("vol") or 0) + mult = float(acc.get("mult") or 0) + open_cost = float(acc.get("open_cost") or 0) + if vol > 0 and mult > 0 and open_cost > 0: + self._position_open_avg[key] = open_cost / (vol * mult) + self._position_open_cost_acc.clear() + + def _lookup_position_open_avg(self, sym: str, direction: str) -> float: + return float( + self._position_open_avg.get(self._position_margin_key(sym, direction), 0) or 0 + ) + + def _position_avg_from_vnpy( + self, + pos: Any, + *, + sym: str, + ex_name: str, + direction: str, + ) -> float: + cached = self._lookup_position_open_avg(sym, direction) + if cached > 0: + return cached + try: + from modules.ctp.ctp_entry_price import compute_open_avg_from_trades + + trades = self.list_trades() + trade_avg = compute_open_avg_from_trades(sym, direction, trades) + if trade_avg > 0: + return trade_avg + except Exception as exc: + logger.debug("position avg from trades: %s", exc) + return float(getattr(pos, "price", 0) or 0) + def _lookup_position_open_time(self, sym: str, direction: str) -> str: return (self._position_open_times.get(self._position_margin_key(sym, direction)) or "").strip() @@ -1825,7 +1907,9 @@ class CtpBridge: sym = getattr(pos, "symbol", "") or "" exchange = getattr(pos, "exchange", None) ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") - price = float(getattr(pos, "price", 0) or 0) + price = self._position_avg_from_vnpy( + pos, sym=sym, ex_name=ex_name, direction=d, + ) margin = self.estimate_position_margin( sym, ex_name, d, vol, price, pos=pos, )