From 0e2e360ccfbddf88fb09c335c9d8fb77501e69ab Mon Sep 17 00:00:00 2001 From: dekun Date: Thu, 11 Jun 2026 00:35:01 +0800 Subject: [PATCH] fix: improve AI coach chat context, 128k window, and output limits Co-authored-by: Cursor --- ai_client.py | 30 ++++++++--- manual_trading_hub/hub_ai/chat.py | 13 +++-- manual_trading_hub/hub_ai/client.py | 8 ++- manual_trading_hub/hub_ai/config.py | 5 +- manual_trading_hub/hub_ai/context.py | 74 +++++++++++++++++++++++++--- manual_trading_hub/hub_ai/prompts.py | 19 ++++--- 6 files changed, 124 insertions(+), 25 deletions(-) diff --git a/ai_client.py b/ai_client.py index 4f87bbe..2a8603e 100644 --- a/ai_client.py +++ b/ai_client.py @@ -104,7 +104,13 @@ def _openai_chat_url() -> str: return f"{base}/chat/completions" -def _generate_openai(prompt: str, images: List[tuple], temperature: float) -> str: +def _generate_openai( + prompt: str, + images: List[tuple], + temperature: float, + *, + max_tokens: int | None = None, +) -> str: api_key = _openai_api_key() if not api_key: return "AI 调用失败:未配置 OPENAI_API_KEY(请在当前实例目录 .env 中设置,修改后需重启服务)" @@ -124,12 +130,14 @@ def _generate_openai(prompt: str, images: List[tuple], temperature: float) -> st messages = [{"role": "user", "content": content}] else: messages = [{"role": "user", "content": prompt}] - body = { + body: dict = { "model": _openai_model(), "messages": messages, "temperature": temperature, "stream": False, } + if max_tokens is not None and max_tokens > 0: + body["max_tokens"] = int(max_tokens) r = requests.post( _openai_chat_url(), headers=headers, @@ -145,12 +153,21 @@ def _generate_openai(prompt: str, images: List[tuple], temperature: float) -> st return (msg.get("content") or "").strip() or "AI 生成失败:空内容" -def _generate_ollama(prompt: str, images: List[tuple], temperature: float) -> str: +def _generate_ollama( + prompt: str, + images: List[tuple], + temperature: float, + *, + max_tokens: int | None = None, +) -> str: + options: dict = {"temperature": temperature} + if max_tokens is not None and max_tokens > 0: + options["num_predict"] = int(max_tokens) payload = { "model": _ollama_model(), "prompt": prompt, "stream": False, - "options": {"temperature": temperature}, + "options": options, } if images: payload["images"] = [b64 for b64, _mime in images] @@ -165,13 +182,14 @@ def ai_generate( image_paths: Optional[Sequence[str]] = None, images_b64: Optional[Sequence[str]] = None, temperature: float = 0.2, + max_tokens: int | None = None, ) -> str: """统一文本生成;失败时返回以「AI 调用失败」开头的说明。""" images = _collect_images(image_paths, images_b64) try: if _use_openai(): - return _generate_openai(prompt, images, temperature) - return _generate_ollama(prompt, images, temperature) + return _generate_openai(prompt, images, temperature, max_tokens=max_tokens) + return _generate_ollama(prompt, images, temperature, max_tokens=max_tokens) except requests.HTTPError as e: detail = "" try: diff --git a/manual_trading_hub/hub_ai/chat.py b/manual_trading_hub/hub_ai/chat.py index 60e05ce..21138c6 100644 --- a/manual_trading_hub/hub_ai/chat.py +++ b/manual_trading_hub/hub_ai/chat.py @@ -5,7 +5,13 @@ from typing import Any, Optional from hub_ai.attachments import parse_chat_attachments from hub_ai.client import generate_text, model_label -from hub_ai.config import CHAT_MAX_HISTORY_TURNS, CHAT_TEMPERATURE +from hub_ai.config import ( + CHAT_CONTEXT_MAX_CHARS, + CHAT_MAX_HISTORY_TURNS, + CHAT_MAX_OUTPUT_TOKENS, + CHAT_SUMMARY_EXCERPT_MAX_CHARS, + CHAT_TEMPERATURE, +) from hub_ai.context import build_daily_context, format_chat_context_for_chat from hub_ai.prompts import CHAT_SYSTEM, build_chat_user_prompt from hub_ai.store import ( @@ -81,8 +87,8 @@ def send_chat_message( attachments=parsed.get("attachment_meta") or [], ) - brief_ctx = format_chat_context_for_chat(ctx) - excerpt = summary_excerpt_for_chat(day) + brief_ctx = format_chat_context_for_chat(ctx, max_chars=CHAT_CONTEXT_MAX_CHARS) + excerpt = summary_excerpt_for_chat(day, max_chars=CHAT_SUMMARY_EXCERPT_MAX_CHARS) user_prompt = build_chat_user_prompt( context_text=brief_ctx, @@ -100,6 +106,7 @@ def send_chat_message( user=user_prompt, temperature=CHAT_TEMPERATURE, images_b64=parsed.get("images_b64") or None, + max_tokens=CHAT_MAX_OUTPUT_TOKENS, ) if reply.startswith("AI 调用失败"): return {"ok": False, "msg": reply, "session_id": sid} diff --git a/manual_trading_hub/hub_ai/client.py b/manual_trading_hub/hub_ai/client.py index 82c5389..d089368 100644 --- a/manual_trading_hub/hub_ai/client.py +++ b/manual_trading_hub/hub_ai/client.py @@ -22,6 +22,12 @@ def generate_text( user: str, temperature: float, images_b64: Optional[Sequence[str]] = None, + max_tokens: int | None = None, ) -> str: prompt = f"{system.strip()}\n\n---\n\n{user.strip()}" - return ai_generate(prompt, temperature=temperature, images_b64=images_b64) + return ai_generate( + prompt, + temperature=temperature, + images_b64=images_b64, + max_tokens=max_tokens, + ) diff --git a/manual_trading_hub/hub_ai/config.py b/manual_trading_hub/hub_ai/config.py index abaf6ff..81e78c7 100644 --- a/manual_trading_hub/hub_ai/config.py +++ b/manual_trading_hub/hub_ai/config.py @@ -7,7 +7,10 @@ HUB_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) SUMMARY_TEMPERATURE = 0.15 CHAT_TEMPERATURE = 0.5 -CHAT_MAX_HISTORY_TURNS = 20 +CHAT_MAX_HISTORY_TURNS = 40 +CHAT_MAX_OUTPUT_TOKENS = 2048 +CHAT_CONTEXT_MAX_CHARS = 128_000 +CHAT_SUMMARY_EXCERPT_MAX_CHARS = 8000 SUMMARY_RETENTION_DAYS = 90 CHAT_SESSION_RETENTION_DAYS = 60 FUND_HISTORY_DAYS = 180 diff --git a/manual_trading_hub/hub_ai/context.py b/manual_trading_hub/hub_ai/context.py index d78b655..6961fc2 100644 --- a/manual_trading_hub/hub_ai/context.py +++ b/manual_trading_hub/hub_ai/context.py @@ -9,7 +9,13 @@ from typing import Any, Optional import httpx -from hub_ai.config import FUND_HISTORY_DAYS, hub_agent_timeout, hub_flask_timeout, trading_day_reset_hour +from hub_ai.config import ( + CHAT_CONTEXT_MAX_CHARS, + FUND_HISTORY_DAYS, + hub_agent_timeout, + hub_flask_timeout, + trading_day_reset_hour, +) from hub_ai.fund_history import format_fund_history_text, get_fund_history, record_fund_snapshot from hub_trades_lib import current_trading_day, summarize_trades @@ -706,15 +712,71 @@ def format_chat_position_overview(payload: dict) -> str: return "\n".join(lines) -def format_chat_context_for_chat(payload: dict, max_chars: int = 5200) -> str: +def format_chat_context_slim(payload: dict) -> str: + """聊天专用:不含 180 日资金曲线与昨日平仓明细,避免挤占对话上下文。""" + totals = payload.get("totals") or {} + day = totals.get("trading_day") + lines = [ + f"【今日合计 {day}】平仓盈亏 {totals.get('total_pnl_u')}U | " + f"笔数 {totals.get('closed_count')}(胜{totals.get('win_count')}/负{totals.get('loss_count')})| " + f"实盘持仓 {totals.get('open_position_count', 0)} 仓 | 浮盈亏 {totals.get('float_pnl_u')}U", + "【说明】持仓=交易所实盘;趋势/关键位/监控单=本地计划,不等于已开仓。", + ] + for ac in payload.get("accounts") or []: + if ac.get("status") == "未监控": + lines.append(f"- {ac.get('name')}:未监控") + continue + st = ac.get("trade_stats") or {} + open_n = int(ac.get("open_position_count") or _account_open_position_count(ac)) + pos_txt = "空仓" if open_n <= 0 else f"{open_n}仓 浮盈亏{ac.get('float_pnl_u')}U" + mc = _monitor_counts(ac) + mon = [] + if mc["trends"]: + mon.append(f"趋势{mc['trends']}") + if mc["rolls"]: + mon.append(f"加仓{mc['rolls']}") + if mc["keys"]: + mon.append(f"关键位{mc['keys']}") + if mc["orders"]: + mon.append(f"监控单{mc['orders']}") + mon_txt = f";监控 {'/'.join(mon)}" if mon else "" + lines.append( + f"- {ac.get('name')}:{pos_txt} | 今日盈亏{st.get('total_pnl_u')}U " + f"({st.get('closed_count')}笔) | 资金{_fmt_fund(ac.get('funding_usdt'))} " + f"交易{_fmt_fund(ac.get('trading_usdt'))}{mon_txt}" + ) + trades = ac.get("trades") or [] + if trades: + for t in trades[:4]: + lines.append(f" · {_format_trade_line(t)}") + if len(trades) > 4: + lines.append(f" · …共{len(trades)}笔今日平仓") + positions = ac.get("positions") or [] + for p in positions[:4]: + if not isinstance(p, dict): + continue + sym = p.get("symbol") or "?" + side = p.get("side") or "?" + upnl = _position_float_pnl(p) + lines.append(f" · 持仓 {sym} {side} 浮盈亏{upnl:.4f}U") + return "\n".join(lines) + + +def format_chat_context_for_chat( + payload: dict, + max_chars: int = CHAT_CONTEXT_MAX_CHARS, +) -> str: overview = format_chat_position_overview(payload) - body = format_context_text(payload) + body = str(payload.get("text") or "").strip() or format_context_text(payload) text = overview + "\n\n" + body if len(text) <= max_chars: return text - budget = max(800, max_chars - len(overview) - 4) - return overview + "\n\n" + body[:budget].rstrip() + "..." + budget = max(2000, max_chars - len(overview) - 4) + return overview + "\n\n" + body[:budget].rstrip() + "…" -def format_chat_context_brief(payload: dict, max_chars: int = 4500) -> str: +def format_chat_context_brief( + payload: dict, + max_chars: int = CHAT_CONTEXT_MAX_CHARS, +) -> str: return format_chat_context_for_chat(payload, max_chars=max_chars) diff --git a/manual_trading_hub/hub_ai/prompts.py b/manual_trading_hub/hub_ai/prompts.py index 41d8555..5368bdc 100644 --- a/manual_trading_hub/hub_ai/prompts.py +++ b/manual_trading_hub/hub_ai/prompts.py @@ -49,6 +49,9 @@ CHAT_SYSTEM = """ - 用户口述与快照冲突时,以快照为准并口语说明「我这边看到是空仓/有N仓」。 - 若附带「今日总结摘要」,那是较早生成的缓存,**实盘持仓以【当前多账户快照】里的「实盘持仓总览」为准**,摘要里若提到持仓可能已过时。 - 若用户上传图片,可结合图中可见信息讨论,看不清的明确说看不清。 +- **优先接住【用户现在说】和【此前对话】**:用户聊心态、悔单、某笔操作时,先顺着这个话题回应,不要每句都复述账户资金数字。 +- **接续对话**:有【此前对话】时须接着聊,不要重复开场白,回复写完整,不要说到一半戛然而止。 +- 快照里的盈亏/资金仅在需要核对事实时引用;用户口述与快照冲突时,以快照为准并口语说明。 """.strip() @@ -71,19 +74,19 @@ def build_chat_user_prompt( user_message: str, attachment_note: str = "", ) -> str: - parts = [ - f"【交易日】{trading_day}", - "【当前多账户快照(含实盘持仓与本地监控,发送时已刷新)】", + parts = [f"【交易日】{trading_day}"] + if history_lines.strip(): + parts.extend(["【此前对话(须接续,勿重复开场)】", history_lines.strip()]) + parts.extend([ + "【当前多账户快照(事实参考;持仓以「实盘持仓总览」为准)】", context_text.strip() or "(无监控数据)", - ] + ]) if summary_excerpt.strip(): parts.extend([ - "【今日总结摘要(可能滞后,持仓以快照「实盘持仓总览」为准)】", + "【今日总结摘要(可能滞后,持仓以快照为准)】", summary_excerpt.strip(), ]) - if history_lines.strip(): - parts.extend(["【此前对话】", history_lines.strip()]) if attachment_note.strip(): parts.extend(["【用户附件说明】", attachment_note.strip()]) - parts.extend(["【用户现在说】", user_message.strip()]) + parts.extend(["【用户现在说(优先回应这一条)】", user_message.strip()]) return "\n\n".join(parts)