首次上传
This commit is contained in:
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from .models import AlertRecord, Base, KvStore, RuntimeLog
|
||||
|
||||
DEFAULT_CHART_BAR = "1D"
|
||||
|
||||
|
||||
class Storage:
|
||||
def __init__(self, database_url: str) -> None:
|
||||
self.engine = create_async_engine(database_url, pool_pre_ping=True)
|
||||
self.session_factory = async_sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
|
||||
|
||||
async def init_db(self) -> None:
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await self._ensure_default_kv()
|
||||
|
||||
async def _ensure_default_kv(self) -> None:
|
||||
current = await self.get_kv("chart_bar")
|
||||
if current is None:
|
||||
await self.set_kv("chart_bar", DEFAULT_CHART_BAR)
|
||||
|
||||
async def get_kv(self, key: str) -> str | None:
|
||||
async with self.session_factory() as session:
|
||||
row = await session.get(KvStore, key)
|
||||
return row.value if row else None
|
||||
|
||||
async def set_kv(self, key: str, value: str) -> None:
|
||||
async with self.session_factory() as session:
|
||||
await session.execute(
|
||||
sqlite_insert(KvStore)
|
||||
.values(key=key, value=value, updated_at=datetime.utcnow())
|
||||
.on_conflict_do_update(
|
||||
index_elements=["key"],
|
||||
set_={"value": value, "updated_at": datetime.utcnow()},
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def has_recent_alert(
|
||||
self,
|
||||
symbol: str,
|
||||
*,
|
||||
chain: str,
|
||||
within_hours: float,
|
||||
) -> bool:
|
||||
"""同一 symbol + chain 在 within_hours 内是否已有告警(用于去重显示与推送)。"""
|
||||
if within_hours <= 0:
|
||||
return False
|
||||
sym = symbol.strip().upper()
|
||||
cutoff = datetime.utcnow() - timedelta(hours=within_hours)
|
||||
async with self.session_factory() as session:
|
||||
stmt = (
|
||||
select(AlertRecord.id)
|
||||
.where(
|
||||
AlertRecord.symbol == sym,
|
||||
AlertRecord.chain == chain,
|
||||
AlertRecord.created_at > cutoff,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
return row is not None
|
||||
|
||||
async def add_alert(
|
||||
self,
|
||||
symbol: str,
|
||||
venue: str,
|
||||
trigger_types: list[str],
|
||||
score: float,
|
||||
details: dict,
|
||||
) -> None:
|
||||
async with self.session_factory() as session:
|
||||
session.add(
|
||||
AlertRecord(
|
||||
symbol=symbol.strip().upper(),
|
||||
chain=venue,
|
||||
trigger_types=",".join(trigger_types),
|
||||
score=score,
|
||||
details_json=json.dumps(details, ensure_ascii=False),
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def add_log(self, level: str, message: str) -> None:
|
||||
async with self.session_factory() as session:
|
||||
session.add(RuntimeLog(level=level.upper(), message=message))
|
||||
await session.commit()
|
||||
|
||||
async def get_recent_alerts(self, limit: int = 100) -> list[dict]:
|
||||
async with self.session_factory() as session:
|
||||
stmt = select(AlertRecord).order_by(desc(AlertRecord.created_at)).limit(limit)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": row.id,
|
||||
"symbol": row.symbol,
|
||||
"chain": row.chain,
|
||||
"trigger_types": row.trigger_types.split(",") if row.trigger_types else [],
|
||||
"score": row.score,
|
||||
"details": json.loads(row.details_json),
|
||||
"created_at": row.created_at.isoformat(),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_recent_logs(self, limit: int = 200) -> list[dict]:
|
||||
async with self.session_factory() as session:
|
||||
stmt = select(RuntimeLog).order_by(desc(RuntimeLog.created_at)).limit(limit)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": row.id,
|
||||
"level": row.level,
|
||||
"message": row.message,
|
||||
"created_at": row.created_at.isoformat(),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_alerts_between(
|
||||
self,
|
||||
start_utc_naive: datetime,
|
||||
end_utc_naive: datetime,
|
||||
limit: int = 2000,
|
||||
) -> list[dict]:
|
||||
async with self.session_factory() as session:
|
||||
stmt = (
|
||||
select(AlertRecord)
|
||||
.where(AlertRecord.created_at >= start_utc_naive, AlertRecord.created_at < end_utc_naive)
|
||||
.order_by(desc(AlertRecord.created_at))
|
||||
.limit(limit)
|
||||
)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": row.id,
|
||||
"symbol": row.symbol,
|
||||
"chain": row.chain,
|
||||
"trigger_types": row.trigger_types.split(",") if row.trigger_types else [],
|
||||
"score": row.score,
|
||||
"details": json.loads(row.details_json),
|
||||
"created_at": row.created_at.isoformat(),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.engine.dispose()
|
||||
Reference in New Issue
Block a user