"""中控 K 线 SQLite 缓存:按需拉取、15 天滚动保留。""" from __future__ import annotations import os import sqlite3 import time from pathlib import Path from typing import Any, Callable, Optional from hub_ohlcv_lib import ( TIMEFRAME_MS, bar_limit_for_timeframe, chart_fetch_start_ms, format_price_by_tick, last_closed_bar_open_ms, normalize_chart_timeframe, window_start_ms, ) _DEFAULT_RETENTION_DAYS = 15 def retention_days() -> int: try: return max(1, int(os.getenv("HUB_KLINE_RETENTION_DAYS", str(_DEFAULT_RETENTION_DAYS)))) except ValueError: return _DEFAULT_RETENTION_DAYS def default_db_path() -> Path: raw = (os.getenv("HUB_KLINE_DB_PATH") or "").strip() if raw: return Path(raw) hub_dir = Path(__file__).resolve().parent / "manual_trading_hub" / "data" hub_dir.mkdir(parents=True, exist_ok=True) return hub_dir / "hub_kline.db" def _connect(db_path: Path | None = None) -> sqlite3.Connection: path = db_path or default_db_path() path.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(str(path), timeout=30, isolation_level=None) conn.row_factory = sqlite3.Row conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA synchronous=NORMAL") return conn def init_db(db_path: Path | None = None) -> None: conn = _connect(db_path) try: conn.execute( """ CREATE TABLE IF NOT EXISTS ohlcv_bars ( exchange_key TEXT NOT NULL, symbol TEXT NOT NULL, timeframe TEXT NOT NULL, open_time_ms INTEGER NOT NULL, open REAL NOT NULL, high REAL NOT NULL, low REAL NOT NULL, close REAL NOT NULL, volume REAL NOT NULL DEFAULT 0, updated_at INTEGER NOT NULL, PRIMARY KEY (exchange_key, symbol, timeframe, open_time_ms) ) """ ) conn.execute( """ CREATE INDEX IF NOT EXISTS idx_ohlcv_series ON ohlcv_bars (exchange_key, symbol, timeframe, open_time_ms) """ ) finally: conn.close() def purge_retention(db_path: Path | None = None, *, days: int | None = None) -> int: """删除早于 retention 的 K 线;返回删除行数。""" keep = days if days is not None else retention_days() cutoff = int(time.time() * 1000) - keep * 86400000 conn = _connect(db_path) try: cur = conn.execute("DELETE FROM ohlcv_bars WHERE open_time_ms < ?", (cutoff,)) return int(cur.rowcount or 0) finally: conn.close() def upsert_bars( exchange_key: str, symbol: str, timeframe: str, bars: list[dict[str, Any]], db_path: Path | None = None, ) -> int: if not bars: return 0 ex_k = (exchange_key or "").strip().lower() sym = (symbol or "").strip().upper() tf = normalize_chart_timeframe(timeframe) now = int(time.time()) conn = _connect(db_path) n = 0 try: for b in bars: try: oms = int(b["open_time_ms"]) conn.execute( """ INSERT INTO ohlcv_bars (exchange_key, symbol, timeframe, open_time_ms, open, high, low, close, volume, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?) ON CONFLICT(exchange_key, symbol, timeframe, open_time_ms) DO UPDATE SET open=excluded.open, high=excluded.high, low=excluded.low, close=excluded.close, volume=excluded.volume, updated_at=excluded.updated_at """, ( ex_k, sym, tf, oms, float(b["open"]), float(b["high"]), float(b["low"]), float(b["close"]), float(b.get("volume") or 0), now, ), ) n += 1 except (KeyError, TypeError, ValueError): continue finally: conn.close() return n def load_bars_range( exchange_key: str, symbol: str, timeframe: str, start_ms: int, end_ms: int, db_path: Path | None = None, ) -> list[dict[str, Any]]: ex_k = (exchange_key or "").strip().lower() sym = (symbol or "").strip().upper() tf = normalize_chart_timeframe(timeframe) conn = _connect(db_path) try: rows = conn.execute( """ SELECT open_time_ms, open, high, low, close, volume FROM ohlcv_bars WHERE exchange_key=? AND symbol=? AND timeframe=? AND open_time_ms >= ? AND open_time_ms <= ? ORDER BY open_time_ms ASC """, (ex_k, sym, tf, int(start_ms), int(end_ms)), ).fetchall() return [ { "open_time_ms": int(r["open_time_ms"]), "open": float(r["open"]), "high": float(r["high"]), "low": float(r["low"]), "close": float(r["close"]), "volume": float(r["volume"] or 0), } for r in rows ] finally: conn.close() def _to_chart_candles(bars: list[dict[str, Any]]) -> list[dict[str, Any]]: out = [] for b in bars: try: out.append( { "time": int(b["open_time_ms"] // 1000), "open": float(b["open"]), "high": float(b["high"]), "low": float(b["low"]), "close": float(b["close"]), "volume": float(b.get("volume") or 0), } ) except (KeyError, TypeError, ValueError): continue return out def _merge_bars(*groups: list[dict[str, Any]]) -> list[dict[str, Any]]: merged: dict[int, dict[str, Any]] = {} for g in groups: for b in g or []: try: merged[int(b["open_time_ms"])] = b except (KeyError, TypeError, ValueError): continue return [merged[k] for k in sorted(merged.keys())] def resolve_chart_bars( exchange_key: str, symbol: str, timeframe: str, remote_fetch: Callable[..., dict[str, Any]], *, db_path: Path | None = None, force_refresh: bool = False, ) -> dict[str, Any]: """ 按需:先读库,不足则 remote_fetch(symbol, timeframe, since_ms, limit) 补齐并写库。 无后台定时任务;每次调用时顺带 purge 15 天前数据。 """ init_db(db_path) purged = purge_retention(db_path) sym = (symbol or "").strip().upper() ex_k = (exchange_key or "").strip().lower() tf = normalize_chart_timeframe(timeframe) if not sym or not ex_k: return {"ok": False, "msg": "缺少 exchange 或 symbol"} need = bar_limit_for_timeframe(tf) now_ms = int(time.time() * 1000) fetch_start_ms = chart_fetch_start_ms(tf, need, now_ms) db_read_start_ms = window_start_ms(tf, need, retention_days(), now_ms) last_closed = last_closed_bar_open_ms(tf, now_ms) db_rows: list[dict[str, Any]] = [] if not force_refresh: period_ms = TIMEFRAME_MS[tf] db_rows = load_bars_range( ex_k, sym, tf, max(0, db_read_start_ms - period_ms), now_ms + period_ms, db_path ) newest_db = db_rows[-1]["open_time_ms"] if db_rows else None period_ms = TIMEFRAME_MS[tf] newest_ok = newest_db is not None and int(newest_db) >= int(last_closed) - period_ms need_fetch = force_refresh or len(db_rows) < need or not newest_ok fetched = 0 price_tick: Optional[float] = None remote_err: Optional[str] = None if need_fetch: since = fetch_start_ms if db_rows and not force_refresh: since = min(since, int(db_rows[0]["open_time_ms"])) remote = remote_fetch( symbol=sym, timeframe=tf, since_ms=since, limit=need + 20, ) if remote.get("ok") and remote.get("bars"): fetched = upsert_bars(ex_k, sym, tf, remote["bars"], db_path) price_tick = remote.get("price_tick") db_rows = load_bars_range(ex_k, sym, tf, fetch_start_ms, now_ms, db_path) else: remote_err = remote.get("msg") or remote.get("error") or "实例拉取 K 线失败" if not db_rows: return {"ok": False, "msg": remote_err, "purged": purged} if len(db_rows) > need: db_rows = db_rows[-need:] candles = _to_chart_candles(db_rows) if not candles: return {"ok": False, "msg": remote_err or "无 K 线数据", "purged": purged} from_cache = max(0, len(candles) - (1 if fetched else 0)) if fetched: from_cache = max(0, len(candles) - min(fetched, len(candles))) return { "ok": True, "symbol": sym, "exchange_key": ex_k, "timeframe": tf, "limit": need, "retention_days": retention_days(), "candles": candles, "from_cache": from_cache, "fetched": fetched, "purged": purged, "price_tick": price_tick, "stale": bool(remote_err), "stale_message": remote_err if remote_err else None, "updated_at": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), } def format_ohlcv_detail(bar: dict[str, Any] | None, tick: Optional[float]) -> dict[str, str]: if not bar: return {"open": "-", "high": "-", "low": "-", "close": "-", "volume": "-"} return { "open": format_price_by_tick(bar.get("open"), tick), "high": format_price_by_tick(bar.get("high"), tick), "low": format_price_by_tick(bar.get("low"), tick), "close": format_price_by_tick(bar.get("close"), tick), "volume": format_price_by_tick(bar.get("volume"), tick), }