fix: stabilize AI coach chat against truncation and empty replies
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
+45
-15
@@ -124,11 +124,12 @@ def _openai_message_text(msg: dict) -> str:
|
||||
return text
|
||||
|
||||
|
||||
def _apply_max_tokens(body: dict, max_tokens: int | None) -> None:
|
||||
def _apply_max_tokens(body: dict, max_tokens: int | None, *, chat: bool = False) -> None:
|
||||
if max_tokens is not None and max_tokens > 0:
|
||||
mt = int(max_tokens)
|
||||
body["max_tokens"] = mt
|
||||
body["max_completion_tokens"] = mt
|
||||
if not chat:
|
||||
body["max_completion_tokens"] = mt
|
||||
|
||||
|
||||
def _openai_chat_completion(
|
||||
@@ -152,7 +153,7 @@ def _openai_chat_completion(
|
||||
"temperature": temperature,
|
||||
"stream": False,
|
||||
}
|
||||
_apply_max_tokens(body, max_tokens)
|
||||
_apply_max_tokens(body, max_tokens, chat=chat)
|
||||
r = requests.post(
|
||||
_openai_chat_url(),
|
||||
headers=headers,
|
||||
@@ -167,9 +168,27 @@ def _openai_chat_completion(
|
||||
choice = choices[0] or {}
|
||||
msg = choice.get("message") or {}
|
||||
text = _openai_message_text(msg)
|
||||
finish = str(choice.get("finish_reason") or "")
|
||||
if not text and chat and max_tokens:
|
||||
retry_body = dict(body)
|
||||
retry_body.pop("max_completion_tokens", None)
|
||||
r2 = requests.post(
|
||||
_openai_chat_url(),
|
||||
headers=headers,
|
||||
json=retry_body,
|
||||
timeout=_ai_timeout_seconds(image_count=image_count, chat=chat),
|
||||
)
|
||||
r2.raise_for_status()
|
||||
data2 = r2.json()
|
||||
choices2 = data2.get("choices") or []
|
||||
if choices2:
|
||||
msg2 = (choices2[0] or {}).get("message") or {}
|
||||
text2 = _openai_message_text(msg2)
|
||||
if text2:
|
||||
return text2, str((choices2[0] or {}).get("finish_reason") or finish)
|
||||
if not text:
|
||||
return "AI 生成失败:空内容", choice.get("finish_reason") or "error"
|
||||
return text, str(choice.get("finish_reason") or "")
|
||||
return "AI 生成失败:空内容", finish or "error"
|
||||
return text, finish
|
||||
|
||||
|
||||
def _generate_openai(
|
||||
@@ -264,7 +283,7 @@ _CHAT_CONTINUE_USER = (
|
||||
)
|
||||
_CHAT_END_CHARS = "。!?.!?\"」』))>】"
|
||||
_INCOMPLETE_TAIL_RE = re.compile(
|
||||
r"(会不会|是不是|够不够|能不能|要不要|如何|怎么|什么|哪里|多少|对吗|怎么样|"
|
||||
r"(不会|不能|没有|会不会|是不是|够不够|能不能|要不要|如何|怎么|什么|哪里|多少|对吗|怎么样|"
|
||||
r"这个\.\.\.|这个…|\.\.\.\d+\.|\d+\.)$"
|
||||
)
|
||||
|
||||
@@ -291,7 +310,7 @@ def _should_continue(reason: str, full_text: str) -> bool:
|
||||
|
||||
|
||||
def _chat_continue_message(full_text: str) -> str:
|
||||
tail = full_text[-900:] if len(full_text) > 900 else full_text
|
||||
tail = full_text[-500:] if len(full_text) > 500 else full_text
|
||||
return (
|
||||
f"{_CHAT_CONTINUE_USER}\n\n"
|
||||
f"已写到最后这几句:\n「{tail}」\n\n"
|
||||
@@ -299,6 +318,14 @@ def _chat_continue_message(full_text: str) -> str:
|
||||
)
|
||||
|
||||
|
||||
def _chat_continue_system(system: str) -> str:
|
||||
return (
|
||||
f"{system.strip()}\n\n"
|
||||
"【续写模式】只输出断点后的剩余内容,不要重复前文;"
|
||||
"列表每条单独一行;必须以句号、问号或感叹号收尾。"
|
||||
)
|
||||
|
||||
|
||||
def ai_generate_chat(
|
||||
*,
|
||||
system: str,
|
||||
@@ -306,9 +333,9 @@ def ai_generate_chat(
|
||||
temperature: float = 0.5,
|
||||
images_b64: Optional[Sequence[str]] = None,
|
||||
max_tokens: int = 8192,
|
||||
max_continuations: int = 3,
|
||||
max_continuations: int = 4,
|
||||
) -> str:
|
||||
"""聊天专用:system/user 分消息;输出触顶时自动续写(携带已写全文)。"""
|
||||
"""聊天专用:system/user 分消息;输出触顶时轻量续写(不重复巨型上下文)。"""
|
||||
images = _collect_images(None, images_b64)
|
||||
max_rounds = max(1, int(max_continuations) + 1)
|
||||
try:
|
||||
@@ -336,7 +363,7 @@ def ai_generate_chat(
|
||||
messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
image_count=len(images),
|
||||
image_count=len(images) if attempt == 0 else 0,
|
||||
chat=True,
|
||||
)
|
||||
if chunk.startswith("AI 调用失败") or chunk.startswith("AI 生成失败"):
|
||||
@@ -346,8 +373,7 @@ def ai_generate_chat(
|
||||
if not _should_continue(reason, full) or attempt >= max_rounds - 1:
|
||||
break
|
||||
messages = [
|
||||
{"role": "system", "content": system.strip()},
|
||||
base_user_msg,
|
||||
{"role": "system", "content": _chat_continue_system(system)},
|
||||
{"role": "assistant", "content": full},
|
||||
{"role": "user", "content": _chat_continue_message(full)},
|
||||
]
|
||||
@@ -356,12 +382,14 @@ def ai_generate_chat(
|
||||
prompt = f"{system.strip()}\n\n---\n\n{user.strip()}"
|
||||
parts: list[str] = []
|
||||
for attempt in range(max_rounds):
|
||||
current_prompt = prompt
|
||||
if parts:
|
||||
full = "".join(parts)
|
||||
current_prompt = (
|
||||
f"{prompt}\n\n【你已写道】\n{full}\n\n{_chat_continue_message(full)}"
|
||||
f"{_chat_continue_system(system)}\n\n"
|
||||
f"【你已写道】\n{full}\n\n{_chat_continue_message(full)}"
|
||||
)
|
||||
else:
|
||||
current_prompt = prompt
|
||||
chunk, reason = _generate_ollama(
|
||||
current_prompt,
|
||||
images if not parts else [],
|
||||
@@ -371,11 +399,13 @@ def ai_generate_chat(
|
||||
)
|
||||
if chunk.startswith("AI 生成失败") and not parts:
|
||||
return chunk
|
||||
if chunk.startswith("AI 生成失败"):
|
||||
break
|
||||
parts.append(chunk)
|
||||
full = "".join(parts)
|
||||
if not _should_continue(reason, full) or attempt >= max_rounds - 1:
|
||||
break
|
||||
return "".join(parts).strip() or "AI 生成失败"
|
||||
return "".join(parts).strip() or "AI 生成失败:空内容"
|
||||
except requests.HTTPError as e:
|
||||
detail = ""
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user