diff --git a/.env.example b/.env.example index 6574b5f..06be769 100644 --- a/.env.example +++ b/.env.example @@ -17,3 +17,4 @@ OLLAMA_PORT=11434 # 8GB 显存 OOM 时可调低(合成按段切分) # TTS_MAX_CHARS_PER_CHUNK=150 # TTS_MAX_NEW_TOKEN=768 +# TTS_MIN_NEW_TOKEN=16 diff --git a/config.py b/config.py index 7666262..ce8a6a9 100644 --- a/config.py +++ b/config.py @@ -142,6 +142,9 @@ TTS_MAX_CHARS_PER_CHUNK = _env_int("TTS_MAX_CHARS_PER_CHUNK", 200) # ChatTTS 单段最大生成 token(越小越省显存,长句会自动切多段) TTS_MAX_NEW_TOKEN = _env_int("TTS_MAX_NEW_TOKEN", 1024) +# 至少生成多少 audio token 才允许结束(防止首 token EOS → 无限递归重试) +TTS_MIN_NEW_TOKEN = _env_int("TTS_MIN_NEW_TOKEN", 16) + # --------------------------------------------------------------------------- # 上传临时文件目录 # --------------------------------------------------------------------------- diff --git a/tts_service.py b/tts_service.py index 30d573d..b462f19 100644 --- a/tts_service.py +++ b/tts_service.py @@ -12,6 +12,7 @@ import re import traceback import uuid import warnings +from dataclasses import replace from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional, Tuple @@ -32,6 +33,7 @@ from config import ( SPEAKER_SAMPLE_MIN_SEC, TTS_MAX_CHARS_PER_CHUNK, TTS_MAX_NEW_TOKEN, + TTS_MIN_NEW_TOKEN, TTS_SAMPLE_RATE, TTS_SPEED_PROMPT, TTS_TEMPERATURE, @@ -641,10 +643,14 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]: top_P=TTS_TOP_P, top_K=TTS_TOP_K, max_new_token=TTS_MAX_NEW_TOKEN, + min_new_token=TTS_MIN_NEW_TOKEN, + ensure_non_empty=False, ) params_refine_text = ChatTTS.Chat.RefineTextParams( prompt="[oral_2][laugh_0][break_4]", + ensure_non_empty=False, + min_new_token=4, ) logger.info( @@ -657,12 +663,31 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]: segment_wavs: List[np.ndarray] = [] for idx, chunk in enumerate(chunks, start=1): release_cuda_cache() - wavs = chat.infer( - chunk, - skip_refine_text=(idx > 1), - params_refine_text=params_refine_text, - params_infer_code=params_infer_code, - ) + # manual_seed 每段不同;ensure_non_empty=False 避免空输出时无限递归 + chunk_infer = replace(params_infer_code, manual_seed=42 + idx) + wavs = None + last_exc: Optional[BaseException] = None + for attempt in range(3): + try: + wavs = chat.infer( + chunk, + skip_refine_text=False, + params_refine_text=params_refine_text, + params_infer_code=chunk_infer, + ) + break + except RecursionError as exc: + last_exc = exc + chunk_infer.manual_seed = 1000 + idx * 10 + attempt + release_cuda_cache() + if wavs is None: + return ( + False, + f"ChatTTS 第 {idx}/{len(chunks)} 段合成失败(递归重试耗尽)。" + f"请检查音色转写是否填写,或缩短该段文本。" + f" 详情: {last_exc}", + None, + ) if not wavs or len(wavs) == 0: return ( False, @@ -670,7 +695,15 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]: f"(段内容前 40 字: {chunk[:40]}…)", None, ) - segment_wavs.append(np.asarray(wavs[0], dtype=np.float32)) + wav_arr = np.asarray(wavs[0], dtype=np.float32) + if wav_arr.size == 0 or np.max(np.abs(wav_arr)) < 1e-6: + return ( + False, + f"ChatTTS 第 {idx}/{len(chunks)} 段生成了空音频。" + "请重新锁定音色并填写参考转写,或缩短润色稿后重试。", + None, + ) + segment_wavs.append(wav_arr) release_cuda_cache() wav_array = ( @@ -712,6 +745,13 @@ def generate_voice(refined_text: str) -> Tuple[bool, str, Optional[str]]: "4. 确认无其他程序占用 GPU: nvidia-smi\n" f"技术详情: {exc_msg[:400]}" ) + elif "recursion depth" in exc_msg.lower() or isinstance(exc, RecursionError): + err = ( + "语音合成失败: ChatTTS 反复生成空结果导致递归超限。\n" + "常见原因:未填写参考音频转写、润色稿含特殊符号、或音色文件异常。\n" + "处理:重新锁定音色并填写转写 → 用较短纯文本试合成。\n" + f"技术详情: {exc_msg[:400]}" + ) elif "Corrupt input data" in exc_msg: err = ( "语音合成失败: 音色数据损坏或格式不兼容(Corrupt input data)。\n"