perf(hub-ai): reduce CPU load during trading coach chat

Cache chat context, parallelize exchange fetches, skip fund history writes, defer rolling summary to a background thread, and cache markdown rendering on the client.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-14 01:59:43 +08:00
parent 28a23008f3
commit 467d160f4d
5 changed files with 148 additions and 39 deletions
+18 -9
View File
@@ -1,6 +1,7 @@
"""中控 AI:单会话聊天(直到用户点击新开)。""" """中控 AI:单会话聊天(直到用户点击新开)。"""
from __future__ import annotations from __future__ import annotations
import threading
from typing import Any, Optional from typing import Any, Optional
from hub_ai.attachments import parse_chat_attachments from hub_ai.attachments import parse_chat_attachments
@@ -20,7 +21,7 @@ from hub_ai.config import (
) )
from hub_trades_lib import current_trading_day from hub_trades_lib import current_trading_day
from hub_ai.context import ( from hub_ai.context import (
build_daily_context, build_chat_context,
format_chat_context_for_chat, format_chat_context_for_chat,
format_chat_position_overview, format_chat_position_overview,
) )
@@ -213,7 +214,7 @@ def send_chat_message(
user_prompt += "\n\n【附件正文】\n" + _clip_text(parsed["text_append"], 3000) user_prompt += "\n\n【附件正文】\n" + _clip_text(parsed["text_append"], 3000)
system_prompt = CHAT_GENERAL_SYSTEM system_prompt = CHAT_GENERAL_SYSTEM
else: else:
ctx = build_daily_context(exchanges, trading_day=day) ctx = build_chat_context(exchanges, trading_day=day)
day = ctx["trading_day"] day = ctx["trading_day"]
brief_ctx, excerpt = _trading_context_bundle(ctx, prior_count=prior_count) brief_ctx, excerpt = _trading_context_bundle(ctx, prior_count=prior_count)
user_prompt = build_chat_user_prompt( user_prompt = build_chat_user_prompt(
@@ -247,13 +248,21 @@ def send_chat_message(
attachments=parsed.get("attachment_meta") or [], attachments=parsed.get("attachment_meta") or [],
) )
session = append_chat_message(sid, "assistant", reply) session = append_chat_message(sid, "assistant", reply)
refresh_session_rolling_summary( summary_kwargs = {
sid, "session_id": sid,
prior_summary=prior_rolling, "prior_summary": prior_rolling,
user_text=user_visible, "user_text": user_visible,
assistant_text=reply, "assistant_text": reply,
bot_mode=bot_mode, "bot_mode": bot_mode,
) }
def _refresh_summary_bg() -> None:
try:
refresh_session_rolling_summary(**summary_kwargs)
except Exception:
pass
threading.Thread(target=_refresh_summary_bg, daemon=True).start()
session = get_active_session() or session session = get_active_session() or session
return { return {
"ok": True, "ok": True,
+1
View File
@@ -33,6 +33,7 @@ FUND_HISTORY_DAYS = 180
CHAT_MAX_ATTACHMENTS = 3 CHAT_MAX_ATTACHMENTS = 3
CHAT_MAX_IMAGE_BYTES = 4 * 1024 * 1024 CHAT_MAX_IMAGE_BYTES = 4 * 1024 * 1024
CHAT_MAX_TEXT_FILE_BYTES = 200 * 1024 CHAT_MAX_TEXT_FILE_BYTES = 200 * 1024
CHAT_CONTEXT_CACHE_TTL_SEC = _int_env("CHAT_CONTEXT_CACHE_TTL_SEC", 45)
def trading_day_reset_hour() -> int: def trading_day_reset_hour() -> int:
+82 -6
View File
@@ -5,7 +5,10 @@ import hashlib
import json import json
import os import os
import re import re
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta from datetime import datetime, timedelta
from threading import Lock
from typing import Any, Optional from typing import Any, Optional
import httpx import httpx
@@ -20,6 +23,17 @@ from hub_ai.config import (
from hub_ai.fund_history import format_fund_history_text, get_fund_history, record_fund_snapshot from hub_ai.fund_history import format_fund_history_text, get_fund_history, record_fund_snapshot
from hub_trades_lib import current_trading_day, summarize_trades from hub_trades_lib import current_trading_day, summarize_trades
_CHAT_CONTEXT_CACHE: dict[str, dict[str, Any]] = {}
_CHAT_CONTEXT_CACHE_LOCK = Lock()
_HUB_TPSL_MERGE_FN: Any = None
def _chat_context_cache_ttl_sec() -> float:
try:
return float(os.getenv("CHAT_CONTEXT_CACHE_TTL_SEC", "45") or "45")
except ValueError:
return 45.0
def _hub_token() -> str: def _hub_token() -> str:
return (os.getenv("HUB_BRIDGE_TOKEN") or os.getenv("CONTROL_TOKEN") or "").strip() return (os.getenv("HUB_BRIDGE_TOKEN") or os.getenv("CONTROL_TOKEN") or "").strip()
@@ -348,12 +362,20 @@ def _enrich_positions_exchange_tpsl(
price_snap: Optional[dict], price_snap: Optional[dict],
hub_mon: Optional[dict], hub_mon: Optional[dict],
) -> None: ) -> None:
global _HUB_TPSL_MERGE_FN
if not positions: if not positions:
return return
if _HUB_TPSL_MERGE_FN is None:
try: try:
from hub import _merge_flask_exchange_tpsl from hub import _merge_flask_exchange_tpsl
_merge_flask_exchange_tpsl( _HUB_TPSL_MERGE_FN = _merge_flask_exchange_tpsl
except Exception:
_HUB_TPSL_MERGE_FN = False
if not _HUB_TPSL_MERGE_FN:
return
try:
_HUB_TPSL_MERGE_FN(
{"agent": {"positions": positions}}, {"agent": {"positions": positions}},
price_snap if isinstance(price_snap, dict) else None, price_snap if isinstance(price_snap, dict) else None,
hub_mon if isinstance(hub_mon, dict) else None, hub_mon if isinstance(hub_mon, dict) else None,
@@ -362,7 +384,13 @@ def _enrich_positions_exchange_tpsl(
pass pass
def _fetch_account_bundle(client: httpx.Client, ex: dict, trading_day: str) -> dict[str, Any]: def _fetch_account_bundle(
client: httpx.Client,
ex: dict,
trading_day: str,
*,
for_chat: bool = False,
) -> dict[str, Any]:
name = ex.get("name") or ex.get("key") or ex.get("id") name = ex.get("name") or ex.get("key") or ex.get("id")
key = ex.get("key") or "" key = ex.get("key") or ""
enabled = bool(ex.get("enabled")) enabled = bool(ex.get("enabled"))
@@ -460,7 +488,7 @@ def _fetch_account_bundle(client: httpx.Client, ex: dict, trading_day: str) -> d
except Exception as exc: except Exception as exc:
base["issues"].append(f"成交接口: {exc}") base["issues"].append(f"成交接口: {exc}")
if prev_day: if prev_day and not for_chat:
try: try:
r = client.get( r = client.get(
f"{flask_url}/api/hub/trades/today", f"{flask_url}/api/hub/trades/today",
@@ -534,18 +562,35 @@ def _fetch_account_bundle(client: httpx.Client, ex: dict, trading_day: str) -> d
return base return base
def _fetch_account_bundle_isolated(ex: dict, trading_day: str, *, for_chat: bool) -> dict[str, Any]:
with httpx.Client() as client:
return _fetch_account_bundle(client, ex, trading_day, for_chat=for_chat)
def build_daily_context( def build_daily_context(
exchanges: list[dict], exchanges: list[dict],
*, *,
trading_day: Optional[str] = None, trading_day: Optional[str] = None,
for_chat: bool = False,
) -> dict[str, Any]: ) -> dict[str, Any]:
day = (trading_day or "").strip()[:10] or current_trading_day( day = (trading_day or "").strip()[:10] or current_trading_day(
reset_hour=trading_day_reset_hour() reset_hour=trading_day_reset_hour()
) )
accounts: list[dict] = [] ex_list = exchanges or []
if for_chat and len(ex_list) > 1:
workers = min(4, len(ex_list))
with ThreadPoolExecutor(max_workers=workers) as pool:
accounts = list(
pool.map(
lambda ex: _fetch_account_bundle_isolated(ex, day, for_chat=True),
ex_list,
)
)
else:
with httpx.Client() as client: with httpx.Client() as client:
for ex in exchanges or []: accounts = [
accounts.append(_fetch_account_bundle(client, ex, day)) _fetch_account_bundle(client, ex, day, for_chat=for_chat) for ex in ex_list
]
total_closed_pnl = 0.0 total_closed_pnl = 0.0
total_closed = total_win = total_loss = 0 total_closed = total_win = total_loss = 0
@@ -589,6 +634,10 @@ def build_daily_context(
"total_funding_usdt": round(total_funding, 4) if total_funding is not None else None, "total_funding_usdt": round(total_funding, 4) if total_funding is not None else None,
"total_trading_usdt": round(total_trading, 4) if total_trading is not None else None, "total_trading_usdt": round(total_trading, 4) if total_trading is not None else None,
} }
if for_chat:
fund_history: list = []
fund_history_text = ""
else:
snap_accounts = [ snap_accounts = [
{ {
**ac, **ac,
@@ -608,6 +657,9 @@ def build_daily_context(
"fund_history": fund_history, "fund_history": fund_history,
"fund_history_text": fund_history_text, "fund_history_text": fund_history_text,
} }
if for_chat:
text = format_chat_context_for_chat(payload)
else:
text = format_context_text(payload) text = format_context_text(payload)
digest = hashlib.sha256(text.encode("utf-8")).hexdigest()[:16] digest = hashlib.sha256(text.encode("utf-8")).hexdigest()[:16]
return { return {
@@ -622,6 +674,30 @@ def build_daily_context(
} }
def build_chat_context(
exchanges: list[dict],
*,
trading_day: Optional[str] = None,
force_refresh: bool = False,
) -> dict[str, Any]:
"""聊天专用上下文:并行拉取、跳过资金曲线/昨日成交,短 TTL 缓存。"""
day = (trading_day or "").strip()[:10] or current_trading_day(
reset_hour=trading_day_reset_hour()
)
ttl = _chat_context_cache_ttl_sec()
now = time.monotonic()
if not force_refresh and ttl > 0:
with _CHAT_CONTEXT_CACHE_LOCK:
hit = _CHAT_CONTEXT_CACHE.get(day)
if hit and (now - float(hit.get("ts") or 0)) < ttl:
return hit["ctx"]
ctx = build_daily_context(exchanges, trading_day=day, for_chat=True)
if ttl > 0:
with _CHAT_CONTEXT_CACHE_LOCK:
_CHAT_CONTEXT_CACHE[day] = {"ts": now, "ctx": ctx}
return ctx
def format_context_text(payload: dict) -> str: def format_context_text(payload: dict) -> str:
lines = [] lines = []
totals = payload.get("totals") or {} totals = payload.get("totals") or {}
+3 -1
View File
@@ -1,6 +1,7 @@
"""中控 AI FastAPI 路由。""" """中控 AI FastAPI 路由。"""
from __future__ import annotations from __future__ import annotations
import asyncio
from typing import Callable from typing import Callable
from fastapi import APIRouter, Body, File, Form, HTTPException, UploadFile from fastapi import APIRouter, Body, File, Form, HTTPException, UploadFile
@@ -153,7 +154,8 @@ def create_hub_ai_router(*, load_all_exchanges: Callable[[], list]) -> APIRouter
"data": data, "data": data,
} }
) )
result = send_chat_message( result = await asyncio.to_thread(
send_chat_message,
exchanges, exchanges,
message, message,
trading_day=_day(trading_day) if trading_day.strip() else None, trading_day=_day(trading_day) if trading_day.strip() else None,
+27 -6
View File
@@ -3623,15 +3623,33 @@
addAiChatPendingFiles(imageFiles); addAiChatPendingFiles(imageFiles);
} }
function renderHubMarkdown(text) { const AI_CHAT_MAX_ATTACHMENTS = 3;
let aiChatPendingFiles = [];
const aiChatMdCache = new Map();
const AI_CHAT_MD_CACHE_MAX = 120;
function renderHubMarkdown(text, cacheKey) {
const raw = String(text || ""); const raw = String(text || "");
if (typeof window !== "undefined" && window.AiReviewRender && window.AiReviewRender.renderMarkdown) { if (cacheKey && aiChatMdCache.has(cacheKey)) {
return window.AiReviewRender.renderMarkdown(raw); return aiChatMdCache.get(cacheKey);
} }
return esc(raw) let html;
if (typeof window !== "undefined" && window.AiReviewRender && window.AiReviewRender.renderMarkdown) {
html = window.AiReviewRender.renderMarkdown(raw);
} else {
html = esc(raw)
.replace(/\*\*(.+?)\*\*/g, "<strong>$1</strong>") .replace(/\*\*(.+?)\*\*/g, "<strong>$1</strong>")
.replace(/\n/g, "<br>"); .replace(/\n/g, "<br>");
} }
if (cacheKey) {
if (aiChatMdCache.size >= AI_CHAT_MD_CACHE_MAX) {
const firstKey = aiChatMdCache.keys().next().value;
if (firstKey != null) aiChatMdCache.delete(firstKey);
}
aiChatMdCache.set(cacheKey, html);
}
return html;
}
function scrollAiChatToEnd() { function scrollAiChatToEnd() {
const box = document.getElementById("ai-chat-messages"); const box = document.getElementById("ai-chat-messages");
@@ -3716,7 +3734,9 @@
!isUser && !isUser &&
!isThinking && !isThinking &&
/^(AI 调用失败|AI 生成失败)/.test(String(content || "").trim()); /^(AI 调用失败|AI 生成失败)/.test(String(content || "").trim());
const bubbleInner = isUser || isThinking ? esc(content || "") : renderHubMarkdown(content || ""); const mdKey =
!isUser && !isThinking && opts.cacheKey ? String(opts.cacheKey) : "";
const bubbleInner = isUser || isThinking ? esc(content || "") : renderHubMarkdown(content || "", mdKey);
const mdCls = !isUser && !isThinking ? " ai-result-md" : ""; const mdCls = !isUser && !isThinking ? " ai-result-md" : "";
const attList = Array.isArray(attachments) ? attachments : []; const attList = Array.isArray(attachments) ? attachments : [];
const attHtml = attList.length const attHtml = attList.length
@@ -3768,6 +3788,7 @@
box.innerHTML = `<p class="ai-placeholder">${hint}</p>`; box.innerHTML = `<p class="ai-placeholder">${hint}</p>`;
return; return;
} }
const sessionId = session && session.id ? String(session.id) : "local";
let html = msgs let html = msgs
.map((m, idx) => .map((m, idx) =>
renderAiChatRow( renderAiChatRow(
@@ -3775,7 +3796,7 @@
m.content || "", m.content || "",
null, null,
m.attachments, m.attachments,
{ botMode, msgIdx: idx } { botMode, msgIdx: idx, cacheKey: sessionId + ":" + idx }
) )
) )
.join(""); .join("");