Fix ChatTTS load: pre-download via HF mirror, avoid GitHub timeout.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -7,3 +7,7 @@ OLLAMA_PORT=11434
|
|||||||
|
|
||||||
# 可选:覆盖默认模型名
|
# 可选:覆盖默认模型名
|
||||||
# MODEL_NAME=huihui_ai/gemma-4-abliterated:e4b
|
# MODEL_NAME=huihui_ai/gemma-4-abliterated:e4b
|
||||||
|
|
||||||
|
# ChatTTS 模型目录(预下载脚本写入)
|
||||||
|
# CHATTTS_MODEL_DIR=/opt/Trading_Studio/models/ChatTTS
|
||||||
|
# HF_ENDPOINT=https://hf-mirror.com
|
||||||
|
|||||||
+2
-2
@@ -20,8 +20,8 @@ env/
|
|||||||
# 日志
|
# 日志
|
||||||
*.log
|
*.log
|
||||||
|
|
||||||
# 运行时目录
|
models/
|
||||||
uploads/
|
hf_cache/
|
||||||
outputs/
|
outputs/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
|
|||||||
@@ -365,20 +365,35 @@ pip install -r requirements.txt
|
|||||||
|
|
||||||
随 `requirements.txt` 安装。首次运行会自动下载 `small` 模型(约 500MB)至 HuggingFace 缓存。
|
随 `requirements.txt` 安装。首次运行会自动下载 `small` 模型(约 500MB)至 HuggingFace 缓存。
|
||||||
|
|
||||||
### 6.2 ChatTTS
|
### 6.2 ChatTTS(必须预下载,勿依赖 GitHub)
|
||||||
|
|
||||||
从 GitHub 源码安装(已在 requirements.txt 中指定):
|
从 GitHub 源码安装 pip 包(已在 requirements.txt 中指定):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install ChatTTS @ git+https://github.com/2noise/ChatTTS.git
|
pip install ChatTTS @ git+https://github.com/2noise/ChatTTS.git
|
||||||
```
|
```
|
||||||
|
|
||||||
首次 `save_fixed_speaker` 或 `generate_voice` 时会下载模型权重(数 GB),请确保网络畅通或提前配置 HuggingFace 镜像:
|
**重要:** 默认 `chat.load()` 会访问 **github.com** 下载 asset,国内服务器常报 `Read timed out (3)`。
|
||||||
|
部署后**必须**执行预下载脚本(走 HuggingFace 镜像):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export HF_ENDPOINT=https://hf-mirror.com # 可选,国内加速
|
cd /opt/Trading_Studio
|
||||||
|
source venv/bin/activate
|
||||||
|
bash scripts/download_chattts_models.sh
|
||||||
|
pm2 restart trading_studio
|
||||||
```
|
```
|
||||||
|
|
||||||
|
模型保存至 `/opt/Trading_Studio/models/ChatTTS`(约 1–2GB,不入 Git)。
|
||||||
|
|
||||||
|
`.env` 可自定义:
|
||||||
|
|
||||||
|
```ini
|
||||||
|
HF_ENDPOINT=https://hf-mirror.com
|
||||||
|
CHATTTS_MODEL_DIR=/opt/Trading_Studio/models/ChatTTS
|
||||||
|
```
|
||||||
|
|
||||||
|
下载完成后再在 Web UI 点击「锁定音色」。
|
||||||
|
|
||||||
### 6.3 Gradio
|
### 6.3 Gradio
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -103,6 +103,14 @@ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|||||||
# ChatTTS 采样率(Hz)
|
# ChatTTS 采样率(Hz)
|
||||||
TTS_SAMPLE_RATE = 24000
|
TTS_SAMPLE_RATE = 24000
|
||||||
|
|
||||||
|
# ChatTTS 模型本地目录(预下载后离线加载,避免访问 GitHub 超时)
|
||||||
|
CHATTTS_MODEL_DIR = Path(_env_str("CHATTTS_MODEL_DIR", str(BASE_DIR / "models" / "ChatTTS")))
|
||||||
|
|
||||||
|
# HuggingFace 镜像(国内服务器推荐)
|
||||||
|
HF_ENDPOINT = _env_str("HF_ENDPOINT", "https://hf-mirror.com")
|
||||||
|
HF_HUB_DOWNLOAD_TIMEOUT = _env_int("HF_HUB_DOWNLOAD_TIMEOUT", 600)
|
||||||
|
HF_HOME = Path(_env_str("HF_HOME", str(BASE_DIR / "models" / "hf_cache")))
|
||||||
|
|
||||||
# 音色样本时长建议(秒)
|
# 音色样本时长建议(秒)
|
||||||
SPEAKER_SAMPLE_MIN_SEC = 10
|
SPEAKER_SAMPLE_MIN_SEC = 10
|
||||||
SPEAKER_SAMPLE_MAX_SEC = 30
|
SPEAKER_SAMPLE_MAX_SEC = 30
|
||||||
|
|||||||
@@ -21,5 +21,6 @@ librosa>=0.10.0
|
|||||||
|
|
||||||
# 音频处理辅助
|
# 音频处理辅助
|
||||||
soundfile>=0.12.0
|
soundfile>=0.12.0
|
||||||
|
huggingface_hub>=0.20.0
|
||||||
|
|
||||||
# PM2 通过 Node.js 全局安装,不在 pip 范围内
|
# PM2 通过 Node.js 全局安装,不在 pip 范围内
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# 预下载 ChatTTS 模型到本地(走 HuggingFace 镜像,不依赖 GitHub)
|
||||||
|
# 用法: bash scripts/download_chattts_models.sh
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
||||||
|
MODEL_DIR="${CHATTTS_MODEL_DIR:-${ROOT}/models/ChatTTS}"
|
||||||
|
export MODEL_DIR
|
||||||
|
VENV_PY="${ROOT}/venv/bin/python"
|
||||||
|
|
||||||
|
export HF_ENDPOINT="${HF_ENDPOINT:-https://hf-mirror.com}"
|
||||||
|
export HF_HUB_DOWNLOAD_TIMEOUT="${HF_HUB_DOWNLOAD_TIMEOUT:-600}"
|
||||||
|
HF_HOME="${HF_HOME:-${ROOT}/models/hf_cache}"
|
||||||
|
|
||||||
|
echo "[INFO] ChatTTS 模型目录: ${MODEL_DIR}"
|
||||||
|
echo "[INFO] HF 镜像: ${HF_ENDPOINT}"
|
||||||
|
mkdir -p "${MODEL_DIR}" "${HF_HOME}"
|
||||||
|
|
||||||
|
if [[ ! -x "${VENV_PY}" ]]; then
|
||||||
|
echo "[ERROR] 未找到 venv,请先 bash deploy.sh deps"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
"${VENV_PY}" -m pip install -q huggingface_hub
|
||||||
|
|
||||||
|
"${VENV_PY}" << PY
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
target = Path(os.environ["MODEL_DIR"])
|
||||||
|
target.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print("[INFO] 正在从 HuggingFace 镜像下载 2Noise/ChatTTS(约 1-2GB)...")
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="2Noise/ChatTTS",
|
||||||
|
local_dir=str(target),
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
)
|
||||||
|
print("[OK] 下载完成:", target)
|
||||||
|
|
||||||
|
# 简单校验
|
||||||
|
checks = [
|
||||||
|
target / "asset",
|
||||||
|
target / "config" / "path.yaml",
|
||||||
|
]
|
||||||
|
ok = any(p.exists() for p in checks)
|
||||||
|
if not ok:
|
||||||
|
print("[WARN] 目录结构异常,请检查下载是否完整")
|
||||||
|
else:
|
||||||
|
print("[OK] 模型目录结构校验通过")
|
||||||
|
PY
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "[OK] 请执行: pm2 restart trading_studio"
|
||||||
|
echo " 然后重新点击「锁定音色」"
|
||||||
+397
-305
@@ -1,305 +1,397 @@
|
|||||||
"""
|
"""
|
||||||
ChatTTS 本地语音合成服务
|
ChatTTS 本地语音合成服务
|
||||||
支持从参考人声提取 Speaker Embedding 并固定音色合成配音。
|
支持从参考人声提取 Speaker Embedding 并固定音色合成配音。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import inspect
|
||||||
import traceback
|
import logging
|
||||||
import uuid
|
import os
|
||||||
from datetime import datetime
|
import traceback
|
||||||
from pathlib import Path
|
import uuid
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
import numpy as np
|
from typing import Any, Dict, Optional, Tuple
|
||||||
import torch
|
|
||||||
from scipy.io import wavfile
|
import numpy as np
|
||||||
|
import torch
|
||||||
from config import (
|
from scipy.io import wavfile
|
||||||
OUTPUT_DIR,
|
|
||||||
SPEAKER_EMB_PATH,
|
from config import (
|
||||||
SPEAKER_SAMPLE_MAX_SEC,
|
BASE_DIR,
|
||||||
SPEAKER_SAMPLE_MIN_SEC,
|
CHATTTS_MODEL_DIR,
|
||||||
TTS_SAMPLE_RATE,
|
HF_ENDPOINT,
|
||||||
TTS_SPEED_PROMPT,
|
HF_HOME,
|
||||||
TTS_TEMPERATURE,
|
HF_HUB_DOWNLOAD_TIMEOUT,
|
||||||
TTS_TOP_K,
|
OUTPUT_DIR,
|
||||||
TTS_TOP_P,
|
SPEAKER_EMB_PATH,
|
||||||
)
|
SPEAKER_SAMPLE_MAX_SEC,
|
||||||
|
SPEAKER_SAMPLE_MIN_SEC,
|
||||||
logger = logging.getLogger(__name__)
|
TTS_SAMPLE_RATE,
|
||||||
|
TTS_SPEED_PROMPT,
|
||||||
# 全局 ChatTTS 实例
|
TTS_TEMPERATURE,
|
||||||
_chat = None
|
TTS_TOP_K,
|
||||||
_chat_error: Optional[str] = None
|
TTS_TOP_P,
|
||||||
|
)
|
||||||
|
|
||||||
def _load_audio_for_chattts(audio_path: str, sample_rate: int = TTS_SAMPLE_RATE) -> np.ndarray:
|
logger = logging.getLogger(__name__)
|
||||||
"""
|
|
||||||
加载音频并重采样到 ChatTTS 所需采样率。
|
# 全局 ChatTTS 实例
|
||||||
优先使用 ChatTTS 自带工具,回退到 librosa。
|
_chat = None
|
||||||
"""
|
_chat_error: Optional[str] = None
|
||||||
try:
|
|
||||||
from ChatTTS.utils import load_audio
|
|
||||||
|
def _ensure_hf_env() -> None:
|
||||||
return load_audio(audio_path, sample_rate)
|
"""配置 HuggingFace 镜像与下载超时,避免默认 3s 访问 GitHub 超时。"""
|
||||||
except ImportError:
|
os.environ.setdefault("HF_ENDPOINT", HF_ENDPOINT)
|
||||||
pass
|
os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", str(HF_HUB_DOWNLOAD_TIMEOUT))
|
||||||
|
os.environ.setdefault("HF_HOME", str(HF_HOME))
|
||||||
try:
|
HF_HOME.mkdir(parents=True, exist_ok=True)
|
||||||
from tools.audio import load_audio
|
CHATTTS_MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
return load_audio(audio_path, sample_rate)
|
|
||||||
except ImportError:
|
def _chattts_model_ready(model_dir: Path) -> bool:
|
||||||
pass
|
"""检查本地 ChatTTS 模型目录是否完整。"""
|
||||||
|
if not model_dir.is_dir():
|
||||||
import librosa
|
return False
|
||||||
|
if (model_dir / "config" / "path.yaml").is_file():
|
||||||
audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True)
|
return True
|
||||||
return audio
|
asset_dir = model_dir / "asset"
|
||||||
|
if asset_dir.is_dir() and any(asset_dir.rglob("*.pt")):
|
||||||
|
return True
|
||||||
def _get_audio_duration_sec(audio: np.ndarray, sample_rate: int) -> float:
|
if any(model_dir.glob("*.pt")):
|
||||||
"""计算音频时长(秒)。"""
|
return True
|
||||||
if audio is None or len(audio) == 0:
|
return False
|
||||||
return 0.0
|
|
||||||
return len(audio) / float(sample_rate)
|
|
||||||
|
def _build_load_error(exc: BaseException) -> str:
|
||||||
|
"""生成用户可读的 ChatTTS 加载失败说明。"""
|
||||||
def get_chattts_instance():
|
msg = str(exc)
|
||||||
"""
|
hints = [
|
||||||
获取或初始化 ChatTTS 模型。
|
"ChatTTS 模型加载失败。",
|
||||||
启用 GPU 加速,compile=False 以兼容 3060 Ti 8GB 显存。
|
f"详情: {msg}",
|
||||||
"""
|
"",
|
||||||
global _chat, _chat_error
|
"常见原因:服务器无法访问 GitHub(read timeout=3)。",
|
||||||
|
"解决办法(在服务器执行一次):",
|
||||||
if _chat is not None:
|
f" cd {BASE_DIR}",
|
||||||
return _chat, None
|
" bash scripts/download_chattts_models.sh",
|
||||||
|
" pm2 restart trading_studio",
|
||||||
if _chat_error is not None:
|
"",
|
||||||
return None, _chat_error
|
f"模型将下载到: {CHATTTS_MODEL_DIR}",
|
||||||
|
f"HF 镜像: {HF_ENDPOINT}",
|
||||||
try:
|
]
|
||||||
import ChatTTS
|
return "\n".join(hints)
|
||||||
|
|
||||||
logger.info("正在加载 ChatTTS 模型...")
|
|
||||||
chat = ChatTTS.Chat()
|
def _load_chat_model(chat) -> None:
|
||||||
|
"""按优先级加载 ChatTTS:本地 custom → 镜像下载到 cache_dir。"""
|
||||||
# 兼容不同版本 API:load_models(旧版)/ load(新版)
|
_ensure_hf_env()
|
||||||
if hasattr(chat, "load_models"):
|
model_dir = CHATTTS_MODEL_DIR
|
||||||
chat.load_models(compile=False)
|
|
||||||
elif hasattr(chat, "load"):
|
base_kwargs: Dict[str, Any] = {"compile": False}
|
||||||
chat.load(compile=False)
|
|
||||||
else:
|
if not hasattr(chat, "load"):
|
||||||
raise RuntimeError("当前 ChatTTS 版本缺少 load / load_models 方法。")
|
if hasattr(chat, "load_models"):
|
||||||
|
chat.load_models(**base_kwargs)
|
||||||
_chat = chat
|
return
|
||||||
logger.info("ChatTTS 模型加载成功。")
|
raise RuntimeError("当前 ChatTTS 版本缺少 load / load_models 方法。")
|
||||||
return _chat, None
|
|
||||||
|
sig = inspect.signature(chat.load)
|
||||||
except ImportError as exc:
|
params = sig.parameters
|
||||||
_chat_error = (
|
|
||||||
"未安装 ChatTTS,请参考 DEPLOY.md 安装。\n"
|
# 1) 本地已预下载 → 完全离线,不访问 GitHub
|
||||||
f"原始错误: {exc}"
|
if _chattts_model_ready(model_dir):
|
||||||
)
|
logger.info("ChatTTS 从本地目录加载 (source=custom): %s", model_dir)
|
||||||
logger.exception("ChatTTS 导入失败")
|
kwargs = dict(base_kwargs)
|
||||||
return None, _chat_error
|
if "source" in params:
|
||||||
|
kwargs["source"] = "custom"
|
||||||
except Exception as exc:
|
if "custom_path" in params:
|
||||||
_chat_error = f"ChatTTS 模型加载失败: {exc}\n{traceback.format_exc()}"
|
kwargs["custom_path"] = str(model_dir)
|
||||||
logger.exception("ChatTTS 初始化异常")
|
result = chat.load(**kwargs)
|
||||||
return None, _chat_error
|
if result is False:
|
||||||
|
raise RuntimeError(f"ChatTTS 本地加载失败,请检查 {model_dir}")
|
||||||
|
return
|
||||||
def _encode_spk_emb(chat, tensor_or_str: Any) -> str:
|
|
||||||
"""将 Speaker Embedding 编码为 ChatTTS 可用的字符串格式。"""
|
# 2) 未预下载 → 通过 HF 镜像下载到指定目录(仍可能尝试网络)
|
||||||
if isinstance(tensor_or_str, str):
|
logger.warning(
|
||||||
return tensor_or_str
|
"未找到本地 ChatTTS 模型 (%s),尝试通过 HF 镜像下载…",
|
||||||
|
model_dir,
|
||||||
if hasattr(chat, "_encode_spk_emb"):
|
)
|
||||||
return chat._encode_spk_emb(tensor_or_str)
|
kwargs = dict(base_kwargs)
|
||||||
|
if "source" in params:
|
||||||
# 兜底:直接转字符串(部分版本可接受 tensor)
|
kwargs["source"] = "local"
|
||||||
return tensor_or_str
|
if "cache_dir" in params:
|
||||||
|
kwargs["cache_dir"] = str(model_dir)
|
||||||
|
elif "source" in params:
|
||||||
def save_fixed_speaker(
|
kwargs["source"] = "huggingface"
|
||||||
audio_sample_path: str,
|
|
||||||
sample_transcript: str = "",
|
result = chat.load(**kwargs)
|
||||||
) -> Tuple[bool, str]:
|
if result is False:
|
||||||
"""
|
raise RuntimeError(
|
||||||
从 10-30 秒干净人声中提取 Speaker Embedding 并序列化保存。
|
"ChatTTS 在线下载失败。请执行: bash scripts/download_chattts_models.sh"
|
||||||
|
)
|
||||||
Args:
|
|
||||||
audio_sample_path: 参考人声 wav/mp3 等路径
|
|
||||||
sample_transcript: 参考音频的精确转写(可选,有助于 zero-shot 音色还原)
|
def reset_chattts_instance() -> None:
|
||||||
|
"""释放 ChatTTS 实例(模型下载后重启前可调用)。"""
|
||||||
Returns:
|
global _chat, _chat_error
|
||||||
(success, message)
|
_chat = None
|
||||||
"""
|
_chat_error = None
|
||||||
if not audio_sample_path:
|
|
||||||
return False, "未提供音色参考音频。"
|
|
||||||
|
def get_chattts_instance():
|
||||||
chat, init_err = get_chattts_instance()
|
"""
|
||||||
if chat is None:
|
获取或初始化 ChatTTS 模型。
|
||||||
return False, init_err or "ChatTTS 不可用。"
|
启用 GPU 加速,compile=False 以兼容 3060 Ti 8GB 显存。
|
||||||
|
"""
|
||||||
try:
|
global _chat, _chat_error
|
||||||
audio = _load_audio_for_chattts(audio_sample_path, TTS_SAMPLE_RATE)
|
|
||||||
duration = _get_audio_duration_sec(audio, TTS_SAMPLE_RATE)
|
if _chat is not None:
|
||||||
|
return _chat, None
|
||||||
if duration < SPEAKER_SAMPLE_MIN_SEC:
|
|
||||||
return False, (
|
if _chat_error is not None:
|
||||||
f"参考音频过短({duration:.1f}s),建议 {SPEAKER_SAMPLE_MIN_SEC}-"
|
return None, _chat_error
|
||||||
f"{SPEAKER_SAMPLE_MAX_SEC} 秒干净人声。"
|
|
||||||
)
|
try:
|
||||||
if duration > SPEAKER_SAMPLE_MAX_SEC + 5:
|
_ensure_hf_env()
|
||||||
logger.warning("参考音频超过建议时长 %.1fs,将截取前 %ds", duration, SPEAKER_SAMPLE_MAX_SEC)
|
import ChatTTS
|
||||||
max_samples = SPEAKER_SAMPLE_MAX_SEC * TTS_SAMPLE_RATE
|
|
||||||
audio = audio[:max_samples]
|
logger.info("正在加载 ChatTTS 模型...")
|
||||||
|
chat = ChatTTS.Chat()
|
||||||
# 从参考音频提取音色特征
|
_load_chat_model(chat)
|
||||||
spk_smp = chat.sample_audio_speaker(audio)
|
|
||||||
|
_chat = chat
|
||||||
# 同时保存编码后的 spk_emb 字符串,便于 infer 时直接使用
|
logger.info("ChatTTS 模型加载成功。")
|
||||||
spk_emb = _encode_spk_emb(chat, spk_smp)
|
return _chat, None
|
||||||
|
|
||||||
payload: Dict[str, Any] = {
|
except ImportError as exc:
|
||||||
"spk_emb": spk_emb,
|
_chat_error = (
|
||||||
"spk_smp": spk_smp,
|
"未安装 ChatTTS,请参考 DEPLOY.md 安装。\n"
|
||||||
"txt_smp": sample_transcript.strip(),
|
f"原始错误: {exc}"
|
||||||
"created_at": datetime.now().isoformat(),
|
)
|
||||||
"source_audio": str(audio_sample_path),
|
logger.exception("ChatTTS 导入失败")
|
||||||
}
|
return None, _chat_error
|
||||||
|
|
||||||
torch.save(payload, SPEAKER_EMB_PATH)
|
except Exception as exc:
|
||||||
|
_chat_error = _build_load_error(exc)
|
||||||
msg = (
|
logger.exception("ChatTTS 初始化异常")
|
||||||
f"音色已锁定并保存至 {SPEAKER_EMB_PATH}\n"
|
return None, _chat_error
|
||||||
f"参考时长: {duration:.1f}s"
|
|
||||||
)
|
|
||||||
if not sample_transcript.strip():
|
def _load_audio_for_chattts(audio_path: str, sample_rate: int = TTS_SAMPLE_RATE) -> np.ndarray:
|
||||||
msg += "\n提示:填写参考音频精确转写可进一步提升音色还原度。"
|
"""
|
||||||
|
加载音频并重采样到 ChatTTS 所需采样率。
|
||||||
logger.info("Speaker Embedding 保存成功: %s", SPEAKER_EMB_PATH)
|
优先使用 ChatTTS 自带工具,回退到 librosa。
|
||||||
return True, msg
|
"""
|
||||||
|
try:
|
||||||
except Exception as exc:
|
from ChatTTS.utils import load_audio
|
||||||
err = f"音色提取失败: {exc}\n{traceback.format_exc()}"
|
|
||||||
logger.exception("save_fixed_speaker 失败")
|
return load_audio(audio_path, sample_rate)
|
||||||
return False, err
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
def _load_speaker_payload() -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
|
try:
|
||||||
"""加载本地 speaker_emb.pt。"""
|
from tools.audio import load_audio
|
||||||
if not SPEAKER_EMB_PATH.exists():
|
|
||||||
return None, (
|
return load_audio(audio_path, sample_rate)
|
||||||
f"未找到固定音色文件 `{SPEAKER_EMB_PATH.name}`。"
|
except ImportError:
|
||||||
"请先在【音色锁定】模块上传 10-30 秒参考人声。"
|
pass
|
||||||
)
|
|
||||||
|
import librosa
|
||||||
try:
|
|
||||||
payload = torch.load(SPEAKER_EMB_PATH, map_location="cpu", weights_only=False)
|
audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True)
|
||||||
|
return audio
|
||||||
# 兼容旧版仅保存 tensor 的文件
|
|
||||||
if isinstance(payload, torch.Tensor):
|
|
||||||
chat, err = get_chattts_instance()
|
def _get_audio_duration_sec(audio: np.ndarray, sample_rate: int) -> float:
|
||||||
if chat is None:
|
"""计算音频时长(秒)。"""
|
||||||
return None, err
|
if audio is None or len(audio) == 0:
|
||||||
return {
|
return 0.0
|
||||||
"spk_emb": _encode_spk_emb(chat, payload),
|
return len(audio) / float(sample_rate)
|
||||||
"spk_smp": None,
|
|
||||||
"txt_smp": "",
|
|
||||||
}, None
|
def _encode_spk_emb(chat, tensor_or_str: Any) -> str:
|
||||||
|
"""将 Speaker Embedding 编码为 ChatTTS 可用的字符串格式。"""
|
||||||
if not isinstance(payload, dict):
|
if isinstance(tensor_or_str, str):
|
||||||
return None, "speaker_emb.pt 格式无效,请重新锁定音色。"
|
return tensor_or_str
|
||||||
|
|
||||||
return payload, None
|
if hasattr(chat, "_encode_spk_emb"):
|
||||||
|
return chat._encode_spk_emb(tensor_or_str)
|
||||||
except Exception as exc:
|
|
||||||
return None, f"读取 speaker_emb.pt 失败: {exc}"
|
return tensor_or_str
|
||||||
|
|
||||||
|
|
||||||
def speaker_is_ready() -> Tuple[bool, str]:
|
def save_fixed_speaker(
|
||||||
"""检查固定音色是否已配置。"""
|
audio_sample_path: str,
|
||||||
payload, err = _load_speaker_payload()
|
sample_transcript: str = "",
|
||||||
if payload is None:
|
) -> Tuple[bool, str]:
|
||||||
return False, err or "音色未配置。"
|
"""
|
||||||
return True, f"已加载固定音色: {SPEAKER_EMB_PATH}"
|
从 10-30 秒干净人声中提取 Speaker Embedding 并序列化保存。
|
||||||
|
|
||||||
|
Args:
|
||||||
def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
audio_sample_path: 参考人声 wav/mp3 等路径
|
||||||
"""
|
sample_transcript: 参考音频的精确转写(可选,有助于 zero-shot 音色还原)
|
||||||
使用 ChatTTS 将润色后的文稿合成为 wav 配音。
|
|
||||||
|
Returns:
|
||||||
Args:
|
(success, message)
|
||||||
refined_text: LLM 润色后的配音稿
|
"""
|
||||||
|
if not audio_sample_path:
|
||||||
Returns:
|
return False, "未提供音色参考音频。"
|
||||||
(success, message, output_wav_path_or_none)
|
|
||||||
"""
|
chat, init_err = get_chattts_instance()
|
||||||
if not refined_text or not refined_text.strip():
|
if chat is None:
|
||||||
return False, "合成文本为空,请先完成润色。", None
|
return False, init_err or "ChatTTS 不可用。"
|
||||||
|
|
||||||
chat, init_err = get_chattts_instance()
|
try:
|
||||||
if chat is None:
|
audio = _load_audio_for_chattts(audio_sample_path, TTS_SAMPLE_RATE)
|
||||||
return False, init_err or "ChatTTS 不可用。", None
|
duration = _get_audio_duration_sec(audio, TTS_SAMPLE_RATE)
|
||||||
|
|
||||||
payload, spk_err = _load_speaker_payload()
|
if duration < SPEAKER_SAMPLE_MIN_SEC:
|
||||||
if payload is None:
|
return False, (
|
||||||
return False, spk_err or "请先锁定音色。", None
|
f"参考音频过短({duration:.1f}s),建议 {SPEAKER_SAMPLE_MIN_SEC}-"
|
||||||
|
f"{SPEAKER_SAMPLE_MAX_SEC} 秒干净人声。"
|
||||||
try:
|
)
|
||||||
import ChatTTS
|
if duration > SPEAKER_SAMPLE_MAX_SEC + 5:
|
||||||
|
logger.warning("参考音频超过建议时长 %.1fs,将截取前 %ds", duration, SPEAKER_SAMPLE_MAX_SEC)
|
||||||
spk_emb = payload.get("spk_emb")
|
max_samples = SPEAKER_SAMPLE_MAX_SEC * TTS_SAMPLE_RATE
|
||||||
spk_smp = payload.get("spk_smp")
|
audio = audio[:max_samples]
|
||||||
txt_smp = payload.get("txt_smp", "")
|
|
||||||
|
spk_smp = chat.sample_audio_speaker(audio)
|
||||||
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
spk_emb = _encode_spk_emb(chat, spk_smp)
|
||||||
prompt=TTS_SPEED_PROMPT,
|
|
||||||
spk_emb=spk_emb,
|
payload: Dict[str, Any] = {
|
||||||
spk_smp=spk_smp if spk_smp else None,
|
"spk_emb": spk_emb,
|
||||||
txt_smp=txt_smp if txt_smp else None,
|
"spk_smp": spk_smp,
|
||||||
temperature=TTS_TEMPERATURE,
|
"txt_smp": sample_transcript.strip(),
|
||||||
top_P=TTS_TOP_P,
|
"created_at": datetime.now().isoformat(),
|
||||||
top_K=TTS_TOP_K,
|
"source_audio": str(audio_sample_path),
|
||||||
)
|
}
|
||||||
|
|
||||||
# 内向克制语气:降低 oral 强度
|
torch.save(payload, SPEAKER_EMB_PATH)
|
||||||
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
|
||||||
prompt="[oral_2][laugh_0][break_4]",
|
msg = (
|
||||||
)
|
f"音色已锁定并保存至 {SPEAKER_EMB_PATH}\n"
|
||||||
|
f"参考时长: {duration:.1f}s"
|
||||||
wavs = chat.infer(
|
)
|
||||||
refined_text.strip(),
|
if not sample_transcript.strip():
|
||||||
skip_refine_text=False,
|
msg += "\n提示:填写参考音频精确转写可进一步提升音色还原度。"
|
||||||
params_refine_text=params_refine_text,
|
|
||||||
params_infer_code=params_infer_code,
|
logger.info("Speaker Embedding 保存成功: %s", SPEAKER_EMB_PATH)
|
||||||
)
|
return True, msg
|
||||||
|
|
||||||
if not wavs or len(wavs) == 0:
|
except Exception as exc:
|
||||||
return False, "ChatTTS 未生成有效音频。", None
|
err = f"音色提取失败: {exc}\n{traceback.format_exc()}"
|
||||||
|
logger.exception("save_fixed_speaker 失败")
|
||||||
wav_array = np.asarray(wavs[0], dtype=np.float32)
|
return False, err
|
||||||
|
|
||||||
# 归一化并转 int16
|
|
||||||
peak = np.max(np.abs(wav_array)) or 1.0
|
def _load_speaker_payload() -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||||
wav_int16 = (wav_array / peak * 32767).astype(np.int16)
|
"""加载本地 speaker_emb.pt。"""
|
||||||
|
if not SPEAKER_EMB_PATH.exists():
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
return None, (
|
||||||
filename = f"voiceover_{timestamp}_{uuid.uuid4().hex[:6]}.wav"
|
f"未找到固定音色文件 `{SPEAKER_EMB_PATH.name}`。"
|
||||||
output_path = OUTPUT_DIR / filename
|
"请先在【音色锁定】模块上传 10-30 秒参考人声。"
|
||||||
|
)
|
||||||
wavfile.write(str(output_path), TTS_SAMPLE_RATE, wav_int16)
|
|
||||||
|
try:
|
||||||
msg = f"配音合成成功: {output_path}"
|
payload = torch.load(SPEAKER_EMB_PATH, map_location="cpu", weights_only=False)
|
||||||
logger.info(msg)
|
|
||||||
return True, msg, str(output_path)
|
if isinstance(payload, torch.Tensor):
|
||||||
|
chat, err = get_chattts_instance()
|
||||||
except Exception as exc:
|
if chat is None:
|
||||||
err = f"语音合成失败: {exc}\n{traceback.format_exc()}"
|
return None, err
|
||||||
logger.exception("generate_voice 失败")
|
return {
|
||||||
return False, err, None
|
"spk_emb": _encode_spk_emb(chat, payload),
|
||||||
|
"spk_smp": None,
|
||||||
|
"txt_smp": "",
|
||||||
|
}, None
|
||||||
|
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return None, "speaker_emb.pt 格式无效,请重新锁定音色。"
|
||||||
|
|
||||||
|
return payload, None
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
return None, f"读取 speaker_emb.pt 失败: {exc}"
|
||||||
|
|
||||||
|
|
||||||
|
def speaker_is_ready() -> Tuple[bool, str]:
|
||||||
|
"""检查固定音色是否已配置。"""
|
||||||
|
payload, err = _load_speaker_payload()
|
||||||
|
if payload is None:
|
||||||
|
return False, err or "音色未配置。"
|
||||||
|
return True, f"已加载固定音色: {SPEAKER_EMB_PATH}"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||||
|
"""
|
||||||
|
使用 ChatTTS 将润色后的文稿合成为 wav 配音。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
refined_text: LLM 润色后的配音稿
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(success, message, output_wav_path_or_none)
|
||||||
|
"""
|
||||||
|
if not refined_text or not refined_text.strip():
|
||||||
|
return False, "合成文本为空,请先完成润色。", None
|
||||||
|
|
||||||
|
chat, init_err = get_chattts_instance()
|
||||||
|
if chat is None:
|
||||||
|
return False, init_err or "ChatTTS 不可用。", None
|
||||||
|
|
||||||
|
payload, spk_err = _load_speaker_payload()
|
||||||
|
if payload is None:
|
||||||
|
return False, spk_err or "请先锁定音色。", None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ChatTTS
|
||||||
|
|
||||||
|
spk_emb = payload.get("spk_emb")
|
||||||
|
spk_smp = payload.get("spk_smp")
|
||||||
|
txt_smp = payload.get("txt_smp", "")
|
||||||
|
|
||||||
|
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
||||||
|
prompt=TTS_SPEED_PROMPT,
|
||||||
|
spk_emb=spk_emb,
|
||||||
|
spk_smp=spk_smp if spk_smp else None,
|
||||||
|
txt_smp=txt_smp if txt_smp else None,
|
||||||
|
temperature=TTS_TEMPERATURE,
|
||||||
|
top_P=TTS_TOP_P,
|
||||||
|
top_K=TTS_TOP_K,
|
||||||
|
)
|
||||||
|
|
||||||
|
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
||||||
|
prompt="[oral_2][laugh_0][break_4]",
|
||||||
|
)
|
||||||
|
|
||||||
|
wavs = chat.infer(
|
||||||
|
refined_text.strip(),
|
||||||
|
skip_refine_text=False,
|
||||||
|
params_refine_text=params_refine_text,
|
||||||
|
params_infer_code=params_infer_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not wavs or len(wavs) == 0:
|
||||||
|
return False, "ChatTTS 未生成有效音频。", None
|
||||||
|
|
||||||
|
wav_array = np.asarray(wavs[0], dtype=np.float32)
|
||||||
|
|
||||||
|
peak = np.max(np.abs(wav_array)) or 1.0
|
||||||
|
wav_int16 = (wav_array / peak * 32767).astype(np.int16)
|
||||||
|
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"voiceover_{timestamp}_{uuid.uuid4().hex[:6]}.wav"
|
||||||
|
output_path = OUTPUT_DIR / filename
|
||||||
|
|
||||||
|
wavfile.write(str(output_path), TTS_SAMPLE_RATE, wav_int16)
|
||||||
|
|
||||||
|
msg = f"配音合成成功: {output_path}"
|
||||||
|
logger.info(msg)
|
||||||
|
return True, msg, str(output_path)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
err = f"语音合成失败: {exc}\n{traceback.format_exc()}"
|
||||||
|
logger.exception("generate_voice 失败")
|
||||||
|
return False, err, None
|
||||||
|
|||||||
Reference in New Issue
Block a user