Fix CUDA OOM by mutually unloading Whisper and ChatTTS on 8GB GPU.
Release GPU memory before TTS/ASR switches, lower TTS token limits, and set PYTORCH_CUDA_ALLOC_CONF in PM2. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -13,3 +13,7 @@ OLLAMA_PORT=11434
|
|||||||
# WHISPER_MODEL_DIR=/opt/Trading_Studio/models/whisper
|
# WHISPER_MODEL_DIR=/opt/Trading_Studio/models/whisper
|
||||||
# WHISPER_MODEL_SIZE=small
|
# WHISPER_MODEL_SIZE=small
|
||||||
# HF_ENDPOINT=https://hf-mirror.com
|
# HF_ENDPOINT=https://hf-mirror.com
|
||||||
|
|
||||||
|
# 8GB 显存 OOM 时可调低(合成按段切分)
|
||||||
|
# TTS_MAX_CHARS_PER_CHUNK=150
|
||||||
|
# TTS_MAX_NEW_TOKEN=768
|
||||||
|
|||||||
@@ -733,10 +733,26 @@ nvidia-smi
|
|||||||
fuser -v /dev/nvidia*
|
fuser -v /dev/nvidia*
|
||||||
```
|
```
|
||||||
|
|
||||||
Whisper 与 ChatTTS 不会同时常驻最大显存,但首次加载模型时峰值较高。建议:
|
Whisper 与 ChatTTS **不能同时常驻** 8GB 显存(会 CUDA OOM)。应用已自动互斥卸载:
|
||||||
|
|
||||||
- 锁定 120W 功耗墙
|
- 识别前卸载 ChatTTS
|
||||||
- `max_memory_restart: "6G"` 已在 PM2 配置中设置
|
- 合成 / 锁定音色前卸载 Whisper
|
||||||
|
|
||||||
|
若仍 OOM:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pm2 restart trading_studio
|
||||||
|
nvidia-smi # 确认无其他占 GPU 进程
|
||||||
|
```
|
||||||
|
|
||||||
|
在 `.env` 调低合成峰值:
|
||||||
|
|
||||||
|
```ini
|
||||||
|
TTS_MAX_CHARS_PER_CHUNK=150
|
||||||
|
TTS_MAX_NEW_TOKEN=768
|
||||||
|
```
|
||||||
|
|
||||||
|
PM2 已配置 `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` 缓解碎片。建议锁定 120W 功耗墙。
|
||||||
|
|
||||||
### 10.3 Whisper 模型加载失败
|
### 10.3 Whisper 模型加载失败
|
||||||
|
|
||||||
|
|||||||
@@ -136,8 +136,11 @@ TTS_TOP_P = 0.7
|
|||||||
TTS_TOP_K = 20
|
TTS_TOP_K = 20
|
||||||
TTS_SPEED_PROMPT = "[speed_5]"
|
TTS_SPEED_PROMPT = "[speed_5]"
|
||||||
|
|
||||||
# 单段 TTS 最大字数(超长稿按句切分后逐段合成再拼接)
|
# 单段 TTS 最大字数(超长稿按句切分后逐段合成再拼接;8GB 显存建议 ≤200)
|
||||||
TTS_MAX_CHARS_PER_CHUNK = _env_int("TTS_MAX_CHARS_PER_CHUNK", 280)
|
TTS_MAX_CHARS_PER_CHUNK = _env_int("TTS_MAX_CHARS_PER_CHUNK", 200)
|
||||||
|
|
||||||
|
# ChatTTS 单段最大生成 token(越小越省显存,长句会自动切多段)
|
||||||
|
TTS_MAX_NEW_TOKEN = _env_int("TTS_MAX_NEW_TOKEN", 1024)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# 上传临时文件目录
|
# 上传临时文件目录
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ module.exports = {
|
|||||||
env: {
|
env: {
|
||||||
PYTHONUNBUFFERED: "1",
|
PYTHONUNBUFFERED: "1",
|
||||||
CUDA_VISIBLE_DEVICES: "0",
|
CUDA_VISIBLE_DEVICES: "0",
|
||||||
|
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True",
|
||||||
},
|
},
|
||||||
error_file: path.join(APP_DIR, "logs/pm2-error.log"),
|
error_file: path.join(APP_DIR, "logs/pm2-error.log"),
|
||||||
out_file: path.join(APP_DIR, "logs/pm2-out.log"),
|
out_file: path.join(APP_DIR, "logs/pm2-out.log"),
|
||||||
|
|||||||
@@ -0,0 +1,35 @@
|
|||||||
|
"""GPU 显存回收工具(3060 Ti 8GB:Whisper 与 ChatTTS 不宜同时驻留)。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def release_cuda_cache() -> None:
|
||||||
|
"""触发 GC 并清空 PyTorch CUDA 缓存。"""
|
||||||
|
gc.collect()
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if hasattr(torch.cuda, "ipc_collect"):
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_memory_summary() -> str:
|
||||||
|
"""返回简要显存占用(调试用)。"""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return "CUDA 不可用"
|
||||||
|
free, total = torch.cuda.mem_get_info()
|
||||||
|
return f"GPU 显存: 已用 {(total - free) / 1024**3:.2f}GB / {total / 1024**3:.2f}GB"
|
||||||
|
except Exception as exc:
|
||||||
|
return f"显存查询失败: {exc}"
|
||||||
+55
-4
@@ -31,6 +31,7 @@ from config import (
|
|||||||
SPEAKER_SAMPLE_MAX_SEC,
|
SPEAKER_SAMPLE_MAX_SEC,
|
||||||
SPEAKER_SAMPLE_MIN_SEC,
|
SPEAKER_SAMPLE_MIN_SEC,
|
||||||
TTS_MAX_CHARS_PER_CHUNK,
|
TTS_MAX_CHARS_PER_CHUNK,
|
||||||
|
TTS_MAX_NEW_TOKEN,
|
||||||
TTS_SAMPLE_RATE,
|
TTS_SAMPLE_RATE,
|
||||||
TTS_SPEED_PROMPT,
|
TTS_SPEED_PROMPT,
|
||||||
TTS_TEMPERATURE,
|
TTS_TEMPERATURE,
|
||||||
@@ -92,7 +93,7 @@ def _load_chat_model(chat) -> None:
|
|||||||
_ensure_hf_env()
|
_ensure_hf_env()
|
||||||
model_dir = CHATTTS_MODEL_DIR
|
model_dir = CHATTTS_MODEL_DIR
|
||||||
|
|
||||||
base_kwargs: Dict[str, Any] = {"compile": False}
|
base_kwargs: Dict[str, Any] = {"compile": False, "enable_cache": False}
|
||||||
|
|
||||||
if not hasattr(chat, "load"):
|
if not hasattr(chat, "load"):
|
||||||
if hasattr(chat, "load_models"):
|
if hasattr(chat, "load_models"):
|
||||||
@@ -137,11 +138,26 @@ def _load_chat_model(chat) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def reset_chattts_instance() -> None:
|
def reset_chattts_instance() -> None:
|
||||||
"""释放 ChatTTS 实例(模型下载后重启前可调用)。"""
|
"""卸载 ChatTTS 模型并回收 GPU 显存。"""
|
||||||
global _chat, _chat_error
|
global _chat, _chat_error
|
||||||
|
if _chat is not None:
|
||||||
|
try:
|
||||||
|
if hasattr(_chat, "unload"):
|
||||||
|
_chat.unload()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("ChatTTS unload 失败")
|
||||||
|
try:
|
||||||
|
del _chat
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
_chat = None
|
_chat = None
|
||||||
_chat_error = None
|
_chat_error = None
|
||||||
|
|
||||||
|
from gpu_utils import release_cuda_cache
|
||||||
|
|
||||||
|
release_cuda_cache()
|
||||||
|
logger.info("ChatTTS 模型已卸载,显存已尝试回收。")
|
||||||
|
|
||||||
|
|
||||||
def get_chattts_instance():
|
def get_chattts_instance():
|
||||||
"""
|
"""
|
||||||
@@ -348,6 +364,13 @@ def save_fixed_speaker(
|
|||||||
if not audio_sample_path:
|
if not audio_sample_path:
|
||||||
return False, "未提供音色参考音频。"
|
return False, "未提供音色参考音频。"
|
||||||
|
|
||||||
|
try:
|
||||||
|
from whisper_service import reset_whisper_model
|
||||||
|
|
||||||
|
reset_whisper_model()
|
||||||
|
except Exception:
|
||||||
|
logger.debug("锁定音色前释放 Whisper 显存跳过", exc_info=True)
|
||||||
|
|
||||||
chat, init_err = get_chattts_instance()
|
chat, init_err = get_chattts_instance()
|
||||||
if chat is None:
|
if chat is None:
|
||||||
return False, init_err or "ChatTTS 不可用。"
|
return False, init_err or "ChatTTS 不可用。"
|
||||||
@@ -566,6 +589,19 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
|||||||
if not refined_text or not refined_text.strip():
|
if not refined_text or not refined_text.strip():
|
||||||
return False, "合成文本为空,请先完成润色。", None
|
return False, "合成文本为空,请先完成润色。", None
|
||||||
|
|
||||||
|
# 合成前释放 Whisper,避免 8GB 显存上双模型 OOM
|
||||||
|
try:
|
||||||
|
from whisper_service import reset_whisper_model
|
||||||
|
|
||||||
|
reset_whisper_model()
|
||||||
|
except Exception:
|
||||||
|
logger.debug("合成前释放 Whisper 显存跳过", exc_info=True)
|
||||||
|
|
||||||
|
from gpu_utils import cuda_memory_summary, release_cuda_cache
|
||||||
|
|
||||||
|
release_cuda_cache()
|
||||||
|
logger.info("TTS 合成前 %s", cuda_memory_summary())
|
||||||
|
|
||||||
chat, init_err = get_chattts_instance()
|
chat, init_err = get_chattts_instance()
|
||||||
if chat is None:
|
if chat is None:
|
||||||
return False, init_err or "ChatTTS 不可用。", None
|
return False, init_err or "ChatTTS 不可用。", None
|
||||||
@@ -604,6 +640,7 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
|||||||
temperature=TTS_TEMPERATURE,
|
temperature=TTS_TEMPERATURE,
|
||||||
top_P=TTS_TOP_P,
|
top_P=TTS_TOP_P,
|
||||||
top_K=TTS_TOP_K,
|
top_K=TTS_TOP_K,
|
||||||
|
max_new_token=TTS_MAX_NEW_TOKEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
||||||
@@ -619,9 +656,10 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
|||||||
|
|
||||||
segment_wavs: List[np.ndarray] = []
|
segment_wavs: List[np.ndarray] = []
|
||||||
for idx, chunk in enumerate(chunks, start=1):
|
for idx, chunk in enumerate(chunks, start=1):
|
||||||
|
release_cuda_cache()
|
||||||
wavs = chat.infer(
|
wavs = chat.infer(
|
||||||
chunk,
|
chunk,
|
||||||
skip_refine_text=False,
|
skip_refine_text=(idx > 1),
|
||||||
params_refine_text=params_refine_text,
|
params_refine_text=params_refine_text,
|
||||||
params_infer_code=params_infer_code,
|
params_infer_code=params_infer_code,
|
||||||
)
|
)
|
||||||
@@ -633,6 +671,7 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
segment_wavs.append(np.asarray(wavs[0], dtype=np.float32))
|
segment_wavs.append(np.asarray(wavs[0], dtype=np.float32))
|
||||||
|
release_cuda_cache()
|
||||||
|
|
||||||
wav_array = (
|
wav_array = (
|
||||||
segment_wavs[0]
|
segment_wavs[0]
|
||||||
@@ -661,7 +700,19 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
|||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
exc_msg = str(exc)
|
exc_msg = str(exc)
|
||||||
if "Corrupt input data" in exc_msg:
|
if "out of memory" in exc_msg.lower() or "OutOfMemoryError" in exc_msg:
|
||||||
|
release_cuda_cache()
|
||||||
|
err = (
|
||||||
|
"语音合成失败: GPU 显存不足(CUDA OOM)。\n"
|
||||||
|
"3060 Ti 8GB 无法同时运行 Whisper + ChatTTS。\n"
|
||||||
|
"处理步骤:\n"
|
||||||
|
"1. pm2 restart trading_studio 释放显存\n"
|
||||||
|
"2. 不要连续快速点识别+合成;合成前系统会自动卸载 Whisper\n"
|
||||||
|
"3. 若仍 OOM,在 .env 设置 TTS_MAX_CHARS_PER_CHUNK=150、TTS_MAX_NEW_TOKEN=768\n"
|
||||||
|
"4. 确认无其他程序占用 GPU: nvidia-smi\n"
|
||||||
|
f"技术详情: {exc_msg[:400]}"
|
||||||
|
)
|
||||||
|
elif "Corrupt input data" in exc_msg:
|
||||||
err = (
|
err = (
|
||||||
"语音合成失败: 音色数据损坏或格式不兼容(Corrupt input data)。\n"
|
"语音合成失败: 音色数据损坏或格式不兼容(Corrupt input data)。\n"
|
||||||
"处理步骤:\n"
|
"处理步骤:\n"
|
||||||
|
|||||||
@@ -156,6 +156,14 @@ def transcribe_audio(audio_path: str) -> Tuple[bool, str]:
|
|||||||
if not audio_path:
|
if not audio_path:
|
||||||
return False, "未提供音频文件路径。"
|
return False, "未提供音频文件路径。"
|
||||||
|
|
||||||
|
# 识别前释放 ChatTTS,避免与 Whisper 同占 8GB 显存
|
||||||
|
try:
|
||||||
|
from tts_service import reset_chattts_instance
|
||||||
|
|
||||||
|
reset_chattts_instance()
|
||||||
|
except Exception:
|
||||||
|
logger.debug("释放 ChatTTS 显存时跳过", exc_info=True)
|
||||||
|
|
||||||
model, init_error = get_whisper_model()
|
model, init_error = get_whisper_model()
|
||||||
if model is None:
|
if model is None:
|
||||||
return False, init_error or "Whisper 模型不可用。"
|
return False, init_error or "Whisper 模型不可用。"
|
||||||
@@ -199,6 +207,17 @@ def transcribe_audio(audio_path: str) -> Tuple[bool, str]:
|
|||||||
|
|
||||||
|
|
||||||
def reset_whisper_model() -> None:
|
def reset_whisper_model() -> None:
|
||||||
|
"""卸载 Whisper 模型并回收 GPU 显存。"""
|
||||||
global _model, _model_error
|
global _model, _model_error
|
||||||
|
if _model is not None:
|
||||||
|
try:
|
||||||
|
del _model
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
_model = None
|
_model = None
|
||||||
_model_error = None
|
_model_error = None
|
||||||
|
|
||||||
|
from gpu_utils import release_cuda_cache
|
||||||
|
|
||||||
|
release_cuda_cache()
|
||||||
|
logger.info("Whisper 模型已卸载,显存已尝试回收。")
|
||||||
|
|||||||
Reference in New Issue
Block a user