Fix ChatTTS recursion depth exceeded on empty generation.

Disable ensure_non_empty retries, set min_new_token, always refine text, and use per-chunk manual_seed.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-12 17:10:26 +08:00
parent 0cce6cda7c
commit 1779449bba
3 changed files with 51 additions and 7 deletions
+1
View File
@@ -17,3 +17,4 @@ OLLAMA_PORT=11434
# 8GB 显存 OOM 时可调低(合成按段切分) # 8GB 显存 OOM 时可调低(合成按段切分)
# TTS_MAX_CHARS_PER_CHUNK=150 # TTS_MAX_CHARS_PER_CHUNK=150
# TTS_MAX_NEW_TOKEN=768 # TTS_MAX_NEW_TOKEN=768
# TTS_MIN_NEW_TOKEN=16
+3
View File
@@ -142,6 +142,9 @@ TTS_MAX_CHARS_PER_CHUNK = _env_int("TTS_MAX_CHARS_PER_CHUNK", 200)
# ChatTTS 单段最大生成 token(越小越省显存,长句会自动切多段) # ChatTTS 单段最大生成 token(越小越省显存,长句会自动切多段)
TTS_MAX_NEW_TOKEN = _env_int("TTS_MAX_NEW_TOKEN", 1024) TTS_MAX_NEW_TOKEN = _env_int("TTS_MAX_NEW_TOKEN", 1024)
# 至少生成多少 audio token 才允许结束(防止首 token EOS → 无限递归重试)
TTS_MIN_NEW_TOKEN = _env_int("TTS_MIN_NEW_TOKEN", 16)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# 上传临时文件目录 # 上传临时文件目录
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
+47 -7
View File
@@ -12,6 +12,7 @@ import re
import traceback import traceback
import uuid import uuid
import warnings import warnings
from dataclasses import replace
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
@@ -32,6 +33,7 @@ from config import (
SPEAKER_SAMPLE_MIN_SEC, SPEAKER_SAMPLE_MIN_SEC,
TTS_MAX_CHARS_PER_CHUNK, TTS_MAX_CHARS_PER_CHUNK,
TTS_MAX_NEW_TOKEN, TTS_MAX_NEW_TOKEN,
TTS_MIN_NEW_TOKEN,
TTS_SAMPLE_RATE, TTS_SAMPLE_RATE,
TTS_SPEED_PROMPT, TTS_SPEED_PROMPT,
TTS_TEMPERATURE, TTS_TEMPERATURE,
@@ -641,10 +643,14 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
top_P=TTS_TOP_P, top_P=TTS_TOP_P,
top_K=TTS_TOP_K, top_K=TTS_TOP_K,
max_new_token=TTS_MAX_NEW_TOKEN, max_new_token=TTS_MAX_NEW_TOKEN,
min_new_token=TTS_MIN_NEW_TOKEN,
ensure_non_empty=False,
) )
params_refine_text = ChatTTS.Chat.RefineTextParams( params_refine_text = ChatTTS.Chat.RefineTextParams(
prompt="[oral_2][laugh_0][break_4]", prompt="[oral_2][laugh_0][break_4]",
ensure_non_empty=False,
min_new_token=4,
) )
logger.info( logger.info(
@@ -657,12 +663,31 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
segment_wavs: List[np.ndarray] = [] segment_wavs: List[np.ndarray] = []
for idx, chunk in enumerate(chunks, start=1): for idx, chunk in enumerate(chunks, start=1):
release_cuda_cache() release_cuda_cache()
wavs = chat.infer( # manual_seed 每段不同;ensure_non_empty=False 避免空输出时无限递归
chunk, chunk_infer = replace(params_infer_code, manual_seed=42 + idx)
skip_refine_text=(idx > 1), wavs = None
params_refine_text=params_refine_text, last_exc: Optional[BaseException] = None
params_infer_code=params_infer_code, for attempt in range(3):
) try:
wavs = chat.infer(
chunk,
skip_refine_text=False,
params_refine_text=params_refine_text,
params_infer_code=chunk_infer,
)
break
except RecursionError as exc:
last_exc = exc
chunk_infer.manual_seed = 1000 + idx * 10 + attempt
release_cuda_cache()
if wavs is None:
return (
False,
f"ChatTTS 第 {idx}/{len(chunks)} 段合成失败(递归重试耗尽)。"
f"请检查音色转写是否填写,或缩短该段文本。"
f" 详情: {last_exc}",
None,
)
if not wavs or len(wavs) == 0: if not wavs or len(wavs) == 0:
return ( return (
False, False,
@@ -670,7 +695,15 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
f"(段内容前 40 字: {chunk[:40]}…)", f"(段内容前 40 字: {chunk[:40]}…)",
None, None,
) )
segment_wavs.append(np.asarray(wavs[0], dtype=np.float32)) wav_arr = np.asarray(wavs[0], dtype=np.float32)
if wav_arr.size == 0 or np.max(np.abs(wav_arr)) < 1e-6:
return (
False,
f"ChatTTS 第 {idx}/{len(chunks)} 段生成了空音频。"
"请重新锁定音色并填写参考转写,或缩短润色稿后重试。",
None,
)
segment_wavs.append(wav_arr)
release_cuda_cache() release_cuda_cache()
wav_array = ( wav_array = (
@@ -712,6 +745,13 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
"4. 确认无其他程序占用 GPU: nvidia-smi\n" "4. 确认无其他程序占用 GPU: nvidia-smi\n"
f"技术详情: {exc_msg[:400]}" f"技术详情: {exc_msg[:400]}"
) )
elif "recursion depth" in exc_msg.lower() or isinstance(exc, RecursionError):
err = (
"语音合成失败: ChatTTS 反复生成空结果导致递归超限。\n"
"常见原因:未填写参考音频转写、润色稿含特殊符号、或音色文件异常。\n"
"处理:重新锁定音色并填写转写 → 用较短纯文本试合成。\n"
f"技术详情: {exc_msg[:400]}"
)
elif "Corrupt input data" in exc_msg: elif "Corrupt input data" in exc_msg:
err = ( err = (
"语音合成失败: 音色数据损坏或格式不兼容(Corrupt input data)。\n" "语音合成失败: 音色数据损坏或格式不兼容(Corrupt input data)。\n"