diff --git a/scripts/dedupe_strategy_snapshots.py b/scripts/dedupe_strategy_snapshots.py new file mode 100644 index 0000000..bff8690 --- /dev/null +++ b/scripts/dedupe_strategy_snapshots.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +"""清理 strategy_trade_snapshots 重复行(同计划 + 同结果仅保留 id 最大的一条)。 + +用法(在实例目录,如 crypto_monitor_gate_bot): + python ../scripts/dedupe_strategy_snapshots.py + python ../scripts/dedupe_strategy_snapshots.py --db crypto.db +""" +from __future__ import annotations + +import argparse +import os +import sqlite3 +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT)) + +from strategy_snapshot_lib import dedupe_strategy_snapshots, init_strategy_snapshot_table # noqa: E402 + + +def main() -> int: + parser = argparse.ArgumentParser(description="Dedupe strategy_trade_snapshots rows.") + parser.add_argument( + "--db", + default=os.getenv("DB_PATH", "crypto.db"), + help="SQLite database path (default: DB_PATH or crypto.db)", + ) + parser.add_argument("--dry-run", action="store_true", help="Count only, do not delete") + args = parser.parse_args() + + db_path = Path(args.db) + if not db_path.is_file(): + print(f"DB not found: {db_path}", file=sys.stderr) + return 1 + + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + init_strategy_snapshot_table(conn) + before = conn.execute("SELECT COUNT(*) AS c FROM strategy_trade_snapshots").fetchone()["c"] + dup_groups = conn.execute( + """SELECT strategy_type, source_id, result_label, COUNT(*) AS n + FROM strategy_trade_snapshots + GROUP BY strategy_type, source_id, result_label + HAVING n > 1 + ORDER BY n DESC""" + ).fetchall() + extra = sum(int(r["n"]) - 1 for r in dup_groups) + print(f"snapshots total={before}, duplicate rows to remove={extra}, groups={len(dup_groups)}") + for r in dup_groups[:20]: + print( + f" {r['strategy_type']} plan={r['source_id']} " + f"{r['result_label']} x{r['n']}" + ) + if args.dry_run: + conn.close() + return 0 + removed = dedupe_strategy_snapshots(conn) + conn.commit() + after = conn.execute("SELECT COUNT(*) AS c FROM strategy_trade_snapshots").fetchone()["c"] + conn.close() + print(f"removed={removed}, remaining={after}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/strategy_records_register.py b/strategy_records_register.py index 0df37ff..5c1a604 100644 --- a/strategy_records_register.py +++ b/strategy_records_register.py @@ -8,6 +8,7 @@ from flask import flash, redirect, url_for from strategy_snapshot_lib import ( STRATEGY_SNAPSHOTS_MAX_ROWS, + dedupe_strategy_snapshots, list_strategy_snapshots_split, ) @@ -15,6 +16,11 @@ from strategy_snapshot_lib import ( def load_strategy_records_page( conn, *, limit: int = STRATEGY_SNAPSHOTS_MAX_ROWS ) -> dict[str, Any]: + try: + if dedupe_strategy_snapshots(conn): + conn.commit() + except Exception: + pass trend, roll, symbols = list_strategy_snapshots_split(conn, limit=limit) return { "strategy_trend_records": trend, diff --git a/strategy_register.py b/strategy_register.py index 26eafef..b736ed4 100644 --- a/strategy_register.py +++ b/strategy_register.py @@ -12,6 +12,29 @@ from strategy_db import init_strategy_tables from strategy_roll_lib import preview_roll +def _dedupe_strategy_snapshots_on_startup(cfg: dict[str, Any]) -> None: + """启动时清理历史重复快照(同计划同结果仅保留最新一条)。""" + get_db = cfg.get("get_db") + if not callable(get_db): + return + try: + from strategy_snapshot_lib import dedupe_strategy_snapshots + + conn = get_db() + try: + removed = dedupe_strategy_snapshots(conn) + if removed: + conn.commit() + print( + f"[strategy] deduped {removed} duplicate strategy_trade_snapshots", + flush=True, + ) + finally: + conn.close() + except Exception as e: + print(f"[strategy] snapshot dedupe skipped: {e}", flush=True) + + def install_strategy_trading(app: Flask, repo_root: str, app_module: Any = None, **build_kw) -> None: """在 app.py 末尾调用(login_required 已定义后)。仅注册 POST API;页面由各 app 的 render_main_page 渲染。""" from strategy_config import build_strategy_config @@ -24,6 +47,7 @@ def install_strategy_trading(app: Flask, repo_root: str, app_module: Any = None, register_strategy_records(app, cfg) app.extensions["strategy_roll_cfg"] = cfg + _dedupe_strategy_snapshots_on_startup(cfg) def attach_strategy_templates(app: Flask, repo_root: str) -> None: diff --git a/strategy_snapshot_lib.py b/strategy_snapshot_lib.py index 77df764..31b107e 100644 --- a/strategy_snapshot_lib.py +++ b/strategy_snapshot_lib.py @@ -119,6 +119,38 @@ def attach_trend_dca_levels(plan: dict) -> dict: return d +def _snapshot_key_exists( + conn, strategy_type: str, source_id: int, result_label: str +) -> bool: + if source_id <= 0: + return False + label = (result_label or "").strip() + row = conn.execute( + """SELECT 1 FROM strategy_trade_snapshots + WHERE strategy_type=? AND source_id=? AND result_label=? + LIMIT 1""", + (strategy_type, int(source_id), label), + ).fetchone() + return row is not None + + +def dedupe_strategy_snapshots(conn) -> int: + """删除同源同结果的重复快照,仅保留每组最大 id。""" + init_strategy_snapshot_table(conn) + cur = conn.execute( + """DELETE FROM strategy_trade_snapshots + WHERE id IN ( + SELECT s1.id FROM strategy_trade_snapshots s1 + INNER JOIN strategy_trade_snapshots s2 + ON s1.strategy_type = s2.strategy_type + AND s1.source_id = s2.source_id + AND s1.result_label = s2.result_label + AND s1.id < s2.id + )""" + ) + return int(getattr(cur, "rowcount", 0) or 0) + + def save_trend_plan_snapshot( cfg: dict, conn, @@ -134,6 +166,9 @@ def save_trend_plan_snapshot( plan_id = int(row.get("id") or 0) if plan_id <= 0: return + label = (result_label or "").strip() + if _snapshot_key_exists(conn, STRATEGY_TREND, plan_id, label): + return m = cfg.get("app_module") close_ts = (closed_at or "").strip() or ( m.app_now_str() @@ -181,6 +216,9 @@ def save_roll_group_snapshot( gid = int(g.get("id") or 0) if gid <= 0: return + label = (result_label or "结束").strip() + if _snapshot_key_exists(conn, STRATEGY_ROLL, gid, label): + return legs = [] for leg in conn.execute( "SELECT * FROM roll_legs WHERE roll_group_id=? ORDER BY leg_index ASC, id ASC", @@ -231,6 +269,7 @@ def save_roll_group_snapshot( def prune_strategy_snapshots(conn, *, keep: int = STRATEGY_SNAPSHOTS_MAX_ROWS) -> None: """仅保留最近 keep 条策略快照(按 closed_at / id 倒序)。""" + dedupe_strategy_snapshots(conn) k = max(1, min(int(keep), 500)) conn.execute( """DELETE FROM strategy_trade_snapshots @@ -360,6 +399,7 @@ def list_strategy_snapshots(conn, *, limit: int = 200) -> list[dict]: (max(1, min(int(limit), 500)),), ).fetchall() out = [] + seen: dict[tuple[str, int, str], int] = {} for r in rows: d = _row_dict(r) try: @@ -368,7 +408,20 @@ def list_strategy_snapshots(conn, *, limit: int = 200) -> list[dict]: d["snapshot"] = {} st = (d.get("strategy_type") or "").strip() d["strategy_label"] = "趋势回调" if st == STRATEGY_TREND else "顺势加仓" - out.append(enrich_strategy_snapshot_row(d)) + enriched = enrich_strategy_snapshot_row(d) + try: + source_id = int(enriched.get("source_id") or 0) + except (TypeError, ValueError): + source_id = 0 + key = (st, source_id, (enriched.get("result_label") or "").strip()) + snap_id = int(enriched.get("id") or 0) + prev = seen.get(key) + if prev is not None and prev >= snap_id: + continue + if prev is not None: + out = [x for x in out if int(x.get("id") or 0) != prev] + seen[key] = snap_id + out.append(enriched) return out diff --git a/strategy_trend_register.py b/strategy_trend_register.py index aad9b65..a054756 100644 --- a/strategy_trend_register.py +++ b/strategy_trend_register.py @@ -667,6 +667,47 @@ def _trend_plan_trade_exists(conn, plan_id: int) -> bool: return False +def _bump_session_capital_no_commit( + m, conn, session_date: str, pnl_amount: float +) -> float | None: + """更新当日资金,不单独 commit(与 _finalize_plan 同一事务)。""" + try: + row = conn.execute( + "SELECT current_capital FROM trading_sessions WHERE session_date = ?", + (session_date,), + ).fetchone() + if not row: + start_cap = float(getattr(m, "DAILY_START_CAPITAL", 0) or 0) + if start_cap <= 0: + ensure = getattr(m, "ensure_session", None) + if callable(ensure): + ensured = ensure(conn, session_date) + row = ensured + else: + return None + else: + conn.execute( + "INSERT OR IGNORE INTO trading_sessions " + "(session_date, start_capital, current_capital) VALUES (?,?,?)", + (session_date, start_cap, start_cap), + ) + row = conn.execute( + "SELECT current_capital FROM trading_sessions WHERE session_date = ?", + (session_date,), + ).fetchone() + if not row: + return None + new_capital = float(row["current_capital"]) + float(pnl_amount) + conn.execute( + "UPDATE trading_sessions SET current_capital = ?, updated_at = CURRENT_TIMESTAMP " + "WHERE session_date = ?", + (round(new_capital, 4), session_date), + ) + return round(new_capital, 4) + except Exception: + return None + + def _finalize_plan(cfg: dict, conn, row, result_label: str, exit_price: float) -> None: m = _m(cfg) plan_id = int(row["id"]) @@ -700,6 +741,13 @@ def _finalize_plan(cfg: dict, conn, row, result_label: str, exit_price: float) - except (TypeError, ValueError): pass planned_rr = m.calc_rr_ratio(direction, avg_e, float(row["stop_loss"]), float(row["take_profit"])) + st = _plan_stop_status(result_label) + cur = conn.execute( + "UPDATE trend_pullback_plans SET status=?, message=? WHERE id=? AND status='active'", + (st, res, plan_id), + ) + if not getattr(cur, "rowcount", 0): + return try: from strategy_snapshot_lib import save_trend_plan_snapshot @@ -710,6 +758,7 @@ def _finalize_plan(cfg: dict, conn, row, result_label: str, exit_price: float) - result_label=result_label, exit_price=float(exit_price) if exit_price is not None else None, pnl_amount=float(pnl_amount) if pnl_amount is not None else None, + closed_at=closed_at, ) except Exception: pass @@ -717,9 +766,12 @@ def _finalize_plan(cfg: dict, conn, row, result_label: str, exit_price: float) - cancel_symbol_orders(cfg, ex_sym) except Exception: pass + session_capital = None if not _trend_plan_trade_exists(conn, plan_id): session_date = row["session_date"] or m.get_trading_day() - session_capital = m.update_session_capital(conn, session_date, pnl_amount) + session_capital = _bump_session_capital_no_commit( + m, conn, session_date, pnl_amount + ) _call_insert_trade_record( m, plan_id, @@ -746,15 +798,6 @@ def _finalize_plan(cfg: dict, conn, row, result_label: str, exit_price: float) - entry_reason=ENTRY_REASON_TREND_PULLBACK, ), ) - else: - session_capital = None - st = _plan_stop_status(result_label) - cur = conn.execute( - "UPDATE trend_pullback_plans SET status=?, message=? WHERE id=? AND status='active'", - (st, res, plan_id), - ) - if not getattr(cur, "rowcount", 0): - return conn.commit() try: from strategy_wechat_notify import notify_trend_plan_ended diff --git a/tests/test_strategy_snapshot_dedup.py b/tests/test_strategy_snapshot_dedup.py new file mode 100644 index 0000000..ef1282c --- /dev/null +++ b/tests/test_strategy_snapshot_dedup.py @@ -0,0 +1,151 @@ +"""策略快照:同一计划同结果不重复写入。""" +from __future__ import annotations + +import json +import sqlite3 +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT)) + +from strategy_snapshot_lib import ( # noqa: E402 + STRATEGY_TREND, + dedupe_strategy_snapshots, + init_strategy_snapshot_table, + list_strategy_snapshots, + save_trend_plan_snapshot, +) + + +def _mem_conn() -> sqlite3.Connection: + conn = sqlite3.connect(":memory:") + conn.row_factory = sqlite3.Row + init_strategy_snapshot_table(conn) + return conn + + +def test_save_trend_plan_snapshot_skips_duplicate_result(): + conn = _mem_conn() + plan = { + "id": 42, + "symbol": "ONDO/USDT", + "exchange_symbol": "ONDO/USDT:USDT", + "direction": "short", + "status": "active", + "opened_at": "2026-06-08 08:00:00", + "legs_done": 4, + "dca_legs": 4, + "first_order_done": 1, + "grid_prices_json": "[]", + "leg_amounts_json": "[]", + } + cfg = {"app_module": type("M", (), {"app_now_str": staticmethod(lambda: "2026-06-08 08:41:00")})()} + save_trend_plan_snapshot(cfg, conn, plan, result_label="止损", pnl_amount=-2.3) + save_trend_plan_snapshot(cfg, conn, plan, result_label="止损", pnl_amount=-2.4) + conn.commit() + rows = conn.execute( + "SELECT COUNT(*) AS c FROM strategy_trade_snapshots WHERE source_id=? AND result_label=?", + (42, "止损"), + ).fetchone() + assert int(rows["c"]) == 1 + + +def test_dedupe_strategy_snapshots_handles_many_duplicates(): + conn = _mem_conn() + payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False) + for snap_id in range(1, 46): + conn.execute( + """INSERT INTO strategy_trade_snapshots ( + id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount + ) VALUES (?,?,?,?,?,?,?,?,?)""", + ( + snap_id, + STRATEGY_TREND, + 99, + "ONDO/USDT", + "止损", + payload, + "2026-06-08 08:41:00", + "2026-06-08 08:41:00", + -2.2, + ), + ) + conn.commit() + removed = dedupe_strategy_snapshots(conn) + conn.commit() + assert removed == 44 + row = conn.execute( + "SELECT COUNT(*) AS c FROM strategy_trade_snapshots WHERE source_id=?", + (99,), + ).fetchone() + assert int(row["c"]) == 1 + + +def test_dedupe_strategy_snapshots_keeps_latest_id(): + conn = _mem_conn() + payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False) + for snap_id, pnl in ((1, -2.23), (2, -2.31), (3, -2.38)): + conn.execute( + """INSERT INTO strategy_trade_snapshots ( + id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount + ) VALUES (?,?,?,?,?,?,?,?,?)""", + ( + snap_id, + STRATEGY_TREND, + 5, + "ONDO/USDT", + "止损", + payload, + "2026-06-08 08:41:00", + "2026-06-08 08:41:00", + pnl, + ), + ) + conn.commit() + removed = dedupe_strategy_snapshots(conn) + conn.commit() + assert removed == 2 + row = conn.execute( + "SELECT id, pnl_amount FROM strategy_trade_snapshots WHERE source_id=?", + (5,), + ).fetchone() + assert int(row["id"]) == 3 + assert abs(float(row["pnl_amount"]) - (-2.38)) < 1e-6 + + +def test_list_strategy_snapshots_hides_duplicate_keys(): + conn = _mem_conn() + payload = json.dumps({"symbol": "ONDO/USDT", "dca_levels": []}, ensure_ascii=False) + for snap_id in (10, 11, 12): + conn.execute( + """INSERT INTO strategy_trade_snapshots ( + id, strategy_type, source_id, symbol, direction, result_label, + snapshot_json, closed_at, created_at, pnl_amount + ) VALUES (?,?,?,?,?,?,?,?,?,?)""", + ( + snap_id, + STRATEGY_TREND, + 7, + "ONDO/USDT", + "short", + "止损", + payload, + "2026-06-08 08:41:00", + "2026-06-08 08:41:00", + -2.2, + ), + ) + conn.commit() + rows = list_strategy_snapshots(conn, limit=50) + stop_rows = [r for r in rows if int(r.get("source_id") or 0) == 7] + assert len(stop_rows) == 1 + assert int(stop_rows[0]["id"]) == 12 + + +if __name__ == "__main__": + test_save_trend_plan_snapshot_skips_duplicate_result() + test_dedupe_strategy_snapshots_handles_many_duplicates() + test_dedupe_strategy_snapshots_keeps_latest_id() + test_list_strategy_snapshots_hides_duplicate_keys() + print("all ok")