Initial commit: add Trading Studio voice-over pipeline for quant trading review videos.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-12 13:19:44 +08:00
commit 5e95d3af2f
10 changed files with 1862 additions and 0 deletions
+305
View File
@@ -0,0 +1,305 @@
"""
ChatTTS 本地语音合成服务
支持从参考人声提取 Speaker Embedding 并固定音色合成配音。
"""
from __future__ import annotations
import logging
import traceback
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
from scipy.io import wavfile
from config import (
OUTPUT_DIR,
SPEAKER_EMB_PATH,
SPEAKER_SAMPLE_MAX_SEC,
SPEAKER_SAMPLE_MIN_SEC,
TTS_SAMPLE_RATE,
TTS_SPEED_PROMPT,
TTS_TEMPERATURE,
TTS_TOP_K,
TTS_TOP_P,
)
logger = logging.getLogger(__name__)
# 全局 ChatTTS 实例
_chat = None
_chat_error: Optional[str] = None
def _load_audio_for_chattts(audio_path: str, sample_rate: int = TTS_SAMPLE_RATE) -> np.ndarray:
"""
加载音频并重采样到 ChatTTS 所需采样率。
优先使用 ChatTTS 自带工具,回退到 librosa。
"""
try:
from ChatTTS.utils import load_audio
return load_audio(audio_path, sample_rate)
except ImportError:
pass
try:
from tools.audio import load_audio
return load_audio(audio_path, sample_rate)
except ImportError:
pass
import librosa
audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True)
return audio
def _get_audio_duration_sec(audio: np.ndarray, sample_rate: int) -> float:
"""计算音频时长(秒)。"""
if audio is None or len(audio) == 0:
return 0.0
return len(audio) / float(sample_rate)
def get_chattts_instance():
"""
获取或初始化 ChatTTS 模型。
启用 GPU 加速,compile=False 以兼容 3060 Ti 8GB 显存。
"""
global _chat, _chat_error
if _chat is not None:
return _chat, None
if _chat_error is not None:
return None, _chat_error
try:
import ChatTTS
logger.info("正在加载 ChatTTS 模型...")
chat = ChatTTS.Chat()
# 兼容不同版本 APIload_models(旧版)/ load(新版)
if hasattr(chat, "load_models"):
chat.load_models(compile=False)
elif hasattr(chat, "load"):
chat.load(compile=False)
else:
raise RuntimeError("当前 ChatTTS 版本缺少 load / load_models 方法。")
_chat = chat
logger.info("ChatTTS 模型加载成功。")
return _chat, None
except ImportError as exc:
_chat_error = (
"未安装 ChatTTS,请参考 DEPLOY.md 安装。\n"
f"原始错误: {exc}"
)
logger.exception("ChatTTS 导入失败")
return None, _chat_error
except Exception as exc:
_chat_error = f"ChatTTS 模型加载失败: {exc}\n{traceback.format_exc()}"
logger.exception("ChatTTS 初始化异常")
return None, _chat_error
def _encode_spk_emb(chat, tensor_or_str: Any) -> str:
"""将 Speaker Embedding 编码为 ChatTTS 可用的字符串格式。"""
if isinstance(tensor_or_str, str):
return tensor_or_str
if hasattr(chat, "_encode_spk_emb"):
return chat._encode_spk_emb(tensor_or_str)
# 兜底:直接转字符串(部分版本可接受 tensor)
return tensor_or_str
def save_fixed_speaker(
audio_sample_path: str,
sample_transcript: str = "",
) -> Tuple[bool, str]:
"""
从 10-30 秒干净人声中提取 Speaker Embedding 并序列化保存。
Args:
audio_sample_path: 参考人声 wav/mp3 等路径
sample_transcript: 参考音频的精确转写(可选,有助于 zero-shot 音色还原)
Returns:
(success, message)
"""
if not audio_sample_path:
return False, "未提供音色参考音频。"
chat, init_err = get_chattts_instance()
if chat is None:
return False, init_err or "ChatTTS 不可用。"
try:
audio = _load_audio_for_chattts(audio_sample_path, TTS_SAMPLE_RATE)
duration = _get_audio_duration_sec(audio, TTS_SAMPLE_RATE)
if duration < SPEAKER_SAMPLE_MIN_SEC:
return False, (
f"参考音频过短({duration:.1f}s),建议 {SPEAKER_SAMPLE_MIN_SEC}-"
f"{SPEAKER_SAMPLE_MAX_SEC} 秒干净人声。"
)
if duration > SPEAKER_SAMPLE_MAX_SEC + 5:
logger.warning("参考音频超过建议时长 %.1fs,将截取前 %ds", duration, SPEAKER_SAMPLE_MAX_SEC)
max_samples = SPEAKER_SAMPLE_MAX_SEC * TTS_SAMPLE_RATE
audio = audio[:max_samples]
# 从参考音频提取音色特征
spk_smp = chat.sample_audio_speaker(audio)
# 同时保存编码后的 spk_emb 字符串,便于 infer 时直接使用
spk_emb = _encode_spk_emb(chat, spk_smp)
payload: Dict[str, Any] = {
"spk_emb": spk_emb,
"spk_smp": spk_smp,
"txt_smp": sample_transcript.strip(),
"created_at": datetime.now().isoformat(),
"source_audio": str(audio_sample_path),
}
torch.save(payload, SPEAKER_EMB_PATH)
msg = (
f"音色已锁定并保存至 {SPEAKER_EMB_PATH}\n"
f"参考时长: {duration:.1f}s"
)
if not sample_transcript.strip():
msg += "\n提示:填写参考音频精确转写可进一步提升音色还原度。"
logger.info("Speaker Embedding 保存成功: %s", SPEAKER_EMB_PATH)
return True, msg
except Exception as exc:
err = f"音色提取失败: {exc}\n{traceback.format_exc()}"
logger.exception("save_fixed_speaker 失败")
return False, err
def _load_speaker_payload() -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
"""加载本地 speaker_emb.pt。"""
if not SPEAKER_EMB_PATH.exists():
return None, (
f"未找到固定音色文件 `{SPEAKER_EMB_PATH.name}`。"
"请先在【音色锁定】模块上传 10-30 秒参考人声。"
)
try:
payload = torch.load(SPEAKER_EMB_PATH, map_location="cpu", weights_only=False)
# 兼容旧版仅保存 tensor 的文件
if isinstance(payload, torch.Tensor):
chat, err = get_chattts_instance()
if chat is None:
return None, err
return {
"spk_emb": _encode_spk_emb(chat, payload),
"spk_smp": None,
"txt_smp": "",
}, None
if not isinstance(payload, dict):
return None, "speaker_emb.pt 格式无效,请重新锁定音色。"
return payload, None
except Exception as exc:
return None, f"读取 speaker_emb.pt 失败: {exc}"
def speaker_is_ready() -> Tuple[bool, str]:
"""检查固定音色是否已配置。"""
payload, err = _load_speaker_payload()
if payload is None:
return False, err or "音色未配置。"
return True, f"已加载固定音色: {SPEAKER_EMB_PATH}"
def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
"""
使用 ChatTTS 将润色后的文稿合成为 wav 配音。
Args:
refined_text: LLM 润色后的配音稿
Returns:
(success, message, output_wav_path_or_none)
"""
if not refined_text or not refined_text.strip():
return False, "合成文本为空,请先完成润色。", None
chat, init_err = get_chattts_instance()
if chat is None:
return False, init_err or "ChatTTS 不可用。", None
payload, spk_err = _load_speaker_payload()
if payload is None:
return False, spk_err or "请先锁定音色。", None
try:
import ChatTTS
spk_emb = payload.get("spk_emb")
spk_smp = payload.get("spk_smp")
txt_smp = payload.get("txt_smp", "")
params_infer_code = ChatTTS.Chat.InferCodeParams(
prompt=TTS_SPEED_PROMPT,
spk_emb=spk_emb,
spk_smp=spk_smp if spk_smp else None,
txt_smp=txt_smp if txt_smp else None,
temperature=TTS_TEMPERATURE,
top_P=TTS_TOP_P,
top_K=TTS_TOP_K,
)
# 内向克制语气:降低 oral 强度
params_refine_text = ChatTTS.Chat.RefineTextParams(
prompt="[oral_2][laugh_0][break_4]",
)
wavs = chat.infer(
refined_text.strip(),
skip_refine_text=False,
params_refine_text=params_refine_text,
params_infer_code=params_infer_code,
)
if not wavs or len(wavs) == 0:
return False, "ChatTTS 未生成有效音频。", None
wav_array = np.asarray(wavs[0], dtype=np.float32)
# 归一化并转 int16
peak = np.max(np.abs(wav_array)) or 1.0
wav_int16 = (wav_array / peak * 32767).astype(np.int16)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"voiceover_{timestamp}_{uuid.uuid4().hex[:6]}.wav"
output_path = OUTPUT_DIR / filename
wavfile.write(str(output_path), TTS_SAMPLE_RATE, wav_int16)
msg = f"配音合成成功: {output_path}"
logger.info(msg)
return True, msg, str(output_path)
except Exception as exc:
err = f"语音合成失败: {exc}\n{traceback.format_exc()}"
logger.exception("generate_voice 失败")
return False, err, None