Files
crypto_monitor/manual_trading_hub/hub_ai/supervisor.py
T
dekun 65901c5577 Fix supervisor AI empty replies with fallback templates
Skip appending AI error strings to the session and use event-specific fallback commentary when the model returns empty content.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-23 20:10:33 +08:00

126 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""交易监管:AI 评语与用户回聊。"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Any, Optional
_REPO_ROOT = Path(__file__).resolve().parents[2]
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
from ai_client import ai_generate # noqa: E402
from hub_ai.client import generate_text, model_label
from hub_ai.config import (
CHAT_MAX_OUTPUT_TOKENS,
CHAT_TEMPERATURE,
trading_day_reset_hour,
)
from hub_ai.context import build_chat_context, format_chat_context_for_chat, format_chat_position_overview
from hub_ai.prompts import SUPERVISOR_SYSTEM, build_supervisor_ai_prompt, build_supervisor_chat_prompt
from hub_ai.supervisor_store import (
append_supervisor_ai_message,
ensure_supervisor_session,
get_supervisor_session_state,
)
from hub_ai.store import append_chat_message
from hub_ai.text_util import is_ai_error_reply
from hub_supervisor_lib import build_supervisor_fallback_reply
from hub_trades_lib import current_trading_day
SUPERVISOR_AI_MAX_TOKENS = 320
def generate_supervisor_ai_reply(
*,
event: dict,
warnings: list[dict],
trading_day: str,
session_id: str,
exchanges: list[dict],
) -> str:
ctx = build_chat_context(exchanges, trading_day=trading_day)
brief = format_chat_position_overview(ctx) + "\n" + format_chat_context_for_chat(
ctx, max_chars=2400
)
user_prompt = build_supervisor_ai_prompt(
context_text=brief,
trading_day=trading_day,
event=event,
warnings=warnings,
)
prompt = f"{SUPERVISOR_SYSTEM.strip()}\n\n---\n\n{user_prompt.strip()}"
text = ai_generate(prompt, temperature=0.35, max_tokens=SUPERVISOR_AI_MAX_TOKENS)
text = str(text or "").strip()
if not text or is_ai_error_reply(text):
return build_supervisor_fallback_reply(event, warnings)
return text
def make_supervisor_ai_reply_fn(exchanges: list[dict]):
def _fn(*, event: dict, warnings: list[dict], trading_day: str, session_id: str) -> str:
return generate_supervisor_ai_reply(
event=event,
warnings=warnings or [],
trading_day=trading_day,
session_id=session_id,
exchanges=exchanges,
)
return _fn
def send_supervisor_chat(
exchanges: list[dict],
message: str,
*,
trading_day: str | None = None,
) -> dict[str, Any]:
text = (message or "").strip()
if not text:
return {"ok": False, "msg": "消息不能为空"}
day = (trading_day or "").strip()[:10] or current_trading_day(
reset_hour=trading_day_reset_hour()
)
session = ensure_supervisor_session(day)
sid = str(session.get("id") or "")
prior = session.get("messages") or []
ctx = build_chat_context(exchanges, trading_day=day)
brief = format_chat_context_for_chat(ctx, max_chars=6000)
recent = []
for m in prior[-8:]:
role = m.get("role")
if role not in ("user", "assistant", "system"):
continue
label = {"user": "用户", "assistant": "监管", "system": "系统"}.get(role, role)
recent.append(f"{label}{str(m.get('content') or '').strip()}")
user_prompt = build_supervisor_chat_prompt(
context_text=brief,
trading_day=day,
history_lines="\n".join(recent),
user_message=text,
)
reply = generate_text(
system=SUPERVISOR_SYSTEM,
user=user_prompt,
temperature=min(0.4, CHAT_TEMPERATURE),
max_tokens=min(768, CHAT_MAX_OUTPUT_TOKENS),
max_continuations=1,
)
reply = str(reply or "").strip()
if not reply or is_ai_error_reply(reply):
return {"ok": False, "msg": "AI 暂时不可用,请稍后再试", "session_id": sid}
append_chat_message(sid, "user", text)
session = append_supervisor_ai_message(sid, reply)
state = get_supervisor_session_state(day)
return {
"ok": True,
"trading_day": day,
"session": session,
"reply": reply,
"model": model_label(),
"message_count": state.get("message_count"),
"unread_system": state.get("unread_system"),
}