fix: prevent AI coach chat replies from truncating mid-sentence
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
+194
-33
@@ -7,7 +7,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional, Sequence
|
from typing import List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -19,7 +19,12 @@ def _env_str(name: str, default: str = "") -> str:
|
|||||||
return str(v).strip()
|
return str(v).strip()
|
||||||
|
|
||||||
|
|
||||||
def _ai_timeout_seconds(*, image_count: int = 0) -> int:
|
def _ai_timeout_seconds(*, image_count: int = 0, chat: bool = False) -> int:
|
||||||
|
if chat:
|
||||||
|
try:
|
||||||
|
return max(30, int(_env_str("CHAT_AI_TIMEOUT_SECONDS", "300") or "300"))
|
||||||
|
except ValueError:
|
||||||
|
return 300
|
||||||
if image_count > 0:
|
if image_count > 0:
|
||||||
try:
|
try:
|
||||||
return max(30, int(_env_str("AI_REVIEW_TIMEOUT_SECONDS", "300") or "300"))
|
return max(30, int(_env_str("AI_REVIEW_TIMEOUT_SECONDS", "300") or "300"))
|
||||||
@@ -104,6 +109,68 @@ def _openai_chat_url() -> str:
|
|||||||
return f"{base}/chat/completions"
|
return f"{base}/chat/completions"
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_message_text(msg: dict) -> str:
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict) and part.get("type") == "text":
|
||||||
|
parts.append(str(part.get("text") or ""))
|
||||||
|
content = "".join(parts)
|
||||||
|
text = str(content or "").strip()
|
||||||
|
if not text:
|
||||||
|
text = str(msg.get("reasoning_content") or "").strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_max_tokens(body: dict, max_tokens: int | None) -> None:
|
||||||
|
if max_tokens is not None and max_tokens > 0:
|
||||||
|
mt = int(max_tokens)
|
||||||
|
body["max_tokens"] = mt
|
||||||
|
body["max_completion_tokens"] = mt
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_chat_completion(
|
||||||
|
messages: list[dict],
|
||||||
|
*,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
image_count: int = 0,
|
||||||
|
chat: bool = False,
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
api_key = _openai_api_key()
|
||||||
|
if not api_key:
|
||||||
|
return "AI 调用失败:未配置 OPENAI_API_KEY(请在当前实例目录 .env 中设置,修改后需重启服务)", "error"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
body: dict = {
|
||||||
|
"model": _openai_model(),
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
_apply_max_tokens(body, max_tokens)
|
||||||
|
r = requests.post(
|
||||||
|
_openai_chat_url(),
|
||||||
|
headers=headers,
|
||||||
|
json=body,
|
||||||
|
timeout=_ai_timeout_seconds(image_count=image_count, chat=chat),
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
choices = data.get("choices") or []
|
||||||
|
if not choices:
|
||||||
|
return "AI 生成失败:响应无 choices", "error"
|
||||||
|
choice = choices[0] or {}
|
||||||
|
msg = choice.get("message") or {}
|
||||||
|
text = _openai_message_text(msg)
|
||||||
|
if not text:
|
||||||
|
return "AI 生成失败:空内容", choice.get("finish_reason") or "error"
|
||||||
|
return text, str(choice.get("finish_reason") or "")
|
||||||
|
|
||||||
|
|
||||||
def _generate_openai(
|
def _generate_openai(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
images: List[tuple],
|
images: List[tuple],
|
||||||
@@ -111,13 +178,6 @@ def _generate_openai(
|
|||||||
*,
|
*,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
api_key = _openai_api_key()
|
|
||||||
if not api_key:
|
|
||||||
return "AI 调用失败:未配置 OPENAI_API_KEY(请在当前实例目录 .env 中设置,修改后需重启服务)"
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {api_key}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
if images:
|
if images:
|
||||||
content: List[dict] = [{"type": "text", "text": prompt}]
|
content: List[dict] = [{"type": "text", "text": prompt}]
|
||||||
for b64, mime in images:
|
for b64, mime in images:
|
||||||
@@ -130,27 +190,13 @@ def _generate_openai(
|
|||||||
messages = [{"role": "user", "content": content}]
|
messages = [{"role": "user", "content": content}]
|
||||||
else:
|
else:
|
||||||
messages = [{"role": "user", "content": prompt}]
|
messages = [{"role": "user", "content": prompt}]
|
||||||
body: dict = {
|
text, _reason = _openai_chat_completion(
|
||||||
"model": _openai_model(),
|
messages,
|
||||||
"messages": messages,
|
temperature=temperature,
|
||||||
"temperature": temperature,
|
max_tokens=max_tokens,
|
||||||
"stream": False,
|
image_count=len(images),
|
||||||
}
|
|
||||||
if max_tokens is not None and max_tokens > 0:
|
|
||||||
body["max_tokens"] = int(max_tokens)
|
|
||||||
r = requests.post(
|
|
||||||
_openai_chat_url(),
|
|
||||||
headers=headers,
|
|
||||||
json=body,
|
|
||||||
timeout=_ai_timeout_seconds(image_count=len(images)),
|
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
return text
|
||||||
data = r.json()
|
|
||||||
choices = data.get("choices") or []
|
|
||||||
if not choices:
|
|
||||||
return "AI 生成失败:响应无 choices"
|
|
||||||
msg = choices[0].get("message") or {}
|
|
||||||
return (msg.get("content") or "").strip() or "AI 生成失败:空内容"
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_ollama(
|
def _generate_ollama(
|
||||||
@@ -159,7 +205,8 @@ def _generate_ollama(
|
|||||||
temperature: float,
|
temperature: float,
|
||||||
*,
|
*,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
) -> str:
|
chat: bool = False,
|
||||||
|
) -> Tuple[str, str]:
|
||||||
options: dict = {"temperature": temperature}
|
options: dict = {"temperature": temperature}
|
||||||
if max_tokens is not None and max_tokens > 0:
|
if max_tokens is not None and max_tokens > 0:
|
||||||
options["num_predict"] = int(max_tokens)
|
options["num_predict"] = int(max_tokens)
|
||||||
@@ -171,9 +218,15 @@ def _generate_ollama(
|
|||||||
}
|
}
|
||||||
if images:
|
if images:
|
||||||
payload["images"] = [b64 for b64, _mime in images]
|
payload["images"] = [b64 for b64, _mime in images]
|
||||||
r = requests.post(_ollama_api(), json=payload, timeout=_ai_timeout_seconds(image_count=len(images)))
|
r = requests.post(
|
||||||
|
_ollama_api(),
|
||||||
|
json=payload,
|
||||||
|
timeout=_ai_timeout_seconds(image_count=len(images), chat=chat),
|
||||||
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
return (r.json().get("response") or "").strip() or "AI 生成失败"
|
data = r.json()
|
||||||
|
text = (data.get("response") or "").strip() or "AI 生成失败"
|
||||||
|
return text, str(data.get("done_reason") or "")
|
||||||
|
|
||||||
|
|
||||||
def ai_generate(
|
def ai_generate(
|
||||||
@@ -189,7 +242,115 @@ def ai_generate(
|
|||||||
try:
|
try:
|
||||||
if _use_openai():
|
if _use_openai():
|
||||||
return _generate_openai(prompt, images, temperature, max_tokens=max_tokens)
|
return _generate_openai(prompt, images, temperature, max_tokens=max_tokens)
|
||||||
return _generate_ollama(prompt, images, temperature, max_tokens=max_tokens)
|
text, _reason = _generate_ollama(prompt, images, temperature, max_tokens=max_tokens)
|
||||||
|
return text
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
detail = ""
|
||||||
|
try:
|
||||||
|
detail = (e.response.text or "")[:500]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
prov = "OpenAI" if _use_openai() else "Ollama"
|
||||||
|
return f"AI 调用失败({prov} HTTP {e.response.status_code if e.response else '?'}):{detail or str(e)}"
|
||||||
|
except Exception as e:
|
||||||
|
prov = "OpenAI" if _use_openai() else "Ollama"
|
||||||
|
return f"AI 调用失败({prov}):{str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
_CHAT_CONTINUE_USER = (
|
||||||
|
"你上一条回复在中途截断了。请从断点处继续写完,不要重复已写内容,"
|
||||||
|
"保持同一语气,写完给出完整结尾。"
|
||||||
|
)
|
||||||
|
_CHAT_END_CHARS = "。!?.!?\"」』))>】"
|
||||||
|
|
||||||
|
|
||||||
|
def _looks_truncated(text: str) -> bool:
|
||||||
|
t = (text or "").rstrip()
|
||||||
|
if len(t) < 48:
|
||||||
|
return False
|
||||||
|
if t[-1] in _CHAT_END_CHARS:
|
||||||
|
return False
|
||||||
|
if t.endswith("…") or t.endswith("..."):
|
||||||
|
return True
|
||||||
|
return t[-1] not in ",、,;;::\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _should_continue(reason: str, chunk: str) -> bool:
|
||||||
|
if reason == "length":
|
||||||
|
return True
|
||||||
|
return _looks_truncated(chunk)
|
||||||
|
|
||||||
|
|
||||||
|
def ai_generate_chat(
|
||||||
|
*,
|
||||||
|
system: str,
|
||||||
|
user: str,
|
||||||
|
temperature: float = 0.5,
|
||||||
|
images_b64: Optional[Sequence[str]] = None,
|
||||||
|
max_tokens: int = 8192,
|
||||||
|
max_continuations: int = 3,
|
||||||
|
) -> str:
|
||||||
|
"""聊天专用:system/user 分消息;输出触顶时自动续写。"""
|
||||||
|
images = _collect_images(None, images_b64)
|
||||||
|
try:
|
||||||
|
if _use_openai():
|
||||||
|
messages: list[dict] = [
|
||||||
|
{"role": "system", "content": system.strip()},
|
||||||
|
]
|
||||||
|
if images:
|
||||||
|
content: List[dict] = [{"type": "text", "text": user.strip()}]
|
||||||
|
for b64, mime in images:
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:{mime};base64,{b64}"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
messages.append({"role": "user", "content": content})
|
||||||
|
else:
|
||||||
|
messages.append({"role": "user", "content": user.strip()})
|
||||||
|
|
||||||
|
parts: list[str] = []
|
||||||
|
for _ in range(max(1, int(max_continuations) + 1)):
|
||||||
|
chunk, reason = _openai_chat_completion(
|
||||||
|
messages,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
image_count=len(images),
|
||||||
|
chat=True,
|
||||||
|
)
|
||||||
|
if chunk.startswith("AI 调用失败") or chunk.startswith("AI 生成失败"):
|
||||||
|
return chunk if not parts else "".join(parts)
|
||||||
|
parts.append(chunk)
|
||||||
|
if not _should_continue(reason, chunk):
|
||||||
|
break
|
||||||
|
messages.append({"role": "assistant", "content": chunk})
|
||||||
|
messages.append({"role": "user", "content": _CHAT_CONTINUE_USER})
|
||||||
|
return "".join(parts).strip() or "AI 生成失败:空内容"
|
||||||
|
|
||||||
|
prompt = f"{system.strip()}\n\n---\n\n{user.strip()}"
|
||||||
|
parts = []
|
||||||
|
current_prompt = prompt
|
||||||
|
for _ in range(max(1, int(max_continuations) + 1)):
|
||||||
|
chunk, reason = _generate_ollama(
|
||||||
|
current_prompt,
|
||||||
|
images if not parts else [],
|
||||||
|
temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
chat=True,
|
||||||
|
)
|
||||||
|
if chunk.startswith("AI 生成失败") and not parts:
|
||||||
|
return chunk
|
||||||
|
parts.append(chunk)
|
||||||
|
if not _should_continue(reason, chunk):
|
||||||
|
break
|
||||||
|
tail = chunk[-400:] if len(chunk) > 400 else chunk
|
||||||
|
current_prompt = (
|
||||||
|
f"{prompt}\n\n{''.join(parts)}\n\n"
|
||||||
|
f"{_CHAT_CONTINUE_USER}\n\n"
|
||||||
|
f"(已写结尾片段供衔接:…{tail})"
|
||||||
|
)
|
||||||
|
return "".join(parts).strip() or "AI 生成失败"
|
||||||
except requests.HTTPError as e:
|
except requests.HTTPError as e:
|
||||||
detail = ""
|
detail = ""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -85,6 +85,10 @@ HUB_TRUST_LAN=true
|
|||||||
# 与四实例相同变量名;默认 OpenAI 兼容网关(改 AI_PROVIDER=ollama 可走本机 Ollama)
|
# 与四实例相同变量名;默认 OpenAI 兼容网关(改 AI_PROVIDER=ollama 可走本机 Ollama)
|
||||||
# 详见 manual_trading_hub/AI教练说明.md 与仓库根 AI复盘与模型配置说明.md
|
# 详见 manual_trading_hub/AI教练说明.md 与仓库根 AI复盘与模型配置说明.md
|
||||||
AI_TIMEOUT_SECONDS=120
|
AI_TIMEOUT_SECONDS=120
|
||||||
|
# AI 教练聊天:单次输出 token 上限与截断自动续写次数(默认 8192 / 3)
|
||||||
|
# CHAT_MAX_OUTPUT_TOKENS=8192
|
||||||
|
# CHAT_MAX_CONTINUATIONS=3
|
||||||
|
# CHAT_AI_TIMEOUT_SECONDS=300
|
||||||
|
|
||||||
# AI 提供方:openai(默认,OpenAI 兼容网关)| ollama(本机 Ollama)
|
# AI 提供方:openai(默认,OpenAI 兼容网关)| ollama(本机 Ollama)
|
||||||
AI_PROVIDER=openai
|
AI_PROVIDER=openai
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from hub_ai.attachments import parse_chat_attachments
|
|||||||
from hub_ai.client import generate_text, model_label
|
from hub_ai.client import generate_text, model_label
|
||||||
from hub_ai.config import (
|
from hub_ai.config import (
|
||||||
CHAT_CONTEXT_MAX_CHARS,
|
CHAT_CONTEXT_MAX_CHARS,
|
||||||
|
CHAT_MAX_CONTINUATIONS,
|
||||||
CHAT_MAX_HISTORY_TURNS,
|
CHAT_MAX_HISTORY_TURNS,
|
||||||
CHAT_MAX_OUTPUT_TOKENS,
|
CHAT_MAX_OUTPUT_TOKENS,
|
||||||
CHAT_SUMMARY_EXCERPT_MAX_CHARS,
|
CHAT_SUMMARY_EXCERPT_MAX_CHARS,
|
||||||
@@ -107,6 +108,7 @@ def send_chat_message(
|
|||||||
temperature=CHAT_TEMPERATURE,
|
temperature=CHAT_TEMPERATURE,
|
||||||
images_b64=parsed.get("images_b64") or None,
|
images_b64=parsed.get("images_b64") or None,
|
||||||
max_tokens=CHAT_MAX_OUTPUT_TOKENS,
|
max_tokens=CHAT_MAX_OUTPUT_TOKENS,
|
||||||
|
max_continuations=CHAT_MAX_CONTINUATIONS,
|
||||||
)
|
)
|
||||||
if reply.startswith("AI 调用失败"):
|
if reply.startswith("AI 调用失败"):
|
||||||
return {"ok": False, "msg": reply, "session_id": sid}
|
return {"ok": False, "msg": reply, "session_id": sid}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ _REPO_ROOT = Path(__file__).resolve().parents[2]
|
|||||||
if str(_REPO_ROOT) not in sys.path:
|
if str(_REPO_ROOT) not in sys.path:
|
||||||
sys.path.insert(0, str(_REPO_ROOT))
|
sys.path.insert(0, str(_REPO_ROOT))
|
||||||
|
|
||||||
from ai_client import ai_generate, ai_provider_label # noqa: E402
|
from ai_client import ai_generate, ai_generate_chat, ai_provider_label # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def model_label() -> str:
|
def model_label() -> str:
|
||||||
@@ -23,11 +23,20 @@ def generate_text(
|
|||||||
temperature: float,
|
temperature: float,
|
||||||
images_b64: Optional[Sequence[str]] = None,
|
images_b64: Optional[Sequence[str]] = None,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
|
max_continuations: int = 3,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
if max_tokens is not None and max_tokens > 0:
|
||||||
|
return ai_generate_chat(
|
||||||
|
system=system,
|
||||||
|
user=user,
|
||||||
|
temperature=temperature,
|
||||||
|
images_b64=images_b64,
|
||||||
|
max_tokens=int(max_tokens),
|
||||||
|
max_continuations=max_continuations,
|
||||||
|
)
|
||||||
prompt = f"{system.strip()}\n\n---\n\n{user.strip()}"
|
prompt = f"{system.strip()}\n\n---\n\n{user.strip()}"
|
||||||
return ai_generate(
|
return ai_generate(
|
||||||
prompt,
|
prompt,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
images_b64=images_b64,
|
images_b64=images_b64,
|
||||||
max_tokens=max_tokens,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,10 +5,19 @@ import os
|
|||||||
|
|
||||||
HUB_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
HUB_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
|
||||||
|
def _int_env(key: str, default: int) -> int:
|
||||||
|
try:
|
||||||
|
return int(os.getenv(key, str(default)) or default)
|
||||||
|
except ValueError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
SUMMARY_TEMPERATURE = 0.15
|
SUMMARY_TEMPERATURE = 0.15
|
||||||
CHAT_TEMPERATURE = 0.5
|
CHAT_TEMPERATURE = 0.5
|
||||||
CHAT_MAX_HISTORY_TURNS = 40
|
CHAT_MAX_HISTORY_TURNS = 40
|
||||||
CHAT_MAX_OUTPUT_TOKENS = 2048
|
CHAT_MAX_OUTPUT_TOKENS = _int_env("CHAT_MAX_OUTPUT_TOKENS", 8192)
|
||||||
|
CHAT_MAX_CONTINUATIONS = _int_env("CHAT_MAX_CONTINUATIONS", 3)
|
||||||
CHAT_CONTEXT_MAX_CHARS = 128_000
|
CHAT_CONTEXT_MAX_CHARS = 128_000
|
||||||
CHAT_SUMMARY_EXCERPT_MAX_CHARS = 8000
|
CHAT_SUMMARY_EXCERPT_MAX_CHARS = 8000
|
||||||
SUMMARY_RETENTION_DAYS = 90
|
SUMMARY_RETENTION_DAYS = 90
|
||||||
|
|||||||
Reference in New Issue
Block a user