diff --git a/scripts/backfill_trend_strategy_snapshots.py b/scripts/backfill_trend_strategy_snapshots.py new file mode 100644 index 0000000..151273f --- /dev/null +++ b/scripts/backfill_trend_strategy_snapshots.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 +"""补录缺失的趋势回调策略结束快照(strategy_trade_snapshots)。 + +适用:gate_bot 等在计划结束(止盈/止损/手动)时因 strategy_trend_cfg 未注册而漏写快照的历史数据。 +保本移交路径通常已有快照,本脚本默认跳过「已有任意快照」的计划。 + +用法(在仓库根目录): + python scripts/backfill_trend_strategy_snapshots.py \\ + --db crypto_monitor_gate_bot/crypto.db --dry-run + python scripts/backfill_trend_strategy_snapshots.py \\ + --db crypto_monitor_gate_bot/crypto.db --apply +""" +from __future__ import annotations + +import argparse +import sqlite3 +import sys +from pathlib import Path + +_REPO_ROOT = Path(__file__).resolve().parents[1] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from strategy_snapshot_lib import ( # noqa: E402 + STRATEGY_TREND, + init_strategy_snapshot_table, + save_trend_plan_snapshot, +) + +PLAN_STATUS_LABEL = { + "stopped_sl": "止损", + "stopped_tp": "止盈", + "stopped_manual": "手动平仓", + "stopped_handoff": "保本移交", +} + +TRADE_RESULT_LABEL = { + "止损": "止损", + "止盈": "止盈", + "手动平仓": "手动平仓", + "移动止盈": "止盈", + "保本止盈": "止盈", + "强制清仓": "手动平仓", +} + + +def _row_dict(row) -> dict: + if row is None: + return {} + try: + return dict(row) + except Exception: + return {} + + +def infer_exit_price( + direction: str, + entry: float | None, + margin: float | None, + leverage: float | None, + pnl: float | None, +) -> float | None: + """由本地 calc_pnl 口径反推平仓价(供补录快照 exit_price)。""" + try: + trigger = float(entry) + margin_f = float(margin) + lev = float(leverage) + pnl_f = float(pnl) + except (TypeError, ValueError): + return None + if trigger <= 0 or margin_f <= 0 or lev <= 0: + return None + notional = margin_f * lev + if notional <= 0: + return None + ratio = pnl_f / notional + if (direction or "long").strip().lower() == "short": + return round(trigger * (1.0 - ratio), 10) + return round(trigger * (1.0 + ratio), 10) + + +def resolve_result_label(plan: dict, trade: dict | None) -> str: + status = (plan.get("status") or "").strip() + if status in PLAN_STATUS_LABEL: + return PLAN_STATUS_LABEL[status] + if trade: + res = (trade.get("result") or "").strip() + if res in TRADE_RESULT_LABEL: + return TRADE_RESULT_LABEL[res] + if res: + return res + msg = (plan.get("message") or "").strip() + if msg: + return msg[:32] + return "结束" + + +def find_missing_plans( + conn: sqlite3.Connection, + *, + plan_id: int | None = None, + since: str | None = None, +) -> list[dict]: + sql = """ + SELECT p.* + FROM trend_pullback_plans p + WHERE TRIM(COALESCE(p.status, '')) != 'active' + AND NOT EXISTS ( + SELECT 1 FROM strategy_trade_snapshots s + WHERE s.strategy_type = ? AND s.source_id = p.id + ) + """ + params: list[object] = [STRATEGY_TREND] + if plan_id is not None: + sql += " AND p.id = ?" + params.append(int(plan_id)) + if since: + sql += " AND COALESCE(p.opened_at, '') >= ?" + params.append(since.strip()) + sql += " ORDER BY p.id ASC" + rows = conn.execute(sql, params).fetchall() + return [_row_dict(r) for r in rows] + + +def fetch_trade_for_plan(conn: sqlite3.Connection, plan_id: int) -> dict | None: + row = conn.execute( + """ + SELECT * FROM trade_records + WHERE trend_plan_id = ? + ORDER BY COALESCE(closed_at_ms, 0) DESC, id DESC + LIMIT 1 + """, + (int(plan_id),), + ).fetchone() + return _row_dict(row) if row else None + + +def backfill_one(conn: sqlite3.Connection, plan: dict, *, dry_run: bool) -> dict: + plan_id = int(plan["id"]) + trade = fetch_trade_for_plan(conn, plan_id) + result_label = resolve_result_label(plan, trade) + pnl_amount = None + closed_at = None + exit_price = None + entry = plan.get("avg_entry_price") or plan.get("live_price_ref") + margin = plan.get("plan_margin_capital") + leverage = plan.get("leverage") + + if trade: + pnl_amount = trade.get("pnl_amount") + closed_at = trade.get("closed_at") + entry = trade.get("trigger_price") or entry + margin = trade.get("margin_capital") or margin + leverage = trade.get("leverage") or leverage + exit_price = infer_exit_price( + plan.get("direction") or trade.get("direction") or "long", + entry, + margin, + leverage, + pnl_amount, + ) + + info = { + "plan_id": plan_id, + "symbol": plan.get("symbol"), + "status": plan.get("status"), + "result_label": result_label, + "closed_at": closed_at, + "pnl_amount": pnl_amount, + "exit_price": exit_price, + "legs_done": plan.get("legs_done"), + "dca_legs": plan.get("dca_legs"), + "has_trade": bool(trade), + } + + if dry_run: + return info + + save_trend_plan_snapshot( + {}, + conn, + plan, + result_label=result_label, + exit_price=exit_price, + pnl_amount=float(pnl_amount) if pnl_amount is not None else None, + closed_at=closed_at, + ) + return info + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Backfill missing trend_pullback strategy_trade_snapshots rows." + ) + parser.add_argument("--db", required=True, help="Path to instance sqlite db") + parser.add_argument("--plan-id", type=int, help="Only backfill this trend plan id") + parser.add_argument( + "--since", + help="Only plans with opened_at >= YYYY-MM-DD (optional)", + ) + parser.add_argument("--dry-run", action="store_true", help="Preview only (default)") + parser.add_argument("--apply", action="store_true", help="Write snapshots") + args = parser.parse_args() + if not args.dry_run and not args.apply: + args.dry_run = True + + db_path = Path(args.db).expanduser().resolve() + if not db_path.is_file(): + print(f"[ERR] DB not found: {db_path}") + return 1 + + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + init_strategy_snapshot_table(conn) + + missing = find_missing_plans( + conn, plan_id=args.plan_id, since=args.since + ) + if not missing: + print("[INFO] No closed trend plans missing strategy snapshots.") + conn.close() + return 0 + + print(f"[INFO] Found {len(missing)} plan(s) without strategy snapshot.") + applied = 0 + for plan in missing: + info = backfill_one(conn, plan, dry_run=not args.apply) + trade_hint = "有交易记录" if info["has_trade"] else "无交易记录" + print( + f" - plan #{info['plan_id']} {info['symbol']} " + f"status={info['status']} → {info['result_label']} " + f"closed={info['closed_at'] or '—'} pnl={info['pnl_amount']} " + f"补仓 {info['legs_done']}/{info['dca_legs']} ({trade_hint})" + ) + applied += 1 + + if args.apply: + conn.commit() + print(f"[OK] Backfilled {applied} snapshot(s).") + else: + print("[DRY-RUN] No changes written. Re-run with --apply to commit.") + + conn.close() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/strategy_snapshot_lib.py b/strategy_snapshot_lib.py index 513051e..4446fda 100644 --- a/strategy_snapshot_lib.py +++ b/strategy_snapshot_lib.py @@ -127,6 +127,7 @@ def save_trend_plan_snapshot( result_label: str, exit_price: float | None = None, pnl_amount: float | None = None, + closed_at: str | None = None, ) -> None: init_strategy_snapshot_table(conn) row = _row_dict(plan_row) @@ -134,7 +135,7 @@ def save_trend_plan_snapshot( if plan_id <= 0: return m = cfg.get("app_module") - closed_at = ( + close_ts = (closed_at or "").strip() or ( m.app_now_str() if m is not None and hasattr(m, "app_now_str") else datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") @@ -158,10 +159,10 @@ def save_trend_plan_snapshot( result_label, row.get("status"), row.get("opened_at"), - closed_at, + close_ts, pnl_amount, _json_dumps(payload), - closed_at, + close_ts, ), ) prune_strategy_snapshots(conn, keep=STRATEGY_SNAPSHOTS_MAX_ROWS) diff --git a/tests/test_backfill_trend_snapshots.py b/tests/test_backfill_trend_snapshots.py new file mode 100644 index 0000000..18be2a4 --- /dev/null +++ b/tests/test_backfill_trend_snapshots.py @@ -0,0 +1,22 @@ +"""Tests for trend strategy snapshot backfill helpers.""" +from scripts.backfill_trend_strategy_snapshots import ( + infer_exit_price, + resolve_result_label, +) + + +def test_infer_exit_price_short_stop_loss(): + exit_p = infer_exit_price("short", 0.336, 4.85, 10, -2.45) + assert exit_p is not None + assert abs(exit_p - 0.353) < 0.002 + + +def test_resolve_result_label_from_plan_status(): + plan = {"status": "stopped_sl", "message": "stopped_sl"} + assert resolve_result_label(plan, None) == "止损" + + +def test_resolve_result_label_prefers_plan_status(): + plan = {"status": "stopped_sl"} + trade = {"result": "移动止盈"} + assert resolve_result_label(plan, trade) == "止损"