Fix ChatTTS CUDA device-side assert with text sanitize and GPU recovery.
Re-enable KV cache by default, normalize digits and unsafe chars, disable per-chunk split_text, and reload ChatTTS after CUDA errors. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -18,3 +18,4 @@ OLLAMA_PORT=11434
|
|||||||
# TTS_MAX_CHARS_PER_CHUNK=150
|
# TTS_MAX_CHARS_PER_CHUNK=150
|
||||||
# TTS_MAX_NEW_TOKEN=768
|
# TTS_MAX_NEW_TOKEN=768
|
||||||
# TTS_MIN_NEW_TOKEN=16
|
# TTS_MIN_NEW_TOKEN=16
|
||||||
|
# TTS_ENABLE_CACHE=true
|
||||||
|
|||||||
@@ -39,6 +39,13 @@ def _env_int(key: str, default: int) -> int:
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _env_bool(key: str, default: bool) -> bool:
|
||||||
|
raw = os.environ.get(key)
|
||||||
|
if raw is None:
|
||||||
|
return default
|
||||||
|
return raw.strip().lower() in ("1", "true", "yes", "on")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# 网络与服务
|
# 网络与服务
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -145,6 +152,9 @@ TTS_MAX_NEW_TOKEN = _env_int("TTS_MAX_NEW_TOKEN", 1024)
|
|||||||
# 至少生成多少 audio token 才允许结束(防止首 token EOS → 无限递归重试)
|
# 至少生成多少 audio token 才允许结束(防止首 token EOS → 无限递归重试)
|
||||||
TTS_MIN_NEW_TOKEN = _env_int("TTS_MIN_NEW_TOKEN", 16)
|
TTS_MIN_NEW_TOKEN = _env_int("TTS_MIN_NEW_TOKEN", 16)
|
||||||
|
|
||||||
|
# GPT KV cache(关闭可省显存,但部分 transformers 版本会触发 CUDA assert)
|
||||||
|
TTS_ENABLE_CACHE = _env_bool("TTS_ENABLE_CACHE", True)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# 上传临时文件目录
|
# 上传临时文件目录
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -15,6 +15,10 @@ def release_cuda_cache() -> None:
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
try:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
if hasattr(torch.cuda, "ipc_collect"):
|
if hasattr(torch.cuda, "ipc_collect"):
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
@@ -22,6 +26,21 @@ def release_cuda_cache() -> None:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def is_cuda_runtime_error(exc: BaseException) -> bool:
|
||||||
|
msg = str(exc).lower()
|
||||||
|
return any(
|
||||||
|
k in msg
|
||||||
|
for k in (
|
||||||
|
"cuda error",
|
||||||
|
"device-side assert",
|
||||||
|
"out of memory",
|
||||||
|
"cublas",
|
||||||
|
"illegal memory access",
|
||||||
|
"an illegal instruction",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def cuda_memory_summary() -> str:
|
def cuda_memory_summary() -> str:
|
||||||
"""返回简要显存占用(调试用)。"""
|
"""返回简要显存占用(调试用)。"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
+98
-10
@@ -32,6 +32,7 @@ from config import (
|
|||||||
SPEAKER_SAMPLE_MAX_SEC,
|
SPEAKER_SAMPLE_MAX_SEC,
|
||||||
SPEAKER_SAMPLE_MIN_SEC,
|
SPEAKER_SAMPLE_MIN_SEC,
|
||||||
TTS_MAX_CHARS_PER_CHUNK,
|
TTS_MAX_CHARS_PER_CHUNK,
|
||||||
|
TTS_ENABLE_CACHE,
|
||||||
TTS_MAX_NEW_TOKEN,
|
TTS_MAX_NEW_TOKEN,
|
||||||
TTS_MIN_NEW_TOKEN,
|
TTS_MIN_NEW_TOKEN,
|
||||||
TTS_SAMPLE_RATE,
|
TTS_SAMPLE_RATE,
|
||||||
@@ -95,7 +96,7 @@ def _load_chat_model(chat) -> None:
|
|||||||
_ensure_hf_env()
|
_ensure_hf_env()
|
||||||
model_dir = CHATTTS_MODEL_DIR
|
model_dir = CHATTTS_MODEL_DIR
|
||||||
|
|
||||||
base_kwargs: Dict[str, Any] = {"compile": False, "enable_cache": False}
|
base_kwargs: Dict[str, Any] = {"compile": False, "enable_cache": TTS_ENABLE_CACHE}
|
||||||
|
|
||||||
if not hasattr(chat, "load"):
|
if not hasattr(chat, "load"):
|
||||||
if hasattr(chat, "load_models"):
|
if hasattr(chat, "load_models"):
|
||||||
@@ -484,6 +485,34 @@ _STAGE_DIRECTION_RE = re.compile(
|
|||||||
r"[((][^))]{0,80}(?:前奏|转场|语气|背景|BGM|配乐|节奏|环节)[^))]{0,80}[))]"
|
r"[((][^))]{0,80}(?:前奏|转场|语气|背景|BGM|配乐|节奏|环节)[^))]{0,80}[))]"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_CN_DIGITS = "零一二三四五六七八九"
|
||||||
|
|
||||||
|
# ChatTTS tokenizer 对裸 ASCII 数字、控制符敏感,易触发 CUDA device-side assert
|
||||||
|
_TTS_UNSAFE_CHAR_RE = re.compile(
|
||||||
|
r"[\u200b-\u200f\u202a-\u202e\ufeff\x00-\x08\x0b\x0c\x0e-\x1f]"
|
||||||
|
)
|
||||||
|
_TTS_ALLOWED_CHAR_RE = re.compile(
|
||||||
|
r"[^\u4e00-\u9fff\u3400-\u4dbfA-Za-z0-9,。!?;:、「」『』()—…\-\s'\"《》%%]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _digits_to_chinese(text: str) -> str:
|
||||||
|
def _repl(match: re.Match[str]) -> str:
|
||||||
|
return "".join(_CN_DIGITS[int(ch)] for ch in match.group())
|
||||||
|
|
||||||
|
return re.sub(r"\d+", _repl, text)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_tts_chunk(text: str) -> str:
|
||||||
|
"""单段合成用:去控制符、数字转中文、合并换行为逗号。"""
|
||||||
|
text = _TTS_UNSAFE_CHAR_RE.sub("", text)
|
||||||
|
text = text.replace("\r", "").replace("\n", ",")
|
||||||
|
text = _digits_to_chinese(text)
|
||||||
|
text = _TTS_ALLOWED_CHAR_RE.sub("", text)
|
||||||
|
text = re.sub(r"[,,]{2,}", ",", text)
|
||||||
|
text = re.sub(r"\s+", "", text)
|
||||||
|
return text.strip(",。 \t")
|
||||||
|
|
||||||
|
|
||||||
def prepare_text_for_tts(text: str) -> str:
|
def prepare_text_for_tts(text: str) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -523,7 +552,8 @@ def prepare_text_for_tts(text: str) -> str:
|
|||||||
|
|
||||||
lines = [ln.strip() for ln in cleaned.split("\n")]
|
lines = [ln.strip() for ln in cleaned.split("\n")]
|
||||||
lines = [ln for ln in lines if ln and not re.fullmatch(r"[*\-#]+", ln)]
|
lines = [ln for ln in lines if ln and not re.fullmatch(r"[*\-#]+", ln)]
|
||||||
return "\n".join(lines).strip()
|
merged = "。".join(lines)
|
||||||
|
return _normalize_tts_chunk(merged)
|
||||||
|
|
||||||
|
|
||||||
def split_text_for_tts(text: str, max_chars: int = TTS_MAX_CHARS_PER_CHUNK) -> List[str]:
|
def split_text_for_tts(text: str, max_chars: int = TTS_MAX_CHARS_PER_CHUNK) -> List[str]:
|
||||||
@@ -558,7 +588,31 @@ def split_text_for_tts(text: str, max_chars: int = TTS_MAX_CHARS_PER_CHUNK) -> L
|
|||||||
if buf:
|
if buf:
|
||||||
chunks.append(buf)
|
chunks.append(buf)
|
||||||
|
|
||||||
return [c.strip() for c in chunks if c.strip()]
|
return [_normalize_tts_chunk(c) for c in chunks if c.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_cuda_runtime_error(exc: BaseException) -> bool:
|
||||||
|
from gpu_utils import is_cuda_runtime_error
|
||||||
|
|
||||||
|
return is_cuda_runtime_error(exc)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_chattts_infer(
|
||||||
|
chat: Any,
|
||||||
|
chunk: str,
|
||||||
|
params_refine_text: Any,
|
||||||
|
params_infer_code: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""单次 ChatTTS infer;split_text=False 避免段内再切分引发 mask 异常。"""
|
||||||
|
return chat.infer(
|
||||||
|
chunk,
|
||||||
|
skip_refine_text=False,
|
||||||
|
split_text=False,
|
||||||
|
do_text_normalization=True,
|
||||||
|
do_homophone_replacement=True,
|
||||||
|
params_refine_text=params_refine_text,
|
||||||
|
params_infer_code=params_infer_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _concat_wavs(
|
def _concat_wavs(
|
||||||
@@ -662,24 +716,42 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
|||||||
|
|
||||||
segment_wavs: List[np.ndarray] = []
|
segment_wavs: List[np.ndarray] = []
|
||||||
for idx, chunk in enumerate(chunks, start=1):
|
for idx, chunk in enumerate(chunks, start=1):
|
||||||
|
if not chunk or len(chunk) < 2:
|
||||||
|
continue
|
||||||
release_cuda_cache()
|
release_cuda_cache()
|
||||||
# manual_seed 每段不同;ensure_non_empty=False 避免空输出时无限递归
|
|
||||||
chunk_infer = replace(params_infer_code, manual_seed=42 + idx)
|
chunk_infer = replace(params_infer_code, manual_seed=42 + idx)
|
||||||
wavs = None
|
wavs = None
|
||||||
last_exc: Optional[BaseException] = None
|
last_exc: Optional[BaseException] = None
|
||||||
for attempt in range(3):
|
for attempt in range(3):
|
||||||
try:
|
try:
|
||||||
wavs = chat.infer(
|
wavs = _run_chattts_infer(
|
||||||
chunk,
|
chat, chunk, params_refine_text, chunk_infer
|
||||||
skip_refine_text=False,
|
|
||||||
params_refine_text=params_refine_text,
|
|
||||||
params_infer_code=chunk_infer,
|
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
except RecursionError as exc:
|
except RecursionError as exc:
|
||||||
last_exc = exc
|
last_exc = exc
|
||||||
chunk_infer.manual_seed = 1000 + idx * 10 + attempt
|
chunk_infer = replace(
|
||||||
|
chunk_infer, manual_seed=1000 + idx * 10 + attempt
|
||||||
|
)
|
||||||
release_cuda_cache()
|
release_cuda_cache()
|
||||||
|
except RuntimeError as exc:
|
||||||
|
last_exc = exc
|
||||||
|
if not _is_cuda_runtime_error(exc) or attempt >= 2:
|
||||||
|
raise
|
||||||
|
logger.warning(
|
||||||
|
"第 %d 段 CUDA 异常,重置 ChatTTS 后重试 (%d/3): %s",
|
||||||
|
idx,
|
||||||
|
attempt + 1,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
reset_chattts_instance()
|
||||||
|
release_cuda_cache()
|
||||||
|
chat, reload_err = get_chattts_instance()
|
||||||
|
if chat is None:
|
||||||
|
raise RuntimeError(reload_err or "ChatTTS 重载失败") from exc
|
||||||
|
chunk_infer = replace(
|
||||||
|
chunk_infer, manual_seed=2000 + idx * 10 + attempt
|
||||||
|
)
|
||||||
if wavs is None:
|
if wavs is None:
|
||||||
return (
|
return (
|
||||||
False,
|
False,
|
||||||
@@ -706,6 +778,9 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
|||||||
segment_wavs.append(wav_arr)
|
segment_wavs.append(wav_arr)
|
||||||
release_cuda_cache()
|
release_cuda_cache()
|
||||||
|
|
||||||
|
if not segment_wavs:
|
||||||
|
return False, "分段清洗后无有效文本,请缩短或简化润色稿后重试。", None
|
||||||
|
|
||||||
wav_array = (
|
wav_array = (
|
||||||
segment_wavs[0]
|
segment_wavs[0]
|
||||||
if len(segment_wavs) == 1
|
if len(segment_wavs) == 1
|
||||||
@@ -745,6 +820,19 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
|||||||
"4. 确认无其他程序占用 GPU: nvidia-smi\n"
|
"4. 确认无其他程序占用 GPU: nvidia-smi\n"
|
||||||
f"技术详情: {exc_msg[:400]}"
|
f"技术详情: {exc_msg[:400]}"
|
||||||
)
|
)
|
||||||
|
elif _is_cuda_runtime_error(exc):
|
||||||
|
reset_chattts_instance()
|
||||||
|
release_cuda_cache()
|
||||||
|
err = (
|
||||||
|
"语音合成失败: GPU/CUDA 异常(device-side assert 等)。\n"
|
||||||
|
"常见原因:此前 OOM 导致 GPU 状态损坏,或文本含特殊字符。\n"
|
||||||
|
"处理步骤:\n"
|
||||||
|
"1. pm2 restart trading_studio(必须,清理 GPU 脏状态)\n"
|
||||||
|
"2. 确认已填写参考音频转写并重新锁定音色\n"
|
||||||
|
"3. 用 2-3 句短中文试合成\n"
|
||||||
|
"4. 若仍失败,在 .env 设 TTS_ENABLE_CACHE=false 后重启\n"
|
||||||
|
f"技术详情: {exc_msg[:500]}"
|
||||||
|
)
|
||||||
elif "recursion depth" in exc_msg.lower() or isinstance(exc, RecursionError):
|
elif "recursion depth" in exc_msg.lower() or isinstance(exc, RecursionError):
|
||||||
err = (
|
err = (
|
||||||
"语音合成失败: ChatTTS 反复生成空结果导致递归超限。\n"
|
"语音合成失败: ChatTTS 反复生成空结果导致递归超限。\n"
|
||||||
|
|||||||
Reference in New Issue
Block a user