Fix ChatTTS CUDA device-side assert with text sanitize and GPU recovery.
Re-enable KV cache by default, normalize digits and unsafe chars, disable per-chunk split_text, and reload ChatTTS after CUDA errors. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
+98
-10
@@ -32,6 +32,7 @@ from config import (
|
||||
SPEAKER_SAMPLE_MAX_SEC,
|
||||
SPEAKER_SAMPLE_MIN_SEC,
|
||||
TTS_MAX_CHARS_PER_CHUNK,
|
||||
TTS_ENABLE_CACHE,
|
||||
TTS_MAX_NEW_TOKEN,
|
||||
TTS_MIN_NEW_TOKEN,
|
||||
TTS_SAMPLE_RATE,
|
||||
@@ -95,7 +96,7 @@ def _load_chat_model(chat) -> None:
|
||||
_ensure_hf_env()
|
||||
model_dir = CHATTTS_MODEL_DIR
|
||||
|
||||
base_kwargs: Dict[str, Any] = {"compile": False, "enable_cache": False}
|
||||
base_kwargs: Dict[str, Any] = {"compile": False, "enable_cache": TTS_ENABLE_CACHE}
|
||||
|
||||
if not hasattr(chat, "load"):
|
||||
if hasattr(chat, "load_models"):
|
||||
@@ -484,6 +485,34 @@ _STAGE_DIRECTION_RE = re.compile(
|
||||
r"[((][^))]{0,80}(?:前奏|转场|语气|背景|BGM|配乐|节奏|环节)[^))]{0,80}[))]"
|
||||
)
|
||||
|
||||
_CN_DIGITS = "零一二三四五六七八九"
|
||||
|
||||
# ChatTTS tokenizer 对裸 ASCII 数字、控制符敏感,易触发 CUDA device-side assert
|
||||
_TTS_UNSAFE_CHAR_RE = re.compile(
|
||||
r"[\u200b-\u200f\u202a-\u202e\ufeff\x00-\x08\x0b\x0c\x0e-\x1f]"
|
||||
)
|
||||
_TTS_ALLOWED_CHAR_RE = re.compile(
|
||||
r"[^\u4e00-\u9fff\u3400-\u4dbfA-Za-z0-9,。!?;:、「」『』()—…\-\s'\"《》%%]"
|
||||
)
|
||||
|
||||
|
||||
def _digits_to_chinese(text: str) -> str:
|
||||
def _repl(match: re.Match[str]) -> str:
|
||||
return "".join(_CN_DIGITS[int(ch)] for ch in match.group())
|
||||
|
||||
return re.sub(r"\d+", _repl, text)
|
||||
|
||||
|
||||
def _normalize_tts_chunk(text: str) -> str:
|
||||
"""单段合成用:去控制符、数字转中文、合并换行为逗号。"""
|
||||
text = _TTS_UNSAFE_CHAR_RE.sub("", text)
|
||||
text = text.replace("\r", "").replace("\n", ",")
|
||||
text = _digits_to_chinese(text)
|
||||
text = _TTS_ALLOWED_CHAR_RE.sub("", text)
|
||||
text = re.sub(r"[,,]{2,}", ",", text)
|
||||
text = re.sub(r"\s+", "", text)
|
||||
return text.strip(",。 \t")
|
||||
|
||||
|
||||
def prepare_text_for_tts(text: str) -> str:
|
||||
"""
|
||||
@@ -523,7 +552,8 @@ def prepare_text_for_tts(text: str) -> str:
|
||||
|
||||
lines = [ln.strip() for ln in cleaned.split("\n")]
|
||||
lines = [ln for ln in lines if ln and not re.fullmatch(r"[*\-#]+", ln)]
|
||||
return "\n".join(lines).strip()
|
||||
merged = "。".join(lines)
|
||||
return _normalize_tts_chunk(merged)
|
||||
|
||||
|
||||
def split_text_for_tts(text: str, max_chars: int = TTS_MAX_CHARS_PER_CHUNK) -> List[str]:
|
||||
@@ -558,7 +588,31 @@ def split_text_for_tts(text: str, max_chars: int = TTS_MAX_CHARS_PER_CHUNK) -> L
|
||||
if buf:
|
||||
chunks.append(buf)
|
||||
|
||||
return [c.strip() for c in chunks if c.strip()]
|
||||
return [_normalize_tts_chunk(c) for c in chunks if c.strip()]
|
||||
|
||||
|
||||
def _is_cuda_runtime_error(exc: BaseException) -> bool:
|
||||
from gpu_utils import is_cuda_runtime_error
|
||||
|
||||
return is_cuda_runtime_error(exc)
|
||||
|
||||
|
||||
def _run_chattts_infer(
|
||||
chat: Any,
|
||||
chunk: str,
|
||||
params_refine_text: Any,
|
||||
params_infer_code: Any,
|
||||
) -> Any:
|
||||
"""单次 ChatTTS infer;split_text=False 避免段内再切分引发 mask 异常。"""
|
||||
return chat.infer(
|
||||
chunk,
|
||||
skip_refine_text=False,
|
||||
split_text=False,
|
||||
do_text_normalization=True,
|
||||
do_homophone_replacement=True,
|
||||
params_refine_text=params_refine_text,
|
||||
params_infer_code=params_infer_code,
|
||||
)
|
||||
|
||||
|
||||
def _concat_wavs(
|
||||
@@ -662,24 +716,42 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
|
||||
segment_wavs: List[np.ndarray] = []
|
||||
for idx, chunk in enumerate(chunks, start=1):
|
||||
if not chunk or len(chunk) < 2:
|
||||
continue
|
||||
release_cuda_cache()
|
||||
# manual_seed 每段不同;ensure_non_empty=False 避免空输出时无限递归
|
||||
chunk_infer = replace(params_infer_code, manual_seed=42 + idx)
|
||||
wavs = None
|
||||
last_exc: Optional[BaseException] = None
|
||||
for attempt in range(3):
|
||||
try:
|
||||
wavs = chat.infer(
|
||||
chunk,
|
||||
skip_refine_text=False,
|
||||
params_refine_text=params_refine_text,
|
||||
params_infer_code=chunk_infer,
|
||||
wavs = _run_chattts_infer(
|
||||
chat, chunk, params_refine_text, chunk_infer
|
||||
)
|
||||
break
|
||||
except RecursionError as exc:
|
||||
last_exc = exc
|
||||
chunk_infer.manual_seed = 1000 + idx * 10 + attempt
|
||||
chunk_infer = replace(
|
||||
chunk_infer, manual_seed=1000 + idx * 10 + attempt
|
||||
)
|
||||
release_cuda_cache()
|
||||
except RuntimeError as exc:
|
||||
last_exc = exc
|
||||
if not _is_cuda_runtime_error(exc) or attempt >= 2:
|
||||
raise
|
||||
logger.warning(
|
||||
"第 %d 段 CUDA 异常,重置 ChatTTS 后重试 (%d/3): %s",
|
||||
idx,
|
||||
attempt + 1,
|
||||
exc,
|
||||
)
|
||||
reset_chattts_instance()
|
||||
release_cuda_cache()
|
||||
chat, reload_err = get_chattts_instance()
|
||||
if chat is None:
|
||||
raise RuntimeError(reload_err or "ChatTTS 重载失败") from exc
|
||||
chunk_infer = replace(
|
||||
chunk_infer, manual_seed=2000 + idx * 10 + attempt
|
||||
)
|
||||
if wavs is None:
|
||||
return (
|
||||
False,
|
||||
@@ -706,6 +778,9 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
segment_wavs.append(wav_arr)
|
||||
release_cuda_cache()
|
||||
|
||||
if not segment_wavs:
|
||||
return False, "分段清洗后无有效文本,请缩短或简化润色稿后重试。", None
|
||||
|
||||
wav_array = (
|
||||
segment_wavs[0]
|
||||
if len(segment_wavs) == 1
|
||||
@@ -745,6 +820,19 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
"4. 确认无其他程序占用 GPU: nvidia-smi\n"
|
||||
f"技术详情: {exc_msg[:400]}"
|
||||
)
|
||||
elif _is_cuda_runtime_error(exc):
|
||||
reset_chattts_instance()
|
||||
release_cuda_cache()
|
||||
err = (
|
||||
"语音合成失败: GPU/CUDA 异常(device-side assert 等)。\n"
|
||||
"常见原因:此前 OOM 导致 GPU 状态损坏,或文本含特殊字符。\n"
|
||||
"处理步骤:\n"
|
||||
"1. pm2 restart trading_studio(必须,清理 GPU 脏状态)\n"
|
||||
"2. 确认已填写参考音频转写并重新锁定音色\n"
|
||||
"3. 用 2-3 句短中文试合成\n"
|
||||
"4. 若仍失败,在 .env 设 TTS_ENABLE_CACHE=false 后重启\n"
|
||||
f"技术详情: {exc_msg[:500]}"
|
||||
)
|
||||
elif "recursion depth" in exc_msg.lower() or isinstance(exc, RecursionError):
|
||||
err = (
|
||||
"语音合成失败: ChatTTS 反复生成空结果导致递归超限。\n"
|
||||
|
||||
Reference in New Issue
Block a user