Fix inconsistent voice across TTS segments
Use the same manual_seed for every chunk and normalize per-segment peaks before concat so long voiceovers no longer sound like different speakers between segments. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -142,6 +142,10 @@ TTS_TEMPERATURE = 0.3
|
|||||||
TTS_TOP_P = 0.7
|
TTS_TOP_P = 0.7
|
||||||
TTS_TOP_K = 20
|
TTS_TOP_K = 20
|
||||||
TTS_SPEED_PROMPT = "[speed_5]"
|
TTS_SPEED_PROMPT = "[speed_5]"
|
||||||
|
# 多段拼接时各段必须使用同一随机种子,否则音色会像「换了个人」
|
||||||
|
TTS_MANUAL_SEED = _env_int("TTS_MANUAL_SEED", 42)
|
||||||
|
# 段间静音间隔(秒)
|
||||||
|
TTS_SEGMENT_PAUSE_SEC = 0.35
|
||||||
|
|
||||||
# 单段 TTS 最大字数(超长稿按句切分后逐段合成再拼接;8GB 显存建议 ≤200)
|
# 单段 TTS 最大字数(超长稿按句切分后逐段合成再拼接;8GB 显存建议 ≤200)
|
||||||
TTS_MAX_CHARS_PER_CHUNK = _env_int("TTS_MAX_CHARS_PER_CHUNK", 200)
|
TTS_MAX_CHARS_PER_CHUNK = _env_int("TTS_MAX_CHARS_PER_CHUNK", 200)
|
||||||
|
|||||||
+36
-16
@@ -33,9 +33,11 @@ from config import (
|
|||||||
SPEAKER_SAMPLE_MIN_SEC,
|
SPEAKER_SAMPLE_MIN_SEC,
|
||||||
TTS_MAX_CHARS_PER_CHUNK,
|
TTS_MAX_CHARS_PER_CHUNK,
|
||||||
TTS_ENABLE_CACHE,
|
TTS_ENABLE_CACHE,
|
||||||
|
TTS_MANUAL_SEED,
|
||||||
TTS_MAX_NEW_TOKEN,
|
TTS_MAX_NEW_TOKEN,
|
||||||
TTS_MIN_NEW_TOKEN,
|
TTS_MIN_NEW_TOKEN,
|
||||||
TTS_SAMPLE_RATE,
|
TTS_SAMPLE_RATE,
|
||||||
|
TTS_SEGMENT_PAUSE_SEC,
|
||||||
TTS_SPEED_PROMPT,
|
TTS_SPEED_PROMPT,
|
||||||
TTS_TEMPERATURE,
|
TTS_TEMPERATURE,
|
||||||
TTS_TOP_K,
|
TTS_TOP_K,
|
||||||
@@ -615,10 +617,17 @@ def _run_chattts_infer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_segment_peak(wav: np.ndarray, target_peak: float = 0.95) -> np.ndarray:
|
||||||
|
"""各段单独归一化峰值,避免拼接后某段偏响/偏轻像换了人声。"""
|
||||||
|
arr = np.asarray(wav, dtype=np.float32).flatten()
|
||||||
|
peak = float(np.max(np.abs(arr))) or 1.0
|
||||||
|
return arr / peak * target_peak
|
||||||
|
|
||||||
|
|
||||||
def _concat_wavs(
|
def _concat_wavs(
|
||||||
wavs: List[np.ndarray],
|
wavs: List[np.ndarray],
|
||||||
sample_rate: int,
|
sample_rate: int,
|
||||||
pause_sec: float = 0.35,
|
pause_sec: float = TTS_SEGMENT_PAUSE_SEC,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
if not wavs:
|
if not wavs:
|
||||||
return np.array([], dtype=np.float32)
|
return np.array([], dtype=np.float32)
|
||||||
@@ -626,7 +635,7 @@ def _concat_wavs(
|
|||||||
pause = np.zeros(int(sample_rate * pause_sec), dtype=np.float32)
|
pause = np.zeros(int(sample_rate * pause_sec), dtype=np.float32)
|
||||||
segments: List[np.ndarray] = []
|
segments: List[np.ndarray] = []
|
||||||
for i, wav in enumerate(wavs):
|
for i, wav in enumerate(wavs):
|
||||||
segments.append(np.asarray(wav, dtype=np.float32).flatten())
|
segments.append(_normalize_segment_peak(wav))
|
||||||
if i < len(wavs) - 1:
|
if i < len(wavs) - 1:
|
||||||
segments.append(pause)
|
segments.append(pause)
|
||||||
return np.concatenate(segments)
|
return np.concatenate(segments)
|
||||||
@@ -694,18 +703,24 @@ def generate_voice(
|
|||||||
if speaker_warn:
|
if speaker_warn:
|
||||||
logger.warning(speaker_warn)
|
logger.warning(speaker_warn)
|
||||||
|
|
||||||
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
chunk_temperature = (
|
||||||
prompt=TTS_SPEED_PROMPT,
|
min(TTS_TEMPERATURE, 0.2) if len(chunks) > 1 else TTS_TEMPERATURE
|
||||||
spk_emb=speaker_params.get("spk_emb"),
|
|
||||||
spk_smp=speaker_params.get("spk_smp"),
|
|
||||||
txt_smp=speaker_params.get("txt_smp"),
|
|
||||||
temperature=TTS_TEMPERATURE,
|
|
||||||
top_P=TTS_TOP_P,
|
|
||||||
top_K=TTS_TOP_K,
|
|
||||||
max_new_token=TTS_MAX_NEW_TOKEN,
|
|
||||||
min_new_token=TTS_MIN_NEW_TOKEN,
|
|
||||||
ensure_non_empty=False,
|
|
||||||
)
|
)
|
||||||
|
infer_kwargs: Dict[str, Any] = {
|
||||||
|
"prompt": TTS_SPEED_PROMPT,
|
||||||
|
"spk_emb": speaker_params.get("spk_emb"),
|
||||||
|
"spk_smp": speaker_params.get("spk_smp"),
|
||||||
|
"txt_smp": speaker_params.get("txt_smp"),
|
||||||
|
"temperature": chunk_temperature,
|
||||||
|
"top_P": TTS_TOP_P,
|
||||||
|
"top_K": TTS_TOP_K,
|
||||||
|
"max_new_token": TTS_MAX_NEW_TOKEN,
|
||||||
|
"min_new_token": TTS_MIN_NEW_TOKEN,
|
||||||
|
"ensure_non_empty": False,
|
||||||
|
}
|
||||||
|
if "manual_seed" in inspect.signature(ChatTTS.Chat.InferCodeParams).parameters:
|
||||||
|
infer_kwargs["manual_seed"] = TTS_MANUAL_SEED
|
||||||
|
params_infer_code = ChatTTS.Chat.InferCodeParams(**infer_kwargs)
|
||||||
|
|
||||||
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
||||||
prompt="[oral_2][laugh_0][break_4]",
|
prompt="[oral_2][laugh_0][break_4]",
|
||||||
@@ -725,7 +740,7 @@ def generate_voice(
|
|||||||
if not chunk or len(chunk) < 2:
|
if not chunk or len(chunk) < 2:
|
||||||
continue
|
continue
|
||||||
release_cuda_cache()
|
release_cuda_cache()
|
||||||
chunk_infer = replace(params_infer_code, manual_seed=42 + idx)
|
chunk_infer = params_infer_code
|
||||||
wavs = None
|
wavs = None
|
||||||
last_exc: Optional[BaseException] = None
|
last_exc: Optional[BaseException] = None
|
||||||
for attempt in range(3):
|
for attempt in range(3):
|
||||||
@@ -736,8 +751,11 @@ def generate_voice(
|
|||||||
break
|
break
|
||||||
except RecursionError as exc:
|
except RecursionError as exc:
|
||||||
last_exc = exc
|
last_exc = exc
|
||||||
|
# 重试时仍保持同一 manual_seed,避免段内/段间音色突变
|
||||||
|
if "manual_seed" in infer_kwargs and attempt < 2:
|
||||||
chunk_infer = replace(
|
chunk_infer = replace(
|
||||||
chunk_infer, manual_seed=1000 + idx * 10 + attempt
|
params_infer_code,
|
||||||
|
manual_seed=TTS_MANUAL_SEED + attempt + 1,
|
||||||
)
|
)
|
||||||
release_cuda_cache()
|
release_cuda_cache()
|
||||||
except RuntimeError as exc:
|
except RuntimeError as exc:
|
||||||
@@ -755,8 +773,10 @@ def generate_voice(
|
|||||||
chat, reload_err = get_chattts_instance()
|
chat, reload_err = get_chattts_instance()
|
||||||
if chat is None:
|
if chat is None:
|
||||||
raise RuntimeError(reload_err or "ChatTTS 重载失败") from exc
|
raise RuntimeError(reload_err or "ChatTTS 重载失败") from exc
|
||||||
|
if "manual_seed" in infer_kwargs:
|
||||||
chunk_infer = replace(
|
chunk_infer = replace(
|
||||||
chunk_infer, manual_seed=2000 + idx * 10 + attempt
|
params_infer_code,
|
||||||
|
manual_seed=TTS_MANUAL_SEED + attempt + 1,
|
||||||
)
|
)
|
||||||
if wavs is None:
|
if wavs is None:
|
||||||
return (
|
return (
|
||||||
|
|||||||
Reference in New Issue
Block a user