"""中控 AI:单会话聊天(直到用户点击新开)。""" from __future__ import annotations import threading from typing import Any, Optional from hub_ai.attachments import parse_chat_attachments from hub_ai.client import generate_text, model_label from hub_ai.config import ( CHAT_CONTEXT_MAX_CHARS, CHAT_FOLLOWUP_CONTEXT_MAX_CHARS, CHAT_HISTORY_MAX_CHARS_PER_MSG, CHAT_MAX_CONTINUATIONS, CHAT_MAX_HISTORY_TURNS, CHAT_MAX_OUTPUT_TOKENS, CHAT_PROMPT_MAX_CHARS, CHAT_SUMMARY_EXCERPT_MAX_CHARS, CHAT_TEMPERATURE, CHAT_USER_MESSAGE_MAX_CHARS, trading_day_reset_hour, ) from hub_trades_lib import current_trading_day from hub_ai.context import ( build_chat_context, format_chat_context_for_chat, format_chat_position_overview, ) from hub_ai.prompts import ( CHAT_GENERAL_SYSTEM, CHAT_SYSTEM, build_chat_user_prompt, build_general_chat_user_prompt, ) from hub_ai.rolling_summary import refresh_session_rolling_summary from hub_ai.store import ( CHAT_BOT_GENERAL, CHAT_BOT_TRADING, append_chat_message, create_new_session, delete_chat_session, ensure_active_session, get_active_session, list_chat_sessions, load_chat_store, set_active_session, summary_excerpt_for_chat, ) from hub_ai.text_util import clip_text, is_ai_error_reply def _is_ai_error_reply(text: str) -> bool: return is_ai_error_reply(text) def _clip_text(text: str, max_chars: int) -> str: return clip_text(text, max_chars) def _history_lines( messages: list[dict], max_turns: int = CHAT_MAX_HISTORY_TURNS, *, max_chars_per_msg: int = CHAT_HISTORY_MAX_CHARS_PER_MSG, total_max_chars: int | None = None, ) -> str: rows = [m for m in (messages or []) if m.get("role") in ("user", "assistant")] rows = rows[-max_turns * 2 :] lines = [] for m in rows: role = "用户" if m.get("role") == "user" else "搭档" content = str(m.get("content") or "").strip() if m.get("role") == "assistant" and _is_ai_error_reply(content): continue att = m.get("attachments") or [] if att: names = "、".join(str(a.get("name") or "附件") for a in att[:3]) content = f"{content} [附件: {names}]".strip() content = _clip_text(content, max_chars_per_msg) if content: lines.append(f"{role}:{content}") if total_max_chars and total_max_chars > 0: while lines and len("\n".join(lines)) > total_max_chars: lines.pop(0) return "\n".join(lines) def _trading_context_bundle(ctx: dict[str, Any], *, prior_count: int) -> tuple[str, str]: day = str(ctx.get("trading_day") or (ctx.get("totals") or {}).get("trading_day") or "") if prior_count <= 0: brief = format_chat_context_for_chat(ctx, max_chars=CHAT_CONTEXT_MAX_CHARS) excerpt = summary_excerpt_for_chat(day, max_chars=CHAT_SUMMARY_EXCERPT_MAX_CHARS) return brief, excerpt totals = ctx.get("totals") or {} overview = format_chat_position_overview(ctx) slim = ( f"【续聊快照 {day}】平仓盈亏 {totals.get('total_pnl_u')}U | " f"笔数 {totals.get('closed_count')} | " f"持仓 {totals.get('open_position_count', 0)} 仓 | " f"浮盈亏 {totals.get('float_pnl_u')}U" ) brief = _clip_text(overview + "\n" + slim, CHAT_FOLLOWUP_CONTEXT_MAX_CHARS) return brief, "" def _history_budget(*sizes: int) -> int: used = sum(int(s or 0) for s in sizes) + 2200 return max(1200, CHAT_PROMPT_MAX_CHARS - used) def _prompt_memory(session: dict, prior_msgs: list[dict]) -> tuple[str, str]: """续聊优先用滚动摘要;旧会话无摘要时仅带最近 1 轮兜底。""" rolling = str(session.get("rolling_summary") or "").strip() if rolling: return rolling, "" prior_count = len([m for m in prior_msgs if m.get("role") in ("user", "assistant")]) if prior_count <= 0: return "", "" tail = _history_lines( prior_msgs, max_turns=1, max_chars_per_msg=CHAT_HISTORY_MAX_CHARS_PER_MSG, ) return "", tail def get_chat_state() -> dict[str, Any]: store = load_chat_store() session = get_active_session() if session: session.setdefault("bot_mode", CHAT_BOT_TRADING) session.setdefault("rolling_summary", "") return { "active_session_id": store.get("active_session_id"), "session": session, "sessions": list_chat_sessions(), "model": model_label(), } def start_new_chat(*, trading_day: str, bot_mode: str = CHAT_BOT_TRADING) -> dict: session = create_new_session(trading_day=trading_day, bot_mode=bot_mode) return { "ok": True, "session": session, "sessions": list_chat_sessions(), "model": model_label(), } def switch_chat_session(session_id: str) -> dict[str, Any]: session = set_active_session(session_id) return { "ok": True, "session": session, "sessions": list_chat_sessions(), "model": model_label(), } def remove_chat_session(session_id: str) -> dict[str, Any]: deleted, new_active = delete_chat_session(session_id) if not deleted: return {"ok": False, "msg": "session_not_found"} session = get_active_session() return { "ok": True, "active_session_id": new_active, "session": session, "sessions": list_chat_sessions(), "model": model_label(), } def send_chat_message( exchanges: list[dict], message: str, *, trading_day: str | None = None, raw_attachments: Optional[list[dict]] = None, ) -> dict[str, Any]: text = (message or "").strip() parsed = parse_chat_attachments(raw_attachments or []) if parsed.get("errors") and not text and not parsed.get("images_b64"): return {"ok": False, "msg": ";".join(parsed["errors"])} if not text and not parsed.get("images_b64") and not parsed.get("text_append"): return {"ok": False, "msg": "消息不能为空"} user_visible = text if parsed.get("text_append"): user_visible = (user_visible + "\n\n" + parsed["text_append"]).strip() if not user_visible and parsed.get("attachment_note"): user_visible = f"(上传了 {parsed['attachment_note']})" day = (trading_day or "").strip()[:10] or current_trading_day( reset_hour=trading_day_reset_hour() ) session = ensure_active_session(trading_day=day) sid = session["id"] prior_rolling = str(session.get("rolling_summary") or "") prior_msgs = session.get("messages") or [] prior_count = len([m for m in prior_msgs if m.get("role") in ("user", "assistant")]) user_for_prompt = _clip_text(text or user_visible, CHAT_USER_MESSAGE_MAX_CHARS) rolling_summary, history_tail = _prompt_memory(session, prior_msgs) bot_mode = (session.get("bot_mode") or CHAT_BOT_TRADING).strip().lower() if bot_mode == CHAT_BOT_GENERAL: user_prompt = build_general_chat_user_prompt( rolling_summary=rolling_summary, history_lines=history_tail, user_message=user_for_prompt, attachment_note=str(parsed.get("attachment_note") or ""), ) if parsed.get("text_append"): user_prompt += "\n\n【附件正文】\n" + _clip_text(parsed["text_append"], 3000) system_prompt = CHAT_GENERAL_SYSTEM else: ctx = build_chat_context(exchanges, trading_day=day) day = ctx["trading_day"] brief_ctx, excerpt = _trading_context_bundle(ctx, prior_count=prior_count) user_prompt = build_chat_user_prompt( context_text=brief_ctx, trading_day=day, summary_excerpt=excerpt, rolling_summary=rolling_summary, history_lines=history_tail, user_message=user_for_prompt, attachment_note=str(parsed.get("attachment_note") or ""), ) if parsed.get("text_append"): user_prompt += "\n\n【附件正文】\n" + _clip_text(parsed["text_append"], 3000) system_prompt = CHAT_SYSTEM reply = generate_text( system=system_prompt, user=user_prompt, temperature=CHAT_TEMPERATURE, images_b64=parsed.get("images_b64") or None, max_tokens=CHAT_MAX_OUTPUT_TOKENS, max_continuations=CHAT_MAX_CONTINUATIONS, ) if _is_ai_error_reply(reply): return {"ok": False, "msg": reply, "session_id": sid} append_chat_message( sid, "user", user_visible, attachments=parsed.get("attachment_meta") or [], ) session = append_chat_message(sid, "assistant", reply) summary_kwargs = { "session_id": sid, "prior_summary": prior_rolling, "user_text": user_visible, "assistant_text": reply, "bot_mode": bot_mode, } def _refresh_summary_bg() -> None: try: refresh_session_rolling_summary(**summary_kwargs) except Exception: pass threading.Thread(target=_refresh_summary_bg, daemon=True).start() session = get_active_session() or session return { "ok": True, "trading_day": day, "session": session, "sessions": list_chat_sessions(), "reply": reply, "model": model_label(), "attachment_warnings": parsed.get("errors") or [], }