"""交易监管:后台扫描 + SSE 版本通知。""" from __future__ import annotations import asyncio import json import os from collections.abc import AsyncIterator, Awaitable, Callable from typing import Any SUPERVISOR_POLL_INTERVAL_SEC = float(os.getenv("SUPERVISOR_POLL_INTERVAL_SEC", "30")) SUPERVISOR_SSE_HEARTBEAT_SEC = float(os.getenv("SUPERVISOR_SSE_HEARTBEAT_SEC", "25")) TickFn = Callable[[], Awaitable[dict[str, Any]]] class SupervisorStore: def __init__(self) -> None: self._lock = asyncio.Lock() self.version = 0 self.last_result: dict[str, Any] | None = None self.last_error: str | None = None self._subscribers: list[asyncio.Queue[str | None]] = [] self._task: asyncio.Task | None = None self._stop = asyncio.Event() self._refresh = asyncio.Event() self._tick_fn: TickFn | None = None async def start(self, tick_fn: TickFn) -> None: if self._task and not self._task.done(): return self._tick_fn = tick_fn self._stop.clear() self._task = asyncio.create_task(self._loop(), name="hub-supervisor-poll") async def stop(self) -> None: self._stop.set() self._refresh.set() if self._task: self._task.cancel() try: await self._task except asyncio.CancelledError: pass self._task = None self._broadcast(close=True) def request_refresh(self) -> None: self._refresh.set() def bump(self) -> None: self.version += 1 self._broadcast() def event_dict(self) -> dict[str, Any]: r = self.last_result or {} return { "supervisor_version": self.version, "ok": r.get("ok", True), "events": r.get("events", 0), "trading_day": r.get("trading_day"), "session_id": r.get("session_id"), "error": self.last_error, } async def _loop(self) -> None: assert self._tick_fn is not None while not self._stop.is_set(): await self._tick_once(self._tick_fn) if self._stop.is_set(): break self._refresh.clear() sleep_task = asyncio.create_task(asyncio.sleep(SUPERVISOR_POLL_INTERVAL_SEC)) refresh_task = asyncio.create_task(self._refresh.wait()) done, pending = await asyncio.wait( {sleep_task, refresh_task}, return_when=asyncio.FIRST_COMPLETED, ) for t in pending: t.cancel() async def _tick_once(self, tick_fn: TickFn) -> None: async with self._lock: try: result = await tick_fn() if not isinstance(result, dict): result = {"ok": False, "msg": "invalid_tick"} except Exception as e: result = {"ok": False, "msg": str(e)} self.last_error = str(e) else: self.last_error = None if result.get("ok") is not False else str( result.get("msg") or "tick_failed" ) self.last_result = result if int(result.get("events") or 0) > 0: self.version += 1 self._broadcast() def _broadcast(self, *, close: bool = False) -> None: dead: list[asyncio.Queue[str | None]] = [] payload = None if close else json.dumps(self.event_dict(), ensure_ascii=False) for q in self._subscribers: try: q.put_nowait(payload) except asyncio.QueueFull: try: q.get_nowait() except asyncio.QueueEmpty: pass try: q.put_nowait(payload) except asyncio.QueueFull: dead.append(q) except Exception: dead.append(q) for q in dead: if q in self._subscribers: self._subscribers.remove(q) async def iter_sse(self) -> AsyncIterator[str]: q: asyncio.Queue[str | None] = asyncio.Queue(maxsize=32) self._subscribers.append(q) try: yield _sse_frame(self.event_dict()) while True: try: raw = await asyncio.wait_for(q.get(), timeout=SUPERVISOR_SSE_HEARTBEAT_SEC) except asyncio.TimeoutError: yield ": heartbeat\n\n" continue if raw is None: break try: data = json.loads(raw) except Exception: data = self.event_dict() yield _sse_frame(data) finally: if q in self._subscribers: self._subscribers.remove(q) def _sse_frame(data: dict[str, Any]) -> str: body = json.dumps(data, ensure_ascii=False) return f"event: supervisor\ndata: {body}\n\n" supervisor_store = SupervisorStore()