6a1f2608b5
Co-authored-by: Cursor <cursoragent@cursor.com>
70 lines
2.4 KiB
Python
70 lines
2.4 KiB
Python
"""聊天滚动摘要:每轮后压缩历史,续聊只带摘要 + 当前消息。"""
|
|
from __future__ import annotations
|
|
|
|
from hub_ai.text_util import clip_text, is_ai_error_reply
|
|
from hub_ai.client import generate_text
|
|
from hub_ai.config import (
|
|
CHAT_ROLLING_SUMMARY_GEN_MAX_TOKENS,
|
|
CHAT_ROLLING_SUMMARY_MAX_CHARS,
|
|
CHAT_ROLLING_SUMMARY_TEMPERATURE,
|
|
)
|
|
from hub_ai.prompts import (
|
|
ROLLING_SUMMARY_GENERAL_SYSTEM,
|
|
ROLLING_SUMMARY_TRADING_SYSTEM,
|
|
build_rolling_summary_user_prompt,
|
|
)
|
|
from hub_ai.store import CHAT_BOT_GENERAL, update_session_rolling_summary
|
|
|
|
|
|
def refresh_session_rolling_summary(
|
|
session_id: str,
|
|
*,
|
|
prior_summary: str,
|
|
user_text: str,
|
|
assistant_text: str,
|
|
bot_mode: str,
|
|
) -> str:
|
|
"""合并旧摘要与本轮对话,生成新的短摘要并写入会话。"""
|
|
user_clip = clip_text(user_text, 1200)
|
|
assistant_clip = clip_text(assistant_text, 1800)
|
|
if not user_clip and not assistant_clip:
|
|
summary = clip_text(prior_summary, CHAT_ROLLING_SUMMARY_MAX_CHARS)
|
|
update_session_rolling_summary(session_id, summary)
|
|
return summary
|
|
|
|
system = (
|
|
ROLLING_SUMMARY_GENERAL_SYSTEM
|
|
if (bot_mode or "").strip().lower() == CHAT_BOT_GENERAL
|
|
else ROLLING_SUMMARY_TRADING_SYSTEM
|
|
)
|
|
raw = generate_text(
|
|
system=system,
|
|
user=build_rolling_summary_user_prompt(
|
|
prior_summary=prior_summary,
|
|
user_text=user_clip,
|
|
assistant_text=assistant_clip,
|
|
),
|
|
temperature=CHAT_ROLLING_SUMMARY_TEMPERATURE,
|
|
max_tokens=CHAT_ROLLING_SUMMARY_GEN_MAX_TOKENS,
|
|
max_continuations=1,
|
|
)
|
|
if is_ai_error_reply(raw):
|
|
fallback = _fallback_summary(prior_summary, user_clip, assistant_clip)
|
|
update_session_rolling_summary(session_id, fallback)
|
|
return fallback
|
|
|
|
summary = clip_text(raw, CHAT_ROLLING_SUMMARY_MAX_CHARS)
|
|
update_session_rolling_summary(session_id, summary)
|
|
return summary
|
|
|
|
|
|
def _fallback_summary(prior: str, user_text: str, assistant_text: str) -> str:
|
|
parts: list[str] = []
|
|
if prior.strip():
|
|
parts.append(prior.strip())
|
|
if user_text.strip():
|
|
parts.append(f"用户:{clip_text(user_text, 200)}")
|
|
if assistant_text.strip():
|
|
parts.append(f"教练:{clip_text(assistant_text, 280)}")
|
|
return clip_text("\n".join(parts), CHAT_ROLLING_SUMMARY_MAX_CHARS)
|