diff --git a/.env.example b/.env.example index 4ad7ce7..6574b5f 100644 --- a/.env.example +++ b/.env.example @@ -13,3 +13,7 @@ OLLAMA_PORT=11434 # WHISPER_MODEL_DIR=/opt/Trading_Studio/models/whisper # WHISPER_MODEL_SIZE=small # HF_ENDPOINT=https://hf-mirror.com + +# 8GB 显存 OOM 时可调低(合成按段切分) +# TTS_MAX_CHARS_PER_CHUNK=150 +# TTS_MAX_NEW_TOKEN=768 diff --git a/DEPLOY.md b/DEPLOY.md index 7bec072..34b1809 100644 --- a/DEPLOY.md +++ b/DEPLOY.md @@ -733,10 +733,26 @@ nvidia-smi fuser -v /dev/nvidia* ``` -Whisper 与 ChatTTS 不会同时常驻最大显存,但首次加载模型时峰值较高。建议: +Whisper 与 ChatTTS **不能同时常驻** 8GB 显存(会 CUDA OOM)。应用已自动互斥卸载: -- 锁定 120W 功耗墙 -- `max_memory_restart: "6G"` 已在 PM2 配置中设置 +- 识别前卸载 ChatTTS +- 合成 / 锁定音色前卸载 Whisper + +若仍 OOM: + +```bash +pm2 restart trading_studio +nvidia-smi # 确认无其他占 GPU 进程 +``` + +在 `.env` 调低合成峰值: + +```ini +TTS_MAX_CHARS_PER_CHUNK=150 +TTS_MAX_NEW_TOKEN=768 +``` + +PM2 已配置 `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` 缓解碎片。建议锁定 120W 功耗墙。 ### 10.3 Whisper 模型加载失败 diff --git a/config.py b/config.py index 372d4e9..7666262 100644 --- a/config.py +++ b/config.py @@ -136,8 +136,11 @@ TTS_TOP_P = 0.7 TTS_TOP_K = 20 TTS_SPEED_PROMPT = "[speed_5]" -# 单段 TTS 最大字数(超长稿按句切分后逐段合成再拼接) -TTS_MAX_CHARS_PER_CHUNK = _env_int("TTS_MAX_CHARS_PER_CHUNK", 280) +# 单段 TTS 最大字数(超长稿按句切分后逐段合成再拼接;8GB 显存建议 ≤200) +TTS_MAX_CHARS_PER_CHUNK = _env_int("TTS_MAX_CHARS_PER_CHUNK", 200) + +# ChatTTS 单段最大生成 token(越小越省显存,长句会自动切多段) +TTS_MAX_NEW_TOKEN = _env_int("TTS_MAX_NEW_TOKEN", 1024) # --------------------------------------------------------------------------- # 上传临时文件目录 diff --git a/ecosystem.config.js b/ecosystem.config.js index e0a382f..73cf1e3 100644 --- a/ecosystem.config.js +++ b/ecosystem.config.js @@ -1,31 +1,32 @@ -/** - * PM2 进程守护配置 - * 标准安装路径: /opt/Trading_Studio - * 用法: pm2 start ecosystem.config.js - */ -const path = require("path"); - -const APP_DIR = __dirname; - -module.exports = { - apps: [ - { - name: "trading_studio", - script: path.join(APP_DIR, "app.py"), - interpreter: path.join(APP_DIR, "venv/bin/python"), - cwd: APP_DIR, - instances: 1, - autorestart: true, - watch: false, - max_memory_restart: "6G", - env: { - PYTHONUNBUFFERED: "1", - CUDA_VISIBLE_DEVICES: "0", - }, - error_file: path.join(APP_DIR, "logs/pm2-error.log"), - out_file: path.join(APP_DIR, "logs/pm2-out.log"), - log_date_format: "YYYY-MM-DD HH:mm:ss", - merge_logs: true, - }, - ], -}; +/** + * PM2 进程守护配置 + * 标准安装路径: /opt/Trading_Studio + * 用法: pm2 start ecosystem.config.js + */ +const path = require("path"); + +const APP_DIR = __dirname; + +module.exports = { + apps: [ + { + name: "trading_studio", + script: path.join(APP_DIR, "app.py"), + interpreter: path.join(APP_DIR, "venv/bin/python"), + cwd: APP_DIR, + instances: 1, + autorestart: true, + watch: false, + max_memory_restart: "6G", + env: { + PYTHONUNBUFFERED: "1", + CUDA_VISIBLE_DEVICES: "0", + PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True", + }, + error_file: path.join(APP_DIR, "logs/pm2-error.log"), + out_file: path.join(APP_DIR, "logs/pm2-out.log"), + log_date_format: "YYYY-MM-DD HH:mm:ss", + merge_logs: true, + }, + ], +}; diff --git a/gpu_utils.py b/gpu_utils.py new file mode 100644 index 0000000..51cc4cb --- /dev/null +++ b/gpu_utils.py @@ -0,0 +1,35 @@ +"""GPU 显存回收工具(3060 Ti 8GB:Whisper 与 ChatTTS 不宜同时驻留)。""" + +from __future__ import annotations + +import gc +import logging + +logger = logging.getLogger(__name__) + + +def release_cuda_cache() -> None: + """触发 GC 并清空 PyTorch CUDA 缓存。""" + gc.collect() + try: + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if hasattr(torch.cuda, "ipc_collect"): + torch.cuda.ipc_collect() + except ImportError: + pass + + +def cuda_memory_summary() -> str: + """返回简要显存占用(调试用)。""" + try: + import torch + + if not torch.cuda.is_available(): + return "CUDA 不可用" + free, total = torch.cuda.mem_get_info() + return f"GPU 显存: 已用 {(total - free) / 1024**3:.2f}GB / {total / 1024**3:.2f}GB" + except Exception as exc: + return f"显存查询失败: {exc}" diff --git a/tts_service.py b/tts_service.py index 46ee78f..30d573d 100644 --- a/tts_service.py +++ b/tts_service.py @@ -31,6 +31,7 @@ from config import ( SPEAKER_SAMPLE_MAX_SEC, SPEAKER_SAMPLE_MIN_SEC, TTS_MAX_CHARS_PER_CHUNK, + TTS_MAX_NEW_TOKEN, TTS_SAMPLE_RATE, TTS_SPEED_PROMPT, TTS_TEMPERATURE, @@ -92,7 +93,7 @@ def _load_chat_model(chat) -> None: _ensure_hf_env() model_dir = CHATTTS_MODEL_DIR - base_kwargs: Dict[str, Any] = {"compile": False} + base_kwargs: Dict[str, Any] = {"compile": False, "enable_cache": False} if not hasattr(chat, "load"): if hasattr(chat, "load_models"): @@ -137,11 +138,26 @@ def _load_chat_model(chat) -> None: def reset_chattts_instance() -> None: - """释放 ChatTTS 实例(模型下载后重启前可调用)。""" + """卸载 ChatTTS 模型并回收 GPU 显存。""" global _chat, _chat_error + if _chat is not None: + try: + if hasattr(_chat, "unload"): + _chat.unload() + except Exception: + logger.exception("ChatTTS unload 失败") + try: + del _chat + except Exception: + pass _chat = None _chat_error = None + from gpu_utils import release_cuda_cache + + release_cuda_cache() + logger.info("ChatTTS 模型已卸载,显存已尝试回收。") + def get_chattts_instance(): """ @@ -348,6 +364,13 @@ def save_fixed_speaker( if not audio_sample_path: return False, "未提供音色参考音频。" + try: + from whisper_service import reset_whisper_model + + reset_whisper_model() + except Exception: + logger.debug("锁定音色前释放 Whisper 显存跳过", exc_info=True) + chat, init_err = get_chattts_instance() if chat is None: return False, init_err or "ChatTTS 不可用。" @@ -566,6 +589,19 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]: if not refined_text or not refined_text.strip(): return False, "合成文本为空,请先完成润色。", None + # 合成前释放 Whisper,避免 8GB 显存上双模型 OOM + try: + from whisper_service import reset_whisper_model + + reset_whisper_model() + except Exception: + logger.debug("合成前释放 Whisper 显存跳过", exc_info=True) + + from gpu_utils import cuda_memory_summary, release_cuda_cache + + release_cuda_cache() + logger.info("TTS 合成前 %s", cuda_memory_summary()) + chat, init_err = get_chattts_instance() if chat is None: return False, init_err or "ChatTTS 不可用。", None @@ -604,6 +640,7 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]: temperature=TTS_TEMPERATURE, top_P=TTS_TOP_P, top_K=TTS_TOP_K, + max_new_token=TTS_MAX_NEW_TOKEN, ) params_refine_text = ChatTTS.Chat.RefineTextParams( @@ -619,9 +656,10 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]: segment_wavs: List[np.ndarray] = [] for idx, chunk in enumerate(chunks, start=1): + release_cuda_cache() wavs = chat.infer( chunk, - skip_refine_text=False, + skip_refine_text=(idx > 1), params_refine_text=params_refine_text, params_infer_code=params_infer_code, ) @@ -633,6 +671,7 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]: None, ) segment_wavs.append(np.asarray(wavs[0], dtype=np.float32)) + release_cuda_cache() wav_array = ( segment_wavs[0] @@ -661,7 +700,19 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]: except Exception as exc: exc_msg = str(exc) - if "Corrupt input data" in exc_msg: + if "out of memory" in exc_msg.lower() or "OutOfMemoryError" in exc_msg: + release_cuda_cache() + err = ( + "语音合成失败: GPU 显存不足(CUDA OOM)。\n" + "3060 Ti 8GB 无法同时运行 Whisper + ChatTTS。\n" + "处理步骤:\n" + "1. pm2 restart trading_studio 释放显存\n" + "2. 不要连续快速点识别+合成;合成前系统会自动卸载 Whisper\n" + "3. 若仍 OOM,在 .env 设置 TTS_MAX_CHARS_PER_CHUNK=150、TTS_MAX_NEW_TOKEN=768\n" + "4. 确认无其他程序占用 GPU: nvidia-smi\n" + f"技术详情: {exc_msg[:400]}" + ) + elif "Corrupt input data" in exc_msg: err = ( "语音合成失败: 音色数据损坏或格式不兼容(Corrupt input data)。\n" "处理步骤:\n" diff --git a/whisper_service.py b/whisper_service.py index 9c4aa7e..4c0f07b 100644 --- a/whisper_service.py +++ b/whisper_service.py @@ -156,6 +156,14 @@ def transcribe_audio(audio_path: str) -> Tuple[bool, str]: if not audio_path: return False, "未提供音频文件路径。" + # 识别前释放 ChatTTS,避免与 Whisper 同占 8GB 显存 + try: + from tts_service import reset_chattts_instance + + reset_chattts_instance() + except Exception: + logger.debug("释放 ChatTTS 显存时跳过", exc_info=True) + model, init_error = get_whisper_model() if model is None: return False, init_error or "Whisper 模型不可用。" @@ -199,6 +207,17 @@ def transcribe_audio(audio_path: str) -> Tuple[bool, str]: def reset_whisper_model() -> None: + """卸载 Whisper 模型并回收 GPU 显存。""" global _model, _model_error + if _model is not None: + try: + del _model + except Exception: + pass _model = None _model_error = None + + from gpu_utils import release_cuda_cache + + release_cuda_cache() + logger.info("Whisper 模型已卸载,显存已尝试回收。")