Fix roll average entry: CTP trade-weighted avg, sync after fill, live entry for preview.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-30 10:42:56 +08:00
parent 6e954da4e1
commit e6208e403e
5 changed files with 239 additions and 68 deletions
+91
View File
@@ -0,0 +1,91 @@
# Copyright (c) 2025-2026 马建军. All rights reserved.
# 详见 LICENSE.zh-CN.txt
"""CTP 持仓均价:成交加权 / 柜台持仓价(滚仓加仓后以柜台为准)。"""
from __future__ import annotations
from typing import Any, Optional
from ctp_symbol import ths_to_vnpy_symbol
def symbols_match(ctp_sym: str, ths: str) -> bool:
a = (ctp_sym or "").lower()
b = (ths or "").lower()
if a == b:
return True
if a and b and a.split(".")[0] == b.split(".")[0]:
return True
try:
vnpy_sym, _ = ths_to_vnpy_symbol(ths)
if a == vnpy_sym.lower():
return True
except Exception:
pass
try:
vnpy_sym, _ = ths_to_vnpy_symbol(ctp_sym)
if vnpy_sym.lower() == b.split(".")[0]:
return True
except Exception:
pass
return False
def avg_from_trades(
trades: list[dict[str, Any]],
sym: str,
direction: str,
*,
expect_lots: int = 0,
) -> Optional[float]:
"""按成交回报移动加权均价(滚仓多笔开仓后应与柜台一致)。"""
direction = (direction or "long").strip().lower()
vol = 0
cost = 0.0
for t in sorted(trades, key=lambda x: (x.get("datetime") or "", x.get("trade_id") or "")):
if not symbols_match(t.get("symbol") or "", sym):
continue
off = (t.get("offset") or "").strip().lower()
pos_dir = (
t.get("position_direction") or t.get("direction") or "long"
).strip().lower()
if pos_dir != direction:
continue
lots = int(t.get("lots") or 0)
px = float(t.get("price") or 0)
if lots <= 0 or px <= 0:
continue
if off == "open":
cost += px * lots
vol += lots
elif off == "close" and vol > 0:
avg = cost / vol
dec = min(lots, vol)
cost -= avg * dec
vol -= dec
if vol <= 0:
return None
if expect_lots > 0 and vol != expect_lots:
return None
return round(cost / vol, 4)
def resolve_ctp_entry(
sym: str,
direction: str,
ctp: Optional[dict[str, Any]],
trades: Optional[list[dict[str, Any]]] = None,
) -> tuple[float, str]:
"""均价:成交加权 > 柜台 PositionCost 持仓价。"""
if not ctp:
return 0.0, "none"
direction = (direction or "long").strip().lower()
lots = int(ctp.get("lots") or 0)
if trades:
trade_avg = avg_from_trades(trades, sym, direction, expect_lots=lots)
if trade_avg and trade_avg > 0:
return float(trade_avg), "trades"
pos_avg = float(ctp.get("avg_price") or 0)
if pos_avg > 0:
return pos_avg, "ctp"
return 0.0, "none"
+53 -12
View File
@@ -12,6 +12,9 @@ from typing import Any, Callable, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CALIBRATE_INTERVAL_SEC = 30.0
def position_key(exchange: str, symbol: str, direction: str) -> str: def position_key(exchange: str, symbol: str, direction: str) -> str:
"""统一持仓键:exchange|symbol|direction""" """统一持仓键:exchange|symbol|direction"""
ex = (exchange or "").strip().upper() ex = (exchange or "").strip().upper()
@@ -72,16 +75,24 @@ def reconcile_position_avg(
old: Optional[dict[str, Any]], old: Optional[dict[str, Any]],
new: dict[str, Any], new: dict[str, Any],
tick: Optional[float], tick: Optional[float],
*,
trades: Optional[list[dict[str, Any]]] = None,
ths_sym: str = "",
) -> dict[str, Any]: ) -> dict[str, Any]:
"""手数不变时锁定均价;新开/加仓时用柜台盈亏快照校正一次""" """手数不变时锁定均价;滚仓/加仓(手数变化)时以柜台加权均价为准"""
from ctp_entry_price import resolve_ctp_entry
row = dict(new) row = dict(new)
lots = int(row.get("lots") or 0) lots = int(row.get("lots") or 0)
if lots <= 0: if lots <= 0:
return row return row
direction = (row.get("direction") or "long").strip().lower()
old_lots = int(old.get("lots") or 0) if old else 0 old_lots = int(old.get("lots") or 0) if old else 0
lots_changed = not old or old_lots != lots
if ( if (
old not lots_changed
and old_lots == lots and old
and old.get("avg_price_locked") and old.get("avg_price_locked")
and float(old.get("avg_price") or 0) > 0 and float(old.get("avg_price") or 0) > 0
): ):
@@ -89,14 +100,24 @@ def reconcile_position_avg(
row["avg_price_locked"] = True row["avg_price_locked"] = True
return row return row
refined = avg_price_from_ctp_pnl(row, tick) sym = ths_sym or (row.get("symbol") or "")
pos_avg = float(row.get("avg_price") or 0) entry, _src = resolve_ctp_entry(sym, direction, row, trades)
if refined and refined > 0: if entry > 0:
row["avg_price"] = refined row["avg_price"] = entry
row["avg_price_locked"] = True row["avg_price_locked"] = True
elif pos_avg > 0: return row
pos_avg = float(row.get("avg_price") or 0)
if pos_avg > 0:
row["avg_price"] = pos_avg row["avg_price"] = pos_avg
row["avg_price_locked"] = bool(tick and refined) row["avg_price_locked"] = lots_changed or bool(tick)
return row
if not lots_changed:
refined = avg_price_from_ctp_pnl(row, tick)
if refined and refined > 0:
row["avg_price"] = refined
row["avg_price_locked"] = True
return row return row
@@ -203,7 +224,14 @@ class CtpTradingState:
changed = True changed = True
return changed return changed
def upsert_position(self, row: dict[str, Any], *, notify: bool = True) -> None: def upsert_position(
self,
row: dict[str, Any],
*,
notify: bool = True,
trades: Optional[list[dict[str, Any]]] = None,
ths_sym: str = "",
) -> None:
lots = int(row.get("lots") or 0) lots = int(row.get("lots") or 0)
ex = row.get("exchange") or "" ex = row.get("exchange") or ""
sym = row.get("symbol") or "" sym = row.get("symbol") or ""
@@ -215,7 +243,9 @@ class CtpTradingState:
self._positions.pop(pk, None) self._positions.pop(pk, None)
else: else:
old = self._positions.get(pk) old = self._positions.get(pk)
row = reconcile_position_avg(old, dict(row), tick) row = reconcile_position_avg(
old, dict(row), tick, trades=trades, ths_sym=ths_sym or sym,
)
row["position_key"] = pk row["position_key"] = pk
self._positions[pk] = row self._positions[pk] = row
if notify: if notify:
@@ -262,6 +292,9 @@ class CtpTradingState:
self, self,
orders: list[dict[str, Any]], orders: list[dict[str, Any]],
positions: list[dict[str, Any]], positions: list[dict[str, Any]],
*,
trades: Optional[list[dict[str, Any]]] = None,
ths_for_vnpy_sym: Optional[Callable[[str, str], str]] = None,
) -> None: ) -> None:
"""全量校准:以 vnpy 内存为准重建订单/持仓簿。""" """全量校准:以 vnpy 内存为准重建订单/持仓簿。"""
self.begin_sync() self.begin_sync()
@@ -283,7 +316,15 @@ class CtpTradingState:
row["position_key"] = pk row["position_key"] = pk
old = self._positions.get(pk) old = self._positions.get(pk)
tick = self.get_tick_price(ex, sym) tick = self.get_tick_price(ex, sym)
new_positions[pk] = reconcile_position_avg(old, row, tick) ths = sym
if ths_for_vnpy_sym:
try:
ths = ths_for_vnpy_sym(sym, ex) or sym
except Exception:
ths = sym
new_positions[pk] = reconcile_position_avg(
old, row, tick, trades=trades, ths_sym=ths,
)
with self._lock: with self._lock:
self._orders = new_orders self._orders = new_orders
self._positions = new_positions self._positions = new_positions
+72 -53
View File
@@ -120,6 +120,7 @@ from trading_context import (
is_ctp_connected, is_ctp_connected,
trading_mode_label, trading_mode_label,
) )
from ctp_entry_price import resolve_ctp_entry
from ctp_symbol import ths_to_vnpy_symbol from ctp_symbol import ths_to_vnpy_symbol
from ctp_trading_state import position_key, trading_state from ctp_trading_state import position_key, trading_state
from vnpy_bridge import ( from vnpy_bridge import (
@@ -542,48 +543,31 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
pass pass
return False return False
def _ctp_avg_entry_from_trades( def _live_entry_price(
mode: str, sym: str, direction: str, *, expect_lots: int = 0, sym: str,
) -> Optional[float]: direction: str,
"""按成交回报加权计算持仓均价(与柜台一致)。""" mode: str,
fallback: float = 0.0,
) -> float:
"""滚仓/展示用均价:优先柜台成交加权与持仓价。"""
if not ctp_status(mode).get("connected"): if not ctp_status(mode).get("connected"):
return None return fallback
trades: list = []
try: try:
trades = sorted( trades = ctp_list_trades(mode)
ctp_list_trades(mode),
key=lambda t: (t.get("datetime") or "", t.get("trade_id") or ""),
)
except Exception: except Exception:
return None pass
direction = (direction or "long").strip().lower() for p in trading_state.get_positions() or _ctp_positions(
vol = 0 mode, refresh_if_empty=False,
cost = 0.0 ):
for t in trades: if (p.get("direction") or "long") != (direction or "long"):
if not _match_ctp_symbol(t.get("symbol") or "", sym):
continue continue
off = (t.get("offset") or "").strip().lower() if not _match_ctp_symbol(p.get("symbol") or "", sym):
pos_dir = (
t.get("position_direction") or t.get("direction") or "long"
).strip().lower()
if pos_dir != direction:
continue continue
lots = int(t.get("lots") or 0) entry, _ = resolve_ctp_entry(sym, direction, p, trades)
px = float(t.get("price") or 0) if entry > 0:
if lots <= 0 or px <= 0: return float(entry)
continue return fallback
if off == "open":
cost += px * lots
vol += lots
elif off == "close" and vol > 0:
avg = cost / vol
dec = min(lots, vol)
cost -= avg * dec
vol -= dec
if vol <= 0:
return None
if expect_lots > 0 and vol != expect_lots:
return None
return round(cost / vol, 4)
def _resolve_ctp_entry_price( def _resolve_ctp_entry_price(
mode: str, mode: str,
@@ -591,22 +575,15 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
direction: str, direction: str,
ctp: Optional[dict], ctp: Optional[dict],
) -> tuple[float, str]: ) -> tuple[float, str]:
"""持仓均价:成交加权 > 柜台持仓价(锁定后不随 tick 变化)。"""
if not ctp: if not ctp:
return 0.0, "none" return 0.0, "none"
direction = (direction or "long").strip().lower() trades: list = []
lots = int(ctp.get("lots") or 0) if ctp_status(mode).get("connected"):
try:
trade_avg = _ctp_avg_entry_from_trades( trades = ctp_list_trades(mode)
mode, sym, direction, expect_lots=lots, except Exception:
) pass
if trade_avg and trade_avg > 0: return resolve_ctp_entry(sym, direction, ctp, trades)
return float(trade_avg), "trades"
pos_avg = float(ctp.get("avg_price") or 0)
if pos_avg > 0:
return pos_avg, "ctp"
return 0.0, "none"
def _open_commission_from_ctp_trades( def _open_commission_from_ctp_trades(
mode: str, sym: str, direction: str, mode: str, sym: str, direction: str,
@@ -3172,6 +3149,9 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
spec = get_contract_spec(sym) spec = get_contract_spec(sym)
capital = _capital(conn) capital = _capital(conn)
mark = _roll_mark_price(sym, mon, mode) mark = _roll_mark_price(sym, mon, mode)
entry_existing = _live_entry_price(
sym, mon["direction"], mode, float(mon.get("entry_price") or 0),
)
add_mode = (d.get("add_mode") or ADD_MODE_MARKET).strip().lower() add_mode = (d.get("add_mode") or ADD_MODE_MARKET).strip().lower()
if add_mode in FIB_MODES: if add_mode in FIB_MODES:
return None, "斐波加仓已停用,请选市价或突破" return None, "斐波加仓已停用,请选市价或突破"
@@ -3183,7 +3163,7 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
direction=mon["direction"], direction=mon["direction"],
symbol=sym, symbol=sym,
qty_existing=float(mon["lots"]), qty_existing=float(mon["lots"]),
entry_existing=float(mon["entry_price"]), entry_existing=entry_existing,
initial_take_profit=float(mon["take_profit"] or 0), initial_take_profit=float(mon["take_profit"] or 0),
add_mode=add_mode, add_mode=add_mode,
new_stop_loss=float(d.get("new_stop_loss") or 0), new_stop_loss=float(d.get("new_stop_loss") or 0),
@@ -3283,8 +3263,44 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
f"滚仓成交 {sym} {add_mode_label(add_mode)} +{preview['add_lots']}" f"滚仓成交 {sym} {add_mode_label(add_mode)} +{preview['add_lots']}"
f"新止损 {new_sl} 合计 {new_lots}" f"新止损 {new_sl} 合计 {new_lots}"
) )
_schedule_roll_entry_sync(mon_id, sym, mon["direction"], mode)
return True, "成交" return True, "成交"
def _schedule_roll_entry_sync(
mon_id: int, sym: str, direction: str, mode: str,
) -> None:
"""滚仓成交后从柜台同步加权均价到手数监控。"""
def _run() -> None:
import time as _time
_time.sleep(1.5)
try:
conn = get_db()
try:
init_strategy_tables(conn)
capital = _capital(conn)
synced = False
for p in trading_state.get_positions() or _ctp_positions(mode):
if (p.get("direction") or "long") != (direction or "long"):
continue
if not _match_ctp_symbol(p.get("symbol") or "", sym):
continue
_sync_monitor_from_ctp(
conn, mon_id, sym, direction, mode, ctp=p, capital=capital,
)
synced = True
break
if synced:
commit_retry(conn)
finally:
conn.close()
if synced:
_push_position_snapshot_async(fast=False)
except Exception as exc:
logger.debug("roll entry sync: %s", exc)
threading.Thread(target=_run, daemon=True, name="roll-entry-sync").start()
def _submit_roll_pending( def _submit_roll_pending(
conn, conn,
*, *,
@@ -3358,6 +3374,7 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
fill_roll_leg_fn=_fill_roll_leg_cb, fill_roll_leg_fn=_fill_roll_leg_cb,
is_trading_session_fn=is_trading_session, is_trading_session_fn=is_trading_session,
get_risk_budget_fn=lambda: get_fixed_amount(get_setting), get_risk_budget_fn=lambda: get_fixed_amount(get_setting),
get_entry_price_fn=lambda sym, d, fb: _live_entry_price(sym, d, mode, fb),
) )
conn.commit() conn.commit()
finally: finally:
@@ -3380,7 +3397,9 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
direction = (mon.get("direction") or "long").strip().lower() direction = (mon.get("direction") or "long").strip().lower()
price = float(preview.get("add_price") or 0) price = float(preview.get("add_price") or 0)
qty_existing = float(mon.get("lots") or 0) qty_existing = float(mon.get("lots") or 0)
entry_existing = float(mon.get("entry_price") or 0) entry_existing = _live_entry_price(
sym, direction, mode, float(mon.get("entry_price") or 0),
)
mult = int(get_contract_spec(sym).get("mult") or 1) mult = int(get_contract_spec(sym).get("mult") or 1)
roll_pct = get_roll_max_margin_pct(get_setting) roll_pct = get_roll_max_margin_pct(get_setting)
add_lots = int(preview.get("add_lots") or 0) add_lots = int(preview.get("add_lots") or 0)
+8 -1
View File
@@ -71,6 +71,7 @@ def check_roll_monitors(
fill_roll_leg_fn: Callable[[dict, dict, dict, dict], tuple[bool, str]], fill_roll_leg_fn: Callable[[dict, dict, dict, dict], tuple[bool, str]],
is_trading_session_fn: Callable[[], bool], is_trading_session_fn: Callable[[], bool],
get_risk_budget_fn: Callable[[], float], get_risk_budget_fn: Callable[[], float],
get_entry_price_fn: Optional[Callable[[str, str, float], float]] = None,
) -> None: ) -> None:
"""扫描 pending 滚仓腿,标记价穿越则重算手数并市价成交。""" """扫描 pending 滚仓腿,标记价穿越则重算手数并市价成交。"""
if not is_trading_session_fn(): if not is_trading_session_fn():
@@ -114,6 +115,12 @@ def check_roll_monitors(
"entry_price": leg["mon_entry"], "entry_price": leg["mon_entry"],
"take_profit": leg["mon_tp"] or leg["initial_take_profit"], "take_profit": leg["mon_tp"] or leg["initial_take_profit"],
} }
entry_fb = float(leg["mon_entry"] or 0)
entry_existing = (
get_entry_price_fn(sym, direction, entry_fb)
if get_entry_price_fn
else entry_fb
)
grp = { grp = {
"id": leg["roll_group_id"], "id": leg["roll_group_id"],
"order_monitor_id": leg["order_monitor_id"], "order_monitor_id": leg["order_monitor_id"],
@@ -124,7 +131,7 @@ def check_roll_monitors(
direction=direction, direction=direction,
symbol=sym, symbol=sym,
qty_existing=float(leg["mon_lots"] or 0), qty_existing=float(leg["mon_lots"] or 0),
entry_existing=float(leg["mon_entry"] or 0), entry_existing=entry_existing,
initial_take_profit=float(leg["mon_tp"] or leg["initial_take_profit"] or 0), initial_take_profit=float(leg["mon_tp"] or leg["initial_take_profit"] or 0),
add_mode=mode, add_mode=mode,
new_stop_loss=float(leg["new_stop_loss"] or 0), new_stop_loss=float(leg["new_stop_loss"] or 0),
+15 -2
View File
@@ -326,7 +326,14 @@ class CtpBridge:
pos = event.data pos = event.data
row = self._position_row_from_vnpy(pos) row = self._position_row_from_vnpy(pos)
if row: if row:
trading_state.upsert_position(row, notify=False) sym = row.get("symbol") or ""
ex = row.get("exchange") or ""
ths = CtpBridge._vnpy_sym_to_ths(sym, ex) or sym
with _ctp_td_lock:
trades = self.list_trades()
trading_state.upsert_position(
row, notify=False, trades=trades, ths_sym=ths,
)
sym = getattr(pos, "symbol", "") or "" sym = getattr(pos, "symbol", "") or ""
d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short" d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short"
vol = int(getattr(pos, "volume", 0) or 0) vol = int(getattr(pos, "volume", 0) or 0)
@@ -482,7 +489,13 @@ class CtpBridge:
with _ctp_td_lock: with _ctp_td_lock:
orders = self.list_active_orders() orders = self.list_active_orders()
positions = self._collect_positions() positions = self._collect_positions()
trading_state.calibrate_from_lists(orders, positions) trades = self.list_trades()
trading_state.calibrate_from_lists(
orders,
positions,
trades=trades,
ths_for_vnpy_sym=lambda s, e: CtpBridge._vnpy_sym_to_ths(s, e) or s,
)
except Exception as exc: except Exception as exc:
logger.debug("calibrate trading state: %s", exc) logger.debug("calibrate trading state: %s", exc)