File size: 6,035 Bytes
c968fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import time
import numpy as np
from tqdm import tqdm
import torch
import json
from models.tts.base.tts_inferece import TTSInference
from models.tts.vits.vits_dataset import VITSTestDataset, VITSTestCollator
from models.tts.vits.vits import SynthesizerTrn
from processors.phone_extractor import phoneExtractor
from text.text_token_collation import phoneIDCollation
from utils.data_utils import *
class VitsInference(TTSInference):
def __init__(self, args=None, cfg=None):
TTSInference.__init__(self, args, cfg)
def _build_model(self):
net_g = SynthesizerTrn(
self.cfg.model.text_token_num,
self.cfg.preprocess.n_fft // 2 + 1,
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
**self.cfg.model,
)
return net_g
def _build_test_dataset(sefl):
return VITSTestDataset, VITSTestCollator
def build_save_dir(self, dataset, speaker):
save_dir = os.path.join(
self.args.output_dir,
"tts_am_step-{}_{}".format(self.am_restore_step, self.args.mode),
)
if dataset is not None:
save_dir = os.path.join(save_dir, "data_{}".format(dataset))
if speaker != -1:
save_dir = os.path.join(
save_dir,
"spk_{}".format(speaker),
)
os.makedirs(save_dir, exist_ok=True)
print("Saving to ", save_dir)
return save_dir
def inference_for_batches(
self, noise_scale=0.667, noise_scale_w=0.8, length_scale=1
):
###### Construct test_batch ######
n_batch = len(self.test_dataloader)
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(
"Model eval time: {}, batch_size = {}, n_batch = {}".format(
now, self.test_batch_size, n_batch
)
)
self.model.eval()
###### Inference for each batch ######
pred_res = []
with torch.no_grad():
for i, batch_data in enumerate(
self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader)
):
spk_id = None
if (
self.cfg.preprocess.use_spkid
and self.cfg.train.multi_speaker_training
):
spk_id = batch_data["spk_id"]
outputs = self.model.infer(
batch_data["phone_seq"],
batch_data["phone_len"],
spk_id,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
)
audios = outputs["y_hat"]
masks = outputs["mask"]
for idx in range(audios.size(0)):
audio = audios[idx, 0, :].data.cpu().float()
mask = masks[idx, :, :]
audio_length = (
mask.sum([0, 1]).long() * self.cfg.preprocess.hop_size
)
audio_length = audio_length.cpu().numpy()
audio = audio[:audio_length]
pred_res.append(audio)
return pred_res
def inference_for_single_utterance(
self, noise_scale=0.667, noise_scale_w=0.8, length_scale=1
):
text = self.args.text
# get phone symbol file
phone_symbol_file = None
if self.cfg.preprocess.phone_extractor != "lexicon":
phone_symbol_file = os.path.join(
self.exp_dir, self.cfg.preprocess.symbols_dict
)
assert os.path.exists(phone_symbol_file)
# convert text to phone sequence
phone_extractor = phoneExtractor(self.cfg)
phone_seq = phone_extractor.extract_phone(text) # phone_seq: list
# convert phone sequence to phone id sequence
phon_id_collator = phoneIDCollation(
self.cfg, symbols_dict_file=phone_symbol_file
)
phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq)
if self.cfg.preprocess.add_blank:
phone_id_seq = intersperse(phone_id_seq, 0)
# convert phone sequence to phone id sequence
phone_id_seq = np.array(phone_id_seq)
phone_id_seq = torch.from_numpy(phone_id_seq)
# get speaker id if multi-speaker training and use speaker id
speaker_id = None
if self.cfg.preprocess.use_spkid and self.cfg.train.multi_speaker_training:
spk2id_file = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
with open(spk2id_file, "r") as f:
spk2id = json.load(f)
speaker_name = self.args.speaker_name
assert (
speaker_name in spk2id
), f"Speaker {speaker_name} not found in the spk2id keys. \
Please make sure you've specified the correct speaker name in infer_speaker_name."
speaker_id = spk2id[speaker_name]
speaker_id = torch.from_numpy(
np.array([speaker_id], dtype=np.int32)
).unsqueeze(0)
with torch.no_grad():
x_tst = phone_id_seq.to(self.device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phone_id_seq.size(0)]).to(self.device)
if speaker_id is not None:
speaker_id = speaker_id.to(self.device)
outputs = self.model.infer(
x_tst,
x_tst_lengths,
sid=speaker_id,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
)
audio = outputs["y_hat"][0, 0].data.cpu().float().numpy()
return audio
|