from __future__ import annotations import json from datetime import datetime, timedelta from sqlalchemy import desc, select from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from .models import AlertRecord, Base, KvStore, RuntimeLog DEFAULT_CHART_BAR = "1D" class Storage: def __init__(self, database_url: str) -> None: self.engine = create_async_engine(database_url, pool_pre_ping=True) self.session_factory = async_sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession) async def init_db(self) -> None: async with self.engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) await self._ensure_default_kv() async def _ensure_default_kv(self) -> None: current = await self.get_kv("chart_bar") if current is None: await self.set_kv("chart_bar", DEFAULT_CHART_BAR) async def get_kv(self, key: str) -> str | None: async with self.session_factory() as session: row = await session.get(KvStore, key) return row.value if row else None async def set_kv(self, key: str, value: str) -> None: async with self.session_factory() as session: await session.execute( sqlite_insert(KvStore) .values(key=key, value=value, updated_at=datetime.utcnow()) .on_conflict_do_update( index_elements=["key"], set_={"value": value, "updated_at": datetime.utcnow()}, ) ) await session.commit() async def has_recent_alert( self, symbol: str, *, chain: str, within_hours: float, ) -> bool: """同一 symbol + chain 在 within_hours 内是否已有告警(用于去重显示与推送)。""" if within_hours <= 0: return False sym = symbol.strip().upper() cutoff = datetime.utcnow() - timedelta(hours=within_hours) async with self.session_factory() as session: stmt = ( select(AlertRecord.id) .where( AlertRecord.symbol == sym, AlertRecord.chain == chain, AlertRecord.created_at > cutoff, ) .limit(1) ) row = (await session.execute(stmt)).scalar_one_or_none() return row is not None async def add_alert( self, symbol: str, venue: str, trigger_types: list[str], score: float, details: dict, ) -> None: async with self.session_factory() as session: session.add( AlertRecord( symbol=symbol.strip().upper(), chain=venue, trigger_types=",".join(trigger_types), score=score, details_json=json.dumps(details, ensure_ascii=False), ) ) await session.commit() async def add_log(self, level: str, message: str) -> None: async with self.session_factory() as session: session.add(RuntimeLog(level=level.upper(), message=message)) await session.commit() async def get_recent_alerts(self, limit: int = 100) -> list[dict]: async with self.session_factory() as session: stmt = select(AlertRecord).order_by(desc(AlertRecord.created_at)).limit(limit) rows = (await session.execute(stmt)).scalars().all() return [ { "id": row.id, "symbol": row.symbol, "chain": row.chain, "trigger_types": row.trigger_types.split(",") if row.trigger_types else [], "score": row.score, "details": json.loads(row.details_json), "created_at": row.created_at.isoformat(), } for row in rows ] async def get_recent_logs(self, limit: int = 200) -> list[dict]: async with self.session_factory() as session: stmt = select(RuntimeLog).order_by(desc(RuntimeLog.created_at)).limit(limit) rows = (await session.execute(stmt)).scalars().all() return [ { "id": row.id, "level": row.level, "message": row.message, "created_at": row.created_at.isoformat(), } for row in rows ] async def get_alerts_between( self, start_utc_naive: datetime, end_utc_naive: datetime, limit: int = 2000, ) -> list[dict]: async with self.session_factory() as session: stmt = ( select(AlertRecord) .where(AlertRecord.created_at >= start_utc_naive, AlertRecord.created_at < end_utc_naive) .order_by(desc(AlertRecord.created_at)) .limit(limit) ) rows = (await session.execute(stmt)).scalars().all() return [ { "id": row.id, "symbol": row.symbol, "chain": row.chain, "trigger_types": row.trigger_types.split(",") if row.trigger_types else [], "score": row.score, "details": json.loads(row.details_json), "created_at": row.created_at.isoformat(), } for row in rows ] async def close(self) -> None: await self.engine.dispose()