Files
Trading_Studio/gpu_utils.py
T
dekun 8be34a2fd5 Fix ChatTTS CUDA device-side assert with text sanitize and GPU recovery.
Re-enable KV cache by default, normalize digits and unsafe chars, disable per-chunk split_text, and reload ChatTTS after CUDA errors.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-12 17:13:57 +08:00

55 lines
1.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""GPU 显存回收工具(3060 Ti 8GBWhisper 与 ChatTTS 不宜同时驻留)。"""
from __future__ import annotations
import gc
import logging
logger = logging.getLogger(__name__)
def release_cuda_cache() -> None:
"""触发 GC 并清空 PyTorch CUDA 缓存。"""
gc.collect()
try:
import torch
if torch.cuda.is_available():
try:
torch.cuda.synchronize()
except Exception:
pass
torch.cuda.empty_cache()
if hasattr(torch.cuda, "ipc_collect"):
torch.cuda.ipc_collect()
except ImportError:
pass
def is_cuda_runtime_error(exc: BaseException) -> bool:
msg = str(exc).lower()
return any(
k in msg
for k in (
"cuda error",
"device-side assert",
"out of memory",
"cublas",
"illegal memory access",
"an illegal instruction",
)
)
def cuda_memory_summary() -> str:
"""返回简要显存占用(调试用)。"""
try:
import torch
if not torch.cuda.is_available():
return "CUDA 不可用"
free, total = torch.cuda.mem_get_info()
return f"GPU 显存: 已用 {(total - free) / 1024**3:.2f}GB / {total / 1024**3:.2f}GB"
except Exception as exc:
return f"显存查询失败: {exc}"