File size: 17,786 Bytes
2eaa44a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
#!/usr/bin/env python3
# tts_cli.py
"""
Example CLI for generating audio with Kokoro-StyleTTS2.

Usage:
  python tts_cli.py \
    --model /path/to/kokoro-v0_19.pth \
    --config /path/to/config.json \
    --text "Hello, my stinking friends from 1906! You stink." \
    --voicepack /path/to/af.pt \
    --output output.wav

Make sure:
1. `models.py` is in the same folder (with `build_model`, `Decoder`, etc.).
2. You have installed the needed libraries:
   pip install torch phonemizer munch soundfile pyyaml
3. The model is a checkpoint that your `build_model` can load.

Adapt as needed!
"""

import argparse
import os
import re
import torch
import soundfile as sf
import numpy as np
from phonemizer import backend as phonemizer_backend

# If you use eSpeak library:
try:
    from espeak_util import set_espeak_library

    set_espeak_library()
except ImportError:
    pass

# --------------------------------------------------------------------
# Import from your local `models.py` (requires that file to be present).
#   This example assumes `build_model` loads the entire TTS submodules
#   (bert, bert_encoder, predictor, decoder, text_encoder).
# --------------------------------------------------------------------
from models import build_model


def resplit_strings(arr):
    """
    Given a list of string tokens (e.g. words, phrases), tries to
    split them into two sub-lists whose total lengths are as balanced
    as possible. The goal is to chunk a large string in half without
    splitting in the middle of a word.
    """
    if not arr:
        return "", ""
    if len(arr) == 1:
        return arr[0], ""

    min_diff = float("inf")
    best_split = 0
    lengths = [len(s) for s in arr]
    spaces = len(arr) - 1
    left_len = 0
    right_len = sum(lengths) + spaces

    for i in range(1, len(arr)):
        # Add current word + space to left side
        left_len += lengths[i - 1] + (1 if i > 1 else 0)
        # Remove from right side
        right_len -= lengths[i - 1] + 1
        diff = abs(left_len - right_len)
        if diff < min_diff:
            min_diff = diff
            best_split = i

    return " ".join(arr[:best_split]), " ".join(arr[best_split:])


def recursive_split(text, lang="a"):
    """
    Splits a piece of text into smaller segments so that
    each segment's phoneme length < some ~limit (~500 tokens).
    """
    # We'll reuse your existing `phonemize_text` + `tokenize` from script 1
    # to see if it is < 512 tokens. If it is, return it as a single chunk.
    # Otherwise, split on punctuation or whitespace and recurse.

    # 1. Phonemize first, check length
    ps = phonemize_text(text, lang=lang, do_normalize=True)
    tokens = tokenize(ps)
    if len(tokens) < 512:
        return [(text, ps)]

    # If too large, we split on certain punctuation or fallback to whitespace
    # We'll look for punctuation that often indicates sentence boundaries
    # If none found, fallback to space-split
    for punctuation in [r"[.?!…]", r"[:,;—]"]:
        pattern = f"(?:(?<={punctuation})|(?<={punctuation}[\"'»])) "
        # Attempt to split on that punctuation
        splits = re.split(pattern, text)
        if len(splits) > 1:
            break
    else:
        # If we didn't break out, just do whitespace split
        splits = text.split(" ")

    # Use resplit_strings to chunk it about halfway
    left, right = resplit_strings(splits)
    # Recurse
    return recursive_split(left, lang=lang) + recursive_split(right, lang=lang)


def segment_and_tokenize(long_text, lang="a"):
    """
    Takes a large text, optionally normalizes or cleans it,
    then breaks it into a list of (segment_text, segment_phonemes).
    """
    # Additional cleaning if you want:
    # long_text = normalize_text(long_text) # your existing function
    # We chunk it up using recursive_split
    segments = recursive_split(long_text, lang=lang)
    return segments


# -------------- Normalization & Phonemization Routines -------------- #
def parens_to_angles(s):
    return s.replace("(", "«").replace(")", "»")


