Fix ChatTTS load: pre-download via HF mirror, avoid GitHub timeout.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-12 15:16:27 +08:00
parent 1ab1ede1b5
commit aacdffac77
7 changed files with 487 additions and 311 deletions
+4
View File
@@ -7,3 +7,7 @@ OLLAMA_PORT=11434
# 可选:覆盖默认模型名 # 可选:覆盖默认模型名
# MODEL_NAME=huihui_ai/gemma-4-abliterated:e4b # MODEL_NAME=huihui_ai/gemma-4-abliterated:e4b
# ChatTTS 模型目录(预下载脚本写入)
# CHATTTS_MODEL_DIR=/opt/Trading_Studio/models/ChatTTS
# HF_ENDPOINT=https://hf-mirror.com
+2 -2
View File
@@ -20,8 +20,8 @@ env/
# 日志 # 日志
*.log *.log
# 运行时目录 models/
uploads/ hf_cache/
outputs/ outputs/
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
+19 -4
View File
@@ -365,20 +365,35 @@ pip install -r requirements.txt
`requirements.txt` 安装。首次运行会自动下载 `small` 模型(约 500MB)至 HuggingFace 缓存。 `requirements.txt` 安装。首次运行会自动下载 `small` 模型(约 500MB)至 HuggingFace 缓存。
### 6.2 ChatTTS ### 6.2 ChatTTS(必须预下载,勿依赖 GitHub
从 GitHub 源码安装(已在 requirements.txt 中指定): 从 GitHub 源码安装 pip 包(已在 requirements.txt 中指定):
```bash ```bash
pip install ChatTTS @ git+https://github.com/2noise/ChatTTS.git pip install ChatTTS @ git+https://github.com/2noise/ChatTTS.git
``` ```
首次 `save_fixed_speaker``generate_voice` 时会下载模型权重(数 GB),请确保网络畅通或提前配置 HuggingFace 镜像: **重要:** 默认 `chat.load()` 会访问 **github.com** 下载 asset,国内服务器常报 `Read timed out (3)`
部署后**必须**执行预下载脚本(走 HuggingFace 镜像):
```bash ```bash
export HF_ENDPOINT=https://hf-mirror.com # 可选,国内加速 cd /opt/Trading_Studio
source venv/bin/activate
bash scripts/download_chattts_models.sh
pm2 restart trading_studio
``` ```
模型保存至 `/opt/Trading_Studio/models/ChatTTS`(约 12GB,不入 Git)。
`.env` 可自定义:
```ini
HF_ENDPOINT=https://hf-mirror.com
CHATTTS_MODEL_DIR=/opt/Trading_Studio/models/ChatTTS
```
下载完成后再在 Web UI 点击「锁定音色」。
### 6.3 Gradio ### 6.3 Gradio
```bash ```bash
+8
View File
@@ -103,6 +103,14 @@ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# ChatTTS 采样率(Hz # ChatTTS 采样率(Hz
TTS_SAMPLE_RATE = 24000 TTS_SAMPLE_RATE = 24000
# ChatTTS 模型本地目录(预下载后离线加载,避免访问 GitHub 超时)
CHATTTS_MODEL_DIR = Path(_env_str("CHATTTS_MODEL_DIR", str(BASE_DIR / "models" / "ChatTTS")))
# HuggingFace 镜像(国内服务器推荐)
HF_ENDPOINT = _env_str("HF_ENDPOINT", "https://hf-mirror.com")
HF_HUB_DOWNLOAD_TIMEOUT = _env_int("HF_HUB_DOWNLOAD_TIMEOUT", 600)
HF_HOME = Path(_env_str("HF_HOME", str(BASE_DIR / "models" / "hf_cache")))
# 音色样本时长建议(秒) # 音色样本时长建议(秒)
SPEAKER_SAMPLE_MIN_SEC = 10 SPEAKER_SAMPLE_MIN_SEC = 10
SPEAKER_SAMPLE_MAX_SEC = 30 SPEAKER_SAMPLE_MAX_SEC = 30
+1
View File
@@ -21,5 +21,6 @@ librosa>=0.10.0
# 音频处理辅助 # 音频处理辅助
soundfile>=0.12.0 soundfile>=0.12.0
huggingface_hub>=0.20.0
# PM2 通过 Node.js 全局安装,不在 pip 范围内 # PM2 通过 Node.js 全局安装,不在 pip 范围内
+56
View File
@@ -0,0 +1,56 @@
#!/usr/bin/env bash
# 预下载 ChatTTS 模型到本地(走 HuggingFace 镜像,不依赖 GitHub
# 用法: bash scripts/download_chattts_models.sh
set -euo pipefail
ROOT="$(cd "$(dirname "$0")/.." && pwd)"
MODEL_DIR="${CHATTTS_MODEL_DIR:-${ROOT}/models/ChatTTS}"
export MODEL_DIR
VENV_PY="${ROOT}/venv/bin/python"
export HF_ENDPOINT="${HF_ENDPOINT:-https://hf-mirror.com}"
export HF_HUB_DOWNLOAD_TIMEOUT="${HF_HUB_DOWNLOAD_TIMEOUT:-600}"
HF_HOME="${HF_HOME:-${ROOT}/models/hf_cache}"
echo "[INFO] ChatTTS 模型目录: ${MODEL_DIR}"
echo "[INFO] HF 镜像: ${HF_ENDPOINT}"
mkdir -p "${MODEL_DIR}" "${HF_HOME}"
if [[ ! -x "${VENV_PY}" ]]; then
echo "[ERROR] 未找到 venv,请先 bash deploy.sh deps"
exit 1
fi
"${VENV_PY}" -m pip install -q huggingface_hub
"${VENV_PY}" << PY
import os
from pathlib import Path
from huggingface_hub import snapshot_download
target = Path(os.environ["MODEL_DIR"])
target.mkdir(parents=True, exist_ok=True)
print("[INFO] 正在从 HuggingFace 镜像下载 2Noise/ChatTTS(约 1-2GB...")
snapshot_download(
repo_id="2Noise/ChatTTS",
local_dir=str(target),
local_dir_use_symlinks=False,
)
print("[OK] 下载完成:", target)
# 简单校验
checks = [
target / "asset",
target / "config" / "path.yaml",
]
ok = any(p.exists() for p in checks)
if not ok:
print("[WARN] 目录结构异常,请检查下载是否完整")
else:
print("[OK] 模型目录结构校验通过")
PY
echo ""
echo "[OK] 请执行: pm2 restart trading_studio"
echo " 然后重新点击「锁定音色」"
+144 -52
View File
@@ -5,7 +5,9 @@ ChatTTS 本地语音合成服务
from __future__ import annotations from __future__ import annotations
import inspect
import logging import logging
import os
import traceback import traceback
import uuid import uuid
from datetime import datetime from datetime import datetime
@@ -17,6 +19,11 @@ import torch
from scipy.io import wavfile from scipy.io import wavfile
from config import ( from config import (
BASE_DIR,
CHATTTS_MODEL_DIR,
HF_ENDPOINT,
HF_HOME,
HF_HUB_DOWNLOAD_TIMEOUT,
OUTPUT_DIR, OUTPUT_DIR,
SPEAKER_EMB_PATH, SPEAKER_EMB_PATH,
SPEAKER_SAMPLE_MAX_SEC, SPEAKER_SAMPLE_MAX_SEC,
@@ -35,6 +42,143 @@ _chat = None
_chat_error: Optional[str] = None _chat_error: Optional[str] = None
def _ensure_hf_env() -> None:
"""配置 HuggingFace 镜像与下载超时,避免默认 3s 访问 GitHub 超时。"""
os.environ.setdefault("HF_ENDPOINT", HF_ENDPOINT)
os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", str(HF_HUB_DOWNLOAD_TIMEOUT))
os.environ.setdefault("HF_HOME", str(HF_HOME))
HF_HOME.mkdir(parents=True, exist_ok=True)
CHATTTS_MODEL_DIR.mkdir(parents=True, exist_ok=True)
def _chattts_model_ready(model_dir: Path) -> bool:
"""检查本地 ChatTTS 模型目录是否完整。"""
if not model_dir.is_dir():
return False
if (model_dir / "config" / "path.yaml").is_file():
return True
asset_dir = model_dir / "asset"
if asset_dir.is_dir() and any(asset_dir.rglob("*.pt")):
return True
if any(model_dir.glob("*.pt")):
return True
return False
def _build_load_error(exc: BaseException) -> str:
"""生成用户可读的 ChatTTS 加载失败说明。"""
msg = str(exc)
hints = [
"ChatTTS 模型加载失败。",
f"详情: {msg}",
"",
"常见原因:服务器无法访问 GitHubread timeout=3)。",
"解决办法(在服务器执行一次):",
f" cd {BASE_DIR}",
" bash scripts/download_chattts_models.sh",
" pm2 restart trading_studio",
"",
f"模型将下载到: {CHATTTS_MODEL_DIR}",
f"HF 镜像: {HF_ENDPOINT}",
]
return "\n".join(hints)
def _load_chat_model(chat) -> None:
"""按优先级加载 ChatTTS:本地 custom → 镜像下载到 cache_dir。"""
_ensure_hf_env()
model_dir = CHATTTS_MODEL_DIR
base_kwargs: Dict[str, Any] = {"compile": False}
if not hasattr(chat, "load"):
if hasattr(chat, "load_models"):
chat.load_models(**base_kwargs)
return
raise RuntimeError("当前 ChatTTS 版本缺少 load / load_models 方法。")
sig = inspect.signature(chat.load)
params = sig.parameters
# 1) 本地已预下载 → 完全离线,不访问 GitHub
if _chattts_model_ready(model_dir):
logger.info("ChatTTS 从本地目录加载 (source=custom): %s", model_dir)
kwargs = dict(base_kwargs)
if "source" in params:
kwargs["source"] = "custom"
if "custom_path" in params:
kwargs["custom_path"] = str(model_dir)
result = chat.load(**kwargs)
if result is False:
raise RuntimeError(f"ChatTTS 本地加载失败,请检查 {model_dir}")
return
# 2) 未预下载 → 通过 HF 镜像下载到指定目录(仍可能尝试网络)
logger.warning(
"未找到本地 ChatTTS 模型 (%s),尝试通过 HF 镜像下载…",
model_dir,
)
kwargs = dict(base_kwargs)
if "source" in params:
kwargs["source"] = "local"
if "cache_dir" in params:
kwargs["cache_dir"] = str(model_dir)
elif "source" in params:
kwargs["source"] = "huggingface"
result = chat.load(**kwargs)
if result is False:
raise RuntimeError(
"ChatTTS 在线下载失败。请执行: bash scripts/download_chattts_models.sh"
)
def reset_chattts_instance() -> None:
"""释放 ChatTTS 实例(模型下载后重启前可调用)。"""
global _chat, _chat_error
_chat = None
_chat_error = None
def get_chattts_instance():
"""
获取或初始化 ChatTTS 模型。
启用 GPU 加速,compile=False 以兼容 3060 Ti 8GB 显存。
"""
global _chat, _chat_error
if _chat is not None:
return _chat, None
if _chat_error is not None:
return None, _chat_error
try:
_ensure_hf_env()
import ChatTTS
logger.info("正在加载 ChatTTS 模型...")
chat = ChatTTS.Chat()
_load_chat_model(chat)
_chat = chat
logger.info("ChatTTS 模型加载成功。")
return _chat, None
except ImportError as exc:
_chat_error = (
"未安装 ChatTTS,请参考 DEPLOY.md 安装。\n"
f"原始错误: {exc}"
)
logger.exception("ChatTTS 导入失败")
return None, _chat_error
except Exception as exc:
_chat_error = _build_load_error(exc)
logger.exception("ChatTTS 初始化异常")
return None, _chat_error
def _load_audio_for_chattts(audio_path: str, sample_rate: int = TTS_SAMPLE_RATE) -> np.ndarray: def _load_audio_for_chattts(audio_path: str, sample_rate: int = TTS_SAMPLE_RATE) -> np.ndarray:
""" """
加载音频并重采样到 ChatTTS 所需采样率。 加载音频并重采样到 ChatTTS 所需采样率。
@@ -67,51 +211,6 @@ def _get_audio_duration_sec(audio: np.ndarray, sample_rate: int) -> float:
return len(audio) / float(sample_rate) return len(audio) / float(sample_rate)
def get_chattts_instance():
"""
获取或初始化 ChatTTS 模型。
启用 GPU 加速,compile=False 以兼容 3060 Ti 8GB 显存。
"""
global _chat, _chat_error
if _chat is not None:
return _chat, None
if _chat_error is not None:
return None, _chat_error
try:
import ChatTTS
logger.info("正在加载 ChatTTS 模型...")
chat = ChatTTS.Chat()
# 兼容不同版本 APIload_models(旧版)/ load(新版)
if hasattr(chat, "load_models"):
chat.load_models(compile=False)
elif hasattr(chat, "load"):
chat.load(compile=False)
else:
raise RuntimeError("当前 ChatTTS 版本缺少 load / load_models 方法。")
_chat = chat
logger.info("ChatTTS 模型加载成功。")
return _chat, None
except ImportError as exc:
_chat_error = (
"未安装 ChatTTS,请参考 DEPLOY.md 安装。\n"
f"原始错误: {exc}"
)
logger.exception("ChatTTS 导入失败")
return None, _chat_error
except Exception as exc:
_chat_error = f"ChatTTS 模型加载失败: {exc}\n{traceback.format_exc()}"
logger.exception("ChatTTS 初始化异常")
return None, _chat_error
def _encode_spk_emb(chat, tensor_or_str: Any) -> str: def _encode_spk_emb(chat, tensor_or_str: Any) -> str:
"""将 Speaker Embedding 编码为 ChatTTS 可用的字符串格式。""" """将 Speaker Embedding 编码为 ChatTTS 可用的字符串格式。"""
if isinstance(tensor_or_str, str): if isinstance(tensor_or_str, str):
@@ -120,7 +219,6 @@ def _encode_spk_emb(chat, tensor_or_str: Any) -> str:
if hasattr(chat, "_encode_spk_emb"): if hasattr(chat, "_encode_spk_emb"):
return chat._encode_spk_emb(tensor_or_str) return chat._encode_spk_emb(tensor_or_str)
# 兜底:直接转字符串(部分版本可接受 tensor)
return tensor_or_str return tensor_or_str
@@ -159,10 +257,7 @@ def save_fixed_speaker(
max_samples = SPEAKER_SAMPLE_MAX_SEC * TTS_SAMPLE_RATE max_samples = SPEAKER_SAMPLE_MAX_SEC * TTS_SAMPLE_RATE
audio = audio[:max_samples] audio = audio[:max_samples]
# 从参考音频提取音色特征
spk_smp = chat.sample_audio_speaker(audio) spk_smp = chat.sample_audio_speaker(audio)
# 同时保存编码后的 spk_emb 字符串,便于 infer 时直接使用
spk_emb = _encode_spk_emb(chat, spk_smp) spk_emb = _encode_spk_emb(chat, spk_smp)
payload: Dict[str, Any] = { payload: Dict[str, Any] = {
@@ -202,7 +297,6 @@ def _load_speaker_payload() -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
try: try:
payload = torch.load(SPEAKER_EMB_PATH, map_location="cpu", weights_only=False) payload = torch.load(SPEAKER_EMB_PATH, map_location="cpu", weights_only=False)
# 兼容旧版仅保存 tensor 的文件
if isinstance(payload, torch.Tensor): if isinstance(payload, torch.Tensor):
chat, err = get_chattts_instance() chat, err = get_chattts_instance()
if chat is None: if chat is None:
@@ -268,7 +362,6 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
top_K=TTS_TOP_K, top_K=TTS_TOP_K,
) )
# 内向克制语气:降低 oral 强度
params_refine_text = ChatTTS.Chat.RefineTextParams( params_refine_text = ChatTTS.Chat.RefineTextParams(
prompt="[oral_2][laugh_0][break_4]", prompt="[oral_2][laugh_0][break_4]",
) )
@@ -285,7 +378,6 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
wav_array = np.asarray(wavs[0], dtype=np.float32) wav_array = np.asarray(wavs[0], dtype=np.float32)
# 归一化并转 int16
peak = np.max(np.abs(wav_array)) or 1.0 peak = np.max(np.abs(wav_array)) or 1.0
wav_int16 = (wav_array / peak * 32767).astype(np.int16) wav_int16 = (wav_array / peak * 32767).astype(np.int16)