"""行情区 K 线:后台轮询订阅 + SSE 推送尾部 K 线(对齐监控区 board)。""" from __future__ import annotations import asyncio import json import os import time from collections.abc import AsyncIterator, Awaitable, Callable from dataclasses import dataclass from typing import Any from hub_board_cache import board_store HUB_CHART_POLL_INTERVAL = float(os.getenv("HUB_CHART_POLL_INTERVAL", "5")) HUB_CHART_SSE_HEARTBEAT_SEC = float(os.getenv("HUB_CHART_SSE_HEARTBEAT_SEC", "25")) HUB_CHART_WATCH_TTL_SEC = float(os.getenv("HUB_CHART_WATCH_TTL_SEC", "45")) HUB_CHART_POSITION_TIMEFRAME = (os.getenv("HUB_CHART_POSITION_TIMEFRAME", "5m") or "5m").strip() HUB_CHART_MAX_SERIES_PER_TICK = max(1, int(os.getenv("HUB_CHART_MAX_SERIES_PER_TICK", "24"))) HUB_CHART_SSE_TAIL_BARS = max(5, min(int(os.getenv("HUB_CHART_SSE_TAIL_BARS", "30")), 120)) PollFn = Callable[[], Awaitable[dict[str, Any]]] def series_key(exchange_key: str, symbol: str, timeframe: str) -> str: ex_k = (exchange_key or "").strip().lower() sym = (symbol or "").strip().upper() tf = (timeframe or "").strip() return f"{ex_k}|{sym}|{tf}" def parse_series_key(key: str) -> tuple[str, str, str] | None: parts = (key or "").split("|") if len(parts) != 3: return None ex_k, sym, tf = parts[0].strip().lower(), parts[1].strip().upper(), parts[2].strip() if not ex_k or not sym or not tf: return None return ex_k, sym, tf @dataclass class SeriesState: version: int = 0 updated_at: str | None = None fetched: int = 0 error: str | None = None class ChartPollStore: def __init__(self) -> None: self._lock = asyncio.Lock() self.version = 0 self.updated_at: str | None = None self.polling = False self.last_error: str | None = None self._watch_until: dict[str, float] = {} self._position_keys: set[str] = set() self._series: dict[str, SeriesState] = {} self._push_tails: dict[str, dict[str, Any]] = {} self._subscribers: list[asyncio.Queue[str | None]] = [] self._task: asyncio.Task | None = None self._stop = asyncio.Event() self._refresh = asyncio.Event() self._poll_fn: PollFn | None = None async def start(self, poll_fn: PollFn) -> None: if self._task and not self._task.done(): return self._poll_fn = poll_fn self._stop.clear() self._task = asyncio.create_task(self._loop(), name="hub-chart-poll") async def stop(self) -> None: self._stop.set() self._refresh.set() if self._task: self._task.cancel() try: await self._task except asyncio.CancelledError: pass self._task = None self._broadcast(close=True) def request_refresh(self) -> None: self._refresh.set() def touch_watch(self, exchange_key: str, symbol: str, timeframe: str) -> str: key = series_key(exchange_key, symbol, timeframe) self._watch_until[key] = time.monotonic() + HUB_CHART_WATCH_TTL_SEC return key def clear_watch(self, exchange_key: str, symbol: str, timeframe: str) -> None: key = series_key(exchange_key, symbol, timeframe) self._watch_until.pop(key, None) def sync_positions_from_rows(self, rows: list[Any]) -> None: keys: set[str] = set() tf = HUB_CHART_POSITION_TIMEFRAME for row in rows or []: if not isinstance(row, dict): continue ex_key = str(row.get("key") or row.get("exchange_key") or "").strip().lower() if not ex_key: ex_id = str(row.get("id") or "").strip() if ex_id: ex_key = ex_id.lower() if not ex_key: continue ag = row.get("agent") if isinstance(row.get("agent"), dict) else {} if ag.get("ok") is False: continue for pos in ag.get("positions") or []: if not isinstance(pos, dict): continue sym = str(pos.get("symbol") or "").strip().upper() if sym: keys.add(series_key(ex_key, sym, tf)) self._position_keys = keys def active_series_keys(self) -> list[str]: now = time.monotonic() watch = {k for k, until in self._watch_until.items() if until > now} merged = self._position_keys | watch return sorted(merged)[:HUB_CHART_MAX_SERIES_PER_TICK] def series_event_dict(self) -> dict[str, Any]: out: dict[str, Any] = {} for key, st in self._series.items(): out[key] = { "series_version": st.version, "updated_at": st.updated_at, "fetched": st.fetched, "error": st.error, } return out def event_dict(self, *, tails: dict[str, dict[str, Any]] | None = None) -> dict[str, Any]: out: dict[str, Any] = { "chart_version": self.version, "updated_at": self.updated_at, "polling": self.polling, "ok": self.last_error is None, "error": self.last_error, "series": self.series_event_dict(), "poll_interval_sec": HUB_CHART_POLL_INTERVAL, "position_timeframe": HUB_CHART_POSITION_TIMEFRAME, "push_tails": True, } tail_map = tails if tails is not None else self._push_tails if tail_map: out["tails"] = tail_map return out def series_version(self, exchange_key: str, symbol: str, timeframe: str) -> int: key = series_key(exchange_key, symbol, timeframe) st = self._series.get(key) return st.version if st else 0 async def _loop(self) -> None: assert self._poll_fn is not None while not self._stop.is_set(): await self._poll_once(self._poll_fn) if self._stop.is_set(): break self._refresh.clear() sleep_task = asyncio.create_task(asyncio.sleep(HUB_CHART_POLL_INTERVAL)) refresh_task = asyncio.create_task(self._refresh.wait()) done, pending = await asyncio.wait( {sleep_task, refresh_task}, return_when=asyncio.FIRST_COMPLETED, ) for t in pending: t.cancel() async def _poll_once(self, poll_fn: PollFn) -> None: async with self._lock: self.polling = True self._broadcast() try: snap = board_store.snapshot_dict() rows = snap.get("rows") if isinstance(snap, dict) else [] if isinstance(rows, list): self.sync_positions_from_rows(rows) result = await poll_fn() if not isinstance(result, dict): result = {"ok": False, "msg": "chart poll 返回无效"} except Exception as e: result = {"ok": False, "msg": str(e), "error": "chart_poll_failed"} async with self._lock: self.version += 1 self.updated_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) self.last_error = None if result.get("ok") is not False else ( str(result.get("msg") or result.get("error") or "chart_poll_failed") ) self.polling = False self._broadcast() def note_series_result( self, exchange_key: str, symbol: str, timeframe: str, *, ok: bool, fetched: int = 0, error: str | None = None, candles: list[dict[str, Any]] | None = None, price_tick: Any = None, ) -> None: key = series_key(exchange_key, symbol, timeframe) st = self._series.setdefault(key, SeriesState()) st.version += 1 st.updated_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) st.fetched = int(fetched or 0) st.error = error if not ok else None if ok and candles: tail = list(candles[-HUB_CHART_SSE_TAIL_BARS :]) if tail: self._push_tails[key] = { "series_version": st.version, "updated_at": st.updated_at, "fetched": st.fetched, "candles": tail, "price_tick": price_tick, } def _broadcast(self, *, close: bool = False) -> None: dead: list[asyncio.Queue[str | None]] = [] tails_snap = dict(self._push_tails) self._push_tails.clear() payload = None if close else json.dumps(self.event_dict(tails=tails_snap), ensure_ascii=False) for q in self._subscribers: try: q.put_nowait(payload) except asyncio.QueueFull: try: q.get_nowait() except asyncio.QueueEmpty: pass try: q.put_nowait(payload) except asyncio.QueueFull: dead.append(q) except Exception: dead.append(q) for q in dead: if q in self._subscribers: self._subscribers.remove(q) async def iter_sse(self) -> AsyncIterator[str]: q: asyncio.Queue[str | None] = asyncio.Queue(maxsize=32) self._subscribers.append(q) try: yield _sse_frame(self.event_dict()) while True: try: raw = await asyncio.wait_for(q.get(), timeout=HUB_CHART_SSE_HEARTBEAT_SEC) except asyncio.TimeoutError: yield ": heartbeat\n\n" continue if raw is None: break try: data = json.loads(raw) except Exception: data = self.event_dict() yield _sse_frame(data) finally: if q in self._subscribers: self._subscribers.remove(q) def _sse_frame(data: dict[str, Any]) -> str: body = json.dumps(data, ensure_ascii=False) return f"event: chart\ndata: {body}\n\n" chart_poll_store = ChartPollStore()