feat(hub): background chart poll with SSE for positions and market watch
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,259 @@
|
||||
"""行情区 K 线:后台轮询订阅 + SSE 版本通知(对齐监控区 board)。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from hub_board_cache import board_store
|
||||
|
||||
HUB_CHART_POLL_INTERVAL = float(os.getenv("HUB_CHART_POLL_INTERVAL", "5"))
|
||||
HUB_CHART_SSE_HEARTBEAT_SEC = float(os.getenv("HUB_CHART_SSE_HEARTBEAT_SEC", "25"))
|
||||
HUB_CHART_WATCH_TTL_SEC = float(os.getenv("HUB_CHART_WATCH_TTL_SEC", "45"))
|
||||
HUB_CHART_POSITION_TIMEFRAME = (os.getenv("HUB_CHART_POSITION_TIMEFRAME", "5m") or "5m").strip()
|
||||
HUB_CHART_MAX_SERIES_PER_TICK = max(1, int(os.getenv("HUB_CHART_MAX_SERIES_PER_TICK", "24")))
|
||||
|
||||
PollFn = Callable[[], Awaitable[dict[str, Any]]]
|
||||
|
||||
|
||||
def series_key(exchange_key: str, symbol: str, timeframe: str) -> str:
|
||||
ex_k = (exchange_key or "").strip().lower()
|
||||
sym = (symbol or "").strip().upper()
|
||||
tf = (timeframe or "").strip()
|
||||
return f"{ex_k}|{sym}|{tf}"
|
||||
|
||||
|
||||
def parse_series_key(key: str) -> tuple[str, str, str] | None:
|
||||
parts = (key or "").split("|")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
ex_k, sym, tf = parts[0].strip().lower(), parts[1].strip().upper(), parts[2].strip()
|
||||
if not ex_k or not sym or not tf:
|
||||
return None
|
||||
return ex_k, sym, tf
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeriesState:
|
||||
version: int = 0
|
||||
updated_at: str | None = None
|
||||
fetched: int = 0
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ChartPollStore:
|
||||
def __init__(self) -> None:
|
||||
self._lock = asyncio.Lock()
|
||||
self.version = 0
|
||||
self.updated_at: str | None = None
|
||||
self.polling = False
|
||||
self.last_error: str | None = None
|
||||
self._watch_until: dict[str, float] = {}
|
||||
self._position_keys: set[str] = set()
|
||||
self._series: dict[str, SeriesState] = {}
|
||||
self._subscribers: list[asyncio.Queue[str | None]] = []
|
||||
self._task: asyncio.Task | None = None
|
||||
self._stop = asyncio.Event()
|
||||
self._refresh = asyncio.Event()
|
||||
self._poll_fn: PollFn | None = None
|
||||
|
||||
async def start(self, poll_fn: PollFn) -> None:
|
||||
if self._task and not self._task.done():
|
||||
return
|
||||
self._poll_fn = poll_fn
|
||||
self._stop.clear()
|
||||
self._task = asyncio.create_task(self._loop(), name="hub-chart-poll")
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._stop.set()
|
||||
self._refresh.set()
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._task = None
|
||||
self._broadcast(close=True)
|
||||
|
||||
def request_refresh(self) -> None:
|
||||
self._refresh.set()
|
||||
|
||||
def touch_watch(self, exchange_key: str, symbol: str, timeframe: str) -> str:
|
||||
key = series_key(exchange_key, symbol, timeframe)
|
||||
self._watch_until[key] = time.monotonic() + HUB_CHART_WATCH_TTL_SEC
|
||||
return key
|
||||
|
||||
def clear_watch(self, exchange_key: str, symbol: str, timeframe: str) -> None:
|
||||
key = series_key(exchange_key, symbol, timeframe)
|
||||
self._watch_until.pop(key, None)
|
||||
|
||||
def sync_positions_from_rows(self, rows: list[Any]) -> None:
|
||||
keys: set[str] = set()
|
||||
tf = HUB_CHART_POSITION_TIMEFRAME
|
||||
for row in rows or []:
|
||||
if not isinstance(row, dict):
|
||||
continue
|
||||
ex_key = str(row.get("key") or row.get("exchange_key") or "").strip().lower()
|
||||
if not ex_key:
|
||||
ex_id = str(row.get("id") or "").strip()
|
||||
if ex_id:
|
||||
ex_key = ex_id.lower()
|
||||
if not ex_key:
|
||||
continue
|
||||
ag = row.get("agent") if isinstance(row.get("agent"), dict) else {}
|
||||
if ag.get("ok") is False:
|
||||
continue
|
||||
for pos in ag.get("positions") or []:
|
||||
if not isinstance(pos, dict):
|
||||
continue
|
||||
sym = str(pos.get("symbol") or "").strip().upper()
|
||||
if sym:
|
||||
keys.add(series_key(ex_key, sym, tf))
|
||||
self._position_keys = keys
|
||||
|
||||
def active_series_keys(self) -> list[str]:
|
||||
now = time.monotonic()
|
||||
watch = {k for k, until in self._watch_until.items() if until > now}
|
||||
merged = self._position_keys | watch
|
||||
return sorted(merged)[:HUB_CHART_MAX_SERIES_PER_TICK]
|
||||
|
||||
def series_event_dict(self) -> dict[str, Any]:
|
||||
out: dict[str, Any] = {}
|
||||
for key, st in self._series.items():
|
||||
out[key] = {
|
||||
"series_version": st.version,
|
||||
"updated_at": st.updated_at,
|
||||
"fetched": st.fetched,
|
||||
"error": st.error,
|
||||
}
|
||||
return out
|
||||
|
||||
def event_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"chart_version": self.version,
|
||||
"updated_at": self.updated_at,
|
||||
"polling": self.polling,
|
||||
"ok": self.last_error is None,
|
||||
"error": self.last_error,
|
||||
"series": self.series_event_dict(),
|
||||
"poll_interval_sec": HUB_CHART_POLL_INTERVAL,
|
||||
"position_timeframe": HUB_CHART_POSITION_TIMEFRAME,
|
||||
}
|
||||
|
||||
def series_version(self, exchange_key: str, symbol: str, timeframe: str) -> int:
|
||||
key = series_key(exchange_key, symbol, timeframe)
|
||||
st = self._series.get(key)
|
||||
return st.version if st else 0
|
||||
|
||||
async def _loop(self) -> None:
|
||||
assert self._poll_fn is not None
|
||||
while not self._stop.is_set():
|
||||
await self._poll_once(self._poll_fn)
|
||||
if self._stop.is_set():
|
||||
break
|
||||
self._refresh.clear()
|
||||
sleep_task = asyncio.create_task(asyncio.sleep(HUB_CHART_POLL_INTERVAL))
|
||||
refresh_task = asyncio.create_task(self._refresh.wait())
|
||||
done, pending = await asyncio.wait(
|
||||
{sleep_task, refresh_task},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
|
||||
async def _poll_once(self, poll_fn: PollFn) -> None:
|
||||
async with self._lock:
|
||||
self.polling = True
|
||||
self._broadcast()
|
||||
try:
|
||||
snap = board_store.snapshot_dict()
|
||||
rows = snap.get("rows") if isinstance(snap, dict) else []
|
||||
if isinstance(rows, list):
|
||||
self.sync_positions_from_rows(rows)
|
||||
result = await poll_fn()
|
||||
if not isinstance(result, dict):
|
||||
result = {"ok": False, "msg": "chart poll 返回无效"}
|
||||
except Exception as e:
|
||||
result = {"ok": False, "msg": str(e), "error": "chart_poll_failed"}
|
||||
async with self._lock:
|
||||
self.version += 1
|
||||
self.updated_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
self.last_error = None if result.get("ok") is not False else (
|
||||
str(result.get("msg") or result.get("error") or "chart_poll_failed")
|
||||
)
|
||||
self.polling = False
|
||||
self._broadcast()
|
||||
|
||||
def note_series_result(
|
||||
self,
|
||||
exchange_key: str,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
*,
|
||||
ok: bool,
|
||||
fetched: int = 0,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
key = series_key(exchange_key, symbol, timeframe)
|
||||
st = self._series.setdefault(key, SeriesState())
|
||||
st.version += 1
|
||||
st.updated_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
st.fetched = int(fetched or 0)
|
||||
st.error = error if not ok else None
|
||||
|
||||
def _broadcast(self, *, close: bool = False) -> None:
|
||||
dead: list[asyncio.Queue[str | None]] = []
|
||||
payload = None if close else json.dumps(self.event_dict(), ensure_ascii=False)
|
||||
for q in self._subscribers:
|
||||
try:
|
||||
q.put_nowait(payload)
|
||||
except asyncio.QueueFull:
|
||||
try:
|
||||
q.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
try:
|
||||
q.put_nowait(payload)
|
||||
except asyncio.QueueFull:
|
||||
dead.append(q)
|
||||
except Exception:
|
||||
dead.append(q)
|
||||
for q in dead:
|
||||
if q in self._subscribers:
|
||||
self._subscribers.remove(q)
|
||||
|
||||
async def iter_sse(self) -> AsyncIterator[str]:
|
||||
q: asyncio.Queue[str | None] = asyncio.Queue(maxsize=32)
|
||||
self._subscribers.append(q)
|
||||
try:
|
||||
yield _sse_frame(self.event_dict())
|
||||
while True:
|
||||
try:
|
||||
raw = await asyncio.wait_for(q.get(), timeout=HUB_CHART_SSE_HEARTBEAT_SEC)
|
||||
except asyncio.TimeoutError:
|
||||
yield ": heartbeat\n\n"
|
||||
continue
|
||||
if raw is None:
|
||||
break
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except Exception:
|
||||
data = self.event_dict()
|
||||
yield _sse_frame(data)
|
||||
finally:
|
||||
if q in self._subscribers:
|
||||
self._subscribers.remove(q)
|
||||
|
||||
|
||||
def _sse_frame(data: dict[str, Any]) -> str:
|
||||
body = json.dumps(data, ensure_ascii=False)
|
||||
return f"event: chart\ndata: {body}\n\n"
|
||||
|
||||
|
||||
chart_poll_store = ChartPollStore()
|
||||
Reference in New Issue
Block a user