From 5e95d3af2f9305b255b9dc3a8aa988e4a6f97f2b Mon Sep 17 00:00:00 2001 From: dekun Date: Fri, 12 Jun 2026 13:19:44 +0800 Subject: [PATCH] Initial commit: add Trading Studio voice-over pipeline for quant trading review videos. Co-authored-by: Cursor --- .gitignore | 46 +++++ DEPLOY.md | 488 ++++++++++++++++++++++++++++++++++++++++++++ README.md | 197 ++++++++++++++++++ app.py | 378 ++++++++++++++++++++++++++++++++++ config.py | 82 ++++++++ ecosystem.config.js | 26 +++ llm_service.py | 162 +++++++++++++++ requirements.txt | 23 +++ tts_service.py | 305 +++++++++++++++++++++++++++ whisper_service.py | 155 ++++++++++++++ 10 files changed, 1862 insertions(+) create mode 100644 .gitignore create mode 100644 DEPLOY.md create mode 100644 README.md create mode 100644 app.py create mode 100644 config.py create mode 100644 ecosystem.config.js create mode 100644 llm_service.py create mode 100644 requirements.txt create mode 100644 tts_service.py create mode 100644 whisper_service.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f71b090 --- /dev/null +++ b/.gitignore @@ -0,0 +1,46 @@ +# Python 虚拟环境 +venv/ +.venv/ +env/ + +# 模型权重与音色文件(体积大,不入库) +*.pt +*.pth +*.onnx +*.bin +*.safetensors + +# 音频产物 +*.wav +*.mp3 +*.flac +*.ogg +*.m4a + +# 日志 +*.log + +# 运行时目录 +uploads/ +outputs/ +__pycache__/ +*.py[cod] +*$py.class +.Python + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# 系统文件 +.DS_Store +Thumbs.db + +# 环境变量与密钥 +.env +.env.* + +# Gradio 临时 +gradio_cached_examples/ diff --git a/DEPLOY.md b/DEPLOY.md new file mode 100644 index 0000000..fbc07a8 --- /dev/null +++ b/DEPLOY.md @@ -0,0 +1,488 @@ +# Trading Studio 部署指南 (DEPLOY.md) + +本文档面向 **Ubuntu 物理服务器**(搭载 RTX 3060 Ti,已锁定 120W 功耗墙)的完整环境配置与 PM2 常驻部署流程。适用于首次安装或迁移重装场景。 + +**Git 仓库:** https://git.bz121.com/dekun/Trading_Studio.git + +--- + +## 目录 + +1. [硬件与系统前提](#1-硬件与系统前提) +2. [3060 Ti 120W 功耗墙配置](#2-3060-ti-120w-功耗墙配置) +3. [NVIDIA 驱动与 CUDA](#3-nvidia-驱动与-cuda) +4. [Python 虚拟环境](#4-python-虚拟环境) +5. [PyTorch CUDA 12.1 安装](#5-pytorch-cuda-121-安装) +6. [项目依赖安装](#6-项目依赖安装) +7. [远程 Ollama 节点配置](#7-远程-ollama-节点配置) +8. [首次运行与验证](#8-首次运行与验证) +9. [PM2 进程守护](#9-pm2-进程守护) +10. [迁移与故障排查](#10-迁移与故障排查) + +--- + +## 1. 硬件与系统前提 + +| 项目 | 要求 | +|------|------| +| GPU | NVIDIA RTX 3060 Ti 8GB | +| 功耗墙 | 120W(推荐锁定,见下文) | +| 系统 | Ubuntu 22.04 / 24.04 LTS | +| 内存 | ≥ 16GB | +| 磁盘 | ≥ 30GB 可用(含模型缓存) | +| 网络 | 局域网可访问 `192.168.8.64:11434` | + +```bash +# 基础工具 +sudo apt update && sudo apt upgrade -y +sudo apt install -y git curl wget build-essential \ + python3 python3-venv python3-dev \ + ffmpeg libsndfile1 portaudio19-dev +``` + +> 若 `python3-venv` 包名报错,使用 `python3-venv`。 + +--- + +## 2. 3060 Ti 120W 功耗墙配置 + +锁定 GPU 功耗有助于稳定 7×24 运行、降低散热压力,避免 Whisper + ChatTTS 并发时触发功耗波动。 + +### 2.1 安装 nvidia-smi 功耗管理工具 + +驱动安装后自带 `nvidia-smi`。确认 GPU 可见: + +```bash +nvidia-smi +``` + +### 2.2 临时设置 120W 功耗上限 + +```bash +# 查看支持的功耗范围 +nvidia-smi -q -d POWER | grep -A3 "Power Limit" + +# 设置最大功耗为 120W(需 root) +sudo nvidia-smi -pl 120 +``` + +### 2.3 开机持久化(推荐) + +创建 systemd 服务,每次启动自动应用: + +```bash +sudo tee /etc/systemd/system/nvidia-powerlimit.service << 'EOF' +[Unit] +Description=Set NVIDIA GPU Power Limit to 120W +After=multi-user.target + +[Service] +Type=oneshot +ExecStart=/usr/bin/nvidia-smi -pl 120 +RemainAfterExit=yes + +[Install] +WantedBy=multi-user.target +EOF + +sudo systemctl daemon-reload +sudo systemctl enable nvidia-powerlimit.service +sudo systemctl start nvidia-powerlimit.service + +# 验证 +nvidia-smi --query-gpu=power.limit --format=csv +``` + +--- + +## 3. NVIDIA 驱动与 CUDA + +### 3.1 安装驱动(推荐 535+ 或 550+) + +```bash +# Ubuntu 自动安装推荐驱动 +sudo ubuntu-drivers devices +sudo ubuntu-drivers autoinstall +# 或指定版本: sudo apt install nvidia-driver-550 + +sudo reboot +``` + +重启后验证: + +```bash +nvidia-smi +nvcc --version # 若未安装 nvcc 不影响 PyTorch,可选 +``` + +### 3.2 cuDNN(Faster-Whisper / PyTorch 需要) + +PyTorch cu121 wheel 通常自带运行时库。若 Whisper 报 cuDNN 错误: + +```bash +# 参考 NVIDIA 官方文档安装 cuDNN for CUDA 12.x +# https://developer.nvidia.com/cudnn +``` + +--- + +## 4. Python 虚拟环境 + +```bash +# 克隆项目 +cd ~ +git clone https://git.bz121.com/dekun/Trading_Studio.git +cd Trading_Studio + +# 创建虚拟环境(必须使用 venv,与 PM2 interpreter 路径一致) +python3 -m venv venv + +# 激活 +source venv/bin/activate + +# 升级 pip +pip install --upgrade pip setuptools wheel +``` + +**重要:** PM2 配置中 `interpreter` 指向 `./venv/bin/python`,请确保在项目根目录创建 `venv/`。 + +--- + +## 5. PyTorch CUDA 12.1 安装 + +**必须先于其他 GPU 依赖安装**,避免 pip 拉取 CPU 版 torch。 + +```bash +source venv/bin/activate + +pip install torch torchvision torchaudio \ + --index-url https://download.pytorch.org/whl/cu121 +``` + +验证 CUDA 可用: + +```bash +python -c " +import torch +print('PyTorch:', torch.__version__) +print('CUDA available:', torch.cuda.is_available()) +print('GPU:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A') +" +``` + +期望输出类似: + +``` +PyTorch: 2.x.x+cu121 +CUDA available: True +GPU: NVIDIA GeForce RTX 3060 Ti +``` + +--- + +## 6. 项目依赖安装 + +```bash +source venv/bin/activate +cd ~/Trading_Studio + +# 安装其余依赖 +pip install -r requirements.txt +``` + +### 6.1 Faster-Whisper + +随 `requirements.txt` 安装。首次运行会自动下载 `small` 模型(约 500MB)至 HuggingFace 缓存。 + +### 6.2 ChatTTS + +从 GitHub 源码安装(已在 requirements.txt 中指定): + +```bash +pip install ChatTTS @ git+https://github.com/2noise/ChatTTS.git +``` + +首次 `save_fixed_speaker` 或 `generate_voice` 时会下载模型权重(数 GB),请确保网络畅通或提前配置 HuggingFace 镜像: + +```bash +export HF_ENDPOINT=https://hf-mirror.com # 可选,国内加速 +``` + +### 6.3 Gradio + +```bash +pip install gradio>=4.44.0 +``` + +--- + +## 7. 远程 Ollama 节点配置 + +Trading Studio 的 LLM 润色模块连接局域网 Ollama,**不在本机运行大模型**。 + +| 配置项 | 值 | +|--------|-----| +| 地址 | `http://192.168.8.64:11434` | +| API | `POST /api/chat` | +| 模型 | `huihui_ai/gemma-4-abliterated:e4b` | +| 流式 | `stream: false` | + +### 7.1 在 Ollama 节点(192.168.8.64)上 + +```bash +# 安装 Ollama(若未安装) +curl -fsSL https://ollama.com/install.sh | sh + +# 拉取模型 +ollama pull huihui_ai/gemma-4-abliterated:e4b + +# 允许局域网访问(编辑 systemd 或环境变量) +sudo systemctl edit ollama +``` + +添加: + +```ini +[Service] +Environment="OLLAMA_HOST=0.0.0.0:11434" +``` + +```bash +sudo systemctl daemon-reload +sudo systemctl restart ollama +``` + +### 7.2 在本机(Trading Studio 服务器)验证 + +```bash +curl http://192.168.8.64:11434/api/tags + +curl http://192.168.8.64:11434/api/chat -d '{ + "model": "huihui_ai/gemma-4-abliterated:e4b", + "messages": [{"role": "user", "content": "ping"}], + "stream": false +}' +``` + +--- + +## 8. 首次运行与验证 + +```bash +source venv/bin/activate +cd ~/Trading_Studio + +# 前台启动(调试) +python app.py +``` + +浏览器访问: + +``` +http://<本机局域网IP>:5683 +``` + +### 8.1 验证清单 + +- [ ] 页面加载,Ollama 状态显示在线 +- [ ] 上传 10-30s 参考人声 → 音色锁定成功,生成 `speaker_emb.pt` +- [ ] 上传复盘录音 → Whisper 识别出中文文本 +- [ ] 点击润色 → 返回 Gemma4 处理后的文稿 +- [ ] 点击合成 → `outputs/` 下生成 24kHz wav + +### 8.2 日志位置 + +- 应用日志:`trading_studio.log`(项目根目录) +- PM2 日志:`logs/pm2-out.log`、`logs/pm2-error.log` + +```bash +mkdir -p logs +``` + +--- + +## 9. PM2 进程守护 + +Trading Studio 原生支持 PM2 常驻管理,确保 Gradio 服务崩溃后自动重启、开机自启。 + +### 9.1 安装 Node.js 与 PM2 + +```bash +# 安装 Node.js 20 LTS +curl -fsSL https://deb.nodesource.com/setup_20.x | sudo -E bash - +sudo apt install -y nodejs + +# 全局安装 PM2 +sudo npm install -g pm2 +``` + +### 9.2 方式 A:使用 ecosystem.config.js(推荐) + +项目已内置 `ecosystem.config.js`: + +```javascript +module.exports = { + apps: [{ + name: "trading_studio", + script: "app.py", + interpreter: "./venv/bin/python", + cwd: __dirname, + instances: 1, + autorestart: true, + max_memory_restart: "6G", + env: { + PYTHONUNBUFFERED: "1", + CUDA_VISIBLE_DEVICES: "0", + }, + }], +}; +``` + +启动: + +```bash +cd ~/Trading_Studio +mkdir -p logs + +pm2 start ecosystem.config.js +pm2 status +pm2 logs trading_studio --lines 50 +``` + +### 9.3 方式 B:直接命令行 + +```bash +cd ~/Trading_Studio + +pm2 start app.py \ + --name "trading_studio" \ + --interpreter ./venv/bin/python + +pm2 save +``` + +### 9.4 开机自启 + +```bash +pm2 startup +# 按提示执行输出的 sudo 命令 + +pm2 save +``` + +### 9.5 常用运维命令 + +```bash +pm2 restart trading_studio # 重启(改代码后) +pm2 stop trading_studio # 停止 +pm2 delete trading_studio # 移除 +pm2 monit # 实时监控 CPU/内存 +``` + +### 9.6 更新代码后重新部署 + +```bash +cd ~/Trading_Studio +git pull +source venv/bin/activate +pip install -r requirements.txt # 若有新依赖 +pm2 restart trading_studio +``` + +--- + +## 10. 迁移与故障排查 + +### 10.1 迁移到新机器 + +1. 复制 `speaker_emb.pt`(音色文件,在 `.gitignore` 中,需手动备份) +2. 新机器按本文档完整部署 +3. 将 `speaker_emb.pt` 放回项目根目录 +4. `pm2 restart trading_studio` + +### 10.2 CUDA / 显存问题 + +```bash +# 查看显存占用 +nvidia-smi + +# 若 OOM,确保无其他 GPU 进程 +fuser -v /dev/nvidia* +``` + +Whisper 与 ChatTTS 不会同时常驻最大显存,但首次加载模型时峰值较高。建议: + +- 锁定 120W 功耗墙 +- `max_memory_restart: "6G"` 已在 PM2 配置中设置 + +### 10.3 Whisper CUDA 报错 + +``` +错误: CUDA initialization failed / out of memory +``` + +处理: + +1. 重启 PM2 进程释放显存 +2. 确认 `compute_type="float16"`(已在 config.py 配置) +3. 降级模型为 `base`(修改 `config.py` 中 `WHISPER_MODEL_SIZE`) + +### 10.4 Ollama 超时 + +``` +连接 Ollama 超时(>60s) +``` + +处理: + +1. 确认 Ollama 节点模型已预加载:`ollama run huihui_ai/gemma-4-abliterated:e4b` +2. 增大 `config.py` 中 `OLLAMA_TIMEOUT` +3. 检查防火墙:`sudo ufw allow from 192.168.8.0/24 to any port 11434`(在 Ollama 节点) + +### 10.5 ChatTTS 音色文件损坏 + +```bash +rm speaker_emb.pt +# 重新在 Web UI「音色锁定」上传参考人声 +``` + +### 10.6 端口 5683 被占用 + +```bash +sudo lsof -i :5683 +# 或 +ss -tlnp | grep 5683 +``` + +--- + +## 附录:防火墙(本机 Gradio) + +若需局域网其他设备访问 Web UI: + +```bash +sudo ufw allow 5683/tcp +sudo ufw reload +``` + +访问地址:`http://<服务器局域网IP>:5683` + +--- + +## 附录:config.py 关键常量速查 + +```python +HOST = "0.0.0.0" +PORT = 5683 +OLLAMA_URL = "http://192.168.8.64:11434/api/chat" +MODEL_NAME = "huihui_ai/gemma-4-abliterated:e4b" +WHISPER_MODEL_SIZE = "small" +WHISPER_DEVICE = "cuda" +WHISPER_COMPUTE_TYPE = "float16" +SPEAKER_EMB_PATH = "speaker_emb.pt" +TTS_SAMPLE_RATE = 24000 +``` + +--- + +**部署完成后,请先在「音色锁定」模块完成首次音色提取,再进行日常复盘配音生产。** diff --git a/README.md b/README.md new file mode 100644 index 0000000..e067280 --- /dev/null +++ b/README.md @@ -0,0 +1,197 @@ +# Trading Studio + +**本地量化交易复盘 → B 站长视频配音生产流水线** + +Trading Studio 是一套运行在 Ubuntu 物理服务器(RTX 3060 Ti)上的自动化配音系统,专为数字资产量化交易员设计。通过「盲录碎碎念 → 本地 GPU 识别 → 局域网大模型严厉润色 → 本地 GPU 声音克隆」的闭环,高效产出 B 站反思类长视频配音,辅助交易纪律的自我进化。 + +**Git 仓库:** https://git.bz121.com/dekun/Trading_Studio.git + +--- + +## 系统定位 + +| 环节 | 技术栈 | 运行位置 | +|------|--------|----------| +| 碎碎念录音转写 | Faster-Whisper (CUDA float16) | 本地 3060 Ti | +| 纪律审判式润色 | Gemma4 Abliterated @ Ollama | 局域网 `192.168.8.64` | +| 固定音色配音 | ChatTTS (CUDA) | 本地 3060 Ti | +| Web 中控 | Gradio | 端口 **5683** | + +--- + +## 架构说明 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Gradio 中控 (app.py:5683) │ +├──────────────┬──────────────────────┬───────────────────────┤ +│ 音色锁定 │ 音频识别 │ 润色 + 合成 │ +│ tts_service │ whisper_service │ llm_service │ +│ │ │ tts_service │ +└──────┬───────┴──────────┬───────────┴──────────┬────────────┘ + │ │ │ + ▼ ▼ ▼ + speaker_emb.pt Faster-Whisper Ollama HTTP + (本地持久化) CUDA / small 192.168.8.64:11434 + gemma-4-abliterated +``` + +### 模块解耦 + +| 文件 | 职责 | +|------|------| +| `config.py` | Ollama 地址、模型名、Prompt、路径等全局配置 | +| `whisper_service.py` | Faster-Whisper CUDA 转写 | +| `llm_service.py` | 远程 Ollama HTTP 非流式润色 | +| `tts_service.py` | ChatTTS 音色提取与 wav 合成 | +| `app.py` | Gradio 前端与流程编排 | + +--- + +## 快速开始 + +> 完整环境配置请参阅 [DEPLOY.md](./DEPLOY.md) + +```bash +# 1. 克隆仓库 +git clone https://git.bz121.com/dekun/Trading_Studio.git +cd Trading_Studio + +# 2. 创建虚拟环境并安装依赖(详见 DEPLOY.md) +python3 -m venv venv +source venv/bin/activate +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 +pip install -r requirements.txt + +# 3. 启动中控 +python app.py +# 浏览器访问: http://<服务器IP>:5683 +``` + +--- + +## 使用流程 + +### 首次使用:锁定音色 + +1. 进入 **「音色锁定」** 标签页 +2. 上传 10-30 秒干净人声参考(你的碎碎念盲录样本) +3. (可选)填写参考音频的精确转写,提升 zero-shot 还原度 +4. 点击 **锁定音色** → 生成 `speaker_emb.pt` + +### 日常生产 + +**方式 A — 分步操作:** + +1. **音频极速识别**:上传复盘录音 → Whisper 转写 +2. **Gemma4 纪律审判**:一键润色,生成逻辑清晰、语气严厉的反思稿 +3. **ChatTTS 合成**:输出 24kHz `.wav` 成品配音 + +**方式 B — 一键生产:** + +上传录音后点击 **启动全流程**,系统自动串联三步。 + +--- + +## 核心配置(config.py) + +| 配置项 | 默认值 | +|--------|--------| +| 中控端口 | `5683`(`0.0.0.0` 局域网可访问) | +| Ollama 地址 | `http://192.168.8.64:11434` | +| 模型名称 | `huihui_ai/gemma-4-abliterated:e4b` | +| Whisper 模型 | `small` / CUDA / float16 | +| 音色文件 | `speaker_emb.pt` | +| 音频输出 | `outputs/` 目录 | + +--- + +## PM2 守护运行 + +```bash +# 方式 1:ecosystem 配置 +pm2 start ecosystem.config.js + +# 方式 2:直接命令 +pm2 start app.py --name "trading_studio" --interpreter ./venv/bin/python + +# 常用管理 +pm2 status +pm2 logs trading_studio +pm2 restart trading_studio +pm2 save && pm2 startup # 开机自启 +``` + +--- + +## .gitignore 配置 + +提交 Git 时请确保忽略以下产物(已在 `.gitignore` 中预设): + +```gitignore +venv/ +*.wav +*.pt +*.log +uploads/ +outputs/ +``` + +**说明:** + +- `venv/` — Python 虚拟环境,每台机器独立创建 +- `*.wav` — 录音与合成音频,体积大且含隐私 +- `*.pt` — ChatTTS 音色 Embedding 与模型权重 +- `*.log` — 运行日志 + +--- + +## 目录结构 + +``` +Trading_Studio/ +├── app.py # Gradio 主入口 +├── config.py # 全局配置 +├── whisper_service.py # Whisper CUDA 识别 +├── llm_service.py # Ollama 远程润色 +├── tts_service.py # ChatTTS 音色与合成 +├── ecosystem.config.js # PM2 守护配置 +├── requirements.txt # Python 依赖 +├── README.md # 本文件 +├── DEPLOY.md # 部署指南 +├── .gitignore +├── speaker_emb.pt # 音色文件(运行时生成,不入库) +├── uploads/ # 上传临时目录 +└── outputs/ # 合成 wav 输出 +``` + +--- + +## 硬件要求 + +- **GPU:** NVIDIA RTX 3060 Ti(8GB 显存,建议锁定 120W 功耗墙) +- **系统:** Ubuntu 22.04 / 24.04 LTS +- **CUDA:** 12.1+(与 PyTorch cu121 匹配) +- **局域网:** 可访问 `192.168.8.64:11434` 的 Ollama 节点 + +--- + +## 常见问题 + +**Q: Whisper 报 CUDA 错误?** +A: 确认 `nvidia-smi` 正常,且未同时运行其他占显存任务。Whisper 使用 `float16` 已针对 8GB 优化。 + +**Q: Ollama 连接失败?** +A: 在服务器上执行 `curl http://192.168.8.64:11434/api/tags` 验证连通性,确认模型已 `ollama pull`。 + +**Q: TTS 音色不稳定?** +A: 重新锁定音色,填写参考音频精确转写,并保持 `temperature=0.3` 低随机性。 + +**Q: 合成音频为空或噪声?** +A: 检查润色文本长度(过短可能导致异常),确认 `speaker_emb.pt` 存在且有效。 + +--- + +## License + +Private — 仅供个人量化交易复盘使用。 diff --git a/app.py b/app.py new file mode 100644 index 0000000..8cc182a --- /dev/null +++ b/app.py @@ -0,0 +1,378 @@ +""" +Trading Studio — 自动化交易复盘视频配音系统 +Gradio Web 中控:音色锁定 → Whisper 识别 → Gemma4 润色 → ChatTTS 合成 +""" + +from __future__ import annotations + +import logging +import shutil +import sys +import uuid +from pathlib import Path + +import gradio as gr + +from config import ( + GIT_REPO_URL, + HOST, + MODEL_NAME, + OLLAMA_URL, + PORT, + SPEAKER_EMB_PATH, + UPLOAD_DIR, +) +from llm_service import check_ollama_health, polish_text +from tts_service import generate_voice, save_fixed_speaker, speaker_is_ready +from whisper_service import transcribe_audio + +# --------------------------------------------------------------------------- +# 日志 +# --------------------------------------------------------------------------- +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + handlers=[ + logging.StreamHandler(sys.stdout), + logging.FileHandler("trading_studio.log", encoding="utf-8"), + ], +) +logger = logging.getLogger("trading_studio") + +# --------------------------------------------------------------------------- +# 全局 UI 状态(Gradio State) +# --------------------------------------------------------------------------- +# raw_transcript / polished_script 在流水线中传递 + + +def _save_upload(upload_file) -> str | None: + """将 Gradio 上传文件复制到本地 uploads 目录,返回持久化路径。""" + if upload_file is None: + return None + + src = Path(upload_file) + if not src.exists(): + return None + + dest = UPLOAD_DIR / f"{uuid.uuid4().hex}_{src.name}" + shutil.copy2(src, dest) + return str(dest) + + +# --------------------------------------------------------------------------- +# 模块 1:音色锁定 +# --------------------------------------------------------------------------- +def ui_lock_speaker(audio_file, sample_transcript: str) -> str: + """【音色锁定】从参考人声提取并保存 Speaker Embedding。""" + path = _save_upload(audio_file) + if not path: + return "请上传 10-30 秒干净参考人声(wav/mp3 均可)。" + + ok, msg = save_fixed_speaker(path, sample_transcript or "") + return msg if ok else f"❌ {msg}" + + +def ui_speaker_status() -> str: + """刷新音色状态。""" + ok, msg = speaker_is_ready() + return f"✅ {msg}" if ok else f"⚠️ {msg}" + + +# --------------------------------------------------------------------------- +# 模块 2:音频极速识别 +# --------------------------------------------------------------------------- +def ui_transcribe(audio_file) -> tuple[str, str]: + """【Whisper 识别】返回 (转写文本, 状态日志)。""" + path = _save_upload(audio_file) + if not path: + return "", "请上传待识别的碎碎念录音。" + + ok, result = transcribe_audio(path) + if ok: + return result, f"✅ 识别完成,共 {len(result)} 字。" + return "", f"❌ {result}" + + +# --------------------------------------------------------------------------- +# 模块 3:Gemma4 纪律审判 +# --------------------------------------------------------------------------- +def ui_polish(raw_text: str) -> tuple[str, str]: + """【LLM 润色】对转写稿进行严厉自我反思式润色。""" + if not raw_text or not raw_text.strip(): + return "", "请先完成语音识别或手动粘贴转写文本。" + + ok, result = polish_text(raw_text) + if ok: + return result, f"✅ Gemma4 润色完成,共 {len(result)} 字。" + return "", f"❌ {result}" + + +def ui_check_ollama() -> str: + """检测远程 Ollama 节点状态。""" + ok, msg = check_ollama_health() + return f"✅ {msg}" if ok else f"❌ {msg}" + + +# --------------------------------------------------------------------------- +# 模块 4:ChatTTS 音频合成 +# --------------------------------------------------------------------------- +def ui_synthesize(polished_text: str) -> tuple[str | None, str]: + """【TTS 合成】生成最终 wav 配音文件。""" + if not polished_text or not polished_text.strip(): + return None, "请先完成 Gemma4 润色。" + + ok, msg, wav_path = generate_voice(polished_text) + if ok and wav_path: + return wav_path, f"✅ {msg}" + return None, f"❌ {msg}" + + +# --------------------------------------------------------------------------- +# 一键流水线 +# --------------------------------------------------------------------------- +def ui_full_pipeline( + audio_file, + skip_polish: bool, + manual_raw: str, +) -> tuple[str, str, str | None, str]: + """ + 串联执行:识别 → 润色(可跳过)→ 合成。 + 返回 (raw, polished, wav_path, log) + """ + logs: list[str] = [] + + # Step 1: 识别 + if manual_raw and manual_raw.strip(): + raw = manual_raw.strip() + logs.append(f"使用手动输入转写稿({len(raw)} 字)。") + else: + path = _save_upload(audio_file) + if not path: + return "", "", None, "❌ 请上传录音或手动填写转写文本。" + ok, result = transcribe_audio(path) + if not ok: + return "", "", None, f"❌ 识别失败: {result}" + raw = result + logs.append(f"✅ Whisper 识别完成({len(raw)} 字)。") + + # Step 2: 润色 + if skip_polish: + polished = raw + logs.append("已跳过 LLM 润色,直接使用原文。") + else: + ok, result = polish_text(raw) + if not ok: + return raw, "", None, f"❌ 润色失败: {result}\n" + "\n".join(logs) + polished = result + logs.append(f"✅ Gemma4 润色完成({len(polished)} 字)。") + + # Step 3: 合成 + ok, msg, wav_path = generate_voice(polished) + if not ok: + return raw, polished, None, f"❌ 合成失败: {msg}\n" + "\n".join(logs) + + logs.append(f"✅ {msg}") + return raw, polished, wav_path, "\n".join(logs) + + +# --------------------------------------------------------------------------- +# Gradio 界面 +# --------------------------------------------------------------------------- +CUSTOM_CSS = """ +/* 硬核暗黑科技风 */ +.gradio-container { + background: linear-gradient(160deg, #0a0a0f 0%, #12121a 40%, #0d0d12 100%) !important; + color: #c8c8d0 !important; +} +.dark-panel { + border: 1px solid #2a2a35; + border-radius: 8px; + padding: 16px; + background: rgba(18, 18, 26, 0.85); + margin-bottom: 12px; +} +h1, h2, h3 { color: #e8e8f0 !important; letter-spacing: 0.05em; } +.status-bar { + font-family: 'Consolas', 'Monaco', monospace; + font-size: 0.85em; + color: #7a7a90; +} +footer { visibility: hidden; } +""" + + +def build_app() -> gr.Blocks: + """构建 Gradio 主界面。""" + theme = gr.themes.Monochrome( + primary_hue="slate", + secondary_hue="gray", + neutral_hue="slate", + font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"], + font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "Consolas", "monospace"], + ).set( + body_background_fill="#0a0a0f", + body_background_fill_dark="#0a0a0f", + block_background_fill="#12121a", + block_background_fill_dark="#12121a", + block_border_color="#2a2a35", + block_label_text_color="#9090a0", + input_background_fill="#1a1a24", + button_primary_background_fill="#3a3a50", + button_primary_background_fill_hover="#4a4a60", + ) + + with gr.Blocks( + title="Trading Studio | 交易复盘配音中控", + theme=theme, + css=CUSTOM_CSS, + ) as demo: + gr.Markdown( + f""" +# ⚡ Trading Studio +**本地量化交易复盘 → B 站配音生产流水线** + +`Whisper(CUDA)` → `Gemma4 @ {OLLAMA_URL}` → `ChatTTS(CUDA)` + +> 仓库: [{GIT_REPO_URL}]({GIT_REPO_URL}) + """, + elem_classes=["dark-panel"], + ) + + with gr.Row(): + ollama_status = gr.Textbox( + label="Ollama 节点", + value=f"模型: {MODEL_NAME}", + interactive=False, + scale=3, + elem_classes=["status-bar"], + ) + speaker_status = gr.Textbox( + label="音色状态", + value="检测中...", + interactive=False, + scale=2, + elem_classes=["status-bar"], + ) + refresh_btn = gr.Button("🔄 刷新状态", scale=1) + + refresh_btn.click( + fn=lambda: (ui_check_ollama(), ui_speaker_status()), + outputs=[ollama_status, speaker_status], + ) + + with gr.Tabs(): + # ---- Tab 1: 音色锁定 ---- + with gr.Tab("🎙️ 音色锁定"): + gr.Markdown( + "上传 **10-30 秒** 干净人声样本,系统将提取 Speaker Embedding " + f"并保存至 `{SPEAKER_EMB_PATH.name}`,后续合成 100% 还原音色。" + ) + with gr.Row(): + spk_audio = gr.Audio( + label="参考人声(碎碎念盲录样本)", + type="filepath", + sources=["upload", "microphone"], + ) + spk_transcript = gr.Textbox( + label="参考音频精确转写(可选,提升还原度)", + placeholder="尽量与参考音频内容完全一致...", + lines=6, + ) + lock_btn = gr.Button("🔒 锁定音色", variant="primary") + lock_log = gr.Textbox(label="锁定结果", lines=4, interactive=False) + lock_btn.click(ui_lock_speaker, [spk_audio, spk_transcript], lock_log) + + # ---- Tab 2: 分步操作 ---- + with gr.Tab("🔧 分步流水线"): + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown("### Step 1 · 音频极速识别") + rec_audio = gr.Audio( + label="交易复盘碎碎念录音", + type="filepath", + sources=["upload", "microphone"], + ) + transcribe_btn = gr.Button("⚡ Faster-Whisper 识别", variant="primary") + transcribe_log = gr.Textbox(label="识别日志", lines=2, interactive=False) + + with gr.Column(scale=1): + gr.Markdown("### Step 2 · Gemma4 纪律审判") + raw_text = gr.Textbox( + label="转写原文(可编辑)", + lines=10, + placeholder="识别结果将显示在此,也可手动粘贴...", + ) + polish_btn = gr.Button("⚖️ 远程 Gemma4 严厉润色", variant="primary") + polish_log = gr.Textbox(label="润色日志", lines=2, interactive=False) + + with gr.Column(scale=1): + gr.Markdown("### Step 3 · ChatTTS 配音合成") + polished_text = gr.Textbox( + label="润色配音稿(可编辑)", + lines=10, + placeholder="润色结果将显示在此...", + ) + synth_btn = gr.Button("🔊 合成配音 WAV", variant="primary") + synth_log = gr.Textbox(label="合成日志", lines=2, interactive=False) + output_audio = gr.Audio(label="成品配音", type="filepath") + + transcribe_btn.click(ui_transcribe, rec_audio, [raw_text, transcribe_log]) + polish_btn.click(ui_polish, raw_text, [polished_text, polish_log]) + synth_btn.click(ui_synthesize, polished_text, [output_audio, synth_log]) + + # ---- Tab 3: 一键生产 ---- + with gr.Tab("🚀 一键生产"): + gr.Markdown( + "上传碎碎念录音,系统自动完成 **识别 → 润色 → 合成** 全流程。" + ) + with gr.Row(): + pipe_audio = gr.Audio( + label="复盘录音", + type="filepath", + sources=["upload", "microphone"], + ) + pipe_manual = gr.Textbox( + label="或手动输入转写(跳过识别)", + lines=4, + placeholder="若已有转写文本,可直接粘贴,留空则走 Whisper 识别", + ) + skip_polish_cb = gr.Checkbox( + label="跳过 Gemma4 润色(仅测试 TTS)", + value=False, + ) + pipeline_btn = gr.Button("▶ 启动全流程", variant="primary", size="lg") + pipeline_log = gr.Textbox(label="流水线日志", lines=6, interactive=False) + with gr.Row(): + pipe_raw = gr.Textbox(label="转写原文", lines=6) + pipe_polished = gr.Textbox(label="润色稿", lines=6) + pipe_output = gr.Audio(label="成品配音", type="filepath") + + pipeline_btn.click( + ui_full_pipeline, + [pipe_audio, skip_polish_cb, pipe_manual], + [pipe_raw, pipe_polished, pipe_output, pipeline_log], + ) + + demo.load( + fn=lambda: (ui_check_ollama(), ui_speaker_status()), + outputs=[ollama_status, speaker_status], + ) + + return demo + + +def main() -> None: + """主入口:启动 Gradio 服务。""" + logger.info("Trading Studio 启动中... HOST=%s PORT=%s", HOST, PORT) + app = build_app() + app.launch( + server_name=HOST, + server_port=PORT, + share=False, + show_error=True, + allowed_paths=[str(Path(__file__).resolve().parent / "outputs")], + ) + + +if __name__ == "__main__": + main() diff --git a/config.py b/config.py new file mode 100644 index 0000000..37b13af --- /dev/null +++ b/config.py @@ -0,0 +1,82 @@ +""" +Trading Studio 全局配置模块 +统一存放局域网节点、模型名称、固定 Prompt 及本地路径。 +""" + +from pathlib import Path + +# --------------------------------------------------------------------------- +# 网络与服务 +# --------------------------------------------------------------------------- +# 远程 Ollama 节点(局域网大模型审查润色) +OLLAMA_HOST = "192.168.8.64" +OLLAMA_PORT = 11434 +OLLAMA_URL = f"http://{OLLAMA_HOST}:{OLLAMA_PORT}/api/chat" + +# 指定无限制版 Gemma4 模型 +MODEL_NAME = "huihui_ai/gemma-4-abliterated:e4b" + +# Gradio 中控固定端口(硬性死规则) +HOST = "0.0.0.0" +PORT = 5683 + +# HTTP 请求超时(秒) +OLLAMA_TIMEOUT = 60 + +# --------------------------------------------------------------------------- +# LLM 系统 Prompt +# --------------------------------------------------------------------------- +SYSTEM_PROMPT = ( + "你是一个冷静、极其严格的数字资产量化交易员。" + "请把下面这段口语化、包含结巴和逻辑混乱的交易复盘录音转写," + "润色成一段逻辑清晰、行文通顺的 B 站长视频反思配音稿。" + "语气要内向、克制、严谨。" + "如果原视频中有由于心态不好、违背交易纪律(如手贱乱开仓、提前平仓)" + "导致少赚或亏损的部分,请用冷酷、严厉的语气狠狠地自我吐槽、反思该点。" + "去掉所有无意义的口头禅,字数不做删减。" +) + +# --------------------------------------------------------------------------- +# Faster-Whisper 配置 +# --------------------------------------------------------------------------- +WHISPER_MODEL_SIZE = "small" +WHISPER_DEVICE = "cuda" +WHISPER_COMPUTE_TYPE = "float16" +WHISPER_LANGUAGE = "zh" + +# --------------------------------------------------------------------------- +# ChatTTS 配置 +# --------------------------------------------------------------------------- +# 项目根目录 +BASE_DIR = Path(__file__).resolve().parent + +# 固定音色 Embedding 存储路径 +SPEAKER_EMB_PATH = BASE_DIR / "speaker_emb.pt" + +# 合成音频输出目录 +OUTPUT_DIR = BASE_DIR / "outputs" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +# ChatTTS 采样率(Hz) +TTS_SAMPLE_RATE = 24000 + +# 音色样本时长建议(秒) +SPEAKER_SAMPLE_MIN_SEC = 10 +SPEAKER_SAMPLE_MAX_SEC = 30 + +# TTS 推理默认参数(低 temperature 有助于音色稳定) +TTS_TEMPERATURE = 0.3 +TTS_TOP_P = 0.7 +TTS_TOP_K = 20 +TTS_SPEED_PROMPT = "[speed_5]" + +# --------------------------------------------------------------------------- +# 上传临时文件目录 +# --------------------------------------------------------------------------- +UPLOAD_DIR = BASE_DIR / "uploads" +UPLOAD_DIR.mkdir(parents=True, exist_ok=True) + +# --------------------------------------------------------------------------- +# Git 仓库(文档引用) +# --------------------------------------------------------------------------- +GIT_REPO_URL = "https://git.bz121.com/dekun/Trading_Studio.git" diff --git a/ecosystem.config.js b/ecosystem.config.js new file mode 100644 index 0000000..79957ef --- /dev/null +++ b/ecosystem.config.js @@ -0,0 +1,26 @@ +/** + * PM2 进程守护配置 + * 用法: pm2 start ecosystem.config.js + */ +module.exports = { + apps: [ + { + name: "trading_studio", + script: "app.py", + interpreter: "./venv/bin/python", + cwd: __dirname, + instances: 1, + autorestart: true, + watch: false, + max_memory_restart: "6G", + env: { + PYTHONUNBUFFERED: "1", + CUDA_VISIBLE_DEVICES: "0", + }, + error_file: "./logs/pm2-error.log", + out_file: "./logs/pm2-out.log", + log_date_format: "YYYY-MM-DD HH:mm:ss", + merge_logs: true, + }, + ], +}; diff --git a/llm_service.py b/llm_service.py new file mode 100644 index 0000000..e0bc5be --- /dev/null +++ b/llm_service.py @@ -0,0 +1,162 @@ +""" +远程 Ollama LLM 润色服务 +通过局域网 HTTP 请求 Gemma4 模型,对交易复盘转写稿进行纪律审判式润色。 +""" + +from __future__ import annotations + +import logging +from typing import Tuple + +import requests + +from config import MODEL_NAME, OLLAMA_TIMEOUT, OLLAMA_URL, SYSTEM_PROMPT + +logger = logging.getLogger(__name__) + + +def _build_payload(raw_text: str) -> dict: + """构造 Ollama /api/chat 非流式请求体。""" + return { + "model": MODEL_NAME, + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + { + "role": "user", + "content": ( + "以下是我的交易复盘录音转写原文,请严格按系统要求润色:\n\n" + f"{raw_text}" + ), + }, + ], + "stream": False, + "options": { + "temperature": 0.7, + "num_predict": 4096, + }, + } + + +def _extract_content(response_json: dict) -> str: + """从 Ollama 响应 JSON 中提取 assistant 文本。""" + # /api/chat 标准格式 + message = response_json.get("message") + if isinstance(message, dict): + content = message.get("content", "").strip() + if content: + return content + + # 兼容 /api/generate 格式(部分旧版或代理) + if "response" in response_json: + content = str(response_json["response"]).strip() + if content: + return content + + raise ValueError(f"无法从 Ollama 响应中解析文本内容: {response_json}") + + +def polish_text(raw_text: str) -> Tuple[bool, str]: + """ + 调用远程 Ollama 对原始转写文本进行润色。 + + Args: + raw_text: Whisper 转写得到的原始口语文本 + + Returns: + (success, polished_text_or_error_message) + """ + if not raw_text or not raw_text.strip(): + return False, "润色输入为空,请先完成语音识别。" + + payload = _build_payload(raw_text.strip()) + + try: + logger.info("正在请求 Ollama: %s, model=%s", OLLAMA_URL, MODEL_NAME) + response = requests.post( + OLLAMA_URL, + json=payload, + timeout=OLLAMA_TIMEOUT, + ) + response.raise_for_status() + + data = response.json() + polished = _extract_content(data) + + if not polished: + return False, "Ollama 返回内容为空,请检查模型是否正常加载。" + + logger.info("润色完成,输出字数: %d", len(polished)) + return True, polished + + except requests.exceptions.ConnectTimeout: + err = ( + f"连接 Ollama 超时(>{OLLAMA_TIMEOUT}s)。" + f"请确认 {OLLAMA_URL} 可达且 Ollama 服务已启动。" + ) + logger.error(err) + return False, err + + except requests.exceptions.ReadTimeout: + err = ( + f"Ollama 响应超时(>{OLLAMA_TIMEOUT}s)。" + "模型可能正在加载或生成长度过长,请稍后重试。" + ) + logger.error(err) + return False, err + + except requests.exceptions.ConnectionError as exc: + err = ( + f"无法连接到 Ollama 节点 ({OLLAMA_URL})。" + "请检查局域网连通性、防火墙及 Ollama 是否监听 0.0.0.0:11434。\n" + f"详情: {exc}" + ) + logger.error(err) + return False, err + + except requests.exceptions.HTTPError as exc: + status = exc.response.status_code if exc.response is not None else "?" + body = exc.response.text[:500] if exc.response is not None else "" + err = ( + f"Ollama HTTP 错误 ({status})。" + f"请确认模型 `{MODEL_NAME}` 已通过 ollama pull 下载。\n" + f"响应片段: {body}" + ) + logger.error(err) + return False, err + + except ValueError as exc: + logger.error("Ollama 响应解析失败: %s", exc) + return False, str(exc) + + except requests.exceptions.RequestException as exc: + err = f"Ollama 请求异常: {exc}" + logger.exception(err) + return False, err + + except Exception as exc: + err = f"润色过程发生未知错误: {exc}" + logger.exception(err) + return False, err + + +def check_ollama_health() -> Tuple[bool, str]: + """ + 快速检测 Ollama 节点是否在线(不触发完整推理)。 + + Returns: + (online, message) + """ + base_url = OLLAMA_URL.rsplit("/api/", 1)[0] + try: + resp = requests.get(f"{base_url}/api/tags", timeout=10) + resp.raise_for_status() + tags = resp.json().get("models", []) + model_names = [m.get("name", "") for m in tags] + if any(MODEL_NAME.split(":")[0] in name for name in model_names): + return True, f"Ollama 在线,已检测到模型: {MODEL_NAME}" + return True, ( + f"Ollama 在线,但未找到模型 {MODEL_NAME}," + f"请执行: ollama pull {MODEL_NAME}" + ) + except Exception as exc: + return False, f"Ollama 不可达: {exc}" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..95856c1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,23 @@ +# Trading Studio 依赖清单 +# CUDA 版 PyTorch 请按 DEPLOY.md 单独安装(cu121),此处不重复指定 + +# Web 中控 +gradio>=4.44.0 + +# 语音识别(CUDA 加速) +faster-whisper>=1.0.0 + +# 远程 LLM 通信 +requests>=2.31.0 + +# 语音合成 +ChatTTS @ git+https://github.com/2noise/ChatTTS.git +torchaudio>=2.1.0 +scipy>=1.11.0 +numpy>=1.24.0 +librosa>=0.10.0 + +# 音频处理辅助 +soundfile>=0.12.0 + +# PM2 通过 Node.js 全局安装,不在 pip 范围内 diff --git a/tts_service.py b/tts_service.py new file mode 100644 index 0000000..ddde3ea --- /dev/null +++ b/tts_service.py @@ -0,0 +1,305 @@ +""" +ChatTTS 本地语音合成服务 +支持从参考人声提取 Speaker Embedding 并固定音色合成配音。 +""" + +from __future__ import annotations + +import logging +import traceback +import uuid +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch +from scipy.io import wavfile + +from config import ( + OUTPUT_DIR, + SPEAKER_EMB_PATH, + SPEAKER_SAMPLE_MAX_SEC, + SPEAKER_SAMPLE_MIN_SEC, + TTS_SAMPLE_RATE, + TTS_SPEED_PROMPT, + TTS_TEMPERATURE, + TTS_TOP_K, + TTS_TOP_P, +) + +logger = logging.getLogger(__name__) + +# 全局 ChatTTS 实例 +_chat = None +_chat_error: Optional[str] = None + + +def _load_audio_for_chattts(audio_path: str, sample_rate: int = TTS_SAMPLE_RATE) -> np.ndarray: + """ + 加载音频并重采样到 ChatTTS 所需采样率。 + 优先使用 ChatTTS 自带工具,回退到 librosa。 + """ + try: + from ChatTTS.utils import load_audio + + return load_audio(audio_path, sample_rate) + except ImportError: + pass + + try: + from tools.audio import load_audio + + return load_audio(audio_path, sample_rate) + except ImportError: + pass + + import librosa + + audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True) + return audio + + +def _get_audio_duration_sec(audio: np.ndarray, sample_rate: int) -> float: + """计算音频时长(秒)。""" + if audio is None or len(audio) == 0: + return 0.0 + return len(audio) / float(sample_rate) + + +def get_chattts_instance(): + """ + 获取或初始化 ChatTTS 模型。 + 启用 GPU 加速,compile=False 以兼容 3060 Ti 8GB 显存。 + """ + global _chat, _chat_error + + if _chat is not None: + return _chat, None + + if _chat_error is not None: + return None, _chat_error + + try: + import ChatTTS + + logger.info("正在加载 ChatTTS 模型...") + chat = ChatTTS.Chat() + + # 兼容不同版本 API:load_models(旧版)/ load(新版) + if hasattr(chat, "load_models"): + chat.load_models(compile=False) + elif hasattr(chat, "load"): + chat.load(compile=False) + else: + raise RuntimeError("当前 ChatTTS 版本缺少 load / load_models 方法。") + + _chat = chat + logger.info("ChatTTS 模型加载成功。") + return _chat, None + + except ImportError as exc: + _chat_error = ( + "未安装 ChatTTS,请参考 DEPLOY.md 安装。\n" + f"原始错误: {exc}" + ) + logger.exception("ChatTTS 导入失败") + return None, _chat_error + + except Exception as exc: + _chat_error = f"ChatTTS 模型加载失败: {exc}\n{traceback.format_exc()}" + logger.exception("ChatTTS 初始化异常") + return None, _chat_error + + +def _encode_spk_emb(chat, tensor_or_str: Any) -> str: + """将 Speaker Embedding 编码为 ChatTTS 可用的字符串格式。""" + if isinstance(tensor_or_str, str): + return tensor_or_str + + if hasattr(chat, "_encode_spk_emb"): + return chat._encode_spk_emb(tensor_or_str) + + # 兜底:直接转字符串(部分版本可接受 tensor) + return tensor_or_str + + +def save_fixed_speaker( + audio_sample_path: str, + sample_transcript: str = "", +) -> Tuple[bool, str]: + """ + 从 10-30 秒干净人声中提取 Speaker Embedding 并序列化保存。 + + Args: + audio_sample_path: 参考人声 wav/mp3 等路径 + sample_transcript: 参考音频的精确转写(可选,有助于 zero-shot 音色还原) + + Returns: + (success, message) + """ + if not audio_sample_path: + return False, "未提供音色参考音频。" + + chat, init_err = get_chattts_instance() + if chat is None: + return False, init_err or "ChatTTS 不可用。" + + try: + audio = _load_audio_for_chattts(audio_sample_path, TTS_SAMPLE_RATE) + duration = _get_audio_duration_sec(audio, TTS_SAMPLE_RATE) + + if duration < SPEAKER_SAMPLE_MIN_SEC: + return False, ( + f"参考音频过短({duration:.1f}s),建议 {SPEAKER_SAMPLE_MIN_SEC}-" + f"{SPEAKER_SAMPLE_MAX_SEC} 秒干净人声。" + ) + if duration > SPEAKER_SAMPLE_MAX_SEC + 5: + logger.warning("参考音频超过建议时长 %.1fs,将截取前 %ds", duration, SPEAKER_SAMPLE_MAX_SEC) + max_samples = SPEAKER_SAMPLE_MAX_SEC * TTS_SAMPLE_RATE + audio = audio[:max_samples] + + # 从参考音频提取音色特征 + spk_smp = chat.sample_audio_speaker(audio) + + # 同时保存编码后的 spk_emb 字符串,便于 infer 时直接使用 + spk_emb = _encode_spk_emb(chat, spk_smp) + + payload: Dict[str, Any] = { + "spk_emb": spk_emb, + "spk_smp": spk_smp, + "txt_smp": sample_transcript.strip(), + "created_at": datetime.now().isoformat(), + "source_audio": str(audio_sample_path), + } + + torch.save(payload, SPEAKER_EMB_PATH) + + msg = ( + f"音色已锁定并保存至 {SPEAKER_EMB_PATH}\n" + f"参考时长: {duration:.1f}s" + ) + if not sample_transcript.strip(): + msg += "\n提示:填写参考音频精确转写可进一步提升音色还原度。" + + logger.info("Speaker Embedding 保存成功: %s", SPEAKER_EMB_PATH) + return True, msg + + except Exception as exc: + err = f"音色提取失败: {exc}\n{traceback.format_exc()}" + logger.exception("save_fixed_speaker 失败") + return False, err + + +def _load_speaker_payload() -> Tuple[Optional[Dict[str, Any]], Optional[str]]: + """加载本地 speaker_emb.pt。""" + if not SPEAKER_EMB_PATH.exists(): + return None, ( + f"未找到固定音色文件 `{SPEAKER_EMB_PATH.name}`。" + "请先在【音色锁定】模块上传 10-30 秒参考人声。" + ) + + try: + payload = torch.load(SPEAKER_EMB_PATH, map_location="cpu", weights_only=False) + + # 兼容旧版仅保存 tensor 的文件 + if isinstance(payload, torch.Tensor): + chat, err = get_chattts_instance() + if chat is None: + return None, err + return { + "spk_emb": _encode_spk_emb(chat, payload), + "spk_smp": None, + "txt_smp": "", + }, None + + if not isinstance(payload, dict): + return None, "speaker_emb.pt 格式无效,请重新锁定音色。" + + return payload, None + + except Exception as exc: + return None, f"读取 speaker_emb.pt 失败: {exc}" + + +def speaker_is_ready() -> Tuple[bool, str]: + """检查固定音色是否已配置。""" + payload, err = _load_speaker_payload() + if payload is None: + return False, err or "音色未配置。" + return True, f"已加载固定音色: {SPEAKER_EMB_PATH}" + + +def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]: + """ + 使用 ChatTTS 将润色后的文稿合成为 wav 配音。 + + Args: + refined_text: LLM 润色后的配音稿 + + Returns: + (success, message, output_wav_path_or_none) + """ + if not refined_text or not refined_text.strip(): + return False, "合成文本为空,请先完成润色。", None + + chat, init_err = get_chattts_instance() + if chat is None: + return False, init_err or "ChatTTS 不可用。", None + + payload, spk_err = _load_speaker_payload() + if payload is None: + return False, spk_err or "请先锁定音色。", None + + try: + import ChatTTS + + spk_emb = payload.get("spk_emb") + spk_smp = payload.get("spk_smp") + txt_smp = payload.get("txt_smp", "") + + params_infer_code = ChatTTS.Chat.InferCodeParams( + prompt=TTS_SPEED_PROMPT, + spk_emb=spk_emb, + spk_smp=spk_smp if spk_smp else None, + txt_smp=txt_smp if txt_smp else None, + temperature=TTS_TEMPERATURE, + top_P=TTS_TOP_P, + top_K=TTS_TOP_K, + ) + + # 内向克制语气:降低 oral 强度 + params_refine_text = ChatTTS.Chat.RefineTextParams( + prompt="[oral_2][laugh_0][break_4]", + ) + + wavs = chat.infer( + refined_text.strip(), + skip_refine_text=False, + params_refine_text=params_refine_text, + params_infer_code=params_infer_code, + ) + + if not wavs or len(wavs) == 0: + return False, "ChatTTS 未生成有效音频。", None + + wav_array = np.asarray(wavs[0], dtype=np.float32) + + # 归一化并转 int16 + peak = np.max(np.abs(wav_array)) or 1.0 + wav_int16 = (wav_array / peak * 32767).astype(np.int16) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"voiceover_{timestamp}_{uuid.uuid4().hex[:6]}.wav" + output_path = OUTPUT_DIR / filename + + wavfile.write(str(output_path), TTS_SAMPLE_RATE, wav_int16) + + msg = f"配音合成成功: {output_path}" + logger.info(msg) + return True, msg, str(output_path) + + except Exception as exc: + err = f"语音合成失败: {exc}\n{traceback.format_exc()}" + logger.exception("generate_voice 失败") + return False, err, None diff --git a/whisper_service.py b/whisper_service.py new file mode 100644 index 0000000..2853fc5 --- /dev/null +++ b/whisper_service.py @@ -0,0 +1,155 @@ +""" +Faster-Whisper CUDA 语音识别服务 +封装本地 GPU 加速的音频转写逻辑,适配 RTX 3060 Ti 8GB 显存。 +""" + +from __future__ import annotations + +import logging +import traceback +from typing import Optional, Tuple + +from config import ( + WHISPER_COMPUTE_TYPE, + WHISPER_DEVICE, + WHISPER_LANGUAGE, + WHISPER_MODEL_SIZE, +) + +logger = logging.getLogger(__name__) + +# 全局懒加载模型实例,避免 Gradio 重复初始化占用显存 +_model = None +_model_error: Optional[str] = None + + +def _is_cuda_error(exc: BaseException) -> bool: + """判断异常是否与 CUDA/GPU 相关。""" + msg = str(exc).lower() + cuda_keywords = ( + "cuda", + "cudnn", + "cublas", + "gpu", + "out of memory", + "no kernel image", + "device-side assert", + ) + return any(k in msg for k in cuda_keywords) + + +def get_whisper_model(): + """ + 获取或初始化 Faster-Whisper 模型。 + 强制 device=cuda, compute_type=float16。 + """ + global _model, _model_error + + if _model is not None: + return _model, None + + if _model_error is not None: + return None, _model_error + + try: + from faster_whisper import WhisperModel + + logger.info( + "正在加载 Whisper 模型: size=%s, device=%s, compute_type=%s", + WHISPER_MODEL_SIZE, + WHISPER_DEVICE, + WHISPER_COMPUTE_TYPE, + ) + _model = WhisperModel( + WHISPER_MODEL_SIZE, + device=WHISPER_DEVICE, + compute_type=WHISPER_COMPUTE_TYPE, + ) + logger.info("Whisper 模型加载成功。") + return _model, None + + except ImportError as exc: + _model_error = ( + "未安装 faster-whisper,请执行: pip install faster-whisper\n" + f"原始错误: {exc}" + ) + logger.exception("faster-whisper 导入失败") + return None, _model_error + + except Exception as exc: + if _is_cuda_error(exc): + _model_error = ( + "CUDA 初始化失败,请检查 NVIDIA 驱动、CUDA 运行时及 cuDNN 是否正确安装。\n" + f"错误详情: {exc}\n" + f"{traceback.format_exc()}" + ) + else: + _model_error = f"Whisper 模型加载失败: {exc}\n{traceback.format_exc()}" + logger.exception("Whisper 模型加载异常") + return None, _model_error + + +def transcribe_audio(audio_path: str) -> Tuple[bool, str]: + """ + 将音频文件转写为中文文本。 + + Args: + audio_path: 本地音频文件绝对或相对路径 + + Returns: + (success, text_or_error_message) + """ + if not audio_path: + return False, "未提供音频文件路径。" + + model, init_error = get_whisper_model() + if model is None: + return False, init_error or "Whisper 模型不可用。" + + try: + segments, info = model.transcribe( + audio_path, + language=WHISPER_LANGUAGE, + beam_size=5, + vad_filter=True, + ) + + text_parts = [] + for segment in segments: + text_parts.append(segment.text.strip()) + + result_text = "".join(text_parts).strip() + + if not result_text: + return False, ( + "识别结果为空,请检查音频是否有效、音量是否足够," + f"或尝试更换格式。检测到语言: {getattr(info, 'language', 'unknown')}" + ) + + logger.info( + "转写完成: 语言=%s, 概率=%.2f, 字数=%d", + getattr(info, "language", "?"), + getattr(info, "language_probability", 0.0), + len(result_text), + ) + return True, result_text + + except Exception as exc: + if _is_cuda_error(exc): + err = ( + "CUDA 推理异常:显存可能不足或 GPU 状态异常。" + "建议关闭其他占用显存的进程后重试。\n" + f"错误详情: {exc}" + ) + else: + err = f"音频转写失败: {exc}\n{traceback.format_exc()}" + + logger.exception("transcribe_audio 失败") + return False, err + + +def reset_whisper_model() -> None: + """释放模型引用(用于调试或显存回收)。""" + global _model, _model_error + _model = None + _model_error = None