首次上传

This commit is contained in:
dekun
2026-05-16 22:25:48 +08:00
commit 2b8f902548
88 changed files with 16386 additions and 0 deletions
@@ -0,0 +1,428 @@
from __future__ import annotations
import argparse
import csv
from collections import deque
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Iterable
@dataclass
class Bar:
ts: str
open: float
high: float
low: float
close: float
@dataclass
class Trade:
side: str # LONG | SHORT
entry_ts: str
entry_price: float
exit_ts: str
exit_price: float
reason: str
gross_return_pct: float
net_return_pct: float
@dataclass
class BarWithEpoch:
bar: Bar
ts_epoch: int
def _to_float(row: dict[str, str], key: str) -> float:
raw = (row.get(key) or "").strip()
if not raw:
raise ValueError(f"Empty numeric field: {key}")
return float(raw)
def _read_csv(
csv_path: Path,
ts_col: str,
open_col: str,
high_col: str,
low_col: str,
close_col: str,
) -> list[Bar]:
bars: list[Bar] = []
with csv_path.open("r", encoding="utf-8-sig", newline="") as f:
reader = csv.DictReader(f)
needed = {ts_col, open_col, high_col, low_col, close_col}
missing = [c for c in needed if c not in (reader.fieldnames or [])]
if missing:
raise ValueError(f"CSV 缺少列: {missing}. 当前列: {reader.fieldnames}")
for row in reader:
bars.append(
Bar(
ts=str(row[ts_col]),
open=_to_float(row, open_col),
high=_to_float(row, high_col),
low=_to_float(row, low_col),
close=_to_float(row, close_col),
)
)
if len(bars) < 200:
raise ValueError(f"数据量过少: {len(bars)} 行,无法可靠回测。")
return bars
def _parse_ts_to_epoch_seconds(ts_raw: str) -> int:
s = str(ts_raw).strip()
if not s:
raise ValueError("timestamp is empty")
if s.isdigit() or (s.startswith("-") and s[1:].isdigit()):
n = int(s)
# 13 digits => milliseconds
if abs(n) >= 10_000_000_000:
return int(n / 1000)
return n
s_norm = s.replace("Z", "+00:00")
try:
return int(datetime.fromisoformat(s_norm).timestamp())
except ValueError as exc:
raise ValueError(f"Unsupported timestamp format: {s}") from exc
def _sort_and_attach_epoch(bars: list[Bar]) -> list[BarWithEpoch]:
enriched = [BarWithEpoch(bar=b, ts_epoch=_parse_ts_to_epoch_seconds(b.ts)) for b in bars]
enriched.sort(key=lambda x: x.ts_epoch)
return enriched
def _aggregate_bars(bars_1m: list[Bar], timeframe_minutes: int) -> list[Bar]:
if timeframe_minutes <= 1:
return bars_1m
src = _sort_and_attach_epoch(bars_1m)
if not src:
return []
out: list[Bar] = []
bucket_sec = timeframe_minutes * 60
cur_bucket = None
agg_open = agg_high = agg_low = agg_close = 0.0
agg_ts = ""
for item in src:
b = item.bar
bucket = (item.ts_epoch // bucket_sec) * bucket_sec
if cur_bucket is None or bucket != cur_bucket:
if cur_bucket is not None:
out.append(Bar(ts=agg_ts, open=agg_open, high=agg_high, low=agg_low, close=agg_close))
cur_bucket = bucket
agg_open = b.open
agg_high = b.high
agg_low = b.low
agg_close = b.close
agg_ts = datetime.utcfromtimestamp(bucket).isoformat() + "Z"
else:
agg_high = max(agg_high, b.high)
agg_low = min(agg_low, b.low)
agg_close = b.close
if cur_bucket is not None:
out.append(Bar(ts=agg_ts, open=agg_open, high=agg_high, low=agg_low, close=agg_close))
return out
def _parse_timeframe_to_minutes(tf: str) -> int:
s = tf.strip().lower()
if s.endswith("m"):
return int(s[:-1])
if s.endswith("h"):
return int(s[:-1]) * 60
raise ValueError(f"Unsupported timeframe: {tf}. Use like 15m,30m,1h")
def _calc_stats(equity_curve: Iterable[float], trades: list[Trade], initial_capital: float) -> dict[str, float]:
curve = list(equity_curve)
if not curve:
return {}
final_capital = curve[-1]
total_return_pct = (final_capital / initial_capital - 1.0) * 100.0
peak = curve[0]
max_dd = 0.0
for eq in curve:
if eq > peak:
peak = eq
dd = (eq / peak - 1.0) * 100.0
if dd < max_dd:
max_dd = dd
wins = [t for t in trades if t.net_return_pct > 0]
losses = [t for t in trades if t.net_return_pct <= 0]
win_rate = (len(wins) / len(trades) * 100.0) if trades else 0.0
avg_win = sum(t.net_return_pct for t in wins) / len(wins) if wins else 0.0
avg_loss = sum(t.net_return_pct for t in losses) / len(losses) if losses else 0.0
profit_factor = (
abs(sum(t.net_return_pct for t in wins) / sum(t.net_return_pct for t in losses))
if losses and sum(t.net_return_pct for t in losses) != 0
else 0.0
)
return {
"initial_capital": initial_capital,
"final_capital": final_capital,
"total_return_pct": total_return_pct,
"max_drawdown_pct": max_dd,
"total_trades": float(len(trades)),
"win_rate_pct": win_rate,
"avg_win_pct": avg_win,
"avg_loss_pct": avg_loss,
"profit_factor": profit_factor,
}
def run_backtest(
bars: list[Bar],
box_len: int,
buf_pct: float,
min_box_pct: float,
sl_pct: float,
tp_pct: float,
commission_pct: float,
initial_capital: float,
) -> tuple[list[Trade], list[float]]:
high_win: deque[float] = deque(maxlen=box_len)
low_win: deque[float] = deque(maxlen=box_len)
close_hist: list[float] = []
trades: list[Trade] = []
equity_curve: list[float] = [initial_capital]
capital = initial_capital
position = 0 # 1 long, -1 short, 0 flat
entry_price = 0.0
entry_ts = ""
for i, bar in enumerate(bars):
# Build history first
close_hist.append(bar.close)
if i == 0:
high_win.append(bar.high)
low_win.append(bar.low)
continue
# Exit check (intrabar, after entry bar)
if position != 0:
if position == 1:
stop = entry_price * (1 - sl_pct / 100.0)
take = entry_price * (1 + tp_pct / 100.0)
exit_price = 0.0
reason = ""
# Conservative tie-break: stop first if both touched same bar
if bar.low <= stop:
exit_price, reason = stop, "SL"
elif bar.high >= take:
exit_price, reason = take, "TP"
if reason:
gross_ret = (exit_price / entry_price - 1.0) * 100.0
net_ret = gross_ret - 2 * commission_pct
capital *= 1 + net_ret / 100.0
trades.append(
Trade("LONG", entry_ts, entry_price, bar.ts, exit_price, reason, gross_ret, net_ret)
)
equity_curve.append(capital)
position = 0
elif position == -1:
stop = entry_price * (1 + sl_pct / 100.0)
take = entry_price * (1 - tp_pct / 100.0)
exit_price = 0.0
reason = ""
if bar.high >= stop:
exit_price, reason = stop, "SL"
elif bar.low <= take:
exit_price, reason = take, "TP"
if reason:
gross_ret = (entry_price / exit_price - 1.0) * 100.0
net_ret = gross_ret - 2 * commission_pct
capital *= 1 + net_ret / 100.0
trades.append(
Trade("SHORT", entry_ts, entry_price, bar.ts, exit_price, reason, gross_ret, net_ret)
)
equity_curve.append(capital)
position = 0
# Need full lookback and previous close for crossover.
if len(high_win) < box_len or len(low_win) < box_len or i < 2:
high_win.append(bar.high)
low_win.append(bar.low)
continue
box_high = max(high_win)
box_low = min(low_win)
box_mid = (box_high + box_low) / 2.0
box_pct = ((box_high - box_low) / box_mid * 100.0) if box_mid > 0 else 0.0
box_ok = box_pct >= min_box_pct
up_line = box_high * (1 + buf_pct / 100.0)
dn_line = box_low * (1 - buf_pct / 100.0)
prev_close = close_hist[-2]
long_trig = box_ok and prev_close <= up_line and bar.close > up_line
short_trig = box_ok and prev_close >= dn_line and bar.close < dn_line
# Reverse signal close at close price then flip.
if position == 1 and short_trig:
gross_ret = (bar.close / entry_price - 1.0) * 100.0
net_ret = gross_ret - 2 * commission_pct
capital *= 1 + net_ret / 100.0
trades.append(Trade("LONG", entry_ts, entry_price, bar.ts, bar.close, "REVERSE", gross_ret, net_ret))
equity_curve.append(capital)
position = 0
elif position == -1 and long_trig:
gross_ret = (entry_price / bar.close - 1.0) * 100.0
net_ret = gross_ret - 2 * commission_pct
capital *= 1 + net_ret / 100.0
trades.append(Trade("SHORT", entry_ts, entry_price, bar.ts, bar.close, "REVERSE", gross_ret, net_ret))
equity_curve.append(capital)
position = 0
if position == 0:
if long_trig:
position = 1
entry_price = bar.close
entry_ts = bar.ts
elif short_trig:
position = -1
entry_price = bar.close
entry_ts = bar.ts
high_win.append(bar.high)
low_win.append(bar.low)
# Force close at final close
if position != 0:
last = bars[-1]
if position == 1:
gross_ret = (last.close / entry_price - 1.0) * 100.0
side = "LONG"
else:
gross_ret = (entry_price / last.close - 1.0) * 100.0
side = "SHORT"
net_ret = gross_ret - 2 * commission_pct
capital *= 1 + net_ret / 100.0
trades.append(Trade(side, entry_ts, entry_price, last.ts, last.close, "FORCE_CLOSE", gross_ret, net_ret))
equity_curve.append(capital)
return trades, equity_curve
def _save_trades(path: Path, trades: list[Trade]) -> None:
with path.open("w", encoding="utf-8", newline="") as f:
w = csv.writer(f)
w.writerow(
[
"side",
"entry_ts",
"entry_price",
"exit_ts",
"exit_price",
"reason",
"gross_return_pct",
"net_return_pct",
]
)
for t in trades:
w.writerow(
[
t.side,
t.entry_ts,
f"{t.entry_price:.8f}",
t.exit_ts,
f"{t.exit_price:.8f}",
t.reason,
f"{t.gross_return_pct:.6f}",
f"{t.net_return_pct:.6f}",
]
)
def main() -> None:
parser = argparse.ArgumentParser(description="ETH 1m 裸K箱体突破回测")
parser.add_argument("--csv", required=True, help="K线 CSV 路径")
parser.add_argument("--ts-col", default="timestamp", help="时间列名")
parser.add_argument("--open-col", default="open", help="开盘列名")
parser.add_argument("--high-col", default="high", help="最高列名")
parser.add_argument("--low-col", default="low", help="最低列名")
parser.add_argument("--close-col", default="close", help="收盘列名")
parser.add_argument("--box-len", type=int, default=80, help="箱体回看K数")
parser.add_argument("--buf-pct", type=float, default=0.03, help="突破缓冲百分比")
parser.add_argument("--min-box-pct", type=float, default=1.5, help="最小箱体宽度百分比")
parser.add_argument("--sl-pct", type=float, default=0.8, help="止损百分比")
parser.add_argument("--tp-pct", type=float, default=2.4, help="止盈百分比")
parser.add_argument("--commission-pct", type=float, default=0.05, help="单边手续费百分比")
parser.add_argument("--capital", type=float, default=100000.0, help="初始资金")
parser.add_argument("--out", default="runtime/backtest_trades.csv", help="交易明细输出路径")
parser.add_argument(
"--timeframes",
default="15m,30m,1h",
help="回测周期,逗号分隔;会从1m聚合,如: 15m,30m,1h",
)
args = parser.parse_args()
csv_path = Path(args.csv).expanduser().resolve()
if not csv_path.exists():
raise FileNotFoundError(f"CSV 不存在: {csv_path}")
bars = _read_csv(
csv_path,
ts_col=args.ts_col,
open_col=args.open_col,
high_col=args.high_col,
low_col=args.low_col,
close_col=args.close_col,
)
tfs = [x.strip() for x in str(args.timeframes).split(",") if x.strip()]
if not tfs:
raise ValueError("timeframes 不能为空")
base_out = Path(args.out).expanduser().resolve()
base_out.parent.mkdir(parents=True, exist_ok=True)
print("=== Backtest Done (1m聚合多周期) ===")
print(f"source_1m_bars: {len(bars)}")
print(f"source_period: {bars[0].ts} -> {bars[-1].ts}")
print("")
for tf in tfs:
minutes = _parse_timeframe_to_minutes(tf)
agg = _aggregate_bars(bars, minutes)
trades, curve = run_backtest(
bars=agg,
box_len=args.box_len,
buf_pct=args.buf_pct,
min_box_pct=args.min_box_pct,
sl_pct=args.sl_pct,
tp_pct=args.tp_pct,
commission_pct=args.commission_pct,
initial_capital=args.capital,
)
stats = _calc_stats(curve, trades, args.capital)
out_path = base_out.with_name(f"{base_out.stem}_{tf}{base_out.suffix}")
_save_trades(out_path, trades)
print(f"[{tf}] bars={len(agg)} trades={int(stats.get('total_trades', 0))}")
print(f" period: {agg[0].ts} -> {agg[-1].ts}")
print(f" final_capital: {stats.get('final_capital', 0):.2f}")
print(f" total_return: {stats.get('total_return_pct', 0):.2f}%")
print(f" max_drawdown: {stats.get('max_drawdown_pct', 0):.2f}%")
print(f" win_rate: {stats.get('win_rate_pct', 0):.2f}%")
print(f" profit_factor: {stats.get('profit_factor', 0):.3f}")
print(f" trades_csv: {out_path}")
print("")
print(f"generated_at: {datetime.now().isoformat(timespec='seconds')}")
if __name__ == "__main__":
main()