fix: AI coach chat continuation now carries full draft text
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
+53
-28
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
import requests
|
||||
@@ -259,26 +260,43 @@ def ai_generate(
|
||||
|
||||
_CHAT_CONTINUE_USER = (
|
||||
"你上一条回复在中途截断了。请从断点处继续写完,不要重复已写内容,"
|
||||
"保持同一语气,写完给出完整结尾。"
|
||||
"保持同一语气;编号列表每条单独一行。"
|
||||
)
|
||||
_CHAT_END_CHARS = "。!?.!?\"」』))>】"
|
||||
_INCOMPLETE_TAIL_RE = re.compile(
|
||||
r"(会不会|是不是|够不够|能不能|要不要|如何|怎么|什么|哪里|多少|对吗|怎么样|"
|
||||
r"这个\.\.\.|这个…|\.\.\.\d+\.|\d+\.)$"
|
||||
)
|
||||
|
||||
|
||||
def _looks_truncated(text: str) -> bool:
|
||||
t = (text or "").rstrip()
|
||||
if len(t) < 48:
|
||||
if len(t) < 16:
|
||||
return False
|
||||
if t[-1] in _CHAT_END_CHARS:
|
||||
return False
|
||||
if _INCOMPLETE_TAIL_RE.search(t):
|
||||
return True
|
||||
if t.endswith("…") or t.endswith("..."):
|
||||
return True
|
||||
if re.search(r"\d+\.\s*$", t):
|
||||
return True
|
||||
return t[-1] not in ",、,;;::\n"
|
||||
|
||||
|
||||
def _should_continue(reason: str, chunk: str) -> bool:
|
||||
if reason == "length":
|
||||
def _should_continue(reason: str, full_text: str) -> bool:
|
||||
if reason in ("length", "max_tokens", "model_length"):
|
||||
return True
|
||||
return _looks_truncated(chunk)
|
||||
return _looks_truncated(full_text)
|
||||
|
||||
|
||||
def _chat_continue_message(full_text: str) -> str:
|
||||
tail = full_text[-900:] if len(full_text) > 900 else full_text
|
||||
return (
|
||||
f"{_CHAT_CONTINUE_USER}\n\n"
|
||||
f"已写到最后这几句:\n「{tail}」\n\n"
|
||||
f"请从断点接着写完。不要重复前文;最后一句话必须以句号、问号或感叹号结束。"
|
||||
)
|
||||
|
||||
|
||||
def ai_generate_chat(
|
||||
@@ -290,28 +308,30 @@ def ai_generate_chat(
|
||||
max_tokens: int = 8192,
|
||||
max_continuations: int = 3,
|
||||
) -> str:
|
||||
"""聊天专用:system/user 分消息;输出触顶时自动续写。"""
|
||||
"""聊天专用:system/user 分消息;输出触顶时自动续写(携带已写全文)。"""
|
||||
images = _collect_images(None, images_b64)
|
||||
max_rounds = max(1, int(max_continuations) + 1)
|
||||
try:
|
||||
if _use_openai():
|
||||
messages: list[dict] = [
|
||||
{"role": "system", "content": system.strip()},
|
||||
]
|
||||
if images:
|
||||
content: List[dict] = [{"type": "text", "text": user.strip()}]
|
||||
user_content: List[dict] | str = [{"type": "text", "text": user.strip()}]
|
||||
for b64, mime in images:
|
||||
content.append(
|
||||
user_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{mime};base64,{b64}"},
|
||||
}
|
||||
)
|
||||
messages.append({"role": "user", "content": content})
|
||||
else:
|
||||
messages.append({"role": "user", "content": user.strip()})
|
||||
user_content = user.strip()
|
||||
base_user_msg = {"role": "user", "content": user_content}
|
||||
messages: list[dict] = [
|
||||
{"role": "system", "content": system.strip()},
|
||||
base_user_msg,
|
||||
]
|
||||
|
||||
parts: list[str] = []
|
||||
for _ in range(max(1, int(max_continuations) + 1)):
|
||||
for attempt in range(max_rounds):
|
||||
chunk, reason = _openai_chat_completion(
|
||||
messages,
|
||||
temperature=temperature,
|
||||
@@ -320,18 +340,28 @@ def ai_generate_chat(
|
||||
chat=True,
|
||||
)
|
||||
if chunk.startswith("AI 调用失败") or chunk.startswith("AI 生成失败"):
|
||||
return chunk if not parts else "".join(parts)
|
||||
return chunk if not parts else "".join(parts).strip()
|
||||
parts.append(chunk)
|
||||
if not _should_continue(reason, chunk):
|
||||
full = "".join(parts)
|
||||
if not _should_continue(reason, full) or attempt >= max_rounds - 1:
|
||||
break
|
||||
messages.append({"role": "assistant", "content": chunk})
|
||||
messages.append({"role": "user", "content": _CHAT_CONTINUE_USER})
|
||||
messages = [
|
||||
{"role": "system", "content": system.strip()},
|
||||
base_user_msg,
|
||||
{"role": "assistant", "content": full},
|
||||
{"role": "user", "content": _chat_continue_message(full)},
|
||||
]
|
||||
return "".join(parts).strip() or "AI 生成失败:空内容"
|
||||
|
||||
prompt = f"{system.strip()}\n\n---\n\n{user.strip()}"
|
||||
parts = []
|
||||
current_prompt = prompt
|
||||
for _ in range(max(1, int(max_continuations) + 1)):
|
||||
parts: list[str] = []
|
||||
for attempt in range(max_rounds):
|
||||
current_prompt = prompt
|
||||
if parts:
|
||||
full = "".join(parts)
|
||||
current_prompt = (
|
||||
f"{prompt}\n\n【你已写道】\n{full}\n\n{_chat_continue_message(full)}"
|
||||
)
|
||||
chunk, reason = _generate_ollama(
|
||||
current_prompt,
|
||||
images if not parts else [],
|
||||
@@ -342,14 +372,9 @@ def ai_generate_chat(
|
||||
if chunk.startswith("AI 生成失败") and not parts:
|
||||
return chunk
|
||||
parts.append(chunk)
|
||||
if not _should_continue(reason, chunk):
|
||||
full = "".join(parts)
|
||||
if not _should_continue(reason, full) or attempt >= max_rounds - 1:
|
||||
break
|
||||
tail = chunk[-400:] if len(chunk) > 400 else chunk
|
||||
current_prompt = (
|
||||
f"{prompt}\n\n{''.join(parts)}\n\n"
|
||||
f"{_CHAT_CONTINUE_USER}\n\n"
|
||||
f"(已写结尾片段供衔接:…{tail})"
|
||||
)
|
||||
return "".join(parts).strip() or "AI 生成失败"
|
||||
except requests.HTTPError as e:
|
||||
detail = ""
|
||||
|
||||
Reference in New Issue
Block a user