Files
dekun 467d160f4d perf(hub-ai): reduce CPU load during trading coach chat
Cache chat context, parallelize exchange fetches, skip fund history writes, defer rolling summary to a background thread, and cache markdown rendering on the client.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-14 01:59:43 +08:00

276 lines
9.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""中控 AI:单会话聊天(直到用户点击新开)。"""
from __future__ import annotations
import threading
from typing import Any, Optional
from hub_ai.attachments import parse_chat_attachments
from hub_ai.client import generate_text, model_label
from hub_ai.config import (
CHAT_CONTEXT_MAX_CHARS,
CHAT_FOLLOWUP_CONTEXT_MAX_CHARS,
CHAT_HISTORY_MAX_CHARS_PER_MSG,
CHAT_MAX_CONTINUATIONS,
CHAT_MAX_HISTORY_TURNS,
CHAT_MAX_OUTPUT_TOKENS,
CHAT_PROMPT_MAX_CHARS,
CHAT_SUMMARY_EXCERPT_MAX_CHARS,
CHAT_TEMPERATURE,
CHAT_USER_MESSAGE_MAX_CHARS,
trading_day_reset_hour,
)
from hub_trades_lib import current_trading_day
from hub_ai.context import (
build_chat_context,
format_chat_context_for_chat,
format_chat_position_overview,
)
from hub_ai.prompts import (
CHAT_GENERAL_SYSTEM,
CHAT_SYSTEM,
build_chat_user_prompt,
build_general_chat_user_prompt,
)
from hub_ai.rolling_summary import refresh_session_rolling_summary
from hub_ai.store import (
CHAT_BOT_GENERAL,
CHAT_BOT_TRADING,
append_chat_message,
create_new_session,
delete_chat_session,
ensure_active_session,
get_active_session,
list_chat_sessions,
load_chat_store,
set_active_session,
summary_excerpt_for_chat,
)
from hub_ai.text_util import clip_text, is_ai_error_reply
def _is_ai_error_reply(text: str) -> bool:
return is_ai_error_reply(text)
def _clip_text(text: str, max_chars: int) -> str:
return clip_text(text, max_chars)
def _history_lines(
messages: list[dict],
max_turns: int = CHAT_MAX_HISTORY_TURNS,
*,
max_chars_per_msg: int = CHAT_HISTORY_MAX_CHARS_PER_MSG,
total_max_chars: int | None = None,
) -> str:
rows = [m for m in (messages or []) if m.get("role") in ("user", "assistant")]
rows = rows[-max_turns * 2 :]
lines = []
for m in rows:
role = "用户" if m.get("role") == "user" else "搭档"
content = str(m.get("content") or "").strip()
if m.get("role") == "assistant" and _is_ai_error_reply(content):
continue
att = m.get("attachments") or []
if att:
names = "".join(str(a.get("name") or "附件") for a in att[:3])
content = f"{content} [附件: {names}]".strip()
content = _clip_text(content, max_chars_per_msg)
if content:
lines.append(f"{role}{content}")
if total_max_chars and total_max_chars > 0:
while lines and len("\n".join(lines)) > total_max_chars:
lines.pop(0)
return "\n".join(lines)
def _trading_context_bundle(ctx: dict[str, Any], *, prior_count: int) -> tuple[str, str]:
day = str(ctx.get("trading_day") or (ctx.get("totals") or {}).get("trading_day") or "")
if prior_count <= 0:
brief = format_chat_context_for_chat(ctx, max_chars=CHAT_CONTEXT_MAX_CHARS)
excerpt = summary_excerpt_for_chat(day, max_chars=CHAT_SUMMARY_EXCERPT_MAX_CHARS)
return brief, excerpt
totals = ctx.get("totals") or {}
overview = format_chat_position_overview(ctx)
slim = (
f"【续聊快照 {day}】平仓盈亏 {totals.get('total_pnl_u')}U | "
f"笔数 {totals.get('closed_count')} | "
f"持仓 {totals.get('open_position_count', 0)} 仓 | "
f"浮盈亏 {totals.get('float_pnl_u')}U"
)
brief = _clip_text(overview + "\n" + slim, CHAT_FOLLOWUP_CONTEXT_MAX_CHARS)
return brief, ""
def _history_budget(*sizes: int) -> int:
used = sum(int(s or 0) for s in sizes) + 2200
return max(1200, CHAT_PROMPT_MAX_CHARS - used)
def _prompt_memory(session: dict, prior_msgs: list[dict]) -> tuple[str, str]:
"""续聊优先用滚动摘要;旧会话无摘要时仅带最近 1 轮兜底。"""
rolling = str(session.get("rolling_summary") or "").strip()
if rolling:
return rolling, ""
prior_count = len([m for m in prior_msgs if m.get("role") in ("user", "assistant")])
if prior_count <= 0:
return "", ""
tail = _history_lines(
prior_msgs,
max_turns=1,
max_chars_per_msg=CHAT_HISTORY_MAX_CHARS_PER_MSG,
)
return "", tail
def get_chat_state() -> dict[str, Any]:
store = load_chat_store()
session = get_active_session()
if session:
session.setdefault("bot_mode", CHAT_BOT_TRADING)
session.setdefault("rolling_summary", "")
return {
"active_session_id": store.get("active_session_id"),
"session": session,
"sessions": list_chat_sessions(),
"model": model_label(),
}
def start_new_chat(*, trading_day: str, bot_mode: str = CHAT_BOT_TRADING) -> dict:
session = create_new_session(trading_day=trading_day, bot_mode=bot_mode)
return {
"ok": True,
"session": session,
"sessions": list_chat_sessions(),
"model": model_label(),
}
def switch_chat_session(session_id: str) -> dict[str, Any]:
session = set_active_session(session_id)
return {
"ok": True,
"session": session,
"sessions": list_chat_sessions(),
"model": model_label(),
}
def remove_chat_session(session_id: str) -> dict[str, Any]:
deleted, new_active = delete_chat_session(session_id)
if not deleted:
return {"ok": False, "msg": "session_not_found"}
session = get_active_session()
return {
"ok": True,
"active_session_id": new_active,
"session": session,
"sessions": list_chat_sessions(),
"model": model_label(),
}
def send_chat_message(
exchanges: list[dict],
message: str,
*,
trading_day: str | None = None,
raw_attachments: Optional[list[dict]] = None,
) -> dict[str, Any]:
text = (message or "").strip()
parsed = parse_chat_attachments(raw_attachments or [])
if parsed.get("errors") and not text and not parsed.get("images_b64"):
return {"ok": False, "msg": "".join(parsed["errors"])}
if not text and not parsed.get("images_b64") and not parsed.get("text_append"):
return {"ok": False, "msg": "消息不能为空"}
user_visible = text
if parsed.get("text_append"):
user_visible = (user_visible + "\n\n" + parsed["text_append"]).strip()
if not user_visible and parsed.get("attachment_note"):
user_visible = f"(上传了 {parsed['attachment_note']}"
day = (trading_day or "").strip()[:10] or current_trading_day(
reset_hour=trading_day_reset_hour()
)
session = ensure_active_session(trading_day=day)
sid = session["id"]
prior_rolling = str(session.get("rolling_summary") or "")
prior_msgs = session.get("messages") or []
prior_count = len([m for m in prior_msgs if m.get("role") in ("user", "assistant")])
user_for_prompt = _clip_text(text or user_visible, CHAT_USER_MESSAGE_MAX_CHARS)
rolling_summary, history_tail = _prompt_memory(session, prior_msgs)
bot_mode = (session.get("bot_mode") or CHAT_BOT_TRADING).strip().lower()
if bot_mode == CHAT_BOT_GENERAL:
user_prompt = build_general_chat_user_prompt(
rolling_summary=rolling_summary,
history_lines=history_tail,
user_message=user_for_prompt,
attachment_note=str(parsed.get("attachment_note") or ""),
)
if parsed.get("text_append"):
user_prompt += "\n\n【附件正文】\n" + _clip_text(parsed["text_append"], 3000)
system_prompt = CHAT_GENERAL_SYSTEM
else:
ctx = build_chat_context(exchanges, trading_day=day)
day = ctx["trading_day"]
brief_ctx, excerpt = _trading_context_bundle(ctx, prior_count=prior_count)
user_prompt = build_chat_user_prompt(
context_text=brief_ctx,
trading_day=day,
summary_excerpt=excerpt,
rolling_summary=rolling_summary,
history_lines=history_tail,
user_message=user_for_prompt,
attachment_note=str(parsed.get("attachment_note") or ""),
)
if parsed.get("text_append"):
user_prompt += "\n\n【附件正文】\n" + _clip_text(parsed["text_append"], 3000)
system_prompt = CHAT_SYSTEM
reply = generate_text(
system=system_prompt,
user=user_prompt,
temperature=CHAT_TEMPERATURE,
images_b64=parsed.get("images_b64") or None,
max_tokens=CHAT_MAX_OUTPUT_TOKENS,
max_continuations=CHAT_MAX_CONTINUATIONS,
)
if _is_ai_error_reply(reply):
return {"ok": False, "msg": reply, "session_id": sid}
append_chat_message(
sid,
"user",
user_visible,
attachments=parsed.get("attachment_meta") or [],
)
session = append_chat_message(sid, "assistant", reply)
summary_kwargs = {
"session_id": sid,
"prior_summary": prior_rolling,
"user_text": user_visible,
"assistant_text": reply,
"bot_mode": bot_mode,
}
def _refresh_summary_bg() -> None:
try:
refresh_session_rolling_summary(**summary_kwargs)
except Exception:
pass
threading.Thread(target=_refresh_summary_bg, daemon=True).start()
session = get_active_session() or session
return {
"ok": True,
"trading_day": day,
"session": session,
"sessions": list_chat_sessions(),
"reply": reply,
"model": model_label(),
"attachment_warnings": parsed.get("errors") or [],
}