Prevent duplicate strategy trade snapshots on plan close.

Finalize plans before writing snapshots, dedupe on startup and page load, and add a cleanup script for existing repeated rows.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-08 09:00:51 +08:00
parent ea92160d54
commit e71bfe095c
6 changed files with 355 additions and 11 deletions
+67
View File
@@ -0,0 +1,67 @@
#!/usr/bin/env python3
"""清理 strategy_trade_snapshots 重复行(同计划 + 同结果仅保留 id 最大的一条)。
用法(在实例目录,如 crypto_monitor_gate_bot):
python ../scripts/dedupe_strategy_snapshots.py
python ../scripts/dedupe_strategy_snapshots.py --db crypto.db
"""
from __future__ import annotations
import argparse
import os
import sqlite3
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from strategy_snapshot_lib import dedupe_strategy_snapshots, init_strategy_snapshot_table # noqa: E402
def main() -> int:
parser = argparse.ArgumentParser(description="Dedupe strategy_trade_snapshots rows.")
parser.add_argument(
"--db",
default=os.getenv("DB_PATH", "crypto.db"),
help="SQLite database path (default: DB_PATH or crypto.db)",
)
parser.add_argument("--dry-run", action="store_true", help="Count only, do not delete")
args = parser.parse_args()
db_path = Path(args.db)
if not db_path.is_file():
print(f"DB not found: {db_path}", file=sys.stderr)
return 1
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
init_strategy_snapshot_table(conn)
before = conn.execute("SELECT COUNT(*) AS c FROM strategy_trade_snapshots").fetchone()["c"]
dup_groups = conn.execute(
"""SELECT strategy_type, source_id, result_label, COUNT(*) AS n
FROM strategy_trade_snapshots
GROUP BY strategy_type, source_id, result_label
HAVING n > 1
ORDER BY n DESC"""
).fetchall()
extra = sum(int(r["n"]) - 1 for r in dup_groups)
print(f"snapshots total={before}, duplicate rows to remove={extra}, groups={len(dup_groups)}")
for r in dup_groups[:20]:
print(
f" {r['strategy_type']} plan={r['source_id']} "
f"{r['result_label']} x{r['n']}"
)
if args.dry_run:
conn.close()
return 0
removed = dedupe_strategy_snapshots(conn)
conn.commit()
after = conn.execute("SELECT COUNT(*) AS c FROM strategy_trade_snapshots").fetchone()["c"]
conn.close()
print(f"removed={removed}, remaining={after}")
return 0
if __name__ == "__main__":
raise SystemExit(main())
+6
View File
@@ -8,6 +8,7 @@ from flask import flash, redirect, url_for
from strategy_snapshot_lib import (
STRATEGY_SNAPSHOTS_MAX_ROWS,
dedupe_strategy_snapshots,
list_strategy_snapshots_split,
)
@@ -15,6 +16,11 @@ from strategy_snapshot_lib import (
def load_strategy_records_page(
conn, *, limit: int = STRATEGY_SNAPSHOTS_MAX_ROWS
) -> dict[str, Any]:
try:
if dedupe_strategy_snapshots(conn):
conn.commit()
except Exception:
pass
trend, roll, symbols = list_strategy_snapshots_split(conn, limit=limit)
return {
"strategy_trend_records": trend,
+24
View File
@@ -12,6 +12,29 @@ from strategy_db import init_strategy_tables
from strategy_roll_lib import preview_roll
def _dedupe_strategy_snapshots_on_startup(cfg: dict[str, Any]) -> None:
"""启动时清理历史重复快照(同计划同结果仅保留最新一条)。"""
get_db = cfg.get("get_db")
if not callable(get_db):
return
try:
from strategy_snapshot_lib import dedupe_strategy_snapshots
conn = get_db()
try:
removed = dedupe_strategy_snapshots(conn)
if removed:
conn.commit()
print(
f"[strategy] deduped {removed} duplicate strategy_trade_snapshots",
flush=True,
)
finally:
conn.close()
except Exception as e:
print(f"[strategy] snapshot dedupe skipped: {e}", flush=True)
def install_strategy_trading(app: Flask, repo_root: str, app_module: Any = None, **build_kw) -> None:
"""在 app.py 末尾调用(login_required 已定义后)。仅注册 POST API;页面由各 app 的 render_main_page 渲染。"""
from strategy_config import build_strategy_config
@@ -24,6 +47,7 @@ def install_strategy_trading(app: Flask, repo_root: str, app_module: Any = None,
register_strategy_records(app, cfg)
app.extensions["strategy_roll_cfg"] = cfg
_dedupe_strategy_snapshots_on_startup(cfg)
def attach_strategy_templates(app: Flask, repo_root: str) -> None:
+54 -1
View File
@@ -119,6 +119,38 @@ def attach_trend_dca_levels(plan: dict) -> dict:
return d
def _snapshot_key_exists(
conn, strategy_type: str, source_id: int, result_label: str
) -> bool:
if source_id <= 0:
return False
label = (result_label or "").strip()
row = conn.execute(
"""SELECT 1 FROM strategy_trade_snapshots
WHERE strategy_type=? AND source_id=? AND result_label=?
LIMIT 1""",
(strategy_type, int(source_id), label),
).fetchone()
return row is not None
def dedupe_strategy_snapshots(conn) -> int:
"""删除同源同结果的重复快照,仅保留每组最大 id。"""
init_strategy_snapshot_table(conn)
cur = conn.execute(
"""DELETE FROM strategy_trade_snapshots
WHERE id IN (
SELECT s1.id FROM strategy_trade_snapshots s1
INNER JOIN strategy_trade_snapshots s2
ON s1.strategy_type = s2.strategy_type
AND s1.source_id = s2.source_id
AND s1.result_label = s2.result_label
AND s1.id < s2.id
)"""
)
return int(getattr(cur, "rowcount", 0) or 0)
def save_trend_plan_snapshot(
cfg: dict,
conn,
@@ -134,6 +166,9 @@ def save_trend_plan_snapshot(
plan_id = int(row.get("id") or 0)
if plan_id <= 0:
return
label = (result_label or "").strip()
if _snapshot_key_exists(conn, STRATEGY_TREND, plan_id, label):
return
m = cfg.get("app_module")
close_ts = (closed_at or "").strip() or (
m.app_now_str()
@@ -181,6 +216,9 @@ def save_roll_group_snapshot(
gid = int(g.get("id") or 0)
if gid <= 0:
return
label = (result_label or "结束").strip()
if _snapshot_key_exists(conn, STRATEGY_ROLL, gid, label):
return
legs = []
for leg in conn.execute(
"SELECT * FROM roll_legs WHERE roll_group_id=? ORDER BY leg_index ASC, id ASC",
@@ -231,6 +269,7 @@ def save_roll_group_snapshot(
def prune_strategy_snapshots(conn, *, keep: int = STRATEGY_SNAPSHOTS_MAX_ROWS) -> None:
"""仅保留最近 keep 条策略快照(按 closed_at / id 倒序)。"""
dedupe_strategy_snapshots(conn)
k = max(1, min(int(keep), 500))
conn.execute(
"""DELETE FROM strategy_trade_snapshots
@@ -360,6 +399,7 @@ def list_strategy_snapshots(conn, *, limit: int = 200) -> list[dict]:
(max(1, min(int(limit), 500)),),
).fetchall()
out = []
seen: dict[tuple[str, int, str], int] = {}
for r in rows:
d = _row_dict(r)
try:
@@ -368,7 +408,20 @@ def list_strategy_snapshots(conn, *, limit: int = 200) -> list[dict]:
d["snapshot"] = {}
st = (d.get("strategy_type") or "").strip()
d["strategy_label"] = "趋势回调" if st == STRATEGY_TREND else "顺势加仓"
out.append(enrich_strategy_snapshot_row(d))
enriched = enrich_strategy_snapshot_row(d)
try:
source_id = int(enriched.get("source_id") or 0)
except (TypeError, ValueError):
source_id = 0
key = (st, source_id, (enriched.get("result_label") or "").strip())
snap_id = int(enriched.get("id") or 0)
prev = seen.get(key)
if prev is not None and prev >= snap_id:
continue
if prev is not None:
out = [x for x in out if int(x.get("id") or 0) != prev]
seen[key] = snap_id
out.append(enriched)
return out
+53 -10
View File
@@ -667,6 +667,47 @@ def _trend_plan_trade_exists(conn, plan_id: int) -> bool:
return False
def _bump_session_capital_no_commit(
m, conn, session_date: str, pnl_amount: float
) -> float | None:
"""更新当日资金,不单独 commit(与 _finalize_plan 同一事务)。"""
try:
row = conn.execute(
"SELECT current_capital FROM trading_sessions WHERE session_date = ?",
(session_date,),
).fetchone()
if not row:
start_cap = float(getattr(m, "DAILY_START_CAPITAL", 0) or 0)
if start_cap <= 0:
ensure = getattr(m, "ensure_session", None)
if callable(ensure):
ensured = ensure(conn, session_date)
row = ensured
else:
return None
else:
conn.execute(
"INSERT OR IGNORE INTO trading_sessions "
"(session_date, start_capital, current_capital) VALUES (?,?,?)",
(session_date, start_cap, start_cap),
)
row = conn.execute(
"SELECT current_capital FROM trading_sessions WHERE session_date = ?",
(session_date,),
).fetchone()
if not row:
return None
new_capital = float(row["current_capital"]) + float(pnl_amount)
conn.execute(
"UPDATE trading_sessions SET current_capital = ?, updated_at = CURRENT_TIMESTAMP "
"WHERE session_date = ?",
(round(new_capital, 4), session_date),
)
return round(new_capital, 4)
except Exception:
return None
def _finalize_plan(cfg: dict, conn, row, result_label: str, exit_price: float) -> None:
m = _m(cfg)
plan_id = int(row["id"])
@@ -700,6 +741,13 @@ def _finalize_plan(cfg: dict, conn, row, result_label: str, exit_price: float) -
except (TypeError, ValueError):
pass
planned_rr = m.calc_rr_ratio(direction, avg_e, float(row["stop_loss"]), float(row["take_profit"]))
st = _plan_stop_status(result_label)
cur = conn.execute(
"UPDATE trend_pullback_plans SET status=?, message=? WHERE id=? AND status='active'",
(st, res, plan_id),
)
if not getattr(cur, "rowcount", 0):
return
try:
from strategy_snapshot_lib import save_trend_plan_snapshot
@@ -710,6 +758,7 @@ def _finalize_plan(cfg: dict, conn, row, result_label: str, exit_price: float) -
result_label=result_label,
exit_price=float(exit_price) if exit_price is not None else None,
pnl_amount=float(pnl_amount) if pnl_amount is not None else None,
closed_at=closed_at,
)
except Exception:
pass
@@ -717,9 +766,12 @@ def _finalize_plan(cfg: dict, conn, row, result_label: str, exit_price: float) -
cancel_symbol_orders(cfg, ex_sym)
except Exception:
pass
session_capital = None
if not _trend_plan_trade_exists(conn, plan_id):
session_date = row["session_date"] or m.get_trading_day()
session_capital = m.update_session_capital(conn, session_date, pnl_amount)
session_capital = _bump_session_capital_no_commit(
m, conn, session_date, pnl_amount
)
_call_insert_trade_record(
m,
plan_id,
@@ -746,15 +798,6 @@ def _finalize_plan(cfg: dict, conn, row, result_label: str, exit_price: float) -
entry_reason=ENTRY_REASON_TREND_PULLBACK,
),
)
else:
session_capital = None
st = _plan_stop_status(result_label)
cur = conn.execute(
"UPDATE trend_pullback_plans SET status=?, message=? WHERE id=? AND status='active'",
(st, res, plan_id),
)
if not getattr(cur, "rowcount", 0):
return
conn.commit()
try:
from strategy_wechat_notify import notify_trend_plan_ended
+151
View File
@@ -0,0 +1,151 @@
"""策略快照:同一计划同结果不重复写入。"""
from __future__ import annotations
import json
import sqlite3
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from strategy_snapshot_lib import ( # noqa: E402
STRATEGY_TREND,
dedupe_strategy_snapshots,
init_strategy_snapshot_table,
list_strategy_snapshots,
save_trend_plan_snapshot,
)
def _mem_conn() -> sqlite3.Connection:
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
init_strategy_snapshot_table(conn)
return conn
def test_save_trend_plan_snapshot_skips_duplicate_result():
conn = _mem_conn()
plan = {
"id": 42,
"symbol": "ONDO/USDT",
"exchange_symbol": "ONDO/USDT:USDT",
"direction": "short",
"status": "active",
"opened_at": "2026-06-08 08:00:00",
"legs_done": 4,
"dca_legs": 4,
"first_order_done": 1,
"grid_prices_json": "[]",
"leg_amounts_json": "[]",
}
cfg = {"app_module": type("M", (), {"app_now_str": staticmethod(lambda: "2026-06-08 08:41:00")})()}
save_trend_plan_snapshot(cfg, conn, plan, result_label="止损", pnl_amount=-2.3)
save_trend_plan_snapshot(cfg, conn, plan, result_label="止损", pnl_amount=-2.4)
conn.commit()
rows = conn.execute(
"SELECT COUNT(*) AS c FROM strategy_trade_snapshots WHERE source_id=? AND result_label=?",
(42, "止损"),
).fetchone()
assert int(rows["c"]) == 1
def test_dedupe_strategy_snapshots_handles_many_duplicates():
conn = _mem_conn()
payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False)
for snap_id in range(1, 46):
conn.execute(
"""INSERT INTO strategy_trade_snapshots (
id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount
) VALUES (?,?,?,?,?,?,?,?,?)""",
(
snap_id,
STRATEGY_TREND,
99,
"ONDO/USDT",
"止损",
payload,
"2026-06-08 08:41:00",
"2026-06-08 08:41:00",
-2.2,
),
)
conn.commit()
removed = dedupe_strategy_snapshots(conn)
conn.commit()
assert removed == 44
row = conn.execute(
"SELECT COUNT(*) AS c FROM strategy_trade_snapshots WHERE source_id=?",
(99,),
).fetchone()
assert int(row["c"]) == 1
def test_dedupe_strategy_snapshots_keeps_latest_id():
conn = _mem_conn()
payload = json.dumps({"symbol": "ONDO/USDT"}, ensure_ascii=False)
for snap_id, pnl in ((1, -2.23), (2, -2.31), (3, -2.38)):
conn.execute(
"""INSERT INTO strategy_trade_snapshots (
id, strategy_type, source_id, symbol, result_label, snapshot_json, closed_at, created_at, pnl_amount
) VALUES (?,?,?,?,?,?,?,?,?)""",
(
snap_id,
STRATEGY_TREND,
5,
"ONDO/USDT",
"止损",
payload,
"2026-06-08 08:41:00",
"2026-06-08 08:41:00",
pnl,
),
)
conn.commit()
removed = dedupe_strategy_snapshots(conn)
conn.commit()
assert removed == 2
row = conn.execute(
"SELECT id, pnl_amount FROM strategy_trade_snapshots WHERE source_id=?",
(5,),
).fetchone()
assert int(row["id"]) == 3
assert abs(float(row["pnl_amount"]) - (-2.38)) < 1e-6
def test_list_strategy_snapshots_hides_duplicate_keys():
conn = _mem_conn()
payload = json.dumps({"symbol": "ONDO/USDT", "dca_levels": []}, ensure_ascii=False)
for snap_id in (10, 11, 12):
conn.execute(
"""INSERT INTO strategy_trade_snapshots (
id, strategy_type, source_id, symbol, direction, result_label,
snapshot_json, closed_at, created_at, pnl_amount
) VALUES (?,?,?,?,?,?,?,?,?,?)""",
(
snap_id,
STRATEGY_TREND,
7,
"ONDO/USDT",
"short",
"止损",
payload,
"2026-06-08 08:41:00",
"2026-06-08 08:41:00",
-2.2,
),
)
conn.commit()
rows = list_strategy_snapshots(conn, limit=50)
stop_rows = [r for r in rows if int(r.get("source_id") or 0) == 7]
assert len(stop_rows) == 1
assert int(stop_rows[0]["id"]) == 12
if __name__ == "__main__":
test_save_trend_plan_snapshot_skips_duplicate_result()
test_dedupe_strategy_snapshots_handles_many_duplicates()
test_dedupe_strategy_snapshots_keeps_latest_id()
test_list_strategy_snapshots_hides_duplicate_keys()
print("all ok")