d07fc4b70d
Co-authored-by: Cursor <cursoragent@cursor.com>
322 lines
11 KiB
Python
322 lines
11 KiB
Python
# Copyright (c) 2025-2026 马建军. All rights reserved.
|
|
# 专有软件 — 未经授权禁止复制、传播、转售。
|
|
# 详见 LICENSE.zh-CN.txt
|
|
|
|
"""CTP 权威内存簿:委托、持仓、同步状态(事件增量 + 定期全量校准)。"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import threading
|
|
import time
|
|
from typing import Any, Callable, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
CALIBRATE_INTERVAL_SEC = 30.0
|
|
|
|
|
|
def position_key(exchange: str, symbol: str, direction: str) -> str:
|
|
"""统一持仓键:exchange|symbol|direction"""
|
|
ex = (exchange or "").strip().upper()
|
|
sym = (symbol or "").strip().lower()
|
|
d = (direction or "long").strip().lower()
|
|
if ex:
|
|
return f"{ex}|{sym}|{d}"
|
|
return f"{sym}|{d}"
|
|
|
|
|
|
def parse_position_key(key: str) -> tuple[str, str, str]:
|
|
parts = (key or "").split("|")
|
|
if len(parts) >= 3:
|
|
return parts[0], parts[1], parts[2]
|
|
if len(parts) == 2:
|
|
return "", parts[0], parts[1]
|
|
return "", (key or "").lower(), "long"
|
|
|
|
|
|
def reconcile_position_avg(
|
|
old: Optional[dict[str, Any]],
|
|
new: dict[str, Any],
|
|
tick: Optional[float],
|
|
*,
|
|
trades: Optional[list[dict[str, Any]]] = None,
|
|
ths_sym: str = "",
|
|
) -> dict[str, Any]:
|
|
"""手数不变时锁定均价;滚仓/加仓(手数变化)时以柜台加权均价为准。"""
|
|
from ctp_entry_price import entry_from_ctp_pnl, resolve_ctp_entry
|
|
|
|
row = dict(new)
|
|
lots = int(row.get("lots") or 0)
|
|
if lots <= 0:
|
|
return row
|
|
direction = (row.get("direction") or "long").strip().lower()
|
|
old_lots = int(old.get("lots") or 0) if old else 0
|
|
lots_changed = not old or old_lots != lots
|
|
sym = ths_sym or (row.get("symbol") or "")
|
|
|
|
if (
|
|
not lots_changed
|
|
and old
|
|
and old.get("avg_price_locked")
|
|
and float(old.get("avg_price") or 0) > 0
|
|
):
|
|
locked = float(old["avg_price"])
|
|
corrected, _ = resolve_ctp_entry(sym, direction, row, trades, tick=tick)
|
|
pnl_entry = entry_from_ctp_pnl(row, tick, ths_sym=sym)
|
|
if corrected > 0 and abs(corrected - locked) >= 0.5:
|
|
row["avg_price"] = corrected
|
|
row["avg_price_locked"] = True
|
|
return row
|
|
if pnl_entry and abs(pnl_entry - locked) >= 0.5:
|
|
row["avg_price"] = pnl_entry
|
|
row["avg_price_locked"] = True
|
|
return row
|
|
row["avg_price"] = locked
|
|
row["avg_price_locked"] = True
|
|
return row
|
|
|
|
entry, _src = resolve_ctp_entry(sym, direction, row, trades, tick=tick)
|
|
if entry > 0:
|
|
row["avg_price"] = entry
|
|
row["avg_price_locked"] = True
|
|
return row
|
|
|
|
pos_avg = float(row.get("avg_price") or 0)
|
|
if pos_avg > 0:
|
|
row["avg_price"] = pos_avg
|
|
row["avg_price_locked"] = lots_changed or bool(tick)
|
|
return row
|
|
|
|
|
|
class CtpTradingState:
|
|
"""进程内 CTP 快照:柜台回报为准,SQLite 仅挂 SL/TP 元数据。"""
|
|
|
|
def __init__(self) -> None:
|
|
self._lock = threading.RLock()
|
|
self._orders: dict[str, dict[str, Any]] = {}
|
|
self._positions: dict[str, dict[str, Any]] = {}
|
|
self._tick_prices: dict[str, float] = {}
|
|
self._sync_state = "idle"
|
|
self._last_event_ts: float = 0.0
|
|
self._last_calibrate_ts: float = 0.0
|
|
self._on_change: Optional[Callable[[], None]] = None
|
|
|
|
def set_change_callback(self, fn: Optional[Callable[[], None]]) -> None:
|
|
self._on_change = fn
|
|
|
|
def _notify(self) -> None:
|
|
self._last_event_ts = time.time()
|
|
fn = self._on_change
|
|
if fn:
|
|
try:
|
|
fn()
|
|
except Exception as exc:
|
|
logger.debug("trading state change callback: %s", exc)
|
|
|
|
@property
|
|
def sync_state(self) -> str:
|
|
with self._lock:
|
|
return self._sync_state
|
|
|
|
def sync_label(self) -> str:
|
|
st = self.sync_state
|
|
if st == "syncing":
|
|
return "同步中…"
|
|
if st == "ready":
|
|
return "已同步"
|
|
return ""
|
|
|
|
def begin_sync(self) -> None:
|
|
with self._lock:
|
|
self._sync_state = "syncing"
|
|
|
|
def finish_sync(self) -> None:
|
|
with self._lock:
|
|
self._sync_state = "ready"
|
|
self._last_calibrate_ts = time.time()
|
|
|
|
def needs_calibrate(self) -> bool:
|
|
with self._lock:
|
|
if self._sync_state == "idle":
|
|
return True
|
|
return (time.time() - self._last_calibrate_ts) >= CALIBRATE_INTERVAL_SEC
|
|
|
|
def upsert_order(self, row: dict[str, Any], *, notify: bool = True) -> None:
|
|
oid = str(row.get("order_id") or row.get("vt_order_id") or "").strip()
|
|
if not oid:
|
|
return
|
|
with self._lock:
|
|
self._orders[oid] = dict(row)
|
|
if notify:
|
|
self._notify()
|
|
|
|
def remove_order(self, order_id: str, *, notify: bool = True) -> None:
|
|
oid = (order_id or "").strip()
|
|
if not oid:
|
|
return
|
|
removed = False
|
|
with self._lock:
|
|
if oid in self._orders:
|
|
del self._orders[oid]
|
|
removed = True
|
|
else:
|
|
for k in list(self._orders.keys()):
|
|
if k == oid or k.endswith(oid) or oid.endswith(k):
|
|
del self._orders[k]
|
|
removed = True
|
|
break
|
|
if removed and notify:
|
|
self._notify()
|
|
|
|
def get_position(self, pk: str) -> Optional[dict[str, Any]]:
|
|
with self._lock:
|
|
row = self._positions.get(pk)
|
|
return dict(row) if row else None
|
|
|
|
def try_lock_entry_prices(self) -> bool:
|
|
"""有 tick 后校正持仓均价(含已锁定但与柜台盈亏不一致的)。"""
|
|
from ctp_entry_price import resolve_ctp_entry
|
|
|
|
changed = False
|
|
with self._lock:
|
|
for pk, row in list(self._positions.items()):
|
|
ex = row.get("exchange") or ""
|
|
sym = row.get("symbol") or ""
|
|
tick = self.get_tick_price(ex, sym)
|
|
if not tick or tick <= 0:
|
|
continue
|
|
ths = sym
|
|
try:
|
|
from vnpy_bridge import CtpBridge
|
|
ths = CtpBridge._vnpy_sym_to_ths(sym, ex) or sym
|
|
except Exception:
|
|
pass
|
|
entry, _ = resolve_ctp_entry(
|
|
ths,
|
|
row.get("direction") or "long",
|
|
row,
|
|
tick=tick,
|
|
)
|
|
if not entry or entry <= 0:
|
|
continue
|
|
current = float(row.get("avg_price") or 0)
|
|
if row.get("avg_price_locked") and current > 0:
|
|
if abs(entry - current) < 0.5:
|
|
continue
|
|
updated = dict(row)
|
|
updated["avg_price"] = entry
|
|
updated["avg_price_locked"] = True
|
|
self._positions[pk] = updated
|
|
changed = True
|
|
return changed
|
|
|
|
def upsert_position(
|
|
self,
|
|
row: dict[str, Any],
|
|
*,
|
|
notify: bool = True,
|
|
trades: Optional[list[dict[str, Any]]] = None,
|
|
ths_sym: str = "",
|
|
) -> None:
|
|
lots = int(row.get("lots") or 0)
|
|
ex = row.get("exchange") or ""
|
|
sym = row.get("symbol") or ""
|
|
direction = row.get("direction") or "long"
|
|
pk = position_key(ex, sym, direction)
|
|
tick = self.get_tick_price(ex, sym)
|
|
with self._lock:
|
|
if lots <= 0:
|
|
self._positions.pop(pk, None)
|
|
else:
|
|
old = self._positions.get(pk)
|
|
row = reconcile_position_avg(
|
|
old, dict(row), tick, trades=trades, ths_sym=ths_sym or sym,
|
|
)
|
|
row["position_key"] = pk
|
|
self._positions[pk] = row
|
|
if notify:
|
|
self._notify()
|
|
|
|
def remove_position(self, pk: str, *, notify: bool = True) -> None:
|
|
with self._lock:
|
|
self._positions.pop(pk, None)
|
|
if notify:
|
|
self._notify()
|
|
|
|
def set_tick_price(self, exchange: str, symbol: str, price: float) -> None:
|
|
if not symbol or price <= 0:
|
|
return
|
|
key = f"{(exchange or '').upper()}|{symbol.lower()}"
|
|
with self._lock:
|
|
self._tick_prices[key] = float(price)
|
|
|
|
def get_tick_price(self, exchange: str, symbol: str) -> Optional[float]:
|
|
key = f"{(exchange or '').upper()}|{symbol.lower()}"
|
|
with self._lock:
|
|
return self._tick_prices.get(key)
|
|
|
|
def get_active_orders(self) -> list[dict[str, Any]]:
|
|
with self._lock:
|
|
return list(self._orders.values())
|
|
|
|
def get_positions(self) -> list[dict[str, Any]]:
|
|
with self._lock:
|
|
return list(self._positions.values())
|
|
|
|
def position_keys(self) -> set[str]:
|
|
with self._lock:
|
|
return set(self._positions.keys())
|
|
|
|
def clear(self) -> None:
|
|
with self._lock:
|
|
self._orders.clear()
|
|
self._positions.clear()
|
|
self._tick_prices.clear()
|
|
self._sync_state = "idle"
|
|
|
|
def calibrate_from_lists(
|
|
self,
|
|
orders: list[dict[str, Any]],
|
|
positions: list[dict[str, Any]],
|
|
*,
|
|
trades: Optional[list[dict[str, Any]]] = None,
|
|
ths_for_vnpy_sym: Optional[Callable[[str, str], str]] = None,
|
|
) -> None:
|
|
"""全量校准:以 vnpy 内存为准重建订单/持仓簿。"""
|
|
self.begin_sync()
|
|
new_orders: dict[str, dict[str, Any]] = {}
|
|
for o in orders or []:
|
|
oid = str(o.get("order_id") or o.get("vt_order_id") or "").strip()
|
|
if oid:
|
|
new_orders[oid] = dict(o)
|
|
new_positions: dict[str, dict[str, Any]] = {}
|
|
for p in positions or []:
|
|
lots = int(p.get("lots") or 0)
|
|
if lots <= 0:
|
|
continue
|
|
ex = p.get("exchange") or ""
|
|
sym = p.get("symbol") or ""
|
|
direction = p.get("direction") or "long"
|
|
pk = position_key(ex, sym, direction)
|
|
row = dict(p)
|
|
row["position_key"] = pk
|
|
old = self._positions.get(pk)
|
|
tick = self.get_tick_price(ex, sym)
|
|
ths = sym
|
|
if ths_for_vnpy_sym:
|
|
try:
|
|
ths = ths_for_vnpy_sym(sym, ex) or sym
|
|
except Exception:
|
|
ths = sym
|
|
new_positions[pk] = reconcile_position_avg(
|
|
old, row, tick, trades=trades, ths_sym=ths,
|
|
)
|
|
with self._lock:
|
|
self._orders = new_orders
|
|
self._positions = new_positions
|
|
self.finish_sync()
|
|
self._notify()
|
|
|
|
|
|
trading_state = CtpTradingState()
|