fix: AI coach chat continuation now carries full draft text

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-11 00:53:45 +08:00
parent 6169fee7b9
commit 7f1015f852
4 changed files with 56 additions and 31 deletions
+53 -28
View File
@@ -7,6 +7,7 @@ from __future__ import annotations
import base64 import base64
import os import os
import re
from typing import List, Optional, Sequence, Tuple from typing import List, Optional, Sequence, Tuple
import requests import requests
@@ -259,26 +260,43 @@ def ai_generate(
_CHAT_CONTINUE_USER = ( _CHAT_CONTINUE_USER = (
"你上一条回复在中途截断了。请从断点处继续写完,不要重复已写内容," "你上一条回复在中途截断了。请从断点处继续写完,不要重复已写内容,"
"保持同一语气,写完给出完整结尾" "保持同一语气;编号列表每条单独一行"
) )
_CHAT_END_CHARS = "。!?.!?\"」』))>】" _CHAT_END_CHARS = "。!?.!?\"」』))>】"
_INCOMPLETE_TAIL_RE = re.compile(
r"(会不会|是不是|够不够|能不能|要不要|如何|怎么|什么|哪里|多少|对吗|怎么样|"
r"这个\.\.\.|这个…|\.\.\.\d+\.|\d+\.)$"
)
def _looks_truncated(text: str) -> bool: def _looks_truncated(text: str) -> bool:
t = (text or "").rstrip() t = (text or "").rstrip()
if len(t) < 48: if len(t) < 16:
return False return False
if t[-1] in _CHAT_END_CHARS: if t[-1] in _CHAT_END_CHARS:
return False return False
if _INCOMPLETE_TAIL_RE.search(t):
return True
if t.endswith("") or t.endswith("..."): if t.endswith("") or t.endswith("..."):
return True return True
if re.search(r"\d+\.\s*$", t):
return True
return t[-1] not in ",、,;:\n" return t[-1] not in ",、,;:\n"
def _should_continue(reason: str, chunk: str) -> bool: def _should_continue(reason: str, full_text: str) -> bool:
if reason == "length": if reason in ("length", "max_tokens", "model_length"):
return True return True
return _looks_truncated(chunk) return _looks_truncated(full_text)
def _chat_continue_message(full_text: str) -> str:
tail = full_text[-900:] if len(full_text) > 900 else full_text
return (
f"{_CHAT_CONTINUE_USER}\n\n"
f"已写到最后这几句:\n{tail}\n\n"
f"请从断点接着写完。不要重复前文;最后一句话必须以句号、问号或感叹号结束。"
)
def ai_generate_chat( def ai_generate_chat(
@@ -290,28 +308,30 @@ def ai_generate_chat(
max_tokens: int = 8192, max_tokens: int = 8192,
max_continuations: int = 3, max_continuations: int = 3,
) -> str: ) -> str:
"""聊天专用:system/user 分消息;输出触顶时自动续写。""" """聊天专用:system/user 分消息;输出触顶时自动续写(携带已写全文)"""
images = _collect_images(None, images_b64) images = _collect_images(None, images_b64)
max_rounds = max(1, int(max_continuations) + 1)
try: try:
if _use_openai(): if _use_openai():
messages: list[dict] = [
{"role": "system", "content": system.strip()},
]
if images: if images:
content: List[dict] = [{"type": "text", "text": user.strip()}] user_content: List[dict] | str = [{"type": "text", "text": user.strip()}]
for b64, mime in images: for b64, mime in images:
content.append( user_content.append(
{ {
"type": "image_url", "type": "image_url",
"image_url": {"url": f"data:{mime};base64,{b64}"}, "image_url": {"url": f"data:{mime};base64,{b64}"},
} }
) )
messages.append({"role": "user", "content": content})
else: else:
messages.append({"role": "user", "content": user.strip()}) user_content = user.strip()
base_user_msg = {"role": "user", "content": user_content}
messages: list[dict] = [
{"role": "system", "content": system.strip()},
base_user_msg,
]
parts: list[str] = [] parts: list[str] = []
for _ in range(max(1, int(max_continuations) + 1)): for attempt in range(max_rounds):
chunk, reason = _openai_chat_completion( chunk, reason = _openai_chat_completion(
messages, messages,
temperature=temperature, temperature=temperature,
@@ -320,18 +340,28 @@ def ai_generate_chat(
chat=True, chat=True,
) )
if chunk.startswith("AI 调用失败") or chunk.startswith("AI 生成失败"): if chunk.startswith("AI 调用失败") or chunk.startswith("AI 生成失败"):
return chunk if not parts else "".join(parts) return chunk if not parts else "".join(parts).strip()
parts.append(chunk) parts.append(chunk)
if not _should_continue(reason, chunk): full = "".join(parts)
if not _should_continue(reason, full) or attempt >= max_rounds - 1:
break break
messages.append({"role": "assistant", "content": chunk}) messages = [
messages.append({"role": "user", "content": _CHAT_CONTINUE_USER}) {"role": "system", "content": system.strip()},
base_user_msg,
{"role": "assistant", "content": full},
{"role": "user", "content": _chat_continue_message(full)},
]
return "".join(parts).strip() or "AI 生成失败:空内容" return "".join(parts).strip() or "AI 生成失败:空内容"
prompt = f"{system.strip()}\n\n---\n\n{user.strip()}" prompt = f"{system.strip()}\n\n---\n\n{user.strip()}"
parts = [] parts: list[str] = []
current_prompt = prompt for attempt in range(max_rounds):
for _ in range(max(1, int(max_continuations) + 1)): current_prompt = prompt
if parts:
full = "".join(parts)
current_prompt = (
f"{prompt}\n\n【你已写道】\n{full}\n\n{_chat_continue_message(full)}"
)
chunk, reason = _generate_ollama( chunk, reason = _generate_ollama(
current_prompt, current_prompt,
images if not parts else [], images if not parts else [],
@@ -342,14 +372,9 @@ def ai_generate_chat(
if chunk.startswith("AI 生成失败") and not parts: if chunk.startswith("AI 生成失败") and not parts:
return chunk return chunk
parts.append(chunk) parts.append(chunk)
if not _should_continue(reason, chunk): full = "".join(parts)
if not _should_continue(reason, full) or attempt >= max_rounds - 1:
break break
tail = chunk[-400:] if len(chunk) > 400 else chunk
current_prompt = (
f"{prompt}\n\n{''.join(parts)}\n\n"
f"{_CHAT_CONTINUE_USER}\n\n"
f"(已写结尾片段供衔接:…{tail}"
)
return "".join(parts).strip() or "AI 生成失败" return "".join(parts).strip() or "AI 生成失败"
except requests.HTTPError as e: except requests.HTTPError as e:
detail = "" detail = ""
+1 -1
View File
@@ -87,7 +87,7 @@ HUB_TRUST_LAN=true
AI_TIMEOUT_SECONDS=120 AI_TIMEOUT_SECONDS=120
# AI 教练聊天:单次输出 token 上限与截断自动续写次数(默认 8192 / 3) # AI 教练聊天:单次输出 token 上限与截断自动续写次数(默认 8192 / 3)
# CHAT_MAX_OUTPUT_TOKENS=8192 # CHAT_MAX_OUTPUT_TOKENS=8192
# CHAT_MAX_CONTINUATIONS=3 # CHAT_MAX_CONTINUATIONS=8
# CHAT_AI_TIMEOUT_SECONDS=300 # CHAT_AI_TIMEOUT_SECONDS=300
# AI 提供方:openai(默认,OpenAI 兼容网关)| ollama(本机 Ollama # AI 提供方:openai(默认,OpenAI 兼容网关)| ollama(本机 Ollama
+1 -1
View File
@@ -17,7 +17,7 @@ SUMMARY_TEMPERATURE = 0.15
CHAT_TEMPERATURE = 0.5 CHAT_TEMPERATURE = 0.5
CHAT_MAX_HISTORY_TURNS = 40 CHAT_MAX_HISTORY_TURNS = 40
CHAT_MAX_OUTPUT_TOKENS = _int_env("CHAT_MAX_OUTPUT_TOKENS", 8192) CHAT_MAX_OUTPUT_TOKENS = _int_env("CHAT_MAX_OUTPUT_TOKENS", 8192)
CHAT_MAX_CONTINUATIONS = _int_env("CHAT_MAX_CONTINUATIONS", 3) CHAT_MAX_CONTINUATIONS = _int_env("CHAT_MAX_CONTINUATIONS", 8)
CHAT_CONTEXT_MAX_CHARS = 128_000 CHAT_CONTEXT_MAX_CHARS = 128_000
CHAT_SUMMARY_EXCERPT_MAX_CHARS = 8000 CHAT_SUMMARY_EXCERPT_MAX_CHARS = 8000
SUMMARY_RETENTION_DAYS = 90 SUMMARY_RETENTION_DAYS = 90
+1 -1
View File
@@ -50,7 +50,7 @@ CHAT_SYSTEM = """
- 若附带「今日总结摘要」,那是较早生成的缓存,**实盘持仓以【当前多账户快照】里的「实盘持仓总览」为准**,摘要里若提到持仓可能已过时。 - 若附带「今日总结摘要」,那是较早生成的缓存,**实盘持仓以【当前多账户快照】里的「实盘持仓总览」为准**,摘要里若提到持仓可能已过时。
- 若用户上传图片,可结合图中可见信息讨论,看不清的明确说看不清。 - 若用户上传图片,可结合图中可见信息讨论,看不清的明确说看不清。
- **优先接住【用户现在说】和【此前对话】**:用户聊心态、悔单、某笔操作时,先顺着这个话题回应,不要每句都复述账户资金数字。 - **优先接住【用户现在说】和【此前对话】**:用户聊心态、悔单、某笔操作时,先顺着这个话题回应,不要每句都复述账户资金数字。
- **接续对话**:有【此前对话】时须接着聊,不要重复开场白,回复写完整,不要说到一半戛然而止 - **接续对话**:有【此前对话】时须接着聊,不要重复开场白;整段回复必须写完,以句号/问号/感叹号收尾,不得停在半句话;编号列表每条单独一行
- 快照里的盈亏/资金仅在需要核对事实时引用;用户口述与快照冲突时,以快照为准并口语说明。 - 快照里的盈亏/资金仅在需要核对事实时引用;用户口述与快照冲突时,以快照为准并口语说明。
""".strip() """.strip()