"""中控 K 线 SQLite:分周期保留、交易所直拉、分页读取。""" from __future__ import annotations import os import sqlite3 import time from pathlib import Path from typing import Any, Callable, Optional from hub_ohlcv_lib import ( HUB_KLINE_1M_MAX_BARS, HUB_KLINE_5M_1H_RETENTION_DAYS, TIMEFRAME_MS, YEAR_ROLLING_STORED, chart_chunk_limit, chart_initial_limit, chart_memory_cap, history_cutoff_ms_for_storage, normalize_chart_timeframe, normalize_price_tick, format_price_by_tick, last_closed_bar_open_ms, retention_policy_meta, round_ohlcv_bars_to_tick, seed_bar_target, ) HUB_KLINE_MIN_BARS_BEFORE_TAIL = 200 HUB_KLINE_REMOTE_FETCH_CAP = 1500 _DEFAULT_RETENTION_DAYS = 15 def retention_days() -> int: """兼容旧配置;新策略见 retention_policy_meta。""" try: return max(1, int(os.getenv("HUB_KLINE_RETENTION_DAYS", str(_DEFAULT_RETENTION_DAYS)))) except ValueError: return _DEFAULT_RETENTION_DAYS def default_db_path() -> Path: raw = (os.getenv("HUB_KLINE_DB_PATH") or "").strip() if raw: return Path(raw) hub_dir = Path(__file__).resolve().parent / "manual_trading_hub" / "data" hub_dir.mkdir(parents=True, exist_ok=True) return hub_dir / "hub_kline.db" def _connect(db_path: Path | None = None) -> sqlite3.Connection: path = db_path or default_db_path() path.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(str(path), timeout=30, isolation_level=None) conn.row_factory = sqlite3.Row conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA synchronous=NORMAL") return conn def init_db(db_path: Path | None = None) -> None: conn = _connect(db_path) try: conn.execute( """ CREATE TABLE IF NOT EXISTS ohlcv_bars ( exchange_key TEXT NOT NULL, symbol TEXT NOT NULL, timeframe TEXT NOT NULL, open_time_ms INTEGER NOT NULL, open REAL NOT NULL, high REAL NOT NULL, low REAL NOT NULL, close REAL NOT NULL, volume REAL NOT NULL DEFAULT 0, updated_at INTEGER NOT NULL, PRIMARY KEY (exchange_key, symbol, timeframe, open_time_ms) ) """ ) conn.execute( """ CREATE INDEX IF NOT EXISTS idx_ohlcv_series ON ohlcv_bars (exchange_key, symbol, timeframe, open_time_ms) """ ) conn.execute( """ CREATE TABLE IF NOT EXISTS ohlcv_symbol_meta ( exchange_key TEXT NOT NULL, symbol TEXT NOT NULL, price_tick REAL, updated_at INTEGER NOT NULL, PRIMARY KEY (exchange_key, symbol) ) """ ) finally: conn.close() def save_symbol_price_tick( exchange_key: str, symbol: str, price_tick: float | None, db_path: Path | None = None, ) -> None: tick = price_tick if tick is None: return try: t = float(tick) except (TypeError, ValueError): return if t <= 0: return ex_k = (exchange_key or "").strip().lower() sym = (symbol or "").strip().upper() conn = _connect(db_path) try: conn.execute( """ INSERT INTO ohlcv_symbol_meta (exchange_key, symbol, price_tick, updated_at) VALUES (?,?,?,?) ON CONFLICT(exchange_key, symbol) DO UPDATE SET price_tick=excluded.price_tick, updated_at=excluded.updated_at """, (ex_k, sym, t, int(time.time())), ) finally: conn.close() def load_symbol_price_tick( exchange_key: str, symbol: str, db_path: Path | None = None, ) -> float | None: ex_k = (exchange_key or "").strip().lower() sym = (symbol or "").strip().upper() conn = _connect(db_path) try: row = conn.execute( "SELECT price_tick FROM ohlcv_symbol_meta WHERE exchange_key=? AND symbol=?", (ex_k, sym), ).fetchone() if not row or row["price_tick"] is None: return None return float(row["price_tick"]) except (TypeError, ValueError): return None finally: conn.close() def purge_timeframe_by_days( timeframe: str, days: int, db_path: Path | None = None, ) -> int: cutoff = int(time.time() * 1000) - max(1, int(days)) * 86400000 tf = normalize_chart_timeframe(timeframe) conn = _connect(db_path) try: cur = conn.execute( "DELETE FROM ohlcv_bars WHERE timeframe=? AND open_time_ms < ?", (tf, cutoff), ) return int(cur.rowcount or 0) finally: conn.close() def purge_1m_bar_cap(db_path: Path | None = None, *, max_bars: int | None = None) -> int: cap = max(100, int(max_bars or HUB_KLINE_1M_MAX_BARS)) conn = _connect(db_path) try: cur = conn.execute( """ DELETE FROM ohlcv_bars WHERE timeframe='1m' AND rowid IN ( SELECT rowid FROM ( SELECT rowid, ROW_NUMBER() OVER ( PARTITION BY exchange_key, symbol ORDER BY open_time_ms DESC ) AS rn FROM ohlcv_bars WHERE timeframe='1m' ) WHERE rn > ? ) """, (cap,), ) return int(cur.rowcount or 0) finally: conn.close() def clear_series_bars( exchange_key: str, symbol: str, timeframe: str | None = None, db_path: Path | None = None, ) -> int: """删除某交易所+币种 K 线(可指定周期);用于清库后全量重拉。""" init_db(db_path) ex_k = (exchange_key or "").strip().lower() sym = (symbol or "").strip().upper() if not ex_k or not sym: return 0 conn = _connect(db_path) try: if timeframe: tf = normalize_chart_timeframe(timeframe) cur = conn.execute( "DELETE FROM ohlcv_bars WHERE exchange_key=? AND symbol=? AND timeframe=?", (ex_k, sym, tf), ) else: cur = conn.execute( "DELETE FROM ohlcv_bars WHERE exchange_key=? AND symbol=?", (ex_k, sym), ) return int(cur.rowcount or 0) finally: conn.close() def clear_all_bars(db_path: Path | None = None) -> int: """清空 hub K 线库全部 OHLCV 行。""" init_db(db_path) conn = _connect(db_path) try: cur = conn.execute("DELETE FROM ohlcv_bars") return int(cur.rowcount or 0) finally: conn.close() def purge_retention(db_path: Path | None = None) -> int: """按周期策略清理:5m/15m/1h/2h/4h 一年;1m 保留最近 N 根;1d/1w 不删。""" n = 0 for tf in sorted(YEAR_ROLLING_STORED): n += purge_timeframe_by_days(tf, HUB_KLINE_5M_1H_RETENTION_DAYS, db_path) n += purge_1m_bar_cap(db_path) return n def upsert_bars( exchange_key: str, symbol: str, timeframe: str, bars: list[dict[str, Any]], db_path: Path | None = None, ) -> int: if not bars: return 0 ex_k = (exchange_key or "").strip().lower() sym = (symbol or "").strip().upper() tf = normalize_chart_timeframe(timeframe) now = int(time.time()) conn = _connect(db_path) n = 0 try: for b in bars: try: oms = int(b["open_time_ms"]) conn.execute( """ INSERT INTO ohlcv_bars (exchange_key, symbol, timeframe, open_time_ms, open, high, low, close, volume, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?) ON CONFLICT(exchange_key, symbol, timeframe, open_time_ms) DO UPDATE SET open=excluded.open, high=excluded.high, low=excluded.low, close=excluded.close, volume=excluded.volume, updated_at=excluded.updated_at """, ( ex_k, sym, tf, oms, float(b["open"]), float(b["high"]), float(b["low"]), float(b["close"]), float(b.get("volume") or 0), now, ), ) n += 1 except (KeyError, TypeError, ValueError): continue finally: conn.close() return n def load_bars_range( exchange_key: str, symbol: str, timeframe: str, start_ms: int, end_ms: int, db_path: Path | None = None, ) -> list[dict[str, Any]]: ex_k = (exchange_key or "").strip().lower() sym = (symbol or "").strip().upper() tf = normalize_chart_timeframe(timeframe) conn = _connect(db_path) try: rows = conn.execute( """ SELECT open_time_ms, open, high, low, close, volume FROM ohlcv_bars WHERE exchange_key=? AND symbol=? AND timeframe=? AND open_time_ms >= ? AND open_time_ms <= ? ORDER BY open_time_ms ASC """, (ex_k, sym, tf, int(start_ms), int(end_ms)), ).fetchall() return _rows_to_bars(rows) finally: conn.close() def count_series_bars( exchange_key: str, symbol: str, timeframe: str, db_path: Path | None = None, ) -> int: ex_k = (exchange_key or "").strip().lower() sym = (symbol or "").strip().upper() tf = normalize_chart_timeframe(timeframe) conn = _connect(db_path) try: row = conn.execute( """ SELECT COUNT(*) AS c FROM ohlcv_bars WHERE exchange_key=? AND symbol=? AND timeframe=? """, (ex_k, sym, tf), ).fetchone() return int(row["c"] or 0) if row else 0 finally: conn.close() def _remote_fetch_limit( *, need: int, force_refresh: bool, storage_tf: str, tail_only: bool, ) -> int: if tail_only: return min(need + 20, 300) cap = HUB_KLINE_REMOTE_FETCH_CAP if force_refresh: return min(seed_bar_target(storage_tf), cap) return min(max(need + 20, 1), cap) def _since_ms_for_span( *, now_ms: int, period_ms: int, span_bars: int, cutoff_ms: int, ) -> int: """拉取窗口起点:跨度必须与 fetch_limit 一致,保证数据能铺到最近。""" span = max(1, int(span_bars)) return max(int(cutoff_ms), int(now_ms) - int(period_ms) * span) def load_bars_latest( exchange_key: str, symbol: str, timeframe: str, limit: int, db_path: Path | None = None, ) -> list[dict[str, Any]]: ex_k = (exchange_key or "").strip().lower() sym = (symbol or "").strip().upper() tf = normalize_chart_timeframe(timeframe) lim = max(1, int(limit)) conn = _connect(db_path) try: rows = conn.execute( """ SELECT open_time_ms, open, high, low, close, volume FROM ohlcv_bars WHERE exchange_key=? AND symbol=? AND timeframe=? ORDER BY open_time_ms DESC LIMIT ? """, (ex_k, sym, tf, lim), ).fetchall() return list(reversed(_rows_to_bars(rows))) finally: conn.close() def load_bars_before( exchange_key: str, symbol: str, timeframe: str, before_ms: int, limit: int, db_path: Path | None = None, ) -> list[dict[str, Any]]: ex_k = (exchange_key or "").strip().lower() sym = (symbol or "").strip().upper() tf = normalize_chart_timeframe(timeframe) lim = max(1, int(limit)) bms = int(before_ms) conn = _connect(db_path) try: rows = conn.execute( """ SELECT open_time_ms, open, high, low, close, volume FROM ohlcv_bars WHERE exchange_key=? AND symbol=? AND timeframe=? AND open_time_ms < ? ORDER BY open_time_ms DESC LIMIT ? """, (ex_k, sym, tf, bms, lim), ).fetchall() return list(reversed(_rows_to_bars(rows))) finally: conn.close() def trim_contiguous_tail( bars: list[dict[str, Any]], period_ms: int, *, max_gap_factor: float = 3.0, ) -> tuple[list[dict[str, Any]], int]: """只保留最近一段连续 K 线,丢弃左侧与主段断开的孤立数据。""" if len(bars) <= 1: return list(bars), 0 try: period = max(1, int(period_ms)) except (TypeError, ValueError): period = 60_000 max_gap = int(period * max_gap_factor) split = 0 for i in range(len(bars) - 1, 0, -1): gap = int(bars[i]["open_time_ms"]) - int(bars[i - 1]["open_time_ms"]) if gap > max_gap: split = i break return bars[split:], split def normalize_contiguous_db_rows( bars: list[dict[str, Any]], *, period_ms: int, exchange_key: str, symbol: str, timeframe: str, db_path: Path | None = None, purge_orphans: bool = True, ) -> list[dict[str, Any]]: """去掉与主段断开的孤立前缀;可选同步清理库内孤立数据。""" if len(bars) <= 1: return list(bars) trimmed, split_at = trim_contiguous_tail(bars, period_ms) if split_at > 0 and purge_orphans: purge_bars_open_before( exchange_key, symbol, timeframe, int(trimmed[0]["open_time_ms"]), db_path, ) return trimmed def purge_bars_open_before( exchange_key: str, symbol: str, timeframe: str, open_time_ms: int, db_path: Path | None = None, ) -> int: """删除某品种周期下早于 open_time_ms 的 K 线(清理与主段断开的孤立历史)。""" ex_k = (exchange_key or "").strip().lower() sym = (symbol or "").strip().upper() tf = normalize_chart_timeframe(timeframe) conn = _connect(db_path) try: cur = conn.execute( """ DELETE FROM ohlcv_bars WHERE exchange_key=? AND symbol=? AND timeframe=? AND open_time_ms < ? """, (ex_k, sym, tf, int(open_time_ms)), ) return int(cur.rowcount or 0) finally: conn.close() def _rows_to_bars(rows) -> list[dict[str, Any]]: return [ { "open_time_ms": int(r["open_time_ms"]), "open": float(r["open"]), "high": float(r["high"]), "low": float(r["low"]), "close": float(r["close"]), "volume": float(r["volume"] or 0), } for r in rows ] def _to_chart_candles(bars: list[dict[str, Any]]) -> list[dict[str, Any]]: out = [] for b in bars: try: out.append( { "time": int(b["open_time_ms"] // 1000), "open": float(b["open"]), "high": float(b["high"]), "low": float(b["low"]), "close": float(b["close"]), "volume": float(b.get("volume") or 0), } ) except (KeyError, TypeError, ValueError): continue return out def _trim_display_bars( bars: list[dict[str, Any]], *, need: int, before_ms: int | None, ) -> list[dict[str, Any]]: if not bars: return [] if before_ms is not None and int(before_ms) > 0: bms = int(before_ms) bars = [b for b in bars if int(b["open_time_ms"]) < bms] if len(bars) > need: bars = bars[-need:] return bars if len(bars) > need: bars = bars[-need:] return bars def resolve_chart_bars( exchange_key: str, symbol: str, timeframe: str, remote_fetch: Callable[..., dict[str, Any]], *, db_path: Path | None = None, force_refresh: bool = False, tail_refresh: bool = False, clear_db: bool = False, limit: int | None = None, before_ms: int | None = None, ) -> dict[str, Any]: """ 分页读库:首屏 / 左拖 before_ms / 尾部 tail_refresh。 各展示周期均直读交易所同步入库的同名 K 线。 """ init_db(db_path) purged = purge_retention(db_path) cleared = 0 sym = (symbol or "").strip().upper() ex_k = (exchange_key or "").strip().lower() display_tf = normalize_chart_timeframe(timeframe) if not sym or not ex_k: return {"ok": False, "msg": "缺少 exchange 或 symbol"} storage_tf = display_tf is_history = before_ms is not None and int(before_ms) > 0 need = int( limit or (chart_chunk_limit(display_tf) if is_history else chart_initial_limit(display_tf)) ) need = max(1, min(need, chart_memory_cap(display_tf))) now_ms = int(time.time() * 1000) period_display = TIMEFRAME_MS[display_tf] period_storage = TIMEFRAME_MS[storage_tf] series_bar_count = ( count_series_bars(ex_k, sym, storage_tf, db_path) if not is_history else 0 ) if tail_refresh and not is_history: min_seed = min(chart_initial_limit(display_tf) // 5, HUB_KLINE_MIN_BARS_BEFORE_TAIL) if series_bar_count < max(1, min_seed): tail_refresh = False else: need = min(need, 30) cutoff = history_cutoff_ms_for_storage(storage_tf, now_ms) if clear_db and not is_history and not tail_refresh: cleared = clear_series_bars(ex_k, sym, storage_tf, db_path) def load_display_rows() -> list[dict[str, Any]]: if is_history: rows = load_bars_before(ex_k, sym, storage_tf, int(before_ms), need, db_path) return _trim_display_bars(rows, need=need, before_ms=int(before_ms)) return load_bars_latest(ex_k, sym, storage_tf, need, db_path) db_rows: list[dict[str, Any]] = [] if not force_refresh: db_rows = load_display_rows() if not is_history and db_rows: db_rows = normalize_contiguous_db_rows( db_rows, period_ms=period_display, exchange_key=ex_k, symbol=sym, timeframe=storage_tf, db_path=db_path, ) last_closed = last_closed_bar_open_ms(display_tf, now_ms) newest_db = db_rows[-1]["open_time_ms"] if db_rows else None if is_history: newest_ok = True else: newest_ok = newest_db is not None and int(newest_db) >= int(last_closed) - period_display need_fetch = force_refresh or ( not is_history and (len(db_rows) < need or not newest_ok) ) if is_history and len(db_rows) < need: need_fetch = True tail_only = False if tail_refresh and not is_history and db_rows and not force_refresh and not need_fetch: need_fetch = True tail_only = True fetched = 0 price_tick: Optional[float] = None remote_err: Optional[str] = None if need_fetch: if is_history: bms = int(before_ms) anchor = bms - period_display since = max(cutoff, anchor - period_storage * need) fetch_limit = min(need + 20, 1500) elif tail_only: anchor_ms = int(newest_db) if newest_db is not None else now_ms fetch_limit = _remote_fetch_limit( need=need, force_refresh=False, storage_tf=storage_tf, tail_only=True ) since = _since_ms_for_span( now_ms=anchor_ms, period_ms=period_storage, span_bars=5, cutoff_ms=cutoff, ) else: fetch_limit = _remote_fetch_limit( need=need, force_refresh=force_refresh, storage_tf=storage_tf, tail_only=False, ) since = _since_ms_for_span( now_ms=now_ms, period_ms=period_storage, span_bars=fetch_limit, cutoff_ms=cutoff, ) remote = remote_fetch( symbol=sym, timeframe=storage_tf, since_ms=since, limit=fetch_limit, ) if remote.get("ok") and remote.get("bars"): fetched = upsert_bars(ex_k, sym, storage_tf, remote["bars"], db_path) price_tick = remote.get("price_tick") if price_tick is not None: save_symbol_price_tick(ex_k, sym, price_tick, db_path) db_rows = load_display_rows() if not is_history and db_rows: db_rows = normalize_contiguous_db_rows( db_rows, period_ms=period_display, exchange_key=ex_k, symbol=sym, timeframe=storage_tf, db_path=db_path, ) if not is_history and not tail_only and db_rows: newest_ms = int(db_rows[-1]["open_time_ms"]) if newest_ms < int(last_closed) - period_display: gap_limit = min( 500, int((now_ms - newest_ms) // period_storage) + 10, ) if gap_limit > 1: gap_remote = remote_fetch( symbol=sym, timeframe=storage_tf, since_ms=newest_ms, limit=gap_limit, ) if gap_remote.get("ok") and gap_remote.get("bars"): fetched += upsert_bars( ex_k, sym, storage_tf, gap_remote["bars"], db_path ) db_rows = load_display_rows() db_rows = normalize_contiguous_db_rows( db_rows, period_ms=period_display, exchange_key=ex_k, symbol=sym, timeframe=storage_tf, db_path=db_path, ) else: remote_err = remote.get("msg") or remote.get("error") or "实例拉取 K 线失败" if not db_rows: if is_history: exhausted = True else: return {"ok": False, "msg": remote_err, "purged": purged} exhausted = False if is_history: if not db_rows: exhausted = True elif len(db_rows) < need: oldest = int(db_rows[0]["open_time_ms"]) if cutoff > 0 and oldest <= cutoff + period_storage: exhausted = True elif fetched == 0: exhausted = True if price_tick is None: price_tick = load_symbol_price_tick(ex_k, sym, db_path) if price_tick is None and not is_history: try: tick_probe = remote_fetch( symbol=sym, timeframe=storage_tf, since_ms=None, limit=3, ) if tick_probe.get("ok"): price_tick = tick_probe.get("price_tick") if price_tick is not None: save_symbol_price_tick(ex_k, sym, price_tick, db_path) except Exception: pass if not is_history and db_rows: db_rows = normalize_contiguous_db_rows( db_rows, period_ms=period_display, exchange_key=ex_k, symbol=sym, timeframe=storage_tf, db_path=db_path, ) if not is_history and len(db_rows) < need: missing = need - len(db_rows) backfill_limit = min(missing + 60, HUB_KLINE_REMOTE_FETCH_CAP) if db_rows: oldest = int(db_rows[0]["open_time_ms"]) backfill_since = _since_ms_for_span( now_ms=oldest, period_ms=period_storage, span_bars=backfill_limit, cutoff_ms=cutoff, ) else: backfill_since = _since_ms_for_span( now_ms=now_ms, period_ms=period_storage, span_bars=backfill_limit, cutoff_ms=cutoff, ) try: remote_back = remote_fetch( symbol=sym, timeframe=storage_tf, since_ms=backfill_since, limit=backfill_limit, ) if remote_back.get("ok") and remote_back.get("bars"): fetched += upsert_bars(ex_k, sym, storage_tf, remote_back["bars"], db_path) if remote_back.get("price_tick") is not None: price_tick = remote_back.get("price_tick") save_symbol_price_tick(ex_k, sym, price_tick, db_path) db_rows = load_display_rows() db_rows = normalize_contiguous_db_rows( db_rows, period_ms=period_display, exchange_key=ex_k, symbol=sym, timeframe=storage_tf, db_path=db_path, ) elif not remote_err: remote_err = ( remote_back.get("msg") or remote_back.get("error") or "实例补拉 K 线失败" ) except Exception as e: if not remote_err: remote_err = str(e) price_tick = normalize_price_tick(price_tick) if db_rows and price_tick is not None: round_ohlcv_bars_to_tick(db_rows, price_tick) candles = _to_chart_candles(db_rows) if not is_history and not candles and not exhausted: return {"ok": False, "msg": remote_err or "无 K 线数据", "purged": purged} oldest_ms = int(db_rows[0]["open_time_ms"]) if db_rows else None newest_ms = int(db_rows[-1]["open_time_ms"]) if db_rows else None from_cache = max(0, len(candles) - min(fetched, len(candles))) if fetched else len(candles) return { "ok": True, "symbol": sym, "exchange_key": ex_k, "timeframe": display_tf, "storage_timeframe": storage_tf, "limit": need, "before_ms": int(before_ms) if is_history else None, "oldest_ms": oldest_ms, "newest_ms": newest_ms, "exhausted": exhausted, "source": "remote" if fetched else "db", "retention_policy": retention_policy_meta(), "candles": candles, "from_cache": from_cache, "fetched": fetched, "cleared": cleared, "purged": purged, "price_tick": price_tick, "stale": bool(remote_err), "stale_message": remote_err if remote_err else None, "updated_at": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), } def format_ohlcv_detail(bar: dict[str, Any] | None, tick: Optional[float]) -> dict[str, str]: if not bar: return {"open": "-", "high": "-", "low": "-", "close": "-", "volume": "-"} return { "open": format_price_by_tick(bar.get("open"), tick), "high": format_price_by_tick(bar.get("high"), tick), "low": format_price_by_tick(bar.get("low"), tick), "close": format_price_by_tick(bar.get("close"), tick), "volume": format_price_by_tick(bar.get("volume"), tick), }