Fix CUDA OOM by mutually unloading Whisper and ChatTTS on 8GB GPU.
Release GPU memory before TTS/ASR switches, lower TTS token limits, and set PYTORCH_CUDA_ALLOC_CONF in PM2. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -13,3 +13,7 @@ OLLAMA_PORT=11434
|
||||
# WHISPER_MODEL_DIR=/opt/Trading_Studio/models/whisper
|
||||
# WHISPER_MODEL_SIZE=small
|
||||
# HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
# 8GB 显存 OOM 时可调低(合成按段切分)
|
||||
# TTS_MAX_CHARS_PER_CHUNK=150
|
||||
# TTS_MAX_NEW_TOKEN=768
|
||||
|
||||
@@ -733,10 +733,26 @@ nvidia-smi
|
||||
fuser -v /dev/nvidia*
|
||||
```
|
||||
|
||||
Whisper 与 ChatTTS 不会同时常驻最大显存,但首次加载模型时峰值较高。建议:
|
||||
Whisper 与 ChatTTS **不能同时常驻** 8GB 显存(会 CUDA OOM)。应用已自动互斥卸载:
|
||||
|
||||
- 锁定 120W 功耗墙
|
||||
- `max_memory_restart: "6G"` 已在 PM2 配置中设置
|
||||
- 识别前卸载 ChatTTS
|
||||
- 合成 / 锁定音色前卸载 Whisper
|
||||
|
||||
若仍 OOM:
|
||||
|
||||
```bash
|
||||
pm2 restart trading_studio
|
||||
nvidia-smi # 确认无其他占 GPU 进程
|
||||
```
|
||||
|
||||
在 `.env` 调低合成峰值:
|
||||
|
||||
```ini
|
||||
TTS_MAX_CHARS_PER_CHUNK=150
|
||||
TTS_MAX_NEW_TOKEN=768
|
||||
```
|
||||
|
||||
PM2 已配置 `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` 缓解碎片。建议锁定 120W 功耗墙。
|
||||
|
||||
### 10.3 Whisper 模型加载失败
|
||||
|
||||
|
||||
@@ -136,8 +136,11 @@ TTS_TOP_P = 0.7
|
||||
TTS_TOP_K = 20
|
||||
TTS_SPEED_PROMPT = "[speed_5]"
|
||||
|
||||
# 单段 TTS 最大字数(超长稿按句切分后逐段合成再拼接)
|
||||
TTS_MAX_CHARS_PER_CHUNK = _env_int("TTS_MAX_CHARS_PER_CHUNK", 280)
|
||||
# 单段 TTS 最大字数(超长稿按句切分后逐段合成再拼接;8GB 显存建议 ≤200)
|
||||
TTS_MAX_CHARS_PER_CHUNK = _env_int("TTS_MAX_CHARS_PER_CHUNK", 200)
|
||||
|
||||
# ChatTTS 单段最大生成 token(越小越省显存,长句会自动切多段)
|
||||
TTS_MAX_NEW_TOKEN = _env_int("TTS_MAX_NEW_TOKEN", 1024)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 上传临时文件目录
|
||||
|
||||
@@ -21,6 +21,7 @@ module.exports = {
|
||||
env: {
|
||||
PYTHONUNBUFFERED: "1",
|
||||
CUDA_VISIBLE_DEVICES: "0",
|
||||
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True",
|
||||
},
|
||||
error_file: path.join(APP_DIR, "logs/pm2-error.log"),
|
||||
out_file: path.join(APP_DIR, "logs/pm2-out.log"),
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""GPU 显存回收工具(3060 Ti 8GB:Whisper 与 ChatTTS 不宜同时驻留)。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def release_cuda_cache() -> None:
|
||||
"""触发 GC 并清空 PyTorch CUDA 缓存。"""
|
||||
gc.collect()
|
||||
try:
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
if hasattr(torch.cuda, "ipc_collect"):
|
||||
torch.cuda.ipc_collect()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def cuda_memory_summary() -> str:
|
||||
"""返回简要显存占用(调试用)。"""
|
||||
try:
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
return "CUDA 不可用"
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
return f"GPU 显存: 已用 {(total - free) / 1024**3:.2f}GB / {total / 1024**3:.2f}GB"
|
||||
except Exception as exc:
|
||||
return f"显存查询失败: {exc}"
|
||||
+55
-4
@@ -31,6 +31,7 @@ from config import (
|
||||
SPEAKER_SAMPLE_MAX_SEC,
|
||||
SPEAKER_SAMPLE_MIN_SEC,
|
||||
TTS_MAX_CHARS_PER_CHUNK,
|
||||
TTS_MAX_NEW_TOKEN,
|
||||
TTS_SAMPLE_RATE,
|
||||
TTS_SPEED_PROMPT,
|
||||
TTS_TEMPERATURE,
|
||||
@@ -92,7 +93,7 @@ def _load_chat_model(chat) -> None:
|
||||
_ensure_hf_env()
|
||||
model_dir = CHATTTS_MODEL_DIR
|
||||
|
||||
base_kwargs: Dict[str, Any] = {"compile": False}
|
||||
base_kwargs: Dict[str, Any] = {"compile": False, "enable_cache": False}
|
||||
|
||||
if not hasattr(chat, "load"):
|
||||
if hasattr(chat, "load_models"):
|
||||
@@ -137,11 +138,26 @@ def _load_chat_model(chat) -> None:
|
||||
|
||||
|
||||
def reset_chattts_instance() -> None:
|
||||
"""释放 ChatTTS 实例(模型下载后重启前可调用)。"""
|
||||
"""卸载 ChatTTS 模型并回收 GPU 显存。"""
|
||||
global _chat, _chat_error
|
||||
if _chat is not None:
|
||||
try:
|
||||
if hasattr(_chat, "unload"):
|
||||
_chat.unload()
|
||||
except Exception:
|
||||
logger.exception("ChatTTS unload 失败")
|
||||
try:
|
||||
del _chat
|
||||
except Exception:
|
||||
pass
|
||||
_chat = None
|
||||
_chat_error = None
|
||||
|
||||
from gpu_utils import release_cuda_cache
|
||||
|
||||
release_cuda_cache()
|
||||
logger.info("ChatTTS 模型已卸载,显存已尝试回收。")
|
||||
|
||||
|
||||
def get_chattts_instance():
|
||||
"""
|
||||
@@ -348,6 +364,13 @@ def save_fixed_speaker(
|
||||
if not audio_sample_path:
|
||||
return False, "未提供音色参考音频。"
|
||||
|
||||
try:
|
||||
from whisper_service import reset_whisper_model
|
||||
|
||||
reset_whisper_model()
|
||||
except Exception:
|
||||
logger.debug("锁定音色前释放 Whisper 显存跳过", exc_info=True)
|
||||
|
||||
chat, init_err = get_chattts_instance()
|
||||
if chat is None:
|
||||
return False, init_err or "ChatTTS 不可用。"
|
||||
@@ -566,6 +589,19 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
if not refined_text or not refined_text.strip():
|
||||
return False, "合成文本为空,请先完成润色。", None
|
||||
|
||||
# 合成前释放 Whisper,避免 8GB 显存上双模型 OOM
|
||||
try:
|
||||
from whisper_service import reset_whisper_model
|
||||
|
||||
reset_whisper_model()
|
||||
except Exception:
|
||||
logger.debug("合成前释放 Whisper 显存跳过", exc_info=True)
|
||||
|
||||
from gpu_utils import cuda_memory_summary, release_cuda_cache
|
||||
|
||||
release_cuda_cache()
|
||||
logger.info("TTS 合成前 %s", cuda_memory_summary())
|
||||
|
||||
chat, init_err = get_chattts_instance()
|
||||
if chat is None:
|
||||
return False, init_err or "ChatTTS 不可用。", None
|
||||
@@ -604,6 +640,7 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
temperature=TTS_TEMPERATURE,
|
||||
top_P=TTS_TOP_P,
|
||||
top_K=TTS_TOP_K,
|
||||
max_new_token=TTS_MAX_NEW_TOKEN,
|
||||
)
|
||||
|
||||
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
||||
@@ -619,9 +656,10 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
|
||||
segment_wavs: List[np.ndarray] = []
|
||||
for idx, chunk in enumerate(chunks, start=1):
|
||||
release_cuda_cache()
|
||||
wavs = chat.infer(
|
||||
chunk,
|
||||
skip_refine_text=False,
|
||||
skip_refine_text=(idx > 1),
|
||||
params_refine_text=params_refine_text,
|
||||
params_infer_code=params_infer_code,
|
||||
)
|
||||
@@ -633,6 +671,7 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
None,
|
||||
)
|
||||
segment_wavs.append(np.asarray(wavs[0], dtype=np.float32))
|
||||
release_cuda_cache()
|
||||
|
||||
wav_array = (
|
||||
segment_wavs[0]
|
||||
@@ -661,7 +700,19 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
|
||||
except Exception as exc:
|
||||
exc_msg = str(exc)
|
||||
if "Corrupt input data" in exc_msg:
|
||||
if "out of memory" in exc_msg.lower() or "OutOfMemoryError" in exc_msg:
|
||||
release_cuda_cache()
|
||||
err = (
|
||||
"语音合成失败: GPU 显存不足(CUDA OOM)。\n"
|
||||
"3060 Ti 8GB 无法同时运行 Whisper + ChatTTS。\n"
|
||||
"处理步骤:\n"
|
||||
"1. pm2 restart trading_studio 释放显存\n"
|
||||
"2. 不要连续快速点识别+合成;合成前系统会自动卸载 Whisper\n"
|
||||
"3. 若仍 OOM,在 .env 设置 TTS_MAX_CHARS_PER_CHUNK=150、TTS_MAX_NEW_TOKEN=768\n"
|
||||
"4. 确认无其他程序占用 GPU: nvidia-smi\n"
|
||||
f"技术详情: {exc_msg[:400]}"
|
||||
)
|
||||
elif "Corrupt input data" in exc_msg:
|
||||
err = (
|
||||
"语音合成失败: 音色数据损坏或格式不兼容(Corrupt input data)。\n"
|
||||
"处理步骤:\n"
|
||||
|
||||
@@ -156,6 +156,14 @@ def transcribe_audio(audio_path: str) -> Tuple[bool, str]:
|
||||
if not audio_path:
|
||||
return False, "未提供音频文件路径。"
|
||||
|
||||
# 识别前释放 ChatTTS,避免与 Whisper 同占 8GB 显存
|
||||
try:
|
||||
from tts_service import reset_chattts_instance
|
||||
|
||||
reset_chattts_instance()
|
||||
except Exception:
|
||||
logger.debug("释放 ChatTTS 显存时跳过", exc_info=True)
|
||||
|
||||
model, init_error = get_whisper_model()
|
||||
if model is None:
|
||||
return False, init_error or "Whisper 模型不可用。"
|
||||
@@ -199,6 +207,17 @@ def transcribe_audio(audio_path: str) -> Tuple[bool, str]:
|
||||
|
||||
|
||||
def reset_whisper_model() -> None:
|
||||
"""卸载 Whisper 模型并回收 GPU 显存。"""
|
||||
global _model, _model_error
|
||||
if _model is not None:
|
||||
try:
|
||||
del _model
|
||||
except Exception:
|
||||
pass
|
||||
_model = None
|
||||
_model_error = None
|
||||
|
||||
from gpu_utils import release_cuda_cache
|
||||
|
||||
release_cuda_cache()
|
||||
logger.info("Whisper 模型已卸载,显存已尝试回收。")
|
||||
|
||||
Reference in New Issue
Block a user