1b3a7f1bdc
Preserve trading state when CTP memory is empty, bootstrap equity/positions on page load, stabilize risk status from DB monitors, and remove app-layer manual close cooling periods. Co-authored-by: Cursor <cursoragent@cursor.com>
2419 lines
90 KiB
Python
2419 lines
90 KiB
Python
# Copyright (c) 2025-2026 马建军. All rights reserved.
|
||
# 专有软件 — 未经授权禁止复制、传播、转售。
|
||
# 严禁用于:带单/代客理财、向他人推荐期货品种或买卖建议、融资配资等业务。
|
||
# 详见 LICENSE.zh-CN.txt 与 docs/软件购买与使用协议.md
|
||
|
||
"""CTP 执行层:模拟盘 → SimNow;实盘 → 期货公司(vnpy_ctp)。"""
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import os
|
||
import re
|
||
import threading
|
||
import time
|
||
from collections import deque
|
||
from typing import Any, Callable, Optional
|
||
|
||
from locale_fix import ensure_process_locale
|
||
|
||
ensure_process_locale()
|
||
|
||
from ctp_settings import live_setting_dict, simnow_setting_dict
|
||
from ctp_symbol import ths_to_vnpy_symbol, to_vnpy_exchange
|
||
from contract_specs import get_contract_spec
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
GATEWAY_NAME = "CTP"
|
||
|
||
CONNECT_WAIT_SEC = 60
|
||
CONNECT_POLL_INTERVAL_SEC = 0.5
|
||
LOGIN_BAN_COOLDOWN_SEC = 45 * 60
|
||
LOGIN_FAIL_COOLDOWN_SEC = 5 * 60
|
||
CTP_COOLDOWN_UNTIL_KEY = "ctp_login_cooldown_until"
|
||
CTP_LAST_ERROR_KEY = "ctp_last_error"
|
||
|
||
|
||
def _persist_login_cooldown(seconds: float) -> None:
|
||
from fee_specs import get_setting, set_setting
|
||
|
||
new_until = time.time() + max(0.0, seconds)
|
||
try:
|
||
old = float(get_setting(CTP_COOLDOWN_UNTIL_KEY, "0") or 0)
|
||
except (TypeError, ValueError):
|
||
old = 0.0
|
||
if new_until > old:
|
||
set_setting(CTP_COOLDOWN_UNTIL_KEY, str(new_until))
|
||
|
||
|
||
def _persisted_login_cooldown_remaining() -> int:
|
||
from fee_specs import get_setting
|
||
|
||
try:
|
||
until = float(get_setting(CTP_COOLDOWN_UNTIL_KEY, "0") or 0)
|
||
return max(0, int(until - time.time()))
|
||
except (TypeError, ValueError):
|
||
return 0
|
||
|
||
|
||
def _clear_persisted_login_cooldown() -> None:
|
||
from fee_specs import set_setting
|
||
|
||
set_setting(CTP_COOLDOWN_UNTIL_KEY, "0")
|
||
|
||
|
||
def _persist_last_error(msg: str) -> None:
|
||
from fee_specs import set_setting
|
||
|
||
set_setting(CTP_LAST_ERROR_KEY, (msg or "").strip())
|
||
|
||
|
||
def _load_persisted_last_error() -> str:
|
||
from fee_specs import get_setting
|
||
|
||
return (get_setting(CTP_LAST_ERROR_KEY, "") or "").strip()
|
||
|
||
_position_refresh_callback: Optional[Callable[[], None]] = None
|
||
_tick_sl_tp_callback: Optional[Callable[[str, str, float], None]] = None
|
||
_tick_quote_callback: Optional[Callable[[], None]] = None
|
||
_ctp_connected_callback: Optional[Callable[[str], None]] = None
|
||
_position_refresh_debounce_lock = threading.Lock()
|
||
_position_refresh_debounce_ts: float = 0.0
|
||
_tick_quote_timer: Optional[threading.Timer] = None
|
||
_tick_quote_timer_lock = threading.Lock()
|
||
TICK_QUOTE_DEBOUNCE_SEC = 0.12
|
||
|
||
|
||
def set_position_refresh_callback(fn: Optional[Callable[[], None]]) -> None:
|
||
global _position_refresh_callback
|
||
_position_refresh_callback = fn
|
||
|
||
|
||
def set_tick_sl_tp_callback(fn: Optional[Callable[[str, str, float], None]]) -> None:
|
||
"""注册 tick 回调:exchange, symbol, last_price → 本地 SL/TP 触发。"""
|
||
global _tick_sl_tp_callback
|
||
_tick_sl_tp_callback = fn
|
||
|
||
|
||
def set_tick_quote_callback(fn: Optional[Callable[[], None]]) -> None:
|
||
"""注册 tick 回调:推送持仓现价/浮盈(由 bridge 侧防抖)。"""
|
||
global _tick_quote_callback
|
||
_tick_quote_callback = fn
|
||
|
||
|
||
def _fire_tick_quote_callback_debounced() -> None:
|
||
"""持仓品种 tick 后 trailing 防抖,批量推送现价/浮盈。"""
|
||
global _tick_quote_timer
|
||
|
||
def _run() -> None:
|
||
fn = _tick_quote_callback
|
||
if not fn:
|
||
return
|
||
try:
|
||
fn()
|
||
except Exception as exc:
|
||
logger.debug("tick quote callback: %s", exc)
|
||
|
||
with _tick_quote_timer_lock:
|
||
if _tick_quote_timer is not None:
|
||
_tick_quote_timer.cancel()
|
||
_tick_quote_timer = threading.Timer(TICK_QUOTE_DEBOUNCE_SEC, _run)
|
||
_tick_quote_timer.daemon = True
|
||
_tick_quote_timer.start()
|
||
|
||
|
||
def set_ctp_connected_callback(fn: Optional[Callable[[str], None]]) -> None:
|
||
"""CTP 交易通道登录成功后回调(mode=simulation|live)。"""
|
||
global _ctp_connected_callback
|
||
_ctp_connected_callback = fn
|
||
|
||
|
||
def _fire_ctp_connected_callback(mode: str) -> None:
|
||
fn = _ctp_connected_callback
|
||
if not fn:
|
||
return
|
||
try:
|
||
threading.Thread(
|
||
target=fn, args=(mode,), daemon=True, name="ctp-connected-cb",
|
||
).start()
|
||
except Exception as exc:
|
||
logger.debug("ctp connected callback: %s", exc)
|
||
|
||
|
||
def _fire_position_refresh_callback() -> None:
|
||
fn = _position_refresh_callback
|
||
if not fn:
|
||
return
|
||
try:
|
||
threading.Thread(target=fn, daemon=True, name="ctp-position-refresh").start()
|
||
except Exception as exc:
|
||
logger.debug("position refresh callback: %s", exc)
|
||
|
||
|
||
def _fire_position_refresh_callback_debounced(*, min_interval: float = 0.35) -> None:
|
||
global _position_refresh_debounce_ts
|
||
now = time.monotonic()
|
||
with _position_refresh_debounce_lock:
|
||
if now - _position_refresh_debounce_ts < min_interval:
|
||
return
|
||
_position_refresh_debounce_ts = now
|
||
_fire_position_refresh_callback()
|
||
|
||
|
||
def _fire_position_refresh_burst() -> None:
|
||
"""连接后持仓回报可能分批到达,分多次触发快照刷新。"""
|
||
_fire_position_refresh_callback()
|
||
for delay in (1.5, 4.0, 10.0, 18.0):
|
||
threading.Timer(delay, _fire_position_refresh_callback).start()
|
||
|
||
|
||
def _schedule_after_instruments_ready(bridge: "CtpBridge") -> None:
|
||
"""合约查询结束后查询持仓并校准(SimNow 登录后约 10–20s)。"""
|
||
if not getattr(bridge, "_connected_mode", None):
|
||
return
|
||
now = time.monotonic()
|
||
if now - float(getattr(bridge, "_last_instruments_ready_ts", 0) or 0) < 5.0:
|
||
return
|
||
bridge._last_instruments_ready_ts = now
|
||
|
||
def _run() -> None:
|
||
try:
|
||
if bridge._has_live_positions():
|
||
return
|
||
bridge._ensure_instrument_margin_hooks()
|
||
with _ctp_td_lock:
|
||
bridge.request_position_snapshot(force=True)
|
||
time.sleep(2.0)
|
||
with _ctp_td_lock:
|
||
bridge.calibrate_trading_state()
|
||
_fire_position_refresh_callback()
|
||
n = len(bridge._collect_positions())
|
||
logger.info("CTP 合约加载完成,持仓 %s 条,已刷新快照", n)
|
||
except Exception as exc:
|
||
logger.debug("instruments ready refresh: %s", exc)
|
||
|
||
threading.Timer(0.4, _run).start()
|
||
|
||
|
||
def _schedule_position_query_retries(bridge: "CtpBridge") -> None:
|
||
def _run() -> None:
|
||
if not bridge._connected_mode or bridge._has_live_positions():
|
||
return
|
||
try:
|
||
bridge._ensure_instrument_margin_hooks()
|
||
with _ctp_td_lock:
|
||
bridge.request_position_snapshot(force=False)
|
||
time.sleep(1.0)
|
||
with _ctp_td_lock:
|
||
bridge.calibrate_trading_state()
|
||
_fire_position_refresh_callback()
|
||
except Exception as exc:
|
||
logger.debug("position query retry: %s", exc)
|
||
|
||
for delay in POSITION_QUERY_RETRY_DELAYS_SEC:
|
||
threading.Timer(delay, _run).start()
|
||
|
||
_bridge: Optional["CtpBridge"] = None
|
||
_bridge_lock = threading.Lock()
|
||
_ctp_td_lock = threading.RLock()
|
||
POSITION_QUERY_MIN_INTERVAL_SEC = 5.0
|
||
POSITION_QUERY_RETRY_DELAYS_SEC = (22.0, 50.0, 95.0)
|
||
TRADE_QUERY_MIN_INTERVAL_SEC = 10.0
|
||
|
||
|
||
def _simnow_setting() -> dict[str, str]:
|
||
"""SimNow 仿真前置(系统设置优先,.env 兜底)。"""
|
||
return simnow_setting_dict()
|
||
|
||
|
||
def _live_setting() -> dict[str, str]:
|
||
return live_setting_dict()
|
||
|
||
|
||
def _setting_for_mode(mode: str) -> dict[str, str]:
|
||
return _simnow_setting() if mode == "simulation" else _live_setting()
|
||
|
||
|
||
def _mode_label(mode: str) -> str:
|
||
return "SimNow" if mode == "simulation" else "期货公司实盘"
|
||
|
||
|
||
def _parse_tcp_address(address: str) -> tuple[str, int]:
|
||
raw = (address or "").strip()
|
||
if raw.startswith("tcp://"):
|
||
raw = raw[6:]
|
||
if ":" not in raw:
|
||
raise ValueError(f"无效 TCP 地址: {address}")
|
||
host, port_s = raw.rsplit(":", 1)
|
||
return host, int(port_s)
|
||
|
||
|
||
def probe_tcp_address(address: str, timeout: float = 5.0) -> tuple[bool, str]:
|
||
"""探测 CTP 前置 TCP 是否可达。"""
|
||
import socket
|
||
|
||
try:
|
||
host, port = _parse_tcp_address(address)
|
||
with socket.create_connection((host, port), timeout=timeout):
|
||
return True, ""
|
||
except Exception as exc:
|
||
return False, str(exc)
|
||
|
||
|
||
def _format_ctp_failure(ctp_logs: list[str], *, td_address: str = "") -> str:
|
||
"""根据 CTP 网关日志拼出可读错误。"""
|
||
if td_address:
|
||
ok, err = probe_tcp_address(td_address, timeout=4.0)
|
||
if not ok:
|
||
return (
|
||
f"SimNow 交易前置不可达:{td_address}({err})。"
|
||
"182.254.243.31 已停用,请改 .env 为官方前置 "
|
||
"tcp://180.168.146.187:10201 / 10211,并确认服务器能访问该地址。"
|
||
)
|
||
text = "\n".join(ctp_logs)
|
||
if "连续登录失败" in text or "登录被禁止" in text or "代码:75" in text:
|
||
return (
|
||
"CTP 登录被临时禁止:连续失败次数过多(错误码 75)。"
|
||
"请等待约 30~60 分钟后再试,先用快期确认投资者代码与密码正确,期间勿反复点「连接」。"
|
||
)
|
||
if "4097" in text or "Decrypt handshake" in text or "shake hand" in text.lower():
|
||
return (
|
||
"CTP 握手失败(4097):vnpy_ctp 与 SimNow 前置加密不匹配。"
|
||
"请执行 pip install -U vnpy vnpy_ctp 后重启,并确认 .env 中 SIMNOW_ENV=实盘"
|
||
)
|
||
if "不合法的登录" in text or "密码" in text or "账号" in text:
|
||
tail = ctp_logs[-1] if ctp_logs else ""
|
||
return f"CTP 登录被拒:{tail or '请检查投资者代码与密码(快期能否登录)'}"
|
||
if "连接断开" in text or "disconnect" in text.lower():
|
||
tail = ctp_logs[-1] if ctp_logs else ""
|
||
return f"CTP 连接断开:{tail or '请检查前置地址与网络'}"
|
||
if ctp_logs:
|
||
return f"CTP 连接失败:{ctp_logs[-1]}"
|
||
return "CTP 连接超时:未收到柜台回报。请检查 SimNow 账号、前置地址、网络(nc 测端口),并用快期验证账号"
|
||
|
||
|
||
def round_to_tick(price: float, tick: float) -> float:
|
||
if tick <= 0:
|
||
return float(price)
|
||
steps = round(float(price) / tick)
|
||
return round(steps * tick, 10)
|
||
|
||
|
||
def _is_long_direction(direction_obj: Any) -> bool:
|
||
s = str(direction_obj or "")
|
||
return "LONG" in s.upper() or "多" in s
|
||
|
||
|
||
class CtpBridge:
|
||
def __init__(self) -> None:
|
||
self._engine = None
|
||
self._ee = None
|
||
self._connected_mode: Optional[str] = None
|
||
self._last_error: str = ""
|
||
self._connect_lock = threading.Lock()
|
||
self._connect_in_progress = False
|
||
self._login_cooldown_until: float = 0.0
|
||
self._restore_persisted_state()
|
||
self._commission_waiters: dict[int, threading.Event] = {}
|
||
self._commission_lists: dict[int, list] = {}
|
||
self._commission_hooked = False
|
||
self._margin_rate_waiters: dict[int, threading.Event] = {}
|
||
self._margin_rate_lists: dict[int, list] = {}
|
||
self._margin_rate_hooked = False
|
||
self._instrument_hooked = False
|
||
self._hooks_td_api_id: Optional[int] = None
|
||
self._ctp_log_hooked = False
|
||
self._last_instruments_ready_ts: float = 0.0
|
||
self._last_position_rsp_ts: float = 0.0
|
||
self._instrument_margin_ratios: dict[str, dict[str, float]] = {}
|
||
self._margin_per_lot: dict[str, float] = {}
|
||
self._subscribed: set[str] = set()
|
||
self._last_position_query_ts: float = 0.0
|
||
self._position_margins: dict[str, float] = {}
|
||
self._position_open_times: dict[str, str] = {}
|
||
self._margin_hooked = False
|
||
self._trade_hooked = False
|
||
self._trade_query_results: list[dict[str, Any]] = []
|
||
self._trade_query_event = threading.Event()
|
||
self._last_trade_query_ts: float = 0.0
|
||
self._last_connect_ok_ts: float = 0.0
|
||
self._tick_hooked = False
|
||
self._position_hooked = False
|
||
self._order_hooked = False
|
||
self._trade_hooked = False
|
||
self._bar_generators: dict[str, Any] = {}
|
||
self._bars_1m: dict[str, deque] = {}
|
||
self._init_engine()
|
||
|
||
def _init_engine(self) -> None:
|
||
ensure_process_locale()
|
||
try:
|
||
from vnpy.event import EventEngine
|
||
from vnpy.trader.engine import MainEngine
|
||
from vnpy_ctp import CtpGateway
|
||
|
||
self._ee = EventEngine()
|
||
self._engine = MainEngine(self._ee)
|
||
self._engine.add_gateway(CtpGateway)
|
||
self._ensure_position_event_hook()
|
||
self._ensure_order_event_hook()
|
||
self._ensure_trade_event_hook()
|
||
self._ensure_ctp_log_hooks()
|
||
except ImportError:
|
||
self._last_error = "未安装 vnpy / vnpy_ctp,请 pip install vnpy vnpy_ctp"
|
||
except Exception as exc:
|
||
self._last_error = str(exc)
|
||
|
||
def _ensure_position_event_hook(self) -> None:
|
||
if self._position_hooked or not self._ee:
|
||
return
|
||
try:
|
||
from vnpy.trader.event import EVENT_POSITION
|
||
except ImportError:
|
||
return
|
||
|
||
def _on_position(event) -> None:
|
||
try:
|
||
from ctp_trading_state import trading_state
|
||
|
||
pos = event.data
|
||
row = self._position_row_from_vnpy(pos)
|
||
if row:
|
||
sym = row.get("symbol") or ""
|
||
ex = row.get("exchange") or ""
|
||
ths = CtpBridge._vnpy_sym_to_ths(sym, ex) or sym
|
||
with _ctp_td_lock:
|
||
trades = self.list_trades()
|
||
trading_state.upsert_position(
|
||
row, notify=False, trades=trades, ths_sym=ths,
|
||
)
|
||
sym = getattr(pos, "symbol", "") or ""
|
||
d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short"
|
||
vol = int(getattr(pos, "volume", 0) or 0)
|
||
if vol <= 0:
|
||
exchange = getattr(pos, "exchange", None)
|
||
ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "")
|
||
from ctp_trading_state import position_key
|
||
|
||
trading_state.remove_position(
|
||
position_key(ex_name, sym, d), notify=False,
|
||
)
|
||
else:
|
||
for attr in ("margin", "use_margin", "UseMargin"):
|
||
raw = float(getattr(pos, attr, 0) or 0)
|
||
if raw > 0:
|
||
self._position_margins[self._position_margin_key(sym, d)] = raw
|
||
if vol > 0:
|
||
self._margin_per_lot[self._position_margin_key(sym, d)] = round(
|
||
raw / vol, 2,
|
||
)
|
||
break
|
||
except Exception as exc:
|
||
logger.debug("position margin cache: %s", exc)
|
||
_fire_position_refresh_callback()
|
||
|
||
self._ee.register(EVENT_POSITION, _on_position)
|
||
self._position_hooked = True
|
||
|
||
def _ensure_order_event_hook(self) -> None:
|
||
if self._order_hooked or not self._ee:
|
||
return
|
||
try:
|
||
from vnpy.trader.event import EVENT_ORDER
|
||
except ImportError:
|
||
return
|
||
|
||
def _on_order(event) -> None:
|
||
try:
|
||
from ctp_trading_state import trading_state
|
||
|
||
order = event.data
|
||
row = self._order_row_from_vnpy(order)
|
||
if not row:
|
||
return
|
||
status_s = str(row.get("status") or "")
|
||
terminal = any(
|
||
x in status_s
|
||
for x in ("ALLTRADED", "CANCELLED", "REJECTED", "全部成交", "已撤销", "拒单")
|
||
)
|
||
oid = str(row.get("order_id") or row.get("vt_order_id") or "")
|
||
if terminal or int(row.get("lots") or 0) <= 0:
|
||
trading_state.remove_order(oid, notify=False)
|
||
else:
|
||
trading_state.upsert_order(row, notify=False)
|
||
except Exception as exc:
|
||
logger.debug("order event: %s", exc)
|
||
_fire_position_refresh_callback()
|
||
|
||
self._ee.register(EVENT_ORDER, _on_order)
|
||
self._order_hooked = True
|
||
|
||
def _ensure_trade_event_hook(self) -> None:
|
||
if self._trade_hooked or not self._ee:
|
||
return
|
||
try:
|
||
from vnpy.trader.event import EVENT_TRADE
|
||
except ImportError:
|
||
return
|
||
|
||
def _on_trade(event) -> None:
|
||
try:
|
||
trade = event.data
|
||
row = self._trade_row_from_vnpy(trade)
|
||
if row and row.get("offset") == "open":
|
||
sym = row.get("symbol") or ""
|
||
pd = row.get("position_direction") or "long"
|
||
dt = row.get("datetime") or ""
|
||
if sym and dt:
|
||
self._position_open_times[self._position_margin_key(sym, pd)] = dt
|
||
except Exception as exc:
|
||
logger.debug("trade event: %s", exc)
|
||
_fire_position_refresh_callback()
|
||
|
||
self._ee.register(EVENT_TRADE, _on_trade)
|
||
self._trade_hooked = True
|
||
|
||
def _order_row_from_vnpy(self, order: Any) -> Optional[dict[str, Any]]:
|
||
try:
|
||
status = getattr(order, "status", None)
|
||
status_s = str(status)
|
||
vol = int(getattr(order, "volume", 0) or 0)
|
||
traded = int(getattr(order, "traded", 0) or 0)
|
||
remain = max(0, vol - traded)
|
||
direction = getattr(order, "direction", None)
|
||
d = "long"
|
||
if direction is not None and str(direction).endswith("SHORT"):
|
||
d = "short"
|
||
offset = getattr(order, "offset", None)
|
||
sym = getattr(order, "symbol", "") or ""
|
||
exchange = getattr(order, "exchange", None)
|
||
ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "")
|
||
vt_oid = str(getattr(order, "vt_orderid", "") or "")
|
||
order_id = str(getattr(order, "orderid", "") or "")
|
||
return {
|
||
"symbol": sym,
|
||
"exchange": ex_name,
|
||
"direction": d,
|
||
"lots": remain,
|
||
"price": float(getattr(order, "price", 0) or 0),
|
||
"offset": str(offset or ""),
|
||
"order_id": vt_oid or order_id,
|
||
"vt_order_id": vt_oid,
|
||
"status": status_s,
|
||
}
|
||
except Exception as exc:
|
||
logger.debug("order_row_from_vnpy: %s", exc)
|
||
return None
|
||
|
||
def _position_row_from_vnpy(self, pos: Any) -> Optional[dict[str, Any]]:
|
||
try:
|
||
vol = int(getattr(pos, "volume", 0) or 0)
|
||
d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short"
|
||
sym = getattr(pos, "symbol", "") or ""
|
||
exchange = getattr(pos, "exchange", None)
|
||
ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "")
|
||
price = float(getattr(pos, "price", 0) or 0)
|
||
yd = int(getattr(pos, "yd_volume", 0) or 0)
|
||
td = max(0, vol - yd)
|
||
margin = self.estimate_position_margin(sym, ex_name, d, vol, price, pos=pos)
|
||
open_time = self._lookup_position_open_time(sym, d) or None
|
||
pnl = float(getattr(pos, "pnl", 0) or 0)
|
||
row = {
|
||
"symbol": sym,
|
||
"exchange": ex_name,
|
||
"direction": d,
|
||
"lots": vol,
|
||
"avg_price": price,
|
||
"pnl": pnl,
|
||
"frozen": int(getattr(pos, "frozen", 0) or 0),
|
||
"margin": margin,
|
||
"open_time": open_time,
|
||
"yd_volume": yd,
|
||
"td_volume": td,
|
||
}
|
||
try:
|
||
from ctp_entry_price import round_to_tick
|
||
|
||
ths = CtpBridge._vnpy_sym_to_ths(sym, ex_name) or sym
|
||
if price > 0:
|
||
row["avg_price"] = round_to_tick(price, ths)
|
||
except Exception as exc:
|
||
logger.debug("position avg round: %s", exc)
|
||
return row
|
||
except Exception as exc:
|
||
logger.debug("position_row_from_vnpy: %s", exc)
|
||
return None
|
||
|
||
def calibrate_trading_state(self) -> None:
|
||
"""全量校准内存簿(读 vnpy 缓存,不 query 柜台)。"""
|
||
try:
|
||
from ctp_trading_state import trading_state
|
||
|
||
with _ctp_td_lock:
|
||
orders = self.list_active_orders()
|
||
positions = self._collect_positions()
|
||
trades = self.list_trades()
|
||
preserve_margin = 0.0
|
||
if self._connected_mode and not positions:
|
||
try:
|
||
preserve_margin = float(
|
||
ctp_account_margin_used(self._connected_mode) or 0,
|
||
)
|
||
except Exception:
|
||
preserve_margin = 0.0
|
||
trading_state.calibrate_from_lists(
|
||
orders,
|
||
positions,
|
||
trades=trades,
|
||
ths_for_vnpy_sym=lambda s, e: CtpBridge._vnpy_sym_to_ths(s, e) or s,
|
||
preserve_positions_if_margin=preserve_margin,
|
||
)
|
||
except Exception as exc:
|
||
logger.debug("calibrate trading state: %s", exc)
|
||
|
||
def available(self) -> bool:
|
||
return self._engine is not None
|
||
|
||
@property
|
||
def last_error(self) -> str:
|
||
return self._last_error
|
||
|
||
@property
|
||
def connected_mode(self) -> Optional[str]:
|
||
return self._connected_mode
|
||
|
||
def connect_in_progress(self) -> bool:
|
||
return self._connect_in_progress
|
||
|
||
def _restore_persisted_state(self) -> None:
|
||
err = _load_persisted_last_error()
|
||
if err:
|
||
self._last_error = err
|
||
db_remain = _persisted_login_cooldown_remaining()
|
||
if db_remain > 0:
|
||
self._login_cooldown_until = time.monotonic() + db_remain
|
||
|
||
def login_cooldown_remaining(self) -> int:
|
||
"""距允许再次登录的剩余秒数(内存 + 数据库,重启后仍有效)。"""
|
||
mem = max(0, int(self._login_cooldown_until - time.monotonic()))
|
||
return max(mem, _persisted_login_cooldown_remaining())
|
||
|
||
def _is_login_cooldown_active(self) -> bool:
|
||
return self.login_cooldown_remaining() > 0
|
||
|
||
def _set_login_cooldown(self, seconds: float) -> None:
|
||
until = time.monotonic() + max(0.0, seconds)
|
||
if until > self._login_cooldown_until:
|
||
self._login_cooldown_until = until
|
||
_persist_login_cooldown(seconds)
|
||
|
||
def _clear_login_cooldown(self) -> None:
|
||
self._login_cooldown_until = 0.0
|
||
_clear_persisted_login_cooldown()
|
||
|
||
def _apply_login_failure_cooldown(self, ctp_logs: list[str]) -> None:
|
||
text = "\n".join(ctp_logs)
|
||
if "连续登录失败" in text or "登录被禁止" in text or "代码:75" in text:
|
||
self._set_login_cooldown(LOGIN_BAN_COOLDOWN_SEC)
|
||
elif any("登录失败" in m or "不合法的登录" in m for m in ctp_logs):
|
||
self._set_login_cooldown(LOGIN_FAIL_COOLDOWN_SEC)
|
||
|
||
def _login_cooldown_message(self) -> str:
|
||
remain = self.login_cooldown_remaining()
|
||
return (
|
||
f"CTP 登录冷却中,请 {remain // 60} 分 {remain % 60} 秒后再试"
|
||
f"(避免连续失败被 SimNow 封禁)"
|
||
)
|
||
|
||
def _close_gateway(self) -> None:
|
||
"""关闭 CTP 网关,避免半连接状态下重连卡在「连接登录」。"""
|
||
if not self._engine:
|
||
return
|
||
try:
|
||
gw = self._engine.get_gateway(GATEWAY_NAME)
|
||
if gw:
|
||
gw.close()
|
||
except Exception as exc:
|
||
logger.debug("gateway close: %s", exc)
|
||
self._connected_mode = None
|
||
self._hooks_td_api_id = None
|
||
self._instrument_hooked = False
|
||
self._margin_rate_hooked = False
|
||
self._last_position_query_ts = 0.0
|
||
self._last_instruments_ready_ts = 0.0
|
||
try:
|
||
from ctp_trading_state import trading_state
|
||
|
||
trading_state.clear()
|
||
except Exception:
|
||
pass
|
||
time.sleep(0.6)
|
||
|
||
def _ensure_ctp_log_hooks(self) -> None:
|
||
"""监听 vnpy 日志:合约查询成功时补触发持仓刷新(重连后 td_api 可能已换)。"""
|
||
if self._ctp_log_hooked or not self._ee:
|
||
return
|
||
try:
|
||
from vnpy.trader.event import EVENT_LOG
|
||
except ImportError:
|
||
return
|
||
bridge = self
|
||
|
||
def _on_persistent_log(event) -> None:
|
||
try:
|
||
msg = getattr(event.data, "msg", "") or str(event.data)
|
||
if "合约信息查询成功" in str(msg):
|
||
_schedule_after_instruments_ready(bridge)
|
||
except Exception as exc:
|
||
logger.debug("ctp log hook: %s", exc)
|
||
|
||
self._ee.register(EVENT_LOG, _on_persistent_log)
|
||
self._ctp_log_hooked = True
|
||
|
||
def _login_rejected(self, ctp_logs: list[str]) -> bool:
|
||
return any(
|
||
kw in m
|
||
for m in ctp_logs
|
||
for kw in ("登录失败", "不合法的登录", "登录被禁止", "连续登录失败")
|
||
)
|
||
|
||
def _wait_connected(self, mode: str, ctp_logs: list[str] | None = None) -> bool:
|
||
"""等待账户回报或交易通道登录成功。"""
|
||
if not self._engine:
|
||
return False
|
||
logs = ctp_logs or []
|
||
loops = max(1, int(CONNECT_WAIT_SEC / CONNECT_POLL_INTERVAL_SEC))
|
||
for _ in range(loops):
|
||
if self._login_rejected(logs):
|
||
return False
|
||
try:
|
||
if self._engine.get_all_accounts():
|
||
return True
|
||
except Exception:
|
||
pass
|
||
if self._td_logged_in():
|
||
return True
|
||
time.sleep(CONNECT_POLL_INTERVAL_SEC)
|
||
return False
|
||
|
||
def status(self, mode: str) -> dict[str, Any]:
|
||
if self._connected_mode == mode:
|
||
self.ping()
|
||
st = _setting_for_mode(mode)
|
||
missing = [k for k in ("用户名", "密码", "交易服务器") if not st.get(k)]
|
||
cooldown = self.login_cooldown_remaining()
|
||
connecting = bool(self._connect_in_progress and cooldown <= 0)
|
||
last_error = self._last_error or _load_persisted_last_error()
|
||
return {
|
||
"vnpy_installed": self.available(),
|
||
"connected": self._connected_mode == mode,
|
||
"connecting": connecting,
|
||
"connected_mode": self._connected_mode,
|
||
"mode_label": _mode_label(mode),
|
||
"missing_config": missing,
|
||
"last_error": last_error,
|
||
"login_cooldown_sec": cooldown,
|
||
"broker_id": st.get("经纪商代码", ""),
|
||
"td_address": st.get("交易服务器", ""),
|
||
}
|
||
|
||
def connect(self, mode: str, *, force: bool = False, scheduled: bool = False) -> None:
|
||
from ctp_settings import CTP_DISABLED_HINT
|
||
|
||
if not _ctp_connect_permitted(scheduled=scheduled):
|
||
self._last_error = CTP_DISABLED_HINT
|
||
_persist_last_error(CTP_DISABLED_HINT)
|
||
raise RuntimeError(CTP_DISABLED_HINT)
|
||
if self._connect_in_progress:
|
||
raise RuntimeError("CTP 正在连接中,请稍候")
|
||
if self._is_login_cooldown_active() and not force:
|
||
msg = self._login_cooldown_message()
|
||
self._last_error = msg
|
||
raise RuntimeError(msg)
|
||
if not self._engine:
|
||
raise RuntimeError(self._last_error or "vnpy 引擎未初始化")
|
||
if self._connected_mode == mode and not force:
|
||
if self.ping():
|
||
return
|
||
self._connected_mode = None
|
||
setting = _setting_for_mode(mode)
|
||
if not setting.get("用户名") or not setting.get("密码"):
|
||
raise ValueError(
|
||
f"{_mode_label(mode)}:请在 .env 配置 "
|
||
f"{'SIMNOW_USER / SIMNOW_PASSWORD' if mode == 'simulation' else 'CTP_LIVE_USER / CTP_LIVE_PASSWORD'}"
|
||
)
|
||
if not setting.get("交易服务器"):
|
||
raise ValueError(f"{_mode_label(mode)}:未配置交易服务器地址")
|
||
|
||
self._connect_in_progress = True
|
||
try:
|
||
with _ctp_td_lock:
|
||
with self._connect_lock:
|
||
if force and self._connected_mode:
|
||
self._close_gateway()
|
||
elif self._connected_mode and self._connected_mode != mode:
|
||
try:
|
||
self._engine.close()
|
||
except Exception:
|
||
pass
|
||
self._connected_mode = None
|
||
time.sleep(1)
|
||
elif not (self._connected_mode == mode and self.ping()):
|
||
self._close_gateway()
|
||
|
||
ctp_logs: list[str] = []
|
||
from vnpy.trader.event import EVENT_LOG
|
||
|
||
def _on_log(event) -> None:
|
||
msg = getattr(event.data, "msg", "") or str(event.data)
|
||
if msg:
|
||
ctp_logs.append(str(msg))
|
||
if len(ctp_logs) > 40:
|
||
ctp_logs.pop(0)
|
||
logger.info("CTP | %s", msg)
|
||
|
||
self._ee.register(EVENT_LOG, _on_log)
|
||
try:
|
||
ensure_process_locale()
|
||
logger.info(
|
||
"CTP 连接 [%s] user=%s td=%s env=%s",
|
||
mode,
|
||
setting.get("用户名"),
|
||
setting.get("交易服务器"),
|
||
setting.get("柜台环境", "实盘"),
|
||
)
|
||
td_addr = setting.get("交易服务器", "")
|
||
ok, err = probe_tcp_address(td_addr, timeout=5.0)
|
||
if not ok:
|
||
raise RuntimeError(
|
||
f"SimNow 交易前置不可达:{td_addr}({err})。"
|
||
"请更新 .env 中 SIMNOW_TD_ADDRESS 为官网最新地址,"
|
||
"并在服务器执行 nc -zv 验证出网。"
|
||
)
|
||
self._ensure_instrument_margin_hooks()
|
||
self._engine.connect(setting, GATEWAY_NAME)
|
||
if self._wait_connected(mode, ctp_logs):
|
||
self._connected_mode = mode
|
||
self._last_connect_ok_ts = time.time()
|
||
self._last_error = ""
|
||
_persist_last_error("")
|
||
self._clear_login_cooldown()
|
||
logger.info("CTP 已连接 [%s] td_login=%s accounts=%s",
|
||
mode, self._td_logged_in(),
|
||
len(self._engine.get_all_accounts() or []))
|
||
self._schedule_fee_sync(mode)
|
||
try:
|
||
self.calibrate_trading_state()
|
||
except Exception as exc:
|
||
logger.debug("post-connect calibrate: %s", exc)
|
||
self._ensure_instrument_margin_hooks()
|
||
_fire_position_refresh_burst()
|
||
_schedule_position_query_retries(self)
|
||
_fire_ctp_connected_callback(mode)
|
||
return
|
||
finally:
|
||
self._ee.unregister(EVENT_LOG, _on_log)
|
||
|
||
self._close_gateway()
|
||
self._apply_login_failure_cooldown(ctp_logs)
|
||
hint = _format_ctp_failure(ctp_logs, td_address=setting.get("交易服务器", ""))
|
||
self._last_error = hint
|
||
_persist_last_error(hint)
|
||
logger.warning("CTP 连接失败 [%s]: %s | logs=%s", mode, hint, ctp_logs[-5:])
|
||
raise RuntimeError(hint)
|
||
finally:
|
||
self._connect_in_progress = False
|
||
|
||
def start_connect_async(
|
||
self, mode: str, *, force: bool = False, scheduled: bool = False,
|
||
) -> dict[str, Any]:
|
||
"""后台连接,不阻塞 HTTP 请求。"""
|
||
from ctp_settings import CTP_DISABLED_HINT
|
||
|
||
if not _ctp_connect_permitted(scheduled=scheduled):
|
||
self._last_error = CTP_DISABLED_HINT
|
||
_persist_last_error(CTP_DISABLED_HINT)
|
||
return {
|
||
"started": False,
|
||
"connecting": False,
|
||
"connected": False,
|
||
"disabled": True,
|
||
"error": CTP_DISABLED_HINT,
|
||
}
|
||
if self._connected_mode == mode and self.ping() and not force:
|
||
return {"started": False, "connecting": False, "connected": True}
|
||
if self._connect_in_progress:
|
||
return {"started": False, "connecting": True, "connected": False}
|
||
if self._is_login_cooldown_active() and not force:
|
||
self._last_error = self._login_cooldown_message()
|
||
return {
|
||
"started": False,
|
||
"connecting": False,
|
||
"connected": False,
|
||
"cooldown": True,
|
||
}
|
||
|
||
def _run() -> None:
|
||
try:
|
||
self.connect(mode, force=force, scheduled=scheduled)
|
||
except Exception as exc:
|
||
logger.warning("CTP 后台连接失败: %s", exc)
|
||
|
||
threading.Thread(target=_run, daemon=True, name="ctp-connect-async").start()
|
||
return {"started": True, "connecting": True, "connected": False}
|
||
|
||
def ensure_connected(self, mode: str) -> None:
|
||
if self._connected_mode == mode and self.ping():
|
||
return
|
||
self.connect(mode)
|
||
|
||
def require_connected(self, mode: str) -> None:
|
||
"""报单前检查:须已连接,不在此发起阻塞式 connect。"""
|
||
if self._connect_in_progress:
|
||
raise RuntimeError("CTP 连接中,请稍候再下单")
|
||
if self._connected_mode != mode or not self.ping():
|
||
raise RuntimeError("请先连接 CTP(持仓监控页点击「连接 CTP」)")
|
||
if not self._td_logged_in():
|
||
raise RuntimeError("CTP 交易通道未登录,请重连 CTP 后再下单")
|
||
|
||
def _td_logged_in(self) -> bool:
|
||
try:
|
||
gw = self._engine.get_gateway(GATEWAY_NAME)
|
||
td = gw.td_api
|
||
return bool(getattr(td, "login_status", False))
|
||
except Exception:
|
||
return False
|
||
|
||
def _find_position(self, sym: str, ex_name: str, hold_direction: str) -> Any:
|
||
if not self._engine:
|
||
return None
|
||
sym_l = sym.lower()
|
||
ex_u = ex_name.upper()
|
||
want_long = hold_direction == "long"
|
||
try:
|
||
for pos in self._engine.get_all_positions():
|
||
ps = (getattr(pos, "symbol", "") or "").lower()
|
||
pe = getattr(pos, "exchange", None)
|
||
pe_s = str(pe.value if hasattr(pe, "value") else pe or "").upper()
|
||
if ps != sym_l or pe_s != ex_u:
|
||
continue
|
||
vol = int(getattr(pos, "volume", 0) or 0)
|
||
if vol <= 0:
|
||
continue
|
||
is_long = _is_long_direction(getattr(pos, "direction", None))
|
||
if is_long == want_long:
|
||
return pos
|
||
except Exception as exc:
|
||
logger.debug("find position: %s", exc)
|
||
return None
|
||
|
||
def _resolve_close_offset(self, sym: str, ex_name: str, hold_direction: str, lots: int) -> Any:
|
||
from vnpy.trader.constant import Offset
|
||
|
||
ex_u = (ex_name or "").upper()
|
||
# 上期所/能源中心/郑商所/中金所须区分平今/平昨;大商所等可用通用 CLOSE
|
||
if ex_u not in ("CZCE", "CFFEX", "SHFE", "INE"):
|
||
return Offset.CLOSE
|
||
pos = self._find_position(sym, ex_u, hold_direction)
|
||
if not pos:
|
||
for p in self._collect_positions():
|
||
ps = (p.get("symbol") or "").lower()
|
||
if ps != sym.lower():
|
||
continue
|
||
if (p.get("direction") or "long") != hold_direction:
|
||
continue
|
||
td = int(p.get("td_volume") or 0)
|
||
yd = int(p.get("yd_volume") or 0)
|
||
if td >= lots:
|
||
return Offset.CLOSETODAY
|
||
if yd >= lots:
|
||
return Offset.CLOSEYESTERDAY
|
||
if td + yd >= lots:
|
||
return Offset.CLOSETODAY
|
||
break
|
||
if ex_u in ("SHFE", "INE", "CZCE"):
|
||
return Offset.CLOSETODAY
|
||
return Offset.CLOSE
|
||
vol = int(getattr(pos, "volume", 0) or 0)
|
||
yd = int(getattr(pos, "yd_volume", 0) or 0)
|
||
today = max(0, vol - yd)
|
||
if today >= lots:
|
||
return Offset.CLOSETODAY
|
||
return Offset.CLOSEYESTERDAY
|
||
|
||
def _aggressive_limit_price(
|
||
self,
|
||
ths_code: str,
|
||
sym: str,
|
||
ex_name: str,
|
||
direction: Any,
|
||
tick: float,
|
||
fallback: float,
|
||
) -> float:
|
||
from vnpy.trader.constant import Direction
|
||
|
||
self.subscribe_symbol(ths_code)
|
||
lp = fallback
|
||
detail = self.get_tick_detail(ths_code, mode=self._connected_mode or "")
|
||
if detail.get("price"):
|
||
lp = float(detail["price"])
|
||
slip = max(tick, tick * 3)
|
||
if direction == Direction.LONG:
|
||
lp = lp + slip
|
||
else:
|
||
lp = max(tick, lp - slip)
|
||
return round_to_tick(lp, tick)
|
||
|
||
def ping(self) -> bool:
|
||
"""检测连接是否仍有效;无效则清除 connected 状态。"""
|
||
if not self._engine or not self._connected_mode:
|
||
return False
|
||
if self._td_logged_in():
|
||
return True
|
||
try:
|
||
if self._engine.get_all_accounts():
|
||
return True
|
||
except Exception as exc:
|
||
logger.debug("CTP ping failed: %s", exc)
|
||
self._connected_mode = None
|
||
return False
|
||
|
||
def mark_disconnected(self) -> None:
|
||
self._connected_mode = None
|
||
|
||
def reconnect_after_settings_saved(self, mode: str) -> dict[str, Any]:
|
||
"""保存前置/账号后关闭旧连接,并用数据库中的新配置重连。"""
|
||
from ctp_settings import is_ctp_auto_connect_enabled
|
||
|
||
self._close_gateway()
|
||
self._last_error = ""
|
||
_persist_last_error("")
|
||
if not is_ctp_auto_connect_enabled():
|
||
return {"started": False, "connecting": False, "connected": False, "disabled": True}
|
||
return self.start_connect_async(mode, force=True)
|
||
|
||
def _schedule_fee_sync(self, mode: str) -> None:
|
||
"""连接成功后触发每日同步检查(非每次全量)。"""
|
||
|
||
def _run() -> None:
|
||
time.sleep(45)
|
||
try:
|
||
from ctp_fee_worker import try_daily_ctp_fee_sync
|
||
|
||
def _gs(key: str, default: str = "") -> str:
|
||
from fee_specs import get_setting
|
||
return get_setting(key, default)
|
||
|
||
def _ss(key: str, val: str) -> None:
|
||
from fee_specs import set_setting
|
||
set_setting(key, val)
|
||
|
||
try_daily_ctp_fee_sync(
|
||
mode,
|
||
get_setting=_gs,
|
||
set_setting=_ss,
|
||
force=False,
|
||
)
|
||
except Exception as exc:
|
||
logger.debug("CTP 手续费连接后检查: %s", exc)
|
||
|
||
threading.Thread(target=_run, daemon=True, name="ctp-fee-sync-check").start()
|
||
|
||
def _ensure_commission_callback(self) -> None:
|
||
if self._commission_hooked or not self._engine:
|
||
return
|
||
try:
|
||
gw = self._engine.get_gateway(GATEWAY_NAME)
|
||
td = gw.td_api
|
||
except Exception:
|
||
return
|
||
bridge = self
|
||
|
||
def on_rsp(data: dict, error: dict, reqid: int, last: bool) -> None:
|
||
if error and int(error.get("ErrorID") or 0) != 0:
|
||
logger.debug(
|
||
"CTP commission error reqid=%s: %s",
|
||
reqid,
|
||
error.get("ErrorMsg") or error,
|
||
)
|
||
if data and data.get("InstrumentID"):
|
||
bridge._commission_lists.setdefault(reqid, []).append(dict(data))
|
||
ev = bridge._commission_waiters.get(reqid)
|
||
if last and ev:
|
||
ev.set()
|
||
|
||
td.onRspQryInstrumentCommissionRate = on_rsp # type: ignore[method-assign]
|
||
self._commission_hooked = True
|
||
|
||
def _query_commission(
|
||
self,
|
||
*,
|
||
mode: str,
|
||
instrument_id: str = "",
|
||
exchange_id: str = "",
|
||
timeout: float = 8,
|
||
) -> list[dict]:
|
||
if self._connected_mode != mode or not self._engine:
|
||
return []
|
||
try:
|
||
gw = self._engine.get_gateway(GATEWAY_NAME)
|
||
td = gw.td_api
|
||
except Exception as exc:
|
||
logger.debug("commission query init: %s", exc)
|
||
return []
|
||
if not getattr(td, "login_status", False):
|
||
return []
|
||
if not hasattr(td, "reqQryInstrumentCommissionRate"):
|
||
return []
|
||
self._ensure_commission_callback()
|
||
reqid = int(getattr(td, "reqid", 0)) + 1
|
||
td.reqid = reqid
|
||
ev = threading.Event()
|
||
self._commission_waiters[reqid] = ev
|
||
req = {
|
||
"BrokerID": td.brokerid,
|
||
"InvestorID": td.userid,
|
||
"InstrumentID": instrument_id or "",
|
||
"ExchangeID": exchange_id or "",
|
||
}
|
||
ret = td.reqQryInstrumentCommissionRate(req, reqid)
|
||
if ret != 0:
|
||
self._commission_waiters.pop(reqid, None)
|
||
return []
|
||
ev.wait(timeout=timeout)
|
||
self._commission_waiters.pop(reqid, None)
|
||
return self._commission_lists.pop(reqid, [])
|
||
|
||
def query_instrument_commission(self, ths_code: str, *, mode: str) -> dict:
|
||
"""查询单合约 CTP 手续费率(需已连接)。"""
|
||
try:
|
||
sym, ex_name = ths_to_vnpy_symbol(ths_code)
|
||
except Exception:
|
||
return {}
|
||
rows = self._query_commission(
|
||
mode=mode,
|
||
instrument_id=sym,
|
||
exchange_id=ex_name,
|
||
)
|
||
return rows[-1] if rows else {}
|
||
|
||
def query_all_commissions(self, *, mode: str) -> list[dict]:
|
||
"""批量查询全部合约手续费(InstrumentID 留空)。"""
|
||
return self._query_commission(mode=mode, timeout=45)
|
||
|
||
@staticmethod
|
||
def _parse_margin_ratio_row(data: dict) -> dict[str, float]:
|
||
long_r = float(
|
||
data.get("LongMarginRatioByMoney")
|
||
or data.get("LongMarginRatio")
|
||
or 0
|
||
)
|
||
short_r = float(
|
||
data.get("ShortMarginRatioByMoney")
|
||
or data.get("ShortMarginRatio")
|
||
or 0
|
||
)
|
||
return {"long": long_r, "short": short_r}
|
||
|
||
def _cache_margin_ratio(self, sym: str, data: dict) -> None:
|
||
ratios = self._parse_margin_ratio_row(data)
|
||
if ratios["long"] <= 0 and ratios["short"] <= 0:
|
||
return
|
||
key = (sym or "").strip().lower()
|
||
if not key:
|
||
return
|
||
self._instrument_margin_ratios[key] = ratios
|
||
|
||
def _ensure_instrument_margin_hooks(self) -> None:
|
||
"""登录前挂钩:合约/持仓查询回报;td_api 重建后须重新挂钩。"""
|
||
if not self._engine:
|
||
return
|
||
try:
|
||
gw = self._engine.get_gateway(GATEWAY_NAME)
|
||
td = gw.td_api
|
||
except Exception:
|
||
return
|
||
bridge = self
|
||
td_id = id(td)
|
||
if td_id != self._hooks_td_api_id:
|
||
self._hooks_td_api_id = td_id
|
||
self._instrument_hooked = False
|
||
self._margin_rate_hooked = False
|
||
|
||
if not self._instrument_hooked:
|
||
orig_inst = td.onRspQryInstrument
|
||
|
||
def on_instrument(data: dict, error: dict, reqid: int, last: bool) -> None:
|
||
try:
|
||
if data and data.get("InstrumentID"):
|
||
bridge._cache_margin_ratio(str(data["InstrumentID"]), data)
|
||
except Exception as exc:
|
||
logger.debug("instrument margin cache: %s", exc)
|
||
if last:
|
||
_schedule_after_instruments_ready(bridge)
|
||
return orig_inst(data, error, reqid, last)
|
||
|
||
td.onRspQryInstrument = on_instrument # type: ignore[method-assign]
|
||
|
||
orig_pos = td.onRspQryInvestorPosition
|
||
|
||
def on_rsp_position(
|
||
data: dict, error: dict, reqid: int, last: bool,
|
||
) -> None:
|
||
ret = orig_pos(data, error, reqid, last)
|
||
if last:
|
||
now = time.monotonic()
|
||
if now - bridge._last_position_rsp_ts < 30.0:
|
||
return ret
|
||
bridge._last_position_rsp_ts = now
|
||
|
||
def _after_position_query() -> None:
|
||
try:
|
||
time.sleep(1.5)
|
||
with _ctp_td_lock:
|
||
bridge.calibrate_trading_state()
|
||
_fire_position_refresh_callback()
|
||
except Exception as exc:
|
||
logger.debug("position rsp refresh: %s", exc)
|
||
|
||
threading.Timer(0.2, _after_position_query).start()
|
||
return ret
|
||
|
||
td.onRspQryInvestorPosition = on_rsp_position # type: ignore[method-assign]
|
||
self._instrument_hooked = True
|
||
|
||
if self._margin_rate_hooked:
|
||
return
|
||
|
||
def on_margin_rate(data: dict, error: dict, reqid: int, last: bool) -> None:
|
||
if error and int(error.get("ErrorID") or 0) != 0:
|
||
logger.debug(
|
||
"CTP margin rate error reqid=%s: %s",
|
||
reqid,
|
||
error.get("ErrorMsg") or error,
|
||
)
|
||
if data and data.get("InstrumentID"):
|
||
bridge._margin_rate_lists.setdefault(reqid, []).append(dict(data))
|
||
ev = bridge._margin_rate_waiters.get(reqid)
|
||
if last and ev:
|
||
ev.set()
|
||
|
||
td.onRspQryInstrumentMarginRate = on_margin_rate # type: ignore[method-assign]
|
||
self._margin_rate_hooked = True
|
||
|
||
def _query_instrument_margin_rate(
|
||
self,
|
||
*,
|
||
mode: str,
|
||
instrument_id: str,
|
||
exchange_id: str,
|
||
timeout: float = 6,
|
||
) -> Optional[dict[str, float]]:
|
||
if self._connected_mode != mode or not self._engine:
|
||
return None
|
||
sym = (instrument_id or "").strip()
|
||
if not sym:
|
||
return None
|
||
cached = self._instrument_margin_ratios.get(sym.lower())
|
||
if cached and (cached.get("long", 0) > 0 or cached.get("short", 0) > 0):
|
||
return cached
|
||
try:
|
||
gw = self._engine.get_gateway(GATEWAY_NAME)
|
||
td = gw.td_api
|
||
except Exception as exc:
|
||
logger.debug("margin rate query init: %s", exc)
|
||
return None
|
||
if not getattr(td, "login_status", False):
|
||
return None
|
||
if not hasattr(td, "reqQryInstrumentMarginRate"):
|
||
return None
|
||
self._ensure_instrument_margin_hooks()
|
||
reqid = int(getattr(td, "reqid", 0)) + 1
|
||
td.reqid = reqid
|
||
ev = threading.Event()
|
||
self._margin_rate_waiters[reqid] = ev
|
||
req = {
|
||
"BrokerID": td.brokerid,
|
||
"InvestorID": td.userid,
|
||
"InstrumentID": sym,
|
||
"ExchangeID": exchange_id or "",
|
||
"InvestorRange": "1",
|
||
"HedgeFlag": "1",
|
||
}
|
||
with _ctp_td_lock:
|
||
ret = td.reqQryInstrumentMarginRate(req, reqid)
|
||
if ret != 0:
|
||
self._margin_rate_waiters.pop(reqid, None)
|
||
return None
|
||
ev.wait(timeout=timeout)
|
||
self._margin_rate_waiters.pop(reqid, None)
|
||
rows = self._margin_rate_lists.pop(reqid, [])
|
||
if not rows:
|
||
return None
|
||
ratios = self._parse_margin_ratio_row(rows[-1])
|
||
if ratios["long"] > 0 or ratios["short"] > 0:
|
||
self._cache_margin_ratio(sym, rows[-1])
|
||
return ratios
|
||
return None
|
||
|
||
def _lookup_margin_ratios(
|
||
self,
|
||
sym: str,
|
||
ex_name: str,
|
||
*,
|
||
mode: Optional[str] = None,
|
||
) -> Optional[dict[str, float]]:
|
||
key = (sym or "").strip().lower()
|
||
if not key:
|
||
return None
|
||
cached = self._instrument_margin_ratios.get(key)
|
||
if cached and (cached.get("long", 0) > 0 or cached.get("short", 0) > 0):
|
||
return cached
|
||
if mode and self._connected_mode == mode:
|
||
return self._query_instrument_margin_rate(
|
||
mode=mode,
|
||
instrument_id=sym,
|
||
exchange_id=ex_name,
|
||
)
|
||
return None
|
||
|
||
def _lookup_margin_per_lot(self, sym: str, direction: str) -> float:
|
||
return float(
|
||
self._margin_per_lot.get(self._position_margin_key(sym, direction), 0) or 0
|
||
)
|
||
|
||
def _margin_from_ratios(
|
||
self,
|
||
price: float,
|
||
mult: float,
|
||
ratios: dict[str, float],
|
||
*,
|
||
direction: str,
|
||
) -> Optional[float]:
|
||
long_r = float(ratios.get("long") or 0)
|
||
short_r = float(ratios.get("short") or 0)
|
||
d = (direction or "long").strip().lower()
|
||
if mult <= 0 or price <= 0:
|
||
return None
|
||
if d == "max":
|
||
candidates = [
|
||
round(float(price) * mult * r, 2)
|
||
for r in (long_r, short_r)
|
||
if r > 0
|
||
]
|
||
return max(candidates) if candidates else None
|
||
if d == "short" and short_r > 0:
|
||
ratio = short_r
|
||
elif d != "short" and long_r > 0:
|
||
ratio = long_r
|
||
else:
|
||
ratio = max(long_r, short_r)
|
||
if ratio <= 0:
|
||
return None
|
||
return round(float(price) * mult * ratio, 2)
|
||
|
||
def _tick_key(self, symbol: str, ex_name: str) -> str:
|
||
return f"{symbol.lower()}:{ex_name.upper()}"
|
||
|
||
def _price_from_tick(self, tick: Any) -> Optional[float]:
|
||
for attr in ("last_price", "bid_price_1", "ask_price_1", "pre_close"):
|
||
try:
|
||
v = float(getattr(tick, attr, 0) or 0)
|
||
except (TypeError, ValueError):
|
||
v = 0.0
|
||
if v > 0:
|
||
return v
|
||
return None
|
||
|
||
def _lookup_tick(self, symbol: str, ex_name: str) -> Optional[float]:
|
||
if not self._engine:
|
||
return None
|
||
sym_l = symbol.lower()
|
||
ex_u = ex_name.upper()
|
||
try:
|
||
for tick in self._engine.get_all_ticks():
|
||
ts = (getattr(tick, "symbol", "") or "").lower()
|
||
te = getattr(tick, "exchange", None)
|
||
te_s = str(te.value if hasattr(te, "value") else te or "").upper()
|
||
if ts == sym_l and te_s == ex_u:
|
||
p = self._price_from_tick(tick)
|
||
if p:
|
||
return p
|
||
except Exception as exc:
|
||
logger.debug("lookup tick: %s", exc)
|
||
return None
|
||
|
||
def _bar_to_dict(self, bar: Any) -> dict:
|
||
dt = getattr(bar, "datetime", None)
|
||
d_str = dt.strftime("%Y-%m-%d %H:%M:%S") if dt else ""
|
||
return {
|
||
"d": d_str,
|
||
"o": float(getattr(bar, "open_price", 0) or 0),
|
||
"h": float(getattr(bar, "high_price", 0) or 0),
|
||
"l": float(getattr(bar, "low_price", 0) or 0),
|
||
"c": float(getattr(bar, "close_price", 0) or 0),
|
||
"v": float(getattr(bar, "volume", 0) or 0),
|
||
}
|
||
|
||
def _ensure_bar_generator(self, sym: str, ex_name: str) -> None:
|
||
key = self._tick_key(sym, ex_name)
|
||
if key in self._bar_generators:
|
||
return
|
||
self._bars_1m[key] = deque(maxlen=4000)
|
||
|
||
def on_bar(bar: Any) -> None:
|
||
row = self._bar_to_dict(bar)
|
||
if row.get("d"):
|
||
self._bars_1m[key].append(row)
|
||
|
||
try:
|
||
from vnpy.trader.utility import BarGenerator
|
||
|
||
self._bar_generators[key] = BarGenerator(on_bar=on_bar)
|
||
except ImportError:
|
||
logger.debug("BarGenerator unavailable")
|
||
|
||
def _find_tick(self, symbol: str, ex_name: str) -> Any:
|
||
if not self._engine:
|
||
return None
|
||
sym_l = symbol.lower()
|
||
ex_u = ex_name.upper()
|
||
try:
|
||
for tick in self._engine.get_all_ticks():
|
||
ts = (getattr(tick, "symbol", "") or "").lower()
|
||
te = getattr(tick, "exchange", None)
|
||
te_s = str(te.value if hasattr(te, "value") else te or "").upper()
|
||
if ts == sym_l and te_s == ex_u:
|
||
return tick
|
||
except Exception as exc:
|
||
logger.debug("find tick: %s", exc)
|
||
return None
|
||
|
||
def _tick_to_bar(self, symbol: str, ex_name: str) -> Optional[dict]:
|
||
tick = self._find_tick(symbol, ex_name)
|
||
if not tick:
|
||
return None
|
||
lp = self._price_from_tick(tick)
|
||
if not lp or lp <= 0:
|
||
return None
|
||
dt = getattr(tick, "datetime", None)
|
||
d_str = dt.strftime("%Y-%m-%d %H:%M:%S") if dt else ""
|
||
if not d_str:
|
||
from datetime import datetime
|
||
from zoneinfo import ZoneInfo
|
||
|
||
d_str = datetime.now(ZoneInfo("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||
o = float(getattr(tick, "open_price", 0) or lp)
|
||
h = float(getattr(tick, "high_price", 0) or lp)
|
||
lo = float(getattr(tick, "low_price", 0) or lp)
|
||
return {
|
||
"d": d_str,
|
||
"o": o,
|
||
"h": h,
|
||
"l": lo,
|
||
"c": lp,
|
||
"v": float(getattr(tick, "volume", 0) or 0),
|
||
}
|
||
|
||
def _on_tick(self, tick: Any) -> None:
|
||
sym = (getattr(tick, "symbol", "") or "").lower()
|
||
te = getattr(tick, "exchange", None)
|
||
ex_s = str(te.value if hasattr(te, "value") else te or "").upper()
|
||
price = self._price_from_tick(tick)
|
||
if price and price > 0:
|
||
try:
|
||
from ctp_trading_state import trading_state
|
||
|
||
trading_state.set_tick_price(ex_s, sym, price)
|
||
except Exception:
|
||
pass
|
||
fn = _tick_sl_tp_callback
|
||
if fn:
|
||
try:
|
||
fn(ex_s, sym, float(price))
|
||
except Exception as exc:
|
||
logger.debug("tick sl/tp callback: %s", exc)
|
||
_fire_tick_quote_callback_debounced()
|
||
key = self._tick_key(sym, ex_s)
|
||
bg = self._bar_generators.get(key)
|
||
if not bg:
|
||
return
|
||
try:
|
||
bg.update_tick(tick)
|
||
except Exception as exc:
|
||
logger.debug("bar gen tick: %s", exc)
|
||
|
||
def _ensure_tick_handler(self) -> None:
|
||
if self._tick_hooked or not self._ee:
|
||
return
|
||
try:
|
||
from vnpy.trader.event import EVENT_TICK
|
||
except ImportError:
|
||
return
|
||
|
||
def process_tick(event: Any) -> None:
|
||
self._on_tick(event.data)
|
||
|
||
self._ee.register(EVENT_TICK, process_tick)
|
||
self._tick_hooked = True
|
||
|
||
def get_kline_bars_1m(self, ths_code: str, *, mode: str) -> list[dict]:
|
||
"""订阅合约并返回 1 分钟 K 线(含正在形成的 bar)。"""
|
||
if self._connected_mode != mode or not self._engine:
|
||
return []
|
||
try:
|
||
sym, ex_name = ths_to_vnpy_symbol(ths_code)
|
||
except Exception:
|
||
return []
|
||
key = self._tick_key(sym, ex_name)
|
||
self._ensure_bar_generator(sym, ex_name)
|
||
self.subscribe_symbol(ths_code)
|
||
for _ in range(12):
|
||
if self._bars_1m.get(key) and len(self._bars_1m[key]) > 0:
|
||
break
|
||
if self._lookup_tick(sym, ex_name):
|
||
break
|
||
time.sleep(0.2)
|
||
bars_1m = list(self._bars_1m.get(key, []))
|
||
bg = self._bar_generators.get(key)
|
||
if bg and getattr(bg, "bar", None):
|
||
forming = self._bar_to_dict(bg.bar)
|
||
if forming.get("d"):
|
||
if not bars_1m or bars_1m[-1]["d"] != forming["d"]:
|
||
bars_1m.append(forming)
|
||
else:
|
||
bars_1m[-1] = forming
|
||
if not bars_1m:
|
||
tick_bar = self._tick_to_bar(sym, ex_name)
|
||
if tick_bar:
|
||
bars_1m = [tick_bar]
|
||
return bars_1m
|
||
|
||
def get_tick_detail(self, ths_code: str, *, mode: str) -> dict[str, Any]:
|
||
if self._connected_mode != mode or not self._engine:
|
||
return {}
|
||
try:
|
||
sym, ex_name = ths_to_vnpy_symbol(ths_code)
|
||
except Exception:
|
||
return {}
|
||
self.subscribe_symbol(ths_code)
|
||
for _ in range(8):
|
||
tick = self._find_tick(sym, ex_name)
|
||
if tick:
|
||
price = self._price_from_tick(tick)
|
||
try:
|
||
pre_close = float(getattr(tick, "pre_close", 0) or 0)
|
||
except (TypeError, ValueError):
|
||
pre_close = 0.0
|
||
return {
|
||
"price": price,
|
||
"pre_close": pre_close if pre_close > 0 else None,
|
||
}
|
||
time.sleep(0.2)
|
||
return {}
|
||
|
||
def subscribe_symbol(self, ths_code: str) -> None:
|
||
if not self._engine or not self._connected_mode:
|
||
return
|
||
try:
|
||
from vnpy.trader.object import SubscribeRequest
|
||
|
||
sym, ex_name = ths_to_vnpy_symbol(ths_code)
|
||
key = self._tick_key(sym, ex_name)
|
||
self._ensure_bar_generator(sym, ex_name)
|
||
if key in self._subscribed:
|
||
return
|
||
exchange = to_vnpy_exchange(ex_name)
|
||
self._ensure_tick_handler()
|
||
req = SubscribeRequest(symbol=sym, exchange=exchange)
|
||
self._engine.subscribe(req, GATEWAY_NAME)
|
||
self._subscribed.add(key)
|
||
except Exception as exc:
|
||
logger.debug("CTP subscribe %s: %s", ths_code, exc)
|
||
|
||
def get_tick_price(self, ths_code: str, *, mode: str) -> Optional[float]:
|
||
if self._connected_mode != mode or not self._engine:
|
||
return None
|
||
try:
|
||
sym, ex_name = ths_to_vnpy_symbol(ths_code)
|
||
except Exception:
|
||
return None
|
||
price = self._lookup_tick(sym, ex_name)
|
||
if price:
|
||
return price
|
||
self.subscribe_symbol(ths_code)
|
||
for _ in range(8):
|
||
time.sleep(0.2)
|
||
price = self._lookup_tick(sym, ex_name)
|
||
if price:
|
||
return price
|
||
return None
|
||
|
||
def get_account(self) -> dict[str, Any]:
|
||
if not self._engine:
|
||
return {}
|
||
accounts = self._engine.get_all_accounts()
|
||
if not accounts:
|
||
return {}
|
||
acc = accounts[0]
|
||
return {
|
||
"balance": float(getattr(acc, "balance", 0) or 0),
|
||
"available": float(getattr(acc, "available", 0) or 0),
|
||
"frozen": float(getattr(acc, "frozen", 0) or 0),
|
||
"accountid": getattr(acc, "accountid", ""),
|
||
}
|
||
|
||
def _position_margin_key(self, sym: str, direction: str) -> str:
|
||
return f"{(sym or '').lower()}:{(direction or 'long').strip().lower()}"
|
||
|
||
def _lookup_position_open_time(self, sym: str, direction: str) -> str:
|
||
return (self._position_open_times.get(self._position_margin_key(sym, direction)) or "").strip()
|
||
|
||
@staticmethod
|
||
def _parse_ctp_open_datetime(date_raw: str, time_raw: str = "") -> str:
|
||
"""CTP OpenDate + OpenTime → YYYY-MM-DD HH:MM[:SS]。"""
|
||
d = (date_raw or "").strip()
|
||
if len(d) >= 8 and d[:8].isdigit():
|
||
date_part = f"{d[:4]}-{d[4:6]}-{d[6:8]}"
|
||
else:
|
||
return ""
|
||
t = (time_raw or "").strip().replace(":", "")
|
||
if len(t) >= 6 and t[:6].isdigit():
|
||
return f"{date_part} {t[0:2]}:{t[2:4]}:{t[4:6]}"
|
||
if len(t) >= 4 and t.isdigit():
|
||
return f"{date_part} {t[0:2]}:{t[2:4]}"
|
||
return date_part
|
||
|
||
def _parse_ctp_open_date(raw: str) -> str:
|
||
return CtpBridge._parse_ctp_open_datetime(raw, "")
|
||
|
||
def _install_position_margin_hook(self) -> None:
|
||
"""已禁用:monkey-patch CTP 持仓回调在并发下会触发 vnctptd 段错误。"""
|
||
return
|
||
|
||
def _lookup_position_margin(self, sym: str, direction: str) -> float:
|
||
return float(self._position_margins.get(self._position_margin_key(sym, direction), 0) or 0)
|
||
|
||
@staticmethod
|
||
def _vnpy_sym_to_ths(sym: str, ex_name: str) -> str:
|
||
import re
|
||
|
||
s = (sym or "").strip()
|
||
if not s:
|
||
return ""
|
||
ex = (ex_name or "").upper()
|
||
m = re.match(r"^([A-Za-z]+)(\d+)$", s)
|
||
if not m:
|
||
return s
|
||
letters, digits = m.group(1), m.group(2)
|
||
if ex == "CZCE":
|
||
return letters.upper() + (digits[-3:] if len(digits) >= 4 else digits)
|
||
return letters.lower() + digits
|
||
|
||
def _get_contract_for_ths(self, ths_code: str) -> Any:
|
||
"""按同花顺代码查 CTP 合约;精确匹配失败时在同交易所按品种前缀回退。"""
|
||
if not self._engine:
|
||
return None
|
||
try:
|
||
sym, ex_name = ths_to_vnpy_symbol(ths_code)
|
||
exchange = to_vnpy_exchange(ex_name)
|
||
vt_symbol = f"{sym}.{exchange.value}"
|
||
contract = self._engine.get_contract(vt_symbol)
|
||
if contract:
|
||
return contract
|
||
m = re.match(r"^([A-Za-z]+)", (ths_code or "").strip())
|
||
if not m:
|
||
return None
|
||
letters = m.group(1)
|
||
ex_val = exchange.value
|
||
candidates: list[Any] = []
|
||
get_all = getattr(self._engine, "get_all_contracts", None)
|
||
pool = list(get_all()) if callable(get_all) else []
|
||
if not pool:
|
||
raw = getattr(self._engine, "contracts", None)
|
||
if isinstance(raw, dict):
|
||
pool = list(raw.values())
|
||
sym_prefix = sym[: len(letters)] if sym else letters.lower()
|
||
sym_prefix_up = letters.upper()
|
||
for c in pool:
|
||
c_ex = getattr(c, "exchange", None)
|
||
c_ex_val = str(c_ex.value if hasattr(c_ex, "value") else c_ex or "")
|
||
if c_ex_val != ex_val:
|
||
continue
|
||
c_sym = str(getattr(c, "symbol", "") or "")
|
||
if (
|
||
c_sym.lower().startswith(sym_prefix.lower())
|
||
or c_sym.upper().startswith(sym_prefix_up)
|
||
):
|
||
candidates.append(c)
|
||
if not candidates:
|
||
return None
|
||
candidates.sort(key=lambda c: str(getattr(c, "symbol", "") or ""))
|
||
return candidates[0]
|
||
except Exception as exc:
|
||
logger.debug("_get_contract_for_ths %s: %s", ths_code, exc)
|
||
return None
|
||
|
||
def estimate_margin_one_lot(
|
||
self,
|
||
ths_code: str,
|
||
price: float,
|
||
*,
|
||
direction: str = "long",
|
||
) -> Optional[float]:
|
||
"""1 手保证金:持仓实收 > CTP 保证金率查询 > 合约查询缓存。"""
|
||
if not self._engine or not price or price <= 0:
|
||
return None
|
||
try:
|
||
sym, ex_name = ths_to_vnpy_symbol(ths_code)
|
||
contract = self._get_contract_for_ths(ths_code)
|
||
mult = float(getattr(contract, "size", 0) or 0) if contract else 0.0
|
||
if mult <= 0:
|
||
mult = float(get_contract_spec(ths_code).get("mult") or 0)
|
||
d = (direction or "long").strip().lower()
|
||
if d == "max":
|
||
per_lots = [
|
||
self._lookup_margin_per_lot(sym, side)
|
||
for side in ("long", "short")
|
||
]
|
||
per_lots = [x for x in per_lots if x > 0]
|
||
if per_lots:
|
||
return max(per_lots)
|
||
else:
|
||
per_lot = self._lookup_margin_per_lot(sym, d)
|
||
if per_lot > 0:
|
||
return per_lot
|
||
mode = self._connected_mode
|
||
ratios = self._lookup_margin_ratios(sym, ex_name, mode=mode)
|
||
if ratios:
|
||
return self._margin_from_ratios(
|
||
price, mult, ratios, direction=d,
|
||
)
|
||
return None
|
||
except Exception as exc:
|
||
logger.debug("estimate_margin_one_lot %s: %s", ths_code, exc)
|
||
return None
|
||
|
||
def estimate_position_margin(
|
||
self,
|
||
sym: str,
|
||
ex_name: str,
|
||
direction: str,
|
||
lots: int,
|
||
price: float,
|
||
*,
|
||
pos: Any = None,
|
||
) -> Optional[float]:
|
||
"""持仓占用保证金:优先 vnpy 字段,其次 CTP 合约保证金率估算。"""
|
||
if lots <= 0 or price <= 0:
|
||
return None
|
||
if pos is not None:
|
||
raw = float(getattr(pos, "margin", 0) or getattr(pos, "use_margin", 0) or 0)
|
||
if raw > 0:
|
||
return round(raw, 2)
|
||
cached = self._lookup_position_margin(sym, direction)
|
||
if cached > 0:
|
||
return round(cached, 2)
|
||
ths = self._vnpy_sym_to_ths(sym, ex_name)
|
||
if not ths:
|
||
return None
|
||
per_lot = self.estimate_margin_one_lot(ths, price, direction=direction)
|
||
if per_lot and per_lot > 0:
|
||
return round(per_lot * lots, 2)
|
||
return None
|
||
|
||
def lookup_contract_spec(self, ths_code: str) -> Optional[dict]:
|
||
"""从 CTP 合约信息读取乘数与最小变动价位。"""
|
||
if not self._engine:
|
||
return None
|
||
try:
|
||
sym, ex_name = ths_to_vnpy_symbol(ths_code)
|
||
contract = self._get_contract_for_ths(ths_code)
|
||
if not contract:
|
||
return None
|
||
mult = float(getattr(contract, "size", 0) or 0)
|
||
tick = float(
|
||
getattr(contract, "pricetick", 0)
|
||
or getattr(contract, "price_tick", 0)
|
||
or 0
|
||
)
|
||
if mult <= 0:
|
||
return None
|
||
out: dict[str, Any] = {"mult": mult}
|
||
if tick > 0:
|
||
out["tick_size"] = tick
|
||
long_r = float(getattr(contract, "long_margin_ratio", 0) or 0)
|
||
short_r = float(getattr(contract, "short_margin_ratio", 0) or 0)
|
||
c_sym = str(getattr(contract, "symbol", "") or sym or "")
|
||
if c_sym and self._connected_mode:
|
||
queried = self._lookup_margin_ratios(
|
||
c_sym, ex_name, mode=self._connected_mode,
|
||
)
|
||
if queried:
|
||
long_r = float(queried.get("long") or long_r)
|
||
short_r = float(queried.get("short") or short_r)
|
||
if long_r > 0 or short_r > 0:
|
||
out["margin_rate"] = max(long_r, short_r)
|
||
return out
|
||
except Exception as exc:
|
||
logger.debug("lookup_contract_spec %s: %s", ths_code, exc)
|
||
return None
|
||
|
||
def _collect_positions(self) -> list[dict[str, Any]]:
|
||
if not self._engine:
|
||
return []
|
||
out: list[dict[str, Any]] = []
|
||
for pos in self._engine.get_all_positions():
|
||
vol = int(getattr(pos, "volume", 0) or 0)
|
||
if vol <= 0:
|
||
continue
|
||
d = "long" if _is_long_direction(getattr(pos, "direction", None)) else "short"
|
||
sym = getattr(pos, "symbol", "") or ""
|
||
exchange = getattr(pos, "exchange", None)
|
||
ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "")
|
||
price = float(getattr(pos, "price", 0) or 0)
|
||
margin = self.estimate_position_margin(
|
||
sym, ex_name, d, vol, price, pos=pos,
|
||
)
|
||
open_time = self._lookup_position_open_time(sym, d) or None
|
||
yd = int(getattr(pos, "yd_volume", 0) or 0)
|
||
td = max(0, vol - yd)
|
||
out.append({
|
||
"symbol": sym,
|
||
"exchange": ex_name,
|
||
"direction": d,
|
||
"lots": vol,
|
||
"avg_price": price,
|
||
"pnl": float(getattr(pos, "pnl", 0) or 0),
|
||
"frozen": int(getattr(pos, "frozen", 0) or 0),
|
||
"margin": margin,
|
||
"open_time": open_time,
|
||
"yd_volume": yd,
|
||
"td_volume": td,
|
||
})
|
||
return out
|
||
|
||
def refresh_positions(self) -> None:
|
||
"""vnpy 内存缓存持仓;禁止 query_position(vnctptd 并发查询会段错误)。"""
|
||
return
|
||
|
||
def _has_live_positions(self) -> bool:
|
||
if not self._engine:
|
||
return False
|
||
try:
|
||
with _ctp_td_lock:
|
||
return len(self._collect_positions()) > 0
|
||
except Exception:
|
||
return False
|
||
|
||
def request_position_snapshot(self, *, force: bool = False) -> None:
|
||
"""合约加载后查询持仓,填充 vnpy 内存(已有持仓时跳过主动查询)。"""
|
||
if not self._engine or not self._connected_mode:
|
||
return
|
||
if not force and self._has_live_positions():
|
||
return
|
||
now = time.monotonic()
|
||
if not force and (now - self._last_position_query_ts) < POSITION_QUERY_MIN_INTERVAL_SEC:
|
||
return
|
||
try:
|
||
self._ensure_instrument_margin_hooks()
|
||
gw = self._engine.get_gateway(GATEWAY_NAME)
|
||
td = getattr(gw, "td_api", None) if gw else None
|
||
if not td or not getattr(td, "login_status", False):
|
||
logger.debug("CTP 持仓查询跳过:交易未登录")
|
||
return
|
||
if hasattr(td, "reqQryInvestorPosition"):
|
||
reqid = int(getattr(td, "reqid", 0)) + 1
|
||
td.reqid = reqid
|
||
req = {
|
||
"BrokerID": getattr(td, "brokerid", ""),
|
||
"InvestorID": getattr(td, "userid", ""),
|
||
}
|
||
with _ctp_td_lock:
|
||
ret = td.reqQryInvestorPosition(req, reqid)
|
||
if ret == 0:
|
||
self._last_position_query_ts = now
|
||
logger.info("CTP 已请求持仓查询 reqid=%s", reqid)
|
||
else:
|
||
logger.debug("CTP 持仓查询发送失败 ret=%s", ret)
|
||
elif gw and hasattr(gw, "query_position"):
|
||
gw.query_position()
|
||
self._last_position_query_ts = now
|
||
logger.info("CTP 已请求持仓查询(gateway)")
|
||
except Exception as exc:
|
||
logger.debug("request_position_snapshot: %s", exc)
|
||
|
||
def list_positions(self, *, refresh_if_empty: bool = True, refresh_margin: bool = False) -> list[dict[str, Any]]:
|
||
del refresh_if_empty, refresh_margin
|
||
with _ctp_td_lock:
|
||
return self._collect_positions()
|
||
|
||
@staticmethod
|
||
def _parse_trade_offset(offset_obj: Any) -> str:
|
||
s = str(offset_obj or "").upper()
|
||
if "OPEN" in s:
|
||
return "open"
|
||
return "close"
|
||
|
||
@staticmethod
|
||
def _parse_trade_direction(direction_obj: Any) -> str:
|
||
return "long" if _is_long_direction(direction_obj) else "short"
|
||
|
||
@staticmethod
|
||
def _position_direction_from_trade(trade_direction: str, offset: str) -> str:
|
||
td = (trade_direction or "long").strip().lower()
|
||
if (offset or "open").strip().lower() == "open":
|
||
return td
|
||
return "short" if td == "long" else "long"
|
||
|
||
def _format_trade_datetime(self, dt_obj: Any, date_raw: str = "", time_raw: str = "") -> str:
|
||
if dt_obj is not None:
|
||
try:
|
||
if hasattr(dt_obj, "strftime"):
|
||
return dt_obj.strftime("%Y-%m-%d %H:%M:%S")
|
||
text = str(dt_obj).strip()
|
||
if text:
|
||
return text[:19].replace("T", " ")
|
||
except Exception:
|
||
pass
|
||
parsed = self._parse_ctp_open_datetime(date_raw, time_raw)
|
||
return parsed or ""
|
||
|
||
def _trade_row_from_vnpy(self, trade: Any) -> Optional[dict[str, Any]]:
|
||
try:
|
||
sym = (getattr(trade, "symbol", "") or "").strip()
|
||
vol = int(getattr(trade, "volume", 0) or 0)
|
||
if not sym or vol <= 0:
|
||
return None
|
||
direction = self._parse_trade_direction(getattr(trade, "direction", None))
|
||
offset = self._parse_trade_offset(getattr(trade, "offset", None))
|
||
exchange = getattr(trade, "exchange", None)
|
||
ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "")
|
||
dt = self._format_trade_datetime(getattr(trade, "datetime", None))
|
||
trade_id = str(getattr(trade, "tradeid", "") or getattr(trade, "vt_tradeid", "") or "")
|
||
order_id = str(getattr(trade, "orderid", "") or getattr(trade, "vt_orderid", "") or "")
|
||
if not trade_id:
|
||
trade_id = f"{order_id}:{sym}:{offset}:{direction}:{vol}:{getattr(trade, 'price', 0)}:{dt}"
|
||
return {
|
||
"trade_id": trade_id,
|
||
"order_id": order_id,
|
||
"symbol": sym,
|
||
"exchange": ex_name,
|
||
"direction": direction,
|
||
"offset": offset,
|
||
"position_direction": self._position_direction_from_trade(direction, offset),
|
||
"lots": vol,
|
||
"price": float(getattr(trade, "price", 0) or 0),
|
||
"datetime": dt,
|
||
"commission": round(float(getattr(trade, "commission", 0) or 0), 2),
|
||
}
|
||
except Exception as exc:
|
||
logger.debug("trade_row_from_vnpy: %s", exc)
|
||
return None
|
||
|
||
def _trade_row_from_ctp_dict(self, data: dict) -> Optional[dict[str, Any]]:
|
||
try:
|
||
sym = (data.get("InstrumentID") or data.get("instrument_id") or "").strip()
|
||
vol = int(float(data.get("Volume") or data.get("volume") or 0))
|
||
if not sym or vol <= 0:
|
||
return None
|
||
dir_raw = str(data.get("Direction") or data.get("direction") or "")
|
||
direction = "long" if dir_raw in ("0", "2") or "LONG" in dir_raw.upper() or dir_raw == "多" else "short"
|
||
off_raw = str(data.get("OffsetFlag") or data.get("offset") or "")
|
||
if off_raw in ("0",) or "OPEN" in off_raw.upper():
|
||
offset = "open"
|
||
else:
|
||
offset = "close"
|
||
price = float(data.get("Price") or data.get("price") or 0)
|
||
trade_id = str(data.get("TradeID") or data.get("tradeid") or "").strip()
|
||
order_sys = str(data.get("OrderSysID") or data.get("orderid") or "").strip()
|
||
dt = self._format_trade_datetime(
|
||
None,
|
||
str(data.get("TradeDate") or data.get("trade_date") or ""),
|
||
str(data.get("TradeTime") or data.get("trade_time") or ""),
|
||
)
|
||
if not trade_id:
|
||
trade_id = f"{order_sys}:{sym}:{offset}:{direction}:{vol}:{price}:{dt}"
|
||
return {
|
||
"trade_id": trade_id,
|
||
"order_id": order_sys,
|
||
"symbol": sym,
|
||
"exchange": str(data.get("ExchangeID") or data.get("exchange") or ""),
|
||
"direction": direction,
|
||
"offset": offset,
|
||
"position_direction": self._position_direction_from_trade(direction, offset),
|
||
"lots": vol,
|
||
"price": price,
|
||
"datetime": dt,
|
||
"commission": round(
|
||
float(data.get("Commission") or data.get("commission") or 0), 2,
|
||
),
|
||
}
|
||
except Exception as exc:
|
||
logger.debug("trade_row_from_ctp_dict: %s", exc)
|
||
return None
|
||
|
||
def _install_trade_query_hook(self) -> None:
|
||
"""不再 monkey-patch CTP 成交回调(易与并发查询冲突导致 vnctptd 段错误)。"""
|
||
return
|
||
|
||
@staticmethod
|
||
def _engine_collection_items(raw: Any) -> list[Any]:
|
||
"""vnpy 不同版本可能返回 dict 或 list。"""
|
||
if raw is None:
|
||
return []
|
||
if isinstance(raw, dict):
|
||
return list(raw.values())
|
||
if isinstance(raw, (list, tuple)):
|
||
return list(raw)
|
||
return [raw]
|
||
|
||
def _collect_engine_trades(self) -> list[dict[str, Any]]:
|
||
if not self._engine:
|
||
return []
|
||
out: list[dict[str, Any]] = []
|
||
seen: set[str] = set()
|
||
try:
|
||
trades = self._engine.get_all_trades()
|
||
except Exception:
|
||
trades = None
|
||
for trade in self._engine_collection_items(trades):
|
||
row = self._trade_row_from_vnpy(trade)
|
||
if not row:
|
||
continue
|
||
key = row["trade_id"]
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
out.append(row)
|
||
return out
|
||
|
||
def refresh_trades(self) -> None:
|
||
"""成交仅读 vnpy 内存回报;不调用 query_trade(避免 CTP 段错误)。"""
|
||
return
|
||
|
||
def list_trades(self, *, refresh: bool = False) -> list[dict[str, Any]]:
|
||
with _ctp_td_lock:
|
||
out = self._collect_engine_trades()
|
||
out.sort(key=lambda r: (r.get("datetime") or "", r.get("trade_id") or ""))
|
||
return out
|
||
|
||
def list_active_orders(self) -> list[dict[str, Any]]:
|
||
if not self._engine:
|
||
return []
|
||
out: list[dict[str, Any]] = []
|
||
try:
|
||
orders = self._engine.get_all_active_orders()
|
||
except Exception:
|
||
return []
|
||
for order in orders or []:
|
||
status = getattr(order, "status", None)
|
||
status_s = str(status)
|
||
if status_s and not any(x in status_s for x in ("NOTTRADED", "PARTTRADED", "SUBMITTING")):
|
||
continue
|
||
vol = int(getattr(order, "volume", 0) or 0)
|
||
traded = int(getattr(order, "traded", 0) or 0)
|
||
remain = max(0, vol - traded)
|
||
if remain <= 0:
|
||
continue
|
||
direction = getattr(order, "direction", None)
|
||
d = "long"
|
||
if direction is not None and str(direction).endswith("SHORT"):
|
||
d = "short"
|
||
offset = getattr(order, "offset", None)
|
||
offset_s = str(offset or "")
|
||
sym = getattr(order, "symbol", "") or ""
|
||
exchange = getattr(order, "exchange", None)
|
||
ex_name = str(exchange.value if hasattr(exchange, "value") else exchange or "")
|
||
vt_oid = str(getattr(order, "vt_orderid", "") or "")
|
||
order_id = str(getattr(order, "orderid", "") or "")
|
||
out.append({
|
||
"symbol": sym,
|
||
"exchange": ex_name,
|
||
"direction": d,
|
||
"lots": remain,
|
||
"price": float(getattr(order, "price", 0) or 0),
|
||
"offset": offset_s,
|
||
"order_id": vt_oid or order_id,
|
||
"vt_order_id": vt_oid,
|
||
"status": status_s,
|
||
})
|
||
return out
|
||
|
||
def send_order(
|
||
self,
|
||
*,
|
||
ths_code: str,
|
||
offset: str,
|
||
direction: str,
|
||
lots: int,
|
||
price: float,
|
||
order_type: str = "limit",
|
||
) -> str:
|
||
from vnpy.trader.constant import Direction, Offset, OrderType
|
||
from vnpy.trader.object import OrderRequest
|
||
|
||
if not self._engine:
|
||
raise RuntimeError("CTP 未初始化")
|
||
if not self._td_logged_in():
|
||
raise RuntimeError("CTP 交易通道未登录,请重连后再下单")
|
||
|
||
sym, ex_name = ths_to_vnpy_symbol(ths_code)
|
||
exchange = to_vnpy_exchange(ex_name)
|
||
lots = max(1, int(lots))
|
||
tick = float(get_contract_spec(ths_code).get("tick_size") or 1.0)
|
||
|
||
offset = (offset or "open").lower()
|
||
direction = (direction or "long").lower()
|
||
|
||
if offset in ("open", "open_long", "open_short"):
|
||
d = Direction.LONG if direction == "long" or offset == "open_long" else Direction.SHORT
|
||
off = Offset.OPEN
|
||
elif offset in ("close", "close_long", "close_short"):
|
||
hold = "long" if direction == "long" or offset == "close_long" else "short"
|
||
if hold == "long":
|
||
d = Direction.SHORT
|
||
else:
|
||
d = Direction.LONG
|
||
off = self._resolve_close_offset(sym, ex_name, hold, lots)
|
||
else:
|
||
raise ValueError(f"未知开平: {offset}")
|
||
|
||
use_market = (order_type or "limit").lower() == "market"
|
||
if use_market:
|
||
ot = OrderType.FAK
|
||
price = self._aggressive_limit_price(ths_code, sym, ex_name, d, tick, price)
|
||
else:
|
||
ot = OrderType.LIMIT
|
||
price = round_to_tick(float(price), tick)
|
||
if price <= 0:
|
||
raise ValueError("委托价格无效,请检查行情或手动填写价格")
|
||
|
||
req = OrderRequest(
|
||
symbol=sym,
|
||
exchange=exchange,
|
||
direction=d,
|
||
type=ot,
|
||
volume=lots,
|
||
price=price,
|
||
offset=off,
|
||
)
|
||
logger.info(
|
||
"CTP 报单 %s %s %s %s手 @%s offset=%s type=%s",
|
||
sym, ex_name, d, lots, price, off, ot,
|
||
)
|
||
with _ctp_td_lock:
|
||
vt_orderid = self._engine.send_order(req, GATEWAY_NAME)
|
||
if not vt_orderid:
|
||
raise RuntimeError("CTP 拒单或未返回委托号(请检查合约代码、价格是否为最小变动价位整数倍)")
|
||
return str(vt_orderid)
|
||
|
||
def cancel_order(self, vt_orderid: str) -> bool:
|
||
if not self._engine or not vt_orderid:
|
||
return False
|
||
try:
|
||
with _ctp_td_lock:
|
||
order = self._engine.get_order(vt_orderid)
|
||
if order is None:
|
||
return False
|
||
req = order.create_cancel_request()
|
||
self._engine.cancel_order(req, GATEWAY_NAME)
|
||
logger.info("CTP 撤单 %s", vt_orderid)
|
||
return True
|
||
except Exception as exc:
|
||
logger.warning("CTP 撤单失败 %s: %s", vt_orderid, exc)
|
||
return False
|
||
|
||
|
||
def get_bridge() -> CtpBridge:
|
||
global _bridge
|
||
with _bridge_lock:
|
||
if _bridge is None:
|
||
_bridge = CtpBridge()
|
||
return _bridge
|
||
|
||
|
||
def try_init_vnpy(_settings: dict | None = None) -> bool:
|
||
return get_bridge().available()
|
||
|
||
|
||
def vnpy_available() -> bool:
|
||
return get_bridge().available()
|
||
|
||
|
||
def _ctp_connect_permitted(*, scheduled: bool = False) -> bool:
|
||
"""scheduled=True:盘前/交易时段计划连接,不受「自动连接」开关限制。"""
|
||
from ctp_settings import is_ctp_auto_connect_enabled
|
||
|
||
if is_ctp_auto_connect_enabled():
|
||
return True
|
||
if not scheduled:
|
||
return False
|
||
from ctp_premarket_connect import should_auto_connect_now
|
||
|
||
return should_auto_connect_now()
|
||
|
||
|
||
def ctp_disconnect(*, set_disabled_hint: bool = False) -> None:
|
||
"""主动断开 CTP 并清理内存状态。"""
|
||
from ctp_settings import CTP_DISABLED_HINT
|
||
|
||
b = get_bridge()
|
||
b._close_gateway()
|
||
if set_disabled_hint:
|
||
b._last_error = CTP_DISABLED_HINT
|
||
_persist_last_error(CTP_DISABLED_HINT)
|
||
else:
|
||
b._last_error = ""
|
||
_persist_last_error("")
|
||
|
||
|
||
def ctp_connect(mode: str, *, force: bool = False) -> dict[str, Any]:
|
||
b = get_bridge()
|
||
b.connect(mode, force=force)
|
||
return b.status(mode)
|
||
|
||
|
||
def ctp_start_connect(mode: str, *, force: bool = False, scheduled: bool = False) -> dict[str, Any]:
|
||
"""非阻塞发起连接,供 Web API 使用。"""
|
||
b = get_bridge()
|
||
info = b.start_connect_async(mode, force=force, scheduled=scheduled)
|
||
st = b.status(mode)
|
||
return {**info, "status": st}
|
||
|
||
|
||
def ctp_try_auto_reconnect(mode: str) -> bool:
|
||
"""断线时静默异步重连;已连接且交易通道正常则不再重复 connect。"""
|
||
if not _ctp_connect_permitted(scheduled=True):
|
||
return False
|
||
b = get_bridge()
|
||
if not b.available():
|
||
return False
|
||
if b.connect_in_progress():
|
||
return False
|
||
if b.login_cooldown_remaining() > 0:
|
||
return False
|
||
st = _setting_for_mode(mode)
|
||
if not st.get("用户名") or not st.get("密码") or not st.get("交易服务器"):
|
||
return False
|
||
if b.connected_mode == mode:
|
||
if b._td_logged_in() or b.ping():
|
||
return True
|
||
recent = time.time() - float(getattr(b, "_last_connect_ok_ts", 0) or 0)
|
||
if recent < 120:
|
||
logger.debug("CTP 跳过自动重连:刚连接 %.0fs", recent)
|
||
return True
|
||
td = st.get("交易服务器", "")
|
||
ok, err = probe_tcp_address(td, timeout=4.0)
|
||
if not ok:
|
||
b._last_error = (
|
||
f"SimNow 交易前置不可达:{td}({err})。"
|
||
"请更新 SIMNOW_TD_ADDRESS 并确认服务器出网。"
|
||
)
|
||
return False
|
||
info = b.start_connect_async(mode, force=False, scheduled=True)
|
||
return bool(
|
||
info.get("connected")
|
||
or info.get("connecting")
|
||
or info.get("started")
|
||
)
|
||
|
||
|
||
def ctp_status(mode: str) -> dict[str, Any]:
|
||
from ctp_settings import CTP_DISABLED_HINT, is_ctp_auto_connect_enabled
|
||
|
||
auto = is_ctp_auto_connect_enabled()
|
||
st = get_bridge().status(mode)
|
||
st["auto_connect_enabled"] = auto
|
||
if not auto:
|
||
st["disabled_hint"] = CTP_DISABLED_HINT
|
||
if not st.get("connected") and not st.get("connecting"):
|
||
st["last_error"] = ""
|
||
st["td_reachable"] = None
|
||
return st
|
||
if not st.get("connected") and not st.get("connecting"):
|
||
setting = _setting_for_mode(mode)
|
||
td = setting.get("交易服务器", "")
|
||
if td:
|
||
ok, err = probe_tcp_address(td, timeout=3.0)
|
||
st["td_reachable"] = ok
|
||
if not ok and not st.get("last_error"):
|
||
st["last_error"] = (
|
||
f"SimNow 交易前置不可达:{td}({err})"
|
||
)
|
||
return st
|
||
|
||
|
||
def ctp_get_account(mode: str) -> dict[str, Any]:
|
||
b = get_bridge()
|
||
b.ensure_connected(mode)
|
||
return b.get_account()
|
||
|
||
|
||
def ctp_sum_position_margins(
|
||
mode: str,
|
||
*,
|
||
refresh_if_empty: bool = True,
|
||
refresh_margin: bool = False,
|
||
) -> float:
|
||
"""各持仓 CTP 回报保证金之和(与柜台「实收保证金」一致)。"""
|
||
total = 0.0
|
||
for p in ctp_list_positions(
|
||
mode, refresh_if_empty=refresh_if_empty, refresh_margin=refresh_margin,
|
||
):
|
||
m = float(p.get("margin") or 0)
|
||
if m > 0:
|
||
total += m
|
||
return round(total, 2) if total > 0 else 0.0
|
||
|
||
|
||
def ctp_account_margin_used(mode: str) -> Optional[float]:
|
||
"""账户实际占用保证金 ≈ 权益 − 可用(与顶栏柜台资金一致)。"""
|
||
b = get_bridge()
|
||
if b.connected_mode != mode or not b.ping():
|
||
return None
|
||
try:
|
||
acc = b.get_account()
|
||
balance = float(acc.get("balance") or 0)
|
||
available = float(acc.get("available") or 0)
|
||
if balance <= 0:
|
||
return None
|
||
used = balance - available
|
||
return round(used, 2) if used > 0 else None
|
||
except Exception as exc:
|
||
logger.debug("ctp_account_margin_used: %s", exc)
|
||
return None
|
||
|
||
|
||
def ctp_list_positions(
|
||
mode: str,
|
||
*,
|
||
refresh_if_empty: bool = True,
|
||
refresh_margin: bool = False,
|
||
) -> list[dict[str, Any]]:
|
||
b = get_bridge()
|
||
if b.connected_mode != mode or not b.ping():
|
||
return []
|
||
return b.list_positions(refresh_if_empty=refresh_if_empty, refresh_margin=refresh_margin)
|
||
|
||
|
||
def ctp_list_active_orders(mode: str) -> list[dict[str, Any]]:
|
||
b = get_bridge()
|
||
b.ensure_connected(mode)
|
||
return b.list_active_orders()
|
||
|
||
|
||
def ctp_cancel_order(mode: str, vt_orderid: str) -> bool:
|
||
b = get_bridge()
|
||
b.ensure_connected(mode)
|
||
return b.cancel_order(vt_orderid)
|
||
|
||
|
||
def ctp_list_trades(mode: str, *, refresh: bool = False) -> list[dict[str, Any]]:
|
||
b = get_bridge()
|
||
if b.connected_mode != mode or not b.ping():
|
||
return []
|
||
return b.list_trades(refresh=refresh)
|
||
|
||
|
||
def ctp_get_tick_price(mode: str, ths_code: str) -> Optional[float]:
|
||
"""CTP 柜台最新价(需已连接并订阅)。"""
|
||
b = get_bridge()
|
||
if b.connected_mode != mode:
|
||
return None
|
||
try:
|
||
return b.get_tick_price(ths_code, mode=mode)
|
||
except Exception as exc:
|
||
logger.debug("ctp_get_tick_price: %s", exc)
|
||
return None
|
||
|
||
|
||
def ctp_get_tick_detail(mode: str, ths_code: str) -> dict[str, Any]:
|
||
b = get_bridge()
|
||
if b.connected_mode != mode:
|
||
return {}
|
||
try:
|
||
return b.get_tick_detail(ths_code, mode=mode)
|
||
except Exception as exc:
|
||
logger.debug("ctp_get_tick_detail: %s", exc)
|
||
return {}
|
||
|
||
|
||
def ctp_estimate_margin_one_lot(
|
||
mode: str,
|
||
ths_code: str,
|
||
price: float,
|
||
*,
|
||
direction: str = "long",
|
||
) -> Optional[float]:
|
||
b = get_bridge()
|
||
if b.connected_mode != mode or not b.ping():
|
||
return None
|
||
try:
|
||
return b.estimate_margin_one_lot(ths_code, price, direction=direction)
|
||
except Exception as exc:
|
||
logger.debug("ctp_estimate_margin_one_lot: %s", exc)
|
||
return None
|
||
|
||
|
||
def ctp_lookup_contract_spec(mode: str, ths_code: str) -> Optional[dict]:
|
||
b = get_bridge()
|
||
if b.connected_mode != mode or not b.ping():
|
||
return None
|
||
try:
|
||
return b.lookup_contract_spec(ths_code)
|
||
except Exception as exc:
|
||
logger.debug("ctp_lookup_contract_spec: %s", exc)
|
||
return None
|
||
|
||
|
||
def get_ctp_balance(mode: str) -> Optional[float]:
|
||
try:
|
||
acc = ctp_get_account(mode)
|
||
bal = acc.get("balance")
|
||
return float(bal) if bal else None
|
||
except Exception as exc:
|
||
logger.debug("get_ctp_balance: %s", exc)
|
||
return None
|
||
|
||
|
||
def execute_order(
|
||
conn,
|
||
*,
|
||
mode: str,
|
||
offset: str,
|
||
symbol: str,
|
||
direction: str,
|
||
lots: int,
|
||
price: float,
|
||
settings: dict | None = None,
|
||
order_type: str = "limit",
|
||
) -> dict[str, Any]:
|
||
"""统一下单:simulation=SimNow,live=期货公司 CTP。"""
|
||
del conn, settings
|
||
if mode not in ("simulation", "live"):
|
||
raise ValueError("未知交易模式")
|
||
if not vnpy_available():
|
||
raise ValueError(
|
||
"请先安装 vnpy 与 vnpy_ctp:pip install vnpy vnpy_ctp\n"
|
||
f"模拟盘需配置 .env 中 SIMNOW_USER / SIMNOW_PASSWORD 等"
|
||
)
|
||
b = get_bridge()
|
||
b.require_connected(mode)
|
||
order_id = b.send_order(
|
||
ths_code=symbol,
|
||
offset=offset,
|
||
direction=direction,
|
||
lots=lots,
|
||
price=price,
|
||
order_type=order_type,
|
||
)
|
||
return {
|
||
"order_id": order_id,
|
||
"mode": mode,
|
||
"mode_label": _mode_label(mode),
|
||
"symbol": symbol,
|
||
"lots": lots,
|
||
"price": price,
|
||
}
|