def split_num(num):
    num = num.group()
    if "." in num:
        return num
    elif ":" in num:
        h, m = [int(n) for n in num.split(":")]
        if m == 0:
            return f"{h} o'clock"
        elif m < 10:
            return f"{h} oh {m}"
        return f"{h} {m}"
    year = int(num[:4])
    if year < 1100 or year % 1000 < 10:
        return num
    left, right = num[:2], int(num[2:4])
    s = "s" if num.endswith("s") else ""
    if 100 <= year % 1000 <= 999:
        if right == 0:
            return f"{left} hundred{s}"
        elif right < 10:
            return f"{left} oh {right}{s}"
    return f"{left} {right}{s}"


def flip_money(m):
    m = m.group()
    bill = "dollar" if m[0] == "$" else "pound"
    if m[-1].isalpha():
        return f"{m[1:]} {bill}s"
    elif "." not in m:
        s = "" if m[1:] == "1" else "s"
        return f"{m[1:]} {bill}{s}"
    b, c = m[1:].split(".")
    s = "" if b == "1" else "s"
    c = int(c.ljust(2, "0"))
    coins = (
        f"cent{'' if c == 1 else 's'}"
        if m[0] == "$"
        else ("penny" if c == 1 else "pence")
    )
    return f"{b} {bill}{s} and {c} {coins}"


def point_num(num):
    a, b = num.group().split(".")
    return " point ".join([a, " ".join(b)])


def normalize_text(text):
    text = text.replace(chr(8216), "'").replace(chr(8217), "'")
    text = text.replace("«", chr(8220)).replace("»", chr(8221))
    text = text.replace(chr(8220), '"').replace(chr(8221), '"')
    text = parens_to_angles(text)

    # Replace some common full-width punctuation in CJK:
    for a, b in zip("、。!,:;?", ",.!,:;?"):
        text = text.replace(a, b + " ")

    text = re.sub(r"[^\S \n]", " ", text)
    text = re.sub(r"  +", " ", text)
    text = re.sub(r"(?<=\n) +(?=\n)", "", text)
    text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
    text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
    text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text)
    text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text)
    text = re.sub(r"\betc\.(?! [A-Z])", "etc", text)
    text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
    text = re.sub(
        r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)",
        split_num,
        text,
    )
    text = re.sub(r"(?<=\d),(?=\d)", "", text)
    text = re.sub(
        r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
        flip_money,
        text,
    )
    text = re.sub(r"\d*\.\d+", point_num, text)
    text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)  # Could be minus; adjust if needed
    text = re.sub(r"(?<=\d)S", " S", text)
    text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
    text = re.sub(r"(?<=X')S\b", "s", text)
    text = re.sub(
        r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
    )
    text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
    return text.strip()


# -------------------------------------------------------------------
# Vocab and Symbol Mapping
# -------------------------------------------------------------------
def get_vocab():
    _pad = "$"
    _punctuation = ';:,.!?¡¿—…"«»“” '
    _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
    _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
    symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
    dicts = {}
    for i, s in enumerate(symbols):
        dicts[s] = i
    return dicts


VOCAB = get_vocab()


def tokenize(ps: str):
    """Convert the phoneme string into integer tokens based on VOCAB."""
    return [VOCAB.get(p) for p in ps if p in VOCAB]


# -------------------------------------------------------------------
# Initialize a simple phonemizer
#   For English:
#       'a' ~ en-us
#       'b' ~ en-gb
# -------------------------------------------------------------------
phonemizers = dict(
    a=phonemizer_backend.EspeakBackend(
        language="en-us", preserve_punctuation=True, with_stress=True
    ),
    b=phonemizer_backend.EspeakBackend(
        language="en-gb", preserve_punctuation=True, with_stress=True
    ),
    # You can add more, e.g. 'j': some Japanese phonemizer, etc.
)


def phonemize_text(text, lang="a", do_normalize=True):
    if do_normalize:
        text = normalize_text(text)
    ps_list = phonemizers[lang].phonemize([text])
    ps = ps_list[0] if ps_list else ""

    # Some custom replacements (from your code)
    ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ")
    ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l")
    # Example: insert space before "hˈʌndɹɪd" if there's a letter, e.g. "nˈaɪn" => "nˈaɪn hˈʌndɹɪd"
    ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps)
    # "z" at the end of a word -> remove space (just your snippet)
    ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', "z", ps)
    # If lang is 'a', handle "ninety" => "ninedi"? Just from your snippet:
    if lang == "a":
        ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps)

    # Only keep valid symbols
    ps = "".join(p for p in ps if p in VOCAB)
    return ps.strip()


