Kokoro_TTS_Compare / tts_cli.py
Wismut's picture
initial commit
2eaa44a
raw
history blame
17.8 kB
#!/usr/bin/env python3
# tts_cli.py
"""
Example CLI for generating audio with Kokoro-StyleTTS2.
Usage:
python tts_cli.py \
--model /path/to/kokoro-v0_19.pth \
--config /path/to/config.json \
--text "Hello, my stinking friends from 1906! You stink." \
--voicepack /path/to/af.pt \
--output output.wav
Make sure:
1. `models.py` is in the same folder (with `build_model`, `Decoder`, etc.).
2. You have installed the needed libraries:
pip install torch phonemizer munch soundfile pyyaml
3. The model is a checkpoint that your `build_model` can load.
Adapt as needed!
"""
import argparse
import os
import re
import torch
import soundfile as sf
import numpy as np
from phonemizer import backend as phonemizer_backend
# If you use eSpeak library:
try:
from espeak_util import set_espeak_library
set_espeak_library()
except ImportError:
pass
# --------------------------------------------------------------------
# Import from your local `models.py` (requires that file to be present).
# This example assumes `build_model` loads the entire TTS submodules
# (bert, bert_encoder, predictor, decoder, text_encoder).
# --------------------------------------------------------------------
from models import build_model
def resplit_strings(arr):
"""
Given a list of string tokens (e.g. words, phrases), tries to
split them into two sub-lists whose total lengths are as balanced
as possible. The goal is to chunk a large string in half without
splitting in the middle of a word.
"""
if not arr:
return "", ""
if len(arr) == 1:
return arr[0], ""
min_diff = float("inf")
best_split = 0
lengths = [len(s) for s in arr]
spaces = len(arr) - 1
left_len = 0
right_len = sum(lengths) + spaces
for i in range(1, len(arr)):
# Add current word + space to left side
left_len += lengths[i - 1] + (1 if i > 1 else 0)
# Remove from right side
right_len -= lengths[i - 1] + 1
diff = abs(left_len - right_len)
if diff < min_diff:
min_diff = diff
best_split = i
return " ".join(arr[:best_split]), " ".join(arr[best_split:])
def recursive_split(text, lang="a"):
"""
Splits a piece of text into smaller segments so that
each segment's phoneme length < some ~limit (~500 tokens).
"""
# We'll reuse your existing `phonemize_text` + `tokenize` from script 1
# to see if it is < 512 tokens. If it is, return it as a single chunk.
# Otherwise, split on punctuation or whitespace and recurse.
# 1. Phonemize first, check length
ps = phonemize_text(text, lang=lang, do_normalize=True)
tokens = tokenize(ps)
if len(tokens) < 512:
return [(text, ps)]
# If too large, we split on certain punctuation or fallback to whitespace
# We'll look for punctuation that often indicates sentence boundaries
# If none found, fallback to space-split
for punctuation in [r"[.?!…]", r"[:,;—]"]:
pattern = f"(?:(?<={punctuation})|(?<={punctuation}[\"'»])) "
# Attempt to split on that punctuation
splits = re.split(pattern, text)
if len(splits) > 1:
break
else:
# If we didn't break out, just do whitespace split
splits = text.split(" ")
# Use resplit_strings to chunk it about halfway
left, right = resplit_strings(splits)
# Recurse
return recursive_split(left, lang=lang) + recursive_split(right, lang=lang)
def segment_and_tokenize(long_text, lang="a"):
"""
Takes a large text, optionally normalizes or cleans it,
then breaks it into a list of (segment_text, segment_phonemes).
"""
# Additional cleaning if you want:
# long_text = normalize_text(long_text) # your existing function
# We chunk it up using recursive_split
segments = recursive_split(long_text, lang=lang)
return segments
# -------------- Normalization & Phonemization Routines -------------- #
def parens_to_angles(s):
return s.replace("(", "«").replace(")", "»")
def split_num(num):
num = num.group()
if "." in num:
return num
elif ":" in num:
h, m = [int(n) for n in num.split(":")]
if m == 0:
return f"{h} o'clock"
elif m < 10:
return f"{h} oh {m}"
return f"{h} {m}"
year = int(num[:4])
if year < 1100 or year % 1000 < 10:
return num
left, right = num[:2], int(num[2:4])
s = "s" if num.endswith("s") else ""
if 100 <= year % 1000 <= 999:
if right == 0:
return f"{left} hundred{s}"
elif right < 10:
return f"{left} oh {right}{s}"
return f"{left} {right}{s}"
def flip_money(m):
m = m.group()
bill = "dollar" if m[0] == "$" else "pound"
if m[-1].isalpha():
return f"{m[1:]} {bill}s"
elif "." not in m:
s = "" if m[1:] == "1" else "s"
return f"{m[1:]} {bill}{s}"
b, c = m[1:].split(".")
s = "" if b == "1" else "s"
c = int(c.ljust(2, "0"))
coins = (
f"cent{'' if c == 1 else 's'}"
if m[0] == "$"
else ("penny" if c == 1 else "pence")
)
return f"{b} {bill}{s} and {c} {coins}"
def point_num(num):
a, b = num.group().split(".")
return " point ".join([a, " ".join(b)])
def normalize_text(text):
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
text = text.replace("«", chr(8220)).replace("»", chr(8221))
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
text = parens_to_angles(text)
# Replace some common full-width punctuation in CJK:
for a, b in zip("、。!,:;?", ",.!,:;?"):
text = text.replace(a, b + " ")
text = re.sub(r"[^\S \n]", " ", text)
text = re.sub(r" +", " ", text)
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text)
text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text)
text = re.sub(r"\betc\.(?! [A-Z])", "etc", text)
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
text = re.sub(
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)",
split_num,
text,
)
text = re.sub(r"(?<=\d),(?=\d)", "", text)
text = re.sub(
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
flip_money,
text,
)
text = re.sub(r"\d*\.\d+", point_num, text)
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text) # Could be minus; adjust if needed
text = re.sub(r"(?<=\d)S", " S", text)
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
text = re.sub(r"(?<=X')S\b", "s", text)
text = re.sub(
r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
)
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
return text.strip()
# -------------------------------------------------------------------
# Vocab and Symbol Mapping
# -------------------------------------------------------------------
def get_vocab():
_pad = "$"
_punctuation = ';:,.!?¡¿—…"«»“” '
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
dicts = {}
for i, s in enumerate(symbols):
dicts[s] = i
return dicts
VOCAB = get_vocab()
def tokenize(ps: str):
"""Convert the phoneme string into integer tokens based on VOCAB."""
return [VOCAB.get(p) for p in ps if p in VOCAB]
# -------------------------------------------------------------------
# Initialize a simple phonemizer
# For English:
# 'a' ~ en-us
# 'b' ~ en-gb
# -------------------------------------------------------------------
phonemizers = dict(
a=phonemizer_backend.EspeakBackend(
language="en-us", preserve_punctuation=True, with_stress=True
),
b=phonemizer_backend.EspeakBackend(
language="en-gb", preserve_punctuation=True, with_stress=True
),
# You can add more, e.g. 'j': some Japanese phonemizer, etc.
)
def phonemize_text(text, lang="a", do_normalize=True):
if do_normalize:
text = normalize_text(text)
ps_list = phonemizers[lang].phonemize([text])
ps = ps_list[0] if ps_list else ""
# Some custom replacements (from your code)
ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ")
ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l")
# Example: insert space before "hˈʌndɹɪd" if there's a letter, e.g. "nˈaɪn" => "nˈaɪn hˈʌndɹɪd"
ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps)
# "z" at the end of a word -> remove space (just your snippet)
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', "z", ps)
# If lang is 'a', handle "ninety" => "ninedi"? Just from your snippet:
if lang == "a":
ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps)
# Only keep valid symbols
ps = "".join(p for p in ps if p in VOCAB)
return ps.strip()
# -------------------------------------------------------------------
# Utility for generating text masks
# -------------------------------------------------------------------
def length_to_mask(lengths):
# lengths is a Tensor of shape [B], containing the text length for each batch
max_len = lengths.max()
row_ids = torch.arange(max_len, device=lengths.device).unsqueeze(0)
mask = row_ids.expand(lengths.shape[0], -1)
return (mask + 1) > lengths.unsqueeze(1)
# -------------------------------------------------------------------
# The forward pass for inference (from your snippet).
# This version references `model.predictor`, `model.decoder`, etc.
# -------------------------------------------------------------------
@torch.no_grad()
def forward_tts(model, tokens, ref_s, speed=1.0):
"""
model: Munch with submodels: bert, bert_encoder, predictor, decoder, text_encoder
tokens: list[int], the tokenized input (without [0, ... , 0] yet)
ref_s: reference embedding (torch.Tensor)
speed: float, speed factor
"""
device = ref_s.device
tokens_t = torch.LongTensor([[0, *tokens, 0]]).to(device) # add boundary tokens
input_lengths = torch.LongTensor([tokens_t.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
# 1. Encode with BERT
bert_dur = model.bert(tokens_t, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# 2. Prosody predictor
s = ref_s[
:, 128:
] # from your snippet: the last 128 is ???, or the first 128 is ???
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
# 3. Expand alignment
total_len = pred_dur.sum().item()
pred_aln_trg = torch.zeros(input_lengths, total_len, device=device)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
n = pred_dur[0, i].item()
pred_aln_trg[i, c_frame : c_frame + n] = 1
c_frame += n
# 4. Run F0 + Noise predictor
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0)
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
# 5. Text encoder -> asr
t_en = model.text_encoder(tokens_t, input_lengths, text_mask)
asr = t_en @ pred_aln_trg.unsqueeze(0)
# 6. Decode audio
audio = model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]) # B x audio_len
return audio.squeeze().cpu().numpy()
def generate_tts(model, text, voicepack, lang="a", speed=1.0):
"""
model: the Munch returned by build_model(...)
text: the input text (string)
voicepack: the torch Tensor reference embedding, or a dict of them
lang: 'a' or 'b' or etc. from your phonemizers
speed: speech speed factor
sample_rate: sampling rate for the output
"""
# 1. Phonemize
ps = phonemize_text(text, lang=lang, do_normalize=True)
tokens = tokenize(ps)
if not tokens:
return None, ps
# 2. Retrieve reference style
# If your voicepack is a single embedding for all lengths, adapt as needed.
# If your voicepack is something like `voicepack[len(tokens)]`, do that.
# If you have multiple voices, you might do something else.
try:
ref_s = voicepack[len(tokens)]
except:
# fallback if len(tokens) is out of range
ref_s = voicepack[-1]
ref_s = ref_s.to("cpu" if not next(model.bert.parameters()).is_cuda else "cuda")
# 3. Generate
audio = forward_tts(model, tokens, ref_s, speed=speed)
return audio, ps
def generate_long_form_tts(model, full_text, voicepack, lang="a", speed=1.0):
"""
Generate TTS for a large `full_text`, splitting it into smaller segments
and concatenating the resulting audio.
Returns: (np.float32 array) final_audio, list_of_segment_phonemes
"""
# 1. Segment the text
segments = segment_and_tokenize(full_text, lang=lang)
# segments is a list of (seg_text, seg_phonemes)
# 2. For each segment, call `generate_tts(...)`
audio_chunks = []
all_phonemes = []
for i, (seg_text, seg_ps) in enumerate(segments, 1):
print(f"[LongForm] Generating chunk {i}/{len(segments)}: {seg_text[:40]}...")
audio, used_phonemes = generate_tts(
model, seg_text, voicepack, lang=lang, speed=speed
)
if audio is not None:
audio_chunks.append(audio)
all_phonemes.append(used_phonemes)
else:
print(f"[LongForm] Skipped empty segment {i}...")
if not audio_chunks:
return None, []
# 3. Concatenate the audio
final_audio = np.concatenate(audio_chunks, axis=0)
return final_audio, all_phonemes
# -------------------------------------------------------------------
# Main CLI
# -------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Kokoro-StyleTTS2 CLI Example")
parser.add_argument(
"--model",
type=str,
default="pretrained_models/Kokoro/kokoro-v0_19.pth",
help="Path to your model checkpoint (e.g. kokoro-v0_19.pth).",
)
parser.add_argument(
"--config",
type=str,
default="pretrained_models/Kokoro/config.json",
help="Path to config.json (used by build_model).",
)
parser.add_argument(
"--text",
type=str,
default="Hello world! This is Kokoro, a new text-to-speech model based on StyleTTS2 from 2024!",
help="Text to be converted into speech.",
)
parser.add_argument(
"--voicepack",
type=str,
default="pretrained_models/Kokoro/voices/af.pt",
help="Path to a .pt file for your reference embedding(s).",
)
parser.add_argument(
"--output", type=str, default="output.wav", help="Output WAV filename."
)
parser.add_argument(
"--speed",
type=float,
default=1.0,
help="Speech speed factor, e.g. 0.8 slower, 1.2 faster, etc.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="Device to run inference on.",
)
args = parser.parse_args()
# 1. Build model using your local build_model function
# (which loads TextEncoder, Decoder, etc. and returns a Munch).
if not os.path.isfile(args.config):
raise FileNotFoundError(f"config.json not found: {args.config}")
# Optionally load config as Munch (depends on your build_model usage)
# But your snippet does something like:
# with open(config, 'r') as r: ...
# ...
# model = build_model(path, device)
# We'll do the same but in a simpler form:
device = (
args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu"
)
print(f"Loading model from: {args.model}")
model = build_model(
args.model, device
) # This requires that `args.model` is the checkpoint path
# Because `build_model` returns a Munch (dict of submodules),
# we can't just do `model.eval()`, we must set each submodule to eval:
for k, subm in model.items():
if isinstance(subm, torch.nn.Module):
subm.eval()
# 2. Load voicepack
if not os.path.isfile(args.voicepack):
raise FileNotFoundError(f"Voicepack file not found: {args.voicepack}")
print(f"Loading voicepack from: {args.voicepack}")
vp = torch.load(args.voicepack, map_location=device)
# If your voicepack is an nn.Module, set it to eval as well
if isinstance(vp, torch.nn.Module):
vp.eval()
# 3. Generate audio
print(f"Generating speech for text: {args.text}")
audio, phonemes = generate_long_form_tts(
model, args.text, vp, lang="a", speed=args.speed
)
if audio is None:
print("No tokens were generated (maybe empty text?). Exiting.")
return
# 4. Write WAV
print(f"Writing output to: {args.output}")
sf.write(args.output, audio, 22050)
print("Finished!")
print(f"Phonemes used: {phonemes}")
if __name__ == "__main__":
main()