"""品种推荐 SSE 推送与后台刷新。""" from __future__ import annotations import json import logging import queue import threading import time from typing import Callable, Optional from db_conn import connect_db from kline_stream import sse_format from recommend_store import ( load_recommend_cache, recommend_cache_needs_refresh, recommend_payload, refresh_recommend_cache, ) logger = logging.getLogger(__name__) CHECK_INTERVAL_SEC = 3600 _refresh_lock = threading.Lock() _refresh_running = False def schedule_recommend_refresh( *, db_path: str, get_capital_fn: Callable, quote_fn: Callable[[str], Optional[dict]], init_tables_fn: Callable | None = None, get_mode_fn: Callable[[], str] | None = None, get_max_margin_pct_fn: Callable[[], float] | None = None, get_sizing_mode_fn: Callable[[], str] | None = None, get_fixed_lots_fn: Callable[[], int] | None = None, ) -> None: """后台刷新推荐缓存(不阻塞页面请求)。""" global _refresh_running with _refresh_lock: if _refresh_running: return _refresh_running = True def _run() -> None: global _refresh_running try: conn = connect_db(db_path) try: if init_tables_fn: init_tables_fn(conn) capital = float(get_capital_fn(conn) or 0) mode = get_mode_fn() if get_mode_fn else "simulation" max_pct = float(get_max_margin_pct_fn()) if get_max_margin_pct_fn else 30.0 cached = load_recommend_cache(conn) if not recommend_cache_needs_refresh(cached, capital=capital): payload = recommend_payload( conn, live_capital=capital, max_margin_pct=max_pct, trading_mode=mode, sizing_mode=get_sizing_mode_fn() if get_sizing_mode_fn else "fixed", fixed_lots=get_fixed_lots_fn() if get_fixed_lots_fn else 1, ) recommend_hub.broadcast("recommend", {"ok": True, **payload}) return refresh_recommend_cache( conn, capital, quote_fn, trading_mode=mode, max_margin_pct=max_pct, ) cached = load_recommend_cache(conn) logger.info( "品种推荐后台刷新完成,capital=%.2f rows=%d", capital, len(cached.get("rows") or []), ) payload = recommend_payload( conn, live_capital=capital, max_margin_pct=max_pct, trading_mode=mode, sizing_mode=get_sizing_mode_fn() if get_sizing_mode_fn else "fixed", fixed_lots=get_fixed_lots_fn() if get_fixed_lots_fn else 1, ) finally: conn.close() recommend_hub.broadcast("recommend", {"ok": True, **payload}) except Exception as exc: logger.warning("recommend background refresh failed: %s", exc) finally: with _refresh_lock: _refresh_running = False threading.Thread(target=_run, daemon=True, name="recommend-refresh").start() class RecommendStreamHub: def __init__(self) -> None: self._lock = threading.Lock() self._subs: list[queue.Queue] = [] def subscribe(self) -> queue.Queue: q: queue.Queue = queue.Queue(maxsize=8) with self._lock: self._subs.append(q) return q def unsubscribe(self, q: queue.Queue) -> None: with self._lock: try: self._subs.remove(q) except ValueError: pass def broadcast(self, event: str, data: dict) -> None: msg = {"event": event, "data": data} with self._lock: subs = list(self._subs) for q in subs: try: q.put_nowait(msg) except queue.Full: pass recommend_hub = RecommendStreamHub() def start_recommend_worker( *, db_path: str, get_capital_fn: Callable, quote_fn: Callable[[str], Optional[dict]], init_tables_fn: Callable | None = None, get_mode_fn: Callable[[], str] | None = None, get_max_margin_pct_fn: Callable[[], float] | None = None, get_sizing_mode_fn: Callable[[], str] | None = None, get_fixed_lots_fn: Callable[[], int] | None = None, interval: int = CHECK_INTERVAL_SEC, ) -> None: """后台每日刷新推荐(每小时检查一次是否需更新),并推送给 SSE 订阅者。""" def _loop() -> None: while True: try: schedule_recommend_refresh( db_path=db_path, get_capital_fn=get_capital_fn, quote_fn=quote_fn, init_tables_fn=init_tables_fn, get_mode_fn=get_mode_fn, get_max_margin_pct_fn=get_max_margin_pct_fn, get_sizing_mode_fn=get_sizing_mode_fn, get_fixed_lots_fn=get_fixed_lots_fn, ) except Exception as exc: logger.warning("recommend worker failed: %s", exc) time.sleep(max(300, interval)) threading.Thread(target=_loop, daemon=True, name="recommend-worker").start()