7c50b13c57
Enable Gradio queue, immediate pending feedback, segment progress, and gr.update for Audio so long syntheses show logs and playback correctly. Co-authored-by: Cursor <cursoragent@cursor.com>
869 lines
29 KiB
Python
869 lines
29 KiB
Python
"""
|
||
ChatTTS 本地语音合成服务
|
||
支持从参考人声提取 Speaker Embedding 并固定音色合成配音。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import inspect
|
||
import logging
|
||
import os
|
||
import re
|
||
import traceback
|
||
import uuid
|
||
import warnings
|
||
from dataclasses import replace
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
import numpy as np
|
||
import torch
|
||
from scipy.io import wavfile
|
||
|
||
from config import (
|
||
BASE_DIR,
|
||
CHATTTS_MODEL_DIR,
|
||
HF_ENDPOINT,
|
||
HF_HOME,
|
||
HF_HUB_DOWNLOAD_TIMEOUT,
|
||
OUTPUT_DIR,
|
||
SPEAKER_EMB_PATH,
|
||
SPEAKER_SAMPLE_MAX_SEC,
|
||
SPEAKER_SAMPLE_MIN_SEC,
|
||
TTS_MAX_CHARS_PER_CHUNK,
|
||
TTS_ENABLE_CACHE,
|
||
TTS_MAX_NEW_TOKEN,
|
||
TTS_MIN_NEW_TOKEN,
|
||
TTS_SAMPLE_RATE,
|
||
TTS_SPEED_PROMPT,
|
||
TTS_TEMPERATURE,
|
||
TTS_TOP_K,
|
||
TTS_TOP_P,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 全局 ChatTTS 实例
|
||
_chat = None
|
||
_chat_error: Optional[str] = None
|
||
|
||
|
||
def _ensure_hf_env() -> None:
|
||
"""配置 HuggingFace 镜像与下载超时,避免默认 3s 访问 GitHub 超时。"""
|
||
os.environ.setdefault("HF_ENDPOINT", HF_ENDPOINT)
|
||
os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", str(HF_HUB_DOWNLOAD_TIMEOUT))
|
||
os.environ.setdefault("HF_HOME", str(HF_HOME))
|
||
HF_HOME.mkdir(parents=True, exist_ok=True)
|
||
CHATTTS_MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
|
||
def _chattts_model_ready(model_dir: Path) -> bool:
|
||
"""检查本地 ChatTTS 模型目录是否完整。"""
|
||
if not model_dir.is_dir():
|
||
return False
|
||
if (model_dir / "config" / "path.yaml").is_file():
|
||
return True
|
||
asset_dir = model_dir / "asset"
|
||
if asset_dir.is_dir() and any(asset_dir.rglob("*.pt")):
|
||
return True
|
||
if any(model_dir.glob("*.pt")):
|
||
return True
|
||
return False
|
||
|
||
|
||
def _build_load_error(exc: BaseException) -> str:
|
||
"""生成用户可读的 ChatTTS 加载失败说明。"""
|
||
msg = str(exc)
|
||
hints = [
|
||
"ChatTTS 模型加载失败。",
|
||
f"详情: {msg}",
|
||
"",
|
||
"常见原因:服务器无法访问 GitHub(read timeout=3)。",
|
||
"解决办法(在服务器执行一次):",
|
||
f" cd {BASE_DIR}",
|
||
" bash scripts/download_chattts_models.sh",
|
||
" pm2 restart trading_studio",
|
||
"",
|
||
f"模型将下载到: {CHATTTS_MODEL_DIR}",
|
||
f"HF 镜像: {HF_ENDPOINT}",
|
||
]
|
||
return "\n".join(hints)
|
||
|
||
|
||
def _load_chat_model(chat) -> None:
|
||
"""按优先级加载 ChatTTS:本地 custom → 镜像下载到 cache_dir。"""
|
||
_ensure_hf_env()
|
||
model_dir = CHATTTS_MODEL_DIR
|
||
|
||
base_kwargs: Dict[str, Any] = {"compile": False, "enable_cache": TTS_ENABLE_CACHE}
|
||
|
||
if not hasattr(chat, "load"):
|
||
if hasattr(chat, "load_models"):
|
||
chat.load_models(**base_kwargs)
|
||
return
|
||
raise RuntimeError("当前 ChatTTS 版本缺少 load / load_models 方法。")
|
||
|
||
sig = inspect.signature(chat.load)
|
||
params = sig.parameters
|
||
|
||
# 1) 本地已预下载 → 完全离线,不访问 GitHub
|
||
if _chattts_model_ready(model_dir):
|
||
logger.info("ChatTTS 从本地目录加载 (source=custom): %s", model_dir)
|
||
kwargs = dict(base_kwargs)
|
||
if "source" in params:
|
||
kwargs["source"] = "custom"
|
||
if "custom_path" in params:
|
||
kwargs["custom_path"] = str(model_dir)
|
||
result = chat.load(**kwargs)
|
||
if result is False:
|
||
raise RuntimeError(f"ChatTTS 本地加载失败,请检查 {model_dir}")
|
||
return
|
||
|
||
# 2) 未预下载 → 通过 HF 镜像下载到指定目录(仍可能尝试网络)
|
||
logger.warning(
|
||
"未找到本地 ChatTTS 模型 (%s),尝试通过 HF 镜像下载…",
|
||
model_dir,
|
||
)
|
||
kwargs = dict(base_kwargs)
|
||
if "source" in params:
|
||
kwargs["source"] = "local"
|
||
if "cache_dir" in params:
|
||
kwargs["cache_dir"] = str(model_dir)
|
||
elif "source" in params:
|
||
kwargs["source"] = "huggingface"
|
||
|
||
result = chat.load(**kwargs)
|
||
if result is False:
|
||
raise RuntimeError(
|
||
"ChatTTS 在线下载失败。请执行: bash scripts/download_chattts_models.sh"
|
||
)
|
||
|
||
|
||
def reset_chattts_instance() -> None:
|
||
"""卸载 ChatTTS 模型并回收 GPU 显存。"""
|
||
global _chat, _chat_error
|
||
if _chat is not None:
|
||
try:
|
||
if hasattr(_chat, "unload"):
|
||
_chat.unload()
|
||
except Exception:
|
||
logger.exception("ChatTTS unload 失败")
|
||
try:
|
||
del _chat
|
||
except Exception:
|
||
pass
|
||
_chat = None
|
||
_chat_error = None
|
||
|
||
from gpu_utils import release_cuda_cache
|
||
|
||
release_cuda_cache()
|
||
logger.info("ChatTTS 模型已卸载,显存已尝试回收。")
|
||
|
||
|
||
def get_chattts_instance():
|
||
"""
|
||
获取或初始化 ChatTTS 模型。
|
||
启用 GPU 加速,compile=False 以兼容 3060 Ti 8GB 显存。
|
||
"""
|
||
global _chat, _chat_error
|
||
|
||
if _chat is not None:
|
||
return _chat, None
|
||
|
||
if _chat_error is not None:
|
||
return None, _chat_error
|
||
|
||
try:
|
||
_ensure_hf_env()
|
||
import ChatTTS
|
||
|
||
logger.info("正在加载 ChatTTS 模型...")
|
||
chat = ChatTTS.Chat()
|
||
_load_chat_model(chat)
|
||
|
||
_chat = chat
|
||
logger.info("ChatTTS 模型加载成功。")
|
||
return _chat, None
|
||
|
||
except ImportError as exc:
|
||
_chat_error = (
|
||
"未安装 ChatTTS,请参考 DEPLOY.md 安装。\n"
|
||
f"原始错误: {exc}"
|
||
)
|
||
logger.exception("ChatTTS 导入失败")
|
||
return None, _chat_error
|
||
|
||
except Exception as exc:
|
||
_chat_error = _build_load_error(exc)
|
||
logger.exception("ChatTTS 初始化异常")
|
||
return None, _chat_error
|
||
|
||
|
||
def _load_audio_via_ffmpeg(audio_path: str, sample_rate: int) -> np.ndarray:
|
||
"""通过 ffmpeg 转码为 wav 再读取,兼容手机 webm/m4a 等格式。"""
|
||
import subprocess
|
||
import tempfile
|
||
|
||
import soundfile as sf
|
||
|
||
tmp_path = tempfile.mktemp(suffix=".wav")
|
||
try:
|
||
cmd = [
|
||
"ffmpeg",
|
||
"-y",
|
||
"-i",
|
||
audio_path,
|
||
"-ac",
|
||
"1",
|
||
"-ar",
|
||
str(sample_rate),
|
||
"-f",
|
||
"wav",
|
||
tmp_path,
|
||
]
|
||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
|
||
if result.returncode != 0:
|
||
raise RuntimeError(result.stderr[-500:] if result.stderr else "ffmpeg 转码失败")
|
||
|
||
audio, _ = sf.read(tmp_path, dtype="float32", always_2d=False)
|
||
if isinstance(audio, np.ndarray) and audio.ndim > 1:
|
||
audio = audio.mean(axis=1)
|
||
return np.asarray(audio, dtype=np.float32)
|
||
finally:
|
||
Path(tmp_path).unlink(missing_ok=True)
|
||
|
||
|
||
def _load_audio_for_chattts(audio_path: str, sample_rate: int = TTS_SAMPLE_RATE) -> np.ndarray:
|
||
"""
|
||
加载音频并重采样到 ChatTTS 所需采样率。
|
||
优先 ChatTTS 工具 → ffmpeg 转码 → librosa 兜底。
|
||
"""
|
||
errors: list[str] = []
|
||
|
||
try:
|
||
from ChatTTS.utils import load_audio
|
||
|
||
return load_audio(audio_path, sample_rate)
|
||
except Exception as exc:
|
||
errors.append(f"ChatTTS.utils: {exc}")
|
||
|
||
try:
|
||
from tools.audio import load_audio
|
||
|
||
return load_audio(audio_path, sample_rate)
|
||
except Exception as exc:
|
||
errors.append(f"tools.audio: {exc}")
|
||
|
||
try:
|
||
return _load_audio_via_ffmpeg(audio_path, sample_rate)
|
||
except Exception as exc:
|
||
errors.append(f"ffmpeg: {exc}")
|
||
|
||
try:
|
||
import librosa
|
||
|
||
with warnings.catch_warnings():
|
||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||
warnings.filterwarnings("ignore", message="PySoundFile failed")
|
||
audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True)
|
||
return audio
|
||
except Exception as exc:
|
||
errors.append(f"librosa: {exc}")
|
||
|
||
raise RuntimeError(
|
||
"无法读取音频文件,请上传 wav/mp3/m4a 或确认已安装 ffmpeg。\n"
|
||
+ "\n".join(errors[-3:])
|
||
)
|
||
|
||
|
||
def _get_audio_duration_sec(audio: np.ndarray, sample_rate: int) -> float:
|
||
"""计算音频时长(秒)。"""
|
||
if audio is None or len(audio) == 0:
|
||
return 0.0
|
||
return len(audio) / float(sample_rate)
|
||
|
||
|
||
def _encode_random_spk_emb(chat, tensor: torch.Tensor) -> Optional[str]:
|
||
"""将随机说话人向量编码为 spk_emb 字符串(仅用于 sample_random,非参考音频)。"""
|
||
speaker = getattr(chat, "speaker", None)
|
||
if speaker is not None and hasattr(speaker, "_encode"):
|
||
return speaker._encode(tensor)
|
||
if hasattr(chat, "_encode_spk_emb"):
|
||
return chat._encode_spk_emb(tensor)
|
||
return None
|
||
|
||
|
||
def _is_valid_spk_emb_string(chat, spk_emb: str) -> bool:
|
||
"""spk_emb 与 spk_smp 编码不同;非法字符串会在 lzma 解压时报 Corrupt input data。"""
|
||
speaker = getattr(chat, "speaker", None)
|
||
if speaker is None or not hasattr(speaker, "_decode"):
|
||
return False
|
||
try:
|
||
speaker._decode(spk_emb)
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
def _normalize_speaker_for_infer(
|
||
chat,
|
||
payload: Dict[str, Any],
|
||
) -> Tuple[Optional[Dict[str, Optional[str]]], Optional[str]]:
|
||
"""
|
||
规范 ChatTTS 音色参数。
|
||
参考音频克隆必须用 spk_smp + txt_smp,不能把 sample_audio_speaker 结果传给 spk_emb。
|
||
"""
|
||
spk_smp = payload.get("spk_smp")
|
||
txt_smp = (payload.get("txt_smp") or "").strip() or None
|
||
spk_emb = payload.get("spk_emb")
|
||
warn: Optional[str] = None
|
||
|
||
if spk_smp:
|
||
if not txt_smp:
|
||
warn = (
|
||
"未填写参考音频转写(txt_smp),音色克隆可能不稳定。"
|
||
"建议在「音色锁定」补充精确转写后重新锁定。"
|
||
)
|
||
return {"spk_smp": spk_smp, "txt_smp": txt_smp, "spk_emb": None}, warn
|
||
|
||
if isinstance(spk_emb, str) and spk_emb.strip():
|
||
if _is_valid_spk_emb_string(chat, spk_emb):
|
||
return {"spk_emb": spk_emb, "spk_smp": None, "txt_smp": None}, None
|
||
# 旧版误存:把 spk_smp 写进了 spk_emb
|
||
return {
|
||
"spk_smp": spk_emb,
|
||
"txt_smp": txt_smp,
|
||
"spk_emb": None,
|
||
}, (
|
||
"检测到旧版音色文件格式,已自动按 spk_smp 加载。"
|
||
"建议重新锁定音色并填写参考转写。"
|
||
)
|
||
|
||
if isinstance(spk_emb, torch.Tensor):
|
||
encoded = _encode_random_spk_emb(chat, spk_emb)
|
||
if encoded:
|
||
return {"spk_emb": encoded, "spk_smp": None, "txt_smp": None}, None
|
||
return None, "旧版音色张量无法编码,请重新锁定音色。"
|
||
|
||
return None, "音色数据无效或已损坏,请重新锁定音色。"
|
||
|
||
|
||
def save_fixed_speaker(
|
||
audio_sample_path: str,
|
||
sample_transcript: str = "",
|
||
) -> Tuple[bool, str]:
|
||
"""
|
||
从 10-30 秒干净人声中提取 Speaker Embedding 并序列化保存。
|
||
|
||
Args:
|
||
audio_sample_path: 参考人声 wav/mp3 等路径
|
||
sample_transcript: 参考音频的精确转写(可选,有助于 zero-shot 音色还原)
|
||
|
||
Returns:
|
||
(success, message)
|
||
"""
|
||
if not audio_sample_path:
|
||
return False, "未提供音色参考音频。"
|
||
|
||
try:
|
||
from whisper_service import reset_whisper_model
|
||
|
||
reset_whisper_model()
|
||
except Exception:
|
||
logger.debug("锁定音色前释放 Whisper 显存跳过", exc_info=True)
|
||
|
||
chat, init_err = get_chattts_instance()
|
||
if chat is None:
|
||
return False, init_err or "ChatTTS 不可用。"
|
||
|
||
try:
|
||
audio = _load_audio_for_chattts(audio_sample_path, TTS_SAMPLE_RATE)
|
||
duration = _get_audio_duration_sec(audio, TTS_SAMPLE_RATE)
|
||
|
||
if duration < SPEAKER_SAMPLE_MIN_SEC:
|
||
return False, (
|
||
f"参考音频过短({duration:.1f}s),建议 {SPEAKER_SAMPLE_MIN_SEC}-"
|
||
f"{SPEAKER_SAMPLE_MAX_SEC} 秒干净人声。"
|
||
)
|
||
if duration > SPEAKER_SAMPLE_MAX_SEC + 5:
|
||
logger.warning("参考音频超过建议时长 %.1fs,将截取前 %ds", duration, SPEAKER_SAMPLE_MAX_SEC)
|
||
max_samples = SPEAKER_SAMPLE_MAX_SEC * TTS_SAMPLE_RATE
|
||
audio = audio[:max_samples]
|
||
|
||
spk_smp = chat.sample_audio_speaker(audio)
|
||
|
||
payload: Dict[str, Any] = {
|
||
"version": 2,
|
||
"spk_smp": spk_smp,
|
||
"txt_smp": sample_transcript.strip(),
|
||
"created_at": datetime.now().isoformat(),
|
||
"source_audio": str(audio_sample_path),
|
||
}
|
||
|
||
torch.save(payload, SPEAKER_EMB_PATH)
|
||
|
||
msg = (
|
||
f"音色已锁定并保存至 {SPEAKER_EMB_PATH}\n"
|
||
f"参考时长: {duration:.1f}s"
|
||
)
|
||
if not sample_transcript.strip():
|
||
msg += (
|
||
"\n⚠️ 未填写参考转写:合成时可能报 Corrupt input data 或音色不稳。"
|
||
"请填写与录音一致的精确转写后重新锁定。"
|
||
)
|
||
|
||
logger.info("Speaker Embedding 保存成功: %s", SPEAKER_EMB_PATH)
|
||
return True, msg
|
||
|
||
except Exception as exc:
|
||
err = f"音色提取失败: {exc}\n{traceback.format_exc()}"
|
||
logger.exception("save_fixed_speaker 失败")
|
||
return False, err
|
||
|
||
|
||
def _load_speaker_payload() -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||
"""加载本地 speaker_emb.pt。"""
|
||
if not SPEAKER_EMB_PATH.exists():
|
||
return None, (
|
||
f"未找到固定音色文件 `{SPEAKER_EMB_PATH.name}`。"
|
||
"请先在【音色锁定】模块上传 10-30 秒参考人声。"
|
||
)
|
||
|
||
try:
|
||
payload = torch.load(SPEAKER_EMB_PATH, map_location="cpu", weights_only=False)
|
||
|
||
if isinstance(payload, torch.Tensor):
|
||
chat, err = get_chattts_instance()
|
||
if chat is None:
|
||
return None, err
|
||
encoded = _encode_random_spk_emb(chat, payload)
|
||
if not encoded:
|
||
return None, "旧版音色张量无法读取,请重新锁定音色。"
|
||
return {
|
||
"spk_emb": encoded,
|
||
"spk_smp": None,
|
||
"txt_smp": "",
|
||
}, None
|
||
|
||
if not isinstance(payload, dict):
|
||
return None, "speaker_emb.pt 格式无效,请重新锁定音色。"
|
||
|
||
return payload, None
|
||
|
||
except Exception as exc:
|
||
return None, f"读取 speaker_emb.pt 失败: {exc}"
|
||
|
||
|
||
def speaker_is_ready() -> Tuple[bool, str]:
|
||
"""检查固定音色是否已配置。"""
|
||
payload, err = _load_speaker_payload()
|
||
if payload is None:
|
||
return False, err or "音色未配置。"
|
||
return True, f"已加载固定音色: {SPEAKER_EMB_PATH}"
|
||
|
||
|
||
_EMOJI_RE = re.compile(
|
||
"["
|
||
"\U0001F300-\U0001FAFF"
|
||
"\U00002700-\U000027BF"
|
||
"\U00002600-\U000026FF"
|
||
"]+",
|
||
flags=re.UNICODE,
|
||
)
|
||
|
||
_TTS_NOTE_MARKERS = (
|
||
"💡",
|
||
"量化交易员的修改笔记",
|
||
"修改笔记(供你参考)",
|
||
"修改笔记",
|
||
"供你参考",
|
||
)
|
||
|
||
_STAGE_DIRECTION_RE = re.compile(
|
||
r"[((][^))]{0,80}(?:前奏|转场|语气|背景|BGM|配乐|节奏|环节)[^))]{0,80}[))]"
|
||
)
|
||
|
||
_CN_DIGITS = "零一二三四五六七八九"
|
||
|
||
# ChatTTS tokenizer 对裸 ASCII 数字、控制符敏感,易触发 CUDA device-side assert
|
||
_TTS_UNSAFE_CHAR_RE = re.compile(
|
||
r"[\u200b-\u200f\u202a-\u202e\ufeff\x00-\x08\x0b\x0c\x0e-\x1f]"
|
||
)
|
||
_TTS_ALLOWED_CHAR_RE = re.compile(
|
||
r"[^\u4e00-\u9fff\u3400-\u4dbfA-Za-z0-9,。!?;:、「」『』()—…\-\s'\"《》%%]"
|
||
)
|
||
|
||
|
||
def _digits_to_chinese(text: str) -> str:
|
||
def _repl(match: re.Match[str]) -> str:
|
||
return "".join(_CN_DIGITS[int(ch)] for ch in match.group())
|
||
|
||
return re.sub(r"\d+", _repl, text)
|
||
|
||
|
||
def _normalize_tts_chunk(text: str) -> str:
|
||
"""单段合成用:去控制符、数字转中文、合并换行为逗号。"""
|
||
text = _TTS_UNSAFE_CHAR_RE.sub("", text)
|
||
text = text.replace("\r", "").replace("\n", ",")
|
||
text = _digits_to_chinese(text)
|
||
text = _TTS_ALLOWED_CHAR_RE.sub("", text)
|
||
text = re.sub(r"[,,]{2,}", ",", text)
|
||
text = re.sub(r"\s+", "", text)
|
||
return text.strip(",。 \t")
|
||
|
||
|
||
def prepare_text_for_tts(text: str) -> str:
|
||
"""
|
||
将 LLM 润色稿转为 ChatTTS 可朗读的纯文本。
|
||
去除 Markdown、emoji、舞台提示、修改笔记等非朗读内容。
|
||
"""
|
||
if not text:
|
||
return ""
|
||
|
||
cleaned = text.replace("\r\n", "\n").strip()
|
||
|
||
for marker in _TTS_NOTE_MARKERS:
|
||
idx = cleaned.find(marker)
|
||
if idx >= 0:
|
||
cleaned = cleaned[:idx]
|
||
|
||
# 去掉模型常见前言,从标题或正文起点开始
|
||
for pattern in (
|
||
r"^作为一名极其严谨的量化交易员.*?配音稿。\s*",
|
||
r"^以下是为你润色后的文案[::]*\s*",
|
||
r"^以下(?:是|为).*?润色.*?文案[::]*\s*",
|
||
):
|
||
cleaned = re.sub(pattern, "", cleaned, count=1, flags=re.DOTALL)
|
||
|
||
cleaned = re.sub(r"^\*{3,}\s*$", "", cleaned, flags=re.MULTILINE)
|
||
cleaned = re.sub(r"^-{3,}\s*$", "", cleaned, flags=re.MULTILINE)
|
||
cleaned = re.sub(r"^#{1,6}\s*", "", cleaned, flags=re.MULTILINE)
|
||
cleaned = re.sub(r"\*\*([^*\n]+)\*\*", r"\1", cleaned)
|
||
cleaned = re.sub(r"\*([^*\n]+)\*", r"\1", cleaned)
|
||
cleaned = re.sub(r"__([^_\n]+)__", r"\1", cleaned)
|
||
cleaned = _STAGE_DIRECTION_RE.sub("", cleaned)
|
||
cleaned = _EMOJI_RE.sub("", cleaned)
|
||
cleaned = re.sub(r"^\d+\.\s*", "", cleaned, flags=re.MULTILINE)
|
||
cleaned = re.sub(r"^[-*]\s+", "", cleaned, flags=re.MULTILINE)
|
||
cleaned = re.sub(r"[ \t]+\n", "\n", cleaned)
|
||
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
|
||
|
||
lines = [ln.strip() for ln in cleaned.split("\n")]
|
||
lines = [ln for ln in lines if ln and not re.fullmatch(r"[*\-#]+", ln)]
|
||
merged = "。".join(lines)
|
||
return _normalize_tts_chunk(merged)
|
||
|
||
|
||
def split_text_for_tts(text: str, max_chars: int = TTS_MAX_CHARS_PER_CHUNK) -> List[str]:
|
||
"""按句号/换行切分长稿,避免 ChatTTS 单段过长失败。"""
|
||
text = text.strip()
|
||
if not text:
|
||
return []
|
||
if len(text) <= max_chars:
|
||
return [text]
|
||
|
||
parts = re.split(r"(?<=[。!?!?;;])\s*|\n+", text)
|
||
chunks: List[str] = []
|
||
buf = ""
|
||
|
||
for part in parts:
|
||
part = part.strip()
|
||
if not part:
|
||
continue
|
||
candidate = f"{buf}{part}" if buf else part
|
||
if len(candidate) <= max_chars:
|
||
buf = candidate
|
||
continue
|
||
if buf:
|
||
chunks.append(buf)
|
||
buf = ""
|
||
if len(part) <= max_chars:
|
||
buf = part
|
||
continue
|
||
for i in range(0, len(part), max_chars):
|
||
chunks.append(part[i : i + max_chars])
|
||
|
||
if buf:
|
||
chunks.append(buf)
|
||
|
||
return [_normalize_tts_chunk(c) for c in chunks if c.strip()]
|
||
|
||
|
||
def _is_cuda_runtime_error(exc: BaseException) -> bool:
|
||
from gpu_utils import is_cuda_runtime_error
|
||
|
||
return is_cuda_runtime_error(exc)
|
||
|
||
|
||
def _run_chattts_infer(
|
||
chat: Any,
|
||
chunk: str,
|
||
params_refine_text: Any,
|
||
params_infer_code: Any,
|
||
) -> Any:
|
||
"""单次 ChatTTS infer;split_text=False 避免段内再切分引发 mask 异常。"""
|
||
return chat.infer(
|
||
chunk,
|
||
skip_refine_text=False,
|
||
split_text=False,
|
||
do_text_normalization=True,
|
||
do_homophone_replacement=True,
|
||
params_refine_text=params_refine_text,
|
||
params_infer_code=params_infer_code,
|
||
)
|
||
|
||
|
||
def _concat_wavs(
|
||
wavs: List[np.ndarray],
|
||
sample_rate: int,
|
||
pause_sec: float = 0.35,
|
||
) -> np.ndarray:
|
||
if not wavs:
|
||
return np.array([], dtype=np.float32)
|
||
|
||
pause = np.zeros(int(sample_rate * pause_sec), dtype=np.float32)
|
||
segments: List[np.ndarray] = []
|
||
for i, wav in enumerate(wavs):
|
||
segments.append(np.asarray(wav, dtype=np.float32).flatten())
|
||
if i < len(wavs) - 1:
|
||
segments.append(pause)
|
||
return np.concatenate(segments)
|
||
|
||
|
||
def generate_voice(
|
||
refined_text: str,
|
||
voice_id: str = "custom",
|
||
progress_callback=None,
|
||
) -> Tuple[bool, str, Optional[str]]:
|
||
"""
|
||
使用 ChatTTS(本地 GPU)将润色稿合成为 wav。
|
||
|
||
Args:
|
||
refined_text: LLM 润色后的配音稿
|
||
voice_id: ``custom`` 为锁定音色,``preset_*`` 为内置预设(见 voice_presets)
|
||
|
||
Returns:
|
||
(success, message, output_wav_path_or_none)
|
||
"""
|
||
if not refined_text or not refined_text.strip():
|
||
return False, "合成文本为空,请先完成润色。", None
|
||
|
||
# 合成前释放 Whisper,避免 8GB 显存上双模型 OOM
|
||
try:
|
||
from whisper_service import reset_whisper_model
|
||
|
||
reset_whisper_model()
|
||
except Exception:
|
||
logger.debug("合成前释放 Whisper 显存跳过", exc_info=True)
|
||
|
||
from gpu_utils import cuda_memory_summary, release_cuda_cache
|
||
|
||
release_cuda_cache()
|
||
logger.info("TTS 合成前 %s", cuda_memory_summary())
|
||
|
||
chat, init_err = get_chattts_instance()
|
||
if chat is None:
|
||
return False, init_err or "ChatTTS 不可用。", None
|
||
|
||
from voice_presets import load_voice_payload
|
||
|
||
payload, spk_err = load_voice_payload(voice_id)
|
||
if payload is None:
|
||
return False, spk_err or "请先选择或生成可用音色。", None
|
||
|
||
try:
|
||
import ChatTTS
|
||
|
||
speak_text = prepare_text_for_tts(refined_text)
|
||
if not speak_text:
|
||
return (
|
||
False,
|
||
"清洗后无有效朗读文本。请删除 Markdown(#、**)、emoji、舞台提示和「修改笔记」,"
|
||
"只保留可念出的正文后再合成。",
|
||
None,
|
||
)
|
||
|
||
chunks = split_text_for_tts(speak_text)
|
||
if not chunks:
|
||
return False, "无法切分朗读文本,请检查润色稿内容。", None
|
||
|
||
speaker_params, speaker_warn = _normalize_speaker_for_infer(chat, payload)
|
||
if speaker_params is None:
|
||
return False, speaker_warn or "音色参数无效,请重新锁定音色。", None
|
||
if speaker_warn:
|
||
logger.warning(speaker_warn)
|
||
|
||
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
||
prompt=TTS_SPEED_PROMPT,
|
||
spk_emb=speaker_params.get("spk_emb"),
|
||
spk_smp=speaker_params.get("spk_smp"),
|
||
txt_smp=speaker_params.get("txt_smp"),
|
||
temperature=TTS_TEMPERATURE,
|
||
top_P=TTS_TOP_P,
|
||
top_K=TTS_TOP_K,
|
||
max_new_token=TTS_MAX_NEW_TOKEN,
|
||
min_new_token=TTS_MIN_NEW_TOKEN,
|
||
ensure_non_empty=False,
|
||
)
|
||
|
||
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
||
prompt="[oral_2][laugh_0][break_4]",
|
||
ensure_non_empty=False,
|
||
min_new_token=4,
|
||
)
|
||
|
||
logger.info(
|
||
"TTS 合成: 原文 %d 字 → 清洗后 %d 字,分 %d 段",
|
||
len(refined_text),
|
||
len(speak_text),
|
||
len(chunks),
|
||
)
|
||
|
||
segment_wavs: List[np.ndarray] = []
|
||
for idx, chunk in enumerate(chunks, start=1):
|
||
if not chunk or len(chunk) < 2:
|
||
continue
|
||
if progress_callback is not None:
|
||
try:
|
||
progress_callback(idx, len(chunks))
|
||
except Exception:
|
||
logger.debug("TTS 进度回调失败", exc_info=True)
|
||
release_cuda_cache()
|
||
chunk_infer = replace(params_infer_code, manual_seed=42 + idx)
|
||
wavs = None
|
||
last_exc: Optional[BaseException] = None
|
||
for attempt in range(3):
|
||
try:
|
||
wavs = _run_chattts_infer(
|
||
chat, chunk, params_refine_text, chunk_infer
|
||
)
|
||
break
|
||
except RecursionError as exc:
|
||
last_exc = exc
|
||
chunk_infer = replace(
|
||
chunk_infer, manual_seed=1000 + idx * 10 + attempt
|
||
)
|
||
release_cuda_cache()
|
||
except RuntimeError as exc:
|
||
last_exc = exc
|
||
if not _is_cuda_runtime_error(exc) or attempt >= 2:
|
||
raise
|
||
logger.warning(
|
||
"第 %d 段 CUDA 异常,重置 ChatTTS 后重试 (%d/3): %s",
|
||
idx,
|
||
attempt + 1,
|
||
exc,
|
||
)
|
||
reset_chattts_instance()
|
||
release_cuda_cache()
|
||
chat, reload_err = get_chattts_instance()
|
||
if chat is None:
|
||
raise RuntimeError(reload_err or "ChatTTS 重载失败") from exc
|
||
chunk_infer = replace(
|
||
chunk_infer, manual_seed=2000 + idx * 10 + attempt
|
||
)
|
||
if wavs is None:
|
||
return (
|
||
False,
|
||
f"ChatTTS 第 {idx}/{len(chunks)} 段合成失败(递归重试耗尽)。"
|
||
f"请检查音色转写是否填写,或缩短该段文本。"
|
||
f" 详情: {last_exc}",
|
||
None,
|
||
)
|
||
if not wavs or len(wavs) == 0:
|
||
return (
|
||
False,
|
||
f"ChatTTS 第 {idx}/{len(chunks)} 段未生成音频。"
|
||
f"(段内容前 40 字: {chunk[:40]}…)",
|
||
None,
|
||
)
|
||
wav_arr = np.asarray(wavs[0], dtype=np.float32)
|
||
if wav_arr.size == 0 or np.max(np.abs(wav_arr)) < 1e-6:
|
||
return (
|
||
False,
|
||
f"ChatTTS 第 {idx}/{len(chunks)} 段生成了空音频。"
|
||
"请重新锁定音色并填写参考转写,或缩短润色稿后重试。",
|
||
None,
|
||
)
|
||
segment_wavs.append(wav_arr)
|
||
release_cuda_cache()
|
||
|
||
if not segment_wavs:
|
||
return False, "分段清洗后无有效文本,请缩短或简化润色稿后重试。", None
|
||
|
||
wav_array = (
|
||
segment_wavs[0]
|
||
if len(segment_wavs) == 1
|
||
else _concat_wavs(segment_wavs, TTS_SAMPLE_RATE)
|
||
)
|
||
|
||
peak = np.max(np.abs(wav_array)) or 1.0
|
||
wav_int16 = (wav_array / peak * 32767).astype(np.int16)
|
||
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
filename = f"voiceover_{timestamp}_{uuid.uuid4().hex[:6]}.wav"
|
||
output_path = OUTPUT_DIR / filename
|
||
|
||
wavfile.write(str(output_path), TTS_SAMPLE_RATE, wav_int16)
|
||
|
||
chunk_note = f",共 {len(chunks)} 段拼接" if len(chunks) > 1 else ""
|
||
msg = (
|
||
f"配音合成成功: {output_path}"
|
||
f"(朗读 {len(speak_text)} 字{chunk_note})"
|
||
)
|
||
if speaker_warn:
|
||
msg = f"{speaker_warn}\n{msg}"
|
||
logger.info(msg)
|
||
return True, msg, str(output_path)
|
||
|
||
except Exception as exc:
|
||
exc_msg = str(exc)
|
||
if "out of memory" in exc_msg.lower() or "OutOfMemoryError" in exc_msg:
|
||
release_cuda_cache()
|
||
err = (
|
||
"语音合成失败: GPU 显存不足(CUDA OOM)。\n"
|
||
"3060 Ti 8GB 无法同时运行 Whisper + ChatTTS。\n"
|
||
"处理步骤:\n"
|
||
"1. pm2 restart trading_studio 释放显存\n"
|
||
"2. 不要连续快速点识别+合成;合成前系统会自动卸载 Whisper\n"
|
||
"3. 若仍 OOM,在 .env 设置 TTS_MAX_CHARS_PER_CHUNK=150、TTS_MAX_NEW_TOKEN=768\n"
|
||
"4. 确认无其他程序占用 GPU: nvidia-smi\n"
|
||
f"技术详情: {exc_msg[:400]}"
|
||
)
|
||
elif _is_cuda_runtime_error(exc):
|
||
reset_chattts_instance()
|
||
release_cuda_cache()
|
||
err = (
|
||
"语音合成失败: GPU/CUDA 异常(device-side assert 等)。\n"
|
||
"常见原因:此前 OOM 导致 GPU 状态损坏,或文本含特殊字符。\n"
|
||
"处理步骤:\n"
|
||
"1. pm2 restart trading_studio(必须,清理 GPU 脏状态)\n"
|
||
"2. 确认已填写参考音频转写并重新锁定音色\n"
|
||
"3. 用 2-3 句短中文试合成\n"
|
||
"4. 若仍失败,在 .env 设 TTS_ENABLE_CACHE=false 后重启\n"
|
||
f"技术详情: {exc_msg[:500]}"
|
||
)
|
||
elif "recursion depth" in exc_msg.lower() or isinstance(exc, RecursionError):
|
||
err = (
|
||
"语音合成失败: ChatTTS 反复生成空结果导致递归超限。\n"
|
||
"常见原因:未填写参考音频转写、润色稿含特殊符号、或音色文件异常。\n"
|
||
"处理:重新锁定音色并填写转写 → 用较短纯文本试合成。\n"
|
||
f"技术详情: {exc_msg[:400]}"
|
||
)
|
||
elif "Corrupt input data" in exc_msg:
|
||
err = (
|
||
"语音合成失败: 音色数据损坏或格式不兼容(Corrupt input data)。\n"
|
||
"处理步骤:\n"
|
||
"1. 删除旧音色: rm speaker_emb.pt\n"
|
||
"2. 在「音色锁定」重新上传参考人声\n"
|
||
"3. 填写与录音一致的「参考音频精确转写」(必填)\n"
|
||
"4. 重新点击锁定音色后再合成\n"
|
||
f"技术详情: {exc_msg}"
|
||
)
|
||
else:
|
||
err = f"语音合成失败: {exc}\n{traceback.format_exc()}"
|
||
logger.exception("generate_voice 失败")
|
||
return False, err, None
|