fix: prevent AI coach chat replies from truncating mid-sentence

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-11 00:44:08 +08:00
parent 0e2e360ccf
commit 6169fee7b9
5 changed files with 221 additions and 36 deletions
+194 -33
View File
@@ -7,7 +7,7 @@ from __future__ import annotations
import base64 import base64
import os import os
from typing import List, Optional, Sequence from typing import List, Optional, Sequence, Tuple
import requests import requests
@@ -19,7 +19,12 @@ def _env_str(name: str, default: str = "") -> str:
return str(v).strip() return str(v).strip()
def _ai_timeout_seconds(*, image_count: int = 0) -> int: def _ai_timeout_seconds(*, image_count: int = 0, chat: bool = False) -> int:
if chat:
try:
return max(30, int(_env_str("CHAT_AI_TIMEOUT_SECONDS", "300") or "300"))
except ValueError:
return 300
if image_count > 0: if image_count > 0:
try: try:
return max(30, int(_env_str("AI_REVIEW_TIMEOUT_SECONDS", "300") or "300")) return max(30, int(_env_str("AI_REVIEW_TIMEOUT_SECONDS", "300") or "300"))
@@ -104,6 +109,68 @@ def _openai_chat_url() -> str:
return f"{base}/chat/completions" return f"{base}/chat/completions"
def _openai_message_text(msg: dict) -> str:
content = msg.get("content")
if isinstance(content, list):
parts: list[str] = []
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
parts.append(str(part.get("text") or ""))
content = "".join(parts)
text = str(content or "").strip()
if not text:
text = str(msg.get("reasoning_content") or "").strip()
return text
def _apply_max_tokens(body: dict, max_tokens: int | None) -> None:
if max_tokens is not None and max_tokens > 0:
mt = int(max_tokens)
body["max_tokens"] = mt
body["max_completion_tokens"] = mt
def _openai_chat_completion(
messages: list[dict],
*,
temperature: float,
max_tokens: int | None = None,
image_count: int = 0,
chat: bool = False,
) -> Tuple[str, str]:
api_key = _openai_api_key()
if not api_key:
return "AI 调用失败:未配置 OPENAI_API_KEY(请在当前实例目录 .env 中设置,修改后需重启服务)", "error"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
body: dict = {
"model": _openai_model(),
"messages": messages,
"temperature": temperature,
"stream": False,
}
_apply_max_tokens(body, max_tokens)
r = requests.post(
_openai_chat_url(),
headers=headers,
json=body,
timeout=_ai_timeout_seconds(image_count=image_count, chat=chat),
)
r.raise_for_status()
data = r.json()
choices = data.get("choices") or []
if not choices:
return "AI 生成失败:响应无 choices", "error"
choice = choices[0] or {}
msg = choice.get("message") or {}
text = _openai_message_text(msg)
if not text:
return "AI 生成失败:空内容", choice.get("finish_reason") or "error"
return text, str(choice.get("finish_reason") or "")
def _generate_openai( def _generate_openai(
prompt: str, prompt: str,
images: List[tuple], images: List[tuple],
@@ -111,13 +178,6 @@ def _generate_openai(
*, *,
max_tokens: int | None = None, max_tokens: int | None = None,
) -> str: ) -> str:
api_key = _openai_api_key()
if not api_key:
return "AI 调用失败:未配置 OPENAI_API_KEY(请在当前实例目录 .env 中设置,修改后需重启服务)"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
if images: if images:
content: List[dict] = [{"type": "text", "text": prompt}] content: List[dict] = [{"type": "text", "text": prompt}]
for b64, mime in images: for b64, mime in images:
@@ -130,27 +190,13 @@ def _generate_openai(
messages = [{"role": "user", "content": content}] messages = [{"role": "user", "content": content}]
else: else:
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
body: dict = { text, _reason = _openai_chat_completion(
"model": _openai_model(), messages,
"messages": messages, temperature=temperature,
"temperature": temperature, max_tokens=max_tokens,
"stream": False, image_count=len(images),
}
if max_tokens is not None and max_tokens > 0:
body["max_tokens"] = int(max_tokens)
r = requests.post(
_openai_chat_url(),
headers=headers,
json=body,
timeout=_ai_timeout_seconds(image_count=len(images)),
) )
r.raise_for_status() return text
data = r.json()
choices = data.get("choices") or []
if not choices:
return "AI 生成失败:响应无 choices"
msg = choices[0].get("message") or {}
return (msg.get("content") or "").strip() or "AI 生成失败:空内容"
def _generate_ollama( def _generate_ollama(
@@ -159,7 +205,8 @@ def _generate_ollama(
temperature: float, temperature: float,
*, *,
max_tokens: int | None = None, max_tokens: int | None = None,
) -> str: chat: bool = False,
) -> Tuple[str, str]:
options: dict = {"temperature": temperature} options: dict = {"temperature": temperature}
if max_tokens is not None and max_tokens > 0: if max_tokens is not None and max_tokens > 0:
options["num_predict"] = int(max_tokens) options["num_predict"] = int(max_tokens)
@@ -171,9 +218,15 @@ def _generate_ollama(
} }
if images: if images:
payload["images"] = [b64 for b64, _mime in images] payload["images"] = [b64 for b64, _mime in images]
r = requests.post(_ollama_api(), json=payload, timeout=_ai_timeout_seconds(image_count=len(images))) r = requests.post(
_ollama_api(),
json=payload,
timeout=_ai_timeout_seconds(image_count=len(images), chat=chat),
)
r.raise_for_status() r.raise_for_status()
return (r.json().get("response") or "").strip() or "AI 生成失败" data = r.json()
text = (data.get("response") or "").strip() or "AI 生成失败"
return text, str(data.get("done_reason") or "")
def ai_generate( def ai_generate(
@@ -189,7 +242,115 @@ def ai_generate(
try: try:
if _use_openai(): if _use_openai():
return _generate_openai(prompt, images, temperature, max_tokens=max_tokens) return _generate_openai(prompt, images, temperature, max_tokens=max_tokens)
return _generate_ollama(prompt, images, temperature, max_tokens=max_tokens) text, _reason = _generate_ollama(prompt, images, temperature, max_tokens=max_tokens)
return text
except requests.HTTPError as e:
detail = ""
try:
detail = (e.response.text or "")[:500]
except Exception:
pass
prov = "OpenAI" if _use_openai() else "Ollama"
return f"AI 调用失败({prov} HTTP {e.response.status_code if e.response else '?'}):{detail or str(e)}"
except Exception as e:
prov = "OpenAI" if _use_openai() else "Ollama"
return f"AI 调用失败({prov}):{str(e)}"
_CHAT_CONTINUE_USER = (
"你上一条回复在中途截断了。请从断点处继续写完,不要重复已写内容,"
"保持同一语气,写完给出完整结尾。"
)
_CHAT_END_CHARS = "。!?.!?\"」』))>】"
def _looks_truncated(text: str) -> bool:
t = (text or "").rstrip()
if len(t) < 48:
return False
if t[-1] in _CHAT_END_CHARS:
return False
if t.endswith("") or t.endswith("..."):
return True
return t[-1] not in ",、,;:\n"
def _should_continue(reason: str, chunk: str) -> bool:
if reason == "length":
return True
return _looks_truncated(chunk)
def ai_generate_chat(
*,
system: str,
user: str,
temperature: float = 0.5,
images_b64: Optional[Sequence[str]] = None,
max_tokens: int = 8192,
max_continuations: int = 3,
) -> str:
"""聊天专用:system/user 分消息;输出触顶时自动续写。"""
images = _collect_images(None, images_b64)
try:
if _use_openai():
messages: list[dict] = [
{"role": "system", "content": system.strip()},
]
if images:
content: List[dict] = [{"type": "text", "text": user.strip()}]
for b64, mime in images:
content.append(
{
"type": "image_url",
"image_url": {"url": f"data:{mime};base64,{b64}"},
}
)
messages.append({"role": "user", "content": content})
else:
messages.append({"role": "user", "content": user.strip()})
parts: list[str] = []
for _ in range(max(1, int(max_continuations) + 1)):
chunk, reason = _openai_chat_completion(
messages,
temperature=temperature,
max_tokens=max_tokens,
image_count=len(images),
chat=True,
)
if chunk.startswith("AI 调用失败") or chunk.startswith("AI 生成失败"):
return chunk if not parts else "".join(parts)
parts.append(chunk)
if not _should_continue(reason, chunk):
break
messages.append({"role": "assistant", "content": chunk})
messages.append({"role": "user", "content": _CHAT_CONTINUE_USER})
return "".join(parts).strip() or "AI 生成失败:空内容"
prompt = f"{system.strip()}\n\n---\n\n{user.strip()}"
parts = []
current_prompt = prompt
for _ in range(max(1, int(max_continuations) + 1)):
chunk, reason = _generate_ollama(
current_prompt,
images if not parts else [],
temperature,
max_tokens=max_tokens,
chat=True,
)
if chunk.startswith("AI 生成失败") and not parts:
return chunk
parts.append(chunk)
if not _should_continue(reason, chunk):
break
tail = chunk[-400:] if len(chunk) > 400 else chunk
current_prompt = (
f"{prompt}\n\n{''.join(parts)}\n\n"
f"{_CHAT_CONTINUE_USER}\n\n"
f"(已写结尾片段供衔接:…{tail}"
)
return "".join(parts).strip() or "AI 生成失败"
except requests.HTTPError as e: except requests.HTTPError as e:
detail = "" detail = ""
try: try:
+4
View File
@@ -85,6 +85,10 @@ HUB_TRUST_LAN=true
# 与四实例相同变量名;默认 OpenAI 兼容网关(改 AI_PROVIDER=ollama 可走本机 Ollama # 与四实例相同变量名;默认 OpenAI 兼容网关(改 AI_PROVIDER=ollama 可走本机 Ollama
# 详见 manual_trading_hub/AI教练说明.md 与仓库根 AI复盘与模型配置说明.md # 详见 manual_trading_hub/AI教练说明.md 与仓库根 AI复盘与模型配置说明.md
AI_TIMEOUT_SECONDS=120 AI_TIMEOUT_SECONDS=120
# AI 教练聊天:单次输出 token 上限与截断自动续写次数(默认 8192 / 3)
# CHAT_MAX_OUTPUT_TOKENS=8192
# CHAT_MAX_CONTINUATIONS=3
# CHAT_AI_TIMEOUT_SECONDS=300
# AI 提供方:openai(默认,OpenAI 兼容网关)| ollama(本机 Ollama # AI 提供方:openai(默认,OpenAI 兼容网关)| ollama(本机 Ollama
AI_PROVIDER=openai AI_PROVIDER=openai
+2
View File
@@ -7,6 +7,7 @@ from hub_ai.attachments import parse_chat_attachments
from hub_ai.client import generate_text, model_label from hub_ai.client import generate_text, model_label
from hub_ai.config import ( from hub_ai.config import (
CHAT_CONTEXT_MAX_CHARS, CHAT_CONTEXT_MAX_CHARS,
CHAT_MAX_CONTINUATIONS,
CHAT_MAX_HISTORY_TURNS, CHAT_MAX_HISTORY_TURNS,
CHAT_MAX_OUTPUT_TOKENS, CHAT_MAX_OUTPUT_TOKENS,
CHAT_SUMMARY_EXCERPT_MAX_CHARS, CHAT_SUMMARY_EXCERPT_MAX_CHARS,
@@ -107,6 +108,7 @@ def send_chat_message(
temperature=CHAT_TEMPERATURE, temperature=CHAT_TEMPERATURE,
images_b64=parsed.get("images_b64") or None, images_b64=parsed.get("images_b64") or None,
max_tokens=CHAT_MAX_OUTPUT_TOKENS, max_tokens=CHAT_MAX_OUTPUT_TOKENS,
max_continuations=CHAT_MAX_CONTINUATIONS,
) )
if reply.startswith("AI 调用失败"): if reply.startswith("AI 调用失败"):
return {"ok": False, "msg": reply, "session_id": sid} return {"ok": False, "msg": reply, "session_id": sid}
+11 -2
View File
@@ -9,7 +9,7 @@ _REPO_ROOT = Path(__file__).resolve().parents[2]
if str(_REPO_ROOT) not in sys.path: if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT)) sys.path.insert(0, str(_REPO_ROOT))
from ai_client import ai_generate, ai_provider_label # noqa: E402 from ai_client import ai_generate, ai_generate_chat, ai_provider_label # noqa: E402
def model_label() -> str: def model_label() -> str:
@@ -23,11 +23,20 @@ def generate_text(
temperature: float, temperature: float,
images_b64: Optional[Sequence[str]] = None, images_b64: Optional[Sequence[str]] = None,
max_tokens: int | None = None, max_tokens: int | None = None,
max_continuations: int = 3,
) -> str: ) -> str:
if max_tokens is not None and max_tokens > 0:
return ai_generate_chat(
system=system,
user=user,
temperature=temperature,
images_b64=images_b64,
max_tokens=int(max_tokens),
max_continuations=max_continuations,
)
prompt = f"{system.strip()}\n\n---\n\n{user.strip()}" prompt = f"{system.strip()}\n\n---\n\n{user.strip()}"
return ai_generate( return ai_generate(
prompt, prompt,
temperature=temperature, temperature=temperature,
images_b64=images_b64, images_b64=images_b64,
max_tokens=max_tokens,
) )
+10 -1
View File
@@ -5,10 +5,19 @@ import os
HUB_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) HUB_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def _int_env(key: str, default: int) -> int:
try:
return int(os.getenv(key, str(default)) or default)
except ValueError:
return default
SUMMARY_TEMPERATURE = 0.15 SUMMARY_TEMPERATURE = 0.15
CHAT_TEMPERATURE = 0.5 CHAT_TEMPERATURE = 0.5
CHAT_MAX_HISTORY_TURNS = 40 CHAT_MAX_HISTORY_TURNS = 40
CHAT_MAX_OUTPUT_TOKENS = 2048 CHAT_MAX_OUTPUT_TOKENS = _int_env("CHAT_MAX_OUTPUT_TOKENS", 8192)
CHAT_MAX_CONTINUATIONS = _int_env("CHAT_MAX_CONTINUATIONS", 3)
CHAT_CONTEXT_MAX_CHARS = 128_000 CHAT_CONTEXT_MAX_CHARS = 128_000
CHAT_SUMMARY_EXCERPT_MAX_CHARS = 8000 CHAT_SUMMARY_EXCERPT_MAX_CHARS = 8000
SUMMARY_RETENTION_DAYS = 90 SUMMARY_RETENTION_DAYS = 90