Fix inconsistent voice across TTS segments
Use the same manual_seed for every chunk and normalize per-segment peaks before concat so long voiceovers no longer sound like different speakers between segments. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
+40
-20
@@ -33,9 +33,11 @@ from config import (
|
||||
SPEAKER_SAMPLE_MIN_SEC,
|
||||
TTS_MAX_CHARS_PER_CHUNK,
|
||||
TTS_ENABLE_CACHE,
|
||||
TTS_MANUAL_SEED,
|
||||
TTS_MAX_NEW_TOKEN,
|
||||
TTS_MIN_NEW_TOKEN,
|
||||
TTS_SAMPLE_RATE,
|
||||
TTS_SEGMENT_PAUSE_SEC,
|
||||
TTS_SPEED_PROMPT,
|
||||
TTS_TEMPERATURE,
|
||||
TTS_TOP_K,
|
||||
@@ -615,10 +617,17 @@ def _run_chattts_infer(
|
||||
)
|
||||
|
||||
|
||||
def _normalize_segment_peak(wav: np.ndarray, target_peak: float = 0.95) -> np.ndarray:
|
||||
"""各段单独归一化峰值,避免拼接后某段偏响/偏轻像换了人声。"""
|
||||
arr = np.asarray(wav, dtype=np.float32).flatten()
|
||||
peak = float(np.max(np.abs(arr))) or 1.0
|
||||
return arr / peak * target_peak
|
||||
|
||||
|
||||
def _concat_wavs(
|
||||
wavs: List[np.ndarray],
|
||||
sample_rate: int,
|
||||
pause_sec: float = 0.35,
|
||||
pause_sec: float = TTS_SEGMENT_PAUSE_SEC,
|
||||
) -> np.ndarray:
|
||||
if not wavs:
|
||||
return np.array([], dtype=np.float32)
|
||||
@@ -626,7 +635,7 @@ def _concat_wavs(
|
||||
pause = np.zeros(int(sample_rate * pause_sec), dtype=np.float32)
|
||||
segments: List[np.ndarray] = []
|
||||
for i, wav in enumerate(wavs):
|
||||
segments.append(np.asarray(wav, dtype=np.float32).flatten())
|
||||
segments.append(_normalize_segment_peak(wav))
|
||||
if i < len(wavs) - 1:
|
||||
segments.append(pause)
|
||||
return np.concatenate(segments)
|
||||
@@ -694,18 +703,24 @@ def generate_voice(
|
||||
if speaker_warn:
|
||||
logger.warning(speaker_warn)
|
||||
|
||||
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
||||
prompt=TTS_SPEED_PROMPT,
|
||||
spk_emb=speaker_params.get("spk_emb"),
|
||||
spk_smp=speaker_params.get("spk_smp"),
|
||||
txt_smp=speaker_params.get("txt_smp"),
|
||||
temperature=TTS_TEMPERATURE,
|
||||
top_P=TTS_TOP_P,
|
||||
top_K=TTS_TOP_K,
|
||||
max_new_token=TTS_MAX_NEW_TOKEN,
|
||||
min_new_token=TTS_MIN_NEW_TOKEN,
|
||||
ensure_non_empty=False,
|
||||
chunk_temperature = (
|
||||
min(TTS_TEMPERATURE, 0.2) if len(chunks) > 1 else TTS_TEMPERATURE
|
||||
)
|
||||
infer_kwargs: Dict[str, Any] = {
|
||||
"prompt": TTS_SPEED_PROMPT,
|
||||
"spk_emb": speaker_params.get("spk_emb"),
|
||||
"spk_smp": speaker_params.get("spk_smp"),
|
||||
"txt_smp": speaker_params.get("txt_smp"),
|
||||
"temperature": chunk_temperature,
|
||||
"top_P": TTS_TOP_P,
|
||||
"top_K": TTS_TOP_K,
|
||||
"max_new_token": TTS_MAX_NEW_TOKEN,
|
||||
"min_new_token": TTS_MIN_NEW_TOKEN,
|
||||
"ensure_non_empty": False,
|
||||
}
|
||||
if "manual_seed" in inspect.signature(ChatTTS.Chat.InferCodeParams).parameters:
|
||||
infer_kwargs["manual_seed"] = TTS_MANUAL_SEED
|
||||
params_infer_code = ChatTTS.Chat.InferCodeParams(**infer_kwargs)
|
||||
|
||||
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
||||
prompt="[oral_2][laugh_0][break_4]",
|
||||
@@ -725,7 +740,7 @@ def generate_voice(
|
||||
if not chunk or len(chunk) < 2:
|
||||
continue
|
||||
release_cuda_cache()
|
||||
chunk_infer = replace(params_infer_code, manual_seed=42 + idx)
|
||||
chunk_infer = params_infer_code
|
||||
wavs = None
|
||||
last_exc: Optional[BaseException] = None
|
||||
for attempt in range(3):
|
||||
@@ -736,9 +751,12 @@ def generate_voice(
|
||||
break
|
||||
except RecursionError as exc:
|
||||
last_exc = exc
|
||||
chunk_infer = replace(
|
||||
chunk_infer, manual_seed=1000 + idx * 10 + attempt
|
||||
)
|
||||
# 重试时仍保持同一 manual_seed,避免段内/段间音色突变
|
||||
if "manual_seed" in infer_kwargs and attempt < 2:
|
||||
chunk_infer = replace(
|
||||
params_infer_code,
|
||||
manual_seed=TTS_MANUAL_SEED + attempt + 1,
|
||||
)
|
||||
release_cuda_cache()
|
||||
except RuntimeError as exc:
|
||||
last_exc = exc
|
||||
@@ -755,9 +773,11 @@ def generate_voice(
|
||||
chat, reload_err = get_chattts_instance()
|
||||
if chat is None:
|
||||
raise RuntimeError(reload_err or "ChatTTS 重载失败") from exc
|
||||
chunk_infer = replace(
|
||||
chunk_infer, manual_seed=2000 + idx * 10 + attempt
|
||||
)
|
||||
if "manual_seed" in infer_kwargs:
|
||||
chunk_infer = replace(
|
||||
params_infer_code,
|
||||
manual_seed=TTS_MANUAL_SEED + attempt + 1,
|
||||
)
|
||||
if wavs is None:
|
||||
return (
|
||||
False,
|
||||
|
||||
Reference in New Issue
Block a user