Fix ChatTTS recursion depth exceeded on empty generation.
Disable ensure_non_empty retries, set min_new_token, always refine text, and use per-chunk manual_seed. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -17,3 +17,4 @@ OLLAMA_PORT=11434
|
|||||||
# 8GB 显存 OOM 时可调低(合成按段切分)
|
# 8GB 显存 OOM 时可调低(合成按段切分)
|
||||||
# TTS_MAX_CHARS_PER_CHUNK=150
|
# TTS_MAX_CHARS_PER_CHUNK=150
|
||||||
# TTS_MAX_NEW_TOKEN=768
|
# TTS_MAX_NEW_TOKEN=768
|
||||||
|
# TTS_MIN_NEW_TOKEN=16
|
||||||
|
|||||||
@@ -142,6 +142,9 @@ TTS_MAX_CHARS_PER_CHUNK = _env_int("TTS_MAX_CHARS_PER_CHUNK", 200)
|
|||||||
# ChatTTS 单段最大生成 token(越小越省显存,长句会自动切多段)
|
# ChatTTS 单段最大生成 token(越小越省显存,长句会自动切多段)
|
||||||
TTS_MAX_NEW_TOKEN = _env_int("TTS_MAX_NEW_TOKEN", 1024)
|
TTS_MAX_NEW_TOKEN = _env_int("TTS_MAX_NEW_TOKEN", 1024)
|
||||||
|
|
||||||
|
# 至少生成多少 audio token 才允许结束(防止首 token EOS → 无限递归重试)
|
||||||
|
TTS_MIN_NEW_TOKEN = _env_int("TTS_MIN_NEW_TOKEN", 16)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# 上传临时文件目录
|
# 上传临时文件目录
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
+43
-3
@@ -12,6 +12,7 @@ import re
|
|||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
from dataclasses import replace
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
@@ -32,6 +33,7 @@ from config import (
|
|||||||
SPEAKER_SAMPLE_MIN_SEC,
|
SPEAKER_SAMPLE_MIN_SEC,
|
||||||
TTS_MAX_CHARS_PER_CHUNK,
|
TTS_MAX_CHARS_PER_CHUNK,
|
||||||
TTS_MAX_NEW_TOKEN,
|
TTS_MAX_NEW_TOKEN,
|
||||||
|
TTS_MIN_NEW_TOKEN,
|
||||||
TTS_SAMPLE_RATE,
|
TTS_SAMPLE_RATE,
|
||||||
TTS_SPEED_PROMPT,
|
TTS_SPEED_PROMPT,
|
||||||
TTS_TEMPERATURE,
|
TTS_TEMPERATURE,
|
||||||
@@ -641,10 +643,14 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
|||||||
top_P=TTS_TOP_P,
|
top_P=TTS_TOP_P,
|
||||||
top_K=TTS_TOP_K,
|
top_K=TTS_TOP_K,
|
||||||
max_new_token=TTS_MAX_NEW_TOKEN,
|
max_new_token=TTS_MAX_NEW_TOKEN,
|
||||||
|
min_new_token=TTS_MIN_NEW_TOKEN,
|
||||||
|
ensure_non_empty=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
||||||
prompt="[oral_2][laugh_0][break_4]",
|
prompt="[oral_2][laugh_0][break_4]",
|
||||||
|
ensure_non_empty=False,
|
||||||
|
min_new_token=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -657,11 +663,30 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
|||||||
segment_wavs: List[np.ndarray] = []
|
segment_wavs: List[np.ndarray] = []
|
||||||
for idx, chunk in enumerate(chunks, start=1):
|
for idx, chunk in enumerate(chunks, start=1):
|
||||||
release_cuda_cache()
|
release_cuda_cache()
|
||||||
|
# manual_seed 每段不同;ensure_non_empty=False 避免空输出时无限递归
|
||||||
|
chunk_infer = replace(params_infer_code, manual_seed=42 + idx)
|
||||||
|
wavs = None
|
||||||
|
last_exc: Optional[BaseException] = None
|
||||||
|
for attempt in range(3):
|
||||||
|
try:
|
||||||
wavs = chat.infer(
|
wavs = chat.infer(
|
||||||
chunk,
|
chunk,
|
||||||
skip_refine_text=(idx > 1),
|
skip_refine_text=False,
|
||||||
params_refine_text=params_refine_text,
|
params_refine_text=params_refine_text,
|
||||||
params_infer_code=params_infer_code,
|
params_infer_code=chunk_infer,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except RecursionError as exc:
|
||||||
|
last_exc = exc
|
||||||
|
chunk_infer.manual_seed = 1000 + idx * 10 + attempt
|
||||||
|
release_cuda_cache()
|
||||||
|
if wavs is None:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"ChatTTS 第 {idx}/{len(chunks)} 段合成失败(递归重试耗尽)。"
|
||||||
|
f"请检查音色转写是否填写,或缩短该段文本。"
|
||||||
|
f" 详情: {last_exc}",
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
if not wavs or len(wavs) == 0:
|
if not wavs or len(wavs) == 0:
|
||||||
return (
|
return (
|
||||||
@@ -670,7 +695,15 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
|||||||
f"(段内容前 40 字: {chunk[:40]}…)",
|
f"(段内容前 40 字: {chunk[:40]}…)",
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
segment_wavs.append(np.asarray(wavs[0], dtype=np.float32))
|
wav_arr = np.asarray(wavs[0], dtype=np.float32)
|
||||||
|
if wav_arr.size == 0 or np.max(np.abs(wav_arr)) < 1e-6:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"ChatTTS 第 {idx}/{len(chunks)} 段生成了空音频。"
|
||||||
|
"请重新锁定音色并填写参考转写,或缩短润色稿后重试。",
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
segment_wavs.append(wav_arr)
|
||||||
release_cuda_cache()
|
release_cuda_cache()
|
||||||
|
|
||||||
wav_array = (
|
wav_array = (
|
||||||
@@ -712,6 +745,13 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
|||||||
"4. 确认无其他程序占用 GPU: nvidia-smi\n"
|
"4. 确认无其他程序占用 GPU: nvidia-smi\n"
|
||||||
f"技术详情: {exc_msg[:400]}"
|
f"技术详情: {exc_msg[:400]}"
|
||||||
)
|
)
|
||||||
|
elif "recursion depth" in exc_msg.lower() or isinstance(exc, RecursionError):
|
||||||
|
err = (
|
||||||
|
"语音合成失败: ChatTTS 反复生成空结果导致递归超限。\n"
|
||||||
|
"常见原因:未填写参考音频转写、润色稿含特殊符号、或音色文件异常。\n"
|
||||||
|
"处理:重新锁定音色并填写转写 → 用较短纯文本试合成。\n"
|
||||||
|
f"技术详情: {exc_msg[:400]}"
|
||||||
|
)
|
||||||
elif "Corrupt input data" in exc_msg:
|
elif "Corrupt input data" in exc_msg:
|
||||||
err = (
|
err = (
|
||||||
"语音合成失败: 音色数据损坏或格式不兼容(Corrupt input data)。\n"
|
"语音合成失败: 音色数据损坏或格式不兼容(Corrupt input data)。\n"
|
||||||
|
|||||||
Reference in New Issue
Block a user