# Copyright (c) 2025-2026 马建军. All rights reserved. # 专有软件 — 未经授权禁止复制、传播、转售。 # 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。 # 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md """CTP 执行层:模拟盘 → SimNow;实盘 → 期货公司(vnpy_ctp)。""" from __future__ import annotations import logging import os import re import threading import time from collections import deque from typing import Any, Callable, Optional from locale_fix import ensure_process_locale ensure_process_locale() from ctp_settings import live_setting_dict, simnow_setting_dict from ctp_symbol import ths_to_vnpy_symbol, to_vnpy_exchange from contract_specs import get_contract_spec logger = logging.getLogger(__name__) GATEWAY_NAME = "CTP" CONNECT_WAIT_SEC = 60 CONNECT_POLL_INTERVAL_SEC = 0.5 LOGIN_BAN_COOLDOWN_SEC = 45 * 60 LOGIN_FAIL_COOLDOWN_SEC = 5 * 60 CTP_COOLDOWN_UNTIL_KEY = "ctp_login_cooldown_until" CTP_LAST_ERROR_KEY = "ctp_last_error" def _persist_login_cooldown(seconds: float) -> None: from fee_specs import get_setting, set_setting new_until = time.time() + max(0.0, seconds) try: old = float(get_setting(CTP_COOLDOWN_UNTIL_KEY, "0") or 0) except (TypeError, ValueError): old = 0.0 if new_until > old: set_setting(CTP_COOLDOWN_UNTIL_KEY, str(new_until)) def _persisted_login_cooldown_remaining() -> int: from fee_specs import get_setting try: until = float(get_setting(CTP_COOLDOWN_UNTIL_KEY, "0") or 0) return max(0, int(until - time.time())) except (TypeError, ValueError): return 0 def _clear_persisted_login_cooldown() -> None: from fee_specs import set_setting set_setting(CTP_COOLDOWN_UNTIL_KEY, "0") def _persist_last_error(msg: str) -> None: from fee_specs import set_setting set_setting(CTP_LAST_ERROR_KEY, (msg or "").strip()) def _load_persisted_last_error() -> str: from fee_specs import get_setting return (get_setting(CTP_LAST_ERROR_KEY, "") or "").strip() _position_refresh_callback: Optional[Callable[[], None]] = None _tick_sl_tp_callback: Optional[Callable[[str, str, float], None]] = None _ctp_connected_callback: Optional[Callable[[str], None]] = None _position_refresh_debounce_lock = threading.Lock() _position_refresh_debounce_ts: float = 0.0 def set_position_refresh_callback(fn: Optional[Callable[[], None]]) -> None: global _position_refresh_callback _position_refresh_callback = fn def set_tick_sl_tp_callback(fn: Optional[Callable[[str, str, float], None]]) -> None: """注册 tick 回调:exchange, symbol, last_price → 本地 SL/TP 触发。""" global _tick_sl_tp_callback _tick_sl_tp_callback = fn def set_ctp_connected_callback(fn: Optional[Callable[[str], None]]) -> None: """CTP 交易通道登录成功后回调(mode=simulation|live)。""" global _ctp_connected_callback _ctp_connected_callback = fn def _fire_ctp_connected_callback(mode: str) -> None: fn = _ctp_connected_callback if not fn: return try: threading.Thread( target=fn, args=(mode,), daemon=True, name="ctp-connected-cb", ).start() except Exception as exc: logger.debug("ctp connected callback: %s", exc) def _fire_position_refresh_callback() -> None: fn = _position_refresh_callback if not fn: return try: threading.Thread(target=fn, daemon=True, name="ctp-position-refresh").start() except Exception as exc: logger.debug("position refresh callback: %s", exc) def _fire_position_refresh_callback_debounced(*, min_interval: float = 0.35) -> None: global _position_refresh_debounce_ts now = time.monotonic() with _position_refresh_debounce_lock: if now - _position_refresh_debounce_ts < min_interval: return _position_refresh_debounce_ts = now _fire_position_refresh_callback() def _fire_position_refresh_burst() -> None: """连接后持仓回报可能分批到达,分多次触发快照刷新。""" _fire_position_refresh_callback() for delay in (1.5, 4.0, 10.0, 25.0): threading.Timer(delay, _fire_position_refresh_callback).start() _bridge: Optional["CtpBridge"] = None _bridge_lock = threading.Lock() _ctp_td_lock = threading.RLock() POSITION_QUERY_MIN_INTERVAL_SEC = 5.0 TRADE_QUERY_MIN_INTERVAL_SEC = 10.0 def _simnow_setting() -> dict[str, str]: """SimNow 仿真前置(系统设置优先,.env 兜底)。""" return simnow_setting_dict() def _live_setting() -> dict[str, str]: return live_setting_dict() def _setting_for_mode(mode: str) -> dict[str, str]: return _simnow_setting() if mode == "simulation" else _live_setting() def _mode_label(mode: str) -> str: return "SimNow" if mode == "simulation" else "期货公司实盘" def _parse_tcp_address(address: str) -> tuple[str, int]: raw = (address or "").strip() if raw.startswith("tcp://"): raw = raw[6:] if ":" not in raw: raise ValueError(f"无效 TCP 地址: {address}") host, port_s = raw.rsplit(":", 1) return host, int(port_s) def probe_tcp_address(address: str, timeout: float = 5.0) -> tuple[bool, str]: """探测 CTP 前置 TCP 是否可达。""" import socket try: host, port = _parse_tcp_address(address) with socket.create_connection((host, port), timeout=timeout): return True, "" except Exception as exc: return False, str(exc) def _format_ctp_failure(ctp_logs: list[str], *, td_address: str = "") -> str: """根据 CTP 网关日志拼出可读错误。""" if td_address: ok, err = probe_tcp_address(td_address, timeout=4.0) if not ok: return ( f"SimNow 交易前置不可达:{td_address}({err})。" "182.254.243.31 已停用,请改 .env 为官方前置 " "tcp://180.168.146.187:10201 / 10211,并确认服务器能访问该地址。" ) text = "\n".join(ctp_logs) if "连续登录失败" in text or "登录被禁止" in text or "代码:75" in text: return ( "CTP 登录被临时禁止:连续失败次数过多(错误码 75)。" "请等待约 30~60 分钟后再试,先用快期确认投资者代码与密码正确,期间勿反复点「连接」。" ) if "4097" in text or "Decrypt handshake" in text or "shake hand" in text.lower(): return ( "CTP 握手失败(4097):vnpy_ctp 与 SimNow 前置加密不匹配。" "请执行 pip install -U vnpy vnpy_ctp 后重启,并确认 .env 中 SIMNOW_ENV=实盘" ) if "不合法的登录" in text or "密码" in text or "账号" in text: tail = ctp_logs[-1] if ctp_logs else "" return f"CTP 登录被拒:{tail or '请检查投资者代码与密码(快期能否登录)'}" if "连接断开" in text or "disconnect" in text.lower(): tail = ctp_logs[-1] if ctp_logs else "" return f"CTP 连接断开:{tail or '请检查前置地址与网络'}" if ctp_logs: return f"CTP 连接失败:{ctp_logs[-1]}" return "CTP 连接超时:未收到柜台回报。请检查 SimNow 账号、前置地址、网络(nc 测端口),并用快期验证账号" def round_to_tick(price: float, tick: float) -> float: if tick <= 0: return float(price) steps = round(float(price) / tick) return round(steps * tick, 10) def _is_long_direction(direction_obj: Any) -> bool: s = str(direction_obj or "") return "LONG" in s.upper() or "多" in s class CtpBridge: def __init__(self) -> None: self._engine = None self._ee = None self._connected_mode: Optional[str] = None self._last_error: str = "" self._connect_lock = threading.Lock() self._connect_in_progress = False self._login_cooldown_until: float = 0.0 self._restore_persisted_state() self._commission_waiters: dict[int, threading.Event] = {} self._commission_lists: dict[int, list] = {} self._commission_hooked = False self._margin_rate_waiters: dict[int, threading.Event] = {} self._margin_rate_lists: dict[int, list] = {} self._margin_rate_hooked = False self._instrument_hooked = False self._instrument_margin_ratios: dict[str, dict[str, float]] = {} self._margin_per_lot: dict[str, float] = {} self._subscribed: set[str] = set() self._last_position_query_ts: float = 0.0 self._position_margins: dict[str, float] = {} self._position_open_times: dict[str, str] = {} self._margin_hooked = False self._trade_hooked = False self._trade_query_results: list[dict[str, Any]] = [] self._trade_query_event = threading.Event() self._last_trade_query_ts: float = 0.0 self._last_connect_ok_ts: float = 0.0 self._tick_hooked = False self._position_hooked = False self._order_hooked = False self._trade_hooked = False self._bar_generators: dict[str, Any] = {} self._bars_1m: dict[str, deque] = {} self._init_engine() def _init_engine(self) -> None: ensure_process_locale() try: from vnpy.event import EventEngine from vnpy.trader.engine import MainEngine from vnpy_ctp import CtpGateway self._ee = EventEngine() self._engine = MainEngine(self._ee) self._engine.add_gateway(CtpGateway) self._ensure_position_event_hook() self._ensure_order_event_hook() self._ensure_trade_event_hook() except ImportError: self._last_error = "未安装 vnpy / vnpy_ctp,请 pip install vnpy vnpy_ctp" except Exception as exc: self._last_error = str(exc) def _ensure_position_event_hook(self) -> None: if self._position_hooked or not self._ee: return try: from vnpy.trader.event import EVENT_POSITION except ImportError: return def _on_position(event) -> None: try: from ctp_trading_state import trading_state pos = event.data row = self._position_row_from_vnpy(pos) if row: trading_state.upsert_position(row, notify=False) sym = getattr(pos, "symbol", "") or "" d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short" vol = int(getattr(pos, "volume", 0) or 0) if vol <= 0: exchange = getattr(pos, "exchange", None) ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") from ctp_trading_state import position_key trading_state.remove_position( position_key(ex_name, sym, d), notify=False, ) else: for attr in ("margin", "use_margin", "UseMargin"): raw = float(getattr(pos, attr, 0) or 0) if raw > 0: self._position_margins[self._position_margin_key(sym, d)] = raw if vol > 0: self._margin_per_lot[self._position_margin_key(sym, d)] = round( raw / vol, 2, ) break except Exception as exc: logger.debug("position margin cache: %s", exc) _fire_position_refresh_callback_debounced() self._ee.register(EVENT_POSITION, _on_position) self._position_hooked = True def _ensure_order_event_hook(self) -> None: if self._order_hooked or not self._ee: return try: from vnpy.trader.event import EVENT_ORDER except ImportError: return def _on_order(event) -> None: try: from ctp_trading_state import trading_state order = event.data row = self._order_row_from_vnpy(order) if not row: return status_s = str(row.get("status") or "") terminal = any( x in status_s for x in ("ALLTRADED", "CANCELLED", "REJECTED", "全部成交", "已撤销", "拒单") ) oid = str(row.get("order_id") or row.get("vt_order_id") or "") if terminal or int(row.get("lots") or 0) <= 0: trading_state.remove_order(oid, notify=False) else: trading_state.upsert_order(row, notify=False) except Exception as exc: logger.debug("order event: %s", exc) _fire_position_refresh_callback_debounced(min_interval=0.2) self._ee.register(EVENT_ORDER, _on_order) self._order_hooked = True def _ensure_trade_event_hook(self) -> None: if self._trade_hooked or not self._ee: return try: from vnpy.trader.event import EVENT_TRADE except ImportError: return def _on_trade(event) -> None: try: trade = event.data row = self._trade_row_from_vnpy(trade) if row and row.get("offset") == "open": sym = row.get("symbol") or "" pd = row.get("position_direction") or "long" dt = row.get("datetime") or "" if sym and dt: self._position_open_times[self._position_margin_key(sym, pd)] = dt except Exception as exc: logger.debug("trade event: %s", exc) _fire_position_refresh_callback_debounced(min_interval=0.2) self._ee.register(EVENT_TRADE, _on_trade) self._trade_hooked = True def _order_row_from_vnpy(self, order: Any) -> Optional[dict[str, Any]]: try: status = getattr(order, "status", None) status_s = str(status) vol = int(getattr(order, "volume", 0) or 0) traded = int(getattr(order, "traded", 0) or 0) remain = max(0, vol - traded) direction = getattr(order, "direction", None) d = "long" if direction is not None and str(direction).endswith("SHORT"): d = "short" offset = getattr(order, "offset", None) sym = getattr(order, "symbol", "") or "" exchange = getattr(order, "exchange", None) ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") vt_oid = str(getattr(order, "vt_orderid", "") or "") order_id = str(getattr(order, "orderid", "") or "") return { "symbol": sym, "exchange": ex_name, "direction": d, "lots": remain, "price": float(getattr(order, "price", 0) or 0), "offset": str(offset or ""), "order_id": vt_oid or order_id, "vt_order_id": vt_oid, "status": status_s, } except Exception as exc: logger.debug("order_row_from_vnpy: %s", exc) return None def _position_row_from_vnpy(self, pos: Any) -> Optional[dict[str, Any]]: try: vol = int(getattr(pos, "volume", 0) or 0) d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short" sym = getattr(pos, "symbol", "") or "" exchange = getattr(pos, "exchange", None) ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") price = float(getattr(pos, "price", 0) or 0) yd = int(getattr(pos, "yd_volume", 0) or 0) td = max(0, vol - yd) margin = self.estimate_position_margin(sym, ex_name, d, vol, price, pos=pos) open_time = self._lookup_position_open_time(sym, d) or None return { "symbol": sym, "exchange": ex_name, "direction": d, "lots": vol, "avg_price": price, "pnl": float(getattr(pos, "pnl", 0) or 0), "frozen": int(getattr(pos, "frozen", 0) or 0), "margin": margin, "open_time": open_time, "yd_volume": yd, "td_volume": td, } except Exception as exc: logger.debug("position_row_from_vnpy: %s", exc) return None def calibrate_trading_state(self) -> None: """全量校准内存簿(读 vnpy 缓存,不 query 柜台)。""" try: from ctp_trading_state import trading_state with _ctp_td_lock: orders = self.list_active_orders() positions = self._collect_positions() trading_state.calibrate_from_lists(orders, positions) except Exception as exc: logger.debug("calibrate trading state: %s", exc) def available(self) -> bool: return self._engine is not None @property def last_error(self) -> str: return self._last_error @property def connected_mode(self) -> Optional[str]: return self._connected_mode def connect_in_progress(self) -> bool: return self._connect_in_progress def _restore_persisted_state(self) -> None: err = _load_persisted_last_error() if err: self._last_error = err db_remain = _persisted_login_cooldown_remaining() if db_remain > 0: self._login_cooldown_until = time.monotonic() + db_remain def login_cooldown_remaining(self) -> int: """距允许再次登录的剩余秒数(内存 + 数据库,重启后仍有效)。""" mem = max(0, int(self._login_cooldown_until - time.monotonic())) return max(mem, _persisted_login_cooldown_remaining()) def _is_login_cooldown_active(self) -> bool: return self.login_cooldown_remaining() > 0 def _set_login_cooldown(self, seconds: float) -> None: until = time.monotonic() + max(0.0, seconds) if until > self._login_cooldown_until: self._login_cooldown_until = until _persist_login_cooldown(seconds) def _clear_login_cooldown(self) -> None: self._login_cooldown_until = 0.0 _clear_persisted_login_cooldown() def _apply_login_failure_cooldown(self, ctp_logs: list[str]) -> None: text = "\n".join(ctp_logs) if "连续登录失败" in text or "登录被禁止" in text or "代码:75" in text: self._set_login_cooldown(LOGIN_BAN_COOLDOWN_SEC) elif any("登录失败" in m or "不合法的登录" in m for m in ctp_logs): self._set_login_cooldown(LOGIN_FAIL_COOLDOWN_SEC) def _login_cooldown_message(self) -> str: remain = self.login_cooldown_remaining() return ( f"CTP 登录冷却中,请 {remain // 60} 分 {remain % 60} 秒后再试" f"(避免连续失败被 SimNow 封禁)" ) def _close_gateway(self) -> None: """关闭 CTP 网关,避免半连接状态下重连卡在「连接登录」。""" if not self._engine: return try: gw = self._engine.get_gateway(GATEWAY_NAME) if gw: gw.close() except Exception as exc: logger.debug("gateway close: %s", exc) self._connected_mode = None try: from ctp_trading_state import trading_state trading_state.clear() except Exception: pass time.sleep(0.6) def _login_rejected(self, ctp_logs: list[str]) -> bool: return any( kw in m for m in ctp_logs for kw in ("登录失败", "不合法的登录", "登录被禁止", "连续登录失败") ) def _wait_connected(self, mode: str, ctp_logs: list[str] | None = None) -> bool: """等待账户回报或交易通道登录成功。""" if not self._engine: return False logs = ctp_logs or [] loops = max(1, int(CONNECT_WAIT_SEC / CONNECT_POLL_INTERVAL_SEC)) for _ in range(loops): if self._login_rejected(logs): return False try: if self._engine.get_all_accounts(): return True except Exception: pass if self._td_logged_in(): return True time.sleep(CONNECT_POLL_INTERVAL_SEC) return False def status(self, mode: str) -> dict[str, Any]: if self._connected_mode == mode: self.ping() st = _setting_for_mode(mode) missing = [k for k in ("用户名", "密码", "交易服务器") if not st.get(k)] cooldown = self.login_cooldown_remaining() connecting = bool(self._connect_in_progress and cooldown <= 0) last_error = self._last_error or _load_persisted_last_error() return { "vnpy_installed": self.available(), "connected": self._connected_mode == mode, "connecting": connecting, "connected_mode": self._connected_mode, "mode_label": _mode_label(mode), "missing_config": missing, "last_error": last_error, "login_cooldown_sec": cooldown, "broker_id": st.get("经纪商代码", ""), "td_address": st.get("交易服务器", ""), } def connect(self, mode: str, *, force: bool = False, scheduled: bool = False) -> None: from ctp_settings import CTP_DISABLED_HINT if not _ctp_connect_permitted(scheduled=scheduled): self._last_error = CTP_DISABLED_HINT _persist_last_error(CTP_DISABLED_HINT) raise RuntimeError(CTP_DISABLED_HINT) if self._connect_in_progress: raise RuntimeError("CTP 正在连接中,请稍候") if self._is_login_cooldown_active() and not force: msg = self._login_cooldown_message() self._last_error = msg raise RuntimeError(msg) if not self._engine: raise RuntimeError(self._last_error or "vnpy 引擎未初始化") if self._connected_mode == mode and not force: if self.ping(): return self._connected_mode = None setting = _setting_for_mode(mode) if not setting.get("用户名") or not setting.get("密码"): raise ValueError( f"{_mode_label(mode)}:请在 .env 配置 " f"{'SIMNOW_USER / SIMNOW_PASSWORD' if mode == 'simulation' else 'CTP_LIVE_USER / CTP_LIVE_PASSWORD'}" ) if not setting.get("交易服务器"): raise ValueError(f"{_mode_label(mode)}:未配置交易服务器地址") self._connect_in_progress = True try: with _ctp_td_lock: with self._connect_lock: if force and self._connected_mode: self._close_gateway() elif self._connected_mode and self._connected_mode != mode: try: self._engine.close() except Exception: pass self._connected_mode = None time.sleep(1) elif not (self._connected_mode == mode and self.ping()): self._close_gateway() ctp_logs: list[str] = [] from vnpy.trader.event import EVENT_LOG def _on_log(event) -> None: msg = getattr(event.data, "msg", "") or str(event.data) if msg: ctp_logs.append(str(msg)) if len(ctp_logs) > 40: ctp_logs.pop(0) logger.info("CTP | %s", msg) self._ee.register(EVENT_LOG, _on_log) try: ensure_process_locale() logger.info( "CTP 连接 [%s] user=%s td=%s env=%s", mode, setting.get("用户名"), setting.get("交易服务器"), setting.get("柜台环境", "实盘"), ) td_addr = setting.get("交易服务器", "") ok, err = probe_tcp_address(td_addr, timeout=5.0) if not ok: raise RuntimeError( f"SimNow 交易前置不可达:{td_addr}({err})。" "请更新 .env 中 SIMNOW_TD_ADDRESS 为官网最新地址," "并在服务器执行 nc -zv 验证出网。" ) self._ensure_instrument_margin_hooks() self._engine.connect(setting, GATEWAY_NAME) if self._wait_connected(mode, ctp_logs): self._connected_mode = mode self._last_connect_ok_ts = time.time() self._last_error = "" _persist_last_error("") self._clear_login_cooldown() logger.info("CTP 已连接 [%s] td_login=%s accounts=%s", mode, self._td_logged_in(), len(self._engine.get_all_accounts() or [])) self._schedule_fee_sync(mode) try: self.calibrate_trading_state() except Exception as exc: logger.debug("post-connect calibrate: %s", exc) _fire_position_refresh_burst() _fire_ctp_connected_callback(mode) return finally: self._ee.unregister(EVENT_LOG, _on_log) self._close_gateway() self._apply_login_failure_cooldown(ctp_logs) hint = _format_ctp_failure(ctp_logs, td_address=setting.get("交易服务器", "")) self._last_error = hint _persist_last_error(hint) logger.warning("CTP 连接失败 [%s]: %s | logs=%s", mode, hint, ctp_logs[-5:]) raise RuntimeError(hint) finally: self._connect_in_progress = False def start_connect_async( self, mode: str, *, force: bool = False, scheduled: bool = False, ) -> dict[str, Any]: """后台连接,不阻塞 HTTP 请求。""" from ctp_settings import CTP_DISABLED_HINT if not _ctp_connect_permitted(scheduled=scheduled): self._last_error = CTP_DISABLED_HINT _persist_last_error(CTP_DISABLED_HINT) return { "started": False, "connecting": False, "connected": False, "disabled": True, "error": CTP_DISABLED_HINT, } if self._connected_mode == mode and self.ping() and not force: return {"started": False, "connecting": False, "connected": True} if self._connect_in_progress: return {"started": False, "connecting": True, "connected": False} if self._is_login_cooldown_active() and not force: self._last_error = self._login_cooldown_message() return { "started": False, "connecting": False, "connected": False, "cooldown": True, } def _run() -> None: try: self.connect(mode, force=force, scheduled=scheduled) except Exception as exc: logger.warning("CTP 后台连接失败: %s", exc) threading.Thread(target=_run, daemon=True, name="ctp-connect-async").start() return {"started": True, "connecting": True, "connected": False} def ensure_connected(self, mode: str) -> None: if self._connected_mode == mode and self.ping(): return self.connect(mode) def require_connected(self, mode: str) -> None: """报单前检查:须已连接,不在此发起阻塞式 connect。""" if self._connect_in_progress: raise RuntimeError("CTP 连接中,请稍候再下单") if self._connected_mode != mode or not self.ping(): raise RuntimeError("请先连接 CTP(持仓监控页点击「连接 CTP」)") if not self._td_logged_in(): raise RuntimeError("CTP 交易通道未登录,请重连 CTP 后再下单") def _td_logged_in(self) -> bool: try: gw = self._engine.get_gateway(GATEWAY_NAME) td = gw.td_api return bool(getattr(td, "login_status", False)) except Exception: return False def _find_position(self, sym: str, ex_name: str, hold_direction: str) -> Any: if not self._engine: return None sym_l = sym.lower() ex_u = ex_name.upper() want_long = hold_direction == "long" try: for pos in self._engine.get_all_positions(): ps = (getattr(pos, "symbol", "") or "").lower() pe = getattr(pos, "exchange", None) pe_s = str(pe.value if hasattr(pe, "value") else pe or "").upper() if ps != sym_l or pe_s != ex_u: continue vol = int(getattr(pos, "volume", 0) or 0) if vol <= 0: continue is_long = _is_long_direction(getattr(pos, "direction", None)) if is_long == want_long: return pos except Exception as exc: logger.debug("find position: %s", exc) return None def _resolve_close_offset(self, sym: str, ex_name: str, hold_direction: str, lots: int) -> Any: from vnpy.trader.constant import Offset ex_u = (ex_name or "").upper() # 上期所/能源中心/郑商所/中金所须区分平今/平昨;大商所等可用通用 CLOSE if ex_u not in ("CZCE", "CFFEX", "SHFE", "INE"): return Offset.CLOSE pos = self._find_position(sym, ex_u, hold_direction) if not pos: for p in self._collect_positions(): ps = (p.get("symbol") or "").lower() if ps != sym.lower(): continue if (p.get("direction") or "long") != hold_direction: continue td = int(p.get("td_volume") or 0) yd = int(p.get("yd_volume") or 0) if td >= lots: return Offset.CLOSETODAY if yd >= lots: return Offset.CLOSEYESTERDAY if td + yd >= lots: return Offset.CLOSETODAY break if ex_u in ("SHFE", "INE", "CZCE"): return Offset.CLOSETODAY return Offset.CLOSE vol = int(getattr(pos, "volume", 0) or 0) yd = int(getattr(pos, "yd_volume", 0) or 0) today = max(0, vol - yd) if today >= lots: return Offset.CLOSETODAY return Offset.CLOSEYESTERDAY def _aggressive_limit_price( self, ths_code: str, sym: str, ex_name: str, direction: Any, tick: float, fallback: float, ) -> float: from vnpy.trader.constant import Direction self.subscribe_symbol(ths_code) lp = fallback detail = self.get_tick_detail(ths_code, mode=self._connected_mode or "") if detail.get("price"): lp = float(detail["price"]) slip = max(tick, tick * 3) if direction == Direction.LONG: lp = lp + slip else: lp = max(tick, lp - slip) return round_to_tick(lp, tick) def ping(self) -> bool: """检测连接是否仍有效;无效则清除 connected 状态。""" if not self._engine or not self._connected_mode: return False if self._td_logged_in(): return True try: if self._engine.get_all_accounts(): return True except Exception as exc: logger.debug("CTP ping failed: %s", exc) self._connected_mode = None return False def mark_disconnected(self) -> None: self._connected_mode = None def reconnect_after_settings_saved(self, mode: str) -> dict[str, Any]: """保存前置/账号后关闭旧连接,并用数据库中的新配置重连。""" from ctp_settings import is_ctp_auto_connect_enabled self._close_gateway() self._last_error = "" _persist_last_error("") if not is_ctp_auto_connect_enabled(): return {"started": False, "connecting": False, "connected": False, "disabled": True} return self.start_connect_async(mode, force=True) def _schedule_fee_sync(self, mode: str) -> None: """连接成功后触发每日同步检查(非每次全量)。""" def _run() -> None: time.sleep(45) try: from ctp_fee_worker import try_daily_ctp_fee_sync def _gs(key: str, default: str = "") -> str: from fee_specs import get_setting return get_setting(key, default) def _ss(key: str, val: str) -> None: from fee_specs import set_setting set_setting(key, val) try_daily_ctp_fee_sync( mode, get_setting=_gs, set_setting=_ss, force=False, ) except Exception as exc: logger.debug("CTP 手续费连接后检查: %s", exc) threading.Thread(target=_run, daemon=True, name="ctp-fee-sync-check").start() def _ensure_commission_callback(self) -> None: if self._commission_hooked or not self._engine: return try: gw = self._engine.get_gateway(GATEWAY_NAME) td = gw.td_api except Exception: return bridge = self def on_rsp(data: dict, error: dict, reqid: int, last: bool) -> None: if error and int(error.get("ErrorID") or 0) != 0: logger.debug( "CTP commission error reqid=%s: %s", reqid, error.get("ErrorMsg") or error, ) if data and data.get("InstrumentID"): bridge._commission_lists.setdefault(reqid, []).append(dict(data)) ev = bridge._commission_waiters.get(reqid) if last and ev: ev.set() td.onRspQryInstrumentCommissionRate = on_rsp # type: ignore[method-assign] self._commission_hooked = True def _query_commission( self, *, mode: str, instrument_id: str = "", exchange_id: str = "", timeout: float = 8, ) -> list[dict]: if self._connected_mode != mode or not self._engine: return [] try: gw = self._engine.get_gateway(GATEWAY_NAME) td = gw.td_api except Exception as exc: logger.debug("commission query init: %s", exc) return [] if not getattr(td, "login_status", False): return [] if not hasattr(td, "reqQryInstrumentCommissionRate"): return [] self._ensure_commission_callback() reqid = int(getattr(td, "reqid", 0)) + 1 td.reqid = reqid ev = threading.Event() self._commission_waiters[reqid] = ev req = { "BrokerID": td.brokerid, "InvestorID": td.userid, "InstrumentID": instrument_id or "", "ExchangeID": exchange_id or "", } ret = td.reqQryInstrumentCommissionRate(req, reqid) if ret != 0: self._commission_waiters.pop(reqid, None) return [] ev.wait(timeout=timeout) self._commission_waiters.pop(reqid, None) return self._commission_lists.pop(reqid, []) def query_instrument_commission(self, ths_code: str, *, mode: str) -> dict: """查询单合约 CTP 手续费率(需已连接)。""" try: sym, ex_name = ths_to_vnpy_symbol(ths_code) except Exception: return {} rows = self._query_commission( mode=mode, instrument_id=sym, exchange_id=ex_name, ) return rows[-1] if rows else {} def query_all_commissions(self, *, mode: str) -> list[dict]: """批量查询全部合约手续费(InstrumentID 留空)。""" return self._query_commission(mode=mode, timeout=45) @staticmethod def _parse_margin_ratio_row(data: dict) -> dict[str, float]: long_r = float( data.get("LongMarginRatioByMoney") or data.get("LongMarginRatio") or 0 ) short_r = float( data.get("ShortMarginRatioByMoney") or data.get("ShortMarginRatio") or 0 ) return {"long": long_r, "short": short_r} def _cache_margin_ratio(self, sym: str, data: dict) -> None: ratios = self._parse_margin_ratio_row(data) if ratios["long"] <= 0 and ratios["short"] <= 0: return key = (sym or "").strip().lower() if not key: return self._instrument_margin_ratios[key] = ratios def _ensure_instrument_margin_hooks(self) -> None: """登录前挂钩:合约查询回报缓存保证金率;支持按需 reqQryInstrumentMarginRate。""" if not self._engine: return try: gw = self._engine.get_gateway(GATEWAY_NAME) td = gw.td_api except Exception: return bridge = self if not self._instrument_hooked: orig = td.onRspQryInstrument def on_instrument(data: dict, error: dict, reqid: int, last: bool) -> None: try: if data and data.get("InstrumentID"): bridge._cache_margin_ratio(str(data["InstrumentID"]), data) except Exception as exc: logger.debug("instrument margin cache: %s", exc) return orig(data, error, reqid, last) td.onRspQryInstrument = on_instrument # type: ignore[method-assign] self._instrument_hooked = True if self._margin_rate_hooked: return def on_margin_rate(data: dict, error: dict, reqid: int, last: bool) -> None: if error and int(error.get("ErrorID") or 0) != 0: logger.debug( "CTP margin rate error reqid=%s: %s", reqid, error.get("ErrorMsg") or error, ) if data and data.get("InstrumentID"): bridge._margin_rate_lists.setdefault(reqid, []).append(dict(data)) ev = bridge._margin_rate_waiters.get(reqid) if last and ev: ev.set() td.onRspQryInstrumentMarginRate = on_margin_rate # type: ignore[method-assign] self._margin_rate_hooked = True def _query_instrument_margin_rate( self, *, mode: str, instrument_id: str, exchange_id: str, timeout: float = 6, ) -> Optional[dict[str, float]]: if self._connected_mode != mode or not self._engine: return None sym = (instrument_id or "").strip() if not sym: return None cached = self._instrument_margin_ratios.get(sym.lower()) if cached and (cached.get("long", 0) > 0 or cached.get("short", 0) > 0): return cached try: gw = self._engine.get_gateway(GATEWAY_NAME) td = gw.td_api except Exception as exc: logger.debug("margin rate query init: %s", exc) return None if not getattr(td, "login_status", False): return None if not hasattr(td, "reqQryInstrumentMarginRate"): return None self._ensure_instrument_margin_hooks() reqid = int(getattr(td, "reqid", 0)) + 1 td.reqid = reqid ev = threading.Event() self._margin_rate_waiters[reqid] = ev req = { "BrokerID": td.brokerid, "InvestorID": td.userid, "InstrumentID": sym, "ExchangeID": exchange_id or "", "InvestorRange": "1", "HedgeFlag": "1", } with _ctp_td_lock: ret = td.reqQryInstrumentMarginRate(req, reqid) if ret != 0: self._margin_rate_waiters.pop(reqid, None) return None ev.wait(timeout=timeout) self._margin_rate_waiters.pop(reqid, None) rows = self._margin_rate_lists.pop(reqid, []) if not rows: return None ratios = self._parse_margin_ratio_row(rows[-1]) if ratios["long"] > 0 or ratios["short"] > 0: self._cache_margin_ratio(sym, rows[-1]) return ratios return None def _lookup_margin_ratios( self, sym: str, ex_name: str, *, mode: Optional[str] = None, ) -> Optional[dict[str, float]]: key = (sym or "").strip().lower() if not key: return None cached = self._instrument_margin_ratios.get(key) if cached and (cached.get("long", 0) > 0 or cached.get("short", 0) > 0): return cached if mode and self._connected_mode == mode: return self._query_instrument_margin_rate( mode=mode, instrument_id=sym, exchange_id=ex_name, ) return None def _lookup_margin_per_lot(self, sym: str, direction: str) -> float: return float( self._margin_per_lot.get(self._position_margin_key(sym, direction), 0) or 0 ) def _margin_from_ratios( self, price: float, mult: float, ratios: dict[str, float], *, direction: str, ) -> Optional[float]: long_r = float(ratios.get("long") or 0) short_r = float(ratios.get("short") or 0) d = (direction or "long").strip().lower() if mult <= 0 or price <= 0: return None if d == "max": candidates = [ round(float(price) * mult * r, 2) for r in (long_r, short_r) if r > 0 ] return max(candidates) if candidates else None if d == "short" and short_r > 0: ratio = short_r elif d != "short" and long_r > 0: ratio = long_r else: ratio = max(long_r, short_r) if ratio <= 0: return None return round(float(price) * mult * ratio, 2) def _tick_key(self, symbol: str, ex_name: str) -> str: return f"{symbol.lower()}:{ex_name.upper()}" def _price_from_tick(self, tick: Any) -> Optional[float]: for attr in ("last_price", "bid_price_1", "ask_price_1", "pre_close"): try: v = float(getattr(tick, attr, 0) or 0) except (TypeError, ValueError): v = 0.0 if v > 0: return v return None def _lookup_tick(self, symbol: str, ex_name: str) -> Optional[float]: if not self._engine: return None sym_l = symbol.lower() ex_u = ex_name.upper() try: for tick in self._engine.get_all_ticks(): ts = (getattr(tick, "symbol", "") or "").lower() te = getattr(tick, "exchange", None) te_s = str(te.value if hasattr(te, "value") else te or "").upper() if ts == sym_l and te_s == ex_u: p = self._price_from_tick(tick) if p: return p except Exception as exc: logger.debug("lookup tick: %s", exc) return None def _bar_to_dict(self, bar: Any) -> dict: dt = getattr(bar, "datetime", None) d_str = dt.strftime("%Y-%m-%d %H:%M:%S") if dt else "" return { "d": d_str, "o": float(getattr(bar, "open_price", 0) or 0), "h": float(getattr(bar, "high_price", 0) or 0), "l": float(getattr(bar, "low_price", 0) or 0), "c": float(getattr(bar, "close_price", 0) or 0), "v": float(getattr(bar, "volume", 0) or 0), } def _ensure_bar_generator(self, sym: str, ex_name: str) -> None: key = self._tick_key(sym, ex_name) if key in self._bar_generators: return self._bars_1m[key] = deque(maxlen=4000) def on_bar(bar: Any) -> None: row = self._bar_to_dict(bar) if row.get("d"): self._bars_1m[key].append(row) try: from vnpy.trader.utility import BarGenerator self._bar_generators[key] = BarGenerator(on_bar=on_bar) except ImportError: logger.debug("BarGenerator unavailable") def _find_tick(self, symbol: str, ex_name: str) -> Any: if not self._engine: return None sym_l = symbol.lower() ex_u = ex_name.upper() try: for tick in self._engine.get_all_ticks(): ts = (getattr(tick, "symbol", "") or "").lower() te = getattr(tick, "exchange", None) te_s = str(te.value if hasattr(te, "value") else te or "").upper() if ts == sym_l and te_s == ex_u: return tick except Exception as exc: logger.debug("find tick: %s", exc) return None def _tick_to_bar(self, symbol: str, ex_name: str) -> Optional[dict]: tick = self._find_tick(symbol, ex_name) if not tick: return None lp = self._price_from_tick(tick) if not lp or lp <= 0: return None dt = getattr(tick, "datetime", None) d_str = dt.strftime("%Y-%m-%d %H:%M:%S") if dt else "" if not d_str: from datetime import datetime from zoneinfo import ZoneInfo d_str = datetime.now(ZoneInfo("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") o = float(getattr(tick, "open_price", 0) or lp) h = float(getattr(tick, "high_price", 0) or lp) lo = float(getattr(tick, "low_price", 0) or lp) return { "d": d_str, "o": o, "h": h, "l": lo, "c": lp, "v": float(getattr(tick, "volume", 0) or 0), } def _on_tick(self, tick: Any) -> None: sym = (getattr(tick, "symbol", "") or "").lower() te = getattr(tick, "exchange", None) ex_s = str(te.value if hasattr(te, "value") else te or "").upper() price = self._price_from_tick(tick) if price and price > 0: try: from ctp_trading_state import trading_state trading_state.set_tick_price(ex_s, sym, price) except Exception: pass fn = _tick_sl_tp_callback if fn: try: fn(ex_s, sym, float(price)) except Exception as exc: logger.debug("tick sl/tp callback: %s", exc) key = self._tick_key(sym, ex_s) bg = self._bar_generators.get(key) if not bg: return try: bg.update_tick(tick) except Exception as exc: logger.debug("bar gen tick: %s", exc) def _ensure_tick_handler(self) -> None: if self._tick_hooked or not self._ee: return try: from vnpy.trader.event import EVENT_TICK except ImportError: return def process_tick(event: Any) -> None: self._on_tick(event.data) self._ee.register(EVENT_TICK, process_tick) self._tick_hooked = True def get_kline_bars_1m(self, ths_code: str, *, mode: str) -> list[dict]: """订阅合约并返回 1 分钟 K 线(含正在形成的 bar)。""" if self._connected_mode != mode or not self._engine: return [] try: sym, ex_name = ths_to_vnpy_symbol(ths_code) except Exception: return [] key = self._tick_key(sym, ex_name) self._ensure_bar_generator(sym, ex_name) self.subscribe_symbol(ths_code) for _ in range(12): if self._bars_1m.get(key) and len(self._bars_1m[key]) > 0: break if self._lookup_tick(sym, ex_name): break time.sleep(0.2) bars_1m = list(self._bars_1m.get(key, [])) bg = self._bar_generators.get(key) if bg and getattr(bg, "bar", None): forming = self._bar_to_dict(bg.bar) if forming.get("d"): if not bars_1m or bars_1m[-1]["d"] != forming["d"]: bars_1m.append(forming) else: bars_1m[-1] = forming if not bars_1m: tick_bar = self._tick_to_bar(sym, ex_name) if tick_bar: bars_1m = [tick_bar] return bars_1m def get_tick_detail(self, ths_code: str, *, mode: str) -> dict[str, Any]: if self._connected_mode != mode or not self._engine: return {} try: sym, ex_name = ths_to_vnpy_symbol(ths_code) except Exception: return {} self.subscribe_symbol(ths_code) for _ in range(8): tick = self._find_tick(sym, ex_name) if tick: price = self._price_from_tick(tick) try: pre_close = float(getattr(tick, "pre_close", 0) or 0) except (TypeError, ValueError): pre_close = 0.0 return { "price": price, "pre_close": pre_close if pre_close > 0 else None, } time.sleep(0.2) return {} def subscribe_symbol(self, ths_code: str) -> None: if not self._engine or not self._connected_mode: return try: from vnpy.trader.object import SubscribeRequest sym, ex_name = ths_to_vnpy_symbol(ths_code) key = self._tick_key(sym, ex_name) self._ensure_bar_generator(sym, ex_name) if key in self._subscribed: return exchange = to_vnpy_exchange(ex_name) self._ensure_tick_handler() req = SubscribeRequest(symbol=sym, exchange=exchange) self._engine.subscribe(req, GATEWAY_NAME) self._subscribed.add(key) except Exception as exc: logger.debug("CTP subscribe %s: %s", ths_code, exc) def get_tick_price(self, ths_code: str, *, mode: str) -> Optional[float]: if self._connected_mode != mode or not self._engine: return None try: sym, ex_name = ths_to_vnpy_symbol(ths_code) except Exception: return None price = self._lookup_tick(sym, ex_name) if price: return price self.subscribe_symbol(ths_code) for _ in range(8): time.sleep(0.2) price = self._lookup_tick(sym, ex_name) if price: return price return None def get_account(self) -> dict[str, Any]: if not self._engine: return {} accounts = self._engine.get_all_accounts() if not accounts: return {} acc = accounts[0] return { "balance": float(getattr(acc, "balance", 0) or 0), "available": float(getattr(acc, "available", 0) or 0), "frozen": float(getattr(acc, "frozen", 0) or 0), "accountid": getattr(acc, "accountid", ""), } def _position_margin_key(self, sym: str, direction: str) -> str: return f"{(sym or '').lower()}:{(direction or 'long').strip().lower()}" def _lookup_position_open_time(self, sym: str, direction: str) -> str: return (self._position_open_times.get(self._position_margin_key(sym, direction)) or "").strip() @staticmethod def _parse_ctp_open_datetime(date_raw: str, time_raw: str = "") -> str: """CTP OpenDate + OpenTime → YYYY-MM-DD HH:MM[:SS]。""" d = (date_raw or "").strip() if len(d) >= 8 and d[:8].isdigit(): date_part = f"{d[:4]}-{d[4:6]}-{d[6:8]}" else: return "" t = (time_raw or "").strip().replace(":", "") if len(t) >= 6 and t[:6].isdigit(): return f"{date_part} {t[0:2]}:{t[2:4]}:{t[4:6]}" if len(t) >= 4 and t.isdigit(): return f"{date_part} {t[0:2]}:{t[2:4]}" return date_part def _parse_ctp_open_date(raw: str) -> str: return CtpBridge._parse_ctp_open_datetime(raw, "") def _install_position_margin_hook(self) -> None: """已禁用:monkey-patch CTP 持仓回调在并发下会触发 vnctptd 段错误。""" return def _lookup_position_margin(self, sym: str, direction: str) -> float: return float(self._position_margins.get(self._position_margin_key(sym, direction), 0) or 0) @staticmethod def _vnpy_sym_to_ths(sym: str, ex_name: str) -> str: import re s = (sym or "").strip() if not s: return "" ex = (ex_name or "").upper() m = re.match(r"^([A-Za-z]+)(\d+)$", s) if not m: return s letters, digits = m.group(1), m.group(2) if ex == "CZCE": return letters.upper() + (digits[-3:] if len(digits) >= 4 else digits) return letters.lower() + digits def _get_contract_for_ths(self, ths_code: str) -> Any: """按同花顺代码查 CTP 合约;精确匹配失败时在同交易所按品种前缀回退。""" if not self._engine: return None try: sym, ex_name = ths_to_vnpy_symbol(ths_code) exchange = to_vnpy_exchange(ex_name) vt_symbol = f"{sym}.{exchange.value}" contract = self._engine.get_contract(vt_symbol) if contract: return contract m = re.match(r"^([A-Za-z]+)", (ths_code or "").strip()) if not m: return None letters = m.group(1) ex_val = exchange.value candidates: list[Any] = [] get_all = getattr(self._engine, "get_all_contracts", None) pool = list(get_all()) if callable(get_all) else [] if not pool: raw = getattr(self._engine, "contracts", None) if isinstance(raw, dict): pool = list(raw.values()) sym_prefix = sym[: len(letters)] if sym else letters.lower() sym_prefix_up = letters.upper() for c in pool: c_ex = getattr(c, "exchange", None) c_ex_val = str(c_ex.value if hasattr(c_ex, "value") else c_ex or "") if c_ex_val != ex_val: continue c_sym = str(getattr(c, "symbol", "") or "") if ( c_sym.lower().startswith(sym_prefix.lower()) or c_sym.upper().startswith(sym_prefix_up) ): candidates.append(c) if not candidates: return None candidates.sort(key=lambda c: str(getattr(c, "symbol", "") or "")) return candidates[0] except Exception as exc: logger.debug("_get_contract_for_ths %s: %s", ths_code, exc) return None def estimate_margin_one_lot( self, ths_code: str, price: float, *, direction: str = "long", ) -> Optional[float]: """1 手保证金:持仓实收 > CTP 保证金率查询 > 合约查询缓存。""" if not self._engine or not price or price <= 0: return None try: sym, ex_name = ths_to_vnpy_symbol(ths_code) contract = self._get_contract_for_ths(ths_code) mult = float(getattr(contract, "size", 0) or 0) if contract else 0.0 if mult <= 0: mult = float(get_contract_spec(ths_code).get("mult") or 0) d = (direction or "long").strip().lower() if d == "max": per_lots = [ self._lookup_margin_per_lot(sym, side) for side in ("long", "short") ] per_lots = [x for x in per_lots if x > 0] if per_lots: return max(per_lots) else: per_lot = self._lookup_margin_per_lot(sym, d) if per_lot > 0: return per_lot mode = self._connected_mode ratios = self._lookup_margin_ratios(sym, ex_name, mode=mode) if ratios: return self._margin_from_ratios( price, mult, ratios, direction=d, ) return None except Exception as exc: logger.debug("estimate_margin_one_lot %s: %s", ths_code, exc) return None def estimate_position_margin( self, sym: str, ex_name: str, direction: str, lots: int, price: float, *, pos: Any = None, ) -> Optional[float]: """持仓占用保证金:优先 vnpy 字段,其次 CTP 合约保证金率估算。""" if lots <= 0 or price <= 0: return None if pos is not None: raw = float(getattr(pos, "margin", 0) or getattr(pos, "use_margin", 0) or 0) if raw > 0: return round(raw, 2) cached = self._lookup_position_margin(sym, direction) if cached > 0: return round(cached, 2) ths = self._vnpy_sym_to_ths(sym, ex_name) if not ths: return None per_lot = self.estimate_margin_one_lot(ths, price, direction=direction) if per_lot and per_lot > 0: return round(per_lot * lots, 2) return None def lookup_contract_spec(self, ths_code: str) -> Optional[dict]: """从 CTP 合约信息读取乘数与最小变动价位。""" if not self._engine: return None try: sym, ex_name = ths_to_vnpy_symbol(ths_code) contract = self._get_contract_for_ths(ths_code) if not contract: return None mult = float(getattr(contract, "size", 0) or 0) tick = float( getattr(contract, "pricetick", 0) or getattr(contract, "price_tick", 0) or 0 ) if mult <= 0: return None out: dict[str, Any] = {"mult": mult} if tick > 0: out["tick_size"] = tick long_r = float(getattr(contract, "long_margin_ratio", 0) or 0) short_r = float(getattr(contract, "short_margin_ratio", 0) or 0) c_sym = str(getattr(contract, "symbol", "") or sym or "") if c_sym and self._connected_mode: queried = self._lookup_margin_ratios( c_sym, ex_name, mode=self._connected_mode, ) if queried: long_r = float(queried.get("long") or long_r) short_r = float(queried.get("short") or short_r) if long_r > 0 or short_r > 0: out["margin_rate"] = max(long_r, short_r) return out except Exception as exc: logger.debug("lookup_contract_spec %s: %s", ths_code, exc) return None def _collect_positions(self) -> list[dict[str, Any]]: if not self._engine: return [] out: list[dict[str, Any]] = [] for pos in self._engine.get_all_positions(): vol = int(getattr(pos, "volume", 0) or 0) if vol <= 0: continue d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short" sym = getattr(pos, "symbol", "") or "" exchange = getattr(pos, "exchange", None) ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") price = float(getattr(pos, "price", 0) or 0) margin = self.estimate_position_margin( sym, ex_name, d, vol, price, pos=pos, ) open_time = self._lookup_position_open_time(sym, d) or None yd = int(getattr(pos, "yd_volume", 0) or 0) td = max(0, vol - yd) out.append({ "symbol": sym, "exchange": ex_name, "direction": d, "lots": vol, "avg_price": price, "pnl": float(getattr(pos, "pnl", 0) or 0), "frozen": int(getattr(pos, "frozen", 0) or 0), "margin": margin, "open_time": open_time, "yd_volume": yd, "td_volume": td, }) return out def refresh_positions(self) -> None: """vnpy 内存缓存持仓;禁止 query_position(vnctptd 并发查询会段错误)。""" return def list_positions(self, *, refresh_if_empty: bool = True, refresh_margin: bool = False) -> list[dict[str, Any]]: del refresh_if_empty, refresh_margin with _ctp_td_lock: return self._collect_positions() @staticmethod def _parse_trade_offset(offset_obj: Any) -> str: s = str(offset_obj or "").upper() if "OPEN" in s: return "open" return "close" @staticmethod def _parse_trade_direction(direction_obj: Any) -> str: return "long" if _is_long_direction(direction_obj) else "short" @staticmethod def _position_direction_from_trade(trade_direction: str, offset: str) -> str: td = (trade_direction or "long").strip().lower() if (offset or "open").strip().lower() == "open": return td return "short" if td == "long" else "long" def _format_trade_datetime(self, dt_obj: Any, date_raw: str = "", time_raw: str = "") -> str: if dt_obj is not None: try: if hasattr(dt_obj, "strftime"): return dt_obj.strftime("%Y-%m-%d %H:%M:%S") text = str(dt_obj).strip() if text: return text[:19].replace("T", " ") except Exception: pass parsed = self._parse_ctp_open_datetime(date_raw, time_raw) return parsed or "" def _trade_row_from_vnpy(self, trade: Any) -> Optional[dict[str, Any]]: try: sym = (getattr(trade, "symbol", "") or "").strip() vol = int(getattr(trade, "volume", 0) or 0) if not sym or vol <= 0: return None direction = self._parse_trade_direction(getattr(trade, "direction", None)) offset = self._parse_trade_offset(getattr(trade, "offset", None)) exchange = getattr(trade, "exchange", None) ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") dt = self._format_trade_datetime(getattr(trade, "datetime", None)) trade_id = str(getattr(trade, "tradeid", "") or getattr(trade, "vt_tradeid", "") or "") order_id = str(getattr(trade, "orderid", "") or getattr(trade, "vt_orderid", "") or "") if not trade_id: trade_id = f"{order_id}:{sym}:{offset}:{direction}:{vol}:{getattr(trade, 'price', 0)}:{dt}" return { "trade_id": trade_id, "order_id": order_id, "symbol": sym, "exchange": ex_name, "direction": direction, "offset": offset, "position_direction": self._position_direction_from_trade(direction, offset), "lots": vol, "price": float(getattr(trade, "price", 0) or 0), "datetime": dt, "commission": round(float(getattr(trade, "commission", 0) or 0), 2), } except Exception as exc: logger.debug("trade_row_from_vnpy: %s", exc) return None def _trade_row_from_ctp_dict(self, data: dict) -> Optional[dict[str, Any]]: try: sym = (data.get("InstrumentID") or data.get("instrument_id") or "").strip() vol = int(float(data.get("Volume") or data.get("volume") or 0)) if not sym or vol <= 0: return None dir_raw = str(data.get("Direction") or data.get("direction") or "") direction = "long" if dir_raw in ("0", "2") or "LONG" in dir_raw.upper() or dir_raw == "多" else "short" off_raw = str(data.get("OffsetFlag") or data.get("offset") or "") if off_raw in ("0",) or "OPEN" in off_raw.upper(): offset = "open" else: offset = "close" price = float(data.get("Price") or data.get("price") or 0) trade_id = str(data.get("TradeID") or data.get("tradeid") or "").strip() order_sys = str(data.get("OrderSysID") or data.get("orderid") or "").strip() dt = self._format_trade_datetime( None, str(data.get("TradeDate") or data.get("trade_date") or ""), str(data.get("TradeTime") or data.get("trade_time") or ""), ) if not trade_id: trade_id = f"{order_sys}:{sym}:{offset}:{direction}:{vol}:{price}:{dt}" return { "trade_id": trade_id, "order_id": order_sys, "symbol": sym, "exchange": str(data.get("ExchangeID") or data.get("exchange") or ""), "direction": direction, "offset": offset, "position_direction": self._position_direction_from_trade(direction, offset), "lots": vol, "price": price, "datetime": dt, "commission": round( float(data.get("Commission") or data.get("commission") or 0), 2, ), } except Exception as exc: logger.debug("trade_row_from_ctp_dict: %s", exc) return None def _install_trade_query_hook(self) -> None: """不再 monkey-patch CTP 成交回调(易与并发查询冲突导致 vnctptd 段错误)。""" return @staticmethod def _engine_collection_items(raw: Any) -> list[Any]: """vnpy 不同版本可能返回 dict 或 list。""" if raw is None: return [] if isinstance(raw, dict): return list(raw.values()) if isinstance(raw, (list, tuple)): return list(raw) return [raw] def _collect_engine_trades(self) -> list[dict[str, Any]]: if not self._engine: return [] out: list[dict[str, Any]] = [] seen: set[str] = set() try: trades = self._engine.get_all_trades() except Exception: trades = None for trade in self._engine_collection_items(trades): row = self._trade_row_from_vnpy(trade) if not row: continue key = row["trade_id"] if key in seen: continue seen.add(key) out.append(row) return out def refresh_trades(self) -> None: """成交仅读 vnpy 内存回报;不调用 query_trade(避免 CTP 段错误)。""" return def list_trades(self, *, refresh: bool = False) -> list[dict[str, Any]]: with _ctp_td_lock: out = self._collect_engine_trades() out.sort(key=lambda r: (r.get("datetime") or "", r.get("trade_id") or "")) return out def list_active_orders(self) -> list[dict[str, Any]]: if not self._engine: return [] out: list[dict[str, Any]] = [] try: orders = self._engine.get_all_active_orders() except Exception: return [] for order in orders or []: status = getattr(order, "status", None) status_s = str(status) if status_s and not any(x in status_s for x in ("NOTTRADED", "PARTTRADED", "SUBMITTING")): continue vol = int(getattr(order, "volume", 0) or 0) traded = int(getattr(order, "traded", 0) or 0) remain = max(0, vol - traded) if remain <= 0: continue direction = getattr(order, "direction", None) d = "long" if direction is not None and str(direction).endswith("SHORT"): d = "short" offset = getattr(order, "offset", None) offset_s = str(offset or "") sym = getattr(order, "symbol", "") or "" exchange = getattr(order, "exchange", None) ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "") vt_oid = str(getattr(order, "vt_orderid", "") or "") order_id = str(getattr(order, "orderid", "") or "") out.append({ "symbol": sym, "exchange": ex_name, "direction": d, "lots": remain, "price": float(getattr(order, "price", 0) or 0), "offset": offset_s, "order_id": vt_oid or order_id, "vt_order_id": vt_oid, "status": status_s, }) return out def send_order( self, *, ths_code: str, offset: str, direction: str, lots: int, price: float, order_type: str = "limit", ) -> str: from vnpy.trader.constant import Direction, Offset, OrderType from vnpy.trader.object import OrderRequest if not self._engine: raise RuntimeError("CTP 未初始化") if not self._td_logged_in(): raise RuntimeError("CTP 交易通道未登录,请重连后再下单") sym, ex_name = ths_to_vnpy_symbol(ths_code) exchange = to_vnpy_exchange(ex_name) lots = max(1, int(lots)) tick = float(get_contract_spec(ths_code).get("tick_size") or 1.0) offset = (offset or "open").lower() direction = (direction or "long").lower() if offset in ("open", "open_long", "open_short"): d = Direction.LONG if direction == "long" or offset == "open_long" else Direction.SHORT off = Offset.OPEN elif offset in ("close", "close_long", "close_short"): hold = "long" if direction == "long" or offset == "close_long" else "short" if hold == "long": d = Direction.SHORT else: d = Direction.LONG off = self._resolve_close_offset(sym, ex_name, hold, lots) else: raise ValueError(f"未知开平: {offset}") use_market = (order_type or "limit").lower() == "market" if use_market: ot = OrderType.FAK price = self._aggressive_limit_price(ths_code, sym, ex_name, d, tick, price) else: ot = OrderType.LIMIT price = round_to_tick(float(price), tick) if price <= 0: raise ValueError("委托价格无效,请检查行情或手动填写价格") req = OrderRequest( symbol=sym, exchange=exchange, direction=d, type=ot, volume=lots, price=price, offset=off, ) logger.info( "CTP 报单 %s %s %s %s手 @%s offset=%s type=%s", sym, ex_name, d, lots, price, off, ot, ) with _ctp_td_lock: vt_orderid = self._engine.send_order(req, GATEWAY_NAME) if not vt_orderid: raise RuntimeError("CTP 拒单或未返回委托号(请检查合约代码、价格是否为最小变动价位整数倍)") return str(vt_orderid) def cancel_order(self, vt_orderid: str) -> bool: if not self._engine or not vt_orderid: return False try: with _ctp_td_lock: order = self._engine.get_order(vt_orderid) if order is None: return False req = order.create_cancel_request() self._engine.cancel_order(req, GATEWAY_NAME) logger.info("CTP 撤单 %s", vt_orderid) return True except Exception as exc: logger.warning("CTP 撤单失败 %s: %s", vt_orderid, exc) return False def get_bridge() -> CtpBridge: global _bridge with _bridge_lock: if _bridge is None: _bridge = CtpBridge() return _bridge def try_init_vnpy(_settings: dict | None = None) -> bool: return get_bridge().available() def vnpy_available() -> bool: return get_bridge().available() def _ctp_connect_permitted(*, scheduled: bool = False) -> bool: """scheduled=True:盘前/交易时段计划连接,不受「自动连接」开关限制。""" from ctp_settings import is_ctp_auto_connect_enabled if is_ctp_auto_connect_enabled(): return True if not scheduled: return False from ctp_premarket_connect import should_auto_connect_now return should_auto_connect_now() def ctp_disconnect(*, set_disabled_hint: bool = False) -> None: """主动断开 CTP 并清理内存状态。""" from ctp_settings import CTP_DISABLED_HINT b = get_bridge() b._close_gateway() if set_disabled_hint: b._last_error = CTP_DISABLED_HINT _persist_last_error(CTP_DISABLED_HINT) else: b._last_error = "" _persist_last_error("") def ctp_connect(mode: str, *, force: bool = False) -> dict[str, Any]: b = get_bridge() b.connect(mode, force=force) return b.status(mode) def ctp_start_connect(mode: str, *, force: bool = False, scheduled: bool = False) -> dict[str, Any]: """非阻塞发起连接,供 Web API 使用。""" b = get_bridge() info = b.start_connect_async(mode, force=force, scheduled=scheduled) st = b.status(mode) return {**info, "status": st} def ctp_try_auto_reconnect(mode: str) -> bool: """断线时静默异步重连;已连接且交易通道正常则不再重复 connect。""" if not _ctp_connect_permitted(scheduled=True): return False b = get_bridge() if not b.available(): return False if b.connect_in_progress(): return False if b.login_cooldown_remaining() > 0: return False st = _setting_for_mode(mode) if not st.get("用户名") or not st.get("密码") or not st.get("交易服务器"): return False if b.connected_mode == mode: if b._td_logged_in() or b.ping(): return True recent = time.time() - float(getattr(b, "_last_connect_ok_ts", 0) or 0) if recent < 120: logger.debug("CTP 跳过自动重连:刚连接 %.0fs", recent) return True td = st.get("交易服务器", "") ok, err = probe_tcp_address(td, timeout=4.0) if not ok: b._last_error = ( f"SimNow 交易前置不可达:{td}({err})。" "请更新 SIMNOW_TD_ADDRESS 并确认服务器出网。" ) return False info = b.start_connect_async(mode, force=False, scheduled=True) return bool( info.get("connected") or info.get("connecting") or info.get("started") ) def ctp_status(mode: str) -> dict[str, Any]: from ctp_settings import CTP_DISABLED_HINT, is_ctp_auto_connect_enabled auto = is_ctp_auto_connect_enabled() st = get_bridge().status(mode) st["auto_connect_enabled"] = auto if not auto: st["disabled_hint"] = CTP_DISABLED_HINT if not st.get("connected") and not st.get("connecting"): st["last_error"] = "" st["td_reachable"] = None return st if not st.get("connected") and not st.get("connecting"): setting = _setting_for_mode(mode) td = setting.get("交易服务器", "") if td: ok, err = probe_tcp_address(td, timeout=3.0) st["td_reachable"] = ok if not ok and not st.get("last_error"): st["last_error"] = ( f"SimNow 交易前置不可达:{td}({err})" ) return st def ctp_get_account(mode: str) -> dict[str, Any]: b = get_bridge() b.ensure_connected(mode) return b.get_account() def ctp_sum_position_margins( mode: str, *, refresh_if_empty: bool = True, refresh_margin: bool = False, ) -> float: """各持仓 CTP 回报保证金之和(与柜台「实收保证金」一致)。""" total = 0.0 for p in ctp_list_positions( mode, refresh_if_empty=refresh_if_empty, refresh_margin=refresh_margin, ): m = float(p.get("margin") or 0) if m > 0: total += m return round(total, 2) if total > 0 else 0.0 def ctp_account_margin_used(mode: str) -> Optional[float]: """账户实际占用保证金 ≈ 权益 − 可用(与顶栏柜台资金一致)。""" b = get_bridge() if b.connected_mode != mode or not b.ping(): return None try: acc = b.get_account() balance = float(acc.get("balance") or 0) available = float(acc.get("available") or 0) if balance <= 0: return None used = balance - available return round(used, 2) if used > 0 else None except Exception as exc: logger.debug("ctp_account_margin_used: %s", exc) return None def ctp_list_positions( mode: str, *, refresh_if_empty: bool = True, refresh_margin: bool = False, ) -> list[dict[str, Any]]: b = get_bridge() if b.connected_mode != mode or not b.ping(): return [] return b.list_positions(refresh_if_empty=refresh_if_empty, refresh_margin=refresh_margin) def ctp_list_active_orders(mode: str) -> list[dict[str, Any]]: b = get_bridge() b.ensure_connected(mode) return b.list_active_orders() def ctp_cancel_order(mode: str, vt_orderid: str) -> bool: b = get_bridge() b.ensure_connected(mode) return b.cancel_order(vt_orderid) def ctp_list_trades(mode: str, *, refresh: bool = False) -> list[dict[str, Any]]: b = get_bridge() if b.connected_mode != mode or not b.ping(): return [] return b.list_trades(refresh=refresh) def ctp_get_tick_price(mode: str, ths_code: str) -> Optional[float]: """CTP 柜台最新价(需已连接并订阅)。""" b = get_bridge() if b.connected_mode != mode: return None try: return b.get_tick_price(ths_code, mode=mode) except Exception as exc: logger.debug("ctp_get_tick_price: %s", exc) return None def ctp_get_tick_detail(mode: str, ths_code: str) -> dict[str, Any]: b = get_bridge() if b.connected_mode != mode: return {} try: return b.get_tick_detail(ths_code, mode=mode) except Exception as exc: logger.debug("ctp_get_tick_detail: %s", exc) return {} def ctp_estimate_margin_one_lot( mode: str, ths_code: str, price: float, *, direction: str = "long", ) -> Optional[float]: b = get_bridge() if b.connected_mode != mode or not b.ping(): return None try: return b.estimate_margin_one_lot(ths_code, price, direction=direction) except Exception as exc: logger.debug("ctp_estimate_margin_one_lot: %s", exc) return None def ctp_lookup_contract_spec(mode: str, ths_code: str) -> Optional[dict]: b = get_bridge() if b.connected_mode != mode or not b.ping(): return None try: return b.lookup_contract_spec(ths_code) except Exception as exc: logger.debug("ctp_lookup_contract_spec: %s", exc) return None def get_ctp_balance(mode: str) -> Optional[float]: try: acc = ctp_get_account(mode) bal = acc.get("balance") return float(bal) if bal else None except Exception as exc: logger.debug("get_ctp_balance: %s", exc) return None def execute_order( conn, *, mode: str, offset: str, symbol: str, direction: str, lots: int, price: float, settings: dict | None = None, order_type: str = "limit", ) -> dict[str, Any]: """统一下单:simulation=SimNow,live=期货公司 CTP。""" del conn, settings if mode not in ("simulation", "live"): raise ValueError("未知交易模式") if not vnpy_available(): raise ValueError( "请先安装 vnpy 与 vnpy_ctp:pip install vnpy vnpy_ctp\n" f"模拟盘需配置 .env 中 SIMNOW_USER / SIMNOW_PASSWORD 等" ) b = get_bridge() b.require_connected(mode) order_id = b.send_order( ths_code=symbol, offset=offset, direction=direction, lots=lots, price=price, order_type=order_type, ) return { "order_id": order_id, "mode": mode, "mode_label": _mode_label(mode), "symbol": symbol, "lots": lots, "price": price, }