Spaces:
Running
Running
#!/usr/bin/env python3 | |
# tts_cli.py | |
""" | |
Example CLI for generating audio with Kokoro-StyleTTS2. | |
Usage: | |
python tts_cli.py \ | |
--model /path/to/kokoro-v0_19.pth \ | |
--config /path/to/config.json \ | |
--text "Hello, my stinking friends from 1906! You stink." \ | |
--voicepack /path/to/af.pt \ | |
--output output.wav | |
Make sure: | |
1. `models.py` is in the same folder (with `build_model`, `Decoder`, etc.). | |
2. You have installed the needed libraries: | |
pip install torch phonemizer munch soundfile pyyaml | |
3. The model is a checkpoint that your `build_model` can load. | |
Adapt as needed! | |
""" | |
import argparse | |
import os | |
import re | |
import torch | |
import soundfile as sf | |
import numpy as np | |
from phonemizer import backend as phonemizer_backend | |
# If you use eSpeak library: | |
try: | |
from espeak_util import set_espeak_library | |
set_espeak_library() | |
except ImportError: | |
pass | |
# -------------------------------------------------------------------- | |
# Import from your local `models.py` (requires that file to be present). | |
# This example assumes `build_model` loads the entire TTS submodules | |
# (bert, bert_encoder, predictor, decoder, text_encoder). | |
# -------------------------------------------------------------------- | |
from models import build_model | |
def resplit_strings(arr): | |
""" | |
Given a list of string tokens (e.g. words, phrases), tries to | |
split them into two sub-lists whose total lengths are as balanced | |
as possible. The goal is to chunk a large string in half without | |
splitting in the middle of a word. | |
""" | |
if not arr: | |
return "", "" | |
if len(arr) == 1: | |
return arr[0], "" | |
min_diff = float("inf") | |
best_split = 0 | |
lengths = [len(s) for s in arr] | |
spaces = len(arr) - 1 | |
left_len = 0 | |
right_len = sum(lengths) + spaces | |
for i in range(1, len(arr)): | |
# Add current word + space to left side | |
left_len += lengths[i - 1] + (1 if i > 1 else 0) | |
# Remove from right side | |
right_len -= lengths[i - 1] + 1 | |
diff = abs(left_len - right_len) | |
if diff < min_diff: | |
min_diff = diff | |
best_split = i | |
return " ".join(arr[:best_split]), " ".join(arr[best_split:]) | |
def recursive_split(text, lang="a"): | |
""" | |
Splits a piece of text into smaller segments so that | |
each segment's phoneme length < some ~limit (~500 tokens). | |
""" | |
# We'll reuse your existing `phonemize_text` + `tokenize` from script 1 | |
# to see if it is < 512 tokens. If it is, return it as a single chunk. | |
# Otherwise, split on punctuation or whitespace and recurse. | |
# 1. Phonemize first, check length | |
ps = phonemize_text(text, lang=lang, do_normalize=True) | |
tokens = tokenize(ps) | |
if len(tokens) < 512: | |
return [(text, ps)] | |
# If too large, we split on certain punctuation or fallback to whitespace | |
# We'll look for punctuation that often indicates sentence boundaries | |
# If none found, fallback to space-split | |
for punctuation in [r"[.?!…]", r"[:,;—]"]: | |
pattern = f"(?:(?<={punctuation})|(?<={punctuation}[\"'»])) " | |
# Attempt to split on that punctuation | |
splits = re.split(pattern, text) | |
if len(splits) > 1: | |
break | |
else: | |
# If we didn't break out, just do whitespace split | |
splits = text.split(" ") | |
# Use resplit_strings to chunk it about halfway | |
left, right = resplit_strings(splits) | |
# Recurse | |
return recursive_split(left, lang=lang) + recursive_split(right, lang=lang) | |
def segment_and_tokenize(long_text, lang="a"): | |
""" | |
Takes a large text, optionally normalizes or cleans it, | |
then breaks it into a list of (segment_text, segment_phonemes). | |
""" | |
# Additional cleaning if you want: | |
# long_text = normalize_text(long_text) # your existing function | |
# We chunk it up using recursive_split | |
segments = recursive_split(long_text, lang=lang) | |
return segments | |
# -------------- Normalization & Phonemization Routines -------------- # | |
def parens_to_angles(s): | |
return s.replace("(", "«").replace(")", "»") | |
def split_num(num): | |
num = num.group() | |
if "." in num: | |
return num | |
elif ":" in num: | |
h, m = [int(n) for n in num.split(":")] | |
if m == 0: | |
return f"{h} o'clock" | |
elif m < 10: | |
return f"{h} oh {m}" | |
return f"{h} {m}" | |
year = int(num[:4]) | |
if year < 1100 or year % 1000 < 10: | |
return num | |
left, right = num[:2], int(num[2:4]) | |
s = "s" if num.endswith("s") else "" | |
if 100 <= year % 1000 <= 999: | |
if right == 0: | |
return f"{left} hundred{s}" | |
elif right < 10: | |
return f"{left} oh {right}{s}" | |
return f"{left} {right}{s}" | |
def flip_money(m): | |
m = m.group() | |
bill = "dollar" if m[0] == "$" else "pound" | |
if m[-1].isalpha(): | |
return f"{m[1:]} {bill}s" | |
elif "." not in m: | |
s = "" if m[1:] == "1" else "s" | |
return f"{m[1:]} {bill}{s}" | |
b, c = m[1:].split(".") | |
s = "" if b == "1" else "s" | |
c = int(c.ljust(2, "0")) | |
coins = ( | |
f"cent{'' if c == 1 else 's'}" | |
if m[0] == "$" | |
else ("penny" if c == 1 else "pence") | |
) | |
return f"{b} {bill}{s} and {c} {coins}" | |
def point_num(num): | |
a, b = num.group().split(".") | |
return " point ".join([a, " ".join(b)]) | |
def normalize_text(text): | |
text = text.replace(chr(8216), "'").replace(chr(8217), "'") | |
text = text.replace("«", chr(8220)).replace("»", chr(8221)) | |
text = text.replace(chr(8220), '"').replace(chr(8221), '"') | |
text = parens_to_angles(text) | |
# Replace some common full-width punctuation in CJK: | |
for a, b in zip("、。!,:;?", ",.!,:;?"): | |
text = text.replace(a, b + " ") | |
text = re.sub(r"[^\S \n]", " ", text) | |
text = re.sub(r" +", " ", text) | |
text = re.sub(r"(?<=\n) +(?=\n)", "", text) | |
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text) | |
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text) | |
text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text) | |
text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text) | |
text = re.sub(r"\betc\.(?! [A-Z])", "etc", text) | |
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text) | |
text = re.sub( | |
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", | |
split_num, | |
text, | |
) | |
text = re.sub(r"(?<=\d),(?=\d)", "", text) | |
text = re.sub( | |
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b", | |
flip_money, | |
text, | |
) | |
text = re.sub(r"\d*\.\d+", point_num, text) | |
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text) # Could be minus; adjust if needed | |
text = re.sub(r"(?<=\d)S", " S", text) | |
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text) | |
text = re.sub(r"(?<=X')S\b", "s", text) | |
text = re.sub( | |
r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text | |
) | |
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text) | |
return text.strip() | |
# ------------------------------------------------------------------- | |
# Vocab and Symbol Mapping | |
# ------------------------------------------------------------------- | |
def get_vocab(): | |
_pad = "$" | |
_punctuation = ';:,.!?¡¿—…"«»“” ' | |
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" | |
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" | |
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) | |
dicts = {} | |
for i, s in enumerate(symbols): | |
dicts[s] = i | |
return dicts | |
VOCAB = get_vocab() | |
def tokenize(ps: str): | |
"""Convert the phoneme string into integer tokens based on VOCAB.""" | |
return [VOCAB.get(p) for p in ps if p in VOCAB] | |
# ------------------------------------------------------------------- | |
# Initialize a simple phonemizer | |
# For English: | |
# 'a' ~ en-us | |
# 'b' ~ en-gb | |
# ------------------------------------------------------------------- | |
phonemizers = dict( | |
a=phonemizer_backend.EspeakBackend( | |
language="en-us", preserve_punctuation=True, with_stress=True | |
), | |
b=phonemizer_backend.EspeakBackend( | |
language="en-gb", preserve_punctuation=True, with_stress=True | |
), | |
# You can add more, e.g. 'j': some Japanese phonemizer, etc. | |
) | |
def phonemize_text(text, lang="a", do_normalize=True): | |
if do_normalize: | |
text = normalize_text(text) | |
ps_list = phonemizers[lang].phonemize([text]) | |
ps = ps_list[0] if ps_list else "" | |
# Some custom replacements (from your code) | |
ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ") | |
ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l") | |
# Example: insert space before "hˈʌndɹɪd" if there's a letter, e.g. "nˈaɪn" => "nˈaɪn hˈʌndɹɪd" | |
ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps) | |
# "z" at the end of a word -> remove space (just your snippet) | |
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', "z", ps) | |
# If lang is 'a', handle "ninety" => "ninedi"? Just from your snippet: | |
if lang == "a": | |
ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps) | |
# Only keep valid symbols | |
ps = "".join(p for p in ps if p in VOCAB) | |
return ps.strip() | |
# ------------------------------------------------------------------- | |
# Utility for generating text masks | |
# ------------------------------------------------------------------- | |
def length_to_mask(lengths): | |
# lengths is a Tensor of shape [B], containing the text length for each batch | |
max_len = lengths.max() | |
row_ids = torch.arange(max_len, device=lengths.device).unsqueeze(0) | |
mask = row_ids.expand(lengths.shape[0], -1) | |
return (mask + 1) > lengths.unsqueeze(1) | |
# ------------------------------------------------------------------- | |
# The forward pass for inference (from your snippet). | |
# This version references `model.predictor`, `model.decoder`, etc. | |
# ------------------------------------------------------------------- | |
def forward_tts(model, tokens, ref_s, speed=1.0): | |
""" | |
model: Munch with submodels: bert, bert_encoder, predictor, decoder, text_encoder | |
tokens: list[int], the tokenized input (without [0, ... , 0] yet) | |
ref_s: reference embedding (torch.Tensor) | |
speed: float, speed factor | |
""" | |
device = ref_s.device | |
tokens_t = torch.LongTensor([[0, *tokens, 0]]).to(device) # add boundary tokens | |
input_lengths = torch.LongTensor([tokens_t.shape[-1]]).to(device) | |
text_mask = length_to_mask(input_lengths).to(device) | |
# 1. Encode with BERT | |
bert_dur = model.bert(tokens_t, attention_mask=(~text_mask).int()) | |
d_en = model.bert_encoder(bert_dur).transpose(-1, -2) | |
# 2. Prosody predictor | |
s = ref_s[ | |
:, 128: | |
] # from your snippet: the last 128 is ???, or the first 128 is ??? | |
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) | |
x, _ = model.predictor.lstm(d) | |
duration = model.predictor.duration_proj(x) | |
duration = torch.sigmoid(duration).sum(axis=-1) / speed | |
pred_dur = torch.round(duration).clamp(min=1).long() | |
# 3. Expand alignment | |
total_len = pred_dur.sum().item() | |
pred_aln_trg = torch.zeros(input_lengths, total_len, device=device) | |
c_frame = 0 | |
for i in range(pred_aln_trg.size(0)): | |
n = pred_dur[0, i].item() | |
pred_aln_trg[i, c_frame : c_frame + n] = 1 | |
c_frame += n | |
# 4. Run F0 + Noise predictor | |
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0) | |
F0_pred, N_pred = model.predictor.F0Ntrain(en, s) | |
# 5. Text encoder -> asr | |
t_en = model.text_encoder(tokens_t, input_lengths, text_mask) | |
asr = t_en @ pred_aln_trg.unsqueeze(0) | |
# 6. Decode audio | |
audio = model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]) # B x audio_len | |
return audio.squeeze().cpu().numpy() | |
def generate_tts(model, text, voicepack, lang="a", speed=1.0): | |
""" | |
model: the Munch returned by build_model(...) | |
text: the input text (string) | |
voicepack: the torch Tensor reference embedding, or a dict of them | |
lang: 'a' or 'b' or etc. from your phonemizers | |
speed: speech speed factor | |
sample_rate: sampling rate for the output | |
""" | |
# 1. Phonemize | |
ps = phonemize_text(text, lang=lang, do_normalize=True) | |
tokens = tokenize(ps) | |
if not tokens: | |
return None, ps | |
# 2. Retrieve reference style | |
# If your voicepack is a single embedding for all lengths, adapt as needed. | |
# If your voicepack is something like `voicepack[len(tokens)]`, do that. | |
# If you have multiple voices, you might do something else. | |
try: | |
ref_s = voicepack[len(tokens)] | |
except: | |
# fallback if len(tokens) is out of range | |
ref_s = voicepack[-1] | |
ref_s = ref_s.to("cpu" if not next(model.bert.parameters()).is_cuda else "cuda") | |
# 3. Generate | |
audio = forward_tts(model, tokens, ref_s, speed=speed) | |
return audio, ps | |
def generate_long_form_tts(model, full_text, voicepack, lang="a", speed=1.0): | |
""" | |
Generate TTS for a large `full_text`, splitting it into smaller segments | |
and concatenating the resulting audio. | |
Returns: (np.float32 array) final_audio, list_of_segment_phonemes | |
""" | |
# 1. Segment the text | |
segments = segment_and_tokenize(full_text, lang=lang) | |
# segments is a list of (seg_text, seg_phonemes) | |
# 2. For each segment, call `generate_tts(...)` | |
audio_chunks = [] | |
all_phonemes = [] | |
for i, (seg_text, seg_ps) in enumerate(segments, 1): | |
print(f"[LongForm] Generating chunk {i}/{len(segments)}: {seg_text[:40]}...") | |
audio, used_phonemes = generate_tts( | |
model, seg_text, voicepack, lang=lang, speed=speed | |
) | |
if audio is not None: | |
audio_chunks.append(audio) | |
all_phonemes.append(used_phonemes) | |
else: | |
print(f"[LongForm] Skipped empty segment {i}...") | |
if not audio_chunks: | |
return None, [] | |
# 3. Concatenate the audio | |
final_audio = np.concatenate(audio_chunks, axis=0) | |
return final_audio, all_phonemes | |
# ------------------------------------------------------------------- | |
# Main CLI | |
# ------------------------------------------------------------------- | |
def main(): | |
parser = argparse.ArgumentParser(description="Kokoro-StyleTTS2 CLI Example") | |
parser.add_argument( | |
"--model", | |
type=str, | |
default="pretrained_models/Kokoro/kokoro-v0_19.pth", | |
help="Path to your model checkpoint (e.g. kokoro-v0_19.pth).", | |
) | |
parser.add_argument( | |
"--config", | |
type=str, | |
default="pretrained_models/Kokoro/config.json", | |
help="Path to config.json (used by build_model).", | |
) | |
parser.add_argument( | |
"--text", | |
type=str, | |
default="Hello world! This is Kokoro, a new text-to-speech model based on StyleTTS2 from 2024!", | |
help="Text to be converted into speech.", | |
) | |
parser.add_argument( | |
"--voicepack", | |
type=str, | |
default="pretrained_models/Kokoro/voices/af.pt", | |
help="Path to a .pt file for your reference embedding(s).", | |
) | |
parser.add_argument( | |
"--output", type=str, default="output.wav", help="Output WAV filename." | |
) | |
parser.add_argument( | |
"--speed", | |
type=float, | |
default=1.0, | |
help="Speech speed factor, e.g. 0.8 slower, 1.2 faster, etc.", | |
) | |
parser.add_argument( | |
"--device", | |
type=str, | |
default="cpu", | |
choices=["cpu", "cuda"], | |
help="Device to run inference on.", | |
) | |
args = parser.parse_args() | |
# 1. Build model using your local build_model function | |
# (which loads TextEncoder, Decoder, etc. and returns a Munch). | |
if not os.path.isfile(args.config): | |
raise FileNotFoundError(f"config.json not found: {args.config}") | |
# Optionally load config as Munch (depends on your build_model usage) | |
# But your snippet does something like: | |
# with open(config, 'r') as r: ... | |
# ... | |
# model = build_model(path, device) | |
# We'll do the same but in a simpler form: | |
device = ( | |
args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu" | |
) | |
print(f"Loading model from: {args.model}") | |
model = build_model( | |
args.model, device | |
) # This requires that `args.model` is the checkpoint path | |
# Because `build_model` returns a Munch (dict of submodules), | |
# we can't just do `model.eval()`, we must set each submodule to eval: | |
for k, subm in model.items(): | |
if isinstance(subm, torch.nn.Module): | |
subm.eval() | |
# 2. Load voicepack | |
if not os.path.isfile(args.voicepack): | |
raise FileNotFoundError(f"Voicepack file not found: {args.voicepack}") | |
print(f"Loading voicepack from: {args.voicepack}") | |
vp = torch.load(args.voicepack, map_location=device) | |
# If your voicepack is an nn.Module, set it to eval as well | |
if isinstance(vp, torch.nn.Module): | |
vp.eval() | |
# 3. Generate audio | |
print(f"Generating speech for text: {args.text}") | |
audio, phonemes = generate_long_form_tts( | |
model, args.text, vp, lang="a", speed=args.speed | |
) | |
if audio is None: | |
print("No tokens were generated (maybe empty text?). Exiting.") | |
return | |
# 4. Write WAV | |
print(f"Writing output to: {args.output}") | |
sf.write(args.output, audio, 22050) | |
print("Finished!") | |
print(f"Phonemes used: {phonemes}") | |
if __name__ == "__main__": | |
main() | |