# -------------------------------------------------------------------
# Utility for generating text masks
# -------------------------------------------------------------------
def length_to_mask(lengths):
    # lengths is a Tensor of shape [B], containing the text length for each batch
    max_len = lengths.max()
    row_ids = torch.arange(max_len, device=lengths.device).unsqueeze(0)
    mask = row_ids.expand(lengths.shape[0], -1)
    return (mask + 1) > lengths.unsqueeze(1)


# -------------------------------------------------------------------
# The forward pass for inference (from your snippet).
# This version references `model.predictor`, `model.decoder`, etc.
# -------------------------------------------------------------------
@torch.no_grad()
def forward_tts(model, tokens, ref_s, speed=1.0):
    """
    model:   Munch with submodels: bert, bert_encoder, predictor, decoder, text_encoder
    tokens:  list[int], the tokenized input (without [0, ... , 0] yet)
    ref_s:   reference embedding (torch.Tensor)
    speed:   float, speed factor
    """
    device = ref_s.device
    tokens_t = torch.LongTensor([[0, *tokens, 0]]).to(device)  # add boundary tokens
    input_lengths = torch.LongTensor([tokens_t.shape[-1]]).to(device)
    text_mask = length_to_mask(input_lengths).to(device)

    # 1. Encode with BERT
    bert_dur = model.bert(tokens_t, attention_mask=(~text_mask).int())
    d_en = model.bert_encoder(bert_dur).transpose(-1, -2)

    # 2. Prosody predictor
    s = ref_s[
        :, 128:
    ]  # from your snippet: the last 128 is ???, or the first 128 is ???

    d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
    x, _ = model.predictor.lstm(d)
    duration = model.predictor.duration_proj(x)
    duration = torch.sigmoid(duration).sum(axis=-1) / speed
    pred_dur = torch.round(duration).clamp(min=1).long()

    # 3. Expand alignment
    total_len = pred_dur.sum().item()
    pred_aln_trg = torch.zeros(input_lengths, total_len, device=device)
    c_frame = 0
    for i in range(pred_aln_trg.size(0)):
        n = pred_dur[0, i].item()
        pred_aln_trg[i, c_frame : c_frame + n] = 1
        c_frame += n

    # 4. Run F0 + Noise predictor
    en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0)
    F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

    # 5. Text encoder -> asr
    t_en = model.text_encoder(tokens_t, input_lengths, text_mask)
    asr = t_en @ pred_aln_trg.unsqueeze(0)

    # 6. Decode audio
    audio = model.decoder(asr, F0_pred, N_pred, ref_s[:, :128])  # B x audio_len
    return audio.squeeze().cpu().numpy()


def generate_tts(model, text, voicepack, lang="a", speed=1.0):
    """
    model:    the Munch returned by build_model(...)
    text:     the input text (string)
    voicepack: the torch Tensor reference embedding, or a dict of them
    lang:     'a' or 'b' or etc. from your phonemizers
    speed:    speech speed factor
    sample_rate: sampling rate for the output
    """
    # 1. Phonemize
    ps = phonemize_text(text, lang=lang, do_normalize=True)
    tokens = tokenize(ps)
    if not tokens:
        return None, ps

    # 2. Retrieve reference style
    #    If your voicepack is a single embedding for all lengths, adapt as needed.
    #    If your voicepack is something like `voicepack[len(tokens)]`, do that.
    #    If you have multiple voices, you might do something else.
    try:
        ref_s = voicepack[len(tokens)]
    except:
        # fallback if len(tokens) is out of range
        ref_s = voicepack[-1]
    ref_s = ref_s.to("cpu" if not next(model.bert.parameters()).is_cuda else "cuda")

    # 3. Generate
    audio = forward_tts(model, tokens, ref_s, speed=speed)
    return audio, ps


