fix: AI coach chat continuation now carries full draft text
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
+53
-28
@@ -7,6 +7,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from typing import List, Optional, Sequence, Tuple
|
from typing import List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@@ -259,26 +260,43 @@ def ai_generate(
|
|||||||
|
|
||||||
_CHAT_CONTINUE_USER = (
|
_CHAT_CONTINUE_USER = (
|
||||||
"你上一条回复在中途截断了。请从断点处继续写完,不要重复已写内容,"
|
"你上一条回复在中途截断了。请从断点处继续写完,不要重复已写内容,"
|
||||||
"保持同一语气,写完给出完整结尾。"
|
"保持同一语气;编号列表每条单独一行。"
|
||||||
)
|
)
|
||||||
_CHAT_END_CHARS = "。!?.!?\"」』))>】"
|
_CHAT_END_CHARS = "。!?.!?\"」』))>】"
|
||||||
|
_INCOMPLETE_TAIL_RE = re.compile(
|
||||||
|
r"(会不会|是不是|够不够|能不能|要不要|如何|怎么|什么|哪里|多少|对吗|怎么样|"
|
||||||
|
r"这个\.\.\.|这个…|\.\.\.\d+\.|\d+\.)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _looks_truncated(text: str) -> bool:
|
def _looks_truncated(text: str) -> bool:
|
||||||
t = (text or "").rstrip()
|
t = (text or "").rstrip()
|
||||||
if len(t) < 48:
|
if len(t) < 16:
|
||||||
return False
|
return False
|
||||||
if t[-1] in _CHAT_END_CHARS:
|
if t[-1] in _CHAT_END_CHARS:
|
||||||
return False
|
return False
|
||||||
|
if _INCOMPLETE_TAIL_RE.search(t):
|
||||||
|
return True
|
||||||
if t.endswith("…") or t.endswith("..."):
|
if t.endswith("…") or t.endswith("..."):
|
||||||
return True
|
return True
|
||||||
|
if re.search(r"\d+\.\s*$", t):
|
||||||
|
return True
|
||||||
return t[-1] not in ",、,;;::\n"
|
return t[-1] not in ",、,;;::\n"
|
||||||
|
|
||||||
|
|
||||||
def _should_continue(reason: str, chunk: str) -> bool:
|
def _should_continue(reason: str, full_text: str) -> bool:
|
||||||
if reason == "length":
|
if reason in ("length", "max_tokens", "model_length"):
|
||||||
return True
|
return True
|
||||||
return _looks_truncated(chunk)
|
return _looks_truncated(full_text)
|
||||||
|
|
||||||
|
|
||||||
|
def _chat_continue_message(full_text: str) -> str:
|
||||||
|
tail = full_text[-900:] if len(full_text) > 900 else full_text
|
||||||
|
return (
|
||||||
|
f"{_CHAT_CONTINUE_USER}\n\n"
|
||||||
|
f"已写到最后这几句:\n「{tail}」\n\n"
|
||||||
|
f"请从断点接着写完。不要重复前文;最后一句话必须以句号、问号或感叹号结束。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def ai_generate_chat(
|
def ai_generate_chat(
|
||||||
@@ -290,28 +308,30 @@ def ai_generate_chat(
|
|||||||
max_tokens: int = 8192,
|
max_tokens: int = 8192,
|
||||||
max_continuations: int = 3,
|
max_continuations: int = 3,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""聊天专用:system/user 分消息;输出触顶时自动续写。"""
|
"""聊天专用:system/user 分消息;输出触顶时自动续写(携带已写全文)。"""
|
||||||
images = _collect_images(None, images_b64)
|
images = _collect_images(None, images_b64)
|
||||||
|
max_rounds = max(1, int(max_continuations) + 1)
|
||||||
try:
|
try:
|
||||||
if _use_openai():
|
if _use_openai():
|
||||||
messages: list[dict] = [
|
|
||||||
{"role": "system", "content": system.strip()},
|
|
||||||
]
|
|
||||||
if images:
|
if images:
|
||||||
content: List[dict] = [{"type": "text", "text": user.strip()}]
|
user_content: List[dict] | str = [{"type": "text", "text": user.strip()}]
|
||||||
for b64, mime in images:
|
for b64, mime in images:
|
||||||
content.append(
|
user_content.append(
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {"url": f"data:{mime};base64,{b64}"},
|
"image_url": {"url": f"data:{mime};base64,{b64}"},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
messages.append({"role": "user", "content": content})
|
|
||||||
else:
|
else:
|
||||||
messages.append({"role": "user", "content": user.strip()})
|
user_content = user.strip()
|
||||||
|
base_user_msg = {"role": "user", "content": user_content}
|
||||||
|
messages: list[dict] = [
|
||||||
|
{"role": "system", "content": system.strip()},
|
||||||
|
base_user_msg,
|
||||||
|
]
|
||||||
|
|
||||||
parts: list[str] = []
|
parts: list[str] = []
|
||||||
for _ in range(max(1, int(max_continuations) + 1)):
|
for attempt in range(max_rounds):
|
||||||
chunk, reason = _openai_chat_completion(
|
chunk, reason = _openai_chat_completion(
|
||||||
messages,
|
messages,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -320,18 +340,28 @@ def ai_generate_chat(
|
|||||||
chat=True,
|
chat=True,
|
||||||
)
|
)
|
||||||
if chunk.startswith("AI 调用失败") or chunk.startswith("AI 生成失败"):
|
if chunk.startswith("AI 调用失败") or chunk.startswith("AI 生成失败"):
|
||||||
return chunk if not parts else "".join(parts)
|
return chunk if not parts else "".join(parts).strip()
|
||||||
parts.append(chunk)
|
parts.append(chunk)
|
||||||
if not _should_continue(reason, chunk):
|
full = "".join(parts)
|
||||||
|
if not _should_continue(reason, full) or attempt >= max_rounds - 1:
|
||||||
break
|
break
|
||||||
messages.append({"role": "assistant", "content": chunk})
|
messages = [
|
||||||
messages.append({"role": "user", "content": _CHAT_CONTINUE_USER})
|
{"role": "system", "content": system.strip()},
|
||||||
|
base_user_msg,
|
||||||
|
{"role": "assistant", "content": full},
|
||||||
|
{"role": "user", "content": _chat_continue_message(full)},
|
||||||
|
]
|
||||||
return "".join(parts).strip() or "AI 生成失败:空内容"
|
return "".join(parts).strip() or "AI 生成失败:空内容"
|
||||||
|
|
||||||
prompt = f"{system.strip()}\n\n---\n\n{user.strip()}"
|
prompt = f"{system.strip()}\n\n---\n\n{user.strip()}"
|
||||||
parts = []
|
parts: list[str] = []
|
||||||
current_prompt = prompt
|
for attempt in range(max_rounds):
|
||||||
for _ in range(max(1, int(max_continuations) + 1)):
|
current_prompt = prompt
|
||||||
|
if parts:
|
||||||
|
full = "".join(parts)
|
||||||
|
current_prompt = (
|
||||||
|
f"{prompt}\n\n【你已写道】\n{full}\n\n{_chat_continue_message(full)}"
|
||||||
|
)
|
||||||
chunk, reason = _generate_ollama(
|
chunk, reason = _generate_ollama(
|
||||||
current_prompt,
|
current_prompt,
|
||||||
images if not parts else [],
|
images if not parts else [],
|
||||||
@@ -342,14 +372,9 @@ def ai_generate_chat(
|
|||||||
if chunk.startswith("AI 生成失败") and not parts:
|
if chunk.startswith("AI 生成失败") and not parts:
|
||||||
return chunk
|
return chunk
|
||||||
parts.append(chunk)
|
parts.append(chunk)
|
||||||
if not _should_continue(reason, chunk):
|
full = "".join(parts)
|
||||||
|
if not _should_continue(reason, full) or attempt >= max_rounds - 1:
|
||||||
break
|
break
|
||||||
tail = chunk[-400:] if len(chunk) > 400 else chunk
|
|
||||||
current_prompt = (
|
|
||||||
f"{prompt}\n\n{''.join(parts)}\n\n"
|
|
||||||
f"{_CHAT_CONTINUE_USER}\n\n"
|
|
||||||
f"(已写结尾片段供衔接:…{tail})"
|
|
||||||
)
|
|
||||||
return "".join(parts).strip() or "AI 生成失败"
|
return "".join(parts).strip() or "AI 生成失败"
|
||||||
except requests.HTTPError as e:
|
except requests.HTTPError as e:
|
||||||
detail = ""
|
detail = ""
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ HUB_TRUST_LAN=true
|
|||||||
AI_TIMEOUT_SECONDS=120
|
AI_TIMEOUT_SECONDS=120
|
||||||
# AI 教练聊天:单次输出 token 上限与截断自动续写次数(默认 8192 / 3)
|
# AI 教练聊天:单次输出 token 上限与截断自动续写次数(默认 8192 / 3)
|
||||||
# CHAT_MAX_OUTPUT_TOKENS=8192
|
# CHAT_MAX_OUTPUT_TOKENS=8192
|
||||||
# CHAT_MAX_CONTINUATIONS=3
|
# CHAT_MAX_CONTINUATIONS=8
|
||||||
# CHAT_AI_TIMEOUT_SECONDS=300
|
# CHAT_AI_TIMEOUT_SECONDS=300
|
||||||
|
|
||||||
# AI 提供方:openai(默认,OpenAI 兼容网关)| ollama(本机 Ollama)
|
# AI 提供方:openai(默认,OpenAI 兼容网关)| ollama(本机 Ollama)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ SUMMARY_TEMPERATURE = 0.15
|
|||||||
CHAT_TEMPERATURE = 0.5
|
CHAT_TEMPERATURE = 0.5
|
||||||
CHAT_MAX_HISTORY_TURNS = 40
|
CHAT_MAX_HISTORY_TURNS = 40
|
||||||
CHAT_MAX_OUTPUT_TOKENS = _int_env("CHAT_MAX_OUTPUT_TOKENS", 8192)
|
CHAT_MAX_OUTPUT_TOKENS = _int_env("CHAT_MAX_OUTPUT_TOKENS", 8192)
|
||||||
CHAT_MAX_CONTINUATIONS = _int_env("CHAT_MAX_CONTINUATIONS", 3)
|
CHAT_MAX_CONTINUATIONS = _int_env("CHAT_MAX_CONTINUATIONS", 8)
|
||||||
CHAT_CONTEXT_MAX_CHARS = 128_000
|
CHAT_CONTEXT_MAX_CHARS = 128_000
|
||||||
CHAT_SUMMARY_EXCERPT_MAX_CHARS = 8000
|
CHAT_SUMMARY_EXCERPT_MAX_CHARS = 8000
|
||||||
SUMMARY_RETENTION_DAYS = 90
|
SUMMARY_RETENTION_DAYS = 90
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ CHAT_SYSTEM = """
|
|||||||
- 若附带「今日总结摘要」,那是较早生成的缓存,**实盘持仓以【当前多账户快照】里的「实盘持仓总览」为准**,摘要里若提到持仓可能已过时。
|
- 若附带「今日总结摘要」,那是较早生成的缓存,**实盘持仓以【当前多账户快照】里的「实盘持仓总览」为准**,摘要里若提到持仓可能已过时。
|
||||||
- 若用户上传图片,可结合图中可见信息讨论,看不清的明确说看不清。
|
- 若用户上传图片,可结合图中可见信息讨论,看不清的明确说看不清。
|
||||||
- **优先接住【用户现在说】和【此前对话】**:用户聊心态、悔单、某笔操作时,先顺着这个话题回应,不要每句都复述账户资金数字。
|
- **优先接住【用户现在说】和【此前对话】**:用户聊心态、悔单、某笔操作时,先顺着这个话题回应,不要每句都复述账户资金数字。
|
||||||
- **接续对话**:有【此前对话】时须接着聊,不要重复开场白,回复写完整,不要说到一半戛然而止。
|
- **接续对话**:有【此前对话】时须接着聊,不要重复开场白;整段回复必须写完,以句号/问号/感叹号收尾,不得停在半句话;编号列表每条单独一行。
|
||||||
- 快照里的盈亏/资金仅在需要核对事实时引用;用户口述与快照冲突时,以快照为准并口语说明。
|
- 快照里的盈亏/资金仅在需要核对事实时引用;用户口述与快照冲突时,以快照为准并口语说明。
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user