Initial commit: add Trading Studio voice-over pipeline for quant trading review videos.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
+305
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
ChatTTS 本地语音合成服务
|
||||
支持从参考人声提取 Speaker Embedding 并固定音色合成配音。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.io import wavfile
|
||||
|
||||
from config import (
|
||||
OUTPUT_DIR,
|
||||
SPEAKER_EMB_PATH,
|
||||
SPEAKER_SAMPLE_MAX_SEC,
|
||||
SPEAKER_SAMPLE_MIN_SEC,
|
||||
TTS_SAMPLE_RATE,
|
||||
TTS_SPEED_PROMPT,
|
||||
TTS_TEMPERATURE,
|
||||
TTS_TOP_K,
|
||||
TTS_TOP_P,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局 ChatTTS 实例
|
||||
_chat = None
|
||||
_chat_error: Optional[str] = None
|
||||
|
||||
|
||||
def _load_audio_for_chattts(audio_path: str, sample_rate: int = TTS_SAMPLE_RATE) -> np.ndarray:
|
||||
"""
|
||||
加载音频并重采样到 ChatTTS 所需采样率。
|
||||
优先使用 ChatTTS 自带工具,回退到 librosa。
|
||||
"""
|
||||
try:
|
||||
from ChatTTS.utils import load_audio
|
||||
|
||||
return load_audio(audio_path, sample_rate)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from tools.audio import load_audio
|
||||
|
||||
return load_audio(audio_path, sample_rate)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import librosa
|
||||
|
||||
audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True)
|
||||
return audio
|
||||
|
||||
|
||||
def _get_audio_duration_sec(audio: np.ndarray, sample_rate: int) -> float:
|
||||
"""计算音频时长(秒)。"""
|
||||
if audio is None or len(audio) == 0:
|
||||
return 0.0
|
||||
return len(audio) / float(sample_rate)
|
||||
|
||||
|
||||
def get_chattts_instance():
|
||||
"""
|
||||
获取或初始化 ChatTTS 模型。
|
||||
启用 GPU 加速,compile=False 以兼容 3060 Ti 8GB 显存。
|
||||
"""
|
||||
global _chat, _chat_error
|
||||
|
||||
if _chat is not None:
|
||||
return _chat, None
|
||||
|
||||
if _chat_error is not None:
|
||||
return None, _chat_error
|
||||
|
||||
try:
|
||||
import ChatTTS
|
||||
|
||||
logger.info("正在加载 ChatTTS 模型...")
|
||||
chat = ChatTTS.Chat()
|
||||
|
||||
# 兼容不同版本 API:load_models(旧版)/ load(新版)
|
||||
if hasattr(chat, "load_models"):
|
||||
chat.load_models(compile=False)
|
||||
elif hasattr(chat, "load"):
|
||||
chat.load(compile=False)
|
||||
else:
|
||||
raise RuntimeError("当前 ChatTTS 版本缺少 load / load_models 方法。")
|
||||
|
||||
_chat = chat
|
||||
logger.info("ChatTTS 模型加载成功。")
|
||||
return _chat, None
|
||||
|
||||
except ImportError as exc:
|
||||
_chat_error = (
|
||||
"未安装 ChatTTS,请参考 DEPLOY.md 安装。\n"
|
||||
f"原始错误: {exc}"
|
||||
)
|
||||
logger.exception("ChatTTS 导入失败")
|
||||
return None, _chat_error
|
||||
|
||||
except Exception as exc:
|
||||
_chat_error = f"ChatTTS 模型加载失败: {exc}\n{traceback.format_exc()}"
|
||||
logger.exception("ChatTTS 初始化异常")
|
||||
return None, _chat_error
|
||||
|
||||
|
||||
def _encode_spk_emb(chat, tensor_or_str: Any) -> str:
|
||||
"""将 Speaker Embedding 编码为 ChatTTS 可用的字符串格式。"""
|
||||
if isinstance(tensor_or_str, str):
|
||||
return tensor_or_str
|
||||
|
||||
if hasattr(chat, "_encode_spk_emb"):
|
||||
return chat._encode_spk_emb(tensor_or_str)
|
||||
|
||||
# 兜底:直接转字符串(部分版本可接受 tensor)
|
||||
return tensor_or_str
|
||||
|
||||
|
||||
def save_fixed_speaker(
|
||||
audio_sample_path: str,
|
||||
sample_transcript: str = "",
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
从 10-30 秒干净人声中提取 Speaker Embedding 并序列化保存。
|
||||
|
||||
Args:
|
||||
audio_sample_path: 参考人声 wav/mp3 等路径
|
||||
sample_transcript: 参考音频的精确转写(可选,有助于 zero-shot 音色还原)
|
||||
|
||||
Returns:
|
||||
(success, message)
|
||||
"""
|
||||
if not audio_sample_path:
|
||||
return False, "未提供音色参考音频。"
|
||||
|
||||
chat, init_err = get_chattts_instance()
|
||||
if chat is None:
|
||||
return False, init_err or "ChatTTS 不可用。"
|
||||
|
||||
try:
|
||||
audio = _load_audio_for_chattts(audio_sample_path, TTS_SAMPLE_RATE)
|
||||
duration = _get_audio_duration_sec(audio, TTS_SAMPLE_RATE)
|
||||
|
||||
if duration < SPEAKER_SAMPLE_MIN_SEC:
|
||||
return False, (
|
||||
f"参考音频过短({duration:.1f}s),建议 {SPEAKER_SAMPLE_MIN_SEC}-"
|
||||
f"{SPEAKER_SAMPLE_MAX_SEC} 秒干净人声。"
|
||||
)
|
||||
if duration > SPEAKER_SAMPLE_MAX_SEC + 5:
|
||||
logger.warning("参考音频超过建议时长 %.1fs,将截取前 %ds", duration, SPEAKER_SAMPLE_MAX_SEC)
|
||||
max_samples = SPEAKER_SAMPLE_MAX_SEC * TTS_SAMPLE_RATE
|
||||
audio = audio[:max_samples]
|
||||
|
||||
# 从参考音频提取音色特征
|
||||
spk_smp = chat.sample_audio_speaker(audio)
|
||||
|
||||
# 同时保存编码后的 spk_emb 字符串,便于 infer 时直接使用
|
||||
spk_emb = _encode_spk_emb(chat, spk_smp)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"spk_emb": spk_emb,
|
||||
"spk_smp": spk_smp,
|
||||
"txt_smp": sample_transcript.strip(),
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"source_audio": str(audio_sample_path),
|
||||
}
|
||||
|
||||
torch.save(payload, SPEAKER_EMB_PATH)
|
||||
|
||||
msg = (
|
||||
f"音色已锁定并保存至 {SPEAKER_EMB_PATH}\n"
|
||||
f"参考时长: {duration:.1f}s"
|
||||
)
|
||||
if not sample_transcript.strip():
|
||||
msg += "\n提示:填写参考音频精确转写可进一步提升音色还原度。"
|
||||
|
||||
logger.info("Speaker Embedding 保存成功: %s", SPEAKER_EMB_PATH)
|
||||
return True, msg
|
||||
|
||||
except Exception as exc:
|
||||
err = f"音色提取失败: {exc}\n{traceback.format_exc()}"
|
||||
logger.exception("save_fixed_speaker 失败")
|
||||
return False, err
|
||||
|
||||
|
||||
def _load_speaker_payload() -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""加载本地 speaker_emb.pt。"""
|
||||
if not SPEAKER_EMB_PATH.exists():
|
||||
return None, (
|
||||
f"未找到固定音色文件 `{SPEAKER_EMB_PATH.name}`。"
|
||||
"请先在【音色锁定】模块上传 10-30 秒参考人声。"
|
||||
)
|
||||
|
||||
try:
|
||||
payload = torch.load(SPEAKER_EMB_PATH, map_location="cpu", weights_only=False)
|
||||
|
||||
# 兼容旧版仅保存 tensor 的文件
|
||||
if isinstance(payload, torch.Tensor):
|
||||
chat, err = get_chattts_instance()
|
||||
if chat is None:
|
||||
return None, err
|
||||
return {
|
||||
"spk_emb": _encode_spk_emb(chat, payload),
|
||||
"spk_smp": None,
|
||||
"txt_smp": "",
|
||||
}, None
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
return None, "speaker_emb.pt 格式无效,请重新锁定音色。"
|
||||
|
||||
return payload, None
|
||||
|
||||
except Exception as exc:
|
||||
return None, f"读取 speaker_emb.pt 失败: {exc}"
|
||||
|
||||
|
||||
def speaker_is_ready() -> Tuple[bool, str]:
|
||||
"""检查固定音色是否已配置。"""
|
||||
payload, err = _load_speaker_payload()
|
||||
if payload is None:
|
||||
return False, err or "音色未配置。"
|
||||
return True, f"已加载固定音色: {SPEAKER_EMB_PATH}"
|
||||
|
||||
|
||||
def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
使用 ChatTTS 将润色后的文稿合成为 wav 配音。
|
||||
|
||||
Args:
|
||||
refined_text: LLM 润色后的配音稿
|
||||
|
||||
Returns:
|
||||
(success, message, output_wav_path_or_none)
|
||||
"""
|
||||
if not refined_text or not refined_text.strip():
|
||||
return False, "合成文本为空,请先完成润色。", None
|
||||
|
||||
chat, init_err = get_chattts_instance()
|
||||
if chat is None:
|
||||
return False, init_err or "ChatTTS 不可用。", None
|
||||
|
||||
payload, spk_err = _load_speaker_payload()
|
||||
if payload is None:
|
||||
return False, spk_err or "请先锁定音色。", None
|
||||
|
||||
try:
|
||||
import ChatTTS
|
||||
|
||||
spk_emb = payload.get("spk_emb")
|
||||
spk_smp = payload.get("spk_smp")
|
||||
txt_smp = payload.get("txt_smp", "")
|
||||
|
||||
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
||||
prompt=TTS_SPEED_PROMPT,
|
||||
spk_emb=spk_emb,
|
||||
spk_smp=spk_smp if spk_smp else None,
|
||||
txt_smp=txt_smp if txt_smp else None,
|
||||
temperature=TTS_TEMPERATURE,
|
||||
top_P=TTS_TOP_P,
|
||||
top_K=TTS_TOP_K,
|
||||
)
|
||||
|
||||
# 内向克制语气:降低 oral 强度
|
||||
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
||||
prompt="[oral_2][laugh_0][break_4]",
|
||||
)
|
||||
|
||||
wavs = chat.infer(
|
||||
refined_text.strip(),
|
||||
skip_refine_text=False,
|
||||
params_refine_text=params_refine_text,
|
||||
params_infer_code=params_infer_code,
|
||||
)
|
||||
|
||||
if not wavs or len(wavs) == 0:
|
||||
return False, "ChatTTS 未生成有效音频。", None
|
||||
|
||||
wav_array = np.asarray(wavs[0], dtype=np.float32)
|
||||
|
||||
# 归一化并转 int16
|
||||
peak = np.max(np.abs(wav_array)) or 1.0
|
||||
wav_int16 = (wav_array / peak * 32767).astype(np.int16)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"voiceover_{timestamp}_{uuid.uuid4().hex[:6]}.wav"
|
||||
output_path = OUTPUT_DIR / filename
|
||||
|
||||
wavfile.write(str(output_path), TTS_SAMPLE_RATE, wav_int16)
|
||||
|
||||
msg = f"配音合成成功: {output_path}"
|
||||
logger.info(msg)
|
||||
return True, msg, str(output_path)
|
||||
|
||||
except Exception as exc:
|
||||
err = f"语音合成失败: {exc}\n{traceback.format_exc()}"
|
||||
logger.exception("generate_voice 失败")
|
||||
return False, err, None
|
||||
Reference in New Issue
Block a user