Spaces:
Sleeping
Sleeping
import gradio as gr | |
import argparse | |
import os | |
import torch | |
import soundfile as sf | |
import numpy as np | |
from models.tts.naturalspeech2.ns2 import NaturalSpeech2 | |
from encodec import EncodecModel | |
from encodec.utils import convert_audio | |
from utils.util import load_config | |
from text import text_to_sequence | |
from text.cmudict import valid_symbols | |
from text.g2p import preprocess_english, read_lexicon | |
import torchaudio | |
def build_codec(device): | |
encodec_model = EncodecModel.encodec_model_24khz() | |
encodec_model = encodec_model.to(device=device) | |
encodec_model.set_target_bandwidth(12.0) | |
return encodec_model | |
def build_model(cfg, device): | |
model = NaturalSpeech2(cfg.model) | |
model.load_state_dict( | |
torch.load( | |
"ckpts/ns2/pytorch_model.bin", | |
map_location="cpu", | |
) | |
) | |
model = model.to(device=device) | |
return model | |
def ns2_inference( | |
prmopt_audio_path, | |
text, | |
diffusion_steps=100, | |
): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
os.environ["WORK_DIR"] = "./" | |
cfg = load_config("egs/tts/NaturalSpeech2/exp_config.json") | |
model = build_model(cfg, device) | |
codec = build_codec(device) | |
ref_wav_path = prmopt_audio_path | |
ref_wav, sr = torchaudio.load(ref_wav_path) | |
ref_wav = convert_audio( | |
ref_wav, sr, codec.sample_rate, codec.channels | |
) | |
ref_wav = ref_wav.unsqueeze(0).to(device=device) | |
with torch.no_grad(): | |
encoded_frames = codec.encode(ref_wav) | |
ref_code = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) | |
ref_mask = torch.ones(ref_code.shape[0], ref_code.shape[-1]).to(ref_code.device) | |
symbols = valid_symbols + ["sp", "spn", "sil"] + ["<s>", "</s>"] | |
phone2id = {s: i for i, s in enumerate(symbols)} | |
id2phone = {i: s for s, i in phone2id.items()} | |
lexicon = read_lexicon(cfg.preprocess.lexicon_path) | |
phone_seq = preprocess_english(text, lexicon) | |
phone_id = np.array( | |
[ | |
*map( | |
phone2id.get, | |
phone_seq.replace("{", "").replace("}", "").split(), | |
) | |
] | |
) | |
phone_id = torch.from_numpy(phone_id).unsqueeze(0).to(device=device) | |
x0, prior_out = model.inference( | |
ref_code, phone_id, ref_mask, diffusion_steps | |
) | |
latent_ref = codec.quantizer.vq.decode(ref_code.transpose(0, 1)) | |
rec_wav = codec.decoder(x0) | |
os.makedirs("result", exist_ok=True) | |
sf.write( | |
"result/{}.wav".format(prmopt_audio_path.split("/")[-1][:-4] + "_zero_shot_result"), | |
rec_wav[0, 0].detach().cpu().numpy(), | |
samplerate=24000, | |
) | |
result_file = "result/{}.wav".format(prmopt_audio_path.split("/")[-1][:-4] + "_zero_shot_result") | |
return result_file | |
demo_inputs = [ | |
gr.Audio( | |
sources=["upload", "microphone"], | |
label="Upload a reference speech you want to clone timbre", | |
type="filepath", | |
), | |
gr.Textbox( | |
value="Amphion is a toolkit that can speak, make sounds, and sing.", | |
label="Text you want to generate", | |
type="text", | |
), | |
gr.Slider( | |
10, | |
1000, | |
value=200, | |
step=1, | |
label="Diffusion Inference Steps", | |
info="As the step number increases, the synthesis quality will be better while the inference speed will be lower", | |
), | |
] | |
demo_outputs = gr.Audio(label="") | |
demo = gr.Interface( | |
fn=ns2_inference, | |
inputs=demo_inputs, | |
outputs=demo_outputs, | |
title="Amphion Zero-Shot TTS NaturalSpeech2" | |
) | |
if __name__ == "__main__": | |
demo.launch() |