From aacdffac778adf88a398cfbc44c6d187cb0db76f Mon Sep 17 00:00:00 2001 From: dekun Date: Fri, 12 Jun 2026 15:16:27 +0800 Subject: [PATCH] Fix ChatTTS load: pre-download via HF mirror, avoid GitHub timeout. Co-authored-by: Cursor --- .env.example | 4 + .gitignore | 4 +- DEPLOY.md | 23 +- config.py | 8 + requirements.txt | 1 + scripts/download_chattts_models.sh | 56 +++ tts_service.py | 702 ++++++++++++++++------------- 7 files changed, 487 insertions(+), 311 deletions(-) create mode 100644 scripts/download_chattts_models.sh diff --git a/.env.example b/.env.example index 33cbe2b..3d93425 100644 --- a/.env.example +++ b/.env.example @@ -7,3 +7,7 @@ OLLAMA_PORT=11434 # 可选:覆盖默认模型名 # MODEL_NAME=huihui_ai/gemma-4-abliterated:e4b + +# ChatTTS 模型目录(预下载脚本写入) +# CHATTTS_MODEL_DIR=/opt/Trading_Studio/models/ChatTTS +# HF_ENDPOINT=https://hf-mirror.com diff --git a/.gitignore b/.gitignore index c0286a4..e3c0e12 100644 --- a/.gitignore +++ b/.gitignore @@ -20,8 +20,8 @@ env/ # 日志 *.log -# 运行时目录 -uploads/ +models/ +hf_cache/ outputs/ __pycache__/ *.py[cod] diff --git a/DEPLOY.md b/DEPLOY.md index 46ca266..658c13c 100644 --- a/DEPLOY.md +++ b/DEPLOY.md @@ -365,20 +365,35 @@ pip install -r requirements.txt 随 `requirements.txt` 安装。首次运行会自动下载 `small` 模型(约 500MB)至 HuggingFace 缓存。 -### 6.2 ChatTTS +### 6.2 ChatTTS(必须预下载,勿依赖 GitHub) -从 GitHub 源码安装(已在 requirements.txt 中指定): +从 GitHub 源码安装 pip 包(已在 requirements.txt 中指定): ```bash pip install ChatTTS @ git+https://github.com/2noise/ChatTTS.git ``` -首次 `save_fixed_speaker` 或 `generate_voice` 时会下载模型权重(数 GB),请确保网络畅通或提前配置 HuggingFace 镜像: +**重要:** 默认 `chat.load()` 会访问 **github.com** 下载 asset,国内服务器常报 `Read timed out (3)`。 +部署后**必须**执行预下载脚本(走 HuggingFace 镜像): ```bash -export HF_ENDPOINT=https://hf-mirror.com # 可选,国内加速 +cd /opt/Trading_Studio +source venv/bin/activate +bash scripts/download_chattts_models.sh +pm2 restart trading_studio ``` +模型保存至 `/opt/Trading_Studio/models/ChatTTS`(约 1–2GB,不入 Git)。 + +`.env` 可自定义: + +```ini +HF_ENDPOINT=https://hf-mirror.com +CHATTTS_MODEL_DIR=/opt/Trading_Studio/models/ChatTTS +``` + +下载完成后再在 Web UI 点击「锁定音色」。 + ### 6.3 Gradio ```bash diff --git a/config.py b/config.py index 7e19fac..c0bc924 100644 --- a/config.py +++ b/config.py @@ -103,6 +103,14 @@ OUTPUT_DIR.mkdir(parents=True, exist_ok=True) # ChatTTS 采样率(Hz) TTS_SAMPLE_RATE = 24000 +# ChatTTS 模型本地目录(预下载后离线加载,避免访问 GitHub 超时) +CHATTTS_MODEL_DIR = Path(_env_str("CHATTTS_MODEL_DIR", str(BASE_DIR / "models" / "ChatTTS"))) + +# HuggingFace 镜像(国内服务器推荐) +HF_ENDPOINT = _env_str("HF_ENDPOINT", "https://hf-mirror.com") +HF_HUB_DOWNLOAD_TIMEOUT = _env_int("HF_HUB_DOWNLOAD_TIMEOUT", 600) +HF_HOME = Path(_env_str("HF_HOME", str(BASE_DIR / "models" / "hf_cache"))) + # 音色样本时长建议(秒) SPEAKER_SAMPLE_MIN_SEC = 10 SPEAKER_SAMPLE_MAX_SEC = 30 diff --git a/requirements.txt b/requirements.txt index 780a454..db1084a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,5 +21,6 @@ librosa>=0.10.0 # 音频处理辅助 soundfile>=0.12.0 +huggingface_hub>=0.20.0 # PM2 通过 Node.js 全局安装,不在 pip 范围内 diff --git a/scripts/download_chattts_models.sh b/scripts/download_chattts_models.sh new file mode 100644 index 0000000..3d86f72 --- /dev/null +++ b/scripts/download_chattts_models.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash +# 预下载 ChatTTS 模型到本地(走 HuggingFace 镜像,不依赖 GitHub) +# 用法: bash scripts/download_chattts_models.sh +set -euo pipefail + +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +MODEL_DIR="${CHATTTS_MODEL_DIR:-${ROOT}/models/ChatTTS}" +export MODEL_DIR +VENV_PY="${ROOT}/venv/bin/python" + +export HF_ENDPOINT="${HF_ENDPOINT:-https://hf-mirror.com}" +export HF_HUB_DOWNLOAD_TIMEOUT="${HF_HUB_DOWNLOAD_TIMEOUT:-600}" +HF_HOME="${HF_HOME:-${ROOT}/models/hf_cache}" + +echo "[INFO] ChatTTS 模型目录: ${MODEL_DIR}" +echo "[INFO] HF 镜像: ${HF_ENDPOINT}" +mkdir -p "${MODEL_DIR}" "${HF_HOME}" + +if [[ ! -x "${VENV_PY}" ]]; then + echo "[ERROR] 未找到 venv,请先 bash deploy.sh deps" + exit 1 +fi + +"${VENV_PY}" -m pip install -q huggingface_hub + +"${VENV_PY}" << PY +import os +from pathlib import Path +from huggingface_hub import snapshot_download + +target = Path(os.environ["MODEL_DIR"]) +target.mkdir(parents=True, exist_ok=True) + +print("[INFO] 正在从 HuggingFace 镜像下载 2Noise/ChatTTS(约 1-2GB)...") +snapshot_download( + repo_id="2Noise/ChatTTS", + local_dir=str(target), + local_dir_use_symlinks=False, +) +print("[OK] 下载完成:", target) + +# 简单校验 +checks = [ + target / "asset", + target / "config" / "path.yaml", +] +ok = any(p.exists() for p in checks) +if not ok: + print("[WARN] 目录结构异常,请检查下载是否完整") +else: + print("[OK] 模型目录结构校验通过") +PY + +echo "" +echo "[OK] 请执行: pm2 restart trading_studio" +echo " 然后重新点击「锁定音色」" diff --git a/tts_service.py b/tts_service.py index ddde3ea..bef1654 100644 --- a/tts_service.py +++ b/tts_service.py @@ -1,305 +1,397 @@ -""" -ChatTTS 本地语音合成服务 -支持从参考人声提取 Speaker Embedding 并固定音色合成配音。 -""" - -from __future__ import annotations - -import logging -import traceback -import uuid -from datetime import datetime -from pathlib import Path -from typing import Any, Dict, Optional, Tuple - -import numpy as np -import torch -from scipy.io import wavfile - -from config import ( - OUTPUT_DIR, - SPEAKER_EMB_PATH, - SPEAKER_SAMPLE_MAX_SEC, - SPEAKER_SAMPLE_MIN_SEC, - TTS_SAMPLE_RATE, - TTS_SPEED_PROMPT, - TTS_TEMPERATURE, - TTS_TOP_K, - TTS_TOP_P, -) - -logger = logging.getLogger(__name__) - -# 全局 ChatTTS 实例 -_chat = None -_chat_error: Optional[str] = None - - -def _load_audio_for_chattts(audio_path: str, sample_rate: int = TTS_SAMPLE_RATE) -> np.ndarray: - """ - 加载音频并重采样到 ChatTTS 所需采样率。 - 优先使用 ChatTTS 自带工具,回退到 librosa。 - """ - try: - from ChatTTS.utils import load_audio - - return load_audio(audio_path, sample_rate) - except ImportError: - pass - - try: - from tools.audio import load_audio - - return load_audio(audio_path, sample_rate) - except ImportError: - pass - - import librosa - - audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True) - return audio - - -def _get_audio_duration_sec(audio: np.ndarray, sample_rate: int) -> float: - """计算音频时长(秒)。""" - if audio is None or len(audio) == 0: - return 0.0 - return len(audio) / float(sample_rate) - - -def get_chattts_instance(): - """ - 获取或初始化 ChatTTS 模型。 - 启用 GPU 加速,compile=False 以兼容 3060 Ti 8GB 显存。 - """ - global _chat, _chat_error - - if _chat is not None: - return _chat, None - - if _chat_error is not None: - return None, _chat_error - - try: - import ChatTTS - - logger.info("正在加载 ChatTTS 模型...") - chat = ChatTTS.Chat() - - # 兼容不同版本 API:load_models(旧版)/ load(新版) - if hasattr(chat, "load_models"): - chat.load_models(compile=False) - elif hasattr(chat, "load"): - chat.load(compile=False) - else: - raise RuntimeError("当前 ChatTTS 版本缺少 load / load_models 方法。") - - _chat = chat - logger.info("ChatTTS 模型加载成功。") - return _chat, None - - except ImportError as exc: - _chat_error = ( - "未安装 ChatTTS,请参考 DEPLOY.md 安装。\n" - f"原始错误: {exc}" - ) - logger.exception("ChatTTS 导入失败") - return None, _chat_error - - except Exception as exc: - _chat_error = f"ChatTTS 模型加载失败: {exc}\n{traceback.format_exc()}" - logger.exception("ChatTTS 初始化异常") - return None, _chat_error - - -def _encode_spk_emb(chat, tensor_or_str: Any) -> str: - """将 Speaker Embedding 编码为 ChatTTS 可用的字符串格式。""" - if isinstance(tensor_or_str, str): - return tensor_or_str - - if hasattr(chat, "_encode_spk_emb"): - return chat._encode_spk_emb(tensor_or_str) - - # 兜底:直接转字符串(部分版本可接受 tensor) - return tensor_or_str - - -def save_fixed_speaker( - audio_sample_path: str, - sample_transcript: str = "", -) -> Tuple[bool, str]: - """ - 从 10-30 秒干净人声中提取 Speaker Embedding 并序列化保存。 - - Args: - audio_sample_path: 参考人声 wav/mp3 等路径 - sample_transcript: 参考音频的精确转写(可选,有助于 zero-shot 音色还原) - - Returns: - (success, message) - """ - if not audio_sample_path: - return False, "未提供音色参考音频。" - - chat, init_err = get_chattts_instance() - if chat is None: - return False, init_err or "ChatTTS 不可用。" - - try: - audio = _load_audio_for_chattts(audio_sample_path, TTS_SAMPLE_RATE) - duration = _get_audio_duration_sec(audio, TTS_SAMPLE_RATE) - - if duration < SPEAKER_SAMPLE_MIN_SEC: - return False, ( - f"参考音频过短({duration:.1f}s),建议 {SPEAKER_SAMPLE_MIN_SEC}-" - f"{SPEAKER_SAMPLE_MAX_SEC} 秒干净人声。" - ) - if duration > SPEAKER_SAMPLE_MAX_SEC + 5: - logger.warning("参考音频超过建议时长 %.1fs,将截取前 %ds", duration, SPEAKER_SAMPLE_MAX_SEC) - max_samples = SPEAKER_SAMPLE_MAX_SEC * TTS_SAMPLE_RATE - audio = audio[:max_samples] - - # 从参考音频提取音色特征 - spk_smp = chat.sample_audio_speaker(audio) - - # 同时保存编码后的 spk_emb 字符串,便于 infer 时直接使用 - spk_emb = _encode_spk_emb(chat, spk_smp) - - payload: Dict[str, Any] = { - "spk_emb": spk_emb, - "spk_smp": spk_smp, - "txt_smp": sample_transcript.strip(), - "created_at": datetime.now().isoformat(), - "source_audio": str(audio_sample_path), - } - - torch.save(payload, SPEAKER_EMB_PATH) - - msg = ( - f"音色已锁定并保存至 {SPEAKER_EMB_PATH}\n" - f"参考时长: {duration:.1f}s" - ) - if not sample_transcript.strip(): - msg += "\n提示:填写参考音频精确转写可进一步提升音色还原度。" - - logger.info("Speaker Embedding 保存成功: %s", SPEAKER_EMB_PATH) - return True, msg - - except Exception as exc: - err = f"音色提取失败: {exc}\n{traceback.format_exc()}" - logger.exception("save_fixed_speaker 失败") - return False, err - - -def _load_speaker_payload() -> Tuple[Optional[Dict[str, Any]], Optional[str]]: - """加载本地 speaker_emb.pt。""" - if not SPEAKER_EMB_PATH.exists(): - return None, ( - f"未找到固定音色文件 `{SPEAKER_EMB_PATH.name}`。" - "请先在【音色锁定】模块上传 10-30 秒参考人声。" - ) - - try: - payload = torch.load(SPEAKER_EMB_PATH, map_location="cpu", weights_only=False) - - # 兼容旧版仅保存 tensor 的文件 - if isinstance(payload, torch.Tensor): - chat, err = get_chattts_instance() - if chat is None: - return None, err - return { - "spk_emb": _encode_spk_emb(chat, payload), - "spk_smp": None, - "txt_smp": "", - }, None - - if not isinstance(payload, dict): - return None, "speaker_emb.pt 格式无效,请重新锁定音色。" - - return payload, None - - except Exception as exc: - return None, f"读取 speaker_emb.pt 失败: {exc}" - - -def speaker_is_ready() -> Tuple[bool, str]: - """检查固定音色是否已配置。""" - payload, err = _load_speaker_payload() - if payload is None: - return False, err or "音色未配置。" - return True, f"已加载固定音色: {SPEAKER_EMB_PATH}" - - -def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]: - """ - 使用 ChatTTS 将润色后的文稿合成为 wav 配音。 - - Args: - refined_text: LLM 润色后的配音稿 - - Returns: - (success, message, output_wav_path_or_none) - """ - if not refined_text or not refined_text.strip(): - return False, "合成文本为空,请先完成润色。", None - - chat, init_err = get_chattts_instance() - if chat is None: - return False, init_err or "ChatTTS 不可用。", None - - payload, spk_err = _load_speaker_payload() - if payload is None: - return False, spk_err or "请先锁定音色。", None - - try: - import ChatTTS - - spk_emb = payload.get("spk_emb") - spk_smp = payload.get("spk_smp") - txt_smp = payload.get("txt_smp", "") - - params_infer_code = ChatTTS.Chat.InferCodeParams( - prompt=TTS_SPEED_PROMPT, - spk_emb=spk_emb, - spk_smp=spk_smp if spk_smp else None, - txt_smp=txt_smp if txt_smp else None, - temperature=TTS_TEMPERATURE, - top_P=TTS_TOP_P, - top_K=TTS_TOP_K, - ) - - # 内向克制语气:降低 oral 强度 - params_refine_text = ChatTTS.Chat.RefineTextParams( - prompt="[oral_2][laugh_0][break_4]", - ) - - wavs = chat.infer( - refined_text.strip(), - skip_refine_text=False, - params_refine_text=params_refine_text, - params_infer_code=params_infer_code, - ) - - if not wavs or len(wavs) == 0: - return False, "ChatTTS 未生成有效音频。", None - - wav_array = np.asarray(wavs[0], dtype=np.float32) - - # 归一化并转 int16 - peak = np.max(np.abs(wav_array)) or 1.0 - wav_int16 = (wav_array / peak * 32767).astype(np.int16) - - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - filename = f"voiceover_{timestamp}_{uuid.uuid4().hex[:6]}.wav" - output_path = OUTPUT_DIR / filename - - wavfile.write(str(output_path), TTS_SAMPLE_RATE, wav_int16) - - msg = f"配音合成成功: {output_path}" - logger.info(msg) - return True, msg, str(output_path) - - except Exception as exc: - err = f"语音合成失败: {exc}\n{traceback.format_exc()}" - logger.exception("generate_voice 失败") - return False, err, None +""" +ChatTTS 本地语音合成服务 +支持从参考人声提取 Speaker Embedding 并固定音色合成配音。 +""" + +from __future__ import annotations + +import inspect +import logging +import os +import traceback +import uuid +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch +from scipy.io import wavfile + +from config import ( + BASE_DIR, + CHATTTS_MODEL_DIR, + HF_ENDPOINT, + HF_HOME, + HF_HUB_DOWNLOAD_TIMEOUT, + OUTPUT_DIR, + SPEAKER_EMB_PATH, + SPEAKER_SAMPLE_MAX_SEC, + SPEAKER_SAMPLE_MIN_SEC, + TTS_SAMPLE_RATE, + TTS_SPEED_PROMPT, + TTS_TEMPERATURE, + TTS_TOP_K, + TTS_TOP_P, +) + +logger = logging.getLogger(__name__) + +# 全局 ChatTTS 实例 +_chat = None +_chat_error: Optional[str] = None + + +def _ensure_hf_env() -> None: + """配置 HuggingFace 镜像与下载超时,避免默认 3s 访问 GitHub 超时。""" + os.environ.setdefault("HF_ENDPOINT", HF_ENDPOINT) + os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", str(HF_HUB_DOWNLOAD_TIMEOUT)) + os.environ.setdefault("HF_HOME", str(HF_HOME)) + HF_HOME.mkdir(parents=True, exist_ok=True) + CHATTTS_MODEL_DIR.mkdir(parents=True, exist_ok=True) + + +def _chattts_model_ready(model_dir: Path) -> bool: + """检查本地 ChatTTS 模型目录是否完整。""" + if not model_dir.is_dir(): + return False + if (model_dir / "config" / "path.yaml").is_file(): + return True + asset_dir = model_dir / "asset" + if asset_dir.is_dir() and any(asset_dir.rglob("*.pt")): + return True + if any(model_dir.glob("*.pt")): + return True + return False + + +def _build_load_error(exc: BaseException) -> str: + """生成用户可读的 ChatTTS 加载失败说明。""" + msg = str(exc) + hints = [ + "ChatTTS 模型加载失败。", + f"详情: {msg}", + "", + "常见原因:服务器无法访问 GitHub(read timeout=3)。", + "解决办法(在服务器执行一次):", + f" cd {BASE_DIR}", + " bash scripts/download_chattts_models.sh", + " pm2 restart trading_studio", + "", + f"模型将下载到: {CHATTTS_MODEL_DIR}", + f"HF 镜像: {HF_ENDPOINT}", + ] + return "\n".join(hints) + + +def _load_chat_model(chat) -> None: + """按优先级加载 ChatTTS:本地 custom → 镜像下载到 cache_dir。""" + _ensure_hf_env() + model_dir = CHATTTS_MODEL_DIR + + base_kwargs: Dict[str, Any] = {"compile": False} + + if not hasattr(chat, "load"): + if hasattr(chat, "load_models"): + chat.load_models(**base_kwargs) + return + raise RuntimeError("当前 ChatTTS 版本缺少 load / load_models 方法。") + + sig = inspect.signature(chat.load) + params = sig.parameters + + # 1) 本地已预下载 → 完全离线,不访问 GitHub + if _chattts_model_ready(model_dir): + logger.info("ChatTTS 从本地目录加载 (source=custom): %s", model_dir) + kwargs = dict(base_kwargs) + if "source" in params: + kwargs["source"] = "custom" + if "custom_path" in params: + kwargs["custom_path"] = str(model_dir) + result = chat.load(**kwargs) + if result is False: + raise RuntimeError(f"ChatTTS 本地加载失败,请检查 {model_dir}") + return + + # 2) 未预下载 → 通过 HF 镜像下载到指定目录(仍可能尝试网络) + logger.warning( + "未找到本地 ChatTTS 模型 (%s),尝试通过 HF 镜像下载…", + model_dir, + ) + kwargs = dict(base_kwargs) + if "source" in params: + kwargs["source"] = "local" + if "cache_dir" in params: + kwargs["cache_dir"] = str(model_dir) + elif "source" in params: + kwargs["source"] = "huggingface" + + result = chat.load(**kwargs) + if result is False: + raise RuntimeError( + "ChatTTS 在线下载失败。请执行: bash scripts/download_chattts_models.sh" + ) + + +def reset_chattts_instance() -> None: + """释放 ChatTTS 实例(模型下载后重启前可调用)。""" + global _chat, _chat_error + _chat = None + _chat_error = None + + +def get_chattts_instance(): + """ + 获取或初始化 ChatTTS 模型。 + 启用 GPU 加速,compile=False 以兼容 3060 Ti 8GB 显存。 + """ + global _chat, _chat_error + + if _chat is not None: + return _chat, None + + if _chat_error is not None: + return None, _chat_error + + try: + _ensure_hf_env() + import ChatTTS + + logger.info("正在加载 ChatTTS 模型...") + chat = ChatTTS.Chat() + _load_chat_model(chat) + + _chat = chat + logger.info("ChatTTS 模型加载成功。") + return _chat, None + + except ImportError as exc: + _chat_error = ( + "未安装 ChatTTS,请参考 DEPLOY.md 安装。\n" + f"原始错误: {exc}" + ) + logger.exception("ChatTTS 导入失败") + return None, _chat_error + + except Exception as exc: + _chat_error = _build_load_error(exc) + logger.exception("ChatTTS 初始化异常") + return None, _chat_error + + +def _load_audio_for_chattts(audio_path: str, sample_rate: int = TTS_SAMPLE_RATE) -> np.ndarray: + """ + 加载音频并重采样到 ChatTTS 所需采样率。 + 优先使用 ChatTTS 自带工具,回退到 librosa。 + """ + try: + from ChatTTS.utils import load_audio + + return load_audio(audio_path, sample_rate) + except ImportError: + pass + + try: + from tools.audio import load_audio + + return load_audio(audio_path, sample_rate) + except ImportError: + pass + + import librosa + + audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True) + return audio + + +def _get_audio_duration_sec(audio: np.ndarray, sample_rate: int) -> float: + """计算音频时长(秒)。""" + if audio is None or len(audio) == 0: + return 0.0 + return len(audio) / float(sample_rate) + + +def _encode_spk_emb(chat, tensor_or_str: Any) -> str: + """将 Speaker Embedding 编码为 ChatTTS 可用的字符串格式。""" + if isinstance(tensor_or_str, str): + return tensor_or_str + + if hasattr(chat, "_encode_spk_emb"): + return chat._encode_spk_emb(tensor_or_str) + + return tensor_or_str + + +def save_fixed_speaker( + audio_sample_path: str, + sample_transcript: str = "", +) -> Tuple[bool, str]: + """ + 从 10-30 秒干净人声中提取 Speaker Embedding 并序列化保存。 + + Args: + audio_sample_path: 参考人声 wav/mp3 等路径 + sample_transcript: 参考音频的精确转写(可选,有助于 zero-shot 音色还原) + + Returns: + (success, message) + """ + if not audio_sample_path: + return False, "未提供音色参考音频。" + + chat, init_err = get_chattts_instance() + if chat is None: + return False, init_err or "ChatTTS 不可用。" + + try: + audio = _load_audio_for_chattts(audio_sample_path, TTS_SAMPLE_RATE) + duration = _get_audio_duration_sec(audio, TTS_SAMPLE_RATE) + + if duration < SPEAKER_SAMPLE_MIN_SEC: + return False, ( + f"参考音频过短({duration:.1f}s),建议 {SPEAKER_SAMPLE_MIN_SEC}-" + f"{SPEAKER_SAMPLE_MAX_SEC} 秒干净人声。" + ) + if duration > SPEAKER_SAMPLE_MAX_SEC + 5: + logger.warning("参考音频超过建议时长 %.1fs,将截取前 %ds", duration, SPEAKER_SAMPLE_MAX_SEC) + max_samples = SPEAKER_SAMPLE_MAX_SEC * TTS_SAMPLE_RATE + audio = audio[:max_samples] + + spk_smp = chat.sample_audio_speaker(audio) + spk_emb = _encode_spk_emb(chat, spk_smp) + + payload: Dict[str, Any] = { + "spk_emb": spk_emb, + "spk_smp": spk_smp, + "txt_smp": sample_transcript.strip(), + "created_at": datetime.now().isoformat(), + "source_audio": str(audio_sample_path), + } + + torch.save(payload, SPEAKER_EMB_PATH) + + msg = ( + f"音色已锁定并保存至 {SPEAKER_EMB_PATH}\n" + f"参考时长: {duration:.1f}s" + ) + if not sample_transcript.strip(): + msg += "\n提示:填写参考音频精确转写可进一步提升音色还原度。" + + logger.info("Speaker Embedding 保存成功: %s", SPEAKER_EMB_PATH) + return True, msg + + except Exception as exc: + err = f"音色提取失败: {exc}\n{traceback.format_exc()}" + logger.exception("save_fixed_speaker 失败") + return False, err + + +def _load_speaker_payload() -> Tuple[Optional[Dict[str, Any]], Optional[str]]: + """加载本地 speaker_emb.pt。""" + if not SPEAKER_EMB_PATH.exists(): + return None, ( + f"未找到固定音色文件 `{SPEAKER_EMB_PATH.name}`。" + "请先在【音色锁定】模块上传 10-30 秒参考人声。" + ) + + try: + payload = torch.load(SPEAKER_EMB_PATH, map_location="cpu", weights_only=False) + + if isinstance(payload, torch.Tensor): + chat, err = get_chattts_instance() + if chat is None: + return None, err + return { + "spk_emb": _encode_spk_emb(chat, payload), + "spk_smp": None, + "txt_smp": "", + }, None + + if not isinstance(payload, dict): + return None, "speaker_emb.pt 格式无效,请重新锁定音色。" + + return payload, None + + except Exception as exc: + return None, f"读取 speaker_emb.pt 失败: {exc}" + + +def speaker_is_ready() -> Tuple[bool, str]: + """检查固定音色是否已配置。""" + payload, err = _load_speaker_payload() + if payload is None: + return False, err or "音色未配置。" + return True, f"已加载固定音色: {SPEAKER_EMB_PATH}" + + +def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]: + """ + 使用 ChatTTS 将润色后的文稿合成为 wav 配音。 + + Args: + refined_text: LLM 润色后的配音稿 + + Returns: + (success, message, output_wav_path_or_none) + """ + if not refined_text or not refined_text.strip(): + return False, "合成文本为空,请先完成润色。", None + + chat, init_err = get_chattts_instance() + if chat is None: + return False, init_err or "ChatTTS 不可用。", None + + payload, spk_err = _load_speaker_payload() + if payload is None: + return False, spk_err or "请先锁定音色。", None + + try: + import ChatTTS + + spk_emb = payload.get("spk_emb") + spk_smp = payload.get("spk_smp") + txt_smp = payload.get("txt_smp", "") + + params_infer_code = ChatTTS.Chat.InferCodeParams( + prompt=TTS_SPEED_PROMPT, + spk_emb=spk_emb, + spk_smp=spk_smp if spk_smp else None, + txt_smp=txt_smp if txt_smp else None, + temperature=TTS_TEMPERATURE, + top_P=TTS_TOP_P, + top_K=TTS_TOP_K, + ) + + params_refine_text = ChatTTS.Chat.RefineTextParams( + prompt="[oral_2][laugh_0][break_4]", + ) + + wavs = chat.infer( + refined_text.strip(), + skip_refine_text=False, + params_refine_text=params_refine_text, + params_infer_code=params_infer_code, + ) + + if not wavs or len(wavs) == 0: + return False, "ChatTTS 未生成有效音频。", None + + wav_array = np.asarray(wavs[0], dtype=np.float32) + + peak = np.max(np.abs(wav_array)) or 1.0 + wav_int16 = (wav_array / peak * 32767).astype(np.int16) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"voiceover_{timestamp}_{uuid.uuid4().hex[:6]}.wav" + output_path = OUTPUT_DIR / filename + + wavfile.write(str(output_path), TTS_SAMPLE_RATE, wav_int16) + + msg = f"配音合成成功: {output_path}" + logger.info(msg) + return True, msg, str(output_path) + + except Exception as exc: + err = f"语音合成失败: {exc}\n{traceback.format_exc()}" + logger.exception("generate_voice 失败") + return False, err, None