def generate_long_form_tts(model, full_text, voicepack, lang="a", speed=1.0):
    """
    Generate TTS for a large `full_text`, splitting it into smaller segments
    and concatenating the resulting audio.

    Returns: (np.float32 array) final_audio, list_of_segment_phonemes
    """
    # 1. Segment the text
    segments = segment_and_tokenize(full_text, lang=lang)
    # segments is a list of (seg_text, seg_phonemes)

    # 2. For each segment, call `generate_tts(...)`
    audio_chunks = []
    all_phonemes = []
    for i, (seg_text, seg_ps) in enumerate(segments, 1):
        print(f"[LongForm] Generating chunk {i}/{len(segments)}: {seg_text[:40]}...")
        audio, used_phonemes = generate_tts(
            model, seg_text, voicepack, lang=lang, speed=speed
        )
        if audio is not None:
            audio_chunks.append(audio)
            all_phonemes.append(used_phonemes)
        else:
            print(f"[LongForm] Skipped empty segment {i}...")

    if not audio_chunks:
        return None, []

    # 3. Concatenate the audio
    final_audio = np.concatenate(audio_chunks, axis=0)
    return final_audio, all_phonemes


# -------------------------------------------------------------------
# Main CLI
# -------------------------------------------------------------------
def main():
    parser = argparse.ArgumentParser(description="Kokoro-StyleTTS2 CLI Example")
    parser.add_argument(
        "--model",
        type=str,
        default="pretrained_models/Kokoro/kokoro-v0_19.pth",
        help="Path to your model checkpoint (e.g. kokoro-v0_19.pth).",
    )
    parser.add_argument(
        "--config",
        type=str,
        default="pretrained_models/Kokoro/config.json",
        help="Path to config.json (used by build_model).",
    )
    parser.add_argument(
        "--text",
        type=str,
        default="Hello world! This is Kokoro, a new text-to-speech model based on StyleTTS2 from 2024!",
        help="Text to be converted into speech.",
    )
    parser.add_argument(
        "--voicepack",
        type=str,
        default="pretrained_models/Kokoro/voices/af.pt",
        help="Path to a .pt file for your reference embedding(s).",
    )
    parser.add_argument(
        "--output", type=str, default="output.wav", help="Output WAV filename."
    )
    parser.add_argument(
        "--speed",
        type=float,
        default=1.0,
        help="Speech speed factor, e.g. 0.8 slower, 1.2 faster, etc.",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cpu",
        choices=["cpu", "cuda"],
        help="Device to run inference on.",
    )
    args = parser.parse_args()

    # 1. Build model using your local build_model function
    #    (which loads TextEncoder, Decoder, etc. and returns a Munch).
    if not os.path.isfile(args.config):
        raise FileNotFoundError(f"config.json not found: {args.config}")

    # Optionally load config as Munch (depends on your build_model usage)
    # But your snippet does something like:
    #   with open(config, 'r') as r: ...
    #   ...
    #   model = build_model(path, device)
    # We'll do the same but in a simpler form:
    device = (
        args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu"
    )
    print(f"Loading model from: {args.model}")
    model = build_model(
        args.model, device
    )  # This requires that `args.model` is the checkpoint path

    # Because `build_model` returns a Munch (dict of submodules),
    # we can't just do `model.eval()`, we must set each submodule to eval:
    for k, subm in model.items():
        if isinstance(subm, torch.nn.Module):
            subm.eval()

    # 2. Load voicepack
    if not os.path.isfile(args.voicepack):
        raise FileNotFoundError(f"Voicepack file not found: {args.voicepack}")
    print(f"Loading voicepack from: {args.voicepack}")
    vp = torch.load(args.voicepack, map_location=device)
    # If your voicepack is an nn.Module, set it to eval as well
    if isinstance(vp, torch.nn.Module):
        vp.eval()

    # 3. Generate audio
    print(f"Generating speech for text: {args.text}")
    audio, phonemes = generate_long_form_tts(
        model, args.text, vp, lang="a", speed=args.speed
    )
    if audio is None:
        print("No tokens were generated (maybe empty text?). Exiting.")
        return

    # 4. Write WAV
    print(f"Writing output to: {args.output}")
    sf.write(args.output, audio, 22050)

    print("Finished!")
    print(f"Phonemes used: {phonemes}")


if __name__ == "__main__":
    main()