"""交易监管:AI 评语与用户回聊。""" from __future__ import annotations import sys from pathlib import Path from typing import Any, Optional _REPO_ROOT = Path(__file__).resolve().parents[2] if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) from ai_client import ai_generate # noqa: E402 from hub_ai.client import generate_text, model_label from hub_ai.config import ( CHAT_MAX_OUTPUT_TOKENS, CHAT_TEMPERATURE, trading_day_reset_hour, ) from hub_ai.context import build_chat_context, format_chat_context_for_chat, format_chat_position_overview from hub_ai.prompts import SUPERVISOR_SYSTEM, build_supervisor_ai_prompt, build_supervisor_chat_prompt from hub_ai.supervisor_store import ( append_supervisor_ai_message, ensure_supervisor_session, get_supervisor_session_state, ) from hub_ai.store import append_chat_message from hub_ai.text_util import is_ai_error_reply from hub_supervisor_lib import build_supervisor_fallback_reply from hub_trades_lib import current_trading_day SUPERVISOR_AI_MAX_TOKENS = 320 def generate_supervisor_ai_reply( *, event: dict, warnings: list[dict], trading_day: str, session_id: str, exchanges: list[dict], ) -> str: ctx = build_chat_context(exchanges, trading_day=trading_day) brief = format_chat_position_overview(ctx) + "\n" + format_chat_context_for_chat( ctx, max_chars=2400 ) user_prompt = build_supervisor_ai_prompt( context_text=brief, trading_day=trading_day, event=event, warnings=warnings, ) prompt = f"{SUPERVISOR_SYSTEM.strip()}\n\n---\n\n{user_prompt.strip()}" text = ai_generate(prompt, temperature=0.35, max_tokens=SUPERVISOR_AI_MAX_TOKENS) text = str(text or "").strip() if not text or is_ai_error_reply(text): return build_supervisor_fallback_reply(event, warnings) return text def make_supervisor_ai_reply_fn(exchanges: list[dict]): def _fn(*, event: dict, warnings: list[dict], trading_day: str, session_id: str) -> str: return generate_supervisor_ai_reply( event=event, warnings=warnings or [], trading_day=trading_day, session_id=session_id, exchanges=exchanges, ) return _fn def send_supervisor_chat( exchanges: list[dict], message: str, *, trading_day: str | None = None, ) -> dict[str, Any]: text = (message or "").strip() if not text: return {"ok": False, "msg": "消息不能为空"} day = (trading_day or "").strip()[:10] or current_trading_day( reset_hour=trading_day_reset_hour() ) session = ensure_supervisor_session(day) sid = str(session.get("id") or "") prior = session.get("messages") or [] ctx = build_chat_context(exchanges, trading_day=day) brief = format_chat_context_for_chat(ctx, max_chars=6000) recent = [] for m in prior[-8:]: role = m.get("role") if role not in ("user", "assistant", "system"): continue label = {"user": "用户", "assistant": "监管", "system": "系统"}.get(role, role) recent.append(f"{label}:{str(m.get('content') or '').strip()}") user_prompt = build_supervisor_chat_prompt( context_text=brief, trading_day=day, history_lines="\n".join(recent), user_message=text, ) reply = generate_text( system=SUPERVISOR_SYSTEM, user=user_prompt, temperature=min(0.4, CHAT_TEMPERATURE), max_tokens=min(768, CHAT_MAX_OUTPUT_TOKENS), max_continuations=1, ) reply = str(reply or "").strip() if not reply or is_ai_error_reply(reply): return {"ok": False, "msg": "AI 暂时不可用,请稍后再试", "session_id": sid} append_chat_message(sid, "user", text) session = append_supervisor_ai_message(sid, reply) state = get_supervisor_session_state(day) return { "ok": True, "trading_day": day, "session": session, "reply": reply, "model": model_label(), "message_count": state.get("message_count"), "unread_system": state.get("unread_system"), }