Fix roll average entry: CTP trade-weighted avg, sync after fill, live entry for preview.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) 2025-2026 马建军. All rights reserved.
|
||||
# 详见 LICENSE.zh-CN.txt
|
||||
|
||||
"""CTP 持仓均价:成交加权 / 柜台持仓价(滚仓加仓后以柜台为准)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from ctp_symbol import ths_to_vnpy_symbol
|
||||
|
||||
|
||||
def symbols_match(ctp_sym: str, ths: str) -> bool:
|
||||
a = (ctp_sym or "").lower()
|
||||
b = (ths or "").lower()
|
||||
if a == b:
|
||||
return True
|
||||
if a and b and a.split(".")[0] == b.split(".")[0]:
|
||||
return True
|
||||
try:
|
||||
vnpy_sym, _ = ths_to_vnpy_symbol(ths)
|
||||
if a == vnpy_sym.lower():
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
vnpy_sym, _ = ths_to_vnpy_symbol(ctp_sym)
|
||||
if vnpy_sym.lower() == b.split(".")[0]:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def avg_from_trades(
|
||||
trades: list[dict[str, Any]],
|
||||
sym: str,
|
||||
direction: str,
|
||||
*,
|
||||
expect_lots: int = 0,
|
||||
) -> Optional[float]:
|
||||
"""按成交回报移动加权均价(滚仓多笔开仓后应与柜台一致)。"""
|
||||
direction = (direction or "long").strip().lower()
|
||||
vol = 0
|
||||
cost = 0.0
|
||||
for t in sorted(trades, key=lambda x: (x.get("datetime") or "", x.get("trade_id") or "")):
|
||||
if not symbols_match(t.get("symbol") or "", sym):
|
||||
continue
|
||||
off = (t.get("offset") or "").strip().lower()
|
||||
pos_dir = (
|
||||
t.get("position_direction") or t.get("direction") or "long"
|
||||
).strip().lower()
|
||||
if pos_dir != direction:
|
||||
continue
|
||||
lots = int(t.get("lots") or 0)
|
||||
px = float(t.get("price") or 0)
|
||||
if lots <= 0 or px <= 0:
|
||||
continue
|
||||
if off == "open":
|
||||
cost += px * lots
|
||||
vol += lots
|
||||
elif off == "close" and vol > 0:
|
||||
avg = cost / vol
|
||||
dec = min(lots, vol)
|
||||
cost -= avg * dec
|
||||
vol -= dec
|
||||
if vol <= 0:
|
||||
return None
|
||||
if expect_lots > 0 and vol != expect_lots:
|
||||
return None
|
||||
return round(cost / vol, 4)
|
||||
|
||||
|
||||
def resolve_ctp_entry(
|
||||
sym: str,
|
||||
direction: str,
|
||||
ctp: Optional[dict[str, Any]],
|
||||
trades: Optional[list[dict[str, Any]]] = None,
|
||||
) -> tuple[float, str]:
|
||||
"""均价:成交加权 > 柜台 PositionCost 持仓价。"""
|
||||
if not ctp:
|
||||
return 0.0, "none"
|
||||
direction = (direction or "long").strip().lower()
|
||||
lots = int(ctp.get("lots") or 0)
|
||||
if trades:
|
||||
trade_avg = avg_from_trades(trades, sym, direction, expect_lots=lots)
|
||||
if trade_avg and trade_avg > 0:
|
||||
return float(trade_avg), "trades"
|
||||
pos_avg = float(ctp.get("avg_price") or 0)
|
||||
if pos_avg > 0:
|
||||
return pos_avg, "ctp"
|
||||
return 0.0, "none"
|
||||
+53
-12
@@ -12,6 +12,9 @@ from typing import Any, Callable, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CALIBRATE_INTERVAL_SEC = 30.0
|
||||
|
||||
|
||||
def position_key(exchange: str, symbol: str, direction: str) -> str:
|
||||
"""统一持仓键:exchange|symbol|direction"""
|
||||
ex = (exchange or "").strip().upper()
|
||||
@@ -72,16 +75,24 @@ def reconcile_position_avg(
|
||||
old: Optional[dict[str, Any]],
|
||||
new: dict[str, Any],
|
||||
tick: Optional[float],
|
||||
*,
|
||||
trades: Optional[list[dict[str, Any]]] = None,
|
||||
ths_sym: str = "",
|
||||
) -> dict[str, Any]:
|
||||
"""手数不变时锁定均价;新开/加仓时用柜台盈亏快照校正一次。"""
|
||||
"""手数不变时锁定均价;滚仓/加仓(手数变化)时以柜台加权均价为准。"""
|
||||
from ctp_entry_price import resolve_ctp_entry
|
||||
|
||||
row = dict(new)
|
||||
lots = int(row.get("lots") or 0)
|
||||
if lots <= 0:
|
||||
return row
|
||||
direction = (row.get("direction") or "long").strip().lower()
|
||||
old_lots = int(old.get("lots") or 0) if old else 0
|
||||
lots_changed = not old or old_lots != lots
|
||||
|
||||
if (
|
||||
old
|
||||
and old_lots == lots
|
||||
not lots_changed
|
||||
and old
|
||||
and old.get("avg_price_locked")
|
||||
and float(old.get("avg_price") or 0) > 0
|
||||
):
|
||||
@@ -89,14 +100,24 @@ def reconcile_position_avg(
|
||||
row["avg_price_locked"] = True
|
||||
return row
|
||||
|
||||
refined = avg_price_from_ctp_pnl(row, tick)
|
||||
pos_avg = float(row.get("avg_price") or 0)
|
||||
if refined and refined > 0:
|
||||
row["avg_price"] = refined
|
||||
sym = ths_sym or (row.get("symbol") or "")
|
||||
entry, _src = resolve_ctp_entry(sym, direction, row, trades)
|
||||
if entry > 0:
|
||||
row["avg_price"] = entry
|
||||
row["avg_price_locked"] = True
|
||||
elif pos_avg > 0:
|
||||
return row
|
||||
|
||||
pos_avg = float(row.get("avg_price") or 0)
|
||||
if pos_avg > 0:
|
||||
row["avg_price"] = pos_avg
|
||||
row["avg_price_locked"] = bool(tick and refined)
|
||||
row["avg_price_locked"] = lots_changed or bool(tick)
|
||||
return row
|
||||
|
||||
if not lots_changed:
|
||||
refined = avg_price_from_ctp_pnl(row, tick)
|
||||
if refined and refined > 0:
|
||||
row["avg_price"] = refined
|
||||
row["avg_price_locked"] = True
|
||||
return row
|
||||
|
||||
|
||||
@@ -203,7 +224,14 @@ class CtpTradingState:
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
def upsert_position(self, row: dict[str, Any], *, notify: bool = True) -> None:
|
||||
def upsert_position(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
*,
|
||||
notify: bool = True,
|
||||
trades: Optional[list[dict[str, Any]]] = None,
|
||||
ths_sym: str = "",
|
||||
) -> None:
|
||||
lots = int(row.get("lots") or 0)
|
||||
ex = row.get("exchange") or ""
|
||||
sym = row.get("symbol") or ""
|
||||
@@ -215,7 +243,9 @@ class CtpTradingState:
|
||||
self._positions.pop(pk, None)
|
||||
else:
|
||||
old = self._positions.get(pk)
|
||||
row = reconcile_position_avg(old, dict(row), tick)
|
||||
row = reconcile_position_avg(
|
||||
old, dict(row), tick, trades=trades, ths_sym=ths_sym or sym,
|
||||
)
|
||||
row["position_key"] = pk
|
||||
self._positions[pk] = row
|
||||
if notify:
|
||||
@@ -262,6 +292,9 @@ class CtpTradingState:
|
||||
self,
|
||||
orders: list[dict[str, Any]],
|
||||
positions: list[dict[str, Any]],
|
||||
*,
|
||||
trades: Optional[list[dict[str, Any]]] = None,
|
||||
ths_for_vnpy_sym: Optional[Callable[[str, str], str]] = None,
|
||||
) -> None:
|
||||
"""全量校准:以 vnpy 内存为准重建订单/持仓簿。"""
|
||||
self.begin_sync()
|
||||
@@ -283,7 +316,15 @@ class CtpTradingState:
|
||||
row["position_key"] = pk
|
||||
old = self._positions.get(pk)
|
||||
tick = self.get_tick_price(ex, sym)
|
||||
new_positions[pk] = reconcile_position_avg(old, row, tick)
|
||||
ths = sym
|
||||
if ths_for_vnpy_sym:
|
||||
try:
|
||||
ths = ths_for_vnpy_sym(sym, ex) or sym
|
||||
except Exception:
|
||||
ths = sym
|
||||
new_positions[pk] = reconcile_position_avg(
|
||||
old, row, tick, trades=trades, ths_sym=ths,
|
||||
)
|
||||
with self._lock:
|
||||
self._orders = new_orders
|
||||
self._positions = new_positions
|
||||
|
||||
+72
-53
@@ -120,6 +120,7 @@ from trading_context import (
|
||||
is_ctp_connected,
|
||||
trading_mode_label,
|
||||
)
|
||||
from ctp_entry_price import resolve_ctp_entry
|
||||
from ctp_symbol import ths_to_vnpy_symbol
|
||||
from ctp_trading_state import position_key, trading_state
|
||||
from vnpy_bridge import (
|
||||
@@ -542,48 +543,31 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
|
||||
pass
|
||||
return False
|
||||
|
||||
def _ctp_avg_entry_from_trades(
|
||||
mode: str, sym: str, direction: str, *, expect_lots: int = 0,
|
||||
) -> Optional[float]:
|
||||
"""按成交回报加权计算持仓均价(与柜台一致)。"""
|
||||
def _live_entry_price(
|
||||
sym: str,
|
||||
direction: str,
|
||||
mode: str,
|
||||
fallback: float = 0.0,
|
||||
) -> float:
|
||||
"""滚仓/展示用均价:优先柜台成交加权与持仓价。"""
|
||||
if not ctp_status(mode).get("connected"):
|
||||
return None
|
||||
return fallback
|
||||
trades: list = []
|
||||
try:
|
||||
trades = sorted(
|
||||
ctp_list_trades(mode),
|
||||
key=lambda t: (t.get("datetime") or "", t.get("trade_id") or ""),
|
||||
)
|
||||
trades = ctp_list_trades(mode)
|
||||
except Exception:
|
||||
return None
|
||||
direction = (direction or "long").strip().lower()
|
||||
vol = 0
|
||||
cost = 0.0
|
||||
for t in trades:
|
||||
if not _match_ctp_symbol(t.get("symbol") or "", sym):
|
||||
pass
|
||||
for p in trading_state.get_positions() or _ctp_positions(
|
||||
mode, refresh_if_empty=False,
|
||||
):
|
||||
if (p.get("direction") or "long") != (direction or "long"):
|
||||
continue
|
||||
off = (t.get("offset") or "").strip().lower()
|
||||
pos_dir = (
|
||||
t.get("position_direction") or t.get("direction") or "long"
|
||||
).strip().lower()
|
||||
if pos_dir != direction:
|
||||
if not _match_ctp_symbol(p.get("symbol") or "", sym):
|
||||
continue
|
||||
lots = int(t.get("lots") or 0)
|
||||
px = float(t.get("price") or 0)
|
||||
if lots <= 0 or px <= 0:
|
||||
continue
|
||||
if off == "open":
|
||||
cost += px * lots
|
||||
vol += lots
|
||||
elif off == "close" and vol > 0:
|
||||
avg = cost / vol
|
||||
dec = min(lots, vol)
|
||||
cost -= avg * dec
|
||||
vol -= dec
|
||||
if vol <= 0:
|
||||
return None
|
||||
if expect_lots > 0 and vol != expect_lots:
|
||||
return None
|
||||
return round(cost / vol, 4)
|
||||
entry, _ = resolve_ctp_entry(sym, direction, p, trades)
|
||||
if entry > 0:
|
||||
return float(entry)
|
||||
return fallback
|
||||
|
||||
def _resolve_ctp_entry_price(
|
||||
mode: str,
|
||||
@@ -591,22 +575,15 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
|
||||
direction: str,
|
||||
ctp: Optional[dict],
|
||||
) -> tuple[float, str]:
|
||||
"""持仓均价:成交加权 > 柜台持仓价(锁定后不随 tick 变化)。"""
|
||||
if not ctp:
|
||||
return 0.0, "none"
|
||||
direction = (direction or "long").strip().lower()
|
||||
lots = int(ctp.get("lots") or 0)
|
||||
|
||||
trade_avg = _ctp_avg_entry_from_trades(
|
||||
mode, sym, direction, expect_lots=lots,
|
||||
)
|
||||
if trade_avg and trade_avg > 0:
|
||||
return float(trade_avg), "trades"
|
||||
|
||||
pos_avg = float(ctp.get("avg_price") or 0)
|
||||
if pos_avg > 0:
|
||||
return pos_avg, "ctp"
|
||||
return 0.0, "none"
|
||||
trades: list = []
|
||||
if ctp_status(mode).get("connected"):
|
||||
try:
|
||||
trades = ctp_list_trades(mode)
|
||||
except Exception:
|
||||
pass
|
||||
return resolve_ctp_entry(sym, direction, ctp, trades)
|
||||
|
||||
def _open_commission_from_ctp_trades(
|
||||
mode: str, sym: str, direction: str,
|
||||
@@ -3172,6 +3149,9 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
|
||||
spec = get_contract_spec(sym)
|
||||
capital = _capital(conn)
|
||||
mark = _roll_mark_price(sym, mon, mode)
|
||||
entry_existing = _live_entry_price(
|
||||
sym, mon["direction"], mode, float(mon.get("entry_price") or 0),
|
||||
)
|
||||
add_mode = (d.get("add_mode") or ADD_MODE_MARKET).strip().lower()
|
||||
if add_mode in FIB_MODES:
|
||||
return None, "斐波加仓已停用,请选市价或突破"
|
||||
@@ -3183,7 +3163,7 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
|
||||
direction=mon["direction"],
|
||||
symbol=sym,
|
||||
qty_existing=float(mon["lots"]),
|
||||
entry_existing=float(mon["entry_price"]),
|
||||
entry_existing=entry_existing,
|
||||
initial_take_profit=float(mon["take_profit"] or 0),
|
||||
add_mode=add_mode,
|
||||
new_stop_loss=float(d.get("new_stop_loss") or 0),
|
||||
@@ -3283,8 +3263,44 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
|
||||
f"滚仓成交 {sym} {add_mode_label(add_mode)} +{preview['add_lots']}手 "
|
||||
f"新止损 {new_sl} 合计 {new_lots}手"
|
||||
)
|
||||
_schedule_roll_entry_sync(mon_id, sym, mon["direction"], mode)
|
||||
return True, "成交"
|
||||
|
||||
def _schedule_roll_entry_sync(
|
||||
mon_id: int, sym: str, direction: str, mode: str,
|
||||
) -> None:
|
||||
"""滚仓成交后从柜台同步加权均价到手数监控。"""
|
||||
def _run() -> None:
|
||||
import time as _time
|
||||
|
||||
_time.sleep(1.5)
|
||||
try:
|
||||
conn = get_db()
|
||||
try:
|
||||
init_strategy_tables(conn)
|
||||
capital = _capital(conn)
|
||||
synced = False
|
||||
for p in trading_state.get_positions() or _ctp_positions(mode):
|
||||
if (p.get("direction") or "long") != (direction or "long"):
|
||||
continue
|
||||
if not _match_ctp_symbol(p.get("symbol") or "", sym):
|
||||
continue
|
||||
_sync_monitor_from_ctp(
|
||||
conn, mon_id, sym, direction, mode, ctp=p, capital=capital,
|
||||
)
|
||||
synced = True
|
||||
break
|
||||
if synced:
|
||||
commit_retry(conn)
|
||||
finally:
|
||||
conn.close()
|
||||
if synced:
|
||||
_push_position_snapshot_async(fast=False)
|
||||
except Exception as exc:
|
||||
logger.debug("roll entry sync: %s", exc)
|
||||
|
||||
threading.Thread(target=_run, daemon=True, name="roll-entry-sync").start()
|
||||
|
||||
def _submit_roll_pending(
|
||||
conn,
|
||||
*,
|
||||
@@ -3358,6 +3374,7 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
|
||||
fill_roll_leg_fn=_fill_roll_leg_cb,
|
||||
is_trading_session_fn=is_trading_session,
|
||||
get_risk_budget_fn=lambda: get_fixed_amount(get_setting),
|
||||
get_entry_price_fn=lambda sym, d, fb: _live_entry_price(sym, d, mode, fb),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
@@ -3380,7 +3397,9 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
|
||||
direction = (mon.get("direction") or "long").strip().lower()
|
||||
price = float(preview.get("add_price") or 0)
|
||||
qty_existing = float(mon.get("lots") or 0)
|
||||
entry_existing = float(mon.get("entry_price") or 0)
|
||||
entry_existing = _live_entry_price(
|
||||
sym, direction, mode, float(mon.get("entry_price") or 0),
|
||||
)
|
||||
mult = int(get_contract_spec(sym).get("mult") or 1)
|
||||
roll_pct = get_roll_max_margin_pct(get_setting)
|
||||
add_lots = int(preview.get("add_lots") or 0)
|
||||
|
||||
@@ -71,6 +71,7 @@ def check_roll_monitors(
|
||||
fill_roll_leg_fn: Callable[[dict, dict, dict, dict], tuple[bool, str]],
|
||||
is_trading_session_fn: Callable[[], bool],
|
||||
get_risk_budget_fn: Callable[[], float],
|
||||
get_entry_price_fn: Optional[Callable[[str, str, float], float]] = None,
|
||||
) -> None:
|
||||
"""扫描 pending 滚仓腿,标记价穿越则重算手数并市价成交。"""
|
||||
if not is_trading_session_fn():
|
||||
@@ -114,6 +115,12 @@ def check_roll_monitors(
|
||||
"entry_price": leg["mon_entry"],
|
||||
"take_profit": leg["mon_tp"] or leg["initial_take_profit"],
|
||||
}
|
||||
entry_fb = float(leg["mon_entry"] or 0)
|
||||
entry_existing = (
|
||||
get_entry_price_fn(sym, direction, entry_fb)
|
||||
if get_entry_price_fn
|
||||
else entry_fb
|
||||
)
|
||||
grp = {
|
||||
"id": leg["roll_group_id"],
|
||||
"order_monitor_id": leg["order_monitor_id"],
|
||||
@@ -124,7 +131,7 @@ def check_roll_monitors(
|
||||
direction=direction,
|
||||
symbol=sym,
|
||||
qty_existing=float(leg["mon_lots"] or 0),
|
||||
entry_existing=float(leg["mon_entry"] or 0),
|
||||
entry_existing=entry_existing,
|
||||
initial_take_profit=float(leg["mon_tp"] or leg["initial_take_profit"] or 0),
|
||||
add_mode=mode,
|
||||
new_stop_loss=float(leg["new_stop_loss"] or 0),
|
||||
|
||||
+15
-2
@@ -326,7 +326,14 @@ class CtpBridge:
|
||||
pos = event.data
|
||||
row = self._position_row_from_vnpy(pos)
|
||||
if row:
|
||||
trading_state.upsert_position(row, notify=False)
|
||||
sym = row.get("symbol") or ""
|
||||
ex = row.get("exchange") or ""
|
||||
ths = CtpBridge._vnpy_sym_to_ths(sym, ex) or sym
|
||||
with _ctp_td_lock:
|
||||
trades = self.list_trades()
|
||||
trading_state.upsert_position(
|
||||
row, notify=False, trades=trades, ths_sym=ths,
|
||||
)
|
||||
sym = getattr(pos, "symbol", "") or ""
|
||||
d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short"
|
||||
vol = int(getattr(pos, "volume", 0) or 0)
|
||||
@@ -482,7 +489,13 @@ class CtpBridge:
|
||||
with _ctp_td_lock:
|
||||
orders = self.list_active_orders()
|
||||
positions = self._collect_positions()
|
||||
trading_state.calibrate_from_lists(orders, positions)
|
||||
trades = self.list_trades()
|
||||
trading_state.calibrate_from_lists(
|
||||
orders,
|
||||
positions,
|
||||
trades=trades,
|
||||
ths_for_vnpy_sym=lambda s, e: CtpBridge._vnpy_sym_to_ths(s, e) or s,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("calibrate trading state: %s", exc)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user