Compare commits
24 Commits
aacdffac77
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| d26bec085c | |||
| 22a87a4322 | |||
| fbfca0b890 | |||
| 56f14206dd | |||
| d63cb318b2 | |||
| ca49b2feed | |||
| 54523e39af | |||
| 1acba0349c | |||
| 2dd642598f | |||
| 541df29722 | |||
| 4255cf7cd7 | |||
| bdc63c04df | |||
| 7c50b13c57 | |||
| 97c11e08e0 | |||
| 038e00fbcf | |||
| 131cbf070a | |||
| eb71e28427 | |||
| 8be34a2fd5 | |||
| 1779449bba | |||
| 0cce6cda7c | |||
| 82f99c0b89 | |||
| f36056d293 | |||
| 0f5277c22e | |||
| 39e29fe6a9 |
@@ -10,4 +10,12 @@ OLLAMA_PORT=11434
|
||||
|
||||
# ChatTTS 模型目录(预下载脚本写入)
|
||||
# CHATTTS_MODEL_DIR=/opt/Trading_Studio/models/ChatTTS
|
||||
# WHISPER_MODEL_DIR=/opt/Trading_Studio/models/whisper
|
||||
# WHISPER_MODEL_SIZE=small
|
||||
# HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
# 8GB 显存 OOM 时可调低(合成按段切分)
|
||||
# TTS_MAX_CHARS_PER_CHUNK=150
|
||||
# TTS_MAX_NEW_TOKEN=768
|
||||
# TTS_MIN_NEW_TOKEN=16
|
||||
# TTS_ENABLE_CACHE=true
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
## 目录
|
||||
|
||||
0. [**一键部署(推荐)**](#0-一键部署推荐)
|
||||
0. [**一键部署(推荐)**](#0-一键部署推荐) — 含 [模型预下载](#08-ai-模型预下载内网服务器必做)、[服务器更新](#042-代码推送后的服务器更新推荐)、[手机麦克风](#09-手机找不到麦克风)
|
||||
1. [硬件与系统前提](#1-硬件与系统前提)
|
||||
2. [3060 Ti 120W 功耗墙配置](#2-3060-ti-120w-功耗墙配置)
|
||||
3. [NVIDIA 驱动与 CUDA](#3-nvidia-驱动与-cuda)
|
||||
@@ -96,6 +96,8 @@ bash deploy.sh
|
||||
http://<服务器局域网IP>:5683
|
||||
```
|
||||
|
||||
> **重要:首次部署后必须预下载 AI 模型**(Whisper + ChatTTS)。内网服务器无法访问 HuggingFace / GitHub 时,不执行此步会在 Web UI 报 `Network is unreachable` 或 `Read timed out`。详见 [0.8 AI 模型预下载](#08-ai-模型预下载内网服务器必做)。
|
||||
|
||||
### 0.3 脚本命令速查
|
||||
|
||||
```bash
|
||||
@@ -123,6 +125,45 @@ bash deploy.sh update
|
||||
> **git pull 报本地修改冲突?** 新版 `deploy.sh` 会自动 `stash` 后同步;若仍失败可手动:
|
||||
> `git fetch origin && git reset --hard origin/main`
|
||||
|
||||
### 0.4.2 代码推送后的服务器更新(推荐)
|
||||
|
||||
本地开发机 `git push` 到远端后,在 **Ubuntu 服务器**上同步并重启:
|
||||
|
||||
```bash
|
||||
cd /opt/Trading_Studio
|
||||
bash server-update.sh
|
||||
```
|
||||
|
||||
`server-update.sh` 会执行:
|
||||
|
||||
1. `git fetch origin main`
|
||||
2. `git reset --hard origin/main`(覆盖 CRLF 等幽灵改动;**Ollama 地址请写在 `.env`,勿改 `config.py`**)
|
||||
3. `pm2 restart trading_studio`
|
||||
|
||||
若本次更新涉及 **Whisper / ChatTTS 离线加载**(首次部署或新增模型脚本),还需预下载模型:
|
||||
|
||||
```bash
|
||||
cd /opt/Trading_Studio
|
||||
bash server-update.sh
|
||||
bash scripts/download_all_models.sh
|
||||
pm2 restart trading_studio
|
||||
```
|
||||
|
||||
| 脚本 | 作用 |
|
||||
|------|------|
|
||||
| `bash server-update.sh` | 强制与远端 `main` 同步 + PM2 重启 |
|
||||
| `bash scripts/download_all_models.sh` | 一次性下载 Whisper (small) + ChatTTS |
|
||||
| `bash scripts/download_whisper_models.sh small` | 仅下载 Whisper |
|
||||
| `bash scripts/download_chattts_models.sh` | 仅下载 ChatTTS |
|
||||
|
||||
验证 Whisper 是否就绪:
|
||||
|
||||
```bash
|
||||
ls -lh /opt/Trading_Studio/models/whisper/small/model.bin
|
||||
```
|
||||
|
||||
应看到约 **500MB** 的 `model.bin` 文件。
|
||||
|
||||
### 0.4.1 pip / PyTorch 下载超时
|
||||
|
||||
PyTorch + triton 约 2-3GB,国内网络默认启用清华镜像,并延长超时到 600 秒:
|
||||
@@ -161,6 +202,72 @@ SKIP_PYTORCH=1 bash deploy.sh deps
|
||||
4. 反代须透传 **WebSocket**(Gradio 必需)
|
||||
5. 用户通过 `https://你的域名` 访问后再安装 App
|
||||
|
||||
### 0.8 AI 模型预下载(内网服务器必做)
|
||||
|
||||
Trading Studio 的 **Whisper 语音识别** 与 **ChatTTS 音色合成** 均需在服务器本地存放模型文件。
|
||||
内网物理机通常无法访问 `huggingface.co` / `github.com`,若未预下载,Web UI 会出现:
|
||||
|
||||
| 模块 | 典型报错 |
|
||||
|------|----------|
|
||||
| Whisper | `Network is unreachable` / `ConnectError` |
|
||||
| ChatTTS | `Read timed out` / `github.com` 连接失败 |
|
||||
|
||||
**推荐:一键下载全部模型**
|
||||
|
||||
```bash
|
||||
cd /opt/Trading_Studio
|
||||
bash scripts/download_all_models.sh
|
||||
pm2 restart trading_studio
|
||||
```
|
||||
|
||||
**分步下载(可选)**
|
||||
|
||||
```bash
|
||||
# Whisper small(约 500MB,识别默认模型)
|
||||
bash scripts/download_whisper_models.sh small
|
||||
|
||||
# ChatTTS(约 1–2GB,音色锁定与合成必需)
|
||||
bash scripts/download_chattts_models.sh
|
||||
|
||||
pm2 restart trading_studio
|
||||
```
|
||||
|
||||
**模型落盘路径**
|
||||
|
||||
| 模型 | 目录 | 关键文件 |
|
||||
|------|------|----------|
|
||||
| Whisper `small` | `/opt/Trading_Studio/models/whisper/small/` | `model.bin` |
|
||||
| ChatTTS | `/opt/Trading_Studio/models/ChatTTS/` | `asset/` 等 |
|
||||
| HF 缓存 | `/opt/Trading_Studio/models/hf_cache/` | 下载中间缓存 |
|
||||
|
||||
**`.env` 可选配置**(复制 `.env.example` → `.env`):
|
||||
|
||||
```ini
|
||||
HF_ENDPOINT=https://hf-mirror.com
|
||||
WHISPER_MODEL_DIR=/opt/Trading_Studio/models/whisper
|
||||
WHISPER_MODEL_SIZE=small
|
||||
CHATTTS_MODEL_DIR=/opt/Trading_Studio/models/ChatTTS
|
||||
```
|
||||
|
||||
国内服务器建议保留 `HF_ENDPOINT=https://hf-mirror.com`(脚本与 `whisper_service.py` 均会读取)。
|
||||
|
||||
### 0.9 手机「找不到麦克风」
|
||||
|
||||
通过 `http://192.168.x.x:5683` 内网 HTTP 访问时,手机浏览器会显示 **「找不到麦克风」** 或 **「检测不到麦克风」**。
|
||||
这是浏览器安全策略:`getUserMedia`(麦克风)**仅在 HTTPS 或 localhost 下可用**,不是程序 bug。
|
||||
|
||||
| 访问方式 | 电脑录音 | 手机录音 |
|
||||
|----------|----------|----------|
|
||||
| `http://内网IP:5683` | 可能可用 | ❌ 不可用 |
|
||||
| `https://域名`(NPS + 云反代) | ✅ | ✅ |
|
||||
|
||||
**解决办法(任选其一):**
|
||||
|
||||
1. 按 [PWA_NPS.md](./PWA_NPS.md) 配置 **NPS 穿透 + 云服务器 HTTPS 域名**,用手机访问 `https://你的域名`
|
||||
2. 在 HTTP 内网环境下,使用音频区域的 **「上传」** 标签,上传手机「语音备忘录」导出的 `.m4a` / `.wav`(与现场录音效果相同)
|
||||
|
||||
Whisper 离线模型就绪后,**上传音频文件** 可正常识别;麦克风实时录音需 HTTPS。
|
||||
|
||||
---
|
||||
|
||||
### 0.5 PM2 运维(root 环境)
|
||||
@@ -181,11 +288,21 @@ tail -f /opt/Trading_Studio/logs/pm2-out.log
|
||||
```
|
||||
/opt/Trading_Studio/
|
||||
├── deploy.sh # 一键部署脚本
|
||||
├── server-update.sh # 强制同步远端 + PM2 重启
|
||||
├── app.py # Gradio 主入口
|
||||
├── venv/ # Python 虚拟环境
|
||||
├── scripts/
|
||||
│ ├── download_all_models.sh # Whisper + ChatTTS 一键下载
|
||||
│ ├── download_whisper_models.sh # Whisper 预下载(HF 镜像)
|
||||
│ └── download_chattts_models.sh # ChatTTS 预下载(HF 镜像)
|
||||
├── models/ # AI 模型(预下载脚本写入,不入 Git)
|
||||
│ ├── whisper/small/ # Faster-Whisper(含 model.bin)
|
||||
│ ├── ChatTTS/ # ChatTTS 权重
|
||||
│ └── hf_cache/ # HuggingFace 缓存
|
||||
├── logs/ # PM2 日志
|
||||
├── uploads/ # 上传临时文件
|
||||
├── outputs/ # 合成 wav 输出
|
||||
├── .env # 服务器本地配置(Ollama IP 等,不入 Git)
|
||||
├── speaker_emb.pt # 音色文件(Web UI 生成,需手动备份)
|
||||
└── trading_studio.log # 应用日志
|
||||
```
|
||||
@@ -361,9 +478,26 @@ cd /opt/Trading_Studio
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 6.1 Faster-Whisper
|
||||
### 6.1 Faster-Whisper(必须预下载)
|
||||
|
||||
随 `requirements.txt` 安装。首次运行会自动下载 `small` 模型(约 500MB)至 HuggingFace 缓存。
|
||||
随 `requirements.txt` 安装。`whisper_service.py` **优先从本地目录加载**,未预下载时会尝试在线拉取 HuggingFace 模型。
|
||||
|
||||
内网服务器无法访问外网时会报:
|
||||
|
||||
```
|
||||
Whisper 模型加载失败: Network is unreachable
|
||||
```
|
||||
|
||||
**处理:** 见 [0.8 AI 模型预下载](#08-ai-模型预下载内网服务器必做),或执行:
|
||||
|
||||
```bash
|
||||
bash scripts/download_whisper_models.sh small
|
||||
pm2 restart trading_studio
|
||||
```
|
||||
|
||||
本地路径:`/opt/Trading_Studio/models/whisper/small/model.bin`(约 500MB)。
|
||||
|
||||
可选环境变量(`.env`):`WHISPER_MODEL_DIR`、`WHISPER_MODEL_SIZE`、`HF_ENDPOINT`。
|
||||
|
||||
### 6.2 ChatTTS(必须预下载,勿依赖 GitHub)
|
||||
|
||||
@@ -373,27 +507,19 @@ pip install -r requirements.txt
|
||||
pip install ChatTTS @ git+https://github.com/2noise/ChatTTS.git
|
||||
```
|
||||
|
||||
**重要:** 默认 `chat.load()` 会访问 **github.com** 下载 asset,国内服务器常报 `Read timed out (3)`。
|
||||
部署后**必须**执行预下载脚本(走 HuggingFace 镜像):
|
||||
**重要:** 默认 `chat.load()` 会访问 **github.com** 下载 asset,国内/内网服务器常报 `Read timed out`。
|
||||
`tts_service.py` 已支持从 `models/ChatTTS` 离线加载,部署后**必须**预下载:
|
||||
|
||||
```bash
|
||||
cd /opt/Trading_Studio
|
||||
source venv/bin/activate
|
||||
bash scripts/download_chattts_models.sh
|
||||
pm2 restart trading_studio
|
||||
```
|
||||
|
||||
模型保存至 `/opt/Trading_Studio/models/ChatTTS`(约 1–2GB,不入 Git)。
|
||||
|
||||
`.env` 可自定义:
|
||||
|
||||
```ini
|
||||
HF_ENDPOINT=https://hf-mirror.com
|
||||
CHATTTS_MODEL_DIR=/opt/Trading_Studio/models/ChatTTS
|
||||
```
|
||||
|
||||
下载完成后再在 Web UI 点击「锁定音色」。
|
||||
|
||||
一键下载 Whisper + ChatTTS:`bash scripts/download_all_models.sh`
|
||||
|
||||
### 6.3 Gradio
|
||||
|
||||
```bash
|
||||
@@ -470,9 +596,11 @@ http://<本机局域网IP>:5683
|
||||
|
||||
### 8.1 验证清单
|
||||
|
||||
- [ ] `models/whisper/small/model.bin` 存在(`bash scripts/download_whisper_models.sh small`)
|
||||
- [ ] `models/ChatTTS/` 已预下载(`bash scripts/download_chattts_models.sh`)
|
||||
- [ ] 页面加载,Ollama 状态显示在线
|
||||
- [ ] 上传 10-30s 参考人声 → 音色锁定成功,生成 `speaker_emb.pt`
|
||||
- [ ] 上传复盘录音 → Whisper 识别出中文文本
|
||||
- [ ] 上传复盘录音 → Whisper 识别出中文文本(无需外网)
|
||||
- [ ] 点击润色 → 返回 Gemma4 处理后的文稿
|
||||
- [ ] 点击合成 → `outputs/` 下生成 24kHz wav
|
||||
|
||||
@@ -605,12 +733,45 @@ nvidia-smi
|
||||
fuser -v /dev/nvidia*
|
||||
```
|
||||
|
||||
Whisper 与 ChatTTS 不会同时常驻最大显存,但首次加载模型时峰值较高。建议:
|
||||
Whisper 与 ChatTTS **不能同时常驻** 8GB 显存(会 CUDA OOM)。应用已自动互斥卸载:
|
||||
|
||||
- 锁定 120W 功耗墙
|
||||
- `max_memory_restart: "6G"` 已在 PM2 配置中设置
|
||||
- 识别前卸载 ChatTTS
|
||||
- 合成 / 锁定音色前卸载 Whisper
|
||||
|
||||
### 10.3 Whisper CUDA 报错
|
||||
若仍 OOM:
|
||||
|
||||
```bash
|
||||
pm2 restart trading_studio
|
||||
nvidia-smi # 确认无其他占 GPU 进程
|
||||
```
|
||||
|
||||
在 `.env` 调低合成峰值:
|
||||
|
||||
```ini
|
||||
TTS_MAX_CHARS_PER_CHUNK=150
|
||||
TTS_MAX_NEW_TOKEN=768
|
||||
```
|
||||
|
||||
PM2 已配置 `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` 缓解碎片。建议锁定 120W 功耗墙。
|
||||
|
||||
### 10.3 Whisper 模型加载失败
|
||||
|
||||
#### A. `Network is unreachable` / `ConnectError`(内网无外网)
|
||||
|
||||
**原因:** 未预下载 Whisper 模型,程序尝试访问 HuggingFace Hub 失败。
|
||||
|
||||
**处理:**
|
||||
|
||||
```bash
|
||||
cd /opt/Trading_Studio
|
||||
bash scripts/download_whisper_models.sh small
|
||||
ls -lh models/whisper/small/model.bin # 确认约 500MB
|
||||
pm2 restart trading_studio
|
||||
```
|
||||
|
||||
若服务器可访问外网但 HuggingFace 慢,在 `.env` 中设置 `HF_ENDPOINT=https://hf-mirror.com` 后重试下载。
|
||||
|
||||
#### B. CUDA / 显存报错
|
||||
|
||||
```
|
||||
错误: CUDA initialization failed / out of memory
|
||||
@@ -620,7 +781,7 @@ Whisper 与 ChatTTS 不会同时常驻最大显存,但首次加载模型时峰
|
||||
|
||||
1. 重启 PM2 进程释放显存
|
||||
2. 确认 `compute_type="float16"`(已在 config.py 配置)
|
||||
3. 降级模型为 `base`(修改 `config.py` 中 `WHISPER_MODEL_SIZE`)
|
||||
3. 在 `.env` 中降级模型:`WHISPER_MODEL_SIZE=base`,并执行 `bash scripts/download_whisper_models.sh base`
|
||||
|
||||
### 10.4 Ollama 超时
|
||||
|
||||
@@ -634,14 +795,31 @@ Whisper 与 ChatTTS 不会同时常驻最大显存,但首次加载模型时峰
|
||||
2. 增大 `config.py` 中 `OLLAMA_TIMEOUT`
|
||||
3. 检查防火墙:`sudo ufw allow from 192.168.8.0/24 to any port 11434`(在 Ollama 节点)
|
||||
|
||||
### 10.5 ChatTTS 音色文件损坏
|
||||
### 10.5 ChatTTS 合成报 `Corrupt input data`
|
||||
|
||||
**原因:** 音色参数传错。`sample_audio_speaker()` 的结果应作为 **`spk_smp`**,不能同时误传给 **`spk_emb`**(LZMA 解压失败)。旧版 `speaker_emb.pt` 或未填参考转写时常见。
|
||||
|
||||
**处理:**
|
||||
|
||||
```bash
|
||||
rm /opt/Trading_Studio/speaker_emb.pt
|
||||
pm2 restart trading_studio
|
||||
```
|
||||
|
||||
在 Web UI「音色锁定」:
|
||||
|
||||
1. 上传 10–30 秒干净参考人声
|
||||
2. **填写与录音完全一致的「参考音频精确转写」**(必填)
|
||||
3. 重新点击「锁定音色」后再合成
|
||||
|
||||
### 10.6 ChatTTS 音色文件损坏
|
||||
|
||||
```bash
|
||||
rm speaker_emb.pt
|
||||
# 重新在 Web UI「音色锁定」上传参考人声
|
||||
# 重新在 Web UI「音色锁定」上传参考人声并填写转写
|
||||
```
|
||||
|
||||
### 10.6 端口 5683 被占用
|
||||
### 10.7 端口 5683 被占用
|
||||
|
||||
```bash
|
||||
sudo lsof -i :5683
|
||||
@@ -649,6 +827,13 @@ sudo lsof -i :5683
|
||||
ss -tlnp | grep 5683
|
||||
```
|
||||
|
||||
### 10.8 手机「找不到麦克风」
|
||||
|
||||
内网 `http://192.168.x.x:5683` 下手机无法使用实时录音,属浏览器 HTTPS 安全限制。
|
||||
完整说明与 NPS 穿透方案见 [0.9 手机「找不到麦克风」](#09-手机找不到麦克风) 与 [PWA_NPS.md](./PWA_NPS.md) 第九节。
|
||||
|
||||
**临时方案:** Web UI 音频区域使用 **「上传」** 导入录音文件,Whisper 识别流程相同。
|
||||
|
||||
---
|
||||
|
||||
## 附录:防火墙(本机 Gradio)
|
||||
@@ -664,16 +849,30 @@ sudo ufw reload
|
||||
|
||||
---
|
||||
|
||||
## 附录:config.py 关键常量速查
|
||||
## 附录:config.py / .env 关键配置速查
|
||||
|
||||
**服务器本地覆盖请用 `.env`**(`cp .env.example .env`),避免 `git pull` 冲突:
|
||||
|
||||
```ini
|
||||
OLLAMA_HOST=192.168.8.64
|
||||
OLLAMA_PORT=11434
|
||||
HF_ENDPOINT=https://hf-mirror.com
|
||||
WHISPER_MODEL_DIR=/opt/Trading_Studio/models/whisper
|
||||
WHISPER_MODEL_SIZE=small
|
||||
CHATTTS_MODEL_DIR=/opt/Trading_Studio/models/ChatTTS
|
||||
```
|
||||
|
||||
`config.py` 默认值(可被 `.env` 覆盖):
|
||||
|
||||
```python
|
||||
HOST = "0.0.0.0"
|
||||
PORT = 5683
|
||||
OLLAMA_URL = "http://192.168.8.64:11434/api/chat"
|
||||
MODEL_NAME = "huihui_ai/gemma-4-abliterated:e4b"
|
||||
WHISPER_MODEL_SIZE = "small"
|
||||
WHISPER_MODEL_SIZE = "small" # .env: WHISPER_MODEL_SIZE
|
||||
WHISPER_MODEL_DIR = "models/whisper" # .env: WHISPER_MODEL_DIR
|
||||
WHISPER_DEVICE = "cuda"
|
||||
WHISPER_COMPUTE_TYPE = "float16"
|
||||
CHATTTS_MODEL_DIR = "models/ChatTTS"
|
||||
HF_ENDPOINT = "https://hf-mirror.com"
|
||||
SPEAKER_EMB_PATH = "speaker_emb.pt"
|
||||
TTS_SAMPLE_RATE = 24000
|
||||
```
|
||||
|
||||
+3
-2
@@ -220,6 +220,7 @@ Trading Studio 应用层也会发送该头;若反代覆盖了响应头,需
|
||||
|
||||
## 相关文档
|
||||
|
||||
- 内网部署:`DEPLOY.md`
|
||||
- 服务器更新:`bash server-update.sh`
|
||||
- 内网部署与模型预下载:[DEPLOY.md](./DEPLOY.md)(§0.4.2 服务器更新、§0.8 模型预下载、§0.9 麦克风)
|
||||
- 服务器快速更新:`bash server-update.sh`(同步远端 `main` + PM2 重启)
|
||||
- 首次部署后下载模型:`bash scripts/download_all_models.sh`
|
||||
- 麦克风问题:见上文 **第九节**
|
||||
|
||||
@@ -66,12 +66,30 @@ bash deploy.sh
|
||||
|
||||
浏览器访问:`http://<服务器IP>:5683`
|
||||
|
||||
日常更新:
|
||||
**首次部署后必做 — 预下载 AI 模型**(内网服务器无外网时必需,否则 Whisper 报 `Network is unreachable`):
|
||||
|
||||
```bash
|
||||
cd /opt/Trading_Studio && bash deploy.sh update
|
||||
cd /opt/Trading_Studio
|
||||
bash scripts/download_all_models.sh
|
||||
pm2 restart trading_studio
|
||||
```
|
||||
|
||||
日常更新(代码已 `git push` 到远端后,在服务器执行):
|
||||
|
||||
```bash
|
||||
cd /opt/Trading_Studio
|
||||
bash server-update.sh
|
||||
```
|
||||
|
||||
若更新涉及模型脚本或首次部署,追加:
|
||||
|
||||
```bash
|
||||
bash scripts/download_all_models.sh
|
||||
pm2 restart trading_studio
|
||||
```
|
||||
|
||||
完整说明见 [DEPLOY.md §0.4.2 / §0.8](./DEPLOY.md)。
|
||||
|
||||
### 手动部署(开发调试)
|
||||
|
||||
```bash
|
||||
@@ -88,11 +106,25 @@ python app.py
|
||||
|
||||
## 使用流程
|
||||
|
||||
### 首次使用:锁定音色
|
||||
### 配音音色(全部本地 GPU,无需 API)
|
||||
|
||||
| 方式 | 说明 |
|
||||
|------|------|
|
||||
| **我的锁定音色** | 「音色锁定」上传你的人声 → 声音克隆(`speaker_emb.pt`) |
|
||||
| **预设男/女声** | ChatTTS 内置说话人,合成页下拉选择(类似微软音色列表) |
|
||||
|
||||
首次使用预设音色(服务器执行一次):
|
||||
|
||||
```bash
|
||||
bash scripts/generate_voice_presets.sh
|
||||
pm2 restart trading_studio
|
||||
```
|
||||
|
||||
### 首次使用:锁定音色(可选,用于克隆自己的声音)
|
||||
|
||||
1. 进入 **「音色锁定」** 标签页
|
||||
2. 上传 10-30 秒干净人声参考(你的碎碎念盲录样本)
|
||||
3. (可选)填写参考音频的精确转写,提升 zero-shot 还原度
|
||||
3. 填写参考音频的精确转写(强烈建议)
|
||||
4. 点击 **锁定音色** → 生成 `speaker_emb.pt`
|
||||
|
||||
### 日常生产
|
||||
@@ -116,7 +148,9 @@ python app.py
|
||||
| 中控端口 | `5683`(`0.0.0.0` 局域网可访问) |
|
||||
| Ollama 地址 | `http://192.168.8.64:11434` |
|
||||
| 模型名称 | `huihui_ai/gemma-4-abliterated:e4b` |
|
||||
| Whisper 模型 | `small` / CUDA / float16 |
|
||||
| Whisper 模型 | `small` / CUDA / float16,本地路径 `models/whisper/small/` |
|
||||
| ChatTTS 模型 | `models/ChatTTS/`(须预下载脚本) |
|
||||
| HF 镜像 | `HF_ENDPOINT=https://hf-mirror.com`(`.env` 可改) |
|
||||
| 音色文件 | `speaker_emb.pt` |
|
||||
| 音频输出 | `outputs/` 目录 |
|
||||
|
||||
@@ -170,16 +204,25 @@ outputs/
|
||||
```
|
||||
Trading_Studio/
|
||||
├── deploy.sh # 一键部署脚本(/opt + PM2)
|
||||
├── server-update.sh # 强制同步远端 main + PM2 重启
|
||||
├── app.py # Gradio 主入口
|
||||
├── config.py # 全局配置
|
||||
├── whisper_service.py # Whisper CUDA 识别
|
||||
├── config.py # 全局配置(Ollama 等请用 .env 覆盖)
|
||||
├── whisper_service.py # Whisper CUDA 识别(优先本地模型)
|
||||
├── llm_service.py # Ollama 远程润色
|
||||
├── tts_service.py # ChatTTS 音色与合成
|
||||
├── tts_service.py # ChatTTS 音色与合成(优先本地模型)
|
||||
├── scripts/
|
||||
│ ├── download_all_models.sh # Whisper + ChatTTS 一键下载
|
||||
│ ├── download_whisper_models.sh
|
||||
│ └── download_chattts_models.sh
|
||||
├── models/ # AI 模型(预下载,不入 Git)
|
||||
│ ├── whisper/small/
|
||||
│ └── ChatTTS/
|
||||
├── ecosystem.config.js # PM2 守护配置
|
||||
├── requirements.txt # Python 依赖
|
||||
├── .env.example # 服务器本地配置模板 → 复制为 .env
|
||||
├── README.md # 本文件
|
||||
├── DEPLOY.md # 部署指南(含一键部署教程)
|
||||
├── PWA_NPS.md # 云服务器反代 + NPS 穿透 + PWA 安装教程
|
||||
├── DEPLOY.md # 部署指南(含模型预下载、故障排查)
|
||||
├── PWA_NPS.md # HTTPS / NPS 穿透 / 手机麦克风教程
|
||||
├── .gitignore
|
||||
├── speaker_emb.pt # 音色文件(运行时生成,不入库)
|
||||
├── uploads/ # 上传临时目录
|
||||
@@ -201,15 +244,27 @@ Trading_Studio/
|
||||
|
||||
## 常见问题
|
||||
|
||||
**Q: Whisper 报 `Network is unreachable`?**
|
||||
A: 内网服务器无法访问 HuggingFace。执行 `bash scripts/download_whisper_models.sh small`,确认 `models/whisper/small/model.bin` 存在后 `pm2 restart trading_studio`。详见 [DEPLOY.md §0.8](./DEPLOY.md)。
|
||||
|
||||
**Q: Whisper 报 CUDA 错误?**
|
||||
A: 确认 `nvidia-smi` 正常,且未同时运行其他占显存任务。Whisper 使用 `float16` 已针对 8GB 优化。
|
||||
A: 确认 `nvidia-smi` 正常,且未同时运行其他占显存任务。Whisper 使用 `float16` 已针对 8GB 优化。可在 `.env` 设置 `WHISPER_MODEL_SIZE=base` 并重新下载。
|
||||
|
||||
**Q: ChatTTS 报 GitHub / 下载超时?**
|
||||
A: 执行 `bash scripts/download_chattts_models.sh`,或一键 `bash scripts/download_all_models.sh`。
|
||||
|
||||
**Q: Ollama 连接失败?**
|
||||
A: 在服务器上执行 `curl http://192.168.8.64:11434/api/tags` 验证连通性,确认模型已 `ollama pull`。
|
||||
A: 在服务器上执行 `curl http://192.168.8.64:11434/api/tags` 验证连通性,确认模型已 `ollama pull`。Ollama IP 写在 `.env` 的 `OLLAMA_HOST`。
|
||||
|
||||
**Q: 手机显示「找不到麦克风」?**
|
||||
A: `http://内网IP:5683` 非 HTTPS,浏览器禁用麦克风。请按 [PWA_NPS.md](./PWA_NPS.md) 配置 HTTPS 域名,或改用 Web UI **「上传」** 录音文件。
|
||||
|
||||
**Q: TTS 音色不稳定?**
|
||||
A: 重新锁定音色,填写参考音频精确转写,并保持 `temperature=0.3` 低随机性。
|
||||
|
||||
**Q: 合成报 `Corrupt input data`?**
|
||||
A: 音色参数格式问题。删除 `speaker_emb.pt`,重新锁定音色并**填写参考音频精确转写**。详见 [DEPLOY.md §10.5](./DEPLOY.md)。
|
||||
|
||||
**Q: 合成音频为空或噪声?**
|
||||
A: 检查润色文本长度(过短可能导致异常),确认 `speaker_emb.pt` 存在且有效。
|
||||
|
||||
|
||||
@@ -39,6 +39,13 @@ def _env_int(key: str, default: int) -> int:
|
||||
return default
|
||||
|
||||
|
||||
def _env_bool(key: str, default: bool) -> bool:
|
||||
raw = os.environ.get(key)
|
||||
if raw is None:
|
||||
return default
|
||||
return raw.strip().lower() in ("1", "true", "yes", "on")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 网络与服务
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -74,25 +81,40 @@ SYSTEM_PROMPT = (
|
||||
"如果原视频中有由于心态不好、违背交易纪律(如手贱乱开仓、提前平仓)"
|
||||
"导致少赚或亏损的部分,请用冷酷、严厉的语气狠狠地自我吐槽、反思该点。"
|
||||
"去掉所有无意义的口头禅,字数不做删减。"
|
||||
"【输出格式硬性要求】"
|
||||
"只输出可直接朗读的纯文本正文,不要 Markdown(禁止 #、**、---、列表符号、emoji)。"
|
||||
"不要写舞台提示(如前奏、转场、BGM、语气说明等括号备注)。"
|
||||
"不要写「以下是润色后的文案」等前言,也不要写修改笔记或点评。"
|
||||
"可用《》作为标题,正文按自然段换行即可。"
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 路径
|
||||
# ---------------------------------------------------------------------------
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
INSTALL_DIR = Path("/opt/Trading_Studio")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Faster-Whisper 配置
|
||||
# ---------------------------------------------------------------------------
|
||||
WHISPER_MODEL_SIZE = "small"
|
||||
WHISPER_MODEL_SIZE = _env_str("WHISPER_MODEL_SIZE", "small")
|
||||
WHISPER_DEVICE = "cuda"
|
||||
WHISPER_COMPUTE_TYPE = "float16"
|
||||
WHISPER_LANGUAGE = "zh"
|
||||
WHISPER_MODEL_DIR = Path(_env_str("WHISPER_MODEL_DIR", str(BASE_DIR / "models" / "whisper")))
|
||||
|
||||
WHISPER_HF_REPO = {
|
||||
"tiny": "Systran/faster-whisper-tiny",
|
||||
"base": "Systran/faster-whisper-base",
|
||||
"small": "Systran/faster-whisper-small",
|
||||
"medium": "Systran/faster-whisper-medium",
|
||||
"large-v2": "Systran/faster-whisper-large-v2",
|
||||
"large-v3": "Systran/faster-whisper-large-v3",
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChatTTS 配置
|
||||
# ---------------------------------------------------------------------------
|
||||
# 标准生产安装路径(/opt,root 部署)
|
||||
INSTALL_DIR = Path("/opt/Trading_Studio")
|
||||
|
||||
# 项目根目录(开发/生产均自适应,以实际 app.py 所在目录为准)
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
|
||||
# 固定音色 Embedding 存储路径
|
||||
SPEAKER_EMB_PATH = BASE_DIR / "speaker_emb.pt"
|
||||
|
||||
@@ -120,6 +142,22 @@ TTS_TEMPERATURE = 0.3
|
||||
TTS_TOP_P = 0.7
|
||||
TTS_TOP_K = 20
|
||||
TTS_SPEED_PROMPT = "[speed_5]"
|
||||
# 多段拼接时各段必须使用同一随机种子,否则音色会像「换了个人」
|
||||
TTS_MANUAL_SEED = _env_int("TTS_MANUAL_SEED", 42)
|
||||
# 段间静音间隔(秒)
|
||||
TTS_SEGMENT_PAUSE_SEC = 0.35
|
||||
|
||||
# 单段 TTS 最大字数(超长稿按句切分后逐段合成再拼接;8GB 显存建议 ≤200)
|
||||
TTS_MAX_CHARS_PER_CHUNK = _env_int("TTS_MAX_CHARS_PER_CHUNK", 200)
|
||||
|
||||
# ChatTTS 单段最大生成 token(越小越省显存,长句会自动切多段)
|
||||
TTS_MAX_NEW_TOKEN = _env_int("TTS_MAX_NEW_TOKEN", 1024)
|
||||
|
||||
# 至少生成多少 audio token 才允许结束(防止首 token EOS → 无限递归重试)
|
||||
TTS_MIN_NEW_TOKEN = _env_int("TTS_MIN_NEW_TOKEN", 16)
|
||||
|
||||
# GPT KV cache(关闭可省显存,但部分 transformers 版本会触发 CUDA assert)
|
||||
TTS_ENABLE_CACHE = _env_bool("TTS_ENABLE_CACHE", True)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 上传临时文件目录
|
||||
|
||||
@@ -21,6 +21,7 @@ module.exports = {
|
||||
env: {
|
||||
PYTHONUNBUFFERED: "1",
|
||||
CUDA_VISIBLE_DEVICES: "0",
|
||||
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True",
|
||||
},
|
||||
error_file: path.join(APP_DIR, "logs/pm2-error.log"),
|
||||
out_file: path.join(APP_DIR, "logs/pm2-out.log"),
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
"""GPU 显存回收工具(3060 Ti 8GB:Whisper 与 ChatTTS 不宜同时驻留)。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def release_cuda_cache() -> None:
|
||||
"""触发 GC 并清空 PyTorch CUDA 缓存。"""
|
||||
gc.collect()
|
||||
try:
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
torch.cuda.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
torch.cuda.empty_cache()
|
||||
if hasattr(torch.cuda, "ipc_collect"):
|
||||
torch.cuda.ipc_collect()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def is_cuda_runtime_error(exc: BaseException) -> bool:
|
||||
msg = str(exc).lower()
|
||||
return any(
|
||||
k in msg
|
||||
for k in (
|
||||
"cuda error",
|
||||
"device-side assert",
|
||||
"out of memory",
|
||||
"cublas",
|
||||
"illegal memory access",
|
||||
"an illegal instruction",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def cuda_memory_summary() -> str:
|
||||
"""返回简要显存占用(调试用)。"""
|
||||
try:
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
return "CUDA 不可用"
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
return f"GPU 显存: 已用 {(total - free) / 1024**3:.2f}GB / {total / 1024**3:.2f}GB"
|
||||
except Exception as exc:
|
||||
return f"显存查询失败: {exc}"
|
||||
@@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env bash
|
||||
# 一次性下载 Whisper + ChatTTS 全部模型(内网服务器部署必跑)
|
||||
set -euo pipefail
|
||||
ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
||||
cd "${ROOT}"
|
||||
|
||||
echo "========== 下载 Whisper (small) =========="
|
||||
bash scripts/download_whisper_models.sh small
|
||||
|
||||
echo ""
|
||||
echo "========== 下载 ChatTTS =========="
|
||||
bash scripts/download_chattts_models.sh
|
||||
|
||||
echo ""
|
||||
echo "[OK] 全部模型下载完成,请: pm2 restart trading_studio"
|
||||
@@ -0,0 +1,59 @@
|
||||
#!/usr/bin/env bash
|
||||
# 预下载 Faster-Whisper 模型(HF 镜像,内网服务器离线可用)
|
||||
# 用法: bash scripts/download_whisper_models.sh [tiny|base|small|medium|large-v3]
|
||||
set -euo pipefail
|
||||
|
||||
ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
||||
SIZE="${1:-small}"
|
||||
VENV_PY="${ROOT}/venv/bin/python"
|
||||
MODEL_DIR="${WHISPER_MODEL_DIR:-${ROOT}/models/whisper}/${SIZE}"
|
||||
|
||||
export HF_ENDPOINT="${HF_ENDPOINT:-https://hf-mirror.com}"
|
||||
export HF_HUB_DOWNLOAD_TIMEOUT="${HF_HUB_DOWNLOAD_TIMEOUT:-600}"
|
||||
export HF_HOME="${HF_HOME:-${ROOT}/models/hf_cache}"
|
||||
export MODEL_DIR
|
||||
export WHISPER_SIZE="${SIZE}"
|
||||
|
||||
echo "[INFO] Whisper 模型: ${SIZE}"
|
||||
echo "[INFO] 保存目录: ${MODEL_DIR}"
|
||||
echo "[INFO] HF 镜像: ${HF_ENDPOINT}"
|
||||
|
||||
if [[ ! -x "${VENV_PY}" ]]; then
|
||||
echo "[ERROR] 未找到 venv,请先 bash deploy.sh deps"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
"${VENV_PY}" -m pip install -q huggingface_hub
|
||||
|
||||
"${VENV_PY}" << 'PY'
|
||||
import os
|
||||
from pathlib import Path
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
size = os.environ["WHISPER_SIZE"]
|
||||
repos = {
|
||||
"tiny": "Systran/faster-whisper-tiny",
|
||||
"base": "Systran/faster-whisper-base",
|
||||
"small": "Systran/faster-whisper-small",
|
||||
"medium": "Systran/faster-whisper-medium",
|
||||
"large-v2": "Systran/faster-whisper-large-v2",
|
||||
"large-v3": "Systran/faster-whisper-large-v3",
|
||||
}
|
||||
repo = repos.get(size)
|
||||
if not repo:
|
||||
raise SystemExit(f"未知模型尺寸: {size}, 可选: {list(repos)}")
|
||||
|
||||
target = Path(os.environ["MODEL_DIR"])
|
||||
target.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"[INFO] 正在下载 {repo} ...")
|
||||
snapshot_download(repo_id=repo, local_dir=str(target), local_dir_use_symlinks=False)
|
||||
|
||||
if not (target / "model.bin").is_file():
|
||||
raise SystemExit(f"[ERROR] 下载不完整,未找到 model.bin: {target}")
|
||||
|
||||
print(f"[OK] Whisper 模型就绪: {target}")
|
||||
PY
|
||||
|
||||
echo ""
|
||||
echo "[OK] 请执行: pm2 restart trading_studio"
|
||||
@@ -0,0 +1,80 @@
|
||||
#!/usr/bin/env python3
|
||||
"""生成 ChatTTS 本地预设说话人(sample_random_speaker,走 GPU)。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import torch
|
||||
|
||||
from config import CHATTTS_MODEL_DIR
|
||||
from tts_service import get_chattts_instance, reset_chattts_instance
|
||||
from voice_presets import (
|
||||
DEFAULT_MANIFEST,
|
||||
MANIFEST_PATH,
|
||||
PRESETS_DIR,
|
||||
VOICES_DIR,
|
||||
ensure_manifest,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ensure_manifest()
|
||||
PRESETS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
from whisper_service import reset_whisper_model
|
||||
|
||||
reset_whisper_model()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reset_chattts_instance()
|
||||
chat, err = get_chattts_instance()
|
||||
if chat is None:
|
||||
raise SystemExit(f"ChatTTS 加载失败: {err}")
|
||||
|
||||
if not hasattr(chat, "sample_random_speaker"):
|
||||
raise SystemExit("当前 ChatTTS 版本不支持 sample_random_speaker")
|
||||
|
||||
presets = DEFAULT_MANIFEST["presets"]
|
||||
print(f"[INFO] 生成 {len(presets)} 个预设音色 → {PRESETS_DIR}")
|
||||
|
||||
for item in presets:
|
||||
pid = item["id"]
|
||||
label = item["label"]
|
||||
out_path = PRESETS_DIR / f"{pid}.pt"
|
||||
|
||||
spk_emb = chat.sample_random_speaker()
|
||||
payload = {
|
||||
"version": 1,
|
||||
"preset": True,
|
||||
"id": pid,
|
||||
"label": label,
|
||||
"spk_emb": spk_emb,
|
||||
"spk_smp": None,
|
||||
"txt_smp": "",
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"source": "ChatTTS.sample_random_speaker",
|
||||
}
|
||||
torch.save(payload, out_path)
|
||||
print(f" [OK] {label} → {out_path.name}")
|
||||
|
||||
manifest = json.loads(MANIFEST_PATH.read_text(encoding="utf-8"))
|
||||
manifest["generated_at"] = datetime.now().isoformat()
|
||||
manifest["chattts_model"] = str(CHATTTS_MODEL_DIR)
|
||||
MANIFEST_PATH.write_text(
|
||||
json.dumps(manifest, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
print("[OK] 全部预设音色生成完成")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,16 @@
|
||||
#!/usr/bin/env bash
|
||||
# 生成本地 GPU 预设音色(ChatTTS 内置说话人,无需 API)
|
||||
# 用法: bash scripts/generate_voice_presets.sh
|
||||
set -euo pipefail
|
||||
ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
||||
VENV_PY="${ROOT}/venv/bin/python"
|
||||
|
||||
if [[ ! -x "${VENV_PY}" ]]; then
|
||||
echo "[ERROR] 未找到 venv,请先 bash deploy.sh deps"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "[INFO] 正在生成 ChatTTS 预设音色(本地 GPU)..."
|
||||
"${VENV_PY}" "${ROOT}/scripts/generate_voice_presets.py"
|
||||
echo "[OK] 预设音色已写入 ${ROOT}/voices/presets/"
|
||||
echo "[OK] 在 Web UI「配音合成」处可从下拉框选择"
|
||||
+532
-47
@@ -8,11 +8,14 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
import uuid
|
||||
import warnings
|
||||
from dataclasses import replace
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -28,7 +31,13 @@ from config import (
|
||||
SPEAKER_EMB_PATH,
|
||||
SPEAKER_SAMPLE_MAX_SEC,
|
||||
SPEAKER_SAMPLE_MIN_SEC,
|
||||
TTS_MAX_CHARS_PER_CHUNK,
|
||||
TTS_ENABLE_CACHE,
|
||||
TTS_MANUAL_SEED,
|
||||
TTS_MAX_NEW_TOKEN,
|
||||
TTS_MIN_NEW_TOKEN,
|
||||
TTS_SAMPLE_RATE,
|
||||
TTS_SEGMENT_PAUSE_SEC,
|
||||
TTS_SPEED_PROMPT,
|
||||
TTS_TEMPERATURE,
|
||||
TTS_TOP_K,
|
||||
@@ -89,7 +98,7 @@ def _load_chat_model(chat) -> None:
|
||||
_ensure_hf_env()
|
||||
model_dir = CHATTTS_MODEL_DIR
|
||||
|
||||
base_kwargs: Dict[str, Any] = {"compile": False}
|
||||
base_kwargs: Dict[str, Any] = {"compile": False, "enable_cache": TTS_ENABLE_CACHE}
|
||||
|
||||
if not hasattr(chat, "load"):
|
||||
if hasattr(chat, "load_models"):
|
||||
@@ -134,11 +143,26 @@ def _load_chat_model(chat) -> None:
|
||||
|
||||
|
||||
def reset_chattts_instance() -> None:
|
||||
"""释放 ChatTTS 实例(模型下载后重启前可调用)。"""
|
||||
"""卸载 ChatTTS 模型并回收 GPU 显存。"""
|
||||
global _chat, _chat_error
|
||||
if _chat is not None:
|
||||
try:
|
||||
if hasattr(_chat, "unload"):
|
||||
_chat.unload()
|
||||
except Exception:
|
||||
logger.exception("ChatTTS unload 失败")
|
||||
try:
|
||||
del _chat
|
||||
except Exception:
|
||||
pass
|
||||
_chat = None
|
||||
_chat_error = None
|
||||
|
||||
from gpu_utils import release_cuda_cache
|
||||
|
||||
release_cuda_cache()
|
||||
logger.info("ChatTTS 模型已卸载,显存已尝试回收。")
|
||||
|
||||
|
||||
def get_chattts_instance():
|
||||
"""
|
||||
@@ -179,29 +203,81 @@ def get_chattts_instance():
|
||||
return None, _chat_error
|
||||
|
||||
|
||||
def _load_audio_via_ffmpeg(audio_path: str, sample_rate: int) -> np.ndarray:
|
||||
"""通过 ffmpeg 转码为 wav 再读取,兼容手机 webm/m4a 等格式。"""
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
tmp_path = tempfile.mktemp(suffix=".wav")
|
||||
try:
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
audio_path,
|
||||
"-ac",
|
||||
"1",
|
||||
"-ar",
|
||||
str(sample_rate),
|
||||
"-f",
|
||||
"wav",
|
||||
tmp_path,
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(result.stderr[-500:] if result.stderr else "ffmpeg 转码失败")
|
||||
|
||||
audio, _ = sf.read(tmp_path, dtype="float32", always_2d=False)
|
||||
if isinstance(audio, np.ndarray) and audio.ndim > 1:
|
||||
audio = audio.mean(axis=1)
|
||||
return np.asarray(audio, dtype=np.float32)
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
def _load_audio_for_chattts(audio_path: str, sample_rate: int = TTS_SAMPLE_RATE) -> np.ndarray:
|
||||
"""
|
||||
加载音频并重采样到 ChatTTS 所需采样率。
|
||||
优先使用 ChatTTS 自带工具,回退到 librosa。
|
||||
优先 ChatTTS 工具 → ffmpeg 转码 → librosa 兜底。
|
||||
"""
|
||||
errors: list[str] = []
|
||||
|
||||
try:
|
||||
from ChatTTS.utils import load_audio
|
||||
|
||||
return load_audio(audio_path, sample_rate)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
errors.append(f"ChatTTS.utils: {exc}")
|
||||
|
||||
try:
|
||||
from tools.audio import load_audio
|
||||
|
||||
return load_audio(audio_path, sample_rate)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
errors.append(f"tools.audio: {exc}")
|
||||
|
||||
import librosa
|
||||
try:
|
||||
return _load_audio_via_ffmpeg(audio_path, sample_rate)
|
||||
except Exception as exc:
|
||||
errors.append(f"ffmpeg: {exc}")
|
||||
|
||||
audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True)
|
||||
return audio
|
||||
try:
|
||||
import librosa
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
warnings.filterwarnings("ignore", message="PySoundFile failed")
|
||||
audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True)
|
||||
return audio
|
||||
except Exception as exc:
|
||||
errors.append(f"librosa: {exc}")
|
||||
|
||||
raise RuntimeError(
|
||||
"无法读取音频文件,请上传 wav/mp3/m4a 或确认已安装 ffmpeg。\n"
|
||||
+ "\n".join(errors[-3:])
|
||||
)
|
||||
|
||||
|
||||
def _get_audio_duration_sec(audio: np.ndarray, sample_rate: int) -> float:
|
||||
@@ -211,15 +287,69 @@ def _get_audio_duration_sec(audio: np.ndarray, sample_rate: int) -> float:
|
||||
return len(audio) / float(sample_rate)
|
||||
|
||||
|
||||
def _encode_spk_emb(chat, tensor_or_str: Any) -> str:
|
||||
"""将 Speaker Embedding 编码为 ChatTTS 可用的字符串格式。"""
|
||||
if isinstance(tensor_or_str, str):
|
||||
return tensor_or_str
|
||||
|
||||
def _encode_random_spk_emb(chat, tensor: torch.Tensor) -> Optional[str]:
|
||||
"""将随机说话人向量编码为 spk_emb 字符串(仅用于 sample_random,非参考音频)。"""
|
||||
speaker = getattr(chat, "speaker", None)
|
||||
if speaker is not None and hasattr(speaker, "_encode"):
|
||||
return speaker._encode(tensor)
|
||||
if hasattr(chat, "_encode_spk_emb"):
|
||||
return chat._encode_spk_emb(tensor_or_str)
|
||||
return chat._encode_spk_emb(tensor)
|
||||
return None
|
||||
|
||||
return tensor_or_str
|
||||
|
||||
def _is_valid_spk_emb_string(chat, spk_emb: str) -> bool:
|
||||
"""spk_emb 与 spk_smp 编码不同;非法字符串会在 lzma 解压时报 Corrupt input data。"""
|
||||
speaker = getattr(chat, "speaker", None)
|
||||
if speaker is None or not hasattr(speaker, "_decode"):
|
||||
return False
|
||||
try:
|
||||
speaker._decode(spk_emb)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_speaker_for_infer(
|
||||
chat,
|
||||
payload: Dict[str, Any],
|
||||
) -> Tuple[Optional[Dict[str, Optional[str]]], Optional[str]]:
|
||||
"""
|
||||
规范 ChatTTS 音色参数。
|
||||
参考音频克隆必须用 spk_smp + txt_smp,不能把 sample_audio_speaker 结果传给 spk_emb。
|
||||
"""
|
||||
spk_smp = payload.get("spk_smp")
|
||||
txt_smp = (payload.get("txt_smp") or "").strip() or None
|
||||
spk_emb = payload.get("spk_emb")
|
||||
warn: Optional[str] = None
|
||||
|
||||
if spk_smp:
|
||||
if not txt_smp:
|
||||
warn = (
|
||||
"未填写参考音频转写(txt_smp),音色克隆可能不稳定。"
|
||||
"建议在「音色锁定」补充精确转写后重新锁定。"
|
||||
)
|
||||
return {"spk_smp": spk_smp, "txt_smp": txt_smp, "spk_emb": None}, warn
|
||||
|
||||
if isinstance(spk_emb, str) and spk_emb.strip():
|
||||
if _is_valid_spk_emb_string(chat, spk_emb):
|
||||
return {"spk_emb": spk_emb, "spk_smp": None, "txt_smp": None}, None
|
||||
# 旧版误存:把 spk_smp 写进了 spk_emb
|
||||
return {
|
||||
"spk_smp": spk_emb,
|
||||
"txt_smp": txt_smp,
|
||||
"spk_emb": None,
|
||||
}, (
|
||||
"检测到旧版音色文件格式,已自动按 spk_smp 加载。"
|
||||
"建议重新锁定音色并填写参考转写。"
|
||||
)
|
||||
|
||||
if isinstance(spk_emb, torch.Tensor):
|
||||
encoded = _encode_random_spk_emb(chat, spk_emb)
|
||||
if encoded:
|
||||
return {"spk_emb": encoded, "spk_smp": None, "txt_smp": None}, None
|
||||
return None, "旧版音色张量无法编码,请重新锁定音色。"
|
||||
|
||||
return None, "音色数据无效或已损坏,请重新锁定音色。"
|
||||
|
||||
|
||||
def save_fixed_speaker(
|
||||
@@ -239,6 +369,13 @@ def save_fixed_speaker(
|
||||
if not audio_sample_path:
|
||||
return False, "未提供音色参考音频。"
|
||||
|
||||
try:
|
||||
from whisper_service import reset_whisper_model
|
||||
|
||||
reset_whisper_model()
|
||||
except Exception:
|
||||
logger.debug("锁定音色前释放 Whisper 显存跳过", exc_info=True)
|
||||
|
||||
chat, init_err = get_chattts_instance()
|
||||
if chat is None:
|
||||
return False, init_err or "ChatTTS 不可用。"
|
||||
@@ -258,10 +395,9 @@ def save_fixed_speaker(
|
||||
audio = audio[:max_samples]
|
||||
|
||||
spk_smp = chat.sample_audio_speaker(audio)
|
||||
spk_emb = _encode_spk_emb(chat, spk_smp)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"spk_emb": spk_emb,
|
||||
"version": 2,
|
||||
"spk_smp": spk_smp,
|
||||
"txt_smp": sample_transcript.strip(),
|
||||
"created_at": datetime.now().isoformat(),
|
||||
@@ -275,7 +411,10 @@ def save_fixed_speaker(
|
||||
f"参考时长: {duration:.1f}s"
|
||||
)
|
||||
if not sample_transcript.strip():
|
||||
msg += "\n提示:填写参考音频精确转写可进一步提升音色还原度。"
|
||||
msg += (
|
||||
"\n⚠️ 未填写参考转写:合成时可能报 Corrupt input data 或音色不稳。"
|
||||
"请填写与录音一致的精确转写后重新锁定。"
|
||||
)
|
||||
|
||||
logger.info("Speaker Embedding 保存成功: %s", SPEAKER_EMB_PATH)
|
||||
return True, msg
|
||||
@@ -301,8 +440,11 @@ def _load_speaker_payload() -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
chat, err = get_chattts_instance()
|
||||
if chat is None:
|
||||
return None, err
|
||||
encoded = _encode_random_spk_emb(chat, payload)
|
||||
if not encoded:
|
||||
return None, "旧版音色张量无法读取,请重新锁定音色。"
|
||||
return {
|
||||
"spk_emb": _encode_spk_emb(chat, payload),
|
||||
"spk_emb": encoded,
|
||||
"spk_smp": None,
|
||||
"txt_smp": "",
|
||||
}, None
|
||||
@@ -324,12 +466,191 @@ def speaker_is_ready() -> Tuple[bool, str]:
|
||||
return True, f"已加载固定音色: {SPEAKER_EMB_PATH}"
|
||||
|
||||
|
||||
def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
_EMOJI_RE = re.compile(
|
||||
"["
|
||||
"\U0001F300-\U0001FAFF"
|
||||
"\U00002700-\U000027BF"
|
||||
"\U00002600-\U000026FF"
|
||||
"]+",
|
||||
flags=re.UNICODE,
|
||||
)
|
||||
|
||||
_TTS_NOTE_MARKERS = (
|
||||
"💡",
|
||||
"量化交易员的修改笔记",
|
||||
"修改笔记(供你参考)",
|
||||
"修改笔记",
|
||||
"供你参考",
|
||||
)
|
||||
|
||||
_STAGE_DIRECTION_RE = re.compile(
|
||||
r"[((][^))]{0,80}(?:前奏|转场|语气|背景|BGM|配乐|节奏|环节)[^))]{0,80}[))]"
|
||||
)
|
||||
|
||||
_CN_DIGITS = "零一二三四五六七八九"
|
||||
|
||||
# ChatTTS tokenizer 对裸 ASCII 数字、控制符敏感,易触发 CUDA device-side assert
|
||||
_TTS_UNSAFE_CHAR_RE = re.compile(
|
||||
r"[\u200b-\u200f\u202a-\u202e\ufeff\x00-\x08\x0b\x0c\x0e-\x1f]"
|
||||
)
|
||||
_TTS_ALLOWED_CHAR_RE = re.compile(
|
||||
r"[^\u4e00-\u9fff\u3400-\u4dbfA-Za-z0-9,。!?;:、「」『』()—…\-\s'\"《》%%]"
|
||||
)
|
||||
|
||||
|
||||
def _digits_to_chinese(text: str) -> str:
|
||||
def _repl(match: re.Match[str]) -> str:
|
||||
return "".join(_CN_DIGITS[int(ch)] for ch in match.group())
|
||||
|
||||
return re.sub(r"\d+", _repl, text)
|
||||
|
||||
|
||||
def _normalize_tts_chunk(text: str) -> str:
|
||||
"""单段合成用:去控制符、数字转中文、合并换行为逗号。"""
|
||||
text = _TTS_UNSAFE_CHAR_RE.sub("", text)
|
||||
text = text.replace("\r", "").replace("\n", ",")
|
||||
text = _digits_to_chinese(text)
|
||||
text = _TTS_ALLOWED_CHAR_RE.sub("", text)
|
||||
text = re.sub(r"[,,]{2,}", ",", text)
|
||||
text = re.sub(r"\s+", "", text)
|
||||
return text.strip(",。 \t")
|
||||
|
||||
|
||||
def prepare_text_for_tts(text: str) -> str:
|
||||
"""
|
||||
使用 ChatTTS 将润色后的文稿合成为 wav 配音。
|
||||
将 LLM 润色稿转为 ChatTTS 可朗读的纯文本。
|
||||
去除 Markdown、emoji、舞台提示、修改笔记等非朗读内容。
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
cleaned = text.replace("\r\n", "\n").strip()
|
||||
|
||||
for marker in _TTS_NOTE_MARKERS:
|
||||
idx = cleaned.find(marker)
|
||||
if idx >= 0:
|
||||
cleaned = cleaned[:idx]
|
||||
|
||||
# 去掉模型常见前言,从标题或正文起点开始
|
||||
for pattern in (
|
||||
r"^作为一名极其严谨的量化交易员.*?配音稿。\s*",
|
||||
r"^以下是为你润色后的文案[::]*\s*",
|
||||
r"^以下(?:是|为).*?润色.*?文案[::]*\s*",
|
||||
):
|
||||
cleaned = re.sub(pattern, "", cleaned, count=1, flags=re.DOTALL)
|
||||
|
||||
cleaned = re.sub(r"^\*{3,}\s*$", "", cleaned, flags=re.MULTILINE)
|
||||
cleaned = re.sub(r"^-{3,}\s*$", "", cleaned, flags=re.MULTILINE)
|
||||
cleaned = re.sub(r"^#{1,6}\s*", "", cleaned, flags=re.MULTILINE)
|
||||
cleaned = re.sub(r"\*\*([^*\n]+)\*\*", r"\1", cleaned)
|
||||
cleaned = re.sub(r"\*([^*\n]+)\*", r"\1", cleaned)
|
||||
cleaned = re.sub(r"__([^_\n]+)__", r"\1", cleaned)
|
||||
cleaned = _STAGE_DIRECTION_RE.sub("", cleaned)
|
||||
cleaned = _EMOJI_RE.sub("", cleaned)
|
||||
cleaned = re.sub(r"^\d+\.\s*", "", cleaned, flags=re.MULTILINE)
|
||||
cleaned = re.sub(r"^[-*]\s+", "", cleaned, flags=re.MULTILINE)
|
||||
cleaned = re.sub(r"[ \t]+\n", "\n", cleaned)
|
||||
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
|
||||
|
||||
lines = [ln.strip() for ln in cleaned.split("\n")]
|
||||
lines = [ln for ln in lines if ln and not re.fullmatch(r"[*\-#]+", ln)]
|
||||
merged = "。".join(lines)
|
||||
return _normalize_tts_chunk(merged)
|
||||
|
||||
|
||||
def split_text_for_tts(text: str, max_chars: int = TTS_MAX_CHARS_PER_CHUNK) -> List[str]:
|
||||
"""按句号/换行切分长稿,避免 ChatTTS 单段过长失败。"""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
if len(text) <= max_chars:
|
||||
return [text]
|
||||
|
||||
parts = re.split(r"(?<=[。!?!?;;])\s*|\n+", text)
|
||||
chunks: List[str] = []
|
||||
buf = ""
|
||||
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
candidate = f"{buf}{part}" if buf else part
|
||||
if len(candidate) <= max_chars:
|
||||
buf = candidate
|
||||
continue
|
||||
if buf:
|
||||
chunks.append(buf)
|
||||
buf = ""
|
||||
if len(part) <= max_chars:
|
||||
buf = part
|
||||
continue
|
||||
for i in range(0, len(part), max_chars):
|
||||
chunks.append(part[i : i + max_chars])
|
||||
|
||||
if buf:
|
||||
chunks.append(buf)
|
||||
|
||||
return [_normalize_tts_chunk(c) for c in chunks if c.strip()]
|
||||
|
||||
|
||||
def _is_cuda_runtime_error(exc: BaseException) -> bool:
|
||||
from gpu_utils import is_cuda_runtime_error
|
||||
|
||||
return is_cuda_runtime_error(exc)
|
||||
|
||||
|
||||
def _run_chattts_infer(
|
||||
chat: Any,
|
||||
chunk: str,
|
||||
params_refine_text: Any,
|
||||
params_infer_code: Any,
|
||||
) -> Any:
|
||||
"""单次 ChatTTS infer;split_text=False 避免段内再切分引发 mask 异常。"""
|
||||
return chat.infer(
|
||||
chunk,
|
||||
skip_refine_text=False,
|
||||
split_text=False,
|
||||
do_text_normalization=True,
|
||||
do_homophone_replacement=True,
|
||||
params_refine_text=params_refine_text,
|
||||
params_infer_code=params_infer_code,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_segment_peak(wav: np.ndarray, target_peak: float = 0.95) -> np.ndarray:
|
||||
"""各段单独归一化峰值,避免拼接后某段偏响/偏轻像换了人声。"""
|
||||
arr = np.asarray(wav, dtype=np.float32).flatten()
|
||||
peak = float(np.max(np.abs(arr))) or 1.0
|
||||
return arr / peak * target_peak
|
||||
|
||||
|
||||
def _concat_wavs(
|
||||
wavs: List[np.ndarray],
|
||||
sample_rate: int,
|
||||
pause_sec: float = TTS_SEGMENT_PAUSE_SEC,
|
||||
) -> np.ndarray:
|
||||
if not wavs:
|
||||
return np.array([], dtype=np.float32)
|
||||
|
||||
pause = np.zeros(int(sample_rate * pause_sec), dtype=np.float32)
|
||||
segments: List[np.ndarray] = []
|
||||
for i, wav in enumerate(wavs):
|
||||
segments.append(_normalize_segment_peak(wav))
|
||||
if i < len(wavs) - 1:
|
||||
segments.append(pause)
|
||||
return np.concatenate(segments)
|
||||
|
||||
|
||||
def generate_voice(
|
||||
refined_text: str,
|
||||
voice_id: str = "custom",
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
使用 ChatTTS(本地 GPU)将润色稿合成为 wav。
|
||||
|
||||
Args:
|
||||
refined_text: LLM 润色后的配音稿
|
||||
voice_id: ``custom`` 为锁定音色,``preset_*`` 为内置预设(见 voice_presets)
|
||||
|
||||
Returns:
|
||||
(success, message, output_wav_path_or_none)
|
||||
@@ -337,46 +658,160 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
if not refined_text or not refined_text.strip():
|
||||
return False, "合成文本为空,请先完成润色。", None
|
||||
|
||||
# 合成前释放 Whisper,避免 8GB 显存上双模型 OOM
|
||||
try:
|
||||
from whisper_service import reset_whisper_model
|
||||
|
||||
reset_whisper_model()
|
||||
except Exception:
|
||||
logger.debug("合成前释放 Whisper 显存跳过", exc_info=True)
|
||||
|
||||
from gpu_utils import cuda_memory_summary, release_cuda_cache
|
||||
|
||||
release_cuda_cache()
|
||||
logger.info("TTS 合成前 %s", cuda_memory_summary())
|
||||
|
||||
chat, init_err = get_chattts_instance()
|
||||
if chat is None:
|
||||
return False, init_err or "ChatTTS 不可用。", None
|
||||
|
||||
payload, spk_err = _load_speaker_payload()
|
||||
from voice_presets import load_voice_payload
|
||||
|
||||
payload, spk_err = load_voice_payload(voice_id)
|
||||
if payload is None:
|
||||
return False, spk_err or "请先锁定音色。", None
|
||||
return False, spk_err or "请先选择或生成可用音色。", None
|
||||
|
||||
try:
|
||||
import ChatTTS
|
||||
|
||||
spk_emb = payload.get("spk_emb")
|
||||
spk_smp = payload.get("spk_smp")
|
||||
txt_smp = payload.get("txt_smp", "")
|
||||
speak_text = prepare_text_for_tts(refined_text)
|
||||
if not speak_text:
|
||||
return (
|
||||
False,
|
||||
"清洗后无有效朗读文本。请删除 Markdown(#、**)、emoji、舞台提示和「修改笔记」,"
|
||||
"只保留可念出的正文后再合成。",
|
||||
None,
|
||||
)
|
||||
|
||||
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
||||
prompt=TTS_SPEED_PROMPT,
|
||||
spk_emb=spk_emb,
|
||||
spk_smp=spk_smp if spk_smp else None,
|
||||
txt_smp=txt_smp if txt_smp else None,
|
||||
temperature=TTS_TEMPERATURE,
|
||||
top_P=TTS_TOP_P,
|
||||
top_K=TTS_TOP_K,
|
||||
chunks = split_text_for_tts(speak_text)
|
||||
if not chunks:
|
||||
return False, "无法切分朗读文本,请检查润色稿内容。", None
|
||||
|
||||
speaker_params, speaker_warn = _normalize_speaker_for_infer(chat, payload)
|
||||
if speaker_params is None:
|
||||
return False, speaker_warn or "音色参数无效,请重新锁定音色。", None
|
||||
if speaker_warn:
|
||||
logger.warning(speaker_warn)
|
||||
|
||||
chunk_temperature = (
|
||||
min(TTS_TEMPERATURE, 0.2) if len(chunks) > 1 else TTS_TEMPERATURE
|
||||
)
|
||||
infer_kwargs: Dict[str, Any] = {
|
||||
"prompt": TTS_SPEED_PROMPT,
|
||||
"spk_emb": speaker_params.get("spk_emb"),
|
||||
"spk_smp": speaker_params.get("spk_smp"),
|
||||
"txt_smp": speaker_params.get("txt_smp"),
|
||||
"temperature": chunk_temperature,
|
||||
"top_P": TTS_TOP_P,
|
||||
"top_K": TTS_TOP_K,
|
||||
"max_new_token": TTS_MAX_NEW_TOKEN,
|
||||
"min_new_token": TTS_MIN_NEW_TOKEN,
|
||||
"ensure_non_empty": False,
|
||||
}
|
||||
if "manual_seed" in inspect.signature(ChatTTS.Chat.InferCodeParams).parameters:
|
||||
infer_kwargs["manual_seed"] = TTS_MANUAL_SEED
|
||||
params_infer_code = ChatTTS.Chat.InferCodeParams(**infer_kwargs)
|
||||
|
||||
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
||||
prompt="[oral_2][laugh_0][break_4]",
|
||||
ensure_non_empty=False,
|
||||
min_new_token=4,
|
||||
)
|
||||
|
||||
wavs = chat.infer(
|
||||
refined_text.strip(),
|
||||
skip_refine_text=False,
|
||||
params_refine_text=params_refine_text,
|
||||
params_infer_code=params_infer_code,
|
||||
logger.info(
|
||||
"TTS 合成: 原文 %d 字 → 清洗后 %d 字,分 %d 段",
|
||||
len(refined_text),
|
||||
len(speak_text),
|
||||
len(chunks),
|
||||
)
|
||||
|
||||
if not wavs or len(wavs) == 0:
|
||||
return False, "ChatTTS 未生成有效音频。", None
|
||||
segment_wavs: List[np.ndarray] = []
|
||||
for idx, chunk in enumerate(chunks, start=1):
|
||||
if not chunk or len(chunk) < 2:
|
||||
continue
|
||||
release_cuda_cache()
|
||||
chunk_infer = params_infer_code
|
||||
wavs = None
|
||||
last_exc: Optional[BaseException] = None
|
||||
for attempt in range(3):
|
||||
try:
|
||||
wavs = _run_chattts_infer(
|
||||
chat, chunk, params_refine_text, chunk_infer
|
||||
)
|
||||
break
|
||||
except RecursionError as exc:
|
||||
last_exc = exc
|
||||
# 重试时仍保持同一 manual_seed,避免段内/段间音色突变
|
||||
if "manual_seed" in infer_kwargs and attempt < 2:
|
||||
chunk_infer = replace(
|
||||
params_infer_code,
|
||||
manual_seed=TTS_MANUAL_SEED + attempt + 1,
|
||||
)
|
||||
release_cuda_cache()
|
||||
except RuntimeError as exc:
|
||||
last_exc = exc
|
||||
if not _is_cuda_runtime_error(exc) or attempt >= 2:
|
||||
raise
|
||||
logger.warning(
|
||||
"第 %d 段 CUDA 异常,重置 ChatTTS 后重试 (%d/3): %s",
|
||||
idx,
|
||||
attempt + 1,
|
||||
exc,
|
||||
)
|
||||
reset_chattts_instance()
|
||||
release_cuda_cache()
|
||||
chat, reload_err = get_chattts_instance()
|
||||
if chat is None:
|
||||
raise RuntimeError(reload_err or "ChatTTS 重载失败") from exc
|
||||
if "manual_seed" in infer_kwargs:
|
||||
chunk_infer = replace(
|
||||
params_infer_code,
|
||||
manual_seed=TTS_MANUAL_SEED + attempt + 1,
|
||||
)
|
||||
if wavs is None:
|
||||
return (
|
||||
False,
|
||||
f"ChatTTS 第 {idx}/{len(chunks)} 段合成失败(递归重试耗尽)。"
|
||||
f"请检查音色转写是否填写,或缩短该段文本。"
|
||||
f" 详情: {last_exc}",
|
||||
None,
|
||||
)
|
||||
if not wavs or len(wavs) == 0:
|
||||
return (
|
||||
False,
|
||||
f"ChatTTS 第 {idx}/{len(chunks)} 段未生成音频。"
|
||||
f"(段内容前 40 字: {chunk[:40]}…)",
|
||||
None,
|
||||
)
|
||||
wav_arr = np.asarray(wavs[0], dtype=np.float32)
|
||||
if wav_arr.size == 0 or np.max(np.abs(wav_arr)) < 1e-6:
|
||||
return (
|
||||
False,
|
||||
f"ChatTTS 第 {idx}/{len(chunks)} 段生成了空音频。"
|
||||
"请重新锁定音色并填写参考转写,或缩短润色稿后重试。",
|
||||
None,
|
||||
)
|
||||
segment_wavs.append(wav_arr)
|
||||
release_cuda_cache()
|
||||
|
||||
wav_array = np.asarray(wavs[0], dtype=np.float32)
|
||||
if not segment_wavs:
|
||||
return False, "分段清洗后无有效文本,请缩短或简化润色稿后重试。", None
|
||||
|
||||
wav_array = (
|
||||
segment_wavs[0]
|
||||
if len(segment_wavs) == 1
|
||||
else _concat_wavs(segment_wavs, TTS_SAMPLE_RATE)
|
||||
)
|
||||
|
||||
peak = np.max(np.abs(wav_array)) or 1.0
|
||||
wav_int16 = (wav_array / peak * 32767).astype(np.int16)
|
||||
@@ -387,11 +822,61 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]:
|
||||
|
||||
wavfile.write(str(output_path), TTS_SAMPLE_RATE, wav_int16)
|
||||
|
||||
msg = f"配音合成成功: {output_path}"
|
||||
chunk_note = f",共 {len(chunks)} 段拼接" if len(chunks) > 1 else ""
|
||||
msg = (
|
||||
f"配音合成成功: {output_path}"
|
||||
f"(朗读 {len(speak_text)} 字{chunk_note})"
|
||||
)
|
||||
if speaker_warn:
|
||||
msg = f"{speaker_warn}\n{msg}"
|
||||
logger.info(msg)
|
||||
return True, msg, str(output_path)
|
||||
|
||||
except Exception as exc:
|
||||
err = f"语音合成失败: {exc}\n{traceback.format_exc()}"
|
||||
exc_msg = str(exc)
|
||||
if "out of memory" in exc_msg.lower() or "OutOfMemoryError" in exc_msg:
|
||||
release_cuda_cache()
|
||||
err = (
|
||||
"语音合成失败: GPU 显存不足(CUDA OOM)。\n"
|
||||
"3060 Ti 8GB 无法同时运行 Whisper + ChatTTS。\n"
|
||||
"处理步骤:\n"
|
||||
"1. pm2 restart trading_studio 释放显存\n"
|
||||
"2. 不要连续快速点识别+合成;合成前系统会自动卸载 Whisper\n"
|
||||
"3. 若仍 OOM,在 .env 设置 TTS_MAX_CHARS_PER_CHUNK=150、TTS_MAX_NEW_TOKEN=768\n"
|
||||
"4. 确认无其他程序占用 GPU: nvidia-smi\n"
|
||||
f"技术详情: {exc_msg[:400]}"
|
||||
)
|
||||
elif _is_cuda_runtime_error(exc):
|
||||
reset_chattts_instance()
|
||||
release_cuda_cache()
|
||||
err = (
|
||||
"语音合成失败: GPU/CUDA 异常(device-side assert 等)。\n"
|
||||
"常见原因:此前 OOM 导致 GPU 状态损坏,或文本含特殊字符。\n"
|
||||
"处理步骤:\n"
|
||||
"1. pm2 restart trading_studio(必须,清理 GPU 脏状态)\n"
|
||||
"2. 确认已填写参考音频转写并重新锁定音色\n"
|
||||
"3. 用 2-3 句短中文试合成\n"
|
||||
"4. 若仍失败,在 .env 设 TTS_ENABLE_CACHE=false 后重启\n"
|
||||
f"技术详情: {exc_msg[:500]}"
|
||||
)
|
||||
elif "recursion depth" in exc_msg.lower() or isinstance(exc, RecursionError):
|
||||
err = (
|
||||
"语音合成失败: ChatTTS 反复生成空结果导致递归超限。\n"
|
||||
"常见原因:未填写参考音频转写、润色稿含特殊符号、或音色文件异常。\n"
|
||||
"处理:重新锁定音色并填写转写 → 用较短纯文本试合成。\n"
|
||||
f"技术详情: {exc_msg[:400]}"
|
||||
)
|
||||
elif "Corrupt input data" in exc_msg:
|
||||
err = (
|
||||
"语音合成失败: 音色数据损坏或格式不兼容(Corrupt input data)。\n"
|
||||
"处理步骤:\n"
|
||||
"1. 删除旧音色: rm speaker_emb.pt\n"
|
||||
"2. 在「音色锁定」重新上传参考人声\n"
|
||||
"3. 填写与录音一致的「参考音频精确转写」(必填)\n"
|
||||
"4. 重新点击锁定音色后再合成\n"
|
||||
f"技术详情: {exc_msg}"
|
||||
)
|
||||
else:
|
||||
err = f"语音合成失败: {exc}\n{traceback.format_exc()}"
|
||||
logger.exception("generate_voice 失败")
|
||||
return False, err, None
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
本地配音历史:扫描 outputs/ 下已生成的 wav,供 Gradio 下拉试听与下载。
|
||||
文件不会被自动删除,重启服务后仍可访问。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from config import OUTPUT_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HISTORY_MAX_ITEMS = 50
|
||||
VOICEOVER_GLOB = "voiceover_*.wav"
|
||||
|
||||
|
||||
def list_voice_history(limit: int = HISTORY_MAX_ITEMS) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
返回 Gradio Dropdown 选项:(显示名, 文件绝对路径),按时间倒序。
|
||||
"""
|
||||
if not OUTPUT_DIR.is_dir():
|
||||
return []
|
||||
|
||||
files = sorted(
|
||||
OUTPUT_DIR.glob(VOICEOVER_GLOB),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
reverse=True,
|
||||
)[:limit]
|
||||
|
||||
choices: List[Tuple[str, str]] = []
|
||||
for path in files:
|
||||
try:
|
||||
st = path.stat()
|
||||
except OSError:
|
||||
logger.debug("跳过不可读历史文件: %s", path)
|
||||
continue
|
||||
ts = datetime.fromtimestamp(st.st_mtime).strftime("%Y-%m-%d %H:%M")
|
||||
size_mb = st.st_size / (1024 * 1024)
|
||||
label = f"{ts} · {path.name} ({size_mb:.1f} MB)"
|
||||
choices.append((label, str(path.resolve())))
|
||||
return choices
|
||||
|
||||
|
||||
def latest_voice_path() -> str | None:
|
||||
"""最新一条配音路径,无历史时返回 None。"""
|
||||
items = list_voice_history(limit=1)
|
||||
return items[0][1] if items else None
|
||||
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
本地 GPU 音色库(ChatTTS,无需云端 API)
|
||||
- custom:用户在「音色锁定」克隆的 speaker_emb.pt
|
||||
- preset_*:ChatTTS sample_random_speaker 生成的内置说话人(scripts/generate_voice_presets.sh)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from config import BASE_DIR, SPEAKER_EMB_PATH
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOICES_DIR = Path(BASE_DIR) / "voices"
|
||||
PRESETS_DIR = VOICES_DIR / "presets"
|
||||
MANIFEST_PATH = VOICES_DIR / "manifest.json"
|
||||
|
||||
CUSTOM_VOICE_ID = "custom"
|
||||
DEFAULT_PRESET_VOICE_ID = "preset_01"
|
||||
DEFAULT_PRESET_VOICE_LABEL = "预设·沉稳男声"
|
||||
|
||||
# 生成脚本写入的预设元数据(.pt 文件不入 Git)
|
||||
DEFAULT_MANIFEST = {
|
||||
"presets": [
|
||||
{"id": "preset_01", "label": "预设·沉稳男声", "file": "presets/preset_01.pt"},
|
||||
{"id": "preset_02", "label": "预设·青年男声", "file": "presets/preset_02.pt"},
|
||||
{"id": "preset_03", "label": "预设·温柔女声", "file": "presets/preset_03.pt"},
|
||||
{"id": "preset_04", "label": "预设·活泼女声", "file": "presets/preset_04.pt"},
|
||||
{"id": "preset_05", "label": "预设·中性旁白", "file": "presets/preset_05.pt"},
|
||||
{"id": "preset_06", "label": "预设·纪录片风", "file": "presets/preset_06.pt"},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def ensure_manifest() -> None:
|
||||
VOICES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
PRESETS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
if not MANIFEST_PATH.is_file():
|
||||
MANIFEST_PATH.write_text(
|
||||
json.dumps(DEFAULT_MANIFEST, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def _read_manifest() -> Dict[str, Any]:
|
||||
ensure_manifest()
|
||||
try:
|
||||
return json.loads(MANIFEST_PATH.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
logger.warning("读取 manifest 失败: %s", exc)
|
||||
return DEFAULT_MANIFEST
|
||||
|
||||
|
||||
def list_voice_choices() -> List[Tuple[str, str]]:
|
||||
"""
|
||||
返回 Gradio Dropdown 选项:(显示名, voice_id)。
|
||||
仅列出磁盘上已存在的音色。
|
||||
"""
|
||||
choices: List[Tuple[str, str]] = []
|
||||
|
||||
if SPEAKER_EMB_PATH.is_file():
|
||||
choices.append(("我的锁定音色(声音克隆)", CUSTOM_VOICE_ID))
|
||||
|
||||
for preset in _read_manifest().get("presets", []):
|
||||
pid = preset.get("id", "")
|
||||
label = preset.get("label", pid)
|
||||
rel = preset.get("file", "")
|
||||
if pid and rel and (VOICES_DIR / rel).is_file():
|
||||
choices.append((label, pid))
|
||||
|
||||
if not choices:
|
||||
choices.append(
|
||||
(
|
||||
"(请先在「音色锁定」上传人声,或运行 generate_voice_presets.sh)",
|
||||
CUSTOM_VOICE_ID,
|
||||
)
|
||||
)
|
||||
return choices
|
||||
|
||||
|
||||
def default_voice_id() -> str:
|
||||
choices = list_voice_choices()
|
||||
if not choices:
|
||||
return DEFAULT_PRESET_VOICE_ID
|
||||
for _label, vid in choices:
|
||||
if vid == DEFAULT_PRESET_VOICE_ID:
|
||||
return vid
|
||||
for _label, vid in choices:
|
||||
if vid != CUSTOM_VOICE_ID:
|
||||
return vid
|
||||
return choices[0][1]
|
||||
|
||||
|
||||
def default_voice_label() -> str:
|
||||
for lbl, vid in list_voice_choices():
|
||||
if vid == DEFAULT_PRESET_VOICE_ID:
|
||||
return lbl
|
||||
labels = voice_choice_labels()
|
||||
return labels[0] if labels else DEFAULT_PRESET_VOICE_LABEL
|
||||
|
||||
|
||||
def voice_choice_labels() -> List[str]:
|
||||
return [c[0] for c in list_voice_choices()]
|
||||
|
||||
|
||||
def label_to_voice_id(label: str) -> str:
|
||||
for lbl, vid in list_voice_choices():
|
||||
if lbl == label:
|
||||
return vid
|
||||
return default_voice_id()
|
||||
|
||||
|
||||
def load_voice_payload(voice_id: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""按 voice_id 加载 ChatTTS 说话人数据。"""
|
||||
if voice_id == CUSTOM_VOICE_ID or not voice_id:
|
||||
if not SPEAKER_EMB_PATH.is_file():
|
||||
return None, (
|
||||
"未找到锁定音色。请在「音色锁定」上传参考人声,"
|
||||
"或选择下方「预设」音色(需先运行 scripts/generate_voice_presets.sh)。"
|
||||
)
|
||||
return _load_payload_file(SPEAKER_EMB_PATH)
|
||||
|
||||
for preset in _read_manifest().get("presets", []):
|
||||
if preset.get("id") != voice_id:
|
||||
continue
|
||||
path = VOICES_DIR / preset.get("file", "")
|
||||
if not path.is_file():
|
||||
return None, (
|
||||
f"预设音色「{preset.get('label', voice_id)}」尚未生成。\n"
|
||||
f"请在服务器执行: bash scripts/generate_voice_presets.sh"
|
||||
)
|
||||
return _load_payload_file(path)
|
||||
|
||||
return None, f"未知音色 ID: {voice_id}"
|
||||
|
||||
|
||||
def _load_payload_file(path: Path) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
try:
|
||||
payload = torch.load(path, map_location="cpu", weights_only=False)
|
||||
if isinstance(payload, torch.Tensor):
|
||||
return {"spk_emb": payload, "spk_smp": None, "txt_smp": ""}, None
|
||||
if isinstance(payload, dict):
|
||||
return payload, None
|
||||
return None, f"音色文件格式无效: {path.name}"
|
||||
except Exception as exc:
|
||||
return None, f"读取音色文件失败 ({path.name}): {exc}"
|
||||
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"presets": [
|
||||
{"id": "preset_01", "label": "预设·沉稳男声", "file": "presets/preset_01.pt"},
|
||||
{"id": "preset_02", "label": "预设·青年男声", "file": "presets/preset_02.pt"},
|
||||
{"id": "preset_03", "label": "预设·温柔女声", "file": "presets/preset_03.pt"},
|
||||
{"id": "preset_04", "label": "预设·活泼女声", "file": "presets/preset_04.pt"},
|
||||
{"id": "preset_05", "label": "预设·中性旁白", "file": "presets/preset_05.pt"},
|
||||
{"id": "preset_06", "label": "预设·纪录片风", "file": "presets/preset_06.pt"}
|
||||
]
|
||||
}
|
||||
+104
-36
@@ -6,43 +6,93 @@ Faster-Whisper CUDA 语音识别服务
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from config import (
|
||||
BASE_DIR,
|
||||
HF_ENDPOINT,
|
||||
HF_HOME,
|
||||
HF_HUB_DOWNLOAD_TIMEOUT,
|
||||
WHISPER_COMPUTE_TYPE,
|
||||
WHISPER_DEVICE,
|
||||
WHISPER_HF_REPO,
|
||||
WHISPER_LANGUAGE,
|
||||
WHISPER_MODEL_DIR,
|
||||
WHISPER_MODEL_SIZE,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局懒加载模型实例,避免 Gradio 重复初始化占用显存
|
||||
_model = None
|
||||
_model_error: Optional[str] = None
|
||||
|
||||
|
||||
def _ensure_hf_env() -> None:
|
||||
os.environ.setdefault("HF_ENDPOINT", HF_ENDPOINT)
|
||||
os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", str(HF_HUB_DOWNLOAD_TIMEOUT))
|
||||
os.environ.setdefault("HF_HOME", str(HF_HOME))
|
||||
WHISPER_MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _whisper_local_path() -> Optional[Path]:
|
||||
"""返回已预下载的本地模型目录。"""
|
||||
local = WHISPER_MODEL_DIR / WHISPER_MODEL_SIZE
|
||||
if (local / "model.bin").is_file():
|
||||
return local
|
||||
return None
|
||||
|
||||
|
||||
def _is_cuda_error(exc: BaseException) -> bool:
|
||||
"""判断异常是否与 CUDA/GPU 相关。"""
|
||||
msg = str(exc).lower()
|
||||
cuda_keywords = (
|
||||
"cuda",
|
||||
"cudnn",
|
||||
"cublas",
|
||||
"gpu",
|
||||
"out of memory",
|
||||
"no kernel image",
|
||||
"device-side assert",
|
||||
"cuda", "cudnn", "cublas", "gpu",
|
||||
"out of memory", "no kernel image", "device-side assert",
|
||||
)
|
||||
return any(k in msg for k in cuda_keywords)
|
||||
|
||||
|
||||
def _is_network_error(exc: BaseException) -> bool:
|
||||
msg = str(exc).lower()
|
||||
return any(
|
||||
k in msg
|
||||
for k in (
|
||||
"network is unreachable",
|
||||
"connection error",
|
||||
"connecterror",
|
||||
"timed out",
|
||||
"couldn't connect",
|
||||
"name resolution",
|
||||
"hub",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _build_load_error(exc: BaseException) -> str:
|
||||
lines = [
|
||||
"Whisper 模型加载失败。",
|
||||
f"详情: {exc}",
|
||||
"",
|
||||
]
|
||||
if _is_network_error(exc):
|
||||
lines.extend([
|
||||
"原因:服务器无法访问 HuggingFace 下载模型(内网/无外网常见)。",
|
||||
"请在服务器执行(走 HF 镜像,仅需一次):",
|
||||
f" cd {BASE_DIR}",
|
||||
" bash scripts/download_whisper_models.sh",
|
||||
" pm2 restart trading_studio",
|
||||
"",
|
||||
f"模型将保存到: {WHISPER_MODEL_DIR / WHISPER_MODEL_SIZE}",
|
||||
])
|
||||
else:
|
||||
lines.append(f"完整日志:\n{traceback.format_exc()}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def get_whisper_model():
|
||||
"""
|
||||
获取或初始化 Faster-Whisper 模型。
|
||||
强制 device=cuda, compute_type=float16。
|
||||
"""
|
||||
"""获取或初始化 Faster-Whisper 模型(优先本地预下载)。"""
|
||||
global _model, _model_error
|
||||
|
||||
if _model is not None:
|
||||
@@ -52,18 +102,31 @@ def get_whisper_model():
|
||||
return None, _model_error
|
||||
|
||||
try:
|
||||
_ensure_hf_env()
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
local = _whisper_local_path()
|
||||
if local:
|
||||
model_id = str(local)
|
||||
logger.info("Whisper 从本地加载: %s", model_id)
|
||||
else:
|
||||
model_id = WHISPER_MODEL_SIZE
|
||||
logger.warning(
|
||||
"未找到本地 Whisper 模型 (%s),尝试在线下载(可能失败)…",
|
||||
WHISPER_MODEL_DIR / WHISPER_MODEL_SIZE,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"正在加载 Whisper 模型: size=%s, device=%s, compute_type=%s",
|
||||
WHISPER_MODEL_SIZE,
|
||||
"Whisper 加载: model=%s, device=%s, compute_type=%s",
|
||||
model_id,
|
||||
WHISPER_DEVICE,
|
||||
WHISPER_COMPUTE_TYPE,
|
||||
)
|
||||
_model = WhisperModel(
|
||||
WHISPER_MODEL_SIZE,
|
||||
model_id,
|
||||
device=WHISPER_DEVICE,
|
||||
compute_type=WHISPER_COMPUTE_TYPE,
|
||||
download_root=str(WHISPER_MODEL_DIR),
|
||||
)
|
||||
logger.info("Whisper 模型加载成功。")
|
||||
return _model, None
|
||||
@@ -79,29 +142,28 @@ def get_whisper_model():
|
||||
except Exception as exc:
|
||||
if _is_cuda_error(exc):
|
||||
_model_error = (
|
||||
"CUDA 初始化失败,请检查 NVIDIA 驱动、CUDA 运行时及 cuDNN 是否正确安装。\n"
|
||||
f"错误详情: {exc}\n"
|
||||
f"{traceback.format_exc()}"
|
||||
"CUDA 初始化失败,请检查 NVIDIA 驱动、CUDA 运行时及 cuDNN。\n"
|
||||
f"错误详情: {exc}"
|
||||
)
|
||||
else:
|
||||
_model_error = f"Whisper 模型加载失败: {exc}\n{traceback.format_exc()}"
|
||||
_model_error = _build_load_error(exc)
|
||||
logger.exception("Whisper 模型加载异常")
|
||||
return None, _model_error
|
||||
|
||||
|
||||
def transcribe_audio(audio_path: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
将音频文件转写为中文文本。
|
||||
|
||||
Args:
|
||||
audio_path: 本地音频文件绝对或相对路径
|
||||
|
||||
Returns:
|
||||
(success, text_or_error_message)
|
||||
"""
|
||||
"""将音频文件转写为中文文本。"""
|
||||
if not audio_path:
|
||||
return False, "未提供音频文件路径。"
|
||||
|
||||
# 识别前释放 ChatTTS,避免与 Whisper 同占 8GB 显存
|
||||
try:
|
||||
from tts_service import reset_chattts_instance
|
||||
|
||||
reset_chattts_instance()
|
||||
except Exception:
|
||||
logger.debug("释放 ChatTTS 显存时跳过", exc_info=True)
|
||||
|
||||
model, init_error = get_whisper_model()
|
||||
if model is None:
|
||||
return False, init_error or "Whisper 模型不可用。"
|
||||
@@ -114,10 +176,7 @@ def transcribe_audio(audio_path: str) -> Tuple[bool, str]:
|
||||
vad_filter=True,
|
||||
)
|
||||
|
||||
text_parts = []
|
||||
for segment in segments:
|
||||
text_parts.append(segment.text.strip())
|
||||
|
||||
text_parts = [segment.text.strip() for segment in segments]
|
||||
result_text = "".join(text_parts).strip()
|
||||
|
||||
if not result_text:
|
||||
@@ -137,8 +196,7 @@ def transcribe_audio(audio_path: str) -> Tuple[bool, str]:
|
||||
except Exception as exc:
|
||||
if _is_cuda_error(exc):
|
||||
err = (
|
||||
"CUDA 推理异常:显存可能不足或 GPU 状态异常。"
|
||||
"建议关闭其他占用显存的进程后重试。\n"
|
||||
"CUDA 推理异常:显存可能不足或 GPU 状态异常。\n"
|
||||
f"错误详情: {exc}"
|
||||
)
|
||||
else:
|
||||
@@ -149,7 +207,17 @@ def transcribe_audio(audio_path: str) -> Tuple[bool, str]:
|
||||
|
||||
|
||||
def reset_whisper_model() -> None:
|
||||
"""释放模型引用(用于调试或显存回收)。"""
|
||||
"""卸载 Whisper 模型并回收 GPU 显存。"""
|
||||
global _model, _model_error
|
||||
if _model is not None:
|
||||
try:
|
||||
del _model
|
||||
except Exception:
|
||||
pass
|
||||
_model = None
|
||||
_model_error = None
|
||||
|
||||
from gpu_utils import release_cuda_cache
|
||||
|
||||
release_cuda_cache()
|
||||
logger.info("Whisper 模型已卸载,显存已尝试回收。")
|
||||
|
||||
Reference in New Issue
Block a user