vits2 / app.py
wetdog's picture
add gradio app
7602717
raw
history blame
4.26 kB
## VCTK
import torch
import os
import commons
import utils
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence
from scipy.io.wavfile import write
import gradio as gr
print("Running GRadio", gr.__version__)
model_path = "vits2_pytorch/G_390000.pth"
config_path = "vits2_pytorch/vits2_vctk_cat_inference.json"
hps = utils.get_hparams_from_file(config_path)
if (
"use_mel_posterior_encoder" in hps.model.keys()
and hps.model.use_mel_posterior_encoder == True
):
print("Using mel posterior encoder for VITS2")
posterior_channels = 80 # vits2
hps.data.use_mel_posterior_encoder = True
else:
print("Using lin posterior encoder for VITS1")
posterior_channels = hps.data.filter_length // 2 + 1
hps.data.use_mel_posterior_encoder = False
net_g = SynthesizerTrn(
len(symbols),
posterior_channels,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model
)
_ = net_g.eval()
_ = utils.load_checkpoint(model_path, net_g, None)
def get_text(text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners)
#text_norm = cleaned_text_to_sequence(text) # if model was trained with text
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def tts(text:str, speaker_id:int, speed:float, noise_scale:float=0.667, noise_scale_w:float=0.8):
stn_tst = get_text(text, hps)
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
sid = torch.LongTensor([speaker_id])
waveform = (
net_g.infer(
x_tst,
x_tst_lengths,
sid=sid,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=1/speed,
)[0][0, 0]
.data.cpu()
.float()
.numpy()
)
return gr.make_waveform((22050, waveform))
## GUI space
title = """
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<div
style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
> <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
VITS2 TTS Catalan Demo
</h1> </div>
</div>
"""
description = """
VITS2 is an end-to-end speech synthesis model that predicts a speech waveform conditional on an input text sequence. VITS2 improved the
training and inference efficiency and naturalness by introducing adversarial learning into the duration predictor. The transformer
block was added to the normalizing flows to capture the long-term dependency when transforming the distribution.
The synthesis quality was improved by incorporating Gaussian noise into the alignment search.
This model is being trained in openslr69 and festcat datasets
"""
article = "Model by Jungil Kong, et al. from SK telecom. Demo by BSC."
vits2_inference = gr.Interface(
fn=tts,
inputs=[
gr.Textbox(
value="m'ha costat desenvolupar molt una veu, i ara que la tinc no estaré en silenci.",
max_lines=1,
label="Input text",
),
gr.Slider(
1,
47,
value=10,
step=1,
label="Speaker id",
info=f"This model is trained on 47 speakers. You can prompt the model using one of these speaker ids.",
),
gr.Slider(
0.5,
1.5,
value=1,
step=0.1,
label="Speed",
),
gr.Slider(
0.2,
2.0,
value=0.667,
step=0.01,
label="Noise scale",
),
gr.Slider(
0.2,
2.0,
value=0.8,
step=0.01,
label="Noise scale w",
),
],
outputs=gr.Audio(),
)
demo = gr.Blocks()
with demo:
gr.Markdown(title)
gr.Markdown(description)
gr.TabbedInterface([vits2_inference], ["Multispeaker"])
gr.Markdown(article)
demo.queue(max_size=10)
demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)