Files
crypto_monitor/manual_trading_hub/hub_ai/rolling_summary.py
T

70 lines
2.4 KiB
Python

"""聊天滚动摘要:每轮后压缩历史,续聊只带摘要 + 当前消息。"""
from __future__ import annotations
from hub_ai.text_util import clip_text, is_ai_error_reply
from hub_ai.client import generate_text
from hub_ai.config import (
CHAT_ROLLING_SUMMARY_GEN_MAX_TOKENS,
CHAT_ROLLING_SUMMARY_MAX_CHARS,
CHAT_ROLLING_SUMMARY_TEMPERATURE,
)
from hub_ai.prompts import (
ROLLING_SUMMARY_GENERAL_SYSTEM,
ROLLING_SUMMARY_TRADING_SYSTEM,
build_rolling_summary_user_prompt,
)
from hub_ai.store import CHAT_BOT_GENERAL, update_session_rolling_summary
def refresh_session_rolling_summary(
session_id: str,
*,
prior_summary: str,
user_text: str,
assistant_text: str,
bot_mode: str,
) -> str:
"""合并旧摘要与本轮对话,生成新的短摘要并写入会话。"""
user_clip = clip_text(user_text, 1200)
assistant_clip = clip_text(assistant_text, 1800)
if not user_clip and not assistant_clip:
summary = clip_text(prior_summary, CHAT_ROLLING_SUMMARY_MAX_CHARS)
update_session_rolling_summary(session_id, summary)
return summary
system = (
ROLLING_SUMMARY_GENERAL_SYSTEM
if (bot_mode or "").strip().lower() == CHAT_BOT_GENERAL
else ROLLING_SUMMARY_TRADING_SYSTEM
)
raw = generate_text(
system=system,
user=build_rolling_summary_user_prompt(
prior_summary=prior_summary,
user_text=user_clip,
assistant_text=assistant_clip,
),
temperature=CHAT_ROLLING_SUMMARY_TEMPERATURE,
max_tokens=CHAT_ROLLING_SUMMARY_GEN_MAX_TOKENS,
max_continuations=1,
)
if is_ai_error_reply(raw):
fallback = _fallback_summary(prior_summary, user_clip, assistant_clip)
update_session_rolling_summary(session_id, fallback)
return fallback
summary = clip_text(raw, CHAT_ROLLING_SUMMARY_MAX_CHARS)
update_session_rolling_summary(session_id, summary)
return summary
def _fallback_summary(prior: str, user_text: str, assistant_text: str) -> str:
parts: list[str] = []
if prior.strip():
parts.append(prior.strip())
if user_text.strip():
parts.append(f"用户:{clip_text(user_text, 200)}")
if assistant_text.strip():
parts.append(f"教练:{clip_text(assistant_text, 280)}")
return clip_text("\n".join(parts), CHAT_ROLLING_SUMMARY_MAX_CHARS)