Files
2026-05-22 22:15:46 +08:00

341 lines
12 KiB
Python

from __future__ import annotations
import json
from datetime import datetime, timedelta
from sqlalchemy import desc, select
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from .models import AlertRecord, Base, KeyMonitor, KeyMonitorHistory, KvStore, RuntimeLog
DEFAULT_CHART_BAR = "1D"
class Storage:
def __init__(self, database_url: str) -> None:
self.engine = create_async_engine(database_url, pool_pre_ping=True)
self.session_factory = async_sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
async def init_db(self) -> None:
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await self._ensure_default_kv()
async def _ensure_default_kv(self) -> None:
current = await self.get_kv("chart_bar")
if current is None:
await self.set_kv("chart_bar", DEFAULT_CHART_BAR)
async def get_kv(self, key: str) -> str | None:
async with self.session_factory() as session:
row = await session.get(KvStore, key)
return row.value if row else None
async def set_kv(self, key: str, value: str) -> None:
async with self.session_factory() as session:
await session.execute(
sqlite_insert(KvStore)
.values(key=key, value=value, updated_at=datetime.utcnow())
.on_conflict_do_update(
index_elements=["key"],
set_={"value": value, "updated_at": datetime.utcnow()},
)
)
await session.commit()
async def has_recent_alert(
self,
symbol: str,
*,
chain: str,
within_hours: float,
) -> bool:
"""同一 symbol + chain 在 within_hours 内是否已有告警(用于去重显示与推送)。"""
if within_hours <= 0:
return False
sym = symbol.strip().upper()
cutoff = datetime.utcnow() - timedelta(hours=within_hours)
async with self.session_factory() as session:
stmt = (
select(AlertRecord.id)
.where(
AlertRecord.symbol == sym,
AlertRecord.chain == chain,
AlertRecord.created_at > cutoff,
)
.limit(1)
)
row = (await session.execute(stmt)).scalar_one_or_none()
return row is not None
async def add_alert(
self,
symbol: str,
venue: str,
trigger_types: list[str],
score: float,
details: dict,
) -> None:
async with self.session_factory() as session:
session.add(
AlertRecord(
symbol=symbol.strip().upper(),
chain=venue,
trigger_types=",".join(trigger_types),
score=score,
details_json=json.dumps(details, ensure_ascii=False),
)
)
await session.commit()
async def add_log(self, level: str, message: str) -> None:
async with self.session_factory() as session:
session.add(RuntimeLog(level=level.upper(), message=message))
await session.commit()
async def get_recent_alerts(self, limit: int = 100) -> list[dict]:
async with self.session_factory() as session:
stmt = select(AlertRecord).order_by(desc(AlertRecord.created_at)).limit(limit)
rows = (await session.execute(stmt)).scalars().all()
return [
{
"id": row.id,
"symbol": row.symbol,
"chain": row.chain,
"trigger_types": row.trigger_types.split(",") if row.trigger_types else [],
"score": row.score,
"details": json.loads(row.details_json),
"created_at": row.created_at.isoformat(),
}
for row in rows
]
async def get_recent_logs(self, limit: int = 200) -> list[dict]:
async with self.session_factory() as session:
stmt = select(RuntimeLog).order_by(desc(RuntimeLog.created_at)).limit(limit)
rows = (await session.execute(stmt)).scalars().all()
return [
{
"id": row.id,
"level": row.level,
"message": row.message,
"created_at": row.created_at.isoformat(),
}
for row in rows
]
async def get_alerts_between(
self,
start_utc_naive: datetime,
end_utc_naive: datetime,
limit: int = 2000,
) -> list[dict]:
async with self.session_factory() as session:
stmt = (
select(AlertRecord)
.where(AlertRecord.created_at >= start_utc_naive, AlertRecord.created_at < end_utc_naive)
.order_by(desc(AlertRecord.created_at))
.limit(limit)
)
rows = (await session.execute(stmt)).scalars().all()
return [
{
"id": row.id,
"symbol": row.symbol,
"chain": row.chain,
"trigger_types": row.trigger_types.split(",") if row.trigger_types else [],
"score": row.score,
"details": json.loads(row.details_json),
"created_at": row.created_at.isoformat(),
}
for row in rows
]
async def add_key_monitor(
self,
*,
symbol: str,
inst_id: str,
monitor_type: str,
direction: str,
upper: float,
lower: float,
sl_tp_mode: str,
manual_take_profit: float | None,
stop_outside_pct: float,
breakeven_enabled: int,
note: str | None = None,
) -> int:
async with self.session_factory() as session:
row = KeyMonitor(
symbol=symbol.strip().upper(),
inst_id=inst_id.strip().upper(),
monitor_type=monitor_type,
direction=direction.strip().lower(),
upper=upper,
lower=lower,
sl_tp_mode=sl_tp_mode,
manual_take_profit=manual_take_profit,
stop_outside_pct=stop_outside_pct,
breakeven_enabled=breakeven_enabled,
note=note,
)
session.add(row)
await session.commit()
await session.refresh(row)
return int(row.id)
async def list_key_monitors(self) -> list[dict]:
async with self.session_factory() as session:
stmt = select(KeyMonitor).order_by(desc(KeyMonitor.created_at))
rows = (await session.execute(stmt)).scalars().all()
return [_key_monitor_to_dict(r) for r in rows]
async def get_key_monitor(self, kid: int) -> dict | None:
async with self.session_factory() as session:
row = await session.get(KeyMonitor, kid)
return _key_monitor_to_dict(row) if row else None
async def delete_key_monitor(self, kid: int) -> bool:
async with self.session_factory() as session:
row = await session.get(KeyMonitor, kid)
if not row:
return False
await session.delete(row)
await session.commit()
return True
async def finalize_key_monitor(
self,
row: dict,
*,
close_reason: str,
last_alert_message: str | None,
confirm_close: float | None,
planned_sl: float | None,
planned_tp: float | None,
planned_rr: float | None,
executor_signal_id: str | None,
executor_status: str | None,
checks: dict | None,
) -> int:
async with self.session_factory() as session:
hist = KeyMonitorHistory(
key_monitor_id=int(row["id"]),
symbol=row["symbol"],
inst_id=row["inst_id"],
monitor_type=row["monitor_type"],
direction=row["direction"],
upper=float(row["upper"]),
lower=float(row["lower"]),
sl_tp_mode=row["sl_tp_mode"],
manual_take_profit=row.get("manual_take_profit"),
stop_outside_pct=float(row["stop_outside_pct"]),
confirm_close=confirm_close,
planned_sl=planned_sl,
planned_tp=planned_tp,
planned_rr=planned_rr,
executor_signal_id=executor_signal_id,
executor_status=executor_status,
checks_json=json.dumps(checks or {}, ensure_ascii=False),
last_alert_message=last_alert_message,
close_reason=close_reason,
)
session.add(hist)
active = await session.get(KeyMonitor, int(row["id"]))
if active:
await session.delete(active)
await session.commit()
await session.refresh(hist)
return int(hist.id)
async def list_key_monitor_history(
self,
*,
limit: int = 500,
since: datetime | None = None,
) -> list[dict]:
async with self.session_factory() as session:
stmt = select(KeyMonitorHistory).order_by(desc(KeyMonitorHistory.closed_at)).limit(limit)
if since is not None:
stmt = stmt.where(KeyMonitorHistory.closed_at >= since)
rows = (await session.execute(stmt)).scalars().all()
return [_key_history_to_dict(r) for r in rows]
async def delete_key_monitor_history(self, hid: int) -> bool:
async with self.session_factory() as session:
row = await session.get(KeyMonitorHistory, hid)
if not row:
return False
await session.delete(row)
await session.commit()
return True
async def export_key_monitor_history_rows(
self,
*,
start_utc: datetime,
end_utc: datetime,
) -> list[dict]:
async with self.session_factory() as session:
stmt = (
select(KeyMonitorHistory)
.where(
KeyMonitorHistory.closed_at >= start_utc,
KeyMonitorHistory.closed_at <= end_utc,
)
.order_by(KeyMonitorHistory.id.asc())
)
rows = (await session.execute(stmt)).scalars().all()
return [_key_history_to_dict(r) for r in rows]
async def close(self) -> None:
await self.engine.dispose()
def _key_monitor_to_dict(row: KeyMonitor | None) -> dict:
if row is None:
return {}
return {
"id": row.id,
"symbol": row.symbol,
"inst_id": row.inst_id,
"monitor_type": row.monitor_type,
"direction": row.direction,
"upper": row.upper,
"lower": row.lower,
"sl_tp_mode": row.sl_tp_mode,
"manual_take_profit": row.manual_take_profit,
"stop_outside_pct": row.stop_outside_pct,
"breakeven_enabled": row.breakeven_enabled,
"note": row.note,
"created_at": row.created_at.isoformat(),
}
def _key_history_to_dict(row: KeyMonitorHistory) -> dict:
return {
"id": row.id,
"key_monitor_id": row.key_monitor_id,
"symbol": row.symbol,
"inst_id": row.inst_id,
"monitor_type": row.monitor_type,
"direction": row.direction,
"upper": row.upper,
"lower": row.lower,
"sl_tp_mode": row.sl_tp_mode,
"manual_take_profit": row.manual_take_profit,
"stop_outside_pct": row.stop_outside_pct,
"confirm_close": row.confirm_close,
"planned_sl": row.planned_sl,
"planned_tp": row.planned_tp,
"planned_rr": row.planned_rr,
"executor_signal_id": row.executor_signal_id,
"executor_status": row.executor_status,
"checks_json": row.checks_json,
"last_alert_message": row.last_alert_message,
"close_reason": row.close_reason,
"closed_at": row.closed_at.isoformat(),
}