Fix CUDA OOM by mutually unloading Whisper and ChatTTS on 8GB GPU.

Release GPU memory before TTS/ASR switches, lower TTS token limits, and set PYTORCH_CUDA_ALLOC_CONF in PM2.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-12 17:03:37 +08:00
parent 82f99c0b89
commit 0cce6cda7c
7 changed files with 169 additions and 40 deletions
+4
View File
@@ -13,3 +13,7 @@ OLLAMA_PORT=11434
# WHISPER_MODEL_DIR=/opt/Trading_Studio/models/whisper # WHISPER_MODEL_DIR=/opt/Trading_Studio/models/whisper
# WHISPER_MODEL_SIZE=small # WHISPER_MODEL_SIZE=small
# HF_ENDPOINT=https://hf-mirror.com # HF_ENDPOINT=https://hf-mirror.com
# 8GB 显存 OOM 时可调低(合成按段切分)
# TTS_MAX_CHARS_PER_CHUNK=150
# TTS_MAX_NEW_TOKEN=768
+19 -3
View File
@@ -733,10 +733,26 @@ nvidia-smi
fuser -v /dev/nvidia* fuser -v /dev/nvidia*
``` ```
Whisper 与 ChatTTS 不会同时常驻最大显存,但首次加载模型时峰值较高。建议 Whisper 与 ChatTTS **不能同时常驻** 8GB 显存(会 CUDA OOM)。应用已自动互斥卸载
- 锁定 120W 功耗墙 - 识别前卸载 ChatTTS
- `max_memory_restart: "6G"` 已在 PM2 配置中设置 - 合成 / 锁定音色前卸载 Whisper
若仍 OOM
```bash
pm2 restart trading_studio
nvidia-smi # 确认无其他占 GPU 进程
```
`.env` 调低合成峰值:
```ini
TTS_MAX_CHARS_PER_CHUNK=150
TTS_MAX_NEW_TOKEN=768
```
PM2 已配置 `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` 缓解碎片。建议锁定 120W 功耗墙。
### 10.3 Whisper 模型加载失败 ### 10.3 Whisper 模型加载失败
+5 -2
View File
@@ -136,8 +136,11 @@ TTS_TOP_P = 0.7
TTS_TOP_K = 20 TTS_TOP_K = 20
TTS_SPEED_PROMPT = "[speed_5]" TTS_SPEED_PROMPT = "[speed_5]"
# 单段 TTS 最大字数(超长稿按句切分后逐段合成再拼接) # 单段 TTS 最大字数(超长稿按句切分后逐段合成再拼接8GB 显存建议 ≤200
TTS_MAX_CHARS_PER_CHUNK = _env_int("TTS_MAX_CHARS_PER_CHUNK", 280) TTS_MAX_CHARS_PER_CHUNK = _env_int("TTS_MAX_CHARS_PER_CHUNK", 200)
# ChatTTS 单段最大生成 token(越小越省显存,长句会自动切多段)
TTS_MAX_NEW_TOKEN = _env_int("TTS_MAX_NEW_TOKEN", 1024)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# 上传临时文件目录 # 上传临时文件目录
+32 -31
View File
@@ -1,31 +1,32 @@
/** /**
* PM2 进程守护配置 * PM2 进程守护配置
* 标准安装路径: /opt/Trading_Studio * 标准安装路径: /opt/Trading_Studio
* 用法: pm2 start ecosystem.config.js * 用法: pm2 start ecosystem.config.js
*/ */
const path = require("path"); const path = require("path");
const APP_DIR = __dirname; const APP_DIR = __dirname;
module.exports = { module.exports = {
apps: [ apps: [
{ {
name: "trading_studio", name: "trading_studio",
script: path.join(APP_DIR, "app.py"), script: path.join(APP_DIR, "app.py"),
interpreter: path.join(APP_DIR, "venv/bin/python"), interpreter: path.join(APP_DIR, "venv/bin/python"),
cwd: APP_DIR, cwd: APP_DIR,
instances: 1, instances: 1,
autorestart: true, autorestart: true,
watch: false, watch: false,
max_memory_restart: "6G", max_memory_restart: "6G",
env: { env: {
PYTHONUNBUFFERED: "1", PYTHONUNBUFFERED: "1",
CUDA_VISIBLE_DEVICES: "0", CUDA_VISIBLE_DEVICES: "0",
}, PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True",
error_file: path.join(APP_DIR, "logs/pm2-error.log"), },
out_file: path.join(APP_DIR, "logs/pm2-out.log"), error_file: path.join(APP_DIR, "logs/pm2-error.log"),
log_date_format: "YYYY-MM-DD HH:mm:ss", out_file: path.join(APP_DIR, "logs/pm2-out.log"),
merge_logs: true, log_date_format: "YYYY-MM-DD HH:mm:ss",
}, merge_logs: true,
], },
}; ],
};
+35
View File
@@ -0,0 +1,35 @@
"""GPU 显存回收工具(3060 Ti 8GBWhisper 与 ChatTTS 不宜同时驻留)。"""
from __future__ import annotations
import gc
import logging
logger = logging.getLogger(__name__)
def release_cuda_cache() -> None:
"""触发 GC 并清空 PyTorch CUDA 缓存。"""
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
if hasattr(torch.cuda, "ipc_collect"):
torch.cuda.ipc_collect()
except ImportError:
pass
def cuda_memory_summary() -> str:
"""返回简要显存占用(调试用)。"""
try:
import torch
if not torch.cuda.is_available():
return "CUDA 不可用"
free, total = torch.cuda.mem_get_info()
return f"GPU 显存: 已用 {(total - free) / 1024**3:.2f}GB / {total / 1024**3:.2f}GB"
except Exception as exc:
return f"显存查询失败: {exc}"
+55 -4
View File
@@ -31,6 +31,7 @@ from config import (
SPEAKER_SAMPLE_MAX_SEC, SPEAKER_SAMPLE_MAX_SEC,
SPEAKER_SAMPLE_MIN_SEC, SPEAKER_SAMPLE_MIN_SEC,
TTS_MAX_CHARS_PER_CHUNK, TTS_MAX_CHARS_PER_CHUNK,
TTS_MAX_NEW_TOKEN,
TTS_SAMPLE_RATE, TTS_SAMPLE_RATE,
TTS_SPEED_PROMPT, TTS_SPEED_PROMPT,
TTS_TEMPERATURE, TTS_TEMPERATURE,
@@ -92,7 +93,7 @@ def _load_chat_model(chat) -> None:
_ensure_hf_env() _ensure_hf_env()
model_dir = CHATTTS_MODEL_DIR model_dir = CHATTTS_MODEL_DIR
base_kwargs: Dict[str, Any] = {"compile": False} base_kwargs: Dict[str, Any] = {"compile": False, "enable_cache": False}
if not hasattr(chat, "load"): if not hasattr(chat, "load"):
if hasattr(chat, "load_models"): if hasattr(chat, "load_models"):
@@ -137,11 +138,26 @@ def _load_chat_model(chat) -> None:
def reset_chattts_instance() -> None: def reset_chattts_instance() -> None:
"""释放 ChatTTS 实例(模型下载后重启前可调用)""" """卸载 ChatTTS 模型并回收 GPU 显存"""
global _chat, _chat_error global _chat, _chat_error
if _chat is not None:
try:
if hasattr(_chat, "unload"):
_chat.unload()
except Exception:
logger.exception("ChatTTS unload 失败")
try:
del _chat
except Exception:
pass
_chat = None _chat = None
_chat_error = None _chat_error = None
from gpu_utils import release_cuda_cache
release_cuda_cache()
logger.info("ChatTTS 模型已卸载,显存已尝试回收。")
def get_chattts_instance(): def get_chattts_instance():
""" """
@@ -348,6 +364,13 @@ def save_fixed_speaker(
if not audio_sample_path: if not audio_sample_path:
return False, "未提供音色参考音频。" return False, "未提供音色参考音频。"
try:
from whisper_service import reset_whisper_model
reset_whisper_model()
except Exception:
logger.debug("锁定音色前释放 Whisper 显存跳过", exc_info=True)
chat, init_err = get_chattts_instance() chat, init_err = get_chattts_instance()
if chat is None: if chat is None:
return False, init_err or "ChatTTS 不可用。" return False, init_err or "ChatTTS 不可用。"
@@ -566,6 +589,19 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
if not refined_text or not refined_text.strip(): if not refined_text or not refined_text.strip():
return False, "合成文本为空,请先完成润色。", None return False, "合成文本为空,请先完成润色。", None
# 合成前释放 Whisper,避免 8GB 显存上双模型 OOM
try:
from whisper_service import reset_whisper_model
reset_whisper_model()
except Exception:
logger.debug("合成前释放 Whisper 显存跳过", exc_info=True)
from gpu_utils import cuda_memory_summary, release_cuda_cache
release_cuda_cache()
logger.info("TTS 合成前 %s", cuda_memory_summary())
chat, init_err = get_chattts_instance() chat, init_err = get_chattts_instance()
if chat is None: if chat is None:
return False, init_err or "ChatTTS 不可用。", None return False, init_err or "ChatTTS 不可用。", None
@@ -604,6 +640,7 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
temperature=TTS_TEMPERATURE, temperature=TTS_TEMPERATURE,
top_P=TTS_TOP_P, top_P=TTS_TOP_P,
top_K=TTS_TOP_K, top_K=TTS_TOP_K,
max_new_token=TTS_MAX_NEW_TOKEN,
) )
params_refine_text = ChatTTS.Chat.RefineTextParams( params_refine_text = ChatTTS.Chat.RefineTextParams(
@@ -619,9 +656,10 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
segment_wavs: List[np.ndarray] = [] segment_wavs: List[np.ndarray] = []
for idx, chunk in enumerate(chunks, start=1): for idx, chunk in enumerate(chunks, start=1):
release_cuda_cache()
wavs = chat.infer( wavs = chat.infer(
chunk, chunk,
skip_refine_text=False, skip_refine_text=(idx > 1),
params_refine_text=params_refine_text, params_refine_text=params_refine_text,
params_infer_code=params_infer_code, params_infer_code=params_infer_code,
) )
@@ -633,6 +671,7 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
None, None,
) )
segment_wavs.append(np.asarray(wavs[0], dtype=np.float32)) segment_wavs.append(np.asarray(wavs[0], dtype=np.float32))
release_cuda_cache()
wav_array = ( wav_array = (
segment_wavs[0] segment_wavs[0]
@@ -661,7 +700,19 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
except Exception as exc: except Exception as exc:
exc_msg = str(exc) exc_msg = str(exc)
if "Corrupt input data" in exc_msg: if "out of memory" in exc_msg.lower() or "OutOfMemoryError" in exc_msg:
release_cuda_cache()
err = (
"语音合成失败: GPU 显存不足(CUDA OOM)。\n"
"3060 Ti 8GB 无法同时运行 Whisper + ChatTTS。\n"
"处理步骤:\n"
"1. pm2 restart trading_studio 释放显存\n"
"2. 不要连续快速点识别+合成;合成前系统会自动卸载 Whisper\n"
"3. 若仍 OOM,在 .env 设置 TTS_MAX_CHARS_PER_CHUNK=150、TTS_MAX_NEW_TOKEN=768\n"
"4. 确认无其他程序占用 GPU: nvidia-smi\n"
f"技术详情: {exc_msg[:400]}"
)
elif "Corrupt input data" in exc_msg:
err = ( err = (
"语音合成失败: 音色数据损坏或格式不兼容(Corrupt input data)。\n" "语音合成失败: 音色数据损坏或格式不兼容(Corrupt input data)。\n"
"处理步骤:\n" "处理步骤:\n"
+19
View File
@@ -156,6 +156,14 @@ def transcribe_audio(audio_path: str) -> Tuple[bool, str]:
if not audio_path: if not audio_path:
return False, "未提供音频文件路径。" return False, "未提供音频文件路径。"
# 识别前释放 ChatTTS,避免与 Whisper 同占 8GB 显存
try:
from tts_service import reset_chattts_instance
reset_chattts_instance()
except Exception:
logger.debug("释放 ChatTTS 显存时跳过", exc_info=True)
model, init_error = get_whisper_model() model, init_error = get_whisper_model()
if model is None: if model is None:
return False, init_error or "Whisper 模型不可用。" return False, init_error or "Whisper 模型不可用。"
@@ -199,6 +207,17 @@ def transcribe_audio(audio_path: str) -> Tuple[bool, str]:
def reset_whisper_model() -> None: def reset_whisper_model() -> None:
"""卸载 Whisper 模型并回收 GPU 显存。"""
global _model, _model_error global _model, _model_error
if _model is not None:
try:
del _model
except Exception:
pass
_model = None _model = None
_model_error = None _model_error = None
from gpu_utils import release_cuda_cache
release_cuda_cache()
logger.info("Whisper 模型已卸载,显存已尝试回收。")