Fix CUDA OOM by mutually unloading Whisper and ChatTTS on 8GB GPU.

Release GPU memory before TTS/ASR switches, lower TTS token limits, and set PYTORCH_CUDA_ALLOC_CONF in PM2.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-12 17:03:37 +08:00
parent 82f99c0b89
commit 0cce6cda7c
7 changed files with 169 additions and 40 deletions
+55 -4
View File
@@ -31,6 +31,7 @@ from config import (
SPEAKER_SAMPLE_MAX_SEC,
SPEAKER_SAMPLE_MIN_SEC,
TTS_MAX_CHARS_PER_CHUNK,
TTS_MAX_NEW_TOKEN,
TTS_SAMPLE_RATE,
TTS_SPEED_PROMPT,
TTS_TEMPERATURE,
@@ -92,7 +93,7 @@ def _load_chat_model(chat) -> None:
_ensure_hf_env()
model_dir = CHATTTS_MODEL_DIR
base_kwargs: Dict[str, Any] = {"compile": False}
base_kwargs: Dict[str, Any] = {"compile": False, "enable_cache": False}
if not hasattr(chat, "load"):
if hasattr(chat, "load_models"):
@@ -137,11 +138,26 @@ def _load_chat_model(chat) -> None:
def reset_chattts_instance() -> None:
"""释放 ChatTTS 实例(模型下载后重启前可调用)"""
"""卸载 ChatTTS 模型并回收 GPU 显存"""
global _chat, _chat_error
if _chat is not None:
try:
if hasattr(_chat, "unload"):
_chat.unload()
except Exception:
logger.exception("ChatTTS unload 失败")
try:
del _chat
except Exception:
pass
_chat = None
_chat_error = None
from gpu_utils import release_cuda_cache
release_cuda_cache()
logger.info("ChatTTS 模型已卸载,显存已尝试回收。")
def get_chattts_instance():
"""
@@ -348,6 +364,13 @@ def save_fixed_speaker(
if not audio_sample_path:
return False, "未提供音色参考音频。"
try:
from whisper_service import reset_whisper_model
reset_whisper_model()
except Exception:
logger.debug("锁定音色前释放 Whisper 显存跳过", exc_info=True)
chat, init_err = get_chattts_instance()
if chat is None:
return False, init_err or "ChatTTS 不可用。"
@@ -566,6 +589,19 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
if not refined_text or not refined_text.strip():
return False, "合成文本为空,请先完成润色。", None
# 合成前释放 Whisper,避免 8GB 显存上双模型 OOM
try:
from whisper_service import reset_whisper_model
reset_whisper_model()
except Exception:
logger.debug("合成前释放 Whisper 显存跳过", exc_info=True)
from gpu_utils import cuda_memory_summary, release_cuda_cache
release_cuda_cache()
logger.info("TTS 合成前 %s", cuda_memory_summary())
chat, init_err = get_chattts_instance()
if chat is None:
return False, init_err or "ChatTTS 不可用。", None
@@ -604,6 +640,7 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
temperature=TTS_TEMPERATURE,
top_P=TTS_TOP_P,
top_K=TTS_TOP_K,
max_new_token=TTS_MAX_NEW_TOKEN,
)
params_refine_text = ChatTTS.Chat.RefineTextParams(
@@ -619,9 +656,10 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
segment_wavs: List[np.ndarray] = []
for idx, chunk in enumerate(chunks, start=1):
release_cuda_cache()
wavs = chat.infer(
chunk,
skip_refine_text=False,
skip_refine_text=(idx > 1),
params_refine_text=params_refine_text,
params_infer_code=params_infer_code,
)
@@ -633,6 +671,7 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
None,
)
segment_wavs.append(np.asarray(wavs[0], dtype=np.float32))
release_cuda_cache()
wav_array = (
segment_wavs[0]
@@ -661,7 +700,19 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
except Exception as exc:
exc_msg = str(exc)
if "Corrupt input data" in exc_msg:
if "out of memory" in exc_msg.lower() or "OutOfMemoryError" in exc_msg:
release_cuda_cache()
err = (
"语音合成失败: GPU 显存不足(CUDA OOM)。\n"
"3060 Ti 8GB 无法同时运行 Whisper + ChatTTS。\n"
"处理步骤:\n"
"1. pm2 restart trading_studio 释放显存\n"
"2. 不要连续快速点识别+合成;合成前系统会自动卸载 Whisper\n"
"3. 若仍 OOM,在 .env 设置 TTS_MAX_CHARS_PER_CHUNK=150、TTS_MAX_NEW_TOKEN=768\n"
"4. 确认无其他程序占用 GPU: nvidia-smi\n"
f"技术详情: {exc_msg[:400]}"
)
elif "Corrupt input data" in exc_msg:
err = (
"语音合成失败: 音色数据损坏或格式不兼容(Corrupt input data)。\n"
"处理步骤:\n"