157 lines
5.5 KiB
Python
157 lines
5.5 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
|
|
from sqlalchemy import desc, select
|
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from .models import AlertRecord, Base, KvStore, RuntimeLog
|
|
|
|
DEFAULT_CHART_BAR = "1D"
|
|
|
|
|
|
class Storage:
|
|
def __init__(self, database_url: str) -> None:
|
|
self.engine = create_async_engine(database_url, pool_pre_ping=True)
|
|
self.session_factory = async_sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
|
|
|
|
async def init_db(self) -> None:
|
|
async with self.engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
await self._ensure_default_kv()
|
|
|
|
async def _ensure_default_kv(self) -> None:
|
|
current = await self.get_kv("chart_bar")
|
|
if current is None:
|
|
await self.set_kv("chart_bar", DEFAULT_CHART_BAR)
|
|
|
|
async def get_kv(self, key: str) -> str | None:
|
|
async with self.session_factory() as session:
|
|
row = await session.get(KvStore, key)
|
|
return row.value if row else None
|
|
|
|
async def set_kv(self, key: str, value: str) -> None:
|
|
async with self.session_factory() as session:
|
|
await session.execute(
|
|
sqlite_insert(KvStore)
|
|
.values(key=key, value=value, updated_at=datetime.utcnow())
|
|
.on_conflict_do_update(
|
|
index_elements=["key"],
|
|
set_={"value": value, "updated_at": datetime.utcnow()},
|
|
)
|
|
)
|
|
await session.commit()
|
|
|
|
async def has_recent_alert(
|
|
self,
|
|
symbol: str,
|
|
*,
|
|
chain: str,
|
|
within_hours: float,
|
|
) -> bool:
|
|
"""同一 symbol + chain 在 within_hours 内是否已有告警(用于去重显示与推送)。"""
|
|
if within_hours <= 0:
|
|
return False
|
|
sym = symbol.strip().upper()
|
|
cutoff = datetime.utcnow() - timedelta(hours=within_hours)
|
|
async with self.session_factory() as session:
|
|
stmt = (
|
|
select(AlertRecord.id)
|
|
.where(
|
|
AlertRecord.symbol == sym,
|
|
AlertRecord.chain == chain,
|
|
AlertRecord.created_at > cutoff,
|
|
)
|
|
.limit(1)
|
|
)
|
|
row = (await session.execute(stmt)).scalar_one_or_none()
|
|
return row is not None
|
|
|
|
async def add_alert(
|
|
self,
|
|
symbol: str,
|
|
venue: str,
|
|
trigger_types: list[str],
|
|
score: float,
|
|
details: dict,
|
|
) -> None:
|
|
async with self.session_factory() as session:
|
|
session.add(
|
|
AlertRecord(
|
|
symbol=symbol.strip().upper(),
|
|
chain=venue,
|
|
trigger_types=",".join(trigger_types),
|
|
score=score,
|
|
details_json=json.dumps(details, ensure_ascii=False),
|
|
)
|
|
)
|
|
await session.commit()
|
|
|
|
async def add_log(self, level: str, message: str) -> None:
|
|
async with self.session_factory() as session:
|
|
session.add(RuntimeLog(level=level.upper(), message=message))
|
|
await session.commit()
|
|
|
|
async def get_recent_alerts(self, limit: int = 100) -> list[dict]:
|
|
async with self.session_factory() as session:
|
|
stmt = select(AlertRecord).order_by(desc(AlertRecord.created_at)).limit(limit)
|
|
rows = (await session.execute(stmt)).scalars().all()
|
|
return [
|
|
{
|
|
"id": row.id,
|
|
"symbol": row.symbol,
|
|
"chain": row.chain,
|
|
"trigger_types": row.trigger_types.split(",") if row.trigger_types else [],
|
|
"score": row.score,
|
|
"details": json.loads(row.details_json),
|
|
"created_at": row.created_at.isoformat(),
|
|
}
|
|
for row in rows
|
|
]
|
|
|
|
async def get_recent_logs(self, limit: int = 200) -> list[dict]:
|
|
async with self.session_factory() as session:
|
|
stmt = select(RuntimeLog).order_by(desc(RuntimeLog.created_at)).limit(limit)
|
|
rows = (await session.execute(stmt)).scalars().all()
|
|
return [
|
|
{
|
|
"id": row.id,
|
|
"level": row.level,
|
|
"message": row.message,
|
|
"created_at": row.created_at.isoformat(),
|
|
}
|
|
for row in rows
|
|
]
|
|
|
|
async def get_alerts_between(
|
|
self,
|
|
start_utc_naive: datetime,
|
|
end_utc_naive: datetime,
|
|
limit: int = 2000,
|
|
) -> list[dict]:
|
|
async with self.session_factory() as session:
|
|
stmt = (
|
|
select(AlertRecord)
|
|
.where(AlertRecord.created_at >= start_utc_naive, AlertRecord.created_at < end_utc_naive)
|
|
.order_by(desc(AlertRecord.created_at))
|
|
.limit(limit)
|
|
)
|
|
rows = (await session.execute(stmt)).scalars().all()
|
|
return [
|
|
{
|
|
"id": row.id,
|
|
"symbol": row.symbol,
|
|
"chain": row.chain,
|
|
"trigger_types": row.trigger_types.split(",") if row.trigger_types else [],
|
|
"score": row.score,
|
|
"details": json.loads(row.details_json),
|
|
"created_at": row.created_at.isoformat(),
|
|
}
|
|
for row in rows
|
|
]
|
|
|
|
async def close(self) -> None:
|
|
await self.engine.dispose()
|