feat(hub): add AI coach page with daily summary and chat
Aggregate four-account trades via hub_ai module and /api/hub/trades/today; store sessions in JSON; default OpenAI config matches instances. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""中控 AI 模块:今日总结 + 交易员聊天(与实例 ai_review 分离)。"""
|
||||
@@ -0,0 +1,88 @@
|
||||
"""中控 AI:单会话聊天(直到用户点击新开)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from hub_ai.client import generate_text, model_label
|
||||
from hub_ai.config import CHAT_MAX_HISTORY_TURNS, CHAT_TEMPERATURE
|
||||
from hub_ai.context import build_daily_context, format_chat_context_brief
|
||||
from hub_ai.prompts import CHAT_SYSTEM, build_chat_user_prompt
|
||||
from hub_ai.store import (
|
||||
append_chat_message,
|
||||
create_new_session,
|
||||
ensure_active_session,
|
||||
get_active_session,
|
||||
load_chat_store,
|
||||
summary_excerpt_for_chat,
|
||||
)
|
||||
|
||||
|
||||
def _history_lines(messages: list[dict], max_turns: int = CHAT_MAX_HISTORY_TURNS) -> str:
|
||||
rows = [m for m in (messages or []) if m.get("role") in ("user", "assistant")]
|
||||
rows = rows[-max_turns * 2 :]
|
||||
lines = []
|
||||
for m in rows:
|
||||
role = "用户" if m.get("role") == "user" else "搭档"
|
||||
lines.append(f"{role}:{m.get('content') or ''}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def get_chat_state() -> dict[str, Any]:
|
||||
store = load_chat_store()
|
||||
session = get_active_session()
|
||||
return {
|
||||
"active_session_id": store.get("active_session_id"),
|
||||
"session": session,
|
||||
"model": model_label(),
|
||||
}
|
||||
|
||||
|
||||
def start_new_chat(*, trading_day: str) -> dict:
|
||||
session = create_new_session(trading_day=trading_day)
|
||||
return {"ok": True, "session": session, "model": model_label()}
|
||||
|
||||
|
||||
def send_chat_message(
|
||||
exchanges: list[dict],
|
||||
message: str,
|
||||
*,
|
||||
trading_day: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
text = (message or "").strip()
|
||||
if not text:
|
||||
return {"ok": False, "msg": "消息不能为空"}
|
||||
|
||||
ctx = build_daily_context(exchanges, trading_day=trading_day)
|
||||
day = ctx["trading_day"]
|
||||
session = ensure_active_session(trading_day=day)
|
||||
sid = session["id"]
|
||||
history = _history_lines(session.get("messages") or [])
|
||||
|
||||
append_chat_message(sid, "user", text)
|
||||
|
||||
brief_ctx = format_chat_context_brief(ctx)
|
||||
excerpt = summary_excerpt_for_chat(day)
|
||||
|
||||
user_prompt = build_chat_user_prompt(
|
||||
context_text=brief_ctx,
|
||||
trading_day=day,
|
||||
summary_excerpt=excerpt,
|
||||
history_lines=history,
|
||||
user_message=text,
|
||||
)
|
||||
reply = generate_text(
|
||||
system=CHAT_SYSTEM,
|
||||
user=user_prompt,
|
||||
temperature=CHAT_TEMPERATURE,
|
||||
)
|
||||
if reply.startswith("AI 调用失败"):
|
||||
return {"ok": False, "msg": reply, "session_id": sid}
|
||||
|
||||
session = append_chat_message(sid, "assistant", reply)
|
||||
return {
|
||||
"ok": True,
|
||||
"trading_day": day,
|
||||
"session": session,
|
||||
"reply": reply,
|
||||
"model": model_label(),
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
"""中控 AI 模型调用(共用 ai_client 配置,逻辑独立)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(_REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_REPO_ROOT))
|
||||
|
||||
from ai_client import ai_generate, ai_provider_label # noqa: E402
|
||||
|
||||
|
||||
def model_label() -> str:
|
||||
return ai_provider_label()
|
||||
|
||||
|
||||
def generate_text(*, system: str, user: str, temperature: float) -> str:
|
||||
prompt = f"{system.strip()}\n\n---\n\n{user.strip()}"
|
||||
return ai_generate(prompt, temperature=temperature)
|
||||
@@ -0,0 +1,33 @@
|
||||
"""中控 AI 配置(读 hub .env,与实例同名 AI 变量)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
HUB_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
SUMMARY_TEMPERATURE = 0.15
|
||||
CHAT_TEMPERATURE = 0.5
|
||||
CHAT_MAX_HISTORY_TURNS = 20
|
||||
SUMMARY_RETENTION_DAYS = 90
|
||||
CHAT_SESSION_RETENTION_DAYS = 60
|
||||
|
||||
|
||||
def trading_day_reset_hour() -> int:
|
||||
try:
|
||||
return int(os.getenv("TRADING_DAY_RESET_HOUR", "8") or "8")
|
||||
except ValueError:
|
||||
return 8
|
||||
|
||||
|
||||
def hub_flask_timeout() -> float:
|
||||
try:
|
||||
return float(os.getenv("HUB_FLASK_TIMEOUT", "10") or "10")
|
||||
except ValueError:
|
||||
return 10.0
|
||||
|
||||
|
||||
def hub_agent_timeout() -> float:
|
||||
try:
|
||||
return float(os.getenv("HUB_AGENT_TIMEOUT", "8") or "8")
|
||||
except ValueError:
|
||||
return 8.0
|
||||
@@ -0,0 +1,278 @@
|
||||
"""中控 AI:四户数据聚合为结构化上下文。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from hub_ai.config import hub_agent_timeout, hub_flask_timeout, trading_day_reset_hour
|
||||
from hub_trades_lib import current_trading_day, summarize_trades
|
||||
|
||||
|
||||
def _hub_token() -> str:
|
||||
return (os.getenv("HUB_BRIDGE_TOKEN") or os.getenv("CONTROL_TOKEN") or "").strip()
|
||||
|
||||
|
||||
def _hub_headers() -> dict[str, str]:
|
||||
tok = _hub_token()
|
||||
return {"X-Hub-Token": tok} if tok else {}
|
||||
|
||||
|
||||
def _agent_headers() -> dict[str, str]:
|
||||
tok = (os.getenv("CONTROL_TOKEN") or os.getenv("HUB_BRIDGE_TOKEN") or "").strip()
|
||||
return {"X-Control-Token": tok} if tok else {}
|
||||
|
||||
|
||||
def _safe_float(v: Any) -> Optional[float]:
|
||||
try:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
return float(v)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _position_float_pnl(pos: dict) -> float:
|
||||
for key in ("unrealized_pnl", "unrealizedPnl", "upnl"):
|
||||
v = _safe_float(pos.get(key))
|
||||
if v is not None:
|
||||
return v
|
||||
return 0.0
|
||||
|
||||
|
||||
def _collect_open_issues(
|
||||
*,
|
||||
monitored: bool,
|
||||
agent_ok: bool,
|
||||
flask_ok: bool,
|
||||
positions: list,
|
||||
hub_mon: Optional[dict],
|
||||
day_pnl: float,
|
||||
) -> list[str]:
|
||||
issues: list[str] = []
|
||||
if not monitored:
|
||||
return issues
|
||||
if not agent_ok:
|
||||
issues.append("Agent 连接异常")
|
||||
if not flask_ok:
|
||||
issues.append("Flask 监控连接异常")
|
||||
if day_pnl < -0.01:
|
||||
issues.append(f"当日平仓亏损 {day_pnl:.2f}U")
|
||||
float_pnl = sum(_position_float_pnl(p) for p in positions if isinstance(p, dict))
|
||||
if float_pnl < -0.5:
|
||||
issues.append(f"当前浮亏 {float_pnl:.2f}U")
|
||||
if isinstance(hub_mon, dict) and hub_mon.get("ok") is not False:
|
||||
orders = hub_mon.get("orders") or []
|
||||
trends = hub_mon.get("trends") or []
|
||||
if positions and not orders and not trends:
|
||||
issues.append("交易所有持仓但无本地 active 监控/趋势计划")
|
||||
return issues
|
||||
|
||||
|
||||
def _fetch_account_bundle(client: httpx.Client, ex: dict, trading_day: str) -> dict[str, Any]:
|
||||
name = ex.get("name") or ex.get("key") or ex.get("id")
|
||||
key = ex.get("key") or ""
|
||||
enabled = bool(ex.get("enabled"))
|
||||
env_disabled = bool(ex.get("env_disabled"))
|
||||
monitored = enabled and not env_disabled
|
||||
|
||||
base: dict[str, Any] = {
|
||||
"id": ex.get("id"),
|
||||
"key": key,
|
||||
"name": name,
|
||||
"enabled": enabled,
|
||||
"env_disabled": env_disabled,
|
||||
"status": "未监控" if not monitored else "已监控",
|
||||
"trades": [],
|
||||
"trade_stats": summarize_trades([]),
|
||||
"positions": [],
|
||||
"float_pnl_u": 0.0,
|
||||
"balance_usdt": None,
|
||||
"issues": [],
|
||||
"agent_ok": False,
|
||||
"flask_ok": False,
|
||||
"hub_monitor": None,
|
||||
"active_orders": 0,
|
||||
"active_trends": 0,
|
||||
}
|
||||
if not monitored:
|
||||
base["issues"] = []
|
||||
return base
|
||||
|
||||
agent_url = (ex.get("agent_url") or "").rstrip("/")
|
||||
flask_url = (ex.get("flask_url") or "").rstrip("/")
|
||||
agent_body = None
|
||||
if agent_url:
|
||||
try:
|
||||
r = client.get(
|
||||
f"{agent_url}/status",
|
||||
headers=_agent_headers(),
|
||||
timeout=hub_agent_timeout(),
|
||||
)
|
||||
if r.status_code == 200:
|
||||
agent_body = r.json()
|
||||
base["agent_ok"] = True
|
||||
except Exception as exc:
|
||||
base["issues"].append(f"Agent: {exc}")
|
||||
|
||||
if isinstance(agent_body, dict):
|
||||
base["balance_usdt"] = _safe_float(agent_body.get("balance_usdt"))
|
||||
positions = agent_body.get("positions") or []
|
||||
if isinstance(positions, list):
|
||||
base["positions"] = positions
|
||||
base["float_pnl_u"] = round(
|
||||
sum(_position_float_pnl(p) for p in positions if isinstance(p, dict)), 4
|
||||
)
|
||||
|
||||
hub_mon = None
|
||||
if flask_url:
|
||||
try:
|
||||
r = client.get(
|
||||
f"{flask_url}/api/hub/trades/today",
|
||||
headers=_hub_headers(),
|
||||
params={"trading_day": trading_day},
|
||||
timeout=hub_flask_timeout(),
|
||||
)
|
||||
if r.status_code == 200:
|
||||
trades_body = r.json()
|
||||
if isinstance(trades_body, dict) and trades_body.get("ok"):
|
||||
base["trades"] = trades_body.get("trades") or []
|
||||
base["trade_stats"] = trades_body.get("stats") or summarize_trades(base["trades"])
|
||||
base["flask_ok"] = True
|
||||
except Exception as exc:
|
||||
base["issues"].append(f"成交接口: {exc}")
|
||||
|
||||
try:
|
||||
r = client.get(
|
||||
f"{flask_url}/api/hub/monitor",
|
||||
headers=_hub_headers(),
|
||||
timeout=hub_flask_timeout(),
|
||||
)
|
||||
if r.status_code == 200:
|
||||
hub_mon = r.json()
|
||||
if isinstance(hub_mon, dict) and hub_mon.get("ok") is not False:
|
||||
base["hub_monitor"] = hub_mon
|
||||
base["flask_ok"] = True
|
||||
base["active_orders"] = len(hub_mon.get("orders") or [])
|
||||
base["active_trends"] = len(hub_mon.get("trends") or [])
|
||||
except Exception as exc:
|
||||
if "成交接口" not in str(base["issues"]):
|
||||
base["issues"].append(f"监控接口: {exc}")
|
||||
|
||||
if monitored and not base["agent_ok"] and not base["flask_ok"]:
|
||||
base["status"] = "连接异常"
|
||||
elif base["issues"]:
|
||||
base["status"] = "已监控·需关注"
|
||||
|
||||
day_pnl = float((base.get("trade_stats") or {}).get("total_pnl_u") or 0)
|
||||
base["issues"].extend(
|
||||
_collect_open_issues(
|
||||
monitored=monitored,
|
||||
agent_ok=base["agent_ok"],
|
||||
flask_ok=base["flask_ok"],
|
||||
positions=base["positions"],
|
||||
hub_mon=hub_mon if isinstance(hub_mon, dict) else None,
|
||||
day_pnl=day_pnl,
|
||||
)
|
||||
)
|
||||
base["issues"] = list(dict.fromkeys(base["issues"]))
|
||||
return base
|
||||
|
||||
|
||||
def build_daily_context(
|
||||
exchanges: list[dict],
|
||||
*,
|
||||
trading_day: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
day = (trading_day or "").strip()[:10] or current_trading_day(
|
||||
reset_hour=trading_day_reset_hour()
|
||||
)
|
||||
accounts: list[dict] = []
|
||||
with httpx.Client() as client:
|
||||
for ex in exchanges or []:
|
||||
accounts.append(_fetch_account_bundle(client, ex, day))
|
||||
|
||||
total_closed_pnl = 0.0
|
||||
total_closed = total_win = total_loss = 0
|
||||
total_float = 0.0
|
||||
for ac in accounts:
|
||||
if ac.get("status") == "未监控":
|
||||
continue
|
||||
st = ac.get("trade_stats") or {}
|
||||
total_closed_pnl += float(st.get("total_pnl_u") or 0)
|
||||
total_closed += int(st.get("closed_count") or 0)
|
||||
total_win += int(st.get("win_count") or 0)
|
||||
total_loss += int(st.get("loss_count") or 0)
|
||||
total_float += float(ac.get("float_pnl_u") or 0)
|
||||
|
||||
totals = {
|
||||
"trading_day": day,
|
||||
"total_pnl_u": round(total_closed_pnl, 4),
|
||||
"closed_count": total_closed,
|
||||
"win_count": total_win,
|
||||
"loss_count": total_loss,
|
||||
"float_pnl_u": round(total_float, 4),
|
||||
}
|
||||
payload = {"trading_day": day, "totals": totals, "accounts": accounts}
|
||||
text = format_context_text(payload)
|
||||
digest = hashlib.sha256(text.encode("utf-8")).hexdigest()[:16]
|
||||
return {"trading_day": day, "totals": totals, "accounts": accounts, "text": text, "context_hash": digest}
|
||||
|
||||
|
||||
def format_context_text(payload: dict) -> str:
|
||||
lines = []
|
||||
totals = payload.get("totals") or {}
|
||||
lines.append(
|
||||
f"【合计】交易日 {totals.get('trading_day')} | "
|
||||
f"平仓盈亏 {totals.get('total_pnl_u')}U | "
|
||||
f"笔数 {totals.get('closed_count')}(胜{totals.get('win_count')}/负{totals.get('loss_count')})| "
|
||||
f"监控户浮盈亏合计 {totals.get('float_pnl_u')}U"
|
||||
)
|
||||
lines.append("")
|
||||
for ac in payload.get("accounts") or []:
|
||||
st = ac.get("trade_stats") or {}
|
||||
lines.append(f"--- 账户:{ac.get('name')} ({ac.get('key')}) ---")
|
||||
lines.append(f"状态:{ac.get('status')}")
|
||||
if ac.get("status") == "未监控":
|
||||
lines.append("")
|
||||
continue
|
||||
lines.append(
|
||||
f"当日平仓:{st.get('closed_count')} 笔,盈亏 {st.get('total_pnl_u')}U "
|
||||
f"(胜{st.get('win_count')}/负{st.get('loss_count')})"
|
||||
)
|
||||
lines.append(f"合约可用余额:{ac.get('balance_usdt') if ac.get('balance_usdt') is not None else '未知'} USDT")
|
||||
lines.append(f"当前持仓浮盈亏:{ac.get('float_pnl_u')}U | 下单监控 {ac.get('active_orders')} | 趋势计划 {ac.get('active_trends')}")
|
||||
positions = ac.get("positions") or []
|
||||
if positions:
|
||||
lines.append("持仓:")
|
||||
for p in positions[:8]:
|
||||
if not isinstance(p, dict):
|
||||
continue
|
||||
sym = p.get("symbol") or "?"
|
||||
side = p.get("side") or "?"
|
||||
contracts = p.get("contracts") or p.get("size") or "?"
|
||||
upnl = _position_float_pnl(p)
|
||||
lines.append(f" - {sym} {side} 张数{contracts} 浮盈亏{upnl:.4f}U")
|
||||
trades = ac.get("trades") or []
|
||||
if trades:
|
||||
lines.append("当日平仓明细:")
|
||||
for t in trades[:15]:
|
||||
lines.append(
|
||||
f" - {t.get('symbol')} {t.get('direction')} {t.get('result')} "
|
||||
f"{t.get('pnl_amount')}U @ {t.get('closed_at') or '?'}"
|
||||
)
|
||||
issues = ac.get("issues") or []
|
||||
if issues:
|
||||
lines.append("关注点:" + ";".join(issues))
|
||||
lines.append("")
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
|
||||
def format_chat_context_brief(payload: dict, max_chars: int = 2500) -> str:
|
||||
text = format_context_text(payload)
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
return text[: max_chars - 3].rstrip() + "..."
|
||||
@@ -0,0 +1,74 @@
|
||||
"""中控 AI 提示词(与实例 ai_review 分离)。"""
|
||||
|
||||
SUMMARY_SYSTEM = """
|
||||
你是多账户加密货币合约交易的台账助手。只根据用户提供的结构化数据输出中文 Markdown,语气克制、偏冷、客观,像值班记录。
|
||||
|
||||
硬性规则:
|
||||
- 只能陈述数据中明确出现的数字与事实;禁止编造成交、止损、扛单、行情预测。
|
||||
- 未监控的账户必须标注「未监控」,不得臆测其盈亏。
|
||||
- 连接失败或数据缺失的账户如实写明,不要猜测。
|
||||
- 不要用安慰、说教、建议口吻(那些属于聊天功能)。
|
||||
- 禁止夸张词(致命、崩溃、灾难等)。
|
||||
|
||||
输出格式(Markdown,标题必须一致):
|
||||
**今日交易总结({trading_day})**
|
||||
|
||||
**1. 总览**
|
||||
- **合计盈亏(U)**:…
|
||||
- **平仓笔数**:…(胜 / 负 / 平)
|
||||
- **当前持仓浮盈亏(U)**:…(仅汇总已监控且有数据的账户)
|
||||
|
||||
**2. 分户明细**
|
||||
每个账户一行:账户名 | 状态(已监控/未监控/连接异常) | 当日平仓盈亏 | 笔数 | 当前浮盈亏 | 备注
|
||||
|
||||
**3. 需关注**
|
||||
仅有依据时列出(如:某户当日亏损最大、浮亏偏大、Flask/Agent 异常、有持仓但无本地监控等);若无则写「无」。
|
||||
|
||||
**4. 数据说明**
|
||||
列出数据缺口(某户未启用、接口失败等)。
|
||||
""".strip()
|
||||
|
||||
|
||||
CHAT_SYSTEM = """
|
||||
你是和用户一起盯盘的老搭档交易员,熟悉他多个交易所账户的分工。用中文、口语化、短句交流。
|
||||
|
||||
语气要求:
|
||||
- 先理解对方的压力和情绪,再轻轻帮他把事想清楚(安慰、体贴)。
|
||||
- 可以指出执行或心态上的偏差点,但用商量、陪伴的口吻,绝不用教育、训诫、上课、列清单式说教。
|
||||
- 不要「第1点第2点你应该…」;不要「作为你的教练我必须…」。
|
||||
- 不预测涨跌,不保证收益,不替用户做决定。
|
||||
- 只能依据提供的监控与交易数据说话;看不到的就说「我这边看不到,你可以去 xx 实例页确认」。
|
||||
|
||||
若附带「今日总结摘要」,可自然引用,但保持口语,不要复读整份报告。
|
||||
""".strip()
|
||||
|
||||
|
||||
def build_summary_user_prompt(context_text: str, trading_day: str) -> str:
|
||||
return f"""
|
||||
交易日:{trading_day}
|
||||
|
||||
以下为中控聚合的多账户数据(含未监控账户标记):
|
||||
|
||||
{context_text}
|
||||
""".strip()
|
||||
|
||||
|
||||
def build_chat_user_prompt(
|
||||
*,
|
||||
context_text: str,
|
||||
trading_day: str,
|
||||
summary_excerpt: str,
|
||||
history_lines: str,
|
||||
user_message: str,
|
||||
) -> str:
|
||||
parts = [
|
||||
f"【交易日】{trading_day}",
|
||||
"【当前多账户快照】",
|
||||
context_text.strip() or "(无监控数据)",
|
||||
]
|
||||
if summary_excerpt.strip():
|
||||
parts.extend(["【今日总结摘要(供参考)】", summary_excerpt.strip()])
|
||||
if history_lines.strip():
|
||||
parts.extend(["【此前对话】", history_lines.strip()])
|
||||
parts.extend(["【用户现在说】", user_message.strip()])
|
||||
return "\n\n".join(parts)
|
||||
@@ -0,0 +1,108 @@
|
||||
"""中控 AI FastAPI 路由。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from hub_ai.chat import get_chat_state, send_chat_message, start_new_chat
|
||||
from hub_ai.client import model_label
|
||||
from hub_ai.config import trading_day_reset_hour
|
||||
from hub_ai.context import build_daily_context
|
||||
from hub_ai.store import get_latest_summary, list_summaries
|
||||
from hub_ai.summary import generate_daily_summary
|
||||
from hub_trades_lib import current_trading_day
|
||||
|
||||
|
||||
class ChatSendBody(BaseModel):
|
||||
message: str = ""
|
||||
trading_day: str = ""
|
||||
|
||||
|
||||
class SummaryGenerateBody(BaseModel):
|
||||
trading_day: str = ""
|
||||
force: bool = False
|
||||
|
||||
|
||||
class ChatNewBody(BaseModel):
|
||||
trading_day: str = ""
|
||||
|
||||
|
||||
def create_hub_ai_router(*, load_all_exchanges: Callable[[], list]) -> APIRouter:
|
||||
router = APIRouter(prefix="/api/ai", tags=["hub-ai"])
|
||||
|
||||
def _day(raw: str = "") -> str:
|
||||
d = (raw or "").strip()[:10]
|
||||
return d or current_trading_day(reset_hour=trading_day_reset_hour())
|
||||
|
||||
@router.get("/meta")
|
||||
def api_ai_meta():
|
||||
return {
|
||||
"ok": True,
|
||||
"model": model_label(),
|
||||
"trading_day_reset_hour": trading_day_reset_hour(),
|
||||
"trading_day": current_trading_day(reset_hour=trading_day_reset_hour()),
|
||||
"storage": {
|
||||
"summaries": "hub_ai_summaries.json",
|
||||
"chat": "hub_ai_chat.json",
|
||||
},
|
||||
}
|
||||
|
||||
@router.get("/context")
|
||||
def api_ai_context(trading_day: str = ""):
|
||||
exchanges = load_all_exchanges()
|
||||
ctx = build_daily_context(exchanges, trading_day=_day(trading_day))
|
||||
return {"ok": True, **ctx}
|
||||
|
||||
@router.get("/summary")
|
||||
def api_ai_summary_list(trading_day: str = ""):
|
||||
day = _day(trading_day) if trading_day.strip() else ""
|
||||
items = list_summaries(trading_day=day or None, limit=20)
|
||||
latest = get_latest_summary(_day(trading_day)) if trading_day.strip() else (
|
||||
items[0] if items else None
|
||||
)
|
||||
return {
|
||||
"ok": True,
|
||||
"trading_day": _day(trading_day) if trading_day.strip() else None,
|
||||
"summaries": items,
|
||||
"latest": latest,
|
||||
"model": model_label(),
|
||||
}
|
||||
|
||||
@router.post("/summary/generate")
|
||||
def api_ai_summary_generate(body: SummaryGenerateBody = SummaryGenerateBody()):
|
||||
exchanges = load_all_exchanges()
|
||||
result = generate_daily_summary(
|
||||
exchanges,
|
||||
trading_day=_day(body.trading_day) if body.trading_day.strip() else None,
|
||||
force=bool(body.force),
|
||||
)
|
||||
if not result.get("ok"):
|
||||
raise HTTPException(status_code=502, detail=result.get("msg") or "生成失败")
|
||||
result.pop("context", None)
|
||||
return result
|
||||
|
||||
@router.get("/chat/session")
|
||||
def api_ai_chat_session():
|
||||
state = get_chat_state()
|
||||
return {"ok": True, **state, "model": model_label()}
|
||||
|
||||
@router.post("/chat/new")
|
||||
def api_ai_chat_new(body: ChatNewBody = ChatNewBody()):
|
||||
day = _day(body.trading_day)
|
||||
return start_new_chat(trading_day=day)
|
||||
|
||||
@router.post("/chat/send")
|
||||
def api_ai_chat_send(body: ChatSendBody):
|
||||
exchanges = load_all_exchanges()
|
||||
result = send_chat_message(
|
||||
exchanges,
|
||||
body.message,
|
||||
trading_day=_day(body.trading_day) if body.trading_day.strip() else None,
|
||||
)
|
||||
if not result.get("ok"):
|
||||
raise HTTPException(status_code=502, detail=result.get("msg") or "发送失败")
|
||||
return result
|
||||
|
||||
return router
|
||||
@@ -0,0 +1,195 @@
|
||||
"""中控 AI:JSON 持久化(与 hub_settings.json 同目录)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from hub_ai.config import CHAT_SESSION_RETENTION_DAYS, SUMMARY_RETENTION_DAYS
|
||||
|
||||
HUB_DIR = Path(__file__).resolve().parent.parent
|
||||
SUMMARIES_PATH = HUB_DIR / "hub_ai_summaries.json"
|
||||
CHAT_PATH = HUB_DIR / "hub_ai_chat.json"
|
||||
|
||||
|
||||
def _now_str() -> str:
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def _atomic_write(path: Path, data: dict) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||
tmp.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
os.replace(tmp, path)
|
||||
|
||||
|
||||
def _load_json(path: Path, default: dict) -> dict:
|
||||
if not path.is_file():
|
||||
return dict(default)
|
||||
try:
|
||||
loaded = json.loads(path.read_text(encoding="utf-8"))
|
||||
if isinstance(loaded, dict):
|
||||
return loaded
|
||||
except Exception:
|
||||
pass
|
||||
return dict(default)
|
||||
|
||||
|
||||
def _prune_summaries(items: list, *, keep_days: int) -> list:
|
||||
cutoff = (datetime.now() - timedelta(days=max(1, keep_days))).strftime("%Y-%m-%d")
|
||||
out = [x for x in items if str(x.get("trading_day") or "") >= cutoff]
|
||||
return out[-500:]
|
||||
|
||||
|
||||
def _prune_chat_sessions(sessions: list, *, keep_days: int) -> list:
|
||||
cutoff_dt = datetime.now() - timedelta(days=max(1, keep_days))
|
||||
out = []
|
||||
for s in sessions:
|
||||
ts = str(s.get("updated_at") or s.get("created_at") or "")
|
||||
try:
|
||||
dt = datetime.strptime(ts[:19], "%Y-%m-%d %H:%M:%S")
|
||||
except ValueError:
|
||||
out.append(s)
|
||||
continue
|
||||
if dt >= cutoff_dt:
|
||||
out.append(s)
|
||||
return out[-50:]
|
||||
|
||||
|
||||
def load_summaries_store() -> dict:
|
||||
return _load_json(SUMMARIES_PATH, {"version": 1, "summaries": []})
|
||||
|
||||
|
||||
def save_summaries_store(data: dict) -> None:
|
||||
summaries = _prune_summaries(
|
||||
list(data.get("summaries") or []),
|
||||
keep_days=SUMMARY_RETENTION_DAYS,
|
||||
)
|
||||
_atomic_write(SUMMARIES_PATH, {"version": 1, "summaries": summaries})
|
||||
|
||||
|
||||
def append_summary(
|
||||
*,
|
||||
trading_day: str,
|
||||
content_md: str,
|
||||
model: str,
|
||||
context_hash: str,
|
||||
stats_snapshot: dict,
|
||||
) -> dict:
|
||||
store = load_summaries_store()
|
||||
row = {
|
||||
"id": uuid.uuid4().hex,
|
||||
"trading_day": trading_day,
|
||||
"generated_at": _now_str(),
|
||||
"model": model,
|
||||
"context_hash": context_hash,
|
||||
"content_md": content_md,
|
||||
"stats_snapshot": stats_snapshot,
|
||||
}
|
||||
store.setdefault("summaries", []).append(row)
|
||||
save_summaries_store(store)
|
||||
return row
|
||||
|
||||
|
||||
def list_summaries(*, trading_day: Optional[str] = None, limit: int = 30) -> list[dict]:
|
||||
store = load_summaries_store()
|
||||
items = list(store.get("summaries") or [])
|
||||
if trading_day:
|
||||
items = [x for x in items if str(x.get("trading_day")) == trading_day]
|
||||
items.sort(key=lambda x: str(x.get("generated_at") or ""), reverse=True)
|
||||
return items[: max(1, min(limit, 100))]
|
||||
|
||||
|
||||
def get_latest_summary(trading_day: str) -> Optional[dict]:
|
||||
rows = list_summaries(trading_day=trading_day, limit=1)
|
||||
return rows[0] if rows else None
|
||||
|
||||
|
||||
def load_chat_store() -> dict:
|
||||
default = {"version": 1, "sessions": [], "active_session_id": None}
|
||||
data = _load_json(CHAT_PATH, default)
|
||||
data.setdefault("version", 1)
|
||||
data.setdefault("sessions", [])
|
||||
return data
|
||||
|
||||
|
||||
def save_chat_store(data: dict) -> None:
|
||||
sessions = _prune_chat_sessions(
|
||||
list(data.get("sessions") or []),
|
||||
keep_days=CHAT_SESSION_RETENTION_DAYS,
|
||||
)
|
||||
active = data.get("active_session_id")
|
||||
ids = {str(s.get("id")) for s in sessions}
|
||||
if active and str(active) not in ids:
|
||||
active = sessions[-1]["id"] if sessions else None
|
||||
_atomic_write(
|
||||
CHAT_PATH,
|
||||
{"version": 1, "sessions": sessions, "active_session_id": active},
|
||||
)
|
||||
|
||||
|
||||
def get_active_session() -> Optional[dict]:
|
||||
store = load_chat_store()
|
||||
sid = store.get("active_session_id")
|
||||
for s in store.get("sessions") or []:
|
||||
if str(s.get("id")) == str(sid):
|
||||
return s
|
||||
return None
|
||||
|
||||
|
||||
def create_new_session(*, trading_day: str, title: str = "新对话") -> dict:
|
||||
store = load_chat_store()
|
||||
session = {
|
||||
"id": uuid.uuid4().hex,
|
||||
"trading_day": trading_day,
|
||||
"title": title,
|
||||
"created_at": _now_str(),
|
||||
"updated_at": _now_str(),
|
||||
"messages": [],
|
||||
}
|
||||
store.setdefault("sessions", []).append(session)
|
||||
store["active_session_id"] = session["id"]
|
||||
save_chat_store(store)
|
||||
return session
|
||||
|
||||
|
||||
def ensure_active_session(*, trading_day: str) -> dict:
|
||||
active = get_active_session()
|
||||
if active:
|
||||
return active
|
||||
return create_new_session(trading_day=trading_day)
|
||||
|
||||
|
||||
def append_chat_message(session_id: str, role: str, content: str) -> dict:
|
||||
store = load_chat_store()
|
||||
sessions = store.get("sessions") or []
|
||||
target = None
|
||||
for s in sessions:
|
||||
if str(s.get("id")) == str(session_id):
|
||||
target = s
|
||||
break
|
||||
if not target:
|
||||
raise KeyError("session_not_found")
|
||||
msg = {"role": role, "content": content.strip(), "at": _now_str()}
|
||||
target.setdefault("messages", []).append(msg)
|
||||
target["updated_at"] = _now_str()
|
||||
if role == "user" and (target.get("title") in (None, "", "新对话")):
|
||||
title = content.strip().replace("\n", " ")[:24]
|
||||
if title:
|
||||
target["title"] = title
|
||||
store["active_session_id"] = target["id"]
|
||||
save_chat_store(store)
|
||||
return target
|
||||
|
||||
|
||||
def summary_excerpt_for_chat(trading_day: str, max_chars: int = 600) -> str:
|
||||
latest = get_latest_summary(trading_day)
|
||||
if not latest:
|
||||
return ""
|
||||
text = str(latest.get("content_md") or "").strip()
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
return text[: max_chars - 3].rstrip() + "..."
|
||||
@@ -0,0 +1,69 @@
|
||||
"""中控 AI:今日总结生成。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from hub_ai.client import generate_text, model_label
|
||||
from hub_ai.context import build_daily_context
|
||||
from hub_ai.prompts import SUMMARY_SYSTEM, build_summary_user_prompt
|
||||
from hub_ai.store import append_summary, get_latest_summary, list_summaries
|
||||
|
||||
|
||||
def generate_daily_summary(
|
||||
exchanges: list[dict],
|
||||
*,
|
||||
trading_day: str | None = None,
|
||||
force: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
ctx = build_daily_context(exchanges, trading_day=trading_day)
|
||||
day = ctx["trading_day"]
|
||||
if not force:
|
||||
latest = get_latest_summary(day)
|
||||
if latest and latest.get("context_hash") == ctx.get("context_hash"):
|
||||
return {
|
||||
"ok": True,
|
||||
"cached": True,
|
||||
"trading_day": day,
|
||||
"summary": latest,
|
||||
"model": latest.get("model") or model_label(),
|
||||
}
|
||||
|
||||
system = SUMMARY_SYSTEM.replace("{trading_day}", day)
|
||||
user = build_summary_user_prompt(ctx["text"], day)
|
||||
content = generate_text(system=system, user=user, temperature=0.15)
|
||||
if content.startswith("AI 调用失败"):
|
||||
return {"ok": False, "msg": content, "trading_day": day}
|
||||
|
||||
stats_snapshot = {
|
||||
"totals": ctx.get("totals"),
|
||||
"by_account": {
|
||||
str(ac.get("key") or ac.get("id")): {
|
||||
"name": ac.get("name"),
|
||||
"status": ac.get("status"),
|
||||
"pnl_u": (ac.get("trade_stats") or {}).get("total_pnl_u"),
|
||||
"closed_count": (ac.get("trade_stats") or {}).get("closed_count"),
|
||||
"float_pnl_u": ac.get("float_pnl_u"),
|
||||
"issues": ac.get("issues") or [],
|
||||
}
|
||||
for ac in ctx.get("accounts") or []
|
||||
},
|
||||
}
|
||||
row = append_summary(
|
||||
trading_day=day,
|
||||
content_md=content,
|
||||
model=model_label(),
|
||||
context_hash=ctx.get("context_hash") or "",
|
||||
stats_snapshot=stats_snapshot,
|
||||
)
|
||||
return {
|
||||
"ok": True,
|
||||
"cached": False,
|
||||
"trading_day": day,
|
||||
"summary": row,
|
||||
"model": model_label(),
|
||||
"context": ctx,
|
||||
}
|
||||
|
||||
|
||||
def summary_list(trading_day: str | None = None) -> list[dict]:
|
||||
return list_summaries(trading_day=trading_day)
|
||||
Reference in New Issue
Block a user