Fix ChatTTS CUDA device-side assert with text sanitize and GPU recovery.

Re-enable KV cache by default, normalize digits and unsafe chars, disable per-chunk split_text, and reload ChatTTS after CUDA errors.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-12 17:13:57 +08:00
parent 1779449bba
commit 8be34a2fd5
4 changed files with 128 additions and 10 deletions
+1
View File
@@ -18,3 +18,4 @@ OLLAMA_PORT=11434
# TTS_MAX_CHARS_PER_CHUNK=150
# TTS_MAX_NEW_TOKEN=768
# TTS_MIN_NEW_TOKEN=16
# TTS_ENABLE_CACHE=true
+10
View File
@@ -39,6 +39,13 @@ def _env_int(key: str, default: int) -> int:
return default
def _env_bool(key: str, default: bool) -> bool:
raw = os.environ.get(key)
if raw is None:
return default
return raw.strip().lower() in ("1", "true", "yes", "on")
# ---------------------------------------------------------------------------
# 网络与服务
# ---------------------------------------------------------------------------
@@ -145,6 +152,9 @@ TTS_MAX_NEW_TOKEN = _env_int("TTS_MAX_NEW_TOKEN", 1024)
# 至少生成多少 audio token 才允许结束(防止首 token EOS → 无限递归重试)
TTS_MIN_NEW_TOKEN = _env_int("TTS_MIN_NEW_TOKEN", 16)
# GPT KV cache(关闭可省显存,但部分 transformers 版本会触发 CUDA assert
TTS_ENABLE_CACHE = _env_bool("TTS_ENABLE_CACHE", True)
# ---------------------------------------------------------------------------
# 上传临时文件目录
# ---------------------------------------------------------------------------
+19
View File
@@ -15,6 +15,10 @@ def release_cuda_cache() -> None:
import torch
if torch.cuda.is_available():
try:
torch.cuda.synchronize()
except Exception:
pass
torch.cuda.empty_cache()
if hasattr(torch.cuda, "ipc_collect"):
torch.cuda.ipc_collect()
@@ -22,6 +26,21 @@ def release_cuda_cache() -> None:
pass
def is_cuda_runtime_error(exc: BaseException) -> bool:
msg = str(exc).lower()
return any(
k in msg
for k in (
"cuda error",
"device-side assert",
"out of memory",
"cublas",
"illegal memory access",
"an illegal instruction",
)
)
def cuda_memory_summary() -> str:
"""返回简要显存占用(调试用)。"""
try:
+98 -10
View File
@@ -32,6 +32,7 @@ from config import (
SPEAKER_SAMPLE_MAX_SEC,
SPEAKER_SAMPLE_MIN_SEC,
TTS_MAX_CHARS_PER_CHUNK,
TTS_ENABLE_CACHE,
TTS_MAX_NEW_TOKEN,
TTS_MIN_NEW_TOKEN,
TTS_SAMPLE_RATE,
@@ -95,7 +96,7 @@ def _load_chat_model(chat) -> None:
_ensure_hf_env()
model_dir = CHATTTS_MODEL_DIR
base_kwargs: Dict[str, Any] = {"compile": False, "enable_cache": False}
base_kwargs: Dict[str, Any] = {"compile": False, "enable_cache": TTS_ENABLE_CACHE}
if not hasattr(chat, "load"):
if hasattr(chat, "load_models"):
@@ -484,6 +485,34 @@ _STAGE_DIRECTION_RE = re.compile(
r"[(][^)]{0,80}(?:前奏|转场|语气|背景|BGM|配乐|节奏|环节)[^)]{0,80}[)]"
)
_CN_DIGITS = "零一二三四五六七八九"
# ChatTTS tokenizer 对裸 ASCII 数字、控制符敏感,易触发 CUDA device-side assert
_TTS_UNSAFE_CHAR_RE = re.compile(
r"[\u200b-\u200f\u202a-\u202e\ufeff\x00-\x08\x0b\x0c\x0e-\x1f]"
)
_TTS_ALLOWED_CHAR_RE = re.compile(
r"[^\u4e00-\u9fff\u3400-\u4dbfA-Za-z0-9,。!?;:、「」『』()—…\-\s'\"《》%%]"
)
def _digits_to_chinese(text: str) -> str:
def _repl(match: re.Match[str]) -> str:
return "".join(_CN_DIGITS[int(ch)] for ch in match.group())
return re.sub(r"\d+", _repl, text)
def _normalize_tts_chunk(text: str) -> str:
"""单段合成用:去控制符、数字转中文、合并换行为逗号。"""
text = _TTS_UNSAFE_CHAR_RE.sub("", text)
text = text.replace("\r", "").replace("\n", "")
text = _digits_to_chinese(text)
text = _TTS_ALLOWED_CHAR_RE.sub("", text)
text = re.sub(r"[,]{2,}", "", text)
text = re.sub(r"\s+", "", text)
return text.strip(",。 \t")
def prepare_text_for_tts(text: str) -> str:
"""
@@ -523,7 +552,8 @@ def prepare_text_for_tts(text: str) -> str:
lines = [ln.strip() for ln in cleaned.split("\n")]
lines = [ln for ln in lines if ln and not re.fullmatch(r"[*\-#]+", ln)]
return "\n".join(lines).strip()
merged = "".join(lines)
return _normalize_tts_chunk(merged)
def split_text_for_tts(text: str, max_chars: int = TTS_MAX_CHARS_PER_CHUNK) -> List[str]:
@@ -558,7 +588,31 @@ def split_text_for_tts(text: str, max_chars: int = TTS_MAX_CHARS_PER_CHUNK) -> L
if buf:
chunks.append(buf)
return [c.strip() for c in chunks if c.strip()]
return [_normalize_tts_chunk(c) for c in chunks if c.strip()]
def _is_cuda_runtime_error(exc: BaseException) -> bool:
from gpu_utils import is_cuda_runtime_error
return is_cuda_runtime_error(exc)
def _run_chattts_infer(
chat: Any,
chunk: str,
params_refine_text: Any,
params_infer_code: Any,
) -> Any:
"""单次 ChatTTS infersplit_text=False 避免段内再切分引发 mask 异常。"""
return chat.infer(
chunk,
skip_refine_text=False,
split_text=False,
do_text_normalization=True,
do_homophone_replacement=True,
params_refine_text=params_refine_text,
params_infer_code=params_infer_code,
)
def _concat_wavs(
@@ -662,24 +716,42 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
segment_wavs: List[np.ndarray] = []
for idx, chunk in enumerate(chunks, start=1):
if not chunk or len(chunk) < 2:
continue
release_cuda_cache()
# manual_seed 每段不同;ensure_non_empty=False 避免空输出时无限递归
chunk_infer = replace(params_infer_code, manual_seed=42 + idx)
wavs = None
last_exc: Optional[BaseException] = None
for attempt in range(3):
try:
wavs = chat.infer(
chunk,
skip_refine_text=False,
params_refine_text=params_refine_text,
params_infer_code=chunk_infer,
wavs = _run_chattts_infer(
chat, chunk, params_refine_text, chunk_infer
)
break
except RecursionError as exc:
last_exc = exc
chunk_infer.manual_seed = 1000 + idx * 10 + attempt
chunk_infer = replace(
chunk_infer, manual_seed=1000 + idx * 10 + attempt
)
release_cuda_cache()
except RuntimeError as exc:
last_exc = exc
if not _is_cuda_runtime_error(exc) or attempt >= 2:
raise
logger.warning(
"%d 段 CUDA 异常,重置 ChatTTS 后重试 (%d/3): %s",
idx,
attempt + 1,
exc,
)
reset_chattts_instance()
release_cuda_cache()
chat, reload_err = get_chattts_instance()
if chat is None:
raise RuntimeError(reload_err or "ChatTTS 重载失败") from exc
chunk_infer = replace(
chunk_infer, manual_seed=2000 + idx * 10 + attempt
)
if wavs is None:
return (
False,
@@ -706,6 +778,9 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
segment_wavs.append(wav_arr)
release_cuda_cache()
if not segment_wavs:
return False, "分段清洗后无有效文本,请缩短或简化润色稿后重试。", None
wav_array = (
segment_wavs[0]
if len(segment_wavs) == 1
@@ -745,6 +820,19 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
"4. 确认无其他程序占用 GPU: nvidia-smi\n"
f"技术详情: {exc_msg[:400]}"
)
elif _is_cuda_runtime_error(exc):
reset_chattts_instance()
release_cuda_cache()
err = (
"语音合成失败: GPU/CUDA 异常(device-side assert 等)。\n"
"常见原因:此前 OOM 导致 GPU 状态损坏,或文本含特殊字符。\n"
"处理步骤:\n"
"1. pm2 restart trading_studio(必须,清理 GPU 脏状态)\n"
"2. 确认已填写参考音频转写并重新锁定音色\n"
"3. 用 2-3 句短中文试合成\n"
"4. 若仍失败,在 .env 设 TTS_ENABLE_CACHE=false 后重启\n"
f"技术详情: {exc_msg[:500]}"
)
elif "recursion depth" in exc_msg.lower() or isinstance(exc, RecursionError):
err = (
"语音合成失败: ChatTTS 反复生成空结果导致递归超限。\n"