65901c5577
Skip appending AI error strings to the session and use event-specific fallback commentary when the model returns empty content. Co-authored-by: Cursor <cursoragent@cursor.com>
126 lines
4.1 KiB
Python
126 lines
4.1 KiB
Python
"""交易监管:AI 评语与用户回聊。"""
|
||
from __future__ import annotations
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import Any, Optional
|
||
|
||
_REPO_ROOT = Path(__file__).resolve().parents[2]
|
||
if str(_REPO_ROOT) not in sys.path:
|
||
sys.path.insert(0, str(_REPO_ROOT))
|
||
|
||
from ai_client import ai_generate # noqa: E402
|
||
|
||
from hub_ai.client import generate_text, model_label
|
||
from hub_ai.config import (
|
||
CHAT_MAX_OUTPUT_TOKENS,
|
||
CHAT_TEMPERATURE,
|
||
trading_day_reset_hour,
|
||
)
|
||
from hub_ai.context import build_chat_context, format_chat_context_for_chat, format_chat_position_overview
|
||
from hub_ai.prompts import SUPERVISOR_SYSTEM, build_supervisor_ai_prompt, build_supervisor_chat_prompt
|
||
from hub_ai.supervisor_store import (
|
||
append_supervisor_ai_message,
|
||
ensure_supervisor_session,
|
||
get_supervisor_session_state,
|
||
)
|
||
from hub_ai.store import append_chat_message
|
||
from hub_ai.text_util import is_ai_error_reply
|
||
from hub_supervisor_lib import build_supervisor_fallback_reply
|
||
from hub_trades_lib import current_trading_day
|
||
|
||
SUPERVISOR_AI_MAX_TOKENS = 320
|
||
|
||
|
||
def generate_supervisor_ai_reply(
|
||
*,
|
||
event: dict,
|
||
warnings: list[dict],
|
||
trading_day: str,
|
||
session_id: str,
|
||
exchanges: list[dict],
|
||
) -> str:
|
||
ctx = build_chat_context(exchanges, trading_day=trading_day)
|
||
brief = format_chat_position_overview(ctx) + "\n" + format_chat_context_for_chat(
|
||
ctx, max_chars=2400
|
||
)
|
||
user_prompt = build_supervisor_ai_prompt(
|
||
context_text=brief,
|
||
trading_day=trading_day,
|
||
event=event,
|
||
warnings=warnings,
|
||
)
|
||
prompt = f"{SUPERVISOR_SYSTEM.strip()}\n\n---\n\n{user_prompt.strip()}"
|
||
text = ai_generate(prompt, temperature=0.35, max_tokens=SUPERVISOR_AI_MAX_TOKENS)
|
||
text = str(text or "").strip()
|
||
if not text or is_ai_error_reply(text):
|
||
return build_supervisor_fallback_reply(event, warnings)
|
||
return text
|
||
|
||
|
||
def make_supervisor_ai_reply_fn(exchanges: list[dict]):
|
||
def _fn(*, event: dict, warnings: list[dict], trading_day: str, session_id: str) -> str:
|
||
return generate_supervisor_ai_reply(
|
||
event=event,
|
||
warnings=warnings or [],
|
||
trading_day=trading_day,
|
||
session_id=session_id,
|
||
exchanges=exchanges,
|
||
)
|
||
|
||
return _fn
|
||
|
||
|
||
def send_supervisor_chat(
|
||
exchanges: list[dict],
|
||
message: str,
|
||
*,
|
||
trading_day: str | None = None,
|
||
) -> dict[str, Any]:
|
||
text = (message or "").strip()
|
||
if not text:
|
||
return {"ok": False, "msg": "消息不能为空"}
|
||
day = (trading_day or "").strip()[:10] or current_trading_day(
|
||
reset_hour=trading_day_reset_hour()
|
||
)
|
||
session = ensure_supervisor_session(day)
|
||
sid = str(session.get("id") or "")
|
||
prior = session.get("messages") or []
|
||
ctx = build_chat_context(exchanges, trading_day=day)
|
||
brief = format_chat_context_for_chat(ctx, max_chars=6000)
|
||
recent = []
|
||
for m in prior[-8:]:
|
||
role = m.get("role")
|
||
if role not in ("user", "assistant", "system"):
|
||
continue
|
||
label = {"user": "用户", "assistant": "监管", "system": "系统"}.get(role, role)
|
||
recent.append(f"{label}:{str(m.get('content') or '').strip()}")
|
||
user_prompt = build_supervisor_chat_prompt(
|
||
context_text=brief,
|
||
trading_day=day,
|
||
history_lines="\n".join(recent),
|
||
user_message=text,
|
||
)
|
||
reply = generate_text(
|
||
system=SUPERVISOR_SYSTEM,
|
||
user=user_prompt,
|
||
temperature=min(0.4, CHAT_TEMPERATURE),
|
||
max_tokens=min(768, CHAT_MAX_OUTPUT_TOKENS),
|
||
max_continuations=1,
|
||
)
|
||
reply = str(reply or "").strip()
|
||
if not reply or is_ai_error_reply(reply):
|
||
return {"ok": False, "msg": "AI 暂时不可用,请稍后再试", "session_id": sid}
|
||
append_chat_message(sid, "user", text)
|
||
session = append_supervisor_ai_message(sid, reply)
|
||
state = get_supervisor_session_state(day)
|
||
return {
|
||
"ok": True,
|
||
"trading_day": day,
|
||
"session": session,
|
||
"reply": reply,
|
||
"model": model_label(),
|
||
"message_count": state.get("message_count"),
|
||
"unread_system": state.get("unread_system"),
|
||
}
|