467d160f4d
Cache chat context, parallelize exchange fetches, skip fund history writes, defer rolling summary to a background thread, and cache markdown rendering on the client. Co-authored-by: Cursor <cursoragent@cursor.com>
276 lines
9.2 KiB
Python
276 lines
9.2 KiB
Python
"""中控 AI:单会话聊天(直到用户点击新开)。"""
|
||
from __future__ import annotations
|
||
|
||
import threading
|
||
from typing import Any, Optional
|
||
|
||
from hub_ai.attachments import parse_chat_attachments
|
||
from hub_ai.client import generate_text, model_label
|
||
from hub_ai.config import (
|
||
CHAT_CONTEXT_MAX_CHARS,
|
||
CHAT_FOLLOWUP_CONTEXT_MAX_CHARS,
|
||
CHAT_HISTORY_MAX_CHARS_PER_MSG,
|
||
CHAT_MAX_CONTINUATIONS,
|
||
CHAT_MAX_HISTORY_TURNS,
|
||
CHAT_MAX_OUTPUT_TOKENS,
|
||
CHAT_PROMPT_MAX_CHARS,
|
||
CHAT_SUMMARY_EXCERPT_MAX_CHARS,
|
||
CHAT_TEMPERATURE,
|
||
CHAT_USER_MESSAGE_MAX_CHARS,
|
||
trading_day_reset_hour,
|
||
)
|
||
from hub_trades_lib import current_trading_day
|
||
from hub_ai.context import (
|
||
build_chat_context,
|
||
format_chat_context_for_chat,
|
||
format_chat_position_overview,
|
||
)
|
||
from hub_ai.prompts import (
|
||
CHAT_GENERAL_SYSTEM,
|
||
CHAT_SYSTEM,
|
||
build_chat_user_prompt,
|
||
build_general_chat_user_prompt,
|
||
)
|
||
from hub_ai.rolling_summary import refresh_session_rolling_summary
|
||
from hub_ai.store import (
|
||
CHAT_BOT_GENERAL,
|
||
CHAT_BOT_TRADING,
|
||
append_chat_message,
|
||
create_new_session,
|
||
delete_chat_session,
|
||
ensure_active_session,
|
||
get_active_session,
|
||
list_chat_sessions,
|
||
load_chat_store,
|
||
set_active_session,
|
||
summary_excerpt_for_chat,
|
||
)
|
||
from hub_ai.text_util import clip_text, is_ai_error_reply
|
||
|
||
|
||
def _is_ai_error_reply(text: str) -> bool:
|
||
return is_ai_error_reply(text)
|
||
|
||
|
||
def _clip_text(text: str, max_chars: int) -> str:
|
||
return clip_text(text, max_chars)
|
||
|
||
|
||
def _history_lines(
|
||
messages: list[dict],
|
||
max_turns: int = CHAT_MAX_HISTORY_TURNS,
|
||
*,
|
||
max_chars_per_msg: int = CHAT_HISTORY_MAX_CHARS_PER_MSG,
|
||
total_max_chars: int | None = None,
|
||
) -> str:
|
||
rows = [m for m in (messages or []) if m.get("role") in ("user", "assistant")]
|
||
rows = rows[-max_turns * 2 :]
|
||
lines = []
|
||
for m in rows:
|
||
role = "用户" if m.get("role") == "user" else "搭档"
|
||
content = str(m.get("content") or "").strip()
|
||
if m.get("role") == "assistant" and _is_ai_error_reply(content):
|
||
continue
|
||
att = m.get("attachments") or []
|
||
if att:
|
||
names = "、".join(str(a.get("name") or "附件") for a in att[:3])
|
||
content = f"{content} [附件: {names}]".strip()
|
||
content = _clip_text(content, max_chars_per_msg)
|
||
if content:
|
||
lines.append(f"{role}:{content}")
|
||
if total_max_chars and total_max_chars > 0:
|
||
while lines and len("\n".join(lines)) > total_max_chars:
|
||
lines.pop(0)
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _trading_context_bundle(ctx: dict[str, Any], *, prior_count: int) -> tuple[str, str]:
|
||
day = str(ctx.get("trading_day") or (ctx.get("totals") or {}).get("trading_day") or "")
|
||
if prior_count <= 0:
|
||
brief = format_chat_context_for_chat(ctx, max_chars=CHAT_CONTEXT_MAX_CHARS)
|
||
excerpt = summary_excerpt_for_chat(day, max_chars=CHAT_SUMMARY_EXCERPT_MAX_CHARS)
|
||
return brief, excerpt
|
||
totals = ctx.get("totals") or {}
|
||
overview = format_chat_position_overview(ctx)
|
||
slim = (
|
||
f"【续聊快照 {day}】平仓盈亏 {totals.get('total_pnl_u')}U | "
|
||
f"笔数 {totals.get('closed_count')} | "
|
||
f"持仓 {totals.get('open_position_count', 0)} 仓 | "
|
||
f"浮盈亏 {totals.get('float_pnl_u')}U"
|
||
)
|
||
brief = _clip_text(overview + "\n" + slim, CHAT_FOLLOWUP_CONTEXT_MAX_CHARS)
|
||
return brief, ""
|
||
|
||
|
||
def _history_budget(*sizes: int) -> int:
|
||
used = sum(int(s or 0) for s in sizes) + 2200
|
||
return max(1200, CHAT_PROMPT_MAX_CHARS - used)
|
||
|
||
|
||
def _prompt_memory(session: dict, prior_msgs: list[dict]) -> tuple[str, str]:
|
||
"""续聊优先用滚动摘要;旧会话无摘要时仅带最近 1 轮兜底。"""
|
||
rolling = str(session.get("rolling_summary") or "").strip()
|
||
if rolling:
|
||
return rolling, ""
|
||
prior_count = len([m for m in prior_msgs if m.get("role") in ("user", "assistant")])
|
||
if prior_count <= 0:
|
||
return "", ""
|
||
tail = _history_lines(
|
||
prior_msgs,
|
||
max_turns=1,
|
||
max_chars_per_msg=CHAT_HISTORY_MAX_CHARS_PER_MSG,
|
||
)
|
||
return "", tail
|
||
|
||
|
||
def get_chat_state() -> dict[str, Any]:
|
||
store = load_chat_store()
|
||
session = get_active_session()
|
||
if session:
|
||
session.setdefault("bot_mode", CHAT_BOT_TRADING)
|
||
session.setdefault("rolling_summary", "")
|
||
return {
|
||
"active_session_id": store.get("active_session_id"),
|
||
"session": session,
|
||
"sessions": list_chat_sessions(),
|
||
"model": model_label(),
|
||
}
|
||
|
||
|
||
def start_new_chat(*, trading_day: str, bot_mode: str = CHAT_BOT_TRADING) -> dict:
|
||
session = create_new_session(trading_day=trading_day, bot_mode=bot_mode)
|
||
return {
|
||
"ok": True,
|
||
"session": session,
|
||
"sessions": list_chat_sessions(),
|
||
"model": model_label(),
|
||
}
|
||
|
||
|
||
def switch_chat_session(session_id: str) -> dict[str, Any]:
|
||
session = set_active_session(session_id)
|
||
return {
|
||
"ok": True,
|
||
"session": session,
|
||
"sessions": list_chat_sessions(),
|
||
"model": model_label(),
|
||
}
|
||
|
||
|
||
def remove_chat_session(session_id: str) -> dict[str, Any]:
|
||
deleted, new_active = delete_chat_session(session_id)
|
||
if not deleted:
|
||
return {"ok": False, "msg": "session_not_found"}
|
||
session = get_active_session()
|
||
return {
|
||
"ok": True,
|
||
"active_session_id": new_active,
|
||
"session": session,
|
||
"sessions": list_chat_sessions(),
|
||
"model": model_label(),
|
||
}
|
||
|
||
|
||
def send_chat_message(
|
||
exchanges: list[dict],
|
||
message: str,
|
||
*,
|
||
trading_day: str | None = None,
|
||
raw_attachments: Optional[list[dict]] = None,
|
||
) -> dict[str, Any]:
|
||
text = (message or "").strip()
|
||
parsed = parse_chat_attachments(raw_attachments or [])
|
||
if parsed.get("errors") and not text and not parsed.get("images_b64"):
|
||
return {"ok": False, "msg": ";".join(parsed["errors"])}
|
||
if not text and not parsed.get("images_b64") and not parsed.get("text_append"):
|
||
return {"ok": False, "msg": "消息不能为空"}
|
||
|
||
user_visible = text
|
||
if parsed.get("text_append"):
|
||
user_visible = (user_visible + "\n\n" + parsed["text_append"]).strip()
|
||
if not user_visible and parsed.get("attachment_note"):
|
||
user_visible = f"(上传了 {parsed['attachment_note']})"
|
||
|
||
day = (trading_day or "").strip()[:10] or current_trading_day(
|
||
reset_hour=trading_day_reset_hour()
|
||
)
|
||
session = ensure_active_session(trading_day=day)
|
||
sid = session["id"]
|
||
prior_rolling = str(session.get("rolling_summary") or "")
|
||
prior_msgs = session.get("messages") or []
|
||
prior_count = len([m for m in prior_msgs if m.get("role") in ("user", "assistant")])
|
||
user_for_prompt = _clip_text(text or user_visible, CHAT_USER_MESSAGE_MAX_CHARS)
|
||
rolling_summary, history_tail = _prompt_memory(session, prior_msgs)
|
||
|
||
bot_mode = (session.get("bot_mode") or CHAT_BOT_TRADING).strip().lower()
|
||
if bot_mode == CHAT_BOT_GENERAL:
|
||
user_prompt = build_general_chat_user_prompt(
|
||
rolling_summary=rolling_summary,
|
||
history_lines=history_tail,
|
||
user_message=user_for_prompt,
|
||
attachment_note=str(parsed.get("attachment_note") or ""),
|
||
)
|
||
if parsed.get("text_append"):
|
||
user_prompt += "\n\n【附件正文】\n" + _clip_text(parsed["text_append"], 3000)
|
||
system_prompt = CHAT_GENERAL_SYSTEM
|
||
else:
|
||
ctx = build_chat_context(exchanges, trading_day=day)
|
||
day = ctx["trading_day"]
|
||
brief_ctx, excerpt = _trading_context_bundle(ctx, prior_count=prior_count)
|
||
user_prompt = build_chat_user_prompt(
|
||
context_text=brief_ctx,
|
||
trading_day=day,
|
||
summary_excerpt=excerpt,
|
||
rolling_summary=rolling_summary,
|
||
history_lines=history_tail,
|
||
user_message=user_for_prompt,
|
||
attachment_note=str(parsed.get("attachment_note") or ""),
|
||
)
|
||
if parsed.get("text_append"):
|
||
user_prompt += "\n\n【附件正文】\n" + _clip_text(parsed["text_append"], 3000)
|
||
system_prompt = CHAT_SYSTEM
|
||
|
||
reply = generate_text(
|
||
system=system_prompt,
|
||
user=user_prompt,
|
||
temperature=CHAT_TEMPERATURE,
|
||
images_b64=parsed.get("images_b64") or None,
|
||
max_tokens=CHAT_MAX_OUTPUT_TOKENS,
|
||
max_continuations=CHAT_MAX_CONTINUATIONS,
|
||
)
|
||
if _is_ai_error_reply(reply):
|
||
return {"ok": False, "msg": reply, "session_id": sid}
|
||
|
||
append_chat_message(
|
||
sid,
|
||
"user",
|
||
user_visible,
|
||
attachments=parsed.get("attachment_meta") or [],
|
||
)
|
||
session = append_chat_message(sid, "assistant", reply)
|
||
summary_kwargs = {
|
||
"session_id": sid,
|
||
"prior_summary": prior_rolling,
|
||
"user_text": user_visible,
|
||
"assistant_text": reply,
|
||
"bot_mode": bot_mode,
|
||
}
|
||
|
||
def _refresh_summary_bg() -> None:
|
||
try:
|
||
refresh_session_rolling_summary(**summary_kwargs)
|
||
except Exception:
|
||
pass
|
||
|
||
threading.Thread(target=_refresh_summary_bg, daemon=True).start()
|
||
session = get_active_session() or session
|
||
return {
|
||
"ok": True,
|
||
"trading_day": day,
|
||
"session": session,
|
||
"sessions": list_chat_sessions(),
|
||
"reply": reply,
|
||
"model": model_label(),
|
||
"attachment_warnings": parsed.get("errors") or [],
|
||
}
|