341 lines
12 KiB
Python
341 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
|
|
from sqlalchemy import desc, select
|
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from .models import AlertRecord, Base, KeyMonitor, KeyMonitorHistory, KvStore, RuntimeLog
|
|
|
|
DEFAULT_CHART_BAR = "1D"
|
|
|
|
|
|
class Storage:
|
|
def __init__(self, database_url: str) -> None:
|
|
self.engine = create_async_engine(database_url, pool_pre_ping=True)
|
|
self.session_factory = async_sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
|
|
|
|
async def init_db(self) -> None:
|
|
async with self.engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
await self._ensure_default_kv()
|
|
|
|
async def _ensure_default_kv(self) -> None:
|
|
current = await self.get_kv("chart_bar")
|
|
if current is None:
|
|
await self.set_kv("chart_bar", DEFAULT_CHART_BAR)
|
|
|
|
async def get_kv(self, key: str) -> str | None:
|
|
async with self.session_factory() as session:
|
|
row = await session.get(KvStore, key)
|
|
return row.value if row else None
|
|
|
|
async def set_kv(self, key: str, value: str) -> None:
|
|
async with self.session_factory() as session:
|
|
await session.execute(
|
|
sqlite_insert(KvStore)
|
|
.values(key=key, value=value, updated_at=datetime.utcnow())
|
|
.on_conflict_do_update(
|
|
index_elements=["key"],
|
|
set_={"value": value, "updated_at": datetime.utcnow()},
|
|
)
|
|
)
|
|
await session.commit()
|
|
|
|
async def has_recent_alert(
|
|
self,
|
|
symbol: str,
|
|
*,
|
|
chain: str,
|
|
within_hours: float,
|
|
) -> bool:
|
|
"""同一 symbol + chain 在 within_hours 内是否已有告警(用于去重显示与推送)。"""
|
|
if within_hours <= 0:
|
|
return False
|
|
sym = symbol.strip().upper()
|
|
cutoff = datetime.utcnow() - timedelta(hours=within_hours)
|
|
async with self.session_factory() as session:
|
|
stmt = (
|
|
select(AlertRecord.id)
|
|
.where(
|
|
AlertRecord.symbol == sym,
|
|
AlertRecord.chain == chain,
|
|
AlertRecord.created_at > cutoff,
|
|
)
|
|
.limit(1)
|
|
)
|
|
row = (await session.execute(stmt)).scalar_one_or_none()
|
|
return row is not None
|
|
|
|
async def add_alert(
|
|
self,
|
|
symbol: str,
|
|
venue: str,
|
|
trigger_types: list[str],
|
|
score: float,
|
|
details: dict,
|
|
) -> None:
|
|
async with self.session_factory() as session:
|
|
session.add(
|
|
AlertRecord(
|
|
symbol=symbol.strip().upper(),
|
|
chain=venue,
|
|
trigger_types=",".join(trigger_types),
|
|
score=score,
|
|
details_json=json.dumps(details, ensure_ascii=False),
|
|
)
|
|
)
|
|
await session.commit()
|
|
|
|
async def add_log(self, level: str, message: str) -> None:
|
|
async with self.session_factory() as session:
|
|
session.add(RuntimeLog(level=level.upper(), message=message))
|
|
await session.commit()
|
|
|
|
async def get_recent_alerts(self, limit: int = 100) -> list[dict]:
|
|
async with self.session_factory() as session:
|
|
stmt = select(AlertRecord).order_by(desc(AlertRecord.created_at)).limit(limit)
|
|
rows = (await session.execute(stmt)).scalars().all()
|
|
return [
|
|
{
|
|
"id": row.id,
|
|
"symbol": row.symbol,
|
|
"chain": row.chain,
|
|
"trigger_types": row.trigger_types.split(",") if row.trigger_types else [],
|
|
"score": row.score,
|
|
"details": json.loads(row.details_json),
|
|
"created_at": row.created_at.isoformat(),
|
|
}
|
|
for row in rows
|
|
]
|
|
|
|
async def get_recent_logs(self, limit: int = 200) -> list[dict]:
|
|
async with self.session_factory() as session:
|
|
stmt = select(RuntimeLog).order_by(desc(RuntimeLog.created_at)).limit(limit)
|
|
rows = (await session.execute(stmt)).scalars().all()
|
|
return [
|
|
{
|
|
"id": row.id,
|
|
"level": row.level,
|
|
"message": row.message,
|
|
"created_at": row.created_at.isoformat(),
|
|
}
|
|
for row in rows
|
|
]
|
|
|
|
async def get_alerts_between(
|
|
self,
|
|
start_utc_naive: datetime,
|
|
end_utc_naive: datetime,
|
|
limit: int = 2000,
|
|
) -> list[dict]:
|
|
async with self.session_factory() as session:
|
|
stmt = (
|
|
select(AlertRecord)
|
|
.where(AlertRecord.created_at >= start_utc_naive, AlertRecord.created_at < end_utc_naive)
|
|
.order_by(desc(AlertRecord.created_at))
|
|
.limit(limit)
|
|
)
|
|
rows = (await session.execute(stmt)).scalars().all()
|
|
return [
|
|
{
|
|
"id": row.id,
|
|
"symbol": row.symbol,
|
|
"chain": row.chain,
|
|
"trigger_types": row.trigger_types.split(",") if row.trigger_types else [],
|
|
"score": row.score,
|
|
"details": json.loads(row.details_json),
|
|
"created_at": row.created_at.isoformat(),
|
|
}
|
|
for row in rows
|
|
]
|
|
|
|
async def add_key_monitor(
|
|
self,
|
|
*,
|
|
symbol: str,
|
|
inst_id: str,
|
|
monitor_type: str,
|
|
direction: str,
|
|
upper: float,
|
|
lower: float,
|
|
sl_tp_mode: str,
|
|
manual_take_profit: float | None,
|
|
stop_outside_pct: float,
|
|
breakeven_enabled: int,
|
|
note: str | None = None,
|
|
) -> int:
|
|
async with self.session_factory() as session:
|
|
row = KeyMonitor(
|
|
symbol=symbol.strip().upper(),
|
|
inst_id=inst_id.strip().upper(),
|
|
monitor_type=monitor_type,
|
|
direction=direction.strip().lower(),
|
|
upper=upper,
|
|
lower=lower,
|
|
sl_tp_mode=sl_tp_mode,
|
|
manual_take_profit=manual_take_profit,
|
|
stop_outside_pct=stop_outside_pct,
|
|
breakeven_enabled=breakeven_enabled,
|
|
note=note,
|
|
)
|
|
session.add(row)
|
|
await session.commit()
|
|
await session.refresh(row)
|
|
return int(row.id)
|
|
|
|
async def list_key_monitors(self) -> list[dict]:
|
|
async with self.session_factory() as session:
|
|
stmt = select(KeyMonitor).order_by(desc(KeyMonitor.created_at))
|
|
rows = (await session.execute(stmt)).scalars().all()
|
|
return [_key_monitor_to_dict(r) for r in rows]
|
|
|
|
async def get_key_monitor(self, kid: int) -> dict | None:
|
|
async with self.session_factory() as session:
|
|
row = await session.get(KeyMonitor, kid)
|
|
return _key_monitor_to_dict(row) if row else None
|
|
|
|
async def delete_key_monitor(self, kid: int) -> bool:
|
|
async with self.session_factory() as session:
|
|
row = await session.get(KeyMonitor, kid)
|
|
if not row:
|
|
return False
|
|
await session.delete(row)
|
|
await session.commit()
|
|
return True
|
|
|
|
async def finalize_key_monitor(
|
|
self,
|
|
row: dict,
|
|
*,
|
|
close_reason: str,
|
|
last_alert_message: str | None,
|
|
confirm_close: float | None,
|
|
planned_sl: float | None,
|
|
planned_tp: float | None,
|
|
planned_rr: float | None,
|
|
executor_signal_id: str | None,
|
|
executor_status: str | None,
|
|
checks: dict | None,
|
|
) -> int:
|
|
async with self.session_factory() as session:
|
|
hist = KeyMonitorHistory(
|
|
key_monitor_id=int(row["id"]),
|
|
symbol=row["symbol"],
|
|
inst_id=row["inst_id"],
|
|
monitor_type=row["monitor_type"],
|
|
direction=row["direction"],
|
|
upper=float(row["upper"]),
|
|
lower=float(row["lower"]),
|
|
sl_tp_mode=row["sl_tp_mode"],
|
|
manual_take_profit=row.get("manual_take_profit"),
|
|
stop_outside_pct=float(row["stop_outside_pct"]),
|
|
confirm_close=confirm_close,
|
|
planned_sl=planned_sl,
|
|
planned_tp=planned_tp,
|
|
planned_rr=planned_rr,
|
|
executor_signal_id=executor_signal_id,
|
|
executor_status=executor_status,
|
|
checks_json=json.dumps(checks or {}, ensure_ascii=False),
|
|
last_alert_message=last_alert_message,
|
|
close_reason=close_reason,
|
|
)
|
|
session.add(hist)
|
|
active = await session.get(KeyMonitor, int(row["id"]))
|
|
if active:
|
|
await session.delete(active)
|
|
await session.commit()
|
|
await session.refresh(hist)
|
|
return int(hist.id)
|
|
|
|
async def list_key_monitor_history(
|
|
self,
|
|
*,
|
|
limit: int = 500,
|
|
since: datetime | None = None,
|
|
) -> list[dict]:
|
|
async with self.session_factory() as session:
|
|
stmt = select(KeyMonitorHistory).order_by(desc(KeyMonitorHistory.closed_at)).limit(limit)
|
|
if since is not None:
|
|
stmt = stmt.where(KeyMonitorHistory.closed_at >= since)
|
|
rows = (await session.execute(stmt)).scalars().all()
|
|
return [_key_history_to_dict(r) for r in rows]
|
|
|
|
async def delete_key_monitor_history(self, hid: int) -> bool:
|
|
async with self.session_factory() as session:
|
|
row = await session.get(KeyMonitorHistory, hid)
|
|
if not row:
|
|
return False
|
|
await session.delete(row)
|
|
await session.commit()
|
|
return True
|
|
|
|
async def export_key_monitor_history_rows(
|
|
self,
|
|
*,
|
|
start_utc: datetime,
|
|
end_utc: datetime,
|
|
) -> list[dict]:
|
|
async with self.session_factory() as session:
|
|
stmt = (
|
|
select(KeyMonitorHistory)
|
|
.where(
|
|
KeyMonitorHistory.closed_at >= start_utc,
|
|
KeyMonitorHistory.closed_at <= end_utc,
|
|
)
|
|
.order_by(KeyMonitorHistory.id.asc())
|
|
)
|
|
rows = (await session.execute(stmt)).scalars().all()
|
|
return [_key_history_to_dict(r) for r in rows]
|
|
|
|
async def close(self) -> None:
|
|
await self.engine.dispose()
|
|
|
|
|
|
def _key_monitor_to_dict(row: KeyMonitor | None) -> dict:
|
|
if row is None:
|
|
return {}
|
|
return {
|
|
"id": row.id,
|
|
"symbol": row.symbol,
|
|
"inst_id": row.inst_id,
|
|
"monitor_type": row.monitor_type,
|
|
"direction": row.direction,
|
|
"upper": row.upper,
|
|
"lower": row.lower,
|
|
"sl_tp_mode": row.sl_tp_mode,
|
|
"manual_take_profit": row.manual_take_profit,
|
|
"stop_outside_pct": row.stop_outside_pct,
|
|
"breakeven_enabled": row.breakeven_enabled,
|
|
"note": row.note,
|
|
"created_at": row.created_at.isoformat(),
|
|
}
|
|
|
|
|
|
def _key_history_to_dict(row: KeyMonitorHistory) -> dict:
|
|
return {
|
|
"id": row.id,
|
|
"key_monitor_id": row.key_monitor_id,
|
|
"symbol": row.symbol,
|
|
"inst_id": row.inst_id,
|
|
"monitor_type": row.monitor_type,
|
|
"direction": row.direction,
|
|
"upper": row.upper,
|
|
"lower": row.lower,
|
|
"sl_tp_mode": row.sl_tp_mode,
|
|
"manual_take_profit": row.manual_take_profit,
|
|
"stop_outside_pct": row.stop_outside_pct,
|
|
"confirm_close": row.confirm_close,
|
|
"planned_sl": row.planned_sl,
|
|
"planned_tp": row.planned_tp,
|
|
"planned_rr": row.planned_rr,
|
|
"executor_signal_id": row.executor_signal_id,
|
|
"executor_status": row.executor_status,
|
|
"checks_json": row.checks_json,
|
|
"last_alert_message": row.last_alert_message,
|
|
"close_reason": row.close_reason,
|
|
"closed_at": row.closed_at.isoformat(),
|
|
}
|