Fix roll average entry: CTP trade-weighted avg, sync after fill, live entry for preview.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-30 10:42:56 +08:00
parent 6e954da4e1
commit e6208e403e
5 changed files with 239 additions and 68 deletions
+72 -53
View File
@@ -120,6 +120,7 @@ from trading_context import (
is_ctp_connected,
trading_mode_label,
)
from ctp_entry_price import resolve_ctp_entry
from ctp_symbol import ths_to_vnpy_symbol
from ctp_trading_state import position_key, trading_state
from vnpy_bridge import (
@@ -542,48 +543,31 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
pass
return False
def _ctp_avg_entry_from_trades(
mode: str, sym: str, direction: str, *, expect_lots: int = 0,
) -> Optional[float]:
"""按成交回报加权计算持仓均价(与柜台一致)。"""
def _live_entry_price(
sym: str,
direction: str,
mode: str,
fallback: float = 0.0,
) -> float:
"""滚仓/展示用均价:优先柜台成交加权与持仓价。"""
if not ctp_status(mode).get("connected"):
return None
return fallback
trades: list = []
try:
trades = sorted(
ctp_list_trades(mode),
key=lambda t: (t.get("datetime") or "", t.get("trade_id") or ""),
)
trades = ctp_list_trades(mode)
except Exception:
return None
direction = (direction or "long").strip().lower()
vol = 0
cost = 0.0
for t in trades:
if not _match_ctp_symbol(t.get("symbol") or "", sym):
pass
for p in trading_state.get_positions() or _ctp_positions(
mode, refresh_if_empty=False,
):
if (p.get("direction") or "long") != (direction or "long"):
continue
off = (t.get("offset") or "").strip().lower()
pos_dir = (
t.get("position_direction") or t.get("direction") or "long"
).strip().lower()
if pos_dir != direction:
if not _match_ctp_symbol(p.get("symbol") or "", sym):
continue
lots = int(t.get("lots") or 0)
px = float(t.get("price") or 0)
if lots <= 0 or px <= 0:
continue
if off == "open":
cost += px * lots
vol += lots
elif off == "close" and vol > 0:
avg = cost / vol
dec = min(lots, vol)
cost -= avg * dec
vol -= dec
if vol <= 0:
return None
if expect_lots > 0 and vol != expect_lots:
return None
return round(cost / vol, 4)
entry, _ = resolve_ctp_entry(sym, direction, p, trades)
if entry > 0:
return float(entry)
return fallback
def _resolve_ctp_entry_price(
mode: str,
@@ -591,22 +575,15 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
direction: str,
ctp: Optional[dict],
) -> tuple[float, str]:
"""持仓均价:成交加权 > 柜台持仓价(锁定后不随 tick 变化)。"""
if not ctp:
return 0.0, "none"
direction = (direction or "long").strip().lower()
lots = int(ctp.get("lots") or 0)
trade_avg = _ctp_avg_entry_from_trades(
mode, sym, direction, expect_lots=lots,
)
if trade_avg and trade_avg > 0:
return float(trade_avg), "trades"
pos_avg = float(ctp.get("avg_price") or 0)
if pos_avg > 0:
return pos_avg, "ctp"
return 0.0, "none"
trades: list = []
if ctp_status(mode).get("connected"):
try:
trades = ctp_list_trades(mode)
except Exception:
pass
return resolve_ctp_entry(sym, direction, ctp, trades)
def _open_commission_from_ctp_trades(
mode: str, sym: str, direction: str,
@@ -3172,6 +3149,9 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
spec = get_contract_spec(sym)
capital = _capital(conn)
mark = _roll_mark_price(sym, mon, mode)
entry_existing = _live_entry_price(
sym, mon["direction"], mode, float(mon.get("entry_price") or 0),
)
add_mode = (d.get("add_mode") or ADD_MODE_MARKET).strip().lower()
if add_mode in FIB_MODES:
return None, "斐波加仓已停用,请选市价或突破"
@@ -3183,7 +3163,7 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
direction=mon["direction"],
symbol=sym,
qty_existing=float(mon["lots"]),
entry_existing=float(mon["entry_price"]),
entry_existing=entry_existing,
initial_take_profit=float(mon["take_profit"] or 0),
add_mode=add_mode,
new_stop_loss=float(d.get("new_stop_loss") or 0),
@@ -3283,8 +3263,44 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
f"滚仓成交 {sym} {add_mode_label(add_mode)} +{preview['add_lots']}"
f"新止损 {new_sl} 合计 {new_lots}"
)
_schedule_roll_entry_sync(mon_id, sym, mon["direction"], mode)
return True, "成交"
def _schedule_roll_entry_sync(
mon_id: int, sym: str, direction: str, mode: str,
) -> None:
"""滚仓成交后从柜台同步加权均价到手数监控。"""
def _run() -> None:
import time as _time
_time.sleep(1.5)
try:
conn = get_db()
try:
init_strategy_tables(conn)
capital = _capital(conn)
synced = False
for p in trading_state.get_positions() or _ctp_positions(mode):
if (p.get("direction") or "long") != (direction or "long"):
continue
if not _match_ctp_symbol(p.get("symbol") or "", sym):
continue
_sync_monitor_from_ctp(
conn, mon_id, sym, direction, mode, ctp=p, capital=capital,
)
synced = True
break
if synced:
commit_retry(conn)
finally:
conn.close()
if synced:
_push_position_snapshot_async(fast=False)
except Exception as exc:
logger.debug("roll entry sync: %s", exc)
threading.Thread(target=_run, daemon=True, name="roll-entry-sync").start()
def _submit_roll_pending(
conn,
*,
@@ -3358,6 +3374,7 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
fill_roll_leg_fn=_fill_roll_leg_cb,
is_trading_session_fn=is_trading_session,
get_risk_budget_fn=lambda: get_fixed_amount(get_setting),
get_entry_price_fn=lambda sym, d, fb: _live_entry_price(sym, d, mode, fb),
)
conn.commit()
finally:
@@ -3380,7 +3397,9 @@ def install_trading(app, *, login_required, require_nav, get_db, get_setting, se
direction = (mon.get("direction") or "long").strip().lower()
price = float(preview.get("add_price") or 0)
qty_existing = float(mon.get("lots") or 0)
entry_existing = float(mon.get("entry_price") or 0)
entry_existing = _live_entry_price(
sym, direction, mode, float(mon.get("entry_price") or 0),
)
mult = int(get_contract_spec(sym).get("mult") or 1)
roll_pct = get_roll_max_margin_pct(get_setting)
add_lots = int(preview.get("add_lots") or 0)