Initial commit: add Trading Studio voice-over pipeline for quant trading review videos.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-12 13:19:44 +08:00
commit 5e95d3af2f
10 changed files with 1862 additions and 0 deletions
+46
View File
@@ -0,0 +1,46 @@
# Python 虚拟环境
venv/
.venv/
env/
# 模型权重与音色文件(体积大,不入库)
*.pt
*.pth
*.onnx
*.bin
*.safetensors
# 音频产物
*.wav
*.mp3
*.flac
*.ogg
*.m4a
# 日志
*.log
# 运行时目录
uploads/
outputs/
__pycache__/
*.py[cod]
*$py.class
.Python
# IDE
.idea/
.vscode/
*.swp
*.swo
# 系统文件
.DS_Store
Thumbs.db
# 环境变量与密钥
.env
.env.*
# Gradio 临时
gradio_cached_examples/
+488
View File
@@ -0,0 +1,488 @@
# Trading Studio 部署指南 (DEPLOY.md)
本文档面向 **Ubuntu 物理服务器**(搭载 RTX 3060 Ti,已锁定 120W 功耗墙)的完整环境配置与 PM2 常驻部署流程。适用于首次安装或迁移重装场景。
**Git 仓库:** https://git.bz121.com/dekun/Trading_Studio.git
---
## 目录
1. [硬件与系统前提](#1-硬件与系统前提)
2. [3060 Ti 120W 功耗墙配置](#2-3060-ti-120w-功耗墙配置)
3. [NVIDIA 驱动与 CUDA](#3-nvidia-驱动与-cuda)
4. [Python 虚拟环境](#4-python-虚拟环境)
5. [PyTorch CUDA 12.1 安装](#5-pytorch-cuda-121-安装)
6. [项目依赖安装](#6-项目依赖安装)
7. [远程 Ollama 节点配置](#7-远程-ollama-节点配置)
8. [首次运行与验证](#8-首次运行与验证)
9. [PM2 进程守护](#9-pm2-进程守护)
10. [迁移与故障排查](#10-迁移与故障排查)
---
## 1. 硬件与系统前提
| 项目 | 要求 |
|------|------|
| GPU | NVIDIA RTX 3060 Ti 8GB |
| 功耗墙 | 120W(推荐锁定,见下文) |
| 系统 | Ubuntu 22.04 / 24.04 LTS |
| 内存 | ≥ 16GB |
| 磁盘 | ≥ 30GB 可用(含模型缓存) |
| 网络 | 局域网可访问 `192.168.8.64:11434` |
```bash
# 基础工具
sudo apt update && sudo apt upgrade -y
sudo apt install -y git curl wget build-essential \
python3 python3-venv python3-dev \
ffmpeg libsndfile1 portaudio19-dev
```
> 若 `python3-venv` 包名报错,使用 `python3-venv`。
---
## 2. 3060 Ti 120W 功耗墙配置
锁定 GPU 功耗有助于稳定 7×24 运行、降低散热压力,避免 Whisper + ChatTTS 并发时触发功耗波动。
### 2.1 安装 nvidia-smi 功耗管理工具
驱动安装后自带 `nvidia-smi`。确认 GPU 可见:
```bash
nvidia-smi
```
### 2.2 临时设置 120W 功耗上限
```bash
# 查看支持的功耗范围
nvidia-smi -q -d POWER | grep -A3 "Power Limit"
# 设置最大功耗为 120W(需 root)
sudo nvidia-smi -pl 120
```
### 2.3 开机持久化(推荐)
创建 systemd 服务,每次启动自动应用:
```bash
sudo tee /etc/systemd/system/nvidia-powerlimit.service << 'EOF'
[Unit]
Description=Set NVIDIA GPU Power Limit to 120W
After=multi-user.target
[Service]
Type=oneshot
ExecStart=/usr/bin/nvidia-smi -pl 120
RemainAfterExit=yes
[Install]
WantedBy=multi-user.target
EOF
sudo systemctl daemon-reload
sudo systemctl enable nvidia-powerlimit.service
sudo systemctl start nvidia-powerlimit.service
# 验证
nvidia-smi --query-gpu=power.limit --format=csv
```
---
## 3. NVIDIA 驱动与 CUDA
### 3.1 安装驱动(推荐 535+ 或 550+)
```bash
# Ubuntu 自动安装推荐驱动
sudo ubuntu-drivers devices
sudo ubuntu-drivers autoinstall
# 或指定版本: sudo apt install nvidia-driver-550
sudo reboot
```
重启后验证:
```bash
nvidia-smi
nvcc --version # 若未安装 nvcc 不影响 PyTorch,可选
```
### 3.2 cuDNNFaster-Whisper / PyTorch 需要)
PyTorch cu121 wheel 通常自带运行时库。若 Whisper 报 cuDNN 错误:
```bash
# 参考 NVIDIA 官方文档安装 cuDNN for CUDA 12.x
# https://developer.nvidia.com/cudnn
```
---
## 4. Python 虚拟环境
```bash
# 克隆项目
cd ~
git clone https://git.bz121.com/dekun/Trading_Studio.git
cd Trading_Studio
# 创建虚拟环境(必须使用 venv,与 PM2 interpreter 路径一致)
python3 -m venv venv
# 激活
source venv/bin/activate
# 升级 pip
pip install --upgrade pip setuptools wheel
```
**重要:** PM2 配置中 `interpreter` 指向 `./venv/bin/python`,请确保在项目根目录创建 `venv/`
---
## 5. PyTorch CUDA 12.1 安装
**必须先于其他 GPU 依赖安装**,避免 pip 拉取 CPU 版 torch。
```bash
source venv/bin/activate
pip install torch torchvision torchaudio \
--index-url https://download.pytorch.org/whl/cu121
```
验证 CUDA 可用:
```bash
python -c "
import torch
print('PyTorch:', torch.__version__)
print('CUDA available:', torch.cuda.is_available())
print('GPU:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A')
"
```
期望输出类似:
```
PyTorch: 2.x.x+cu121
CUDA available: True
GPU: NVIDIA GeForce RTX 3060 Ti
```
---
## 6. 项目依赖安装
```bash
source venv/bin/activate
cd ~/Trading_Studio
# 安装其余依赖
pip install -r requirements.txt
```
### 6.1 Faster-Whisper
`requirements.txt` 安装。首次运行会自动下载 `small` 模型(约 500MB)至 HuggingFace 缓存。
### 6.2 ChatTTS
从 GitHub 源码安装(已在 requirements.txt 中指定):
```bash
pip install ChatTTS @ git+https://github.com/2noise/ChatTTS.git
```
首次 `save_fixed_speaker``generate_voice` 时会下载模型权重(数 GB),请确保网络畅通或提前配置 HuggingFace 镜像:
```bash
export HF_ENDPOINT=https://hf-mirror.com # 可选,国内加速
```
### 6.3 Gradio
```bash
pip install gradio>=4.44.0
```
---
## 7. 远程 Ollama 节点配置
Trading Studio 的 LLM 润色模块连接局域网 Ollama,**不在本机运行大模型**。
| 配置项 | 值 |
|--------|-----|
| 地址 | `http://192.168.8.64:11434` |
| API | `POST /api/chat` |
| 模型 | `huihui_ai/gemma-4-abliterated:e4b` |
| 流式 | `stream: false` |
### 7.1 在 Ollama 节点(192.168.8.64)上
```bash
# 安装 Ollama(若未安装)
curl -fsSL https://ollama.com/install.sh | sh
# 拉取模型
ollama pull huihui_ai/gemma-4-abliterated:e4b
# 允许局域网访问(编辑 systemd 或环境变量)
sudo systemctl edit ollama
```
添加:
```ini
[Service]
Environment="OLLAMA_HOST=0.0.0.0:11434"
```
```bash
sudo systemctl daemon-reload
sudo systemctl restart ollama
```
### 7.2 在本机(Trading Studio 服务器)验证
```bash
curl http://192.168.8.64:11434/api/tags
curl http://192.168.8.64:11434/api/chat -d '{
"model": "huihui_ai/gemma-4-abliterated:e4b",
"messages": [{"role": "user", "content": "ping"}],
"stream": false
}'
```
---
## 8. 首次运行与验证
```bash
source venv/bin/activate
cd ~/Trading_Studio
# 前台启动(调试)
python app.py
```
浏览器访问:
```
http://<本机局域网IP>:5683
```
### 8.1 验证清单
- [ ] 页面加载,Ollama 状态显示在线
- [ ] 上传 10-30s 参考人声 → 音色锁定成功,生成 `speaker_emb.pt`
- [ ] 上传复盘录音 → Whisper 识别出中文文本
- [ ] 点击润色 → 返回 Gemma4 处理后的文稿
- [ ] 点击合成 → `outputs/` 下生成 24kHz wav
### 8.2 日志位置
- 应用日志:`trading_studio.log`(项目根目录)
- PM2 日志:`logs/pm2-out.log``logs/pm2-error.log`
```bash
mkdir -p logs
```
---
## 9. PM2 进程守护
Trading Studio 原生支持 PM2 常驻管理,确保 Gradio 服务崩溃后自动重启、开机自启。
### 9.1 安装 Node.js 与 PM2
```bash
# 安装 Node.js 20 LTS
curl -fsSL https://deb.nodesource.com/setup_20.x | sudo -E bash -
sudo apt install -y nodejs
# 全局安装 PM2
sudo npm install -g pm2
```
### 9.2 方式 A:使用 ecosystem.config.js(推荐)
项目已内置 `ecosystem.config.js`
```javascript
module.exports = {
apps: [{
name: "trading_studio",
script: "app.py",
interpreter: "./venv/bin/python",
cwd: __dirname,
instances: 1,
autorestart: true,
max_memory_restart: "6G",
env: {
PYTHONUNBUFFERED: "1",
CUDA_VISIBLE_DEVICES: "0",
},
}],
};
```
启动:
```bash
cd ~/Trading_Studio
mkdir -p logs
pm2 start ecosystem.config.js
pm2 status
pm2 logs trading_studio --lines 50
```
### 9.3 方式 B:直接命令行
```bash
cd ~/Trading_Studio
pm2 start app.py \
--name "trading_studio" \
--interpreter ./venv/bin/python
pm2 save
```
### 9.4 开机自启
```bash
pm2 startup
# 按提示执行输出的 sudo 命令
pm2 save
```
### 9.5 常用运维命令
```bash
pm2 restart trading_studio # 重启(改代码后)
pm2 stop trading_studio # 停止
pm2 delete trading_studio # 移除
pm2 monit # 实时监控 CPU/内存
```
### 9.6 更新代码后重新部署
```bash
cd ~/Trading_Studio
git pull
source venv/bin/activate
pip install -r requirements.txt # 若有新依赖
pm2 restart trading_studio
```
---
## 10. 迁移与故障排查
### 10.1 迁移到新机器
1. 复制 `speaker_emb.pt`(音色文件,在 `.gitignore` 中,需手动备份)
2. 新机器按本文档完整部署
3.`speaker_emb.pt` 放回项目根目录
4. `pm2 restart trading_studio`
### 10.2 CUDA / 显存问题
```bash
# 查看显存占用
nvidia-smi
# 若 OOM,确保无其他 GPU 进程
fuser -v /dev/nvidia*
```
Whisper 与 ChatTTS 不会同时常驻最大显存,但首次加载模型时峰值较高。建议:
- 锁定 120W 功耗墙
- `max_memory_restart: "6G"` 已在 PM2 配置中设置
### 10.3 Whisper CUDA 报错
```
错误: CUDA initialization failed / out of memory
```
处理:
1. 重启 PM2 进程释放显存
2. 确认 `compute_type="float16"`(已在 config.py 配置)
3. 降级模型为 `base`(修改 `config.py``WHISPER_MODEL_SIZE`
### 10.4 Ollama 超时
```
连接 Ollama 超时(>60s
```
处理:
1. 确认 Ollama 节点模型已预加载:`ollama run huihui_ai/gemma-4-abliterated:e4b`
2. 增大 `config.py``OLLAMA_TIMEOUT`
3. 检查防火墙:`sudo ufw allow from 192.168.8.0/24 to any port 11434`(在 Ollama 节点)
### 10.5 ChatTTS 音色文件损坏
```bash
rm speaker_emb.pt
# 重新在 Web UI「音色锁定」上传参考人声
```
### 10.6 端口 5683 被占用
```bash
sudo lsof -i :5683
# 或
ss -tlnp | grep 5683
```
---
## 附录:防火墙(本机 Gradio)
若需局域网其他设备访问 Web UI:
```bash
sudo ufw allow 5683/tcp
sudo ufw reload
```
访问地址:`http://<服务器局域网IP>:5683`
---
## 附录:config.py 关键常量速查
```python
HOST = "0.0.0.0"
PORT = 5683
OLLAMA_URL = "http://192.168.8.64:11434/api/chat"
MODEL_NAME = "huihui_ai/gemma-4-abliterated:e4b"
WHISPER_MODEL_SIZE = "small"
WHISPER_DEVICE = "cuda"
WHISPER_COMPUTE_TYPE = "float16"
SPEAKER_EMB_PATH = "speaker_emb.pt"
TTS_SAMPLE_RATE = 24000
```
---
**部署完成后,请先在「音色锁定」模块完成首次音色提取,再进行日常复盘配音生产。**
+197
View File
@@ -0,0 +1,197 @@
# Trading Studio
**本地量化交易复盘 → B 站长视频配音生产流水线**
Trading Studio 是一套运行在 Ubuntu 物理服务器(RTX 3060 Ti)上的自动化配音系统,专为数字资产量化交易员设计。通过「盲录碎碎念 → 本地 GPU 识别 → 局域网大模型严厉润色 → 本地 GPU 声音克隆」的闭环,高效产出 B 站反思类长视频配音,辅助交易纪律的自我进化。
**Git 仓库:** https://git.bz121.com/dekun/Trading_Studio.git
---
## 系统定位
| 环节 | 技术栈 | 运行位置 |
|------|--------|----------|
| 碎碎念录音转写 | Faster-Whisper (CUDA float16) | 本地 3060 Ti |
| 纪律审判式润色 | Gemma4 Abliterated @ Ollama | 局域网 `192.168.8.64` |
| 固定音色配音 | ChatTTS (CUDA) | 本地 3060 Ti |
| Web 中控 | Gradio | 端口 **5683** |
---
## 架构说明
```
┌─────────────────────────────────────────────────────────────┐
│ Gradio 中控 (app.py:5683) │
├──────────────┬──────────────────────┬───────────────────────┤
│ 音色锁定 │ 音频识别 │ 润色 + 合成 │
│ tts_service │ whisper_service │ llm_service │
│ │ │ tts_service │
└──────┬───────┴──────────┬───────────┴──────────┬────────────┘
│ │ │
▼ ▼ ▼
speaker_emb.pt Faster-Whisper Ollama HTTP
(本地持久化) CUDA / small 192.168.8.64:11434
gemma-4-abliterated
```
### 模块解耦
| 文件 | 职责 |
|------|------|
| `config.py` | Ollama 地址、模型名、Prompt、路径等全局配置 |
| `whisper_service.py` | Faster-Whisper CUDA 转写 |
| `llm_service.py` | 远程 Ollama HTTP 非流式润色 |
| `tts_service.py` | ChatTTS 音色提取与 wav 合成 |
| `app.py` | Gradio 前端与流程编排 |
---
## 快速开始
> 完整环境配置请参阅 [DEPLOY.md](./DEPLOY.md)
```bash
# 1. 克隆仓库
git clone https://git.bz121.com/dekun/Trading_Studio.git
cd Trading_Studio
# 2. 创建虚拟环境并安装依赖(详见 DEPLOY.md)
python3 -m venv venv
source venv/bin/activate
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txt
# 3. 启动中控
python app.py
# 浏览器访问: http://<服务器IP>:5683
```
---
## 使用流程
### 首次使用:锁定音色
1. 进入 **「音色锁定」** 标签页
2. 上传 10-30 秒干净人声参考(你的碎碎念盲录样本)
3. (可选)填写参考音频的精确转写,提升 zero-shot 还原度
4. 点击 **锁定音色** → 生成 `speaker_emb.pt`
### 日常生产
**方式 A — 分步操作:**
1. **音频极速识别**:上传复盘录音 → Whisper 转写
2. **Gemma4 纪律审判**:一键润色,生成逻辑清晰、语气严厉的反思稿
3. **ChatTTS 合成**:输出 24kHz `.wav` 成品配音
**方式 B — 一键生产:**
上传录音后点击 **启动全流程**,系统自动串联三步。
---
## 核心配置(config.py
| 配置项 | 默认值 |
|--------|--------|
| 中控端口 | `5683``0.0.0.0` 局域网可访问) |
| Ollama 地址 | `http://192.168.8.64:11434` |
| 模型名称 | `huihui_ai/gemma-4-abliterated:e4b` |
| Whisper 模型 | `small` / CUDA / float16 |
| 音色文件 | `speaker_emb.pt` |
| 音频输出 | `outputs/` 目录 |
---
## PM2 守护运行
```bash
# 方式 1ecosystem 配置
pm2 start ecosystem.config.js
# 方式 2:直接命令
pm2 start app.py --name "trading_studio" --interpreter ./venv/bin/python
# 常用管理
pm2 status
pm2 logs trading_studio
pm2 restart trading_studio
pm2 save && pm2 startup # 开机自启
```
---
## .gitignore 配置
提交 Git 时请确保忽略以下产物(已在 `.gitignore` 中预设):
```gitignore
venv/
*.wav
*.pt
*.log
uploads/
outputs/
```
**说明:**
- `venv/` — Python 虚拟环境,每台机器独立创建
- `*.wav` — 录音与合成音频,体积大且含隐私
- `*.pt` — ChatTTS 音色 Embedding 与模型权重
- `*.log` — 运行日志
---
## 目录结构
```
Trading_Studio/
├── app.py # Gradio 主入口
├── config.py # 全局配置
├── whisper_service.py # Whisper CUDA 识别
├── llm_service.py # Ollama 远程润色
├── tts_service.py # ChatTTS 音色与合成
├── ecosystem.config.js # PM2 守护配置
├── requirements.txt # Python 依赖
├── README.md # 本文件
├── DEPLOY.md # 部署指南
├── .gitignore
├── speaker_emb.pt # 音色文件(运行时生成,不入库)
├── uploads/ # 上传临时目录
└── outputs/ # 合成 wav 输出
```
---
## 硬件要求
- **GPU** NVIDIA RTX 3060 Ti8GB 显存,建议锁定 120W 功耗墙)
- **系统:** Ubuntu 22.04 / 24.04 LTS
- **CUDA** 12.1+(与 PyTorch cu121 匹配)
- **局域网:** 可访问 `192.168.8.64:11434` 的 Ollama 节点
---
## 常见问题
**Q: Whisper 报 CUDA 错误?**
A: 确认 `nvidia-smi` 正常,且未同时运行其他占显存任务。Whisper 使用 `float16` 已针对 8GB 优化。
**Q: Ollama 连接失败?**
A: 在服务器上执行 `curl http://192.168.8.64:11434/api/tags` 验证连通性,确认模型已 `ollama pull`
**Q: TTS 音色不稳定?**
A: 重新锁定音色,填写参考音频精确转写,并保持 `temperature=0.3` 低随机性。
**Q: 合成音频为空或噪声?**
A: 检查润色文本长度(过短可能导致异常),确认 `speaker_emb.pt` 存在且有效。
---
## License
Private — 仅供个人量化交易复盘使用。
+378
View File
@@ -0,0 +1,378 @@
"""
Trading Studio — 自动化交易复盘视频配音系统
Gradio Web 中控:音色锁定 → Whisper 识别 → Gemma4 润色 → ChatTTS 合成
"""
from __future__ import annotations
import logging
import shutil
import sys
import uuid
from pathlib import Path
import gradio as gr
from config import (
GIT_REPO_URL,
HOST,
MODEL_NAME,
OLLAMA_URL,
PORT,
SPEAKER_EMB_PATH,
UPLOAD_DIR,
)
from llm_service import check_ollama_health, polish_text
from tts_service import generate_voice, save_fixed_speaker, speaker_is_ready
from whisper_service import transcribe_audio
# ---------------------------------------------------------------------------
# 日志
# ---------------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler("trading_studio.log", encoding="utf-8"),
],
)
logger = logging.getLogger("trading_studio")
# ---------------------------------------------------------------------------
# 全局 UI 状态(Gradio State
# ---------------------------------------------------------------------------
# raw_transcript / polished_script 在流水线中传递
def _save_upload(upload_file) -> str | None:
"""将 Gradio 上传文件复制到本地 uploads 目录,返回持久化路径。"""
if upload_file is None:
return None
src = Path(upload_file)
if not src.exists():
return None
dest = UPLOAD_DIR / f"{uuid.uuid4().hex}_{src.name}"
shutil.copy2(src, dest)
return str(dest)
# ---------------------------------------------------------------------------
# 模块 1:音色锁定
# ---------------------------------------------------------------------------
def ui_lock_speaker(audio_file, sample_transcript: str) -> str:
"""【音色锁定】从参考人声提取并保存 Speaker Embedding。"""
path = _save_upload(audio_file)
if not path:
return "请上传 10-30 秒干净参考人声(wav/mp3 均可)。"
ok, msg = save_fixed_speaker(path, sample_transcript or "")
return msg if ok else f"{msg}"
def ui_speaker_status() -> str:
"""刷新音色状态。"""
ok, msg = speaker_is_ready()
return f"{msg}" if ok else f"⚠️ {msg}"
# ---------------------------------------------------------------------------
# 模块 2:音频极速识别
# ---------------------------------------------------------------------------
def ui_transcribe(audio_file) -> tuple[str, str]:
"""【Whisper 识别】返回 (转写文本, 状态日志)。"""
path = _save_upload(audio_file)
if not path:
return "", "请上传待识别的碎碎念录音。"
ok, result = transcribe_audio(path)
if ok:
return result, f"✅ 识别完成,共 {len(result)} 字。"
return "", f"{result}"
# ---------------------------------------------------------------------------
# 模块 3Gemma4 纪律审判
# ---------------------------------------------------------------------------
def ui_polish(raw_text: str) -> tuple[str, str]:
"""【LLM 润色】对转写稿进行严厉自我反思式润色。"""
if not raw_text or not raw_text.strip():
return "", "请先完成语音识别或手动粘贴转写文本。"
ok, result = polish_text(raw_text)
if ok:
return result, f"✅ Gemma4 润色完成,共 {len(result)} 字。"
return "", f"{result}"
def ui_check_ollama() -> str:
"""检测远程 Ollama 节点状态。"""
ok, msg = check_ollama_health()
return f"{msg}" if ok else f"{msg}"
# ---------------------------------------------------------------------------
# 模块 4ChatTTS 音频合成
# ---------------------------------------------------------------------------
def ui_synthesize(polished_text: str) -> tuple[str | None, str]:
"""【TTS 合成】生成最终 wav 配音文件。"""
if not polished_text or not polished_text.strip():
return None, "请先完成 Gemma4 润色。"
ok, msg, wav_path = generate_voice(polished_text)
if ok and wav_path:
return wav_path, f"{msg}"
return None, f"{msg}"
# ---------------------------------------------------------------------------
# 一键流水线
# ---------------------------------------------------------------------------
def ui_full_pipeline(
audio_file,
skip_polish: bool,
manual_raw: str,
) -> tuple[str, str, str | None, str]:
"""
串联执行:识别 → 润色(可跳过)→ 合成。
返回 (raw, polished, wav_path, log)
"""
logs: list[str] = []
# Step 1: 识别
if manual_raw and manual_raw.strip():
raw = manual_raw.strip()
logs.append(f"使用手动输入转写稿({len(raw)} 字)。")
else:
path = _save_upload(audio_file)
if not path:
return "", "", None, "❌ 请上传录音或手动填写转写文本。"
ok, result = transcribe_audio(path)
if not ok:
return "", "", None, f"❌ 识别失败: {result}"
raw = result
logs.append(f"✅ Whisper 识别完成({len(raw)} 字)。")
# Step 2: 润色
if skip_polish:
polished = raw
logs.append("已跳过 LLM 润色,直接使用原文。")
else:
ok, result = polish_text(raw)
if not ok:
return raw, "", None, f"❌ 润色失败: {result}\n" + "\n".join(logs)
polished = result
logs.append(f"✅ Gemma4 润色完成({len(polished)} 字)。")
# Step 3: 合成
ok, msg, wav_path = generate_voice(polished)
if not ok:
return raw, polished, None, f"❌ 合成失败: {msg}\n" + "\n".join(logs)
logs.append(f"{msg}")
return raw, polished, wav_path, "\n".join(logs)
# ---------------------------------------------------------------------------
# Gradio 界面
# ---------------------------------------------------------------------------
CUSTOM_CSS = """
/* 硬核暗黑科技风 */
.gradio-container {
background: linear-gradient(160deg, #0a0a0f 0%, #12121a 40%, #0d0d12 100%) !important;
color: #c8c8d0 !important;
}
.dark-panel {
border: 1px solid #2a2a35;
border-radius: 8px;
padding: 16px;
background: rgba(18, 18, 26, 0.85);
margin-bottom: 12px;
}
h1, h2, h3 { color: #e8e8f0 !important; letter-spacing: 0.05em; }
.status-bar {
font-family: 'Consolas', 'Monaco', monospace;
font-size: 0.85em;
color: #7a7a90;
}
footer { visibility: hidden; }
"""
def build_app() -> gr.Blocks:
"""构建 Gradio 主界面。"""
theme = gr.themes.Monochrome(
primary_hue="slate",
secondary_hue="gray",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"],
font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "Consolas", "monospace"],
).set(
body_background_fill="#0a0a0f",
body_background_fill_dark="#0a0a0f",
block_background_fill="#12121a",
block_background_fill_dark="#12121a",
block_border_color="#2a2a35",
block_label_text_color="#9090a0",
input_background_fill="#1a1a24",
button_primary_background_fill="#3a3a50",
button_primary_background_fill_hover="#4a4a60",
)
with gr.Blocks(
title="Trading Studio | 交易复盘配音中控",
theme=theme,
css=CUSTOM_CSS,
) as demo:
gr.Markdown(
f"""
# ⚡ Trading Studio
**本地量化交易复盘 → B 站配音生产流水线**
`Whisper(CUDA)` → `Gemma4 @ {OLLAMA_URL}` → `ChatTTS(CUDA)`
> 仓库: [{GIT_REPO_URL}]({GIT_REPO_URL})
""",
elem_classes=["dark-panel"],
)
with gr.Row():
ollama_status = gr.Textbox(
label="Ollama 节点",
value=f"模型: {MODEL_NAME}",
interactive=False,
scale=3,
elem_classes=["status-bar"],
)
speaker_status = gr.Textbox(
label="音色状态",
value="检测中...",
interactive=False,
scale=2,
elem_classes=["status-bar"],
)
refresh_btn = gr.Button("🔄 刷新状态", scale=1)
refresh_btn.click(
fn=lambda: (ui_check_ollama(), ui_speaker_status()),
outputs=[ollama_status, speaker_status],
)
with gr.Tabs():
# ---- Tab 1: 音色锁定 ----
with gr.Tab("🎙️ 音色锁定"):
gr.Markdown(
"上传 **10-30 秒** 干净人声样本,系统将提取 Speaker Embedding "
f"并保存至 `{SPEAKER_EMB_PATH.name}`,后续合成 100% 还原音色。"
)
with gr.Row():
spk_audio = gr.Audio(
label="参考人声(碎碎念盲录样本)",
type="filepath",
sources=["upload", "microphone"],
)
spk_transcript = gr.Textbox(
label="参考音频精确转写(可选,提升还原度)",
placeholder="尽量与参考音频内容完全一致...",
lines=6,
)
lock_btn = gr.Button("🔒 锁定音色", variant="primary")
lock_log = gr.Textbox(label="锁定结果", lines=4, interactive=False)
lock_btn.click(ui_lock_speaker, [spk_audio, spk_transcript], lock_log)
# ---- Tab 2: 分步操作 ----
with gr.Tab("🔧 分步流水线"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Step 1 · 音频极速识别")
rec_audio = gr.Audio(
label="交易复盘碎碎念录音",
type="filepath",
sources=["upload", "microphone"],
)
transcribe_btn = gr.Button("⚡ Faster-Whisper 识别", variant="primary")
transcribe_log = gr.Textbox(label="识别日志", lines=2, interactive=False)
with gr.Column(scale=1):
gr.Markdown("### Step 2 · Gemma4 纪律审判")
raw_text = gr.Textbox(
label="转写原文(可编辑)",
lines=10,
placeholder="识别结果将显示在此,也可手动粘贴...",
)
polish_btn = gr.Button("⚖️ 远程 Gemma4 严厉润色", variant="primary")
polish_log = gr.Textbox(label="润色日志", lines=2, interactive=False)
with gr.Column(scale=1):
gr.Markdown("### Step 3 · ChatTTS 配音合成")
polished_text = gr.Textbox(
label="润色配音稿(可编辑)",
lines=10,
placeholder="润色结果将显示在此...",
)
synth_btn = gr.Button("🔊 合成配音 WAV", variant="primary")
synth_log = gr.Textbox(label="合成日志", lines=2, interactive=False)
output_audio = gr.Audio(label="成品配音", type="filepath")
transcribe_btn.click(ui_transcribe, rec_audio, [raw_text, transcribe_log])
polish_btn.click(ui_polish, raw_text, [polished_text, polish_log])
synth_btn.click(ui_synthesize, polished_text, [output_audio, synth_log])
# ---- Tab 3: 一键生产 ----
with gr.Tab("🚀 一键生产"):
gr.Markdown(
"上传碎碎念录音,系统自动完成 **识别 → 润色 → 合成** 全流程。"
)
with gr.Row():
pipe_audio = gr.Audio(
label="复盘录音",
type="filepath",
sources=["upload", "microphone"],
)
pipe_manual = gr.Textbox(
label="或手动输入转写(跳过识别)",
lines=4,
placeholder="若已有转写文本,可直接粘贴,留空则走 Whisper 识别",
)
skip_polish_cb = gr.Checkbox(
label="跳过 Gemma4 润色(仅测试 TTS)",
value=False,
)
pipeline_btn = gr.Button("▶ 启动全流程", variant="primary", size="lg")
pipeline_log = gr.Textbox(label="流水线日志", lines=6, interactive=False)
with gr.Row():
pipe_raw = gr.Textbox(label="转写原文", lines=6)
pipe_polished = gr.Textbox(label="润色稿", lines=6)
pipe_output = gr.Audio(label="成品配音", type="filepath")
pipeline_btn.click(
ui_full_pipeline,
[pipe_audio, skip_polish_cb, pipe_manual],
[pipe_raw, pipe_polished, pipe_output, pipeline_log],
)
demo.load(
fn=lambda: (ui_check_ollama(), ui_speaker_status()),
outputs=[ollama_status, speaker_status],
)
return demo
def main() -> None:
"""主入口:启动 Gradio 服务。"""
logger.info("Trading Studio 启动中... HOST=%s PORT=%s", HOST, PORT)
app = build_app()
app.launch(
server_name=HOST,
server_port=PORT,
share=False,
show_error=True,
allowed_paths=[str(Path(__file__).resolve().parent / "outputs")],
)
if __name__ == "__main__":
main()
+82
View File
@@ -0,0 +1,82 @@
"""
Trading Studio 全局配置模块
统一存放局域网节点、模型名称、固定 Prompt 及本地路径。
"""
from pathlib import Path
# ---------------------------------------------------------------------------
# 网络与服务
# ---------------------------------------------------------------------------
# 远程 Ollama 节点(局域网大模型审查润色)
OLLAMA_HOST = "192.168.8.64"
OLLAMA_PORT = 11434
OLLAMA_URL = f"http://{OLLAMA_HOST}:{OLLAMA_PORT}/api/chat"
# 指定无限制版 Gemma4 模型
MODEL_NAME = "huihui_ai/gemma-4-abliterated:e4b"
# Gradio 中控固定端口(硬性死规则)
HOST = "0.0.0.0"
PORT = 5683
# HTTP 请求超时(秒)
OLLAMA_TIMEOUT = 60
# ---------------------------------------------------------------------------
# LLM 系统 Prompt
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = (
"你是一个冷静、极其严格的数字资产量化交易员。"
"请把下面这段口语化、包含结巴和逻辑混乱的交易复盘录音转写,"
"润色成一段逻辑清晰、行文通顺的 B 站长视频反思配音稿。"
"语气要内向、克制、严谨。"
"如果原视频中有由于心态不好、违背交易纪律(如手贱乱开仓、提前平仓)"
"导致少赚或亏损的部分,请用冷酷、严厉的语气狠狠地自我吐槽、反思该点。"
"去掉所有无意义的口头禅,字数不做删减。"
)
# ---------------------------------------------------------------------------
# Faster-Whisper 配置
# ---------------------------------------------------------------------------
WHISPER_MODEL_SIZE = "small"
WHISPER_DEVICE = "cuda"
WHISPER_COMPUTE_TYPE = "float16"
WHISPER_LANGUAGE = "zh"
# ---------------------------------------------------------------------------
# ChatTTS 配置
# ---------------------------------------------------------------------------
# 项目根目录
BASE_DIR = Path(__file__).resolve().parent
# 固定音色 Embedding 存储路径
SPEAKER_EMB_PATH = BASE_DIR / "speaker_emb.pt"
# 合成音频输出目录
OUTPUT_DIR = BASE_DIR / "outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# ChatTTS 采样率(Hz
TTS_SAMPLE_RATE = 24000
# 音色样本时长建议(秒)
SPEAKER_SAMPLE_MIN_SEC = 10
SPEAKER_SAMPLE_MAX_SEC = 30
# TTS 推理默认参数(低 temperature 有助于音色稳定)
TTS_TEMPERATURE = 0.3
TTS_TOP_P = 0.7
TTS_TOP_K = 20
TTS_SPEED_PROMPT = "[speed_5]"
# ---------------------------------------------------------------------------
# 上传临时文件目录
# ---------------------------------------------------------------------------
UPLOAD_DIR = BASE_DIR / "uploads"
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
# ---------------------------------------------------------------------------
# Git 仓库(文档引用)
# ---------------------------------------------------------------------------
GIT_REPO_URL = "https://git.bz121.com/dekun/Trading_Studio.git"
+26
View File
@@ -0,0 +1,26 @@
/**
* PM2 进程守护配置
* 用法: pm2 start ecosystem.config.js
*/
module.exports = {
apps: [
{
name: "trading_studio",
script: "app.py",
interpreter: "./venv/bin/python",
cwd: __dirname,
instances: 1,
autorestart: true,
watch: false,
max_memory_restart: "6G",
env: {
PYTHONUNBUFFERED: "1",
CUDA_VISIBLE_DEVICES: "0",
},
error_file: "./logs/pm2-error.log",
out_file: "./logs/pm2-out.log",
log_date_format: "YYYY-MM-DD HH:mm:ss",
merge_logs: true,
},
],
};
+162
View File
@@ -0,0 +1,162 @@
"""
远程 Ollama LLM 润色服务
通过局域网 HTTP 请求 Gemma4 模型,对交易复盘转写稿进行纪律审判式润色。
"""
from __future__ import annotations
import logging
from typing import Tuple
import requests
from config import MODEL_NAME, OLLAMA_TIMEOUT, OLLAMA_URL, SYSTEM_PROMPT
logger = logging.getLogger(__name__)
def _build_payload(raw_text: str) -> dict:
"""构造 Ollama /api/chat 非流式请求体。"""
return {
"model": MODEL_NAME,
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": (
"以下是我的交易复盘录音转写原文,请严格按系统要求润色:\n\n"
f"{raw_text}"
),
},
],
"stream": False,
"options": {
"temperature": 0.7,
"num_predict": 4096,
},
}
def _extract_content(response_json: dict) -> str:
"""从 Ollama 响应 JSON 中提取 assistant 文本。"""
# /api/chat 标准格式
message = response_json.get("message")
if isinstance(message, dict):
content = message.get("content", "").strip()
if content:
return content
# 兼容 /api/generate 格式(部分旧版或代理)
if "response" in response_json:
content = str(response_json["response"]).strip()
if content:
return content
raise ValueError(f"无法从 Ollama 响应中解析文本内容: {response_json}")
def polish_text(raw_text: str) -> Tuple[bool, str]:
"""
调用远程 Ollama 对原始转写文本进行润色。
Args:
raw_text: Whisper 转写得到的原始口语文本
Returns:
(success, polished_text_or_error_message)
"""
if not raw_text or not raw_text.strip():
return False, "润色输入为空,请先完成语音识别。"
payload = _build_payload(raw_text.strip())
try:
logger.info("正在请求 Ollama: %s, model=%s", OLLAMA_URL, MODEL_NAME)
response = requests.post(
OLLAMA_URL,
json=payload,
timeout=OLLAMA_TIMEOUT,
)
response.raise_for_status()
data = response.json()
polished = _extract_content(data)
if not polished:
return False, "Ollama 返回内容为空,请检查模型是否正常加载。"
logger.info("润色完成,输出字数: %d", len(polished))
return True, polished
except requests.exceptions.ConnectTimeout:
err = (
f"连接 Ollama 超时(>{OLLAMA_TIMEOUT}s)。"
f"请确认 {OLLAMA_URL} 可达且 Ollama 服务已启动。"
)
logger.error(err)
return False, err
except requests.exceptions.ReadTimeout:
err = (
f"Ollama 响应超时(>{OLLAMA_TIMEOUT}s)。"
"模型可能正在加载或生成长度过长,请稍后重试。"
)
logger.error(err)
return False, err
except requests.exceptions.ConnectionError as exc:
err = (
f"无法连接到 Ollama 节点 ({OLLAMA_URL})。"
"请检查局域网连通性、防火墙及 Ollama 是否监听 0.0.0.0:11434。\n"
f"详情: {exc}"
)
logger.error(err)
return False, err
except requests.exceptions.HTTPError as exc:
status = exc.response.status_code if exc.response is not None else "?"
body = exc.response.text[:500] if exc.response is not None else ""
err = (
f"Ollama HTTP 错误 ({status})。"
f"请确认模型 `{MODEL_NAME}` 已通过 ollama pull 下载。\n"
f"响应片段: {body}"
)
logger.error(err)
return False, err
except ValueError as exc:
logger.error("Ollama 响应解析失败: %s", exc)
return False, str(exc)
except requests.exceptions.RequestException as exc:
err = f"Ollama 请求异常: {exc}"
logger.exception(err)
return False, err
except Exception as exc:
err = f"润色过程发生未知错误: {exc}"
logger.exception(err)
return False, err
def check_ollama_health() -> Tuple[bool, str]:
"""
快速检测 Ollama 节点是否在线(不触发完整推理)。
Returns:
(online, message)
"""
base_url = OLLAMA_URL.rsplit("/api/", 1)[0]
try:
resp = requests.get(f"{base_url}/api/tags", timeout=10)
resp.raise_for_status()
tags = resp.json().get("models", [])
model_names = [m.get("name", "") for m in tags]
if any(MODEL_NAME.split(":")[0] in name for name in model_names):
return True, f"Ollama 在线,已检测到模型: {MODEL_NAME}"
return True, (
f"Ollama 在线,但未找到模型 {MODEL_NAME}"
f"请执行: ollama pull {MODEL_NAME}"
)
except Exception as exc:
return False, f"Ollama 不可达: {exc}"
+23
View File
@@ -0,0 +1,23 @@
# Trading Studio 依赖清单
# CUDA 版 PyTorch 请按 DEPLOY.md 单独安装(cu121),此处不重复指定
# Web 中控
gradio>=4.44.0
# 语音识别(CUDA 加速)
faster-whisper>=1.0.0
# 远程 LLM 通信
requests>=2.31.0
# 语音合成
ChatTTS @ git+https://github.com/2noise/ChatTTS.git
torchaudio>=2.1.0
scipy>=1.11.0
numpy>=1.24.0
librosa>=0.10.0
# 音频处理辅助
soundfile>=0.12.0
# PM2 通过 Node.js 全局安装,不在 pip 范围内
+305
View File
@@ -0,0 +1,305 @@
"""
ChatTTS 本地语音合成服务
支持从参考人声提取 Speaker Embedding 并固定音色合成配音。
"""
from __future__ import annotations
import logging
import traceback
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
from scipy.io import wavfile
from config import (
OUTPUT_DIR,
SPEAKER_EMB_PATH,
SPEAKER_SAMPLE_MAX_SEC,
SPEAKER_SAMPLE_MIN_SEC,
TTS_SAMPLE_RATE,
TTS_SPEED_PROMPT,
TTS_TEMPERATURE,
TTS_TOP_K,
TTS_TOP_P,
)
logger = logging.getLogger(__name__)
# 全局 ChatTTS 实例
_chat = None
_chat_error: Optional[str] = None
def _load_audio_for_chattts(audio_path: str, sample_rate: int = TTS_SAMPLE_RATE) -> np.ndarray:
"""
加载音频并重采样到 ChatTTS 所需采样率。
优先使用 ChatTTS 自带工具,回退到 librosa。
"""
try:
from ChatTTS.utils import load_audio
return load_audio(audio_path, sample_rate)
except ImportError:
pass
try:
from tools.audio import load_audio
return load_audio(audio_path, sample_rate)
except ImportError:
pass
import librosa
audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True)
return audio
def _get_audio_duration_sec(audio: np.ndarray, sample_rate: int) -> float:
"""计算音频时长(秒)。"""
if audio is None or len(audio) == 0:
return 0.0
return len(audio) / float(sample_rate)
def get_chattts_instance():
"""
获取或初始化 ChatTTS 模型。
启用 GPU 加速,compile=False 以兼容 3060 Ti 8GB 显存。
"""
global _chat, _chat_error
if _chat is not None:
return _chat, None
if _chat_error is not None:
return None, _chat_error
try:
import ChatTTS
logger.info("正在加载 ChatTTS 模型...")
chat = ChatTTS.Chat()
# 兼容不同版本 APIload_models(旧版)/ load(新版)
if hasattr(chat, "load_models"):
chat.load_models(compile=False)
elif hasattr(chat, "load"):
chat.load(compile=False)
else:
raise RuntimeError("当前 ChatTTS 版本缺少 load / load_models 方法。")
_chat = chat
logger.info("ChatTTS 模型加载成功。")
return _chat, None
except ImportError as exc:
_chat_error = (
"未安装 ChatTTS,请参考 DEPLOY.md 安装。\n"
f"原始错误: {exc}"
)
logger.exception("ChatTTS 导入失败")
return None, _chat_error
except Exception as exc:
_chat_error = f"ChatTTS 模型加载失败: {exc}\n{traceback.format_exc()}"
logger.exception("ChatTTS 初始化异常")
return None, _chat_error
def _encode_spk_emb(chat, tensor_or_str: Any) -> str:
"""将 Speaker Embedding 编码为 ChatTTS 可用的字符串格式。"""
if isinstance(tensor_or_str, str):
return tensor_or_str
if hasattr(chat, "_encode_spk_emb"):
return chat._encode_spk_emb(tensor_or_str)
# 兜底:直接转字符串(部分版本可接受 tensor)
return tensor_or_str
def save_fixed_speaker(
audio_sample_path: str,
sample_transcript: str = "",
) -> Tuple[bool, str]:
"""
从 10-30 秒干净人声中提取 Speaker Embedding 并序列化保存。
Args:
audio_sample_path: 参考人声 wav/mp3 等路径
sample_transcript: 参考音频的精确转写(可选,有助于 zero-shot 音色还原)
Returns:
(success, message)
"""
if not audio_sample_path:
return False, "未提供音色参考音频。"
chat, init_err = get_chattts_instance()
if chat is None:
return False, init_err or "ChatTTS 不可用。"
try:
audio = _load_audio_for_chattts(audio_sample_path, TTS_SAMPLE_RATE)
duration = _get_audio_duration_sec(audio, TTS_SAMPLE_RATE)
if duration < SPEAKER_SAMPLE_MIN_SEC:
return False, (
f"参考音频过短({duration:.1f}s),建议 {SPEAKER_SAMPLE_MIN_SEC}-"
f"{SPEAKER_SAMPLE_MAX_SEC} 秒干净人声。"
)
if duration > SPEAKER_SAMPLE_MAX_SEC + 5:
logger.warning("参考音频超过建议时长 %.1fs,将截取前 %ds", duration, SPEAKER_SAMPLE_MAX_SEC)
max_samples = SPEAKER_SAMPLE_MAX_SEC * TTS_SAMPLE_RATE
audio = audio[:max_samples]
# 从参考音频提取音色特征
spk_smp = chat.sample_audio_speaker(audio)
# 同时保存编码后的 spk_emb 字符串,便于 infer 时直接使用
spk_emb = _encode_spk_emb(chat, spk_smp)
payload: Dict[str, Any] = {
"spk_emb": spk_emb,
"spk_smp": spk_smp,
"txt_smp": sample_transcript.strip(),
"created_at": datetime.now().isoformat(),
"source_audio": str(audio_sample_path),
}
torch.save(payload, SPEAKER_EMB_PATH)
msg = (
f"音色已锁定并保存至 {SPEAKER_EMB_PATH}\n"
f"参考时长: {duration:.1f}s"
)
if not sample_transcript.strip():
msg += "\n提示:填写参考音频精确转写可进一步提升音色还原度。"
logger.info("Speaker Embedding 保存成功: %s", SPEAKER_EMB_PATH)
return True, msg
except Exception as exc:
err = f"音色提取失败: {exc}\n{traceback.format_exc()}"
logger.exception("save_fixed_speaker 失败")
return False, err
def _load_speaker_payload() -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
"""加载本地 speaker_emb.pt。"""
if not SPEAKER_EMB_PATH.exists():
return None, (
f"未找到固定音色文件 `{SPEAKER_EMB_PATH.name}`。"
"请先在【音色锁定】模块上传 10-30 秒参考人声。"
)
try:
payload = torch.load(SPEAKER_EMB_PATH, map_location="cpu", weights_only=False)
# 兼容旧版仅保存 tensor 的文件
if isinstance(payload, torch.Tensor):
chat, err = get_chattts_instance()
if chat is None:
return None, err
return {
"spk_emb": _encode_spk_emb(chat, payload),
"spk_smp": None,
"txt_smp": "",
}, None
if not isinstance(payload, dict):
return None, "speaker_emb.pt 格式无效,请重新锁定音色。"
return payload, None
except Exception as exc:
return None, f"读取 speaker_emb.pt 失败: {exc}"
def speaker_is_ready() -> Tuple[bool, str]:
"""检查固定音色是否已配置。"""
payload, err = _load_speaker_payload()
if payload is None:
return False, err or "音色未配置。"
return True, f"已加载固定音色: {SPEAKER_EMB_PATH}"
def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
"""
使用 ChatTTS 将润色后的文稿合成为 wav 配音。
Args:
refined_text: LLM 润色后的配音稿
Returns:
(success, message, output_wav_path_or_none)
"""
if not refined_text or not refined_text.strip():
return False, "合成文本为空,请先完成润色。", None
chat, init_err = get_chattts_instance()
if chat is None:
return False, init_err or "ChatTTS 不可用。", None
payload, spk_err = _load_speaker_payload()
if payload is None:
return False, spk_err or "请先锁定音色。", None
try:
import ChatTTS
spk_emb = payload.get("spk_emb")
spk_smp = payload.get("spk_smp")
txt_smp = payload.get("txt_smp", "")
params_infer_code = ChatTTS.Chat.InferCodeParams(
prompt=TTS_SPEED_PROMPT,
spk_emb=spk_emb,
spk_smp=spk_smp if spk_smp else None,
txt_smp=txt_smp if txt_smp else None,
temperature=TTS_TEMPERATURE,
top_P=TTS_TOP_P,
top_K=TTS_TOP_K,
)
# 内向克制语气:降低 oral 强度
params_refine_text = ChatTTS.Chat.RefineTextParams(
prompt="[oral_2][laugh_0][break_4]",
)
wavs = chat.infer(
refined_text.strip(),
skip_refine_text=False,
params_refine_text=params_refine_text,
params_infer_code=params_infer_code,
)
if not wavs or len(wavs) == 0:
return False, "ChatTTS 未生成有效音频。", None
wav_array = np.asarray(wavs[0], dtype=np.float32)
# 归一化并转 int16
peak = np.max(np.abs(wav_array)) or 1.0
wav_int16 = (wav_array / peak * 32767).astype(np.int16)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"voiceover_{timestamp}_{uuid.uuid4().hex[:6]}.wav"
output_path = OUTPUT_DIR / filename
wavfile.write(str(output_path), TTS_SAMPLE_RATE, wav_int16)
msg = f"配音合成成功: {output_path}"
logger.info(msg)
return True, msg, str(output_path)
except Exception as exc:
err = f"语音合成失败: {exc}\n{traceback.format_exc()}"
logger.exception("generate_voice 失败")
return False, err, None
+155
View File
@@ -0,0 +1,155 @@
"""
Faster-Whisper CUDA 语音识别服务
封装本地 GPU 加速的音频转写逻辑,适配 RTX 3060 Ti 8GB 显存。
"""
from __future__ import annotations
import logging
import traceback
from typing import Optional, Tuple
from config import (
WHISPER_COMPUTE_TYPE,
WHISPER_DEVICE,
WHISPER_LANGUAGE,
WHISPER_MODEL_SIZE,
)
logger = logging.getLogger(__name__)
# 全局懒加载模型实例,避免 Gradio 重复初始化占用显存
_model = None
_model_error: Optional[str] = None
def _is_cuda_error(exc: BaseException) -> bool:
"""判断异常是否与 CUDA/GPU 相关。"""
msg = str(exc).lower()
cuda_keywords = (
"cuda",
"cudnn",
"cublas",
"gpu",
"out of memory",
"no kernel image",
"device-side assert",
)
return any(k in msg for k in cuda_keywords)
def get_whisper_model():
"""
获取或初始化 Faster-Whisper 模型。
强制 device=cuda, compute_type=float16。
"""
global _model, _model_error
if _model is not None:
return _model, None
if _model_error is not None:
return None, _model_error
try:
from faster_whisper import WhisperModel
logger.info(
"正在加载 Whisper 模型: size=%s, device=%s, compute_type=%s",
WHISPER_MODEL_SIZE,
WHISPER_DEVICE,
WHISPER_COMPUTE_TYPE,
)
_model = WhisperModel(
WHISPER_MODEL_SIZE,
device=WHISPER_DEVICE,
compute_type=WHISPER_COMPUTE_TYPE,
)
logger.info("Whisper 模型加载成功。")
return _model, None
except ImportError as exc:
_model_error = (
"未安装 faster-whisper,请执行: pip install faster-whisper\n"
f"原始错误: {exc}"
)
logger.exception("faster-whisper 导入失败")
return None, _model_error
except Exception as exc:
if _is_cuda_error(exc):
_model_error = (
"CUDA 初始化失败,请检查 NVIDIA 驱动、CUDA 运行时及 cuDNN 是否正确安装。\n"
f"错误详情: {exc}\n"
f"{traceback.format_exc()}"
)
else:
_model_error = f"Whisper 模型加载失败: {exc}\n{traceback.format_exc()}"
logger.exception("Whisper 模型加载异常")
return None, _model_error
def transcribe_audio(audio_path: str) -> Tuple[bool, str]:
"""
将音频文件转写为中文文本。
Args:
audio_path: 本地音频文件绝对或相对路径
Returns:
(success, text_or_error_message)
"""
if not audio_path:
return False, "未提供音频文件路径。"
model, init_error = get_whisper_model()
if model is None:
return False, init_error or "Whisper 模型不可用。"
try:
segments, info = model.transcribe(
audio_path,
language=WHISPER_LANGUAGE,
beam_size=5,
vad_filter=True,
)
text_parts = []
for segment in segments:
text_parts.append(segment.text.strip())
result_text = "".join(text_parts).strip()
if not result_text:
return False, (
"识别结果为空,请检查音频是否有效、音量是否足够,"
f"或尝试更换格式。检测到语言: {getattr(info, 'language', 'unknown')}"
)
logger.info(
"转写完成: 语言=%s, 概率=%.2f, 字数=%d",
getattr(info, "language", "?"),
getattr(info, "language_probability", 0.0),
len(result_text),
)
return True, result_text
except Exception as exc:
if _is_cuda_error(exc):
err = (
"CUDA 推理异常:显存可能不足或 GPU 状态异常。"
"建议关闭其他占用显存的进程后重试。\n"
f"错误详情: {exc}"
)
else:
err = f"音频转写失败: {exc}\n{traceback.format_exc()}"
logger.exception("transcribe_audio 失败")
return False, err
def reset_whisper_model() -> None:
"""释放模型引用(用于调试或显存回收)。"""
global _model, _model_error
_model = None
_model_error = None