diff --git a/hub_ohlcv_lib.py b/hub_ohlcv_lib.py index 584ceb4..74e83fb 100644 --- a/hub_ohlcv_lib.py +++ b/hub_ohlcv_lib.py @@ -44,6 +44,13 @@ CHART_TIMEFRAME_ORDER = ( ) DAILY_PLUS_TIMEFRAMES = frozenset({"1d", "1w"}) +# 部分交易所 ccxt 无原生周期(如 Gate 无 6h/12h),或原生 K 线间隔异常时从细周期聚合 +OHLCV_AGGREGATE_FROM: dict[str, str] = { + "6h": "1h", + "8h": "1h", + "12h": "1h", +} + TIMEFRAME_MS: dict[str, int] = { "1m": 60_000, "3m": 3 * 60_000, @@ -189,6 +196,132 @@ def format_price_by_tick(value: Any, tick: Optional[float]) -> str: return text.rstrip("0").rstrip(".") if "." in text else text +def exchange_supports_timeframe(exchange, timeframe: str) -> bool: + tf = normalize_chart_timeframe(timeframe) + tfs = getattr(exchange, "timeframes", None) or {} + if not tfs: + return True + return tf in tfs + + +def _median_bar_step_ms(bars: list[dict[str, Any]]) -> Optional[int]: + if len(bars) < 2: + return None + steps: list[int] = [] + for i in range(1, min(len(bars), 64)): + step = int(bars[i]["open_time_ms"]) - int(bars[i - 1]["open_time_ms"]) + if step > 0: + steps.append(step) + if not steps: + return None + steps.sort() + return steps[len(steps) // 2] + + +def bars_spacing_matches_timeframe( + bars: list[dict[str, Any]], timeframe: str, *, tolerance: float = 0.08 +) -> bool: + if len(bars) < 2: + return True + period = TIMEFRAME_MS[normalize_chart_timeframe(timeframe)] + step = _median_bar_step_ms(bars) + if step is None: + return False + return abs(step - period) <= period * tolerance + + +def align_bar_open_ms(open_time_ms: int, period_ms: int) -> int: + return (int(open_time_ms) // period_ms) * period_ms + + +def aggregate_ohlcv_bars( + bars: list[dict[str, Any]], target_timeframe: str +) -> list[dict[str, Any]]: + """将细周期 OHLCV 聚合为目标周期(UTC 对齐 bucket)。""" + tf = normalize_chart_timeframe(target_timeframe) + period = TIMEFRAME_MS[tf] + buckets: dict[int, dict[str, Any]] = {} + for b in bars or []: + try: + key = align_bar_open_ms(int(b["open_time_ms"]), period) + o = float(b["open"]) + h = float(b["high"]) + l = float(b["low"]) + c = float(b["close"]) + v = float(b.get("volume") or 0) + except (KeyError, TypeError, ValueError): + continue + cur = buckets.get(key) + if cur is None: + buckets[key] = { + "open_time_ms": key, + "open": o, + "high": h, + "low": l, + "close": c, + "volume": v, + } + continue + cur["high"] = max(float(cur["high"]), h) + cur["low"] = min(float(cur["low"]), l) + cur["close"] = c + cur["volume"] = float(cur.get("volume") or 0) + v + return [buckets[k] for k in sorted(buckets.keys())] + + +def _next_since_from_batch(batch: list, period_ms: int) -> int: + last_ts = int(batch[-1][0]) + if len(batch) >= 2: + step = int(batch[-1][0]) - int(batch[-2][0]) + if step > 0: + return last_ts + step + return last_ts + period_ms + + +def _paginate_fetch_ohlcv( + exchange, + ex_sym: str, + timeframe: str, + *, + want: int, + since_ms: int | None, + period_ms: int, + chunk_max: int = 300, +) -> list[dict[str, Any]]: + tf = normalize_chart_timeframe(timeframe) + collected: list = [] + if since_ms is not None and int(since_ms) > 0: + since = int(since_ms) + else: + since = max(0, int(time.time() * 1000) - want * period_ms) + + guard = 0 + prev_since = None + while len(collected) < want and guard < 80: + guard += 1 + req_limit = min(chunk_max, want - len(collected)) + batch = exchange.fetch_ohlcv( + ex_sym, timeframe=tf, since=since, limit=req_limit + ) + if not batch: + break + collected.extend(batch) + next_since = _next_since_from_batch(batch, period_ms) + if prev_since is not None and next_since <= prev_since: + break + prev_since = since + since = next_since + + bars = _bars_to_dicts(collected) + uniq: dict[int, dict[str, Any]] = {} + for b in bars: + uniq[int(b["open_time_ms"])] = b + merged = [uniq[k] for k in sorted(uniq.keys())] + if len(merged) > want: + merged = merged[-want:] + return merged + + def _bars_to_dicts(ohlcv: list) -> list[dict[str, Any]]: out: list[dict[str, Any]] = [] for bar in ohlcv or []: @@ -231,44 +364,51 @@ def fetch_ohlcv_for_hub( ensure_markets_loaded() ex_sym = normalize_exchange_symbol(sym) want = max(1, min(int(limit or bar_limit_for_timeframe(tf)), 1500)) - chunk_max = 300 period = TIMEFRAME_MS[tf] - collected: list = [] + merged: list[dict[str, Any]] = [] + src_tf = OHLCV_AGGREGATE_FROM.get(tf) - if since_ms is not None and int(since_ms) > 0: - since = int(since_ms) - else: - # OKX/Gate 等无 since 时单次常被限制在 ~300 根,须从目标起点分页向前拉 - since = max(0, int(time.time() * 1000) - want * period) - - guard = 0 - prev_since = None - while len(collected) < want and guard < 80: - guard += 1 - req_limit = min(chunk_max, want - len(collected)) - batch = exchange.fetch_ohlcv( - ex_sym, timeframe=tf, since=since, limit=req_limit + if exchange_supports_timeframe(exchange, tf): + candidate = _paginate_fetch_ohlcv( + exchange, + ex_sym, + tf, + want=want, + since_ms=since_ms, + period_ms=period, ) - if not batch: - break - collected.extend(batch) - next_since = int(batch[-1][0]) + period - if prev_since is not None and next_since <= prev_since: - break - prev_since = since - since = next_since + if candidate and bars_spacing_matches_timeframe(candidate, tf): + merged = candidate - bars = _bars_to_dicts(collected) - if not bars: + if ( + not merged + and src_tf + and exchange_supports_timeframe(exchange, src_tf) + ): + src_period = TIMEFRAME_MS[normalize_chart_timeframe(src_tf)] + ratio = max(1, int(math.ceil(period / src_period))) + src_want = min(1500, want * ratio + ratio * 4) + src_bars = _paginate_fetch_ohlcv( + exchange, + ex_sym, + src_tf, + want=src_want, + since_ms=since_ms, + period_ms=src_period, + ) + if not src_bars or not bars_spacing_matches_timeframe(src_bars, src_tf): + return { + "ok": False, + "msg": f"无法获取 {tf} K 线(细周期 {src_tf} 数据异常)", + } + merged = aggregate_ohlcv_bars(src_bars, tf) + if len(merged) > want: + merged = merged[-want:] + + if not merged: return {"ok": False, "msg": "交易所未返回 K 线"} tick = price_tick_from_market(exchange, ex_sym) - uniq: dict[int, dict] = {} - for b in bars: - uniq[int(b["open_time_ms"])] = b - merged = [uniq[k] for k in sorted(uniq.keys())] - if len(merged) > want: - merged = merged[-want:] return { "ok": True, diff --git a/tests/test_hub_ohlcv_lib.py b/tests/test_hub_ohlcv_lib.py index bd9a1fa..57e3b00 100644 --- a/tests/test_hub_ohlcv_lib.py +++ b/tests/test_hub_ohlcv_lib.py @@ -3,17 +3,24 @@ from __future__ import annotations import unittest -from hub_ohlcv_lib import fetch_ohlcv_for_hub +from hub_ohlcv_lib import ( + aggregate_ohlcv_bars, + bars_spacing_matches_timeframe, + fetch_ohlcv_for_hub, +) class _FakeExchange: - def __init__(self, pages): + def __init__(self, pages, *, timeframes=None): self.pages = list(pages) self.calls = [] self.markets = {} + self.timeframes = timeframes if timeframes is not None else {} def fetch_ohlcv(self, symbol, timeframe=None, since=None, limit=None): - self.calls.append({"symbol": symbol, "since": since, "limit": limit}) + self.calls.append( + {"symbol": symbol, "since": since, "limit": limit, "timeframe": timeframe} + ) if not self.pages: return [] page = self.pages.pop(0) @@ -125,6 +132,61 @@ class TestHubOhlcvLib(unittest.TestCase): self.assertGreaterEqual(len(ex.calls), 3) self.assertAlmostEqual(out["bars"][-1]["close"], 3.05) + def test_aggregate_6h_from_1h_when_exchange_lacks_native(self): + """Gate 等无 6h 时应从 1h 聚合。""" + from hub_ohlcv_lib import TIMEFRAME_MS + + h1 = TIMEFRAME_MS["1h"] + h6 = TIMEFRAME_MS["6h"] + base = 1_700_000_000_000 + base = (base // h6) * h6 + one_h = [ + [base + i * h1, 100.0 + i, 101.0 + i, 99.0 + i, 100.5 + i, 10.0] + for i in range(24) + ] + ex = _FakeExchange( + [one_h], + timeframes={"1h": "1h", "4h": "4h", "8h": "8h"}, + ) + out = fetch_ohlcv_for_hub( + symbol="BTC/USDT", + timeframe="6h", + since_ms=base, + limit=4, + normalize_symbol_input=lambda s: str(s).strip().upper(), + normalize_exchange_symbol=lambda s: f"{s}:USDT" if ":" not in s else s, + ensure_markets_loaded=lambda: None, + exchange=ex, + ) + self.assertTrue(out.get("ok")) + bars = out.get("bars") or [] + self.assertEqual(len(bars), 4) + self.assertTrue(bars_spacing_matches_timeframe(bars, "6h")) + self.assertEqual(ex.calls[0]["timeframe"], "1h") + + def test_aggregate_ohlcv_bars_buckets(self): + from hub_ohlcv_lib import TIMEFRAME_MS + + h1 = TIMEFRAME_MS["1h"] + h6 = TIMEFRAME_MS["6h"] + base = (1_700_000_000_000 // h6) * h6 + src = [ + { + "open_time_ms": base + i * h1, + "open": 1.0, + "high": 2.0, + "low": 0.5, + "close": 1.5, + "volume": 1.0, + } + for i in range(6) + ] + out = aggregate_ohlcv_bars(src, "6h") + self.assertEqual(len(out), 1) + self.assertEqual(out[0]["volume"], 6.0) + self.assertEqual(out[0]["high"], 2.0) + self.assertEqual(out[0]["low"], 0.5) + if __name__ == "__main__": unittest.main()