From e6208e403efd38c927e22a0c1104f0b1b3c2f890 Mon Sep 17 00:00:00 2001 From: dekun Date: Tue, 30 Jun 2026 10:42:56 +0800 Subject: [PATCH] Fix roll average entry: CTP trade-weighted avg, sync after fill, live entry for preview. Co-authored-by: Cursor --- ctp_entry_price.py | 91 +++++++++++++++++++ ctp_trading_state.py | 65 +++++++++++--- install_trading.py | 125 +++++++++++++++----------- strategy/strategy_roll_monitor_lib.py | 9 +- vnpy_bridge.py | 17 +++- 5 files changed, 239 insertions(+), 68 deletions(-) create mode 100644 ctp_entry_price.py diff --git a/ctp_entry_price.py b/ctp_entry_price.py new file mode 100644 index 0000000..e66fd53 --- /dev/null +++ b/ctp_entry_price.py @@ -0,0 +1,91 @@ +# Copyright (c) 2025-2026 马建军. All rights reserved. +# 详见 LICENSE.zh-CN.txt + +"""CTP 持仓均价:成交加权 / 柜台持仓价(滚仓加仓后以柜台为准)。""" +from __future__ import annotations + +from typing import Any, Optional + +from ctp_symbol import ths_to_vnpy_symbol + + +def symbols_match(ctp_sym: str, ths: str) -> bool: + a = (ctp_sym or "").lower() + b = (ths or "").lower() + if a == b: + return True + if a and b and a.split(".")[0] == b.split(".")[0]: + return True + try: + vnpy_sym, _ = ths_to_vnpy_symbol(ths) + if a == vnpy_sym.lower(): + return True + except Exception: + pass + try: + vnpy_sym, _ = ths_to_vnpy_symbol(ctp_sym) + if vnpy_sym.lower() == b.split(".")[0]: + return True + except Exception: + pass + return False + + +def avg_from_trades( + trades: list[dict[str, Any]], + sym: str, + direction: str, + *, + expect_lots: int = 0, +) -> Optional[float]: + """按成交回报移动加权均价(滚仓多笔开仓后应与柜台一致)。""" + direction = (direction or "long").strip().lower() + vol = 0 + cost = 0.0 + for t in sorted(trades, key=lambda x: (x.get("datetime") or "", x.get("trade_id") or "")): + if not symbols_match(t.get("symbol") or "", sym): + continue + off = (t.get("offset") or "").strip().lower() + pos_dir = ( + t.get("position_direction") or t.get("direction") or "long" + ).strip().lower() + if pos_dir != direction: + continue + lots = int(t.get("lots") or 0) + px = float(t.get("price") or 0) + if lots <= 0 or px <= 0: + continue + if off == "open": + cost += px * lots + vol += lots + elif off == "close" and vol > 0: + avg = cost / vol + dec = min(lots, vol) + cost -= avg * dec + vol -= dec + if vol <= 0: + return None + if expect_lots > 0 and vol != expect_lots: + return None + return round(cost / vol, 4) + + +def resolve_ctp_entry( + sym: str, + direction: str, + ctp: Optional[dict[str, Any]], + trades: Optional[list[dict[str, Any]]] = None, +) -> tuple[float, str]: + """均价:成交加权 > 柜台 PositionCost 持仓价。""" + if not ctp: + return 0.0, "none" + direction = (direction or "long").strip().lower() + lots = int(ctp.get("lots") or 0) + if trades: + trade_avg = avg_from_trades(trades, sym, direction, expect_lots=lots) + if trade_avg and trade_avg > 0: + return float(trade_avg), "trades" + pos_avg = float(ctp.get("avg_price") or 0) + if pos_avg > 0: + return pos_avg, "ctp" + return 0.0, "none" diff --git a/ctp_trading_state.py b/ctp_trading_state.py index cb1703f..bf3251c 100644 --- a/ctp_trading_state.py +++ b/ctp_trading_state.py @@ -12,6 +12,9 @@ from typing import Any, Callable, Optional logger = logging.getLogger(__name__) +CALIBRATE_INTERVAL_SEC = 30.0 + + def position_key(exchange: str, symbol: str, direction: str) -> str: """统一持仓键:exchange|symbol|direction""" ex = (exchange or "").strip().upper() @@ -72,16 +75,24 @@ def reconcile_position_avg( old: Optional[dict[str, Any]], new: dict[str, Any], tick: Optional[float], + *, + trades: Optional[list[dict[str, Any]]] = None, + ths_sym: str = "", ) -> dict[str, Any]: - """手数不变时锁定均价;新开/加仓时用柜台盈亏快照校正一次。""" + """手数不变时锁定均价;滚仓/加仓(手数变化)时以柜台加权均价为准。""" + from ctp_entry_price import resolve_ctp_entry + row = dict(new) lots = int(row.get("lots") or 0) if lots <= 0: return row + direction = (row.get("direction") or "long").strip().lower() old_lots = int(old.get("lots") or 0) if old else 0 + lots_changed = not old or old_lots != lots + if ( - old - and old_lots == lots + not lots_changed + and old and old.get("avg_price_locked") and float(old.get("avg_price") or 0) > 0 ): @@ -89,14 +100,24 @@ def reconcile_position_avg( row["avg_price_locked"] = True return row - refined = avg_price_from_ctp_pnl(row, tick) - pos_avg = float(row.get("avg_price") or 0) - if refined and refined > 0: - row["avg_price"] = refined + sym = ths_sym or (row.get("symbol") or "") + entry, _src = resolve_ctp_entry(sym, direction, row, trades) + if entry > 0: + row["avg_price"] = entry row["avg_price_locked"] = True - elif pos_avg > 0: + return row + + pos_avg = float(row.get("avg_price") or 0) + if pos_avg > 0: row["avg_price"] = pos_avg - row["avg_price_locked"] = bool(tick and refined) + row["avg_price_locked"] = lots_changed or bool(tick) + return row + + if not lots_changed: + refined = avg_price_from_ctp_pnl(row, tick) + if refined and refined > 0: + row["avg_price"] = refined + row["avg_price_locked"] = True return row @@ -203,7 +224,14 @@ class CtpTradingState: changed = True return changed - def upsert_position(self, row: dict[str, Any], *, notify: bool = True) -> None: + def upsert_position( + self, + row: dict[str, Any], + *, + notify: bool = True, + trades: Optional[list[dict[str, Any]]] = None, + ths_sym: str = "", + ) -> None: lots = int(row.get("lots") or 0) ex = row.get("exchange") or "" sym = row.get("symbol") or "" @@ -215,7 +243,9 @@ class CtpTradingState: self._positions.pop(pk, None) else: old = self._positions.get(pk) - row = reconcile_position_avg(old, dict(row), tick) + row = reconcile_position_avg( + old, dict(row), tick, trades=trades, ths_sym=ths_sym or sym, + ) row["position_key"] = pk self._positions[pk] = row if notify: @@ -262,6 +292,9 @@ class CtpTradingState: self, orders: list[dict[str, Any]], positions: list[dict[str, Any]], + *, + trades: Optional[list[dict[str, Any]]] = None, + ths_for_vnpy_sym: Optional[Callable[[str, str], str]] = None, ) -> None: """全量校准:以 vnpy 内存为准重建订单/持仓簿。""" self.begin_sync() @@ -283,7 +316,15 @@ class CtpTradingState: row["position_key"] = pk old = self._positions.get(pk) tick = self.get_tick_price(ex, sym) - new_positions[pk] = reconcile_position_avg(old, row, tick) + ths = sym + if ths_for_vnpy_sym: + try: + ths = ths_for_vnpy_sym(sym, ex) or sym + except Exception: + ths = sym + new_positions[pk] = reconcile_position_avg( + old, row, tick, trades=trades, ths_sym=ths, + ) with self._lock: self._orders = new_orders self._positions = new_positions diff --git a/install_trading.py b/install_trading.py index b4f1982..337b22a 100644 --- a/install_trading.py +++ b/install_trading.py @@ -120,6 +120,7 @@ from trading_context import ( is_ctp_connected, trading_mode_label, ) +from ctp_entry_price import resolve_ctp_entry from ctp_symbol import ths_to_vnpy_symbol from ctp_trading_state import position_key, trading_state from vnpy_bridge import ( @@ -542,48 +543,31 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se pass return False - def _ctp_avg_entry_from_trades( - mode: str, sym: str, direction: str, *, expect_lots: int = 0, - ) -> Optional[float]: - """按成交回报加权计算持仓均价(与柜台一致)。""" + def _live_entry_price( + sym: str, + direction: str, + mode: str, + fallback: float = 0.0, + ) -> float: + """滚仓/展示用均价:优先柜台成交加权与持仓价。""" if not ctp_status(mode).get("connected"): - return None + return fallback + trades: list = [] try: - trades = sorted( - ctp_list_trades(mode), - key=lambda t: (t.get("datetime") or "", t.get("trade_id") or ""), - ) + trades = ctp_list_trades(mode) except Exception: - return None - direction = (direction or "long").strip().lower() - vol = 0 - cost = 0.0 - for t in trades: - if not _match_ctp_symbol(t.get("symbol") or "", sym): + pass + for p in trading_state.get_positions() or _ctp_positions( + mode, refresh_if_empty=False, + ): + if (p.get("direction") or "long") != (direction or "long"): continue - off = (t.get("offset") or "").strip().lower() - pos_dir = ( - t.get("position_direction") or t.get("direction") or "long" - ).strip().lower() - if pos_dir != direction: + if not _match_ctp_symbol(p.get("symbol") or "", sym): continue - lots = int(t.get("lots") or 0) - px = float(t.get("price") or 0) - if lots <= 0 or px <= 0: - continue - if off == "open": - cost += px * lots - vol += lots - elif off == "close" and vol > 0: - avg = cost / vol - dec = min(lots, vol) - cost -= avg * dec - vol -= dec - if vol <= 0: - return None - if expect_lots > 0 and vol != expect_lots: - return None - return round(cost / vol, 4) + entry, _ = resolve_ctp_entry(sym, direction, p, trades) + if entry > 0: + return float(entry) + return fallback def _resolve_ctp_entry_price( mode: str, @@ -591,22 +575,15 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se direction: str, ctp: Optional[dict], ) -> tuple[float, str]: - """持仓均价:成交加权 > 柜台持仓价(锁定后不随 tick 变化)。""" if not ctp: return 0.0, "none" - direction = (direction or "long").strip().lower() - lots = int(ctp.get("lots") or 0) - - trade_avg = _ctp_avg_entry_from_trades( - mode, sym, direction, expect_lots=lots, - ) - if trade_avg and trade_avg > 0: - return float(trade_avg), "trades" - - pos_avg = float(ctp.get("avg_price") or 0) - if pos_avg > 0: - return pos_avg, "ctp" - return 0.0, "none" + trades: list = [] + if ctp_status(mode).get("connected"): + try: + trades = ctp_list_trades(mode) + except Exception: + pass + return resolve_ctp_entry(sym, direction, ctp, trades) def _open_commission_from_ctp_trades( mode: str, sym: str, direction: str, @@ -3172,6 +3149,9 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se spec = get_contract_spec(sym) capital = _capital(conn) mark = _roll_mark_price(sym, mon, mode) + entry_existing = _live_entry_price( + sym, mon["direction"], mode, float(mon.get("entry_price") or 0), + ) add_mode = (d.get("add_mode") or ADD_MODE_MARKET).strip().lower() if add_mode in FIB_MODES: return None, "斐波加仓已停用,请选市价或突破" @@ -3183,7 +3163,7 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se direction=mon["direction"], symbol=sym, qty_existing=float(mon["lots"]), - entry_existing=float(mon["entry_price"]), + entry_existing=entry_existing, initial_take_profit=float(mon["take_profit"] or 0), add_mode=add_mode, new_stop_loss=float(d.get("new_stop_loss") or 0), @@ -3283,8 +3263,44 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se f"滚仓成交 {sym} {add_mode_label(add_mode)} +{preview['add_lots']}手 " f"新止损 {new_sl} 合计 {new_lots}手" ) + _schedule_roll_entry_sync(mon_id, sym, mon["direction"], mode) return True, "成交" + def _schedule_roll_entry_sync( + mon_id: int, sym: str, direction: str, mode: str, + ) -> None: + """滚仓成交后从柜台同步加权均价到手数监控。""" + def _run() -> None: + import time as _time + + _time.sleep(1.5) + try: + conn = get_db() + try: + init_strategy_tables(conn) + capital = _capital(conn) + synced = False + for p in trading_state.get_positions() or _ctp_positions(mode): + if (p.get("direction") or "long") != (direction or "long"): + continue + if not _match_ctp_symbol(p.get("symbol") or "", sym): + continue + _sync_monitor_from_ctp( + conn, mon_id, sym, direction, mode, ctp=p, capital=capital, + ) + synced = True + break + if synced: + commit_retry(conn) + finally: + conn.close() + if synced: + _push_position_snapshot_async(fast=False) + except Exception as exc: + logger.debug("roll entry sync: %s", exc) + + threading.Thread(target=_run, daemon=True, name="roll-entry-sync").start() + def _submit_roll_pending( conn, *, @@ -3358,6 +3374,7 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se fill_roll_leg_fn=_fill_roll_leg_cb, is_trading_session_fn=is_trading_session, get_risk_budget_fn=lambda: get_fixed_amount(get_setting), + get_entry_price_fn=lambda sym, d, fb: _live_entry_price(sym, d, mode, fb), ) conn.commit() finally: @@ -3380,7 +3397,9 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se direction = (mon.get("direction") or "long").strip().lower() price = float(preview.get("add_price") or 0) qty_existing = float(mon.get("lots") or 0) - entry_existing = float(mon.get("entry_price") or 0) + entry_existing = _live_entry_price( + sym, direction, mode, float(mon.get("entry_price") or 0), + ) mult = int(get_contract_spec(sym).get("mult") or 1) roll_pct = get_roll_max_margin_pct(get_setting) add_lots = int(preview.get("add_lots") or 0) diff --git a/strategy/strategy_roll_monitor_lib.py b/strategy/strategy_roll_monitor_lib.py index f605123..9c7ab32 100644 --- a/strategy/strategy_roll_monitor_lib.py +++ b/strategy/strategy_roll_monitor_lib.py @@ -71,6 +71,7 @@ def check_roll_monitors( fill_roll_leg_fn: Callable[[dict, dict, dict, dict], tuple[bool, str]], is_trading_session_fn: Callable[[], bool], get_risk_budget_fn: Callable[[], float], + get_entry_price_fn: Optional[Callable[[str, str, float], float]] = None, ) -> None: """扫描 pending 滚仓腿,标记价穿越则重算手数并市价成交。""" if not is_trading_session_fn(): @@ -114,6 +115,12 @@ def check_roll_monitors( "entry_price": leg["mon_entry"], "take_profit": leg["mon_tp"] or leg["initial_take_profit"], } + entry_fb = float(leg["mon_entry"] or 0) + entry_existing = ( + get_entry_price_fn(sym, direction, entry_fb) + if get_entry_price_fn + else entry_fb + ) grp = { "id": leg["roll_group_id"], "order_monitor_id": leg["order_monitor_id"], @@ -124,7 +131,7 @@ def check_roll_monitors( direction=direction, symbol=sym, qty_existing=float(leg["mon_lots"] or 0), - entry_existing=float(leg["mon_entry"] or 0), + entry_existing=entry_existing, initial_take_profit=float(leg["mon_tp"] or leg["initial_take_profit"] or 0), add_mode=mode, new_stop_loss=float(leg["new_stop_loss"] or 0), diff --git a/vnpy_bridge.py b/vnpy_bridge.py index 5a855f9..65ef6fd 100644 --- a/vnpy_bridge.py +++ b/vnpy_bridge.py @@ -326,7 +326,14 @@ class CtpBridge: pos = event.data row = self._position_row_from_vnpy(pos) if row: - trading_state.upsert_position(row, notify=False) + sym = row.get("symbol") or "" + ex = row.get("exchange") or "" + ths = CtpBridge._vnpy_sym_to_ths(sym, ex) or sym + with _ctp_td_lock: + trades = self.list_trades() + trading_state.upsert_position( + row, notify=False, trades=trades, ths_sym=ths, + ) sym = getattr(pos, "symbol", "") or "" d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short" vol = int(getattr(pos, "volume", 0) or 0) @@ -482,7 +489,13 @@ class CtpBridge: with _ctp_td_lock: orders = self.list_active_orders() positions = self._collect_positions() - trading_state.calibrate_from_lists(orders, positions) + trades = self.list_trades() + trading_state.calibrate_from_lists( + orders, + positions, + trades=trades, + ths_for_vnpy_sym=lambda s, e: CtpBridge._vnpy_sym_to_ths(s, e) or s, + ) except Exception as exc: logger.debug("calibrate trading state: %s", exc)