Files
crypto_monitor/manual_trading_hub/hub_ai/store.py
T
dekun 582ada7e60 feat(hub): add data dashboard and AI chat with session history
Add /dashboard with daily PnL overview and loss alerts. Extend AI coach chat with history sidebar, delete/switch sessions, message copy, and trading vs general bot modes.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 10:42:33 +08:00

285 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""中控 AIJSON 持久化(与 hub_settings.json 同目录)。"""
from __future__ import annotations
import json
import os
import uuid
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Optional
from hub_ai.config import CHAT_SESSION_RETENTION_DAYS, SUMMARY_RETENTION_DAYS
HUB_DIR = Path(__file__).resolve().parent.parent
SUMMARIES_PATH = HUB_DIR / "hub_ai_summaries.json"
CHAT_PATH = HUB_DIR / "hub_ai_chat.json"
def _now_str() -> str:
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def _atomic_write(path: Path, data: dict) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(path.suffix + ".tmp")
tmp.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
os.replace(tmp, path)
def _load_json(path: Path, default: dict) -> dict:
if not path.is_file():
return dict(default)
try:
loaded = json.loads(path.read_text(encoding="utf-8"))
if isinstance(loaded, dict):
return loaded
except Exception:
pass
return dict(default)
def _prune_summaries(items: list, *, keep_days: int) -> list:
cutoff = (datetime.now() - timedelta(days=max(1, keep_days))).strftime("%Y-%m-%d")
out = [x for x in items if str(x.get("trading_day") or "") >= cutoff]
return out[-500:]
def _prune_chat_sessions(sessions: list, *, keep_days: int) -> list:
cutoff_dt = datetime.now() - timedelta(days=max(1, keep_days))
out = []
for s in sessions:
ts = str(s.get("updated_at") or s.get("created_at") or "")
try:
dt = datetime.strptime(ts[:19], "%Y-%m-%d %H:%M:%S")
except ValueError:
out.append(s)
continue
if dt >= cutoff_dt:
out.append(s)
return out[-50:]
def load_summaries_store() -> dict:
return _load_json(SUMMARIES_PATH, {"version": 1, "summaries": []})
def save_summaries_store(data: dict) -> None:
summaries = _prune_summaries(
list(data.get("summaries") or []),
keep_days=SUMMARY_RETENTION_DAYS,
)
_atomic_write(SUMMARIES_PATH, {"version": 1, "summaries": summaries})
def append_summary(
*,
trading_day: str,
content_md: str,
model: str,
context_hash: str,
stats_snapshot: dict,
) -> dict:
store = load_summaries_store()
row = {
"id": uuid.uuid4().hex,
"trading_day": trading_day,
"generated_at": _now_str(),
"model": model,
"context_hash": context_hash,
"content_md": content_md,
"stats_snapshot": stats_snapshot,
}
store.setdefault("summaries", []).append(row)
save_summaries_store(store)
return row
def list_summaries(*, trading_day: Optional[str] = None, limit: int = 30) -> list[dict]:
store = load_summaries_store()
items = list(store.get("summaries") or [])
if trading_day:
items = [x for x in items if str(x.get("trading_day")) == trading_day]
items.sort(key=lambda x: str(x.get("generated_at") or ""), reverse=True)
return items[: max(1, min(limit, 100))]
def get_latest_summary(trading_day: str) -> Optional[dict]:
rows = list_summaries(trading_day=trading_day, limit=1)
return rows[0] if rows else None
def load_chat_store() -> dict:
default = {"version": 1, "sessions": [], "active_session_id": None}
data = _load_json(CHAT_PATH, default)
data.setdefault("version", 1)
data.setdefault("sessions", [])
return data
def save_chat_store(data: dict) -> None:
sessions = _prune_chat_sessions(
list(data.get("sessions") or []),
keep_days=CHAT_SESSION_RETENTION_DAYS,
)
active = data.get("active_session_id")
ids = {str(s.get("id")) for s in sessions}
if active and str(active) not in ids:
active = sessions[-1]["id"] if sessions else None
_atomic_write(
CHAT_PATH,
{"version": 1, "sessions": sessions, "active_session_id": active},
)
def get_active_session() -> Optional[dict]:
store = load_chat_store()
sid = store.get("active_session_id")
for s in store.get("sessions") or []:
if str(s.get("id")) == str(sid):
return s
return None
CHAT_BOT_TRADING = "trading"
CHAT_BOT_GENERAL = "general"
CHAT_BOT_MODES = frozenset({CHAT_BOT_TRADING, CHAT_BOT_GENERAL})
def _normalize_bot_mode(raw: Any) -> str:
mode = (raw or CHAT_BOT_TRADING).strip().lower()
return mode if mode in CHAT_BOT_MODES else CHAT_BOT_TRADING
def create_new_session(
*,
trading_day: str,
title: str = "新对话",
bot_mode: str = CHAT_BOT_TRADING,
) -> dict:
store = load_chat_store()
session = {
"id": uuid.uuid4().hex,
"trading_day": trading_day,
"title": title,
"bot_mode": _normalize_bot_mode(bot_mode),
"created_at": _now_str(),
"updated_at": _now_str(),
"messages": [],
}
store.setdefault("sessions", []).append(session)
store["active_session_id"] = session["id"]
save_chat_store(store)
return session
def ensure_active_session(*, trading_day: str) -> dict:
active = get_active_session()
if active:
return active
return create_new_session(trading_day=trading_day)
def append_chat_message(
session_id: str,
role: str,
content: str,
*,
attachments: Optional[list] = None,
) -> dict:
store = load_chat_store()
sessions = store.get("sessions") or []
target = None
for s in sessions:
if str(s.get("id")) == str(session_id):
target = s
break
if not target:
raise KeyError("session_not_found")
msg = {"role": role, "content": content.strip(), "at": _now_str()}
if attachments:
msg["attachments"] = list(attachments)
target.setdefault("messages", []).append(msg)
target["updated_at"] = _now_str()
if role == "user" and (target.get("title") in (None, "", "新对话")):
title = content.strip().replace("\n", " ")[:24]
if title:
target["title"] = title
store["active_session_id"] = target["id"]
save_chat_store(store)
return target
def _session_list_item(s: dict, *, active_id: Optional[str]) -> dict:
msgs = s.get("messages") or []
preview = ""
for m in reversed(msgs):
if m.get("role") == "user":
preview = str(m.get("content") or "").replace("\n", " ")[:48]
break
if not preview and msgs:
last = msgs[-1]
preview = str(last.get("content") or "").replace("\n", " ")[:48]
sid = str(s.get("id") or "")
return {
"id": sid,
"title": s.get("title") or "新对话",
"bot_mode": _normalize_bot_mode(s.get("bot_mode")),
"trading_day": s.get("trading_day"),
"created_at": s.get("created_at"),
"updated_at": s.get("updated_at"),
"message_count": len(msgs),
"preview": preview,
"is_active": sid and sid == str(active_id or ""),
}
def list_chat_sessions(*, limit: int = 50) -> list[dict]:
store = load_chat_store()
active_id = store.get("active_session_id")
sessions = list(store.get("sessions") or [])
for s in sessions:
s.setdefault("bot_mode", CHAT_BOT_TRADING)
sessions.sort(key=lambda x: str(x.get("updated_at") or ""), reverse=True)
return [_session_list_item(s, active_id=active_id) for s in sessions[: max(1, min(limit, 100))]]
def set_active_session(session_id: str) -> dict:
store = load_chat_store()
target = None
for s in store.get("sessions") or []:
if str(s.get("id")) == str(session_id):
target = s
break
if not target:
raise KeyError("session_not_found")
target.setdefault("bot_mode", CHAT_BOT_TRADING)
store["active_session_id"] = target["id"]
save_chat_store(store)
return target
def delete_chat_session(session_id: str) -> tuple[bool, Optional[str]]:
store = load_chat_store()
sessions = list(store.get("sessions") or [])
new_sessions = [s for s in sessions if str(s.get("id")) != str(session_id)]
if len(new_sessions) == len(sessions):
return False, None
active = store.get("active_session_id")
new_active = active
if str(active) == str(session_id):
new_active = new_sessions[0]["id"] if new_sessions else None
store["sessions"] = new_sessions
store["active_session_id"] = new_active
save_chat_store(store)
return True, new_active
def summary_excerpt_for_chat(trading_day: str, max_chars: int = 600) -> str:
latest = get_latest_summary(trading_day)
if not latest:
return ""
text = str(latest.get("content_md") or "").strip()
if len(text) <= max_chars:
return text
return text[: max_chars - 3].rstrip() + "..."