perf(hub-ai): reduce CPU load during trading coach chat
Cache chat context, parallelize exchange fetches, skip fund history writes, defer rolling summary to a background thread, and cache markdown rendering on the client. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""中控 AI:单会话聊天(直到用户点击新开)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
|
||||
from hub_ai.attachments import parse_chat_attachments
|
||||
@@ -20,7 +21,7 @@ from hub_ai.config import (
|
||||
)
|
||||
from hub_trades_lib import current_trading_day
|
||||
from hub_ai.context import (
|
||||
build_daily_context,
|
||||
build_chat_context,
|
||||
format_chat_context_for_chat,
|
||||
format_chat_position_overview,
|
||||
)
|
||||
@@ -213,7 +214,7 @@ def send_chat_message(
|
||||
user_prompt += "\n\n【附件正文】\n" + _clip_text(parsed["text_append"], 3000)
|
||||
system_prompt = CHAT_GENERAL_SYSTEM
|
||||
else:
|
||||
ctx = build_daily_context(exchanges, trading_day=day)
|
||||
ctx = build_chat_context(exchanges, trading_day=day)
|
||||
day = ctx["trading_day"]
|
||||
brief_ctx, excerpt = _trading_context_bundle(ctx, prior_count=prior_count)
|
||||
user_prompt = build_chat_user_prompt(
|
||||
@@ -247,13 +248,21 @@ def send_chat_message(
|
||||
attachments=parsed.get("attachment_meta") or [],
|
||||
)
|
||||
session = append_chat_message(sid, "assistant", reply)
|
||||
refresh_session_rolling_summary(
|
||||
sid,
|
||||
prior_summary=prior_rolling,
|
||||
user_text=user_visible,
|
||||
assistant_text=reply,
|
||||
bot_mode=bot_mode,
|
||||
)
|
||||
summary_kwargs = {
|
||||
"session_id": sid,
|
||||
"prior_summary": prior_rolling,
|
||||
"user_text": user_visible,
|
||||
"assistant_text": reply,
|
||||
"bot_mode": bot_mode,
|
||||
}
|
||||
|
||||
def _refresh_summary_bg() -> None:
|
||||
try:
|
||||
refresh_session_rolling_summary(**summary_kwargs)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
threading.Thread(target=_refresh_summary_bg, daemon=True).start()
|
||||
session = get_active_session() or session
|
||||
return {
|
||||
"ok": True,
|
||||
|
||||
@@ -33,6 +33,7 @@ FUND_HISTORY_DAYS = 180
|
||||
CHAT_MAX_ATTACHMENTS = 3
|
||||
CHAT_MAX_IMAGE_BYTES = 4 * 1024 * 1024
|
||||
CHAT_MAX_TEXT_FILE_BYTES = 200 * 1024
|
||||
CHAT_CONTEXT_CACHE_TTL_SEC = _int_env("CHAT_CONTEXT_CACHE_TTL_SEC", 45)
|
||||
|
||||
|
||||
def trading_day_reset_hour() -> int:
|
||||
|
||||
@@ -5,7 +5,10 @@ import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
@@ -20,6 +23,17 @@ from hub_ai.config import (
|
||||
from hub_ai.fund_history import format_fund_history_text, get_fund_history, record_fund_snapshot
|
||||
from hub_trades_lib import current_trading_day, summarize_trades
|
||||
|
||||
_CHAT_CONTEXT_CACHE: dict[str, dict[str, Any]] = {}
|
||||
_CHAT_CONTEXT_CACHE_LOCK = Lock()
|
||||
_HUB_TPSL_MERGE_FN: Any = None
|
||||
|
||||
|
||||
def _chat_context_cache_ttl_sec() -> float:
|
||||
try:
|
||||
return float(os.getenv("CHAT_CONTEXT_CACHE_TTL_SEC", "45") or "45")
|
||||
except ValueError:
|
||||
return 45.0
|
||||
|
||||
|
||||
def _hub_token() -> str:
|
||||
return (os.getenv("HUB_BRIDGE_TOKEN") or os.getenv("CONTROL_TOKEN") or "").strip()
|
||||
@@ -348,12 +362,20 @@ def _enrich_positions_exchange_tpsl(
|
||||
price_snap: Optional[dict],
|
||||
hub_mon: Optional[dict],
|
||||
) -> None:
|
||||
global _HUB_TPSL_MERGE_FN
|
||||
if not positions:
|
||||
return
|
||||
try:
|
||||
from hub import _merge_flask_exchange_tpsl
|
||||
if _HUB_TPSL_MERGE_FN is None:
|
||||
try:
|
||||
from hub import _merge_flask_exchange_tpsl
|
||||
|
||||
_merge_flask_exchange_tpsl(
|
||||
_HUB_TPSL_MERGE_FN = _merge_flask_exchange_tpsl
|
||||
except Exception:
|
||||
_HUB_TPSL_MERGE_FN = False
|
||||
if not _HUB_TPSL_MERGE_FN:
|
||||
return
|
||||
try:
|
||||
_HUB_TPSL_MERGE_FN(
|
||||
{"agent": {"positions": positions}},
|
||||
price_snap if isinstance(price_snap, dict) else None,
|
||||
hub_mon if isinstance(hub_mon, dict) else None,
|
||||
@@ -362,7 +384,13 @@ def _enrich_positions_exchange_tpsl(
|
||||
pass
|
||||
|
||||
|
||||
def _fetch_account_bundle(client: httpx.Client, ex: dict, trading_day: str) -> dict[str, Any]:
|
||||
def _fetch_account_bundle(
|
||||
client: httpx.Client,
|
||||
ex: dict,
|
||||
trading_day: str,
|
||||
*,
|
||||
for_chat: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
name = ex.get("name") or ex.get("key") or ex.get("id")
|
||||
key = ex.get("key") or ""
|
||||
enabled = bool(ex.get("enabled"))
|
||||
@@ -460,7 +488,7 @@ def _fetch_account_bundle(client: httpx.Client, ex: dict, trading_day: str) -> d
|
||||
except Exception as exc:
|
||||
base["issues"].append(f"成交接口: {exc}")
|
||||
|
||||
if prev_day:
|
||||
if prev_day and not for_chat:
|
||||
try:
|
||||
r = client.get(
|
||||
f"{flask_url}/api/hub/trades/today",
|
||||
@@ -534,18 +562,35 @@ def _fetch_account_bundle(client: httpx.Client, ex: dict, trading_day: str) -> d
|
||||
return base
|
||||
|
||||
|
||||
def _fetch_account_bundle_isolated(ex: dict, trading_day: str, *, for_chat: bool) -> dict[str, Any]:
|
||||
with httpx.Client() as client:
|
||||
return _fetch_account_bundle(client, ex, trading_day, for_chat=for_chat)
|
||||
|
||||
|
||||
def build_daily_context(
|
||||
exchanges: list[dict],
|
||||
*,
|
||||
trading_day: Optional[str] = None,
|
||||
for_chat: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
day = (trading_day or "").strip()[:10] or current_trading_day(
|
||||
reset_hour=trading_day_reset_hour()
|
||||
)
|
||||
accounts: list[dict] = []
|
||||
with httpx.Client() as client:
|
||||
for ex in exchanges or []:
|
||||
accounts.append(_fetch_account_bundle(client, ex, day))
|
||||
ex_list = exchanges or []
|
||||
if for_chat and len(ex_list) > 1:
|
||||
workers = min(4, len(ex_list))
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
accounts = list(
|
||||
pool.map(
|
||||
lambda ex: _fetch_account_bundle_isolated(ex, day, for_chat=True),
|
||||
ex_list,
|
||||
)
|
||||
)
|
||||
else:
|
||||
with httpx.Client() as client:
|
||||
accounts = [
|
||||
_fetch_account_bundle(client, ex, day, for_chat=for_chat) for ex in ex_list
|
||||
]
|
||||
|
||||
total_closed_pnl = 0.0
|
||||
total_closed = total_win = total_loss = 0
|
||||
@@ -589,17 +634,21 @@ def build_daily_context(
|
||||
"total_funding_usdt": round(total_funding, 4) if total_funding is not None else None,
|
||||
"total_trading_usdt": round(total_trading, 4) if total_trading is not None else None,
|
||||
}
|
||||
snap_accounts = [
|
||||
{
|
||||
**ac,
|
||||
"monitored": ac.get("status") != "未监控",
|
||||
}
|
||||
for ac in accounts
|
||||
]
|
||||
record_fund_snapshot(day, snap_accounts, keep_days=FUND_HISTORY_DAYS)
|
||||
fund_history = get_fund_history(anchor_day=day, keep_days=FUND_HISTORY_DAYS)
|
||||
account_names = {str(ac.get("key") or ac.get("id")): ac.get("name") for ac in accounts}
|
||||
fund_history_text = format_fund_history_text(fund_history, account_names=account_names)
|
||||
if for_chat:
|
||||
fund_history: list = []
|
||||
fund_history_text = ""
|
||||
else:
|
||||
snap_accounts = [
|
||||
{
|
||||
**ac,
|
||||
"monitored": ac.get("status") != "未监控",
|
||||
}
|
||||
for ac in accounts
|
||||
]
|
||||
record_fund_snapshot(day, snap_accounts, keep_days=FUND_HISTORY_DAYS)
|
||||
fund_history = get_fund_history(anchor_day=day, keep_days=FUND_HISTORY_DAYS)
|
||||
account_names = {str(ac.get("key") or ac.get("id")): ac.get("name") for ac in accounts}
|
||||
fund_history_text = format_fund_history_text(fund_history, account_names=account_names)
|
||||
payload = {
|
||||
"trading_day": day,
|
||||
"prev_trading_day": previous_trading_day(day),
|
||||
@@ -608,7 +657,10 @@ def build_daily_context(
|
||||
"fund_history": fund_history,
|
||||
"fund_history_text": fund_history_text,
|
||||
}
|
||||
text = format_context_text(payload)
|
||||
if for_chat:
|
||||
text = format_chat_context_for_chat(payload)
|
||||
else:
|
||||
text = format_context_text(payload)
|
||||
digest = hashlib.sha256(text.encode("utf-8")).hexdigest()[:16]
|
||||
return {
|
||||
"trading_day": day,
|
||||
@@ -622,6 +674,30 @@ def build_daily_context(
|
||||
}
|
||||
|
||||
|
||||
def build_chat_context(
|
||||
exchanges: list[dict],
|
||||
*,
|
||||
trading_day: Optional[str] = None,
|
||||
force_refresh: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""聊天专用上下文:并行拉取、跳过资金曲线/昨日成交,短 TTL 缓存。"""
|
||||
day = (trading_day or "").strip()[:10] or current_trading_day(
|
||||
reset_hour=trading_day_reset_hour()
|
||||
)
|
||||
ttl = _chat_context_cache_ttl_sec()
|
||||
now = time.monotonic()
|
||||
if not force_refresh and ttl > 0:
|
||||
with _CHAT_CONTEXT_CACHE_LOCK:
|
||||
hit = _CHAT_CONTEXT_CACHE.get(day)
|
||||
if hit and (now - float(hit.get("ts") or 0)) < ttl:
|
||||
return hit["ctx"]
|
||||
ctx = build_daily_context(exchanges, trading_day=day, for_chat=True)
|
||||
if ttl > 0:
|
||||
with _CHAT_CONTEXT_CACHE_LOCK:
|
||||
_CHAT_CONTEXT_CACHE[day] = {"ts": now, "ctx": ctx}
|
||||
return ctx
|
||||
|
||||
|
||||
def format_context_text(payload: dict) -> str:
|
||||
lines = []
|
||||
totals = payload.get("totals") or {}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""中控 AI FastAPI 路由。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import APIRouter, Body, File, Form, HTTPException, UploadFile
|
||||
@@ -153,7 +154,8 @@ def create_hub_ai_router(*, load_all_exchanges: Callable[[], list]) -> APIRouter
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
result = send_chat_message(
|
||||
result = await asyncio.to_thread(
|
||||
send_chat_message,
|
||||
exchanges,
|
||||
message,
|
||||
trading_day=_day(trading_day) if trading_day.strip() else None,
|
||||
|
||||
Reference in New Issue
Block a user