首次上传
This commit is contained in:
@@ -0,0 +1,428 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
@dataclass
|
||||
class Bar:
|
||||
ts: str
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class Trade:
|
||||
side: str # LONG | SHORT
|
||||
entry_ts: str
|
||||
entry_price: float
|
||||
exit_ts: str
|
||||
exit_price: float
|
||||
reason: str
|
||||
gross_return_pct: float
|
||||
net_return_pct: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class BarWithEpoch:
|
||||
bar: Bar
|
||||
ts_epoch: int
|
||||
|
||||
|
||||
def _to_float(row: dict[str, str], key: str) -> float:
|
||||
raw = (row.get(key) or "").strip()
|
||||
if not raw:
|
||||
raise ValueError(f"Empty numeric field: {key}")
|
||||
return float(raw)
|
||||
|
||||
|
||||
def _read_csv(
|
||||
csv_path: Path,
|
||||
ts_col: str,
|
||||
open_col: str,
|
||||
high_col: str,
|
||||
low_col: str,
|
||||
close_col: str,
|
||||
) -> list[Bar]:
|
||||
bars: list[Bar] = []
|
||||
with csv_path.open("r", encoding="utf-8-sig", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
needed = {ts_col, open_col, high_col, low_col, close_col}
|
||||
missing = [c for c in needed if c not in (reader.fieldnames or [])]
|
||||
if missing:
|
||||
raise ValueError(f"CSV 缺少列: {missing}. 当前列: {reader.fieldnames}")
|
||||
for row in reader:
|
||||
bars.append(
|
||||
Bar(
|
||||
ts=str(row[ts_col]),
|
||||
open=_to_float(row, open_col),
|
||||
high=_to_float(row, high_col),
|
||||
low=_to_float(row, low_col),
|
||||
close=_to_float(row, close_col),
|
||||
)
|
||||
)
|
||||
if len(bars) < 200:
|
||||
raise ValueError(f"数据量过少: {len(bars)} 行,无法可靠回测。")
|
||||
return bars
|
||||
|
||||
|
||||
def _parse_ts_to_epoch_seconds(ts_raw: str) -> int:
|
||||
s = str(ts_raw).strip()
|
||||
if not s:
|
||||
raise ValueError("timestamp is empty")
|
||||
if s.isdigit() or (s.startswith("-") and s[1:].isdigit()):
|
||||
n = int(s)
|
||||
# 13 digits => milliseconds
|
||||
if abs(n) >= 10_000_000_000:
|
||||
return int(n / 1000)
|
||||
return n
|
||||
s_norm = s.replace("Z", "+00:00")
|
||||
try:
|
||||
return int(datetime.fromisoformat(s_norm).timestamp())
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Unsupported timestamp format: {s}") from exc
|
||||
|
||||
|
||||
def _sort_and_attach_epoch(bars: list[Bar]) -> list[BarWithEpoch]:
|
||||
enriched = [BarWithEpoch(bar=b, ts_epoch=_parse_ts_to_epoch_seconds(b.ts)) for b in bars]
|
||||
enriched.sort(key=lambda x: x.ts_epoch)
|
||||
return enriched
|
||||
|
||||
|
||||
def _aggregate_bars(bars_1m: list[Bar], timeframe_minutes: int) -> list[Bar]:
|
||||
if timeframe_minutes <= 1:
|
||||
return bars_1m
|
||||
src = _sort_and_attach_epoch(bars_1m)
|
||||
if not src:
|
||||
return []
|
||||
|
||||
out: list[Bar] = []
|
||||
bucket_sec = timeframe_minutes * 60
|
||||
|
||||
cur_bucket = None
|
||||
agg_open = agg_high = agg_low = agg_close = 0.0
|
||||
agg_ts = ""
|
||||
|
||||
for item in src:
|
||||
b = item.bar
|
||||
bucket = (item.ts_epoch // bucket_sec) * bucket_sec
|
||||
if cur_bucket is None or bucket != cur_bucket:
|
||||
if cur_bucket is not None:
|
||||
out.append(Bar(ts=agg_ts, open=agg_open, high=agg_high, low=agg_low, close=agg_close))
|
||||
cur_bucket = bucket
|
||||
agg_open = b.open
|
||||
agg_high = b.high
|
||||
agg_low = b.low
|
||||
agg_close = b.close
|
||||
agg_ts = datetime.utcfromtimestamp(bucket).isoformat() + "Z"
|
||||
else:
|
||||
agg_high = max(agg_high, b.high)
|
||||
agg_low = min(agg_low, b.low)
|
||||
agg_close = b.close
|
||||
|
||||
if cur_bucket is not None:
|
||||
out.append(Bar(ts=agg_ts, open=agg_open, high=agg_high, low=agg_low, close=agg_close))
|
||||
return out
|
||||
|
||||
|
||||
def _parse_timeframe_to_minutes(tf: str) -> int:
|
||||
s = tf.strip().lower()
|
||||
if s.endswith("m"):
|
||||
return int(s[:-1])
|
||||
if s.endswith("h"):
|
||||
return int(s[:-1]) * 60
|
||||
raise ValueError(f"Unsupported timeframe: {tf}. Use like 15m,30m,1h")
|
||||
|
||||
|
||||
def _calc_stats(equity_curve: Iterable[float], trades: list[Trade], initial_capital: float) -> dict[str, float]:
|
||||
curve = list(equity_curve)
|
||||
if not curve:
|
||||
return {}
|
||||
final_capital = curve[-1]
|
||||
total_return_pct = (final_capital / initial_capital - 1.0) * 100.0
|
||||
|
||||
peak = curve[0]
|
||||
max_dd = 0.0
|
||||
for eq in curve:
|
||||
if eq > peak:
|
||||
peak = eq
|
||||
dd = (eq / peak - 1.0) * 100.0
|
||||
if dd < max_dd:
|
||||
max_dd = dd
|
||||
|
||||
wins = [t for t in trades if t.net_return_pct > 0]
|
||||
losses = [t for t in trades if t.net_return_pct <= 0]
|
||||
win_rate = (len(wins) / len(trades) * 100.0) if trades else 0.0
|
||||
avg_win = sum(t.net_return_pct for t in wins) / len(wins) if wins else 0.0
|
||||
avg_loss = sum(t.net_return_pct for t in losses) / len(losses) if losses else 0.0
|
||||
profit_factor = (
|
||||
abs(sum(t.net_return_pct for t in wins) / sum(t.net_return_pct for t in losses))
|
||||
if losses and sum(t.net_return_pct for t in losses) != 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return {
|
||||
"initial_capital": initial_capital,
|
||||
"final_capital": final_capital,
|
||||
"total_return_pct": total_return_pct,
|
||||
"max_drawdown_pct": max_dd,
|
||||
"total_trades": float(len(trades)),
|
||||
"win_rate_pct": win_rate,
|
||||
"avg_win_pct": avg_win,
|
||||
"avg_loss_pct": avg_loss,
|
||||
"profit_factor": profit_factor,
|
||||
}
|
||||
|
||||
|
||||
def run_backtest(
|
||||
bars: list[Bar],
|
||||
box_len: int,
|
||||
buf_pct: float,
|
||||
min_box_pct: float,
|
||||
sl_pct: float,
|
||||
tp_pct: float,
|
||||
commission_pct: float,
|
||||
initial_capital: float,
|
||||
) -> tuple[list[Trade], list[float]]:
|
||||
high_win: deque[float] = deque(maxlen=box_len)
|
||||
low_win: deque[float] = deque(maxlen=box_len)
|
||||
close_hist: list[float] = []
|
||||
|
||||
trades: list[Trade] = []
|
||||
equity_curve: list[float] = [initial_capital]
|
||||
capital = initial_capital
|
||||
|
||||
position = 0 # 1 long, -1 short, 0 flat
|
||||
entry_price = 0.0
|
||||
entry_ts = ""
|
||||
|
||||
for i, bar in enumerate(bars):
|
||||
# Build history first
|
||||
close_hist.append(bar.close)
|
||||
if i == 0:
|
||||
high_win.append(bar.high)
|
||||
low_win.append(bar.low)
|
||||
continue
|
||||
|
||||
# Exit check (intrabar, after entry bar)
|
||||
if position != 0:
|
||||
if position == 1:
|
||||
stop = entry_price * (1 - sl_pct / 100.0)
|
||||
take = entry_price * (1 + tp_pct / 100.0)
|
||||
exit_price = 0.0
|
||||
reason = ""
|
||||
# Conservative tie-break: stop first if both touched same bar
|
||||
if bar.low <= stop:
|
||||
exit_price, reason = stop, "SL"
|
||||
elif bar.high >= take:
|
||||
exit_price, reason = take, "TP"
|
||||
if reason:
|
||||
gross_ret = (exit_price / entry_price - 1.0) * 100.0
|
||||
net_ret = gross_ret - 2 * commission_pct
|
||||
capital *= 1 + net_ret / 100.0
|
||||
trades.append(
|
||||
Trade("LONG", entry_ts, entry_price, bar.ts, exit_price, reason, gross_ret, net_ret)
|
||||
)
|
||||
equity_curve.append(capital)
|
||||
position = 0
|
||||
elif position == -1:
|
||||
stop = entry_price * (1 + sl_pct / 100.0)
|
||||
take = entry_price * (1 - tp_pct / 100.0)
|
||||
exit_price = 0.0
|
||||
reason = ""
|
||||
if bar.high >= stop:
|
||||
exit_price, reason = stop, "SL"
|
||||
elif bar.low <= take:
|
||||
exit_price, reason = take, "TP"
|
||||
if reason:
|
||||
gross_ret = (entry_price / exit_price - 1.0) * 100.0
|
||||
net_ret = gross_ret - 2 * commission_pct
|
||||
capital *= 1 + net_ret / 100.0
|
||||
trades.append(
|
||||
Trade("SHORT", entry_ts, entry_price, bar.ts, exit_price, reason, gross_ret, net_ret)
|
||||
)
|
||||
equity_curve.append(capital)
|
||||
position = 0
|
||||
|
||||
# Need full lookback and previous close for crossover.
|
||||
if len(high_win) < box_len or len(low_win) < box_len or i < 2:
|
||||
high_win.append(bar.high)
|
||||
low_win.append(bar.low)
|
||||
continue
|
||||
|
||||
box_high = max(high_win)
|
||||
box_low = min(low_win)
|
||||
box_mid = (box_high + box_low) / 2.0
|
||||
box_pct = ((box_high - box_low) / box_mid * 100.0) if box_mid > 0 else 0.0
|
||||
box_ok = box_pct >= min_box_pct
|
||||
|
||||
up_line = box_high * (1 + buf_pct / 100.0)
|
||||
dn_line = box_low * (1 - buf_pct / 100.0)
|
||||
|
||||
prev_close = close_hist[-2]
|
||||
long_trig = box_ok and prev_close <= up_line and bar.close > up_line
|
||||
short_trig = box_ok and prev_close >= dn_line and bar.close < dn_line
|
||||
|
||||
# Reverse signal close at close price then flip.
|
||||
if position == 1 and short_trig:
|
||||
gross_ret = (bar.close / entry_price - 1.0) * 100.0
|
||||
net_ret = gross_ret - 2 * commission_pct
|
||||
capital *= 1 + net_ret / 100.0
|
||||
trades.append(Trade("LONG", entry_ts, entry_price, bar.ts, bar.close, "REVERSE", gross_ret, net_ret))
|
||||
equity_curve.append(capital)
|
||||
position = 0
|
||||
elif position == -1 and long_trig:
|
||||
gross_ret = (entry_price / bar.close - 1.0) * 100.0
|
||||
net_ret = gross_ret - 2 * commission_pct
|
||||
capital *= 1 + net_ret / 100.0
|
||||
trades.append(Trade("SHORT", entry_ts, entry_price, bar.ts, bar.close, "REVERSE", gross_ret, net_ret))
|
||||
equity_curve.append(capital)
|
||||
position = 0
|
||||
|
||||
if position == 0:
|
||||
if long_trig:
|
||||
position = 1
|
||||
entry_price = bar.close
|
||||
entry_ts = bar.ts
|
||||
elif short_trig:
|
||||
position = -1
|
||||
entry_price = bar.close
|
||||
entry_ts = bar.ts
|
||||
|
||||
high_win.append(bar.high)
|
||||
low_win.append(bar.low)
|
||||
|
||||
# Force close at final close
|
||||
if position != 0:
|
||||
last = bars[-1]
|
||||
if position == 1:
|
||||
gross_ret = (last.close / entry_price - 1.0) * 100.0
|
||||
side = "LONG"
|
||||
else:
|
||||
gross_ret = (entry_price / last.close - 1.0) * 100.0
|
||||
side = "SHORT"
|
||||
net_ret = gross_ret - 2 * commission_pct
|
||||
capital *= 1 + net_ret / 100.0
|
||||
trades.append(Trade(side, entry_ts, entry_price, last.ts, last.close, "FORCE_CLOSE", gross_ret, net_ret))
|
||||
equity_curve.append(capital)
|
||||
|
||||
return trades, equity_curve
|
||||
|
||||
|
||||
def _save_trades(path: Path, trades: list[Trade]) -> None:
|
||||
with path.open("w", encoding="utf-8", newline="") as f:
|
||||
w = csv.writer(f)
|
||||
w.writerow(
|
||||
[
|
||||
"side",
|
||||
"entry_ts",
|
||||
"entry_price",
|
||||
"exit_ts",
|
||||
"exit_price",
|
||||
"reason",
|
||||
"gross_return_pct",
|
||||
"net_return_pct",
|
||||
]
|
||||
)
|
||||
for t in trades:
|
||||
w.writerow(
|
||||
[
|
||||
t.side,
|
||||
t.entry_ts,
|
||||
f"{t.entry_price:.8f}",
|
||||
t.exit_ts,
|
||||
f"{t.exit_price:.8f}",
|
||||
t.reason,
|
||||
f"{t.gross_return_pct:.6f}",
|
||||
f"{t.net_return_pct:.6f}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="ETH 1m 裸K箱体突破回测")
|
||||
parser.add_argument("--csv", required=True, help="K线 CSV 路径")
|
||||
parser.add_argument("--ts-col", default="timestamp", help="时间列名")
|
||||
parser.add_argument("--open-col", default="open", help="开盘列名")
|
||||
parser.add_argument("--high-col", default="high", help="最高列名")
|
||||
parser.add_argument("--low-col", default="low", help="最低列名")
|
||||
parser.add_argument("--close-col", default="close", help="收盘列名")
|
||||
|
||||
parser.add_argument("--box-len", type=int, default=80, help="箱体回看K数")
|
||||
parser.add_argument("--buf-pct", type=float, default=0.03, help="突破缓冲百分比")
|
||||
parser.add_argument("--min-box-pct", type=float, default=1.5, help="最小箱体宽度百分比")
|
||||
parser.add_argument("--sl-pct", type=float, default=0.8, help="止损百分比")
|
||||
parser.add_argument("--tp-pct", type=float, default=2.4, help="止盈百分比")
|
||||
parser.add_argument("--commission-pct", type=float, default=0.05, help="单边手续费百分比")
|
||||
parser.add_argument("--capital", type=float, default=100000.0, help="初始资金")
|
||||
parser.add_argument("--out", default="runtime/backtest_trades.csv", help="交易明细输出路径")
|
||||
parser.add_argument(
|
||||
"--timeframes",
|
||||
default="15m,30m,1h",
|
||||
help="回测周期,逗号分隔;会从1m聚合,如: 15m,30m,1h",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
csv_path = Path(args.csv).expanduser().resolve()
|
||||
if not csv_path.exists():
|
||||
raise FileNotFoundError(f"CSV 不存在: {csv_path}")
|
||||
|
||||
bars = _read_csv(
|
||||
csv_path,
|
||||
ts_col=args.ts_col,
|
||||
open_col=args.open_col,
|
||||
high_col=args.high_col,
|
||||
low_col=args.low_col,
|
||||
close_col=args.close_col,
|
||||
)
|
||||
tfs = [x.strip() for x in str(args.timeframes).split(",") if x.strip()]
|
||||
if not tfs:
|
||||
raise ValueError("timeframes 不能为空")
|
||||
|
||||
base_out = Path(args.out).expanduser().resolve()
|
||||
base_out.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print("=== Backtest Done (1m聚合多周期) ===")
|
||||
print(f"source_1m_bars: {len(bars)}")
|
||||
print(f"source_period: {bars[0].ts} -> {bars[-1].ts}")
|
||||
print("")
|
||||
|
||||
for tf in tfs:
|
||||
minutes = _parse_timeframe_to_minutes(tf)
|
||||
agg = _aggregate_bars(bars, minutes)
|
||||
trades, curve = run_backtest(
|
||||
bars=agg,
|
||||
box_len=args.box_len,
|
||||
buf_pct=args.buf_pct,
|
||||
min_box_pct=args.min_box_pct,
|
||||
sl_pct=args.sl_pct,
|
||||
tp_pct=args.tp_pct,
|
||||
commission_pct=args.commission_pct,
|
||||
initial_capital=args.capital,
|
||||
)
|
||||
stats = _calc_stats(curve, trades, args.capital)
|
||||
out_path = base_out.with_name(f"{base_out.stem}_{tf}{base_out.suffix}")
|
||||
_save_trades(out_path, trades)
|
||||
|
||||
print(f"[{tf}] bars={len(agg)} trades={int(stats.get('total_trades', 0))}")
|
||||
print(f" period: {agg[0].ts} -> {agg[-1].ts}")
|
||||
print(f" final_capital: {stats.get('final_capital', 0):.2f}")
|
||||
print(f" total_return: {stats.get('total_return_pct', 0):.2f}%")
|
||||
print(f" max_drawdown: {stats.get('max_drawdown_pct', 0):.2f}%")
|
||||
print(f" win_rate: {stats.get('win_rate_pct', 0):.2f}%")
|
||||
print(f" profit_factor: {stats.get('profit_factor', 0):.3f}")
|
||||
print(f" trades_csv: {out_path}")
|
||||
print("")
|
||||
|
||||
print(f"generated_at: {datetime.now().isoformat(timespec='seconds')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user