from __future__ import annotations import json from datetime import datetime, timedelta from sqlalchemy import desc, select from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from .models import AlertRecord, Base, KeyMonitor, KeyMonitorHistory, KvStore, RuntimeLog DEFAULT_CHART_BAR = "1D" class Storage: def __init__(self, database_url: str) -> None: self.engine = create_async_engine(database_url, pool_pre_ping=True) self.session_factory = async_sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession) async def init_db(self) -> None: async with self.engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) await self._ensure_default_kv() async def _ensure_default_kv(self) -> None: current = await self.get_kv("chart_bar") if current is None: await self.set_kv("chart_bar", DEFAULT_CHART_BAR) async def get_kv(self, key: str) -> str | None: async with self.session_factory() as session: row = await session.get(KvStore, key) return row.value if row else None async def set_kv(self, key: str, value: str) -> None: async with self.session_factory() as session: await session.execute( sqlite_insert(KvStore) .values(key=key, value=value, updated_at=datetime.utcnow()) .on_conflict_do_update( index_elements=["key"], set_={"value": value, "updated_at": datetime.utcnow()}, ) ) await session.commit() async def has_recent_alert( self, symbol: str, *, chain: str, within_hours: float, ) -> bool: """同一 symbol + chain 在 within_hours 内是否已有告警(用于去重显示与推送)。""" if within_hours <= 0: return False sym = symbol.strip().upper() cutoff = datetime.utcnow() - timedelta(hours=within_hours) async with self.session_factory() as session: stmt = ( select(AlertRecord.id) .where( AlertRecord.symbol == sym, AlertRecord.chain == chain, AlertRecord.created_at > cutoff, ) .limit(1) ) row = (await session.execute(stmt)).scalar_one_or_none() return row is not None async def add_alert( self, symbol: str, venue: str, trigger_types: list[str], score: float, details: dict, ) -> None: async with self.session_factory() as session: session.add( AlertRecord( symbol=symbol.strip().upper(), chain=venue, trigger_types=",".join(trigger_types), score=score, details_json=json.dumps(details, ensure_ascii=False), ) ) await session.commit() async def add_log(self, level: str, message: str) -> None: async with self.session_factory() as session: session.add(RuntimeLog(level=level.upper(), message=message)) await session.commit() async def get_recent_alerts(self, limit: int = 100) -> list[dict]: async with self.session_factory() as session: stmt = select(AlertRecord).order_by(desc(AlertRecord.created_at)).limit(limit) rows = (await session.execute(stmt)).scalars().all() return [ { "id": row.id, "symbol": row.symbol, "chain": row.chain, "trigger_types": row.trigger_types.split(",") if row.trigger_types else [], "score": row.score, "details": json.loads(row.details_json), "created_at": row.created_at.isoformat(), } for row in rows ] async def get_recent_logs(self, limit: int = 200) -> list[dict]: async with self.session_factory() as session: stmt = select(RuntimeLog).order_by(desc(RuntimeLog.created_at)).limit(limit) rows = (await session.execute(stmt)).scalars().all() return [ { "id": row.id, "level": row.level, "message": row.message, "created_at": row.created_at.isoformat(), } for row in rows ] async def get_alerts_between( self, start_utc_naive: datetime, end_utc_naive: datetime, limit: int = 2000, ) -> list[dict]: async with self.session_factory() as session: stmt = ( select(AlertRecord) .where(AlertRecord.created_at >= start_utc_naive, AlertRecord.created_at < end_utc_naive) .order_by(desc(AlertRecord.created_at)) .limit(limit) ) rows = (await session.execute(stmt)).scalars().all() return [ { "id": row.id, "symbol": row.symbol, "chain": row.chain, "trigger_types": row.trigger_types.split(",") if row.trigger_types else [], "score": row.score, "details": json.loads(row.details_json), "created_at": row.created_at.isoformat(), } for row in rows ] async def add_key_monitor( self, *, symbol: str, inst_id: str, monitor_type: str, direction: str, upper: float, lower: float, sl_tp_mode: str, manual_take_profit: float | None, stop_outside_pct: float, breakeven_enabled: int, note: str | None = None, ) -> int: async with self.session_factory() as session: row = KeyMonitor( symbol=symbol.strip().upper(), inst_id=inst_id.strip().upper(), monitor_type=monitor_type, direction=direction.strip().lower(), upper=upper, lower=lower, sl_tp_mode=sl_tp_mode, manual_take_profit=manual_take_profit, stop_outside_pct=stop_outside_pct, breakeven_enabled=breakeven_enabled, note=note, ) session.add(row) await session.commit() await session.refresh(row) return int(row.id) async def list_key_monitors(self) -> list[dict]: async with self.session_factory() as session: stmt = select(KeyMonitor).order_by(desc(KeyMonitor.created_at)) rows = (await session.execute(stmt)).scalars().all() return [_key_monitor_to_dict(r) for r in rows] async def get_key_monitor(self, kid: int) -> dict | None: async with self.session_factory() as session: row = await session.get(KeyMonitor, kid) return _key_monitor_to_dict(row) if row else None async def delete_key_monitor(self, kid: int) -> bool: async with self.session_factory() as session: row = await session.get(KeyMonitor, kid) if not row: return False await session.delete(row) await session.commit() return True async def finalize_key_monitor( self, row: dict, *, close_reason: str, last_alert_message: str | None, confirm_close: float | None, planned_sl: float | None, planned_tp: float | None, planned_rr: float | None, executor_signal_id: str | None, executor_status: str | None, checks: dict | None, ) -> int: async with self.session_factory() as session: hist = KeyMonitorHistory( key_monitor_id=int(row["id"]), symbol=row["symbol"], inst_id=row["inst_id"], monitor_type=row["monitor_type"], direction=row["direction"], upper=float(row["upper"]), lower=float(row["lower"]), sl_tp_mode=row["sl_tp_mode"], manual_take_profit=row.get("manual_take_profit"), stop_outside_pct=float(row["stop_outside_pct"]), confirm_close=confirm_close, planned_sl=planned_sl, planned_tp=planned_tp, planned_rr=planned_rr, executor_signal_id=executor_signal_id, executor_status=executor_status, checks_json=json.dumps(checks or {}, ensure_ascii=False), last_alert_message=last_alert_message, close_reason=close_reason, ) session.add(hist) active = await session.get(KeyMonitor, int(row["id"])) if active: await session.delete(active) await session.commit() await session.refresh(hist) return int(hist.id) async def list_key_monitor_history( self, *, limit: int = 500, since: datetime | None = None, ) -> list[dict]: async with self.session_factory() as session: stmt = select(KeyMonitorHistory).order_by(desc(KeyMonitorHistory.closed_at)).limit(limit) if since is not None: stmt = stmt.where(KeyMonitorHistory.closed_at >= since) rows = (await session.execute(stmt)).scalars().all() return [_key_history_to_dict(r) for r in rows] async def delete_key_monitor_history(self, hid: int) -> bool: async with self.session_factory() as session: row = await session.get(KeyMonitorHistory, hid) if not row: return False await session.delete(row) await session.commit() return True async def export_key_monitor_history_rows( self, *, start_utc: datetime, end_utc: datetime, ) -> list[dict]: async with self.session_factory() as session: stmt = ( select(KeyMonitorHistory) .where( KeyMonitorHistory.closed_at >= start_utc, KeyMonitorHistory.closed_at <= end_utc, ) .order_by(KeyMonitorHistory.id.asc()) ) rows = (await session.execute(stmt)).scalars().all() return [_key_history_to_dict(r) for r in rows] async def close(self) -> None: await self.engine.dispose() def _key_monitor_to_dict(row: KeyMonitor | None) -> dict: if row is None: return {} return { "id": row.id, "symbol": row.symbol, "inst_id": row.inst_id, "monitor_type": row.monitor_type, "direction": row.direction, "upper": row.upper, "lower": row.lower, "sl_tp_mode": row.sl_tp_mode, "manual_take_profit": row.manual_take_profit, "stop_outside_pct": row.stop_outside_pct, "breakeven_enabled": row.breakeven_enabled, "note": row.note, "created_at": row.created_at.isoformat(), } def _key_history_to_dict(row: KeyMonitorHistory) -> dict: return { "id": row.id, "key_monitor_id": row.key_monitor_id, "symbol": row.symbol, "inst_id": row.inst_id, "monitor_type": row.monitor_type, "direction": row.direction, "upper": row.upper, "lower": row.lower, "sl_tp_mode": row.sl_tp_mode, "manual_take_profit": row.manual_take_profit, "stop_outside_pct": row.stop_outside_pct, "confirm_close": row.confirm_close, "planned_sl": row.planned_sl, "planned_tp": row.planned_tp, "planned_rr": row.planned_rr, "executor_signal_id": row.executor_signal_id, "executor_status": row.executor_status, "checks_json": row.checks_json, "last_alert_message": row.last_alert_message, "close_reason": row.close_reason, "closed_at": row.closed_at.isoformat(), }