Fix inconsistent voice across TTS segments

Use the same manual_seed for every chunk and normalize per-segment peaks before concat so long voiceovers no longer sound like different speakers between segments.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
dekun
2026-06-12 18:46:25 +08:00
parent 4255cf7cd7
commit 541df29722
2 changed files with 44 additions and 20 deletions
+4
View File
@@ -142,6 +142,10 @@ TTS_TEMPERATURE = 0.3
TTS_TOP_P = 0.7 TTS_TOP_P = 0.7
TTS_TOP_K = 20 TTS_TOP_K = 20
TTS_SPEED_PROMPT = "[speed_5]" TTS_SPEED_PROMPT = "[speed_5]"
# 多段拼接时各段必须使用同一随机种子,否则音色会像「换了个人」
TTS_MANUAL_SEED = _env_int("TTS_MANUAL_SEED", 42)
# 段间静音间隔(秒)
TTS_SEGMENT_PAUSE_SEC = 0.35
# 单段 TTS 最大字数(超长稿按句切分后逐段合成再拼接;8GB 显存建议 ≤200) # 单段 TTS 最大字数(超长稿按句切分后逐段合成再拼接;8GB 显存建议 ≤200)
TTS_MAX_CHARS_PER_CHUNK = _env_int("TTS_MAX_CHARS_PER_CHUNK", 200) TTS_MAX_CHARS_PER_CHUNK = _env_int("TTS_MAX_CHARS_PER_CHUNK", 200)
+36 -16
View File
@@ -33,9 +33,11 @@ from config import (
SPEAKER_SAMPLE_MIN_SEC, SPEAKER_SAMPLE_MIN_SEC,
TTS_MAX_CHARS_PER_CHUNK, TTS_MAX_CHARS_PER_CHUNK,
TTS_ENABLE_CACHE, TTS_ENABLE_CACHE,
TTS_MANUAL_SEED,
TTS_MAX_NEW_TOKEN, TTS_MAX_NEW_TOKEN,
TTS_MIN_NEW_TOKEN, TTS_MIN_NEW_TOKEN,
TTS_SAMPLE_RATE, TTS_SAMPLE_RATE,
TTS_SEGMENT_PAUSE_SEC,
TTS_SPEED_PROMPT, TTS_SPEED_PROMPT,
TTS_TEMPERATURE, TTS_TEMPERATURE,
TTS_TOP_K, TTS_TOP_K,
@@ -615,10 +617,17 @@ def _run_chattts_infer(
) )
def _normalize_segment_peak(wav: np.ndarray, target_peak: float = 0.95) -> np.ndarray:
"""各段单独归一化峰值,避免拼接后某段偏响/偏轻像换了人声。"""
arr = np.asarray(wav, dtype=np.float32).flatten()
peak = float(np.max(np.abs(arr))) or 1.0
return arr / peak * target_peak
def _concat_wavs( def _concat_wavs(
wavs: List[np.ndarray], wavs: List[np.ndarray],
sample_rate: int, sample_rate: int,
pause_sec: float = 0.35, pause_sec: float = TTS_SEGMENT_PAUSE_SEC,
) -> np.ndarray: ) -> np.ndarray:
if not wavs: if not wavs:
return np.array([], dtype=np.float32) return np.array([], dtype=np.float32)
@@ -626,7 +635,7 @@ def _concat_wavs(
pause = np.zeros(int(sample_rate * pause_sec), dtype=np.float32) pause = np.zeros(int(sample_rate * pause_sec), dtype=np.float32)
segments: List[np.ndarray] = [] segments: List[np.ndarray] = []
for i, wav in enumerate(wavs): for i, wav in enumerate(wavs):
segments.append(np.asarray(wav, dtype=np.float32).flatten()) segments.append(_normalize_segment_peak(wav))
if i < len(wavs) - 1: if i < len(wavs) - 1:
segments.append(pause) segments.append(pause)
return np.concatenate(segments) return np.concatenate(segments)
@@ -694,18 +703,24 @@ def generate_voice(
if speaker_warn: if speaker_warn:
logger.warning(speaker_warn) logger.warning(speaker_warn)
params_infer_code = ChatTTS.Chat.InferCodeParams( chunk_temperature = (
prompt=TTS_SPEED_PROMPT, min(TTS_TEMPERATURE, 0.2) if len(chunks) > 1 else TTS_TEMPERATURE
spk_emb=speaker_params.get("spk_emb"),
spk_smp=speaker_params.get("spk_smp"),
txt_smp=speaker_params.get("txt_smp"),
temperature=TTS_TEMPERATURE,
top_P=TTS_TOP_P,
top_K=TTS_TOP_K,
max_new_token=TTS_MAX_NEW_TOKEN,
min_new_token=TTS_MIN_NEW_TOKEN,
ensure_non_empty=False,
) )
infer_kwargs: Dict[str, Any] = {
"prompt": TTS_SPEED_PROMPT,
"spk_emb": speaker_params.get("spk_emb"),
"spk_smp": speaker_params.get("spk_smp"),
"txt_smp": speaker_params.get("txt_smp"),
"temperature": chunk_temperature,
"top_P": TTS_TOP_P,
"top_K": TTS_TOP_K,
"max_new_token": TTS_MAX_NEW_TOKEN,
"min_new_token": TTS_MIN_NEW_TOKEN,
"ensure_non_empty": False,
}
if "manual_seed" in inspect.signature(ChatTTS.Chat.InferCodeParams).parameters:
infer_kwargs["manual_seed"] = TTS_MANUAL_SEED
params_infer_code = ChatTTS.Chat.InferCodeParams(**infer_kwargs)
params_refine_text = ChatTTS.Chat.RefineTextParams( params_refine_text = ChatTTS.Chat.RefineTextParams(
prompt="[oral_2][laugh_0][break_4]", prompt="[oral_2][laugh_0][break_4]",
@@ -725,7 +740,7 @@ def generate_voice(
if not chunk or len(chunk) < 2: if not chunk or len(chunk) < 2:
continue continue
release_cuda_cache() release_cuda_cache()
chunk_infer = replace(params_infer_code, manual_seed=42 + idx) chunk_infer = params_infer_code
wavs = None wavs = None
last_exc: Optional[BaseException] = None last_exc: Optional[BaseException] = None
for attempt in range(3): for attempt in range(3):
@@ -736,8 +751,11 @@ def generate_voice(
break break
except RecursionError as exc: except RecursionError as exc:
last_exc = exc last_exc = exc
# 重试时仍保持同一 manual_seed,避免段内/段间音色突变
if "manual_seed" in infer_kwargs and attempt < 2:
chunk_infer = replace( chunk_infer = replace(
chunk_infer, manual_seed=1000 + idx * 10 + attempt params_infer_code,
manual_seed=TTS_MANUAL_SEED + attempt + 1,
) )
release_cuda_cache() release_cuda_cache()
except RuntimeError as exc: except RuntimeError as exc:
@@ -755,8 +773,10 @@ def generate_voice(
chat, reload_err = get_chattts_instance() chat, reload_err = get_chattts_instance()
if chat is None: if chat is None:
raise RuntimeError(reload_err or "ChatTTS 重载失败") from exc raise RuntimeError(reload_err or "ChatTTS 重载失败") from exc
if "manual_seed" in infer_kwargs:
chunk_infer = replace( chunk_infer = replace(
chunk_infer, manual_seed=2000 + idx * 10 + attempt params_infer_code,
manual_seed=TTS_MANUAL_SEED + attempt + 1,
) )
if wavs is None: if wavs is None:
return ( return (