Files
Trading_Studio/app.py
T
2026-06-12 14:20:05 +08:00

552 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Trading Studio — 自动化交易复盘视频配音系统
Gradio Web 中控:音色锁定 → Whisper 识别 → Gemma4 润色 → ChatTTS 合成
"""
from __future__ import annotations
import logging
import shutil
import sys
import uuid
from pathlib import Path
import gradio as gr
from config import (
GIT_REPO_URL,
HOST,
MODEL_NAME,
OLLAMA_URL,
PORT,
SPEAKER_EMB_PATH,
UPLOAD_DIR,
)
from llm_service import check_ollama_health, polish_text
from tts_service import generate_voice, save_fixed_speaker, speaker_is_ready
from whisper_service import transcribe_audio
# ---------------------------------------------------------------------------
# 日志
# ---------------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler("trading_studio.log", encoding="utf-8"),
],
)
logger = logging.getLogger("trading_studio")
# ---------------------------------------------------------------------------
# 全局 UI 状态(Gradio State
# ---------------------------------------------------------------------------
# raw_transcript / polished_script 在流水线中传递
def _save_upload(upload_file) -> str | None:
"""将 Gradio 上传文件复制到本地 uploads 目录,返回持久化路径。"""
if upload_file is None:
return None
src = Path(upload_file)
if not src.exists():
return None
dest = UPLOAD_DIR / f"{uuid.uuid4().hex}_{src.name}"
shutil.copy2(src, dest)
return str(dest)
# ---------------------------------------------------------------------------
# 模块 1:音色锁定
# ---------------------------------------------------------------------------
def ui_lock_speaker(audio_file, sample_transcript: str) -> tuple[str, str]:
"""【音色锁定】从参考人声提取并保存 Speaker Embedding。"""
path = _save_upload(audio_file)
if not path:
return "请上传 10-30 秒干净参考人声(wav/mp3 均可)。", ui_speaker_status_html()
ok, msg = save_fixed_speaker(path, sample_transcript or "")
result = msg if ok else f"{msg}"
return result, ui_speaker_status_html()
def ui_speaker_status() -> str:
"""刷新音色状态(纯文本,供日志框使用)。"""
ok, msg = speaker_is_ready()
return f"{msg}" if ok else f"⚠️ {msg}"
# ---------------------------------------------------------------------------
# 模块 2:音频极速识别
# ---------------------------------------------------------------------------
def ui_transcribe(audio_file) -> tuple[str, str]:
"""【Whisper 识别】返回 (转写文本, 状态日志)。"""
path = _save_upload(audio_file)
if not path:
return "", "请上传待识别的碎碎念录音。"
ok, result = transcribe_audio(path)
if ok:
return result, f"✅ 识别完成,共 {len(result)} 字。"
return "", f"{result}"
# ---------------------------------------------------------------------------
# 模块 3Gemma4 纪律审判
# ---------------------------------------------------------------------------
def ui_polish(raw_text: str) -> tuple[str, str]:
"""【LLM 润色】对转写稿进行严厉自我反思式润色。"""
if not raw_text or not raw_text.strip():
return "", "请先完成语音识别或手动粘贴转写文本。"
ok, result = polish_text(raw_text)
if ok:
return result, f"✅ Gemma4 润色完成,共 {len(result)} 字。"
return "", f"{result}"
def ui_check_ollama() -> str:
"""检测远程 Ollama 节点状态。"""
ok, msg = check_ollama_health()
return f"{msg}" if ok else f"{msg}"
# ---------------------------------------------------------------------------
# 模块 4ChatTTS 音频合成
# ---------------------------------------------------------------------------
def ui_synthesize(polished_text: str) -> tuple[str | None, str]:
"""【TTS 合成】生成最终 wav 配音文件。"""
if not polished_text or not polished_text.strip():
return None, "请先完成 Gemma4 润色。"
ok, msg, wav_path = generate_voice(polished_text)
if ok and wav_path:
return wav_path, f"{msg}"
return None, f"{msg}"
# ---------------------------------------------------------------------------
# 一键流水线
# ---------------------------------------------------------------------------
def ui_full_pipeline(
audio_file,
skip_polish: bool,
manual_raw: str,
) -> tuple[str, str, str | None, str]:
"""
串联执行:识别 → 润色(可跳过)→ 合成。
返回 (raw, polished, wav_path, log)
"""
logs: list[str] = []
# Step 1: 识别
if manual_raw and manual_raw.strip():
raw = manual_raw.strip()
logs.append(f"使用手动输入转写稿({len(raw)} 字)。")
else:
path = _save_upload(audio_file)
if not path:
return "", "", None, "❌ 请上传录音或手动填写转写文本。"
ok, result = transcribe_audio(path)
if not ok:
return "", "", None, f"❌ 识别失败: {result}"
raw = result
logs.append(f"✅ Whisper 识别完成({len(raw)} 字)。")
# Step 2: 润色
if skip_polish:
polished = raw
logs.append("已跳过 LLM 润色,直接使用原文。")
else:
ok, result = polish_text(raw)
if not ok:
return raw, "", None, f"❌ 润色失败: {result}\n" + "\n".join(logs)
polished = result
logs.append(f"✅ Gemma4 润色完成({len(polished)} 字)。")
# Step 3: 合成
ok, msg, wav_path = generate_voice(polished)
if not ok:
return raw, polished, None, f"❌ 合成失败: {msg}\n" + "\n".join(logs)
logs.append(f"{msg}")
return raw, polished, wav_path, "\n".join(logs)
# ---------------------------------------------------------------------------
# Gradio 界面
# ---------------------------------------------------------------------------
CUSTOM_CSS = """
/* ========== 高对比度暗色主题(确保文字清晰可读) ========== */
.gradio-container {
background: #0f1419 !important;
color: #eef2f7 !important;
font-size: 15px !important;
max-width: 1400px !important;
}
/* 全局文字 */
.gradio-container p,
.gradio-container span,
.gradio-container label,
.gradio-container .prose,
.gradio-container .markdown-text,
.gradio-container .md {
color: #eef2f7 !important;
}
/* 标题 */
.gradio-container h1 {
color: #ffffff !important;
font-size: 1.75rem !important;
font-weight: 700 !important;
}
.gradio-container h2,
.gradio-container h3 {
color: #dbeafe !important;
font-size: 1.15rem !important;
font-weight: 600 !important;
}
/* 标签 */
.gradio-container .block-label,
.gradio-container label,
.gradio-container .label-wrap span {
color: #93c5fd !important;
font-weight: 600 !important;
font-size: 0.95rem !important;
}
/* 输入框 / 文本域 */
.gradio-container textarea,
.gradio-container input[type="text"],
.gradio-container .wrap textarea,
.gradio-container .wrap input {
color: #ffffff !important;
background: #1a2332 !important;
border: 1px solid #4b5563 !important;
font-size: 0.95rem !important;
line-height: 1.6 !important;
}
.gradio-container textarea::placeholder,
.gradio-container input::placeholder {
color: #9ca3af !important;
opacity: 1 !important;
}
/* 只读日志框 */
.gradio-container .wrap .readonly textarea {
background: #111827 !important;
color: #e5e7eb !important;
border-color: #374151 !important;
}
/* Tab 标签页 */
.gradio-container .tab-nav button {
color: #9ca3af !important;
font-weight: 600 !important;
font-size: 0.95rem !important;
padding: 10px 18px !important;
}
.gradio-container .tab-nav button.selected {
color: #ffffff !important;
background: #1e3a5f !important;
border-bottom: 3px solid #3b82f6 !important;
}
/* 按钮 */
.gradio-container button.primary,
.gradio-container .primary {
background: #2563eb !important;
color: #ffffff !important;
font-weight: 700 !important;
font-size: 0.95rem !important;
border: 1px solid #3b82f6 !important;
}
.gradio-container button.primary:hover {
background: #1d4ed8 !important;
}
.gradio-container button.secondary,
.gradio-container button:not(.primary) {
color: #e5e7eb !important;
background: #374151 !important;
border: 1px solid #6b7280 !important;
font-weight: 600 !important;
}
/* 顶部信息面板 */
.dark-panel {
border: 1px solid #3b82f6 !important;
border-radius: 10px !important;
padding: 18px 20px !important;
background: #1a2332 !important;
margin-bottom: 16px !important;
}
.dark-panel code {
color: #93c5fd !important;
background: #0f172a !important;
padding: 2px 6px !important;
border-radius: 4px !important;
}
/* 状态卡片 */
.status-card {
border-radius: 10px;
padding: 14px 16px;
margin: 4px 0;
border-left: 5px solid #6b7280;
background: #1f2937;
min-height: 72px;
}
.status-card .status-title {
font-size: 0.85rem;
font-weight: 700;
color: #93c5fd !important;
margin-bottom: 8px;
letter-spacing: 0.03em;
}
.status-card .status-body {
font-size: 0.92rem;
line-height: 1.55;
color: #f3f4f6 !important;
word-break: break-word;
}
.status-ok {
border-left-color: #22c55e !important;
background: #14291a !important;
}
.status-ok .status-body { color: #bbf7d0 !important; }
.status-warn {
border-left-color: #f59e0b !important;
background: #2a2010 !important;
}
.status-warn .status-body { color: #fde68a !important; }
.status-err {
border-left-color: #ef4444 !important;
background: #2a1515 !important;
}
.status-err .status-body { color: #fecaca !important; }
/* 音频上传区 */
.gradio-container .audio-container,
.gradio-container .upload-container {
border: 2px dashed #4b5563 !important;
background: #1a2332 !important;
}
footer { visibility: hidden; }
"""
def _status_html(title: str, message: str, level: str = "warn") -> str:
"""生成高对比度状态卡片 HTML。level: ok | warn | err"""
icons = {"ok": "", "warn": "⚠️", "err": ""}
icon = icons.get(level, "")
# 去掉 message 里重复的 emoji,避免双图标
clean = message.lstrip("✅❌⚠️ ").strip()
return (
f'<div class="status-card status-{level}">'
f'<div class="status-title">{icon} {title}</div>'
f'<div class="status-body">{clean}</div>'
f"</div>"
)
def ui_check_ollama_html() -> str:
ok, msg = check_ollama_health()
return _status_html("Ollama 节点", msg, "ok" if ok else "err")
def ui_speaker_status_html() -> str:
ok, msg = speaker_is_ready()
return _status_html("音色状态", msg, "ok" if ok else "warn")
def build_theme() -> gr.themes.Base:
"""高对比度暗色主题(Gradio 6.0 需在 launch() 传入)。"""
return gr.themes.Base(
primary_hue="blue",
secondary_hue="blue",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"],
font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "Consolas", "monospace"],
).set(
body_background_fill="#0f1419",
body_background_fill_dark="#0f1419",
body_text_color="#eef2f7",
body_text_color_dark="#eef2f7",
block_background_fill="#1a2332",
block_background_fill_dark="#1a2332",
block_border_color="#4b5563",
block_title_text_color="#ffffff",
block_label_text_color="#93c5fd",
input_background_fill="#1a2332",
input_background_fill_dark="#1a2332",
button_primary_background_fill="#2563eb",
button_primary_background_fill_hover="#1d4ed8",
button_primary_text_color="#ffffff",
button_secondary_background_fill="#374151",
button_secondary_text_color="#e5e7eb",
border_color_primary="#3b82f6",
)
def build_app() -> gr.Blocks:
"""构建 Gradio 主界面。"""
with gr.Blocks(
title="Trading Studio | 交易复盘配音中控",
) as demo:
gr.Markdown(
f"""
# ⚡ Trading Studio
**本地量化交易复盘 → B 站配音生产流水线**
| 模块 | 说明 |
|------|------|
| Whisper | 本地 GPU 语音识别 |
| Gemma4 | `{MODEL_NAME}` @ `{OLLAMA_URL.replace('/api/chat', '')}` |
| ChatTTS | 本地 GPU 固定音色合成 |
> 仓库: [{GIT_REPO_URL}]({GIT_REPO_URL})
""",
elem_classes=["dark-panel"],
)
with gr.Row():
ollama_status = gr.HTML(value=_status_html("Ollama 节点", "正在检测...", "warn"))
speaker_status = gr.HTML(value=_status_html("音色状态", "正在检测...", "warn"))
refresh_btn = gr.Button("🔄 刷新状态", variant="secondary", scale=0)
refresh_btn.click(
fn=lambda: (ui_check_ollama_html(), ui_speaker_status_html()),
outputs=[ollama_status, speaker_status],
)
with gr.Tabs():
# ---- Tab 1: 音色锁定 ----
with gr.Tab("🎙️ 音色锁定"):
gr.Markdown(
"上传 **10-30 秒** 干净人声样本,系统将提取 Speaker Embedding "
f"并保存至 `{SPEAKER_EMB_PATH.name}`,后续合成 100% 还原音色。"
)
with gr.Row():
spk_audio = gr.Audio(
label="参考人声(碎碎念盲录样本)",
type="filepath",
sources=["upload", "microphone"],
)
spk_transcript = gr.Textbox(
label="参考音频精确转写(可选,提升还原度)",
placeholder="尽量与参考音频内容完全一致...",
lines=6,
)
lock_btn = gr.Button("🔒 锁定音色", variant="primary")
lock_log = gr.Textbox(label="锁定结果", lines=4, interactive=False)
lock_btn.click(
ui_lock_speaker,
[spk_audio, spk_transcript],
[lock_log, speaker_status],
)
# ---- Tab 2: 分步操作 ----
with gr.Tab("🔧 分步流水线"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Step 1 · 音频极速识别")
rec_audio = gr.Audio(
label="交易复盘碎碎念录音",
type="filepath",
sources=["upload", "microphone"],
)
transcribe_btn = gr.Button("⚡ Faster-Whisper 识别", variant="primary")
transcribe_log = gr.Textbox(label="识别日志", lines=2, interactive=False)
with gr.Column(scale=1):
gr.Markdown("### Step 2 · Gemma4 纪律审判")
raw_text = gr.Textbox(
label="转写原文(可编辑)",
lines=10,
placeholder="识别结果将显示在此,也可手动粘贴...",
)
polish_btn = gr.Button("⚖️ 远程 Gemma4 严厉润色", variant="primary")
polish_log = gr.Textbox(label="润色日志", lines=2, interactive=False)
with gr.Column(scale=1):
gr.Markdown("### Step 3 · ChatTTS 配音合成")
polished_text = gr.Textbox(
label="润色配音稿(可编辑)",
lines=10,
placeholder="润色结果将显示在此...",
)
synth_btn = gr.Button("🔊 合成配音 WAV", variant="primary")
synth_log = gr.Textbox(label="合成日志", lines=2, interactive=False)
output_audio = gr.Audio(label="成品配音", type="filepath")
transcribe_btn.click(ui_transcribe, rec_audio, [raw_text, transcribe_log])
polish_btn.click(ui_polish, raw_text, [polished_text, polish_log])
synth_btn.click(ui_synthesize, polished_text, [output_audio, synth_log])
# ---- Tab 3: 一键生产 ----
with gr.Tab("🚀 一键生产"):
gr.Markdown(
"上传碎碎念录音,系统自动完成 **识别 → 润色 → 合成** 全流程。"
)
with gr.Row():
pipe_audio = gr.Audio(
label="复盘录音",
type="filepath",
sources=["upload", "microphone"],
)
pipe_manual = gr.Textbox(
label="或手动输入转写(跳过识别)",
lines=4,
placeholder="若已有转写文本,可直接粘贴,留空则走 Whisper 识别",
)
skip_polish_cb = gr.Checkbox(
label="跳过 Gemma4 润色(仅测试 TTS)",
value=False,
)
pipeline_btn = gr.Button("▶ 启动全流程", variant="primary", size="lg")
pipeline_log = gr.Textbox(label="流水线日志", lines=6, interactive=False)
with gr.Row():
pipe_raw = gr.Textbox(label="转写原文", lines=6)
pipe_polished = gr.Textbox(label="润色稿", lines=6)
pipe_output = gr.Audio(label="成品配音", type="filepath")
pipeline_btn.click(
ui_full_pipeline,
[pipe_audio, skip_polish_cb, pipe_manual],
[pipe_raw, pipe_polished, pipe_output, pipeline_log],
)
demo.load(
fn=lambda: (ui_check_ollama_html(), ui_speaker_status_html()),
outputs=[ollama_status, speaker_status],
)
return demo
def main() -> None:
"""主入口:启动 Gradio 服务。"""
logger.info("Trading Studio 启动中... HOST=%s PORT=%s", HOST, PORT)
app = build_app()
app.launch(
server_name=HOST,
server_port=PORT,
share=False,
show_error=True,
theme=build_theme(),
css=CUSTOM_CSS,
allowed_paths=[str(Path(__file__).resolve().parent / "outputs")],
)
if __name__ == "__main__":
main()