8be34a2fd5
Re-enable KV cache by default, normalize digits and unsafe chars, disable per-chunk split_text, and reload ChatTTS after CUDA errors. Co-authored-by: Cursor <cursoragent@cursor.com>
55 lines
1.4 KiB
Python
55 lines
1.4 KiB
Python
"""GPU 显存回收工具(3060 Ti 8GB:Whisper 与 ChatTTS 不宜同时驻留)。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import gc
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def release_cuda_cache() -> None:
|
||
"""触发 GC 并清空 PyTorch CUDA 缓存。"""
|
||
gc.collect()
|
||
try:
|
||
import torch
|
||
|
||
if torch.cuda.is_available():
|
||
try:
|
||
torch.cuda.synchronize()
|
||
except Exception:
|
||
pass
|
||
torch.cuda.empty_cache()
|
||
if hasattr(torch.cuda, "ipc_collect"):
|
||
torch.cuda.ipc_collect()
|
||
except ImportError:
|
||
pass
|
||
|
||
|
||
def is_cuda_runtime_error(exc: BaseException) -> bool:
|
||
msg = str(exc).lower()
|
||
return any(
|
||
k in msg
|
||
for k in (
|
||
"cuda error",
|
||
"device-side assert",
|
||
"out of memory",
|
||
"cublas",
|
||
"illegal memory access",
|
||
"an illegal instruction",
|
||
)
|
||
)
|
||
|
||
|
||
def cuda_memory_summary() -> str:
|
||
"""返回简要显存占用(调试用)。"""
|
||
try:
|
||
import torch
|
||
|
||
if not torch.cuda.is_available():
|
||
return "CUDA 不可用"
|
||
free, total = torch.cuda.mem_get_info()
|
||
return f"GPU 显存: 已用 {(total - free) / 1024**3:.2f}GB / {total / 1024**3:.2f}GB"
|
||
except Exception as exc:
|
||
return f"显存查询失败: {exc}"
|