Files
qihuo/sl_tp_guard.py
T

560 lines
18 KiB
Python

"""止盈止损守护:程序本地监控价位,触发后向 CTP 发平仓单(不向交易所挂 SL/TP 限价单)。"""
from __future__ import annotations
import logging
import threading
import time
from datetime import datetime
from typing import Any, Callable, Optional
from zoneinfo import ZoneInfo
from contract_specs import calc_position_metrics
from ctp_symbol import ths_to_vnpy_symbol
from fee_specs import calc_round_trip_fee
from market_sessions import is_trading_session
from symbols import ths_to_codes
from vnpy_bridge import (
ctp_cancel_order,
ctp_get_tick_price,
ctp_list_active_orders,
ctp_list_positions,
ctp_status,
execute_order,
)
logger = logging.getLogger(__name__)
TZ = ZoneInfo("Asia/Shanghai")
CHECK_INTERVAL_SEC = 1
CLOSED_MARKET_SLEEP_SEC = 30
DISCONNECTED_SLEEP_SEC = 5
PLACE_COOLDOWN_SEC = 3
_last_close_attempt: dict[int, float] = {}
_closing_monitors: set[int] = set()
_closing_lock = threading.Lock()
MONITOR_ORDER_COLUMNS = (
"ALTER TABLE trade_order_monitors ADD COLUMN sl_vt_order_id TEXT",
"ALTER TABLE trade_order_monitors ADD COLUMN tp_vt_order_id TEXT",
)
def ensure_monitor_order_columns(conn) -> None:
for sql in MONITOR_ORDER_COLUMNS:
try:
conn.execute(sql)
except Exception:
pass
def _tick_size(ths_code: str) -> float:
from contract_specs import get_contract_spec
return float(get_contract_spec(ths_code).get("tick_size") or 1.0)
def _match_symbol(ctp_sym: str, ths: str) -> bool:
a = (ctp_sym or "").lower()
b = (ths or "").lower()
if a == b:
return True
try:
vnpy_sym, _ = ths_to_vnpy_symbol(ths)
return a == vnpy_sym.lower()
except Exception:
return False
def _close_order_direction(hold_direction: str) -> str:
return "short" if hold_direction == "long" else "long"
def _price_near(a: float, b: float, tick: float) -> bool:
return abs(float(a) - float(b)) <= max(tick * 0.501, 1e-9)
def _find_close_order(
active_orders: list[dict],
*,
ths_code: str,
hold_direction: str,
price: float,
tick: float,
) -> Optional[dict]:
close_dir = _close_order_direction(hold_direction)
for o in active_orders:
sym = o.get("symbol") or ""
if not _match_symbol(sym, ths_code):
continue
offset_s = (o.get("offset") or "").upper()
if "CLOSE" not in offset_s:
continue
if (o.get("direction") or "") != close_dir:
continue
if not _price_near(o.get("price") or 0, price, tick):
continue
return o
return None
def _find_position(positions: list[dict], ths_code: str, direction: str) -> Optional[dict]:
for p in positions:
if int(p.get("lots") or 0) <= 0:
continue
if (p.get("direction") or "long") != direction:
continue
if _match_symbol(p.get("symbol") or "", ths_code):
return p
return None
def _can_close_now(monitor_id: int, *, cooldown: int = PLACE_COOLDOWN_SEC) -> bool:
last = _last_close_attempt.get(monitor_id, 0.0)
return (time.time() - last) >= cooldown
def _mark_close_attempt(monitor_id: int) -> None:
_last_close_attempt[monitor_id] = time.time()
def _try_acquire_close(monitor_id: int) -> bool:
with _closing_lock:
if monitor_id in _closing_monitors:
return False
_closing_monitors.add(monitor_id)
return True
def _release_close(monitor_id: int) -> None:
with _closing_lock:
_closing_monitors.discard(monitor_id)
def _monitor_type_label(raw: str) -> str:
mapping = {
"manual": "期货下单",
"trend": "趋势回调",
"roll": "顺势加仓",
}
return mapping.get(raw or "", raw or "程序监控")
def _write_trade_log(
conn,
mon: dict,
*,
close_price: float,
reason: str,
trading_mode: str,
capital: float = 0.0,
) -> None:
"""止盈/止损触发平仓后写入 trade_logs。"""
sym = (mon.get("symbol") or "").strip()
direction = (mon.get("direction") or "long").strip().lower()
entry = float(mon.get("entry_price") or close_price)
sl_raw = mon.get("stop_loss")
tp_raw = mon.get("take_profit")
sl = float(sl_raw) if sl_raw is not None else entry
tp = float(tp_raw) if tp_raw is not None else entry
lots = float(mon.get("lots") or 1)
open_time = (mon.get("open_time") or "").strip()
close_time = datetime.now(TZ).strftime("%Y-%m-%dT%H:%M")
codes = ths_to_codes(sym) or {}
sina_code = codes.get("sina_code") or ""
symbol_name = mon.get("symbol_name") or sym
market_code = mon.get("market_code") or codes.get("market_code") or ""
metrics = calc_position_metrics(
direction, entry, sl, tp, lots, close_price, capital, sym,
)
pnl = metrics.get("float_pnl") or 0.0
fee = calc_round_trip_fee(
sym, entry, close_price, lots, open_time, close_time, trading_mode=trading_mode,
)
pnl_net = round(pnl - fee, 2)
result = "止盈" if reason == "take_profit" else "止损"
try:
from app import holding_to_minutes
minutes = holding_to_minutes(open_time, close_time)
except Exception:
minutes = 0
conn.execute(
"""INSERT INTO trade_logs
(symbol, symbol_name, market_code, sina_code, monitor_type, direction,
entry_price, stop_loss, take_profit, close_price, lots, margin,
holding_minutes, open_time, close_time, pnl, fee, pnl_net, result)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
(
sym,
symbol_name,
market_code,
sina_code,
_monitor_type_label(mon.get("monitor_type") or ""),
direction,
entry,
sl_raw if sl_raw is not None else sl,
tp_raw if tp_raw is not None else tp,
close_price,
lots,
metrics.get("margin"),
minutes,
open_time,
close_time,
pnl,
fee,
pnl_net,
result,
),
)
try:
from stats_engine import refresh_stats_cache
refresh_stats_cache(conn, capital)
except Exception as exc:
logger.debug("stats refresh after SL/TP close: %s", exc)
def _sl_triggered(direction: str, sl: float, mark: float, tick: float) -> bool:
buf = max(tick * 0.01, 1e-9)
if direction == "long":
return mark <= sl + buf
return mark >= sl - buf
def _tp_triggered(direction: str, tp: float, mark: float, tick: float) -> bool:
buf = max(tick * 0.01, 1e-9)
if direction == "long":
return mark >= tp - buf
return mark <= tp + buf
def cancel_monitor_exit_orders(
conn,
mon: dict,
*,
mode: str,
) -> int:
"""撤销该监控在交易所残留的旧版止盈止损平仓挂单。"""
ensure_monitor_order_columns(conn)
if not ctp_status(mode).get("connected"):
return 0
sym = (mon.get("symbol") or "").strip()
direction = (mon.get("direction") or "long").strip().lower()
tick = _tick_size(sym)
active = ctp_list_active_orders(mode)
cancelled = 0
seen: set[str] = set()
def _try_cancel(vt_id: str) -> None:
nonlocal cancelled
oid = str(vt_id or "").strip()
if not oid or oid in seen:
return
seen.add(oid)
if ctp_cancel_order(mode, oid):
cancelled += 1
for kind, price_key in (("sl", "stop_loss"), ("tp", "take_profit")):
raw = mon.get(price_key)
try:
px = float(raw) if raw is not None else None
except (TypeError, ValueError):
px = None
stored = str(mon.get(f"{kind}_vt_order_id") or "")
if stored:
_try_cancel(stored)
if px is not None:
found = _find_close_order(
active, ths_code=sym, hold_direction=direction, price=px, tick=tick,
)
if found:
_try_cancel(str(found.get("order_id") or ""))
if cancelled:
conn.execute(
"UPDATE trade_order_monitors SET sl_vt_order_id=NULL, tp_vt_order_id=NULL WHERE id=?",
(mon["id"],),
)
conn.commit()
return cancelled
def reconcile_monitors_without_position(conn, mode: str) -> int:
"""持仓已平时:关闭监控并撤销残留止盈止损挂单。"""
if not ctp_status(mode).get("connected"):
return 0
positions = ctp_list_positions(mode)
position_keys: set[tuple[str, str]] = set()
for p in positions:
if int(p.get("lots") or 0) <= 0:
continue
sym = (p.get("symbol") or "").lower()
direction = p.get("direction") or "long"
position_keys.add((sym, direction))
closed = 0
for r in conn.execute("SELECT * FROM trade_order_monitors WHERE status='active'").fetchall():
mon = dict(r)
ms = mon.get("symbol") or ""
md = mon.get("direction") or "long"
matched = False
for ps, pd in position_keys:
if pd != md:
continue
if _match_symbol(ps, ms):
matched = True
break
if matched:
continue
try:
cancel_monitor_exit_orders(conn, mon, mode=mode)
except Exception as exc:
logger.warning("cancel exit orders monitor=%s: %s", mon.get("id"), exc)
conn.execute("UPDATE trade_order_monitors SET status='closed' WHERE id=?", (mon["id"],))
closed += 1
if closed:
conn.commit()
return closed
def _execute_local_close(
conn,
mon: dict,
*,
mode: str,
mark: float,
reason: str,
capital: float = 0.0,
notify_fn: Callable[[str], None] | None = None,
) -> None:
sym = (mon.get("symbol") or "").strip()
direction = (mon.get("direction") or "long").strip().lower()
positions = ctp_list_positions(mode)
pos = _find_position(positions, sym, direction)
if not pos:
reconcile_monitors_without_position(conn, mode)
return
lots = int(pos.get("lots") or mon.get("lots") or 1)
offset = "close_long" if direction == "long" else "close_short"
cancel_monitor_exit_orders(conn, mon, mode=mode)
execute_order(
conn,
mode=mode,
offset=offset,
symbol=sym,
direction=direction,
lots=lots,
price=mark,
order_type="market",
)
_write_trade_log(
conn,
mon,
close_price=mark,
reason=reason,
trading_mode=mode,
capital=capital,
)
conn.execute("UPDATE trade_order_monitors SET status='closed' WHERE id=?", (mon["id"],))
conn.commit()
label = "止盈" if reason == "take_profit" else "止损"
logger.info(
"止盈止损本地触发 monitor=%s reason=%s %s %s %d手 @%s",
mon.get("id"), reason, sym, direction, lots, mark,
)
if notify_fn:
try:
notify_fn(f"{label}平仓 {sym} {direction} {lots}手 @{mark},已记入交易记录")
except Exception as exc:
logger.debug("SL/TP notify failed: %s", exc)
def check_monitors_locally(
conn,
mode: str,
*,
capital: float = 0.0,
notify_fn: Callable[[str], None] | None = None,
) -> int:
"""扫描 active 监控,本地比对行情;触发止盈/止损(含跳空穿透)后立刻市价平仓并记交易记录。"""
ensure_monitor_order_columns(conn)
if not ctp_status(mode).get("connected"):
return 0
if not is_trading_session():
return 0
reconcile_monitors_without_position(conn, mode)
closed = 0
rows = conn.execute(
"SELECT * FROM trade_order_monitors WHERE status='active'"
).fetchall()
for r in rows:
mon = dict(r)
mid = int(mon.get("id") or 0)
sym = (mon.get("symbol") or "").strip()
direction = (mon.get("direction") or "long").strip().lower()
if mon.get("sl_vt_order_id") or mon.get("tp_vt_order_id"):
cancel_monitor_exit_orders(conn, mon, mode=mode)
sl = mon.get("stop_loss")
tp = mon.get("take_profit")
try:
sl_f = float(sl) if sl is not None else None
tp_f = float(tp) if tp is not None else None
except (TypeError, ValueError):
sl_f, tp_f = None, None
if sl_f is None and tp_f is None:
continue
positions = ctp_list_positions(mode)
if not _find_position(positions, sym, direction):
continue
mark = ctp_get_tick_price(mode, sym)
if mark is None or mark <= 0:
continue
tick = _tick_size(sym)
reason = None
if tp_f is not None and _tp_triggered(direction, tp_f, mark, tick):
reason = "take_profit"
elif sl_f is not None and _sl_triggered(direction, sl_f, mark, tick):
reason = "stop_loss"
if not reason:
continue
if mid > 0 and not _can_close_now(mid):
continue
if mid > 0 and not _try_acquire_close(mid):
continue
try:
_execute_local_close(
conn,
mon,
mode=mode,
mark=mark,
reason=reason,
capital=capital,
notify_fn=notify_fn,
)
if mid > 0:
_mark_close_attempt(mid)
closed += 1
except Exception as exc:
logger.warning("SL/TP local close failed monitor=%s: %s", mid, exc)
finally:
if mid > 0:
_release_close(mid)
return closed
def place_monitor_exit_orders(
conn,
mon: dict,
*,
mode: str,
force: bool = False,
) -> dict[str, Any]:
"""兼容旧 API:本地监控模式不再向交易所挂 SL/TP 单,仅清理旧挂单。"""
del force
ensure_monitor_order_columns(conn)
if not ctp_status(mode).get("connected"):
return {"ok": False, "error": "CTP 未连接", "placed": []}
cancelled = cancel_monitor_exit_orders(conn, mon, mode=mode)
msg = "程序本地监控中,不向交易所挂止盈止损单"
if cancelled:
msg += f";已撤销旧版柜台挂单 {cancelled}"
return {"ok": True, "message": msg, "placed": [], "local_monitor": True}
def monitor_order_status(
mon: dict,
*,
mode: str,
ths_code: str,
direction: str,
) -> dict[str, bool]:
"""返回本地监控状态(非交易所挂单状态)。"""
del mode, ths_code, direction
sl = mon.get("stop_loss") if mon else None
tp = mon.get("take_profit") if mon else None
try:
sl_f = float(sl) if sl is not None else None
tp_f = float(tp) if tp is not None else None
except (TypeError, ValueError):
sl_f, tp_f = None, None
return {
"sl_order_active": sl_f is not None,
"tp_order_active": tp_f is not None,
"sl_monitoring": sl_f is not None,
"tp_monitoring": tp_f is not None,
"needs_sl_order": False,
"needs_tp_order": False,
}
def sync_all_sl_tp_orders(conn, mode: str) -> int:
"""兼容旧 worker 入口:执行本地监控检查。"""
del mode
return 0
def start_sl_tp_guard_worker(
*,
db_path: str,
get_mode_fn: Callable[[], str],
init_tables_fn: Callable | None = None,
get_capital_fn: Callable | None = None,
notify_fn: Callable[[str], None] | None = None,
interval: int = CHECK_INTERVAL_SEC,
) -> None:
from db_conn import connect_db
def _loop() -> None:
time.sleep(8)
while True:
sleep_sec = max(1, interval)
try:
if not is_trading_session():
time.sleep(CLOSED_MARKET_SLEEP_SEC)
continue
mode = get_mode_fn()
if not ctp_status(mode).get("connected"):
time.sleep(DISCONNECTED_SLEEP_SEC)
continue
conn = connect_db(db_path)
try:
if init_tables_fn:
init_tables_fn(conn)
has_monitors = conn.execute(
"""SELECT COUNT(*) AS n FROM trade_order_monitors
WHERE status='active'
AND (stop_loss IS NOT NULL OR take_profit IS NOT NULL)"""
).fetchone()["n"]
if not has_monitors:
sleep_sec = max(sleep_sec, 5)
else:
capital = 0.0
if get_capital_fn:
try:
capital = float(get_capital_fn(conn) or 0)
except Exception:
capital = 0.0
n = check_monitors_locally(
conn,
mode,
capital=capital,
notify_fn=notify_fn,
)
if n:
logger.info("止盈止损本地监控: 触发平仓 %d", n)
finally:
conn.close()
except Exception as exc:
logger.warning("sl_tp_guard worker: %s", exc)
time.sleep(sleep_sec)
threading.Thread(target=_loop, daemon=True, name="sl-tp-guard").start()