From 6169fee7b95285abf0e58fa07dfa2277511ae7a6 Mon Sep 17 00:00:00 2001 From: dekun Date: Thu, 11 Jun 2026 00:44:08 +0800 Subject: [PATCH] fix: prevent AI coach chat replies from truncating mid-sentence Co-authored-by: Cursor --- ai_client.py | 227 ++++++++++++++++++++++++---- manual_trading_hub/.env.example | 4 + manual_trading_hub/hub_ai/chat.py | 2 + manual_trading_hub/hub_ai/client.py | 13 +- manual_trading_hub/hub_ai/config.py | 11 +- 5 files changed, 221 insertions(+), 36 deletions(-) diff --git a/ai_client.py b/ai_client.py index 2a8603e..636bc8f 100644 --- a/ai_client.py +++ b/ai_client.py @@ -7,7 +7,7 @@ from __future__ import annotations import base64 import os -from typing import List, Optional, Sequence +from typing import List, Optional, Sequence, Tuple import requests @@ -19,7 +19,12 @@ def _env_str(name: str, default: str = "") -> str: return str(v).strip() -def _ai_timeout_seconds(*, image_count: int = 0) -> int: +def _ai_timeout_seconds(*, image_count: int = 0, chat: bool = False) -> int: + if chat: + try: + return max(30, int(_env_str("CHAT_AI_TIMEOUT_SECONDS", "300") or "300")) + except ValueError: + return 300 if image_count > 0: try: return max(30, int(_env_str("AI_REVIEW_TIMEOUT_SECONDS", "300") or "300")) @@ -104,6 +109,68 @@ def _openai_chat_url() -> str: return f"{base}/chat/completions" +def _openai_message_text(msg: dict) -> str: + content = msg.get("content") + if isinstance(content, list): + parts: list[str] = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + parts.append(str(part.get("text") or "")) + content = "".join(parts) + text = str(content or "").strip() + if not text: + text = str(msg.get("reasoning_content") or "").strip() + return text + + +def _apply_max_tokens(body: dict, max_tokens: int | None) -> None: + if max_tokens is not None and max_tokens > 0: + mt = int(max_tokens) + body["max_tokens"] = mt + body["max_completion_tokens"] = mt + + +def _openai_chat_completion( + messages: list[dict], + *, + temperature: float, + max_tokens: int | None = None, + image_count: int = 0, + chat: bool = False, +) -> Tuple[str, str]: + api_key = _openai_api_key() + if not api_key: + return "AI 调用失败:未配置 OPENAI_API_KEY(请在当前实例目录 .env 中设置,修改后需重启服务)", "error" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + body: dict = { + "model": _openai_model(), + "messages": messages, + "temperature": temperature, + "stream": False, + } + _apply_max_tokens(body, max_tokens) + r = requests.post( + _openai_chat_url(), + headers=headers, + json=body, + timeout=_ai_timeout_seconds(image_count=image_count, chat=chat), + ) + r.raise_for_status() + data = r.json() + choices = data.get("choices") or [] + if not choices: + return "AI 生成失败:响应无 choices", "error" + choice = choices[0] or {} + msg = choice.get("message") or {} + text = _openai_message_text(msg) + if not text: + return "AI 生成失败:空内容", choice.get("finish_reason") or "error" + return text, str(choice.get("finish_reason") or "") + + def _generate_openai( prompt: str, images: List[tuple], @@ -111,13 +178,6 @@ def _generate_openai( *, max_tokens: int | None = None, ) -> str: - api_key = _openai_api_key() - if not api_key: - return "AI 调用失败:未配置 OPENAI_API_KEY(请在当前实例目录 .env 中设置,修改后需重启服务)" - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - } if images: content: List[dict] = [{"type": "text", "text": prompt}] for b64, mime in images: @@ -130,27 +190,13 @@ def _generate_openai( messages = [{"role": "user", "content": content}] else: messages = [{"role": "user", "content": prompt}] - body: dict = { - "model": _openai_model(), - "messages": messages, - "temperature": temperature, - "stream": False, - } - if max_tokens is not None and max_tokens > 0: - body["max_tokens"] = int(max_tokens) - r = requests.post( - _openai_chat_url(), - headers=headers, - json=body, - timeout=_ai_timeout_seconds(image_count=len(images)), + text, _reason = _openai_chat_completion( + messages, + temperature=temperature, + max_tokens=max_tokens, + image_count=len(images), ) - r.raise_for_status() - data = r.json() - choices = data.get("choices") or [] - if not choices: - return "AI 生成失败:响应无 choices" - msg = choices[0].get("message") or {} - return (msg.get("content") or "").strip() or "AI 生成失败:空内容" + return text def _generate_ollama( @@ -159,7 +205,8 @@ def _generate_ollama( temperature: float, *, max_tokens: int | None = None, -) -> str: + chat: bool = False, +) -> Tuple[str, str]: options: dict = {"temperature": temperature} if max_tokens is not None and max_tokens > 0: options["num_predict"] = int(max_tokens) @@ -171,9 +218,15 @@ def _generate_ollama( } if images: payload["images"] = [b64 for b64, _mime in images] - r = requests.post(_ollama_api(), json=payload, timeout=_ai_timeout_seconds(image_count=len(images))) + r = requests.post( + _ollama_api(), + json=payload, + timeout=_ai_timeout_seconds(image_count=len(images), chat=chat), + ) r.raise_for_status() - return (r.json().get("response") or "").strip() or "AI 生成失败" + data = r.json() + text = (data.get("response") or "").strip() or "AI 生成失败" + return text, str(data.get("done_reason") or "") def ai_generate( @@ -189,7 +242,115 @@ def ai_generate( try: if _use_openai(): return _generate_openai(prompt, images, temperature, max_tokens=max_tokens) - return _generate_ollama(prompt, images, temperature, max_tokens=max_tokens) + text, _reason = _generate_ollama(prompt, images, temperature, max_tokens=max_tokens) + return text + except requests.HTTPError as e: + detail = "" + try: + detail = (e.response.text or "")[:500] + except Exception: + pass + prov = "OpenAI" if _use_openai() else "Ollama" + return f"AI 调用失败({prov} HTTP {e.response.status_code if e.response else '?'}):{detail or str(e)}" + except Exception as e: + prov = "OpenAI" if _use_openai() else "Ollama" + return f"AI 调用失败({prov}):{str(e)}" + + +_CHAT_CONTINUE_USER = ( + "你上一条回复在中途截断了。请从断点处继续写完,不要重复已写内容," + "保持同一语气,写完给出完整结尾。" +) +_CHAT_END_CHARS = "。!?.!?\"」』))>】" + + +def _looks_truncated(text: str) -> bool: + t = (text or "").rstrip() + if len(t) < 48: + return False + if t[-1] in _CHAT_END_CHARS: + return False + if t.endswith("…") or t.endswith("..."): + return True + return t[-1] not in ",、,;;::\n" + + +def _should_continue(reason: str, chunk: str) -> bool: + if reason == "length": + return True + return _looks_truncated(chunk) + + +def ai_generate_chat( + *, + system: str, + user: str, + temperature: float = 0.5, + images_b64: Optional[Sequence[str]] = None, + max_tokens: int = 8192, + max_continuations: int = 3, +) -> str: + """聊天专用:system/user 分消息;输出触顶时自动续写。""" + images = _collect_images(None, images_b64) + try: + if _use_openai(): + messages: list[dict] = [ + {"role": "system", "content": system.strip()}, + ] + if images: + content: List[dict] = [{"type": "text", "text": user.strip()}] + for b64, mime in images: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:{mime};base64,{b64}"}, + } + ) + messages.append({"role": "user", "content": content}) + else: + messages.append({"role": "user", "content": user.strip()}) + + parts: list[str] = [] + for _ in range(max(1, int(max_continuations) + 1)): + chunk, reason = _openai_chat_completion( + messages, + temperature=temperature, + max_tokens=max_tokens, + image_count=len(images), + chat=True, + ) + if chunk.startswith("AI 调用失败") or chunk.startswith("AI 生成失败"): + return chunk if not parts else "".join(parts) + parts.append(chunk) + if not _should_continue(reason, chunk): + break + messages.append({"role": "assistant", "content": chunk}) + messages.append({"role": "user", "content": _CHAT_CONTINUE_USER}) + return "".join(parts).strip() or "AI 生成失败:空内容" + + prompt = f"{system.strip()}\n\n---\n\n{user.strip()}" + parts = [] + current_prompt = prompt + for _ in range(max(1, int(max_continuations) + 1)): + chunk, reason = _generate_ollama( + current_prompt, + images if not parts else [], + temperature, + max_tokens=max_tokens, + chat=True, + ) + if chunk.startswith("AI 生成失败") and not parts: + return chunk + parts.append(chunk) + if not _should_continue(reason, chunk): + break + tail = chunk[-400:] if len(chunk) > 400 else chunk + current_prompt = ( + f"{prompt}\n\n{''.join(parts)}\n\n" + f"{_CHAT_CONTINUE_USER}\n\n" + f"(已写结尾片段供衔接:…{tail})" + ) + return "".join(parts).strip() or "AI 生成失败" except requests.HTTPError as e: detail = "" try: diff --git a/manual_trading_hub/.env.example b/manual_trading_hub/.env.example index 9fc1231..3076be4 100644 --- a/manual_trading_hub/.env.example +++ b/manual_trading_hub/.env.example @@ -85,6 +85,10 @@ HUB_TRUST_LAN=true # 与四实例相同变量名;默认 OpenAI 兼容网关(改 AI_PROVIDER=ollama 可走本机 Ollama) # 详见 manual_trading_hub/AI教练说明.md 与仓库根 AI复盘与模型配置说明.md AI_TIMEOUT_SECONDS=120 +# AI 教练聊天:单次输出 token 上限与截断自动续写次数(默认 8192 / 3) +# CHAT_MAX_OUTPUT_TOKENS=8192 +# CHAT_MAX_CONTINUATIONS=3 +# CHAT_AI_TIMEOUT_SECONDS=300 # AI 提供方:openai(默认,OpenAI 兼容网关)| ollama(本机 Ollama) AI_PROVIDER=openai diff --git a/manual_trading_hub/hub_ai/chat.py b/manual_trading_hub/hub_ai/chat.py index 21138c6..ceb7470 100644 --- a/manual_trading_hub/hub_ai/chat.py +++ b/manual_trading_hub/hub_ai/chat.py @@ -7,6 +7,7 @@ from hub_ai.attachments import parse_chat_attachments from hub_ai.client import generate_text, model_label from hub_ai.config import ( CHAT_CONTEXT_MAX_CHARS, + CHAT_MAX_CONTINUATIONS, CHAT_MAX_HISTORY_TURNS, CHAT_MAX_OUTPUT_TOKENS, CHAT_SUMMARY_EXCERPT_MAX_CHARS, @@ -107,6 +108,7 @@ def send_chat_message( temperature=CHAT_TEMPERATURE, images_b64=parsed.get("images_b64") or None, max_tokens=CHAT_MAX_OUTPUT_TOKENS, + max_continuations=CHAT_MAX_CONTINUATIONS, ) if reply.startswith("AI 调用失败"): return {"ok": False, "msg": reply, "session_id": sid} diff --git a/manual_trading_hub/hub_ai/client.py b/manual_trading_hub/hub_ai/client.py index d089368..3a485e2 100644 --- a/manual_trading_hub/hub_ai/client.py +++ b/manual_trading_hub/hub_ai/client.py @@ -9,7 +9,7 @@ _REPO_ROOT = Path(__file__).resolve().parents[2] if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) -from ai_client import ai_generate, ai_provider_label # noqa: E402 +from ai_client import ai_generate, ai_generate_chat, ai_provider_label # noqa: E402 def model_label() -> str: @@ -23,11 +23,20 @@ def generate_text( temperature: float, images_b64: Optional[Sequence[str]] = None, max_tokens: int | None = None, + max_continuations: int = 3, ) -> str: + if max_tokens is not None and max_tokens > 0: + return ai_generate_chat( + system=system, + user=user, + temperature=temperature, + images_b64=images_b64, + max_tokens=int(max_tokens), + max_continuations=max_continuations, + ) prompt = f"{system.strip()}\n\n---\n\n{user.strip()}" return ai_generate( prompt, temperature=temperature, images_b64=images_b64, - max_tokens=max_tokens, ) diff --git a/manual_trading_hub/hub_ai/config.py b/manual_trading_hub/hub_ai/config.py index 81e78c7..aefa313 100644 --- a/manual_trading_hub/hub_ai/config.py +++ b/manual_trading_hub/hub_ai/config.py @@ -5,10 +5,19 @@ import os HUB_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +def _int_env(key: str, default: int) -> int: + try: + return int(os.getenv(key, str(default)) or default) + except ValueError: + return default + + SUMMARY_TEMPERATURE = 0.15 CHAT_TEMPERATURE = 0.5 CHAT_MAX_HISTORY_TURNS = 40 -CHAT_MAX_OUTPUT_TOKENS = 2048 +CHAT_MAX_OUTPUT_TOKENS = _int_env("CHAT_MAX_OUTPUT_TOKENS", 8192) +CHAT_MAX_CONTINUATIONS = _int_env("CHAT_MAX_CONTINUATIONS", 3) CHAT_CONTEXT_MAX_CHARS = 128_000 CHAT_SUMMARY_EXCERPT_MAX_CHARS = 8000 SUMMARY_RETENTION_DAYS = 90