wasmdashai's picture
first commit
7694c84
raw
history blame
14.4 kB
from typing import List, Union, Optional, Literal
import text
import torch
import torch.nn as nn
from .tacotron2_ms import Tacotron2MS
from text.symbols import EOS_TOKENS, SEPARATOR_TOKEN
from utils import get_basic_config
from vocoder import load_hifigan
from vocoder.hifigan.denoiser import Denoiser
from ..diacritizers import load_vowelizer
_VOWELIZER_TYPE = Literal['shakkala', 'shakkelha']
def text_collate_fn(batch: List[torch.Tensor]):
"""
Args:
batch: List[text_ids]
Returns:
text_ids_pad
input_lens_sorted
reverse_ids
"""
input_lens_sorted, input_sort_ids = torch.sort(
torch.LongTensor([len(x) for x in batch]), descending=True)
max_input_len = input_lens_sorted[0]
text_ids_pad = torch.LongTensor(len(batch), max_input_len)
text_ids_pad.zero_()
for i in range(len(input_sort_ids)):
text_ids = batch[input_sort_ids[i]]
text_ids_pad[i, :text_ids.size(0)] = text_ids
return text_ids_pad, input_lens_sorted, input_sort_ids.argsort()
def needs_postprocessing(token: str):
return token not in [
'a', 'i', 'u', 'aa', 'ii', 'uu', 'n', 'm', 'h']
def truncate_mel(mel_spec: torch.Tensor, ps_end):
ps_end_max = ps_end.max()
n_end = next(i for i in range(len(ps_end)) if ps_end[i] >= 0.8*ps_end_max)
mel_cut = mel_spec[:, :n_end]
mel_cut = torch.nn.functional.pad(mel_cut, (0, 3), mode='replicate')
return mel_cut
def resize_mel(mel: torch.Tensor,
rate: Union[int, float] = 1.0,
mode: str = 'bicubic'):
"""
Args:
mel: mel spectrogram [num_mels, spec_length]
Returns:
resized_mel [num_mels, new_spec_length]
"""
Nf, Nt = mel.shape[-2:]
Nt_new = int(1 / rate * Nt)
if Nt == Nt_new:
return mel
mel_res = torch.nn.functional.interpolate(mel[None, None, ...],
(Nf, Nt_new), mode=mode)[0, 0]
return mel_res
class Tacotron2(Tacotron2MS):
def __init__(self,
checkpoint: str = None,
n_symbol: int = 40,
decoder_max_step: int = 3000,
arabic_in: bool = True,
vowelizer: Optional[_VOWELIZER_TYPE] = None,
**kwargs):
super().__init__(n_symbol=n_symbol,
decoder_max_step=decoder_max_step,
**kwargs)
self.n_eos = len(EOS_TOKENS)
self.arabic_in = arabic_in
if checkpoint is not None:
state_dicts = torch.load(checkpoint, map_location='cpu')
self.load_state_dict(state_dicts['model'])
self.config = get_basic_config()
self.vowelizers = {}
if vowelizer is not None:
self.vowelizers[vowelizer] = load_vowelizer(vowelizer, self.config)
self.default_vowelizer = vowelizer
self.phon_to_id = None
if checkpoint is not None and 'symbols' in state_dicts:
self.phon_to_id = {phon: i for i, phon in enumerate(state_dicts['symbols'])}
self.eval()
@property
def device(self):
return next(self.parameters()).device
def _vowelize(self, utterance: str, vowelizer: Optional[_VOWELIZER_TYPE] = None):
vowelizer = self.default_vowelizer if vowelizer is None else vowelizer
if vowelizer is not None:
if not vowelizer in self.vowelizers:
self.vowelizers[vowelizer] = load_vowelizer(vowelizer, self.config)
# print(f"loaded: {vowelizer}")
utterance_ar = text.buckwalter_to_arabic(utterance)
utterance = self.vowelizers[vowelizer].predict(utterance_ar)
return utterance
def _tokenize(self, utterance: str, vowelizer: Optional[_VOWELIZER_TYPE] = None):
utterance = self._vowelize(utterance=utterance, vowelizer=vowelizer)
if self.arabic_in:
return text.arabic_to_tokens(utterance)
return text.buckwalter_to_tokens(utterance)
@torch.inference_mode()
def ttmel_single(self,
utterance: str,
speaker_id: int = 0,
speed: Union[int, float, None] = None,
vowelizer: Optional[_VOWELIZER_TYPE] = None,
postprocess_mel: bool = True,
):
tokens = self._tokenize(utterance, vowelizer=vowelizer)
process_mel = False
if postprocess_mel and needs_postprocessing(tokens[-self.n_eos-1]):
tokens.insert(-self.n_eos, SEPARATOR_TOKEN)
process_mel = True
token_ids = text.tokens_to_ids(tokens, self.phon_to_id)
ids_batch = torch.LongTensor(token_ids).unsqueeze(0).to(self.device)
sid = torch.LongTensor([speaker_id]).to(self.device)
# Infer spectrogram and wave
mel_spec, _, alignments = self.infer(ids_batch, sid)
mel_spec = mel_spec[0]
if process_mel:
mel_spec = truncate_mel(mel_spec, alignments[0, :, -self.n_eos-1])
if speed is not None:
mel_spec = resize_mel(mel_spec, rate=speed)
return mel_spec # [F, T]
@torch.inference_mode()
def ttmel_batch(self,
batch: List[str],
speaker_id: int = 0,
speed: Union[int, float, None] = None,
vowelizer: Optional[_VOWELIZER_TYPE] = None,
postprocess_mel: bool = True
):
batch_tokens = [self._tokenize(line, vowelizer=vowelizer) for line in batch]
list_postprocess = []
if postprocess_mel:
for i in range(len(batch_tokens)):
process_mel = False
if needs_postprocessing(batch_tokens[i][-self.n_eos-1]):
batch_tokens[i].insert(-self.n_eos, SEPARATOR_TOKEN)
process_mel = True
list_postprocess.append(process_mel)
batch_ids = [torch.LongTensor(
text.tokens_to_ids(tokens, self.phon_to_id)
) for tokens in batch_tokens]
batch = text_collate_fn(batch_ids)
(
batch_ids_padded, batch_lens_sorted,
reverse_sort_ids
) = batch
batch_ids_padded = batch_ids_padded.to(self.device)
batch_lens_sorted = batch_lens_sorted.to(self.device)
batch_sids = batch_lens_sorted*0 + speaker_id
y_pred = self.infer(batch_ids_padded, batch_sids, batch_lens_sorted)
mel_outputs_postnet, mel_specgram_lengths, alignments = y_pred
mel_list = []
for i, id in enumerate(reverse_sort_ids):
mel = mel_outputs_postnet[id, :, :mel_specgram_lengths[id]]
if postprocess_mel and list_postprocess[i]:
ps_end = alignments[id,
:mel_specgram_lengths[id],
batch_lens_sorted[id]-self.n_eos-1]
mel = truncate_mel(mel, ps_end)
if speed is not None:
mel = resize_mel(mel, rate=speed)
mel_list.append(mel)
return mel_list
def ttmel(self,
text_input: Union[str, List[str]],
speaker_id: int = 0,
speed: Union[int, float, None] = None,
batch_size: int = 8,
vowelizer: Optional[_VOWELIZER_TYPE] = None,
postprocess_mel: bool = True
):
# input: string
if isinstance(text_input, str):
return self.ttmel_single(text_input, speaker_id,
speed, vowelizer,
postprocess_mel)
# input: list
assert isinstance(text_input, list)
batch = text_input
mel_list = []
if batch_size == 1:
for sample in batch:
mel = self.ttmel_single(sample, speaker_id,
speed, vowelizer,
postprocess_mel)
mel_list.append(mel)
return mel_list
# infer one batch
if len(batch) <= batch_size:
return self.ttmel_batch(batch, speaker_id,
speed, vowelizer,
postprocess_mel)
# batched inference
batches = [batch[k:k+batch_size]
for k in range(0, len(batch), batch_size)]
for batch in batches:
mels = self.ttmel_batch(batch, speaker_id,
speed, vowelizer,
postprocess_mel)
mel_list += mels
return mel_list
class Tacotron2Wave(nn.Module):
def __init__(self,
model_sd_path: str,
vocoder_sd: Optional[str] = None,
vocoder_config: Optional[str] = None,
vowelizer: Optional[_VOWELIZER_TYPE] = None,
arabic_in: bool = True,
n_symbol: int = 40
):
super().__init__()
model = Tacotron2(n_symbol=n_symbol,
arabic_in=arabic_in,
vowelizer=vowelizer)
state_dicts = torch.load(model_sd_path, map_location='cpu')
model.load_state_dict(state_dicts['model'])
self.model = model
if vocoder_sd is None or vocoder_config is None:
config = get_basic_config()
vocoder_sd = config.vocoder_state_path
vocoder_config = config.vocoder_config_path
vocoder = load_hifigan(vocoder_sd, vocoder_config)
self.vocoder = vocoder
self.denoiser = Denoiser(vocoder)
self.eval()
@property
def device(self):
return next(self.parameters()).device
def forward(self, x):
return x
@torch.inference_mode()
def tts_single(self,
text_input: str,
speed: Union[int, float, None] = None,
speaker_id: int = 0,
denoise: float = 0,
vowelizer: Optional[_VOWELIZER_TYPE] = None,
postprocess_mel: bool = True,
return_mel: bool = False
):
mel_spec = self.model.ttmel_single(text_input, speaker_id,
speed, vowelizer,
postprocess_mel)
wave = self.vocoder(mel_spec)
if denoise > 0:
wave = self.denoiser(wave, denoise)
if return_mel:
return wave[0].cpu(), mel_spec
return wave[0].cpu()
@torch.inference_mode()
def tts_batch(self,
batch: List[str],
speed: Union[int, float, None] = None,
denoise: float = 0,
speaker_id: int = 0,
vowelizer: Optional[_VOWELIZER_TYPE] = None,
postprocess_mel: bool = True,
return_mel: bool = False
):
mel_list = self.model.ttmel_batch(batch, speaker_id, speed,
vowelizer,
postprocess_mel)
wav_list = []
for mel in mel_list:
wav_inferred = self.vocoder(mel)
if denoise > 0:
wav_inferred = self.denoiser(wav_inferred, denoise)
wav_list.append(wav_inferred[0].cpu())
if return_mel:
wav_list, mel_list
return wav_list
def tts(self,
text_buckw: Union[str, List[str]],
speed: Union[int, float, None] = None,
denoise: float = 0,
speaker_id: int = 0,
batch_size: int = 8,
vowelizer: Optional[_VOWELIZER_TYPE] = None,
postprocess_mel: bool = True,
return_mel: bool = False
):
"""
Args:
text_buckw (str|List[str]): Input text.
speed (float): Speaking speed.
denoise (float): Hifi-GAN Denoiser strength.
speaker_id (int): Speaker Id.
batch_size (int): bacch size for inferrence.
vowelizer (None|str): options [None, `'shakkala'`, `'shakkelha'`].
postprocess_mel (bool): Whether to postprocess.
return_mel (bool): Whether to return the mel spectrogram(s).
"""
# input: string
if isinstance(text_buckw, str):
return self.tts_single(text_buckw, speaker_id=speaker_id,
speed=speed, denoise=denoise,
vowelizer=vowelizer,
postprocess_mel=postprocess_mel,
return_mel=return_mel)
# input: list
assert isinstance(text_buckw, list)
batch = text_buckw
wav_list = []
if batch_size == 1:
for sample in batch:
wav = self.tts_single(sample, speaker_id=speaker_id,
speed=speed, denoise=denoise,
vowelizer=vowelizer,
postprocess_mel=postprocess_mel,
return_mel=return_mel)
wav_list.append(wav)
return wav_list
# infer one batch
if len(batch) <= batch_size:
return self.tts_batch(batch, speaker_id=speaker_id,
speed=speed, denoise=denoise,
vowelizer=vowelizer,
postprocess_mel=postprocess_mel,
return_mel=return_mel)
# batched inference
batches = [batch[k:k+batch_size]
for k in range(0, len(batch), batch_size)]
for batch in batches:
wavs = self.tts_batch(batch, speaker_id=speaker_id,
speed=speed, denoise=denoise,
vowelizer=vowelizer,
postprocess_mel=postprocess_mel,
return_mel=return_mel)
wav_list += wavs
return wav_list