From 8be34a2fd5e23087b35d20fe6b24ca91e1f3dc64 Mon Sep 17 00:00:00 2001 From: dekun Date: Fri, 12 Jun 2026 17:13:57 +0800 Subject: [PATCH] Fix ChatTTS CUDA device-side assert with text sanitize and GPU recovery. Re-enable KV cache by default, normalize digits and unsafe chars, disable per-chunk split_text, and reload ChatTTS after CUDA errors. Co-authored-by: Cursor --- .env.example | 1 + config.py | 10 +++++ gpu_utils.py | 19 +++++++++ tts_service.py | 108 ++++++++++++++++++++++++++++++++++++++++++++----- 4 files changed, 128 insertions(+), 10 deletions(-) diff --git a/.env.example b/.env.example index 06be769..d9894d1 100644 --- a/.env.example +++ b/.env.example @@ -18,3 +18,4 @@ OLLAMA_PORT=11434 # TTS_MAX_CHARS_PER_CHUNK=150 # TTS_MAX_NEW_TOKEN=768 # TTS_MIN_NEW_TOKEN=16 +# TTS_ENABLE_CACHE=true diff --git a/config.py b/config.py index ce8a6a9..881e946 100644 --- a/config.py +++ b/config.py @@ -39,6 +39,13 @@ def _env_int(key: str, default: int) -> int: return default +def _env_bool(key: str, default: bool) -> bool: + raw = os.environ.get(key) + if raw is None: + return default + return raw.strip().lower() in ("1", "true", "yes", "on") + + # --------------------------------------------------------------------------- # 网络与服务 # --------------------------------------------------------------------------- @@ -145,6 +152,9 @@ TTS_MAX_NEW_TOKEN = _env_int("TTS_MAX_NEW_TOKEN", 1024) # 至少生成多少 audio token 才允许结束(防止首 token EOS → 无限递归重试) TTS_MIN_NEW_TOKEN = _env_int("TTS_MIN_NEW_TOKEN", 16) +# GPT KV cache(关闭可省显存,但部分 transformers 版本会触发 CUDA assert) +TTS_ENABLE_CACHE = _env_bool("TTS_ENABLE_CACHE", True) + # --------------------------------------------------------------------------- # 上传临时文件目录 # --------------------------------------------------------------------------- diff --git a/gpu_utils.py b/gpu_utils.py index 51cc4cb..ad08b54 100644 --- a/gpu_utils.py +++ b/gpu_utils.py @@ -15,6 +15,10 @@ def release_cuda_cache() -> None: import torch if torch.cuda.is_available(): + try: + torch.cuda.synchronize() + except Exception: + pass torch.cuda.empty_cache() if hasattr(torch.cuda, "ipc_collect"): torch.cuda.ipc_collect() @@ -22,6 +26,21 @@ def release_cuda_cache() -> None: pass +def is_cuda_runtime_error(exc: BaseException) -> bool: + msg = str(exc).lower() + return any( + k in msg + for k in ( + "cuda error", + "device-side assert", + "out of memory", + "cublas", + "illegal memory access", + "an illegal instruction", + ) + ) + + def cuda_memory_summary() -> str: """返回简要显存占用(调试用)。""" try: diff --git a/tts_service.py b/tts_service.py index b462f19..ed19c90 100644 --- a/tts_service.py +++ b/tts_service.py @@ -32,6 +32,7 @@ from config import ( SPEAKER_SAMPLE_MAX_SEC, SPEAKER_SAMPLE_MIN_SEC, TTS_MAX_CHARS_PER_CHUNK, + TTS_ENABLE_CACHE, TTS_MAX_NEW_TOKEN, TTS_MIN_NEW_TOKEN, TTS_SAMPLE_RATE, @@ -95,7 +96,7 @@ def _load_chat_model(chat) -> None: _ensure_hf_env() model_dir = CHATTTS_MODEL_DIR - base_kwargs: Dict[str, Any] = {"compile": False, "enable_cache": False} + base_kwargs: Dict[str, Any] = {"compile": False, "enable_cache": TTS_ENABLE_CACHE} if not hasattr(chat, "load"): if hasattr(chat, "load_models"): @@ -484,6 +485,34 @@ _STAGE_DIRECTION_RE = re.compile( r"[((][^))]{0,80}(?:前奏|转场|语气|背景|BGM|配乐|节奏|环节)[^))]{0,80}[))]" ) +_CN_DIGITS = "零一二三四五六七八九" + +# ChatTTS tokenizer 对裸 ASCII 数字、控制符敏感,易触发 CUDA device-side assert +_TTS_UNSAFE_CHAR_RE = re.compile( + r"[\u200b-\u200f\u202a-\u202e\ufeff\x00-\x08\x0b\x0c\x0e-\x1f]" +) +_TTS_ALLOWED_CHAR_RE = re.compile( + r"[^\u4e00-\u9fff\u3400-\u4dbfA-Za-z0-9,。!?;:、「」『』()—…\-\s'\"《》%%]" +) + + +def _digits_to_chinese(text: str) -> str: + def _repl(match: re.Match[str]) -> str: + return "".join(_CN_DIGITS[int(ch)] for ch in match.group()) + + return re.sub(r"\d+", _repl, text) + + +def _normalize_tts_chunk(text: str) -> str: + """单段合成用:去控制符、数字转中文、合并换行为逗号。""" + text = _TTS_UNSAFE_CHAR_RE.sub("", text) + text = text.replace("\r", "").replace("\n", ",") + text = _digits_to_chinese(text) + text = _TTS_ALLOWED_CHAR_RE.sub("", text) + text = re.sub(r"[,,]{2,}", ",", text) + text = re.sub(r"\s+", "", text) + return text.strip(",。 \t") + def prepare_text_for_tts(text: str) -> str: """ @@ -523,7 +552,8 @@ def prepare_text_for_tts(text: str) -> str: lines = [ln.strip() for ln in cleaned.split("\n")] lines = [ln for ln in lines if ln and not re.fullmatch(r"[*\-#]+", ln)] - return "\n".join(lines).strip() + merged = "。".join(lines) + return _normalize_tts_chunk(merged) def split_text_for_tts(text: str, max_chars: int = TTS_MAX_CHARS_PER_CHUNK) -> List[str]: @@ -558,7 +588,31 @@ def split_text_for_tts(text: str, max_chars: int = TTS_MAX_CHARS_PER_CHUNK) -> L if buf: chunks.append(buf) - return [c.strip() for c in chunks if c.strip()] + return [_normalize_tts_chunk(c) for c in chunks if c.strip()] + + +def _is_cuda_runtime_error(exc: BaseException) -> bool: + from gpu_utils import is_cuda_runtime_error + + return is_cuda_runtime_error(exc) + + +def _run_chattts_infer( + chat: Any, + chunk: str, + params_refine_text: Any, + params_infer_code: Any, +) -> Any: + """单次 ChatTTS infer;split_text=False 避免段内再切分引发 mask 异常。""" + return chat.infer( + chunk, + skip_refine_text=False, + split_text=False, + do_text_normalization=True, + do_homophone_replacement=True, + params_refine_text=params_refine_text, + params_infer_code=params_infer_code, + ) def _concat_wavs( @@ -662,24 +716,42 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]: segment_wavs: List[np.ndarray] = [] for idx, chunk in enumerate(chunks, start=1): + if not chunk or len(chunk) < 2: + continue release_cuda_cache() - # manual_seed 每段不同;ensure_non_empty=False 避免空输出时无限递归 chunk_infer = replace(params_infer_code, manual_seed=42 + idx) wavs = None last_exc: Optional[BaseException] = None for attempt in range(3): try: - wavs = chat.infer( - chunk, - skip_refine_text=False, - params_refine_text=params_refine_text, - params_infer_code=chunk_infer, + wavs = _run_chattts_infer( + chat, chunk, params_refine_text, chunk_infer ) break except RecursionError as exc: last_exc = exc - chunk_infer.manual_seed = 1000 + idx * 10 + attempt + chunk_infer = replace( + chunk_infer, manual_seed=1000 + idx * 10 + attempt + ) release_cuda_cache() + except RuntimeError as exc: + last_exc = exc + if not _is_cuda_runtime_error(exc) or attempt >= 2: + raise + logger.warning( + "第 %d 段 CUDA 异常,重置 ChatTTS 后重试 (%d/3): %s", + idx, + attempt + 1, + exc, + ) + reset_chattts_instance() + release_cuda_cache() + chat, reload_err = get_chattts_instance() + if chat is None: + raise RuntimeError(reload_err or "ChatTTS 重载失败") from exc + chunk_infer = replace( + chunk_infer, manual_seed=2000 + idx * 10 + attempt + ) if wavs is None: return ( False, @@ -706,6 +778,9 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]: segment_wavs.append(wav_arr) release_cuda_cache() + if not segment_wavs: + return False, "分段清洗后无有效文本,请缩短或简化润色稿后重试。", None + wav_array = ( segment_wavs[0] if len(segment_wavs) == 1 @@ -745,6 +820,19 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]: "4. 确认无其他程序占用 GPU: nvidia-smi\n" f"技术详情: {exc_msg[:400]}" ) + elif _is_cuda_runtime_error(exc): + reset_chattts_instance() + release_cuda_cache() + err = ( + "语音合成失败: GPU/CUDA 异常(device-side assert 等)。\n" + "常见原因:此前 OOM 导致 GPU 状态损坏,或文本含特殊字符。\n" + "处理步骤:\n" + "1. pm2 restart trading_studio(必须,清理 GPU 脏状态)\n" + "2. 确认已填写参考音频转写并重新锁定音色\n" + "3. 用 2-3 句短中文试合成\n" + "4. 若仍失败,在 .env 设 TTS_ENABLE_CACHE=false 后重启\n" + f"技术详情: {exc_msg[:500]}" + ) elif "recursion depth" in exc_msg.lower() or isinstance(exc, RecursionError): err = ( "语音合成失败: ChatTTS 反复生成空结果导致递归超限。\n"