Fix trade log equity_after to chain from initial capital by close time.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-30 21:41:37 +08:00
parent 8d2d09396b
commit 0b924fca87
4 changed files with 74 additions and 13 deletions
+10 -1
View File
@@ -1241,7 +1241,7 @@ def close_position(pid):
result = classify_close_result(direction, close_price, sl, tp) result = classify_close_result(direction, close_price, sl, tp)
minutes = holding_to_minutes(open_time, close_time) minutes = holding_to_minutes(open_time, close_time)
margin_pct = metrics.get("position_pct") margin_pct = metrics.get("position_pct")
from trade_log_lib import calc_equity_after from trade_log_lib import calc_equity_after, refresh_trade_log_equity_chain
equity_after = calc_equity_after(capital, pnl_net) equity_after = calc_equity_after(capital, pnl_net)
conn.execute( conn.execute(
"""INSERT INTO trade_logs """INSERT INTO trade_logs
@@ -1258,6 +1258,10 @@ def close_position(pid):
), ),
) )
conn.execute("DELETE FROM position_monitors WHERE id=?", (pid,)) conn.execute("DELETE FROM position_monitors WHERE id=?", (pid,))
try:
refresh_trade_log_equity_chain(conn, capital if capital > 0 else None)
except Exception as exc:
app.logger.debug("equity chain refresh after close: %s", exc)
conn.commit() conn.commit()
conn.close() conn.close()
touch_stats_cache() touch_stats_cache()
@@ -1301,6 +1305,11 @@ def update_trade(tid):
tid, tid,
), ),
) )
try:
cap = float(get_setting("live_capital", "0") or 0)
refresh_trade_log_equity_chain(conn, cap if cap > 0 else None)
except Exception as exc:
app.logger.debug("equity chain refresh after trade edit: %s", exc)
conn.commit() conn.commit()
conn.close() conn.close()
touch_stats_cache() touch_stats_cache()
+10 -1
View File
@@ -16,7 +16,12 @@ from contract_specs import calc_position_metrics
from ctp_symbol import ths_to_vnpy_symbol from ctp_symbol import ths_to_vnpy_symbol
from fee_specs import calc_round_trip_fee from fee_specs import calc_round_trip_fee
from symbols import ths_to_codes from symbols import ths_to_codes
from trade_log_lib import calc_equity_after, purge_duplicate_local_trade_logs, ensure_trade_log_columns from trade_log_lib import (
calc_equity_after,
purge_duplicate_local_trade_logs,
ensure_trade_log_columns,
refresh_trade_log_equity_chain,
)
from vnpy_bridge import ctp_list_trades, ctp_status from vnpy_bridge import ctp_list_trades, ctp_status
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -325,4 +330,8 @@ def sync_trade_logs_from_ctp(
purged = purge_duplicate_local_trade_logs(conn) purged = purge_duplicate_local_trade_logs(conn)
if purged: if purged:
stats["purged"] = purged stats["purged"] = purged
try:
refresh_trade_log_equity_chain(conn)
except Exception as exc:
logger.debug("equity chain refresh after ctp sync: %s", exc)
return stats return stats
+5 -1
View File
@@ -16,7 +16,7 @@ from zoneinfo import ZoneInfo
from contract_specs import calc_position_metrics from contract_specs import calc_position_metrics
from ctp_symbol import ths_to_vnpy_symbol from ctp_symbol import ths_to_vnpy_symbol
from fee_specs import calc_round_trip_fee from fee_specs import calc_round_trip_fee
from trade_log_lib import calc_equity_after from trade_log_lib import calc_equity_after, refresh_trade_log_equity_chain
from market_sessions import is_trading_session from market_sessions import is_trading_session
from symbols import ths_to_codes from symbols import ths_to_codes
from vnpy_bridge import ( from vnpy_bridge import (
@@ -336,6 +336,10 @@ def write_trade_log(
result if result in TRADE_RESULTS else "手动平仓", result if result in TRADE_RESULTS else "手动平仓",
), ),
) )
try:
refresh_trade_log_equity_chain(conn, capital if capital > 0 else None)
except Exception as exc:
logger.debug("equity chain refresh after trade log: %s", exc)
try: try:
from stats_engine import refresh_stats_cache from stats_engine import refresh_stats_cache
refresh_stats_cache(conn, capital) refresh_stats_cache(conn, capital)
+49 -10
View File
@@ -32,6 +32,42 @@ def calc_equity_after(capital: float, pnl_net: float) -> float | None:
return round(cap + float(pnl_net or 0), 2) return round(cap + float(pnl_net or 0), 2)
def _read_initial_capital(conn, initial_capital: float | None = None) -> float:
if initial_capital is not None and initial_capital > 0:
return float(initial_capital)
try:
row = conn.execute("SELECT value FROM settings WHERE key='live_capital'").fetchone()
return float(row[0] or 0) if row else 0.0
except (TypeError, ValueError):
return 0.0
def refresh_trade_log_equity_chain(
conn,
initial_capital: float | None = None,
) -> int:
"""按平仓时间顺序重算 trade_logs.equity_after(起始=参考资金 live_capital)。"""
base = _read_initial_capital(conn, initial_capital)
rows = [
dict(r)
for r in conn.execute(
"SELECT id, close_time, pnl_net FROM trade_logs ORDER BY close_time ASC, id ASC"
).fetchall()
]
running = float(base or 0)
updated = 0
for row in rows:
if running <= 0:
break
running = round(running + float(row.get("pnl_net") or 0), 2)
conn.execute(
"UPDATE trade_logs SET equity_after=? WHERE id=?",
(running, int(row["id"])),
)
updated += 1
return updated
def _norm_symbol(symbol: str) -> str: def _norm_symbol(symbol: str) -> str:
return (symbol or "").split(".")[0].strip().lower() return (symbol or "").split(".")[0].strip().lower()
@@ -105,23 +141,21 @@ def enrich_trades_for_records(
) )
running = float(initial_capital or 0) running = float(initial_capital or 0)
curve: list[dict[str, Any]] = [] curve: list[dict[str, Any]] = []
equity_by_id: dict[int, float | None] = {}
for t in chrono: for t in chrono:
_attach_symbol_meta(t) _attach_symbol_meta(t)
pnl_net = float(t.get("pnl_net") or 0) pnl_net = float(t.get("pnl_net") or 0)
eq = t.get("equity_after") if running > 0:
if eq is None: running = round(running + pnl_net, 2)
if running > 0: eq: float | None = running
eq = round(running + pnl_net, 2) else:
else: eq = None
eq = None equity_by_id[int(t.get("id") or 0)] = eq
t["equity_after"] = eq
if eq is not None:
running = float(eq)
cap_before = float(eq or 0) - pnl_net if eq is not None else 0.0
if t.get("margin_pct") is None: if t.get("margin_pct") is None:
margin = float(t.get("margin") or 0) margin = float(t.get("margin") or 0)
cap_before = float(eq or 0) - pnl_net if eq is not None else 0.0
if margin > 0 and cap_before > 0: if margin > 0 and cap_before > 0:
t["margin_pct"] = round(margin / cap_before * 100, 2) t["margin_pct"] = round(margin / cap_before * 100, 2)
@@ -132,4 +166,9 @@ def enrich_trades_for_records(
"id": int(t.get("id") or 0), "id": int(t.get("id") or 0),
}) })
for t in rows:
tid = int(t.get("id") or 0)
if tid in equity_by_id:
t["equity_after"] = equity_by_id[tid]
return rows, curve return rows, curve