Fix CUDA OOM by mutually unloading Whisper and ChatTTS on 8GB GPU.
Release GPU memory before TTS/ASR switches, lower TTS token limits, and set PYTORCH_CUDA_ALLOC_CONF in PM2. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
+55
-4
@@ -31,6 +31,7 @@ from config import (
|
||||
SPEAKER_SAMPLE_MAX_SEC,
|
||||
SPEAKER_SAMPLE_MIN_SEC,
|
||||
TTS_MAX_CHARS_PER_CHUNK,
|
||||
TTS_MAX_NEW_TOKEN,
|
||||
TTS_SAMPLE_RATE,
|
||||
TTS_SPEED_PROMPT,
|
||||
TTS_TEMPERATURE,
|
||||
@@ -92,7 +93,7 @@ def _load_chat_model(chat) -> None:
|
||||
_ensure_hf_env()
|
||||
model_dir = CHATTTS_MODEL_DIR
|
||||
|
||||
base_kwargs: Dict[str, Any] = {"compile": False}
|
||||
base_kwargs: Dict[str, Any] = {"compile": False, "enable_cache": False}
|
||||
|
||||
if not hasattr(chat, "load"):
|
||||
if hasattr(chat, "load_models"):
|
||||
@@ -137,11 +138,26 @@ def _load_chat_model(chat) -> None:
|
||||
|
||||
|
||||
def reset_chattts_instance() -> None:
|
||||
"""释放 ChatTTS 实例(模型下载后重启前可调用)。"""
|
||||
"""卸载 ChatTTS 模型并回收 GPU 显存。"""
|
||||
global _chat, _chat_error
|
||||
if _chat is not None:
|
||||
try:
|
||||
if hasattr(_chat, "unload"):
|
||||
_chat.unload()
|
||||
except Exception:
|
||||
logger.exception("ChatTTS unload 失败")
|
||||
try:
|
||||
del _chat
|
||||
except Exception:
|
||||
pass
|
||||
_chat = None
|
||||
_chat_error = None
|
||||
|
||||
from gpu_utils import release_cuda_cache
|
||||
|
||||
release_cuda_cache()
|
||||
logger.info("ChatTTS 模型已卸载,显存已尝试回收。")
|
||||
|
||||
|
||||
def get_chattts_instance():
|
||||
"""
|
||||
@@ -348,6 +364,13 @@ def save_fixed_speaker(
|
||||
if not audio_sample_path:
|
||||
return False, "未提供音色参考音频。"
|
||||
|
||||
try:
|
||||
from whisper_service import reset_whisper_model
|
||||
|
||||
reset_whisper_model()
|
||||
except Exception:
|
||||
logger.debug("锁定音色前释放 Whisper 显存跳过", exc_info=True)
|
||||
|
||||
chat, init_err = get_chattts_instance()
|
||||
if chat is None:
|
||||
return False, init_err or "ChatTTS 不可用。"
|
||||
@@ -566,6 +589,19 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
if not refined_text or not refined_text.strip():
|
||||
return False, "合成文本为空,请先完成润色。", None
|
||||
|
||||
# 合成前释放 Whisper,避免 8GB 显存上双模型 OOM
|
||||
try:
|
||||
from whisper_service import reset_whisper_model
|
||||
|
||||
reset_whisper_model()
|
||||
except Exception:
|
||||
logger.debug("合成前释放 Whisper 显存跳过", exc_info=True)
|
||||
|
||||
from gpu_utils import cuda_memory_summary, release_cuda_cache
|
||||
|
||||
release_cuda_cache()
|
||||
logger.info("TTS 合成前 %s", cuda_memory_summary())
|
||||
|
||||
chat, init_err = get_chattts_instance()
|
||||
if chat is None:
|
||||
return False, init_err or "ChatTTS 不可用。", None
|
||||
@@ -604,6 +640,7 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
temperature=TTS_TEMPERATURE,
|
||||
top_P=TTS_TOP_P,
|
||||
top_K=TTS_TOP_K,
|
||||
max_new_token=TTS_MAX_NEW_TOKEN,
|
||||
)
|
||||
|
||||
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
||||
@@ -619,9 +656,10 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
|
||||
segment_wavs: List[np.ndarray] = []
|
||||
for idx, chunk in enumerate(chunks, start=1):
|
||||
release_cuda_cache()
|
||||
wavs = chat.infer(
|
||||
chunk,
|
||||
skip_refine_text=False,
|
||||
skip_refine_text=(idx > 1),
|
||||
params_refine_text=params_refine_text,
|
||||
params_infer_code=params_infer_code,
|
||||
)
|
||||
@@ -633,6 +671,7 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
None,
|
||||
)
|
||||
segment_wavs.append(np.asarray(wavs[0], dtype=np.float32))
|
||||
release_cuda_cache()
|
||||
|
||||
wav_array = (
|
||||
segment_wavs[0]
|
||||
@@ -661,7 +700,19 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
|
||||
except Exception as exc:
|
||||
exc_msg = str(exc)
|
||||
if "Corrupt input data" in exc_msg:
|
||||
if "out of memory" in exc_msg.lower() or "OutOfMemoryError" in exc_msg:
|
||||
release_cuda_cache()
|
||||
err = (
|
||||
"语音合成失败: GPU 显存不足(CUDA OOM)。\n"
|
||||
"3060 Ti 8GB 无法同时运行 Whisper + ChatTTS。\n"
|
||||
"处理步骤:\n"
|
||||
"1. pm2 restart trading_studio 释放显存\n"
|
||||
"2. 不要连续快速点识别+合成;合成前系统会自动卸载 Whisper\n"
|
||||
"3. 若仍 OOM,在 .env 设置 TTS_MAX_CHARS_PER_CHUNK=150、TTS_MAX_NEW_TOKEN=768\n"
|
||||
"4. 确认无其他程序占用 GPU: nvidia-smi\n"
|
||||
f"技术详情: {exc_msg[:400]}"
|
||||
)
|
||||
elif "Corrupt input data" in exc_msg:
|
||||
err = (
|
||||
"语音合成失败: 音色数据损坏或格式不兼容(Corrupt input data)。\n"
|
||||
"处理步骤:\n"
|
||||
|
||||
Reference in New Issue
Block a user