Files
Trading_Studio/whisper_service.py
T
dekun 0cce6cda7c Fix CUDA OOM by mutually unloading Whisper and ChatTTS on 8GB GPU.
Release GPU memory before TTS/ASR switches, lower TTS token limits, and set PYTORCH_CUDA_ALLOC_CONF in PM2.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-12 17:03:37 +08:00

224 lines
6.3 KiB
Python

"""
Faster-Whisper CUDA 语音识别服务
封装本地 GPU 加速的音频转写逻辑,适配 RTX 3060 Ti 8GB 显存。
"""
from __future__ import annotations
import logging
import os
import traceback
from pathlib import Path
from typing import Optional, Tuple
from config import (
BASE_DIR,
HF_ENDPOINT,
HF_HOME,
HF_HUB_DOWNLOAD_TIMEOUT,
WHISPER_COMPUTE_TYPE,
WHISPER_DEVICE,
WHISPER_HF_REPO,
WHISPER_LANGUAGE,
WHISPER_MODEL_DIR,
WHISPER_MODEL_SIZE,
)
logger = logging.getLogger(__name__)
_model = None
_model_error: Optional[str] = None
def _ensure_hf_env() -> None:
os.environ.setdefault("HF_ENDPOINT", HF_ENDPOINT)
os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", str(HF_HUB_DOWNLOAD_TIMEOUT))
os.environ.setdefault("HF_HOME", str(HF_HOME))
WHISPER_MODEL_DIR.mkdir(parents=True, exist_ok=True)
def _whisper_local_path() -> Optional[Path]:
"""返回已预下载的本地模型目录。"""
local = WHISPER_MODEL_DIR / WHISPER_MODEL_SIZE
if (local / "model.bin").is_file():
return local
return None
def _is_cuda_error(exc: BaseException) -> bool:
msg = str(exc).lower()
cuda_keywords = (
"cuda", "cudnn", "cublas", "gpu",
"out of memory", "no kernel image", "device-side assert",
)
return any(k in msg for k in cuda_keywords)
def _is_network_error(exc: BaseException) -> bool:
msg = str(exc).lower()
return any(
k in msg
for k in (
"network is unreachable",
"connection error",
"connecterror",
"timed out",
"couldn't connect",
"name resolution",
"hub",
)
)
def _build_load_error(exc: BaseException) -> str:
lines = [
"Whisper 模型加载失败。",
f"详情: {exc}",
"",
]
if _is_network_error(exc):
lines.extend([
"原因:服务器无法访问 HuggingFace 下载模型(内网/无外网常见)。",
"请在服务器执行(走 HF 镜像,仅需一次):",
f" cd {BASE_DIR}",
" bash scripts/download_whisper_models.sh",
" pm2 restart trading_studio",
"",
f"模型将保存到: {WHISPER_MODEL_DIR / WHISPER_MODEL_SIZE}",
])
else:
lines.append(f"完整日志:\n{traceback.format_exc()}")
return "\n".join(lines)
def get_whisper_model():
"""获取或初始化 Faster-Whisper 模型(优先本地预下载)。"""
global _model, _model_error
if _model is not None:
return _model, None
if _model_error is not None:
return None, _model_error
try:
_ensure_hf_env()
from faster_whisper import WhisperModel
local = _whisper_local_path()
if local:
model_id = str(local)
logger.info("Whisper 从本地加载: %s", model_id)
else:
model_id = WHISPER_MODEL_SIZE
logger.warning(
"未找到本地 Whisper 模型 (%s),尝试在线下载(可能失败)…",
WHISPER_MODEL_DIR / WHISPER_MODEL_SIZE,
)
logger.info(
"Whisper 加载: model=%s, device=%s, compute_type=%s",
model_id,
WHISPER_DEVICE,
WHISPER_COMPUTE_TYPE,
)
_model = WhisperModel(
model_id,
device=WHISPER_DEVICE,
compute_type=WHISPER_COMPUTE_TYPE,
download_root=str(WHISPER_MODEL_DIR),
)
logger.info("Whisper 模型加载成功。")
return _model, None
except ImportError as exc:
_model_error = (
"未安装 faster-whisper,请执行: pip install faster-whisper\n"
f"原始错误: {exc}"
)
logger.exception("faster-whisper 导入失败")
return None, _model_error
except Exception as exc:
if _is_cuda_error(exc):
_model_error = (
"CUDA 初始化失败,请检查 NVIDIA 驱动、CUDA 运行时及 cuDNN。\n"
f"错误详情: {exc}"
)
else:
_model_error = _build_load_error(exc)
logger.exception("Whisper 模型加载异常")
return None, _model_error
def transcribe_audio(audio_path: str) -> Tuple[bool, str]:
"""将音频文件转写为中文文本。"""
if not audio_path:
return False, "未提供音频文件路径。"
# 识别前释放 ChatTTS,避免与 Whisper 同占 8GB 显存
try:
from tts_service import reset_chattts_instance
reset_chattts_instance()
except Exception:
logger.debug("释放 ChatTTS 显存时跳过", exc_info=True)
model, init_error = get_whisper_model()
if model is None:
return False, init_error or "Whisper 模型不可用。"
try:
segments, info = model.transcribe(
audio_path,
language=WHISPER_LANGUAGE,
beam_size=5,
vad_filter=True,
)
text_parts = [segment.text.strip() for segment in segments]
result_text = "".join(text_parts).strip()
if not result_text:
return False, (
"识别结果为空,请检查音频是否有效、音量是否足够,"
f"或尝试更换格式。检测到语言: {getattr(info, 'language', 'unknown')}"
)
logger.info(
"转写完成: 语言=%s, 概率=%.2f, 字数=%d",
getattr(info, "language", "?"),
getattr(info, "language_probability", 0.0),
len(result_text),
)
return True, result_text
except Exception as exc:
if _is_cuda_error(exc):
err = (
"CUDA 推理异常:显存可能不足或 GPU 状态异常。\n"
f"错误详情: {exc}"
)
else:
err = f"音频转写失败: {exc}\n{traceback.format_exc()}"
logger.exception("transcribe_audio 失败")
return False, err
def reset_whisper_model() -> None:
"""卸载 Whisper 模型并回收 GPU 显存。"""
global _model, _model_error
if _model is not None:
try:
del _model
except Exception:
pass
_model = None
_model_error = None
from gpu_utils import release_cuda_cache
release_cuda_cache()
logger.info("Whisper 模型已卸载,显存已尝试回收。")