Spaces:
Running
Running
File size: 5,185 Bytes
2eaa44a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import gradio as gr
import os
import torch
# Import eSpeak TTS pipeline
from tts_cli import (
build_model as build_model_espeak,
generate_long_form_tts as generate_long_form_tts_espeak,
)
# Import OpenPhonemizer TTS pipeline
from tts_cli_op import (
build_model as build_model_open,
generate_long_form_tts as generate_long_form_tts_open,
)
from pretrained_models import Kokoro
# ---------------------------------------------------------------------
# Path to models and voicepacks
# ---------------------------------------------------------------------
MODELS_DIR = "pretrained_models/Kokoro"
VOICES_DIR = "pretrained_models/Kokoro/voices"
# ---------------------------------------------------------------------
# List the models (.pth) and voices (.pt)
# ---------------------------------------------------------------------
def get_models():
return sorted([f for f in os.listdir(MODELS_DIR) if f.endswith(".pth")])
def get_voices():
return sorted([f for f in os.listdir(VOICES_DIR) if f.endswith(".pt")])
# ---------------------------------------------------------------------
# We'll map engine selection -> (build_model_func, generate_func)
# ---------------------------------------------------------------------
ENGINES = {
"espeak": (build_model_espeak, generate_long_form_tts_espeak),
"openphonemizer": (build_model_open, generate_long_form_tts_open),
}
# ---------------------------------------------------------------------
# The main inference function called by Gradio
# ---------------------------------------------------------------------
def tts_inference(text, engine, model_file, voice_file, speed=1.0):
"""
text: Input string
engine: "espeak" or "openphonemizer"
model_file: Selected .pth from the models folder
voice_file: Selected .pt from the voices folder
speed: Speech speed
"""
# 1) Map engine to the correct build_model + generate_long_form_tts
build_fn, gen_fn = ENGINES[engine]
# 2) Prepare paths
model_path = os.path.join(MODELS_DIR, model_file)
voice_path = os.path.join(VOICES_DIR, voice_file)
# 3) Decide device
device = "cuda" if torch.cuda.is_available() else "cpu"
# 4) Load model
model = build_fn(model_path, device=device)
# Set submodules eval
for k, subm in model.items():
if hasattr(subm, "eval"):
subm.eval()
# 5) Load voicepack
voicepack = torch.load(voice_path, map_location=device)
if hasattr(voicepack, "eval"):
voicepack.eval()
# 6) Generate TTS
audio, phonemes = gen_fn(model, text, voicepack, speed=speed)
sr = 22050 # or your actual sample rate
return (sr, audio) # Gradio expects (sample_rate, np_array)
# ---------------------------------------------------------------------
# Build Gradio App
# ---------------------------------------------------------------------
def create_gradio_app():
model_list = get_models()
voice_list = get_voices()
css = """
h4 {
text-align: center;
display:block;
}
h2 {
text-align: center;
display:block;
}
"""
with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo:
gr.Markdown("## Kokoro TTS Demo: Choose engine, model, and voice")
# Row 1: Text input
text_input = gr.Textbox(
label="Input Text",
value="Hello, world! Testing both eSpeak and OpenPhonemizer. Can you believe that we live in 2024 and have access to advanced AI?",
lines=3,
)
# Row 2: Engine selection
engine_dropdown = gr.Dropdown(
choices=["espeak", "openphonemizer"],
value="openphonemizer",
label="Phonemizer",
)
# Row 3: Model dropdown
model_dropdown = gr.Dropdown(
choices=model_list,
value=model_list[0] if model_list else None,
label="Model (.pth)",
)
# Row 4: Voice dropdown
voice_dropdown = gr.Dropdown(
choices=voice_list,
value=voice_list[0] if voice_list else None,
label="Voice (.pt)",
)
# Row 5: Speed slider
speed_slider = gr.Slider(
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speech Speed"
)
# Generate button + audio output
generate_btn = gr.Button("Generate")
tts_output = gr.Audio(label="TTS Output")
# Connect the button to our inference function
generate_btn.click(
fn=tts_inference,
inputs=[
text_input,
engine_dropdown,
model_dropdown,
voice_dropdown,
speed_slider,
],
outputs=tts_output,
)
gr.Markdown(
"#### Kokoro TTS Demo based on [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M)"
)
return demo
# ---------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------
if __name__ == "__main__":
app = create_gradio_app()
app.launch()
|