Files
qihuo/recommend_stream.py
T
2026-06-25 17:29:10 +08:00

159 lines
5.5 KiB
Python

"""品种推荐 SSE 推送与后台刷新。"""
from __future__ import annotations
import json
import logging
import queue
import threading
import time
from typing import Callable, Optional
from db_conn import connect_db
from kline_stream import sse_format
from recommend_store import (
load_recommend_cache,
recommend_cache_needs_refresh,
recommend_payload,
refresh_recommend_cache,
)
logger = logging.getLogger(__name__)
CHECK_INTERVAL_SEC = 3600
_refresh_lock = threading.Lock()
_refresh_running = False
def schedule_recommend_refresh(
*,
db_path: str,
get_capital_fn: Callable,
quote_fn: Callable[[str], Optional[dict]],
init_tables_fn: Callable | None = None,
get_mode_fn: Callable[[], str] | None = None,
get_max_margin_pct_fn: Callable[[], float] | None = None,
get_sizing_mode_fn: Callable[[], str] | None = None,
get_fixed_lots_fn: Callable[[], int] | None = None,
) -> None:
"""后台刷新推荐缓存(不阻塞页面请求)。"""
global _refresh_running
with _refresh_lock:
if _refresh_running:
return
_refresh_running = True
def _run() -> None:
global _refresh_running
try:
conn = connect_db(db_path)
try:
if init_tables_fn:
init_tables_fn(conn)
capital = float(get_capital_fn(conn) or 0)
mode = get_mode_fn() if get_mode_fn else "simulation"
max_pct = float(get_max_margin_pct_fn()) if get_max_margin_pct_fn else 30.0
cached = load_recommend_cache(conn)
if not recommend_cache_needs_refresh(cached, capital=capital):
payload = recommend_payload(
conn,
live_capital=capital,
max_margin_pct=max_pct,
trading_mode=mode,
sizing_mode=get_sizing_mode_fn() if get_sizing_mode_fn else "fixed",
fixed_lots=get_fixed_lots_fn() if get_fixed_lots_fn else 1,
)
recommend_hub.broadcast("recommend", {"ok": True, **payload})
return
refresh_recommend_cache(
conn, capital, quote_fn, trading_mode=mode, max_margin_pct=max_pct,
)
cached = load_recommend_cache(conn)
logger.info(
"品种推荐后台刷新完成,capital=%.2f rows=%d",
capital, len(cached.get("rows") or []),
)
payload = recommend_payload(
conn,
live_capital=capital,
max_margin_pct=max_pct,
trading_mode=mode,
sizing_mode=get_sizing_mode_fn() if get_sizing_mode_fn else "fixed",
fixed_lots=get_fixed_lots_fn() if get_fixed_lots_fn else 1,
)
finally:
conn.close()
recommend_hub.broadcast("recommend", {"ok": True, **payload})
except Exception as exc:
logger.warning("recommend background refresh failed: %s", exc)
finally:
with _refresh_lock:
_refresh_running = False
threading.Thread(target=_run, daemon=True, name="recommend-refresh").start()
class RecommendStreamHub:
def __init__(self) -> None:
self._lock = threading.Lock()
self._subs: list[queue.Queue] = []
def subscribe(self) -> queue.Queue:
q: queue.Queue = queue.Queue(maxsize=8)
with self._lock:
self._subs.append(q)
return q
def unsubscribe(self, q: queue.Queue) -> None:
with self._lock:
try:
self._subs.remove(q)
except ValueError:
pass
def broadcast(self, event: str, data: dict) -> None:
msg = {"event": event, "data": data}
with self._lock:
subs = list(self._subs)
for q in subs:
try:
q.put_nowait(msg)
except queue.Full:
pass
recommend_hub = RecommendStreamHub()
def start_recommend_worker(
*,
db_path: str,
get_capital_fn: Callable,
quote_fn: Callable[[str], Optional[dict]],
init_tables_fn: Callable | None = None,
get_mode_fn: Callable[[], str] | None = None,
get_max_margin_pct_fn: Callable[[], float] | None = None,
get_sizing_mode_fn: Callable[[], str] | None = None,
get_fixed_lots_fn: Callable[[], int] | None = None,
interval: int = CHECK_INTERVAL_SEC,
) -> None:
"""后台每日刷新推荐(每小时检查一次是否需更新),并推送给 SSE 订阅者。"""
def _loop() -> None:
while True:
try:
schedule_recommend_refresh(
db_path=db_path,
get_capital_fn=get_capital_fn,
quote_fn=quote_fn,
init_tables_fn=init_tables_fn,
get_mode_fn=get_mode_fn,
get_max_margin_pct_fn=get_max_margin_pct_fn,
get_sizing_mode_fn=get_sizing_mode_fn,
get_fixed_lots_fn=get_fixed_lots_fn,
)
except Exception as exc:
logger.warning("recommend worker failed: %s", exc)
time.sleep(max(300, interval))
threading.Thread(target=_loop, daemon=True, name="recommend-worker").start()