Spaces:
Sleeping
Sleeping
from typing import List, Union, Optional, Literal | |
import text | |
import torch | |
import torch.nn as nn | |
from .tacotron2_ms import Tacotron2MS | |
from text.symbols import EOS_TOKENS, SEPARATOR_TOKEN | |
from utils import get_basic_config | |
from vocoder import load_hifigan | |
from vocoder.hifigan.denoiser import Denoiser | |
from ..diacritizers import load_vowelizer | |
_VOWELIZER_TYPE = Literal['shakkala', 'shakkelha'] | |
def text_collate_fn(batch: List[torch.Tensor]): | |
""" | |
Args: | |
batch: List[text_ids] | |
Returns: | |
text_ids_pad | |
input_lens_sorted | |
reverse_ids | |
""" | |
input_lens_sorted, input_sort_ids = torch.sort( | |
torch.LongTensor([len(x) for x in batch]), descending=True) | |
max_input_len = input_lens_sorted[0] | |
text_ids_pad = torch.LongTensor(len(batch), max_input_len) | |
text_ids_pad.zero_() | |
for i in range(len(input_sort_ids)): | |
text_ids = batch[input_sort_ids[i]] | |
text_ids_pad[i, :text_ids.size(0)] = text_ids | |
return text_ids_pad, input_lens_sorted, input_sort_ids.argsort() | |
def needs_postprocessing(token: str): | |
return token not in [ | |
'a', 'i', 'u', 'aa', 'ii', 'uu', 'n', 'm', 'h'] | |
def truncate_mel(mel_spec: torch.Tensor, ps_end): | |
ps_end_max = ps_end.max() | |
n_end = next(i for i in range(len(ps_end)) if ps_end[i] >= 0.8*ps_end_max) | |
mel_cut = mel_spec[:, :n_end] | |
mel_cut = torch.nn.functional.pad(mel_cut, (0, 3), mode='replicate') | |
return mel_cut | |
def resize_mel(mel: torch.Tensor, | |
rate: Union[int, float] = 1.0, | |
mode: str = 'bicubic'): | |
""" | |
Args: | |
mel: mel spectrogram [num_mels, spec_length] | |
Returns: | |
resized_mel [num_mels, new_spec_length] | |
""" | |
Nf, Nt = mel.shape[-2:] | |
Nt_new = int(1 / rate * Nt) | |
if Nt == Nt_new: | |
return mel | |
mel_res = torch.nn.functional.interpolate(mel[None, None, ...], | |
(Nf, Nt_new), mode=mode)[0, 0] | |
return mel_res | |
class Tacotron2(Tacotron2MS): | |
def __init__(self, | |
checkpoint: str = None, | |
n_symbol: int = 40, | |
decoder_max_step: int = 3000, | |
arabic_in: bool = True, | |
vowelizer: Optional[_VOWELIZER_TYPE] = None, | |
**kwargs): | |
super().__init__(n_symbol=n_symbol, | |
decoder_max_step=decoder_max_step, | |
**kwargs) | |
self.n_eos = len(EOS_TOKENS) | |
self.arabic_in = arabic_in | |
if checkpoint is not None: | |
state_dicts = torch.load(checkpoint, map_location='cpu') | |
self.load_state_dict(state_dicts['model']) | |
self.config = get_basic_config() | |
self.vowelizers = {} | |
if vowelizer is not None: | |
self.vowelizers[vowelizer] = load_vowelizer(vowelizer, self.config) | |
self.default_vowelizer = vowelizer | |
self.phon_to_id = None | |
if checkpoint is not None and 'symbols' in state_dicts: | |
self.phon_to_id = {phon: i for i, phon in enumerate(state_dicts['symbols'])} | |
self.eval() | |
def device(self): | |
return next(self.parameters()).device | |
def _vowelize(self, utterance: str, vowelizer: Optional[_VOWELIZER_TYPE] = None): | |
vowelizer = self.default_vowelizer if vowelizer is None else vowelizer | |
if vowelizer is not None: | |
if not vowelizer in self.vowelizers: | |
self.vowelizers[vowelizer] = load_vowelizer(vowelizer, self.config) | |
# print(f"loaded: {vowelizer}") | |
utterance_ar = text.buckwalter_to_arabic(utterance) | |
utterance = self.vowelizers[vowelizer].predict(utterance_ar) | |
return utterance | |
def _tokenize(self, utterance: str, vowelizer: Optional[_VOWELIZER_TYPE] = None): | |
utterance = self._vowelize(utterance=utterance, vowelizer=vowelizer) | |
if self.arabic_in: | |
return text.arabic_to_tokens(utterance) | |
return text.buckwalter_to_tokens(utterance) | |
def ttmel_single(self, | |
utterance: str, | |
speaker_id: int = 0, | |
speed: Union[int, float, None] = None, | |
vowelizer: Optional[_VOWELIZER_TYPE] = None, | |
postprocess_mel: bool = True, | |
): | |
tokens = self._tokenize(utterance, vowelizer=vowelizer) | |
process_mel = False | |
if postprocess_mel and needs_postprocessing(tokens[-self.n_eos-1]): | |
tokens.insert(-self.n_eos, SEPARATOR_TOKEN) | |
process_mel = True | |
token_ids = text.tokens_to_ids(tokens, self.phon_to_id) | |
ids_batch = torch.LongTensor(token_ids).unsqueeze(0).to(self.device) | |
sid = torch.LongTensor([speaker_id]).to(self.device) | |
# Infer spectrogram and wave | |
mel_spec, _, alignments = self.infer(ids_batch, sid) | |
mel_spec = mel_spec[0] | |
if process_mel: | |
mel_spec = truncate_mel(mel_spec, alignments[0, :, -self.n_eos-1]) | |
if speed is not None: | |
mel_spec = resize_mel(mel_spec, rate=speed) | |
return mel_spec # [F, T] | |
def ttmel_batch(self, | |
batch: List[str], | |
speaker_id: int = 0, | |
speed: Union[int, float, None] = None, | |
vowelizer: Optional[_VOWELIZER_TYPE] = None, | |
postprocess_mel: bool = True | |
): | |
batch_tokens = [self._tokenize(line, vowelizer=vowelizer) for line in batch] | |
list_postprocess = [] | |
if postprocess_mel: | |
for i in range(len(batch_tokens)): | |
process_mel = False | |
if needs_postprocessing(batch_tokens[i][-self.n_eos-1]): | |
batch_tokens[i].insert(-self.n_eos, SEPARATOR_TOKEN) | |
process_mel = True | |
list_postprocess.append(process_mel) | |
batch_ids = [torch.LongTensor( | |
text.tokens_to_ids(tokens, self.phon_to_id) | |
) for tokens in batch_tokens] | |
batch = text_collate_fn(batch_ids) | |
( | |
batch_ids_padded, batch_lens_sorted, | |
reverse_sort_ids | |
) = batch | |
batch_ids_padded = batch_ids_padded.to(self.device) | |
batch_lens_sorted = batch_lens_sorted.to(self.device) | |
batch_sids = batch_lens_sorted*0 + speaker_id | |
y_pred = self.infer(batch_ids_padded, batch_sids, batch_lens_sorted) | |
mel_outputs_postnet, mel_specgram_lengths, alignments = y_pred | |
mel_list = [] | |
for i, id in enumerate(reverse_sort_ids): | |
mel = mel_outputs_postnet[id, :, :mel_specgram_lengths[id]] | |
if postprocess_mel and list_postprocess[i]: | |
ps_end = alignments[id, | |
:mel_specgram_lengths[id], | |
batch_lens_sorted[id]-self.n_eos-1] | |
mel = truncate_mel(mel, ps_end) | |
if speed is not None: | |
mel = resize_mel(mel, rate=speed) | |
mel_list.append(mel) | |
return mel_list | |
def ttmel(self, | |
text_input: Union[str, List[str]], | |
speaker_id: int = 0, | |
speed: Union[int, float, None] = None, | |
batch_size: int = 8, | |
vowelizer: Optional[_VOWELIZER_TYPE] = None, | |
postprocess_mel: bool = True | |
): | |
# input: string | |
if isinstance(text_input, str): | |
return self.ttmel_single(text_input, speaker_id, | |
speed, vowelizer, | |
postprocess_mel) | |
# input: list | |
assert isinstance(text_input, list) | |
batch = text_input | |
mel_list = [] | |
if batch_size == 1: | |
for sample in batch: | |
mel = self.ttmel_single(sample, speaker_id, | |
speed, vowelizer, | |
postprocess_mel) | |
mel_list.append(mel) | |
return mel_list | |
# infer one batch | |
if len(batch) <= batch_size: | |
return self.ttmel_batch(batch, speaker_id, | |
speed, vowelizer, | |
postprocess_mel) | |
# batched inference | |
batches = [batch[k:k+batch_size] | |
for k in range(0, len(batch), batch_size)] | |
for batch in batches: | |
mels = self.ttmel_batch(batch, speaker_id, | |
speed, vowelizer, | |
postprocess_mel) | |
mel_list += mels | |
return mel_list | |
class Tacotron2Wave(nn.Module): | |
def __init__(self, | |
model_sd_path: str, | |
vocoder_sd: Optional[str] = None, | |
vocoder_config: Optional[str] = None, | |
vowelizer: Optional[_VOWELIZER_TYPE] = None, | |
arabic_in: bool = True, | |
n_symbol: int = 40 | |
): | |
super().__init__() | |
model = Tacotron2(n_symbol=n_symbol, | |
arabic_in=arabic_in, | |
vowelizer=vowelizer) | |
state_dicts = torch.load(model_sd_path, map_location='cpu') | |
model.load_state_dict(state_dicts['model']) | |
self.model = model | |
if vocoder_sd is None or vocoder_config is None: | |
config = get_basic_config() | |
vocoder_sd = config.vocoder_state_path | |
vocoder_config = config.vocoder_config_path | |
vocoder = load_hifigan(vocoder_sd, vocoder_config) | |
self.vocoder = vocoder | |
self.denoiser = Denoiser(vocoder) | |
self.eval() | |
def device(self): | |
return next(self.parameters()).device | |
def forward(self, x): | |
return x | |
def tts_single(self, | |
text_input: str, | |
speed: Union[int, float, None] = None, | |
speaker_id: int = 0, | |
denoise: float = 0, | |
vowelizer: Optional[_VOWELIZER_TYPE] = None, | |
postprocess_mel: bool = True, | |
return_mel: bool = False | |
): | |
mel_spec = self.model.ttmel_single(text_input, speaker_id, | |
speed, vowelizer, | |
postprocess_mel) | |
wave = self.vocoder(mel_spec) | |
if denoise > 0: | |
wave = self.denoiser(wave, denoise) | |
if return_mel: | |
return wave[0].cpu(), mel_spec | |
return wave[0].cpu() | |
def tts_batch(self, | |
batch: List[str], | |
speed: Union[int, float, None] = None, | |
denoise: float = 0, | |
speaker_id: int = 0, | |
vowelizer: Optional[_VOWELIZER_TYPE] = None, | |
postprocess_mel: bool = True, | |
return_mel: bool = False | |
): | |
mel_list = self.model.ttmel_batch(batch, speaker_id, speed, | |
vowelizer, | |
postprocess_mel) | |
wav_list = [] | |
for mel in mel_list: | |
wav_inferred = self.vocoder(mel) | |
if denoise > 0: | |
wav_inferred = self.denoiser(wav_inferred, denoise) | |
wav_list.append(wav_inferred[0].cpu()) | |
if return_mel: | |
wav_list, mel_list | |
return wav_list | |
def tts(self, | |
text_buckw: Union[str, List[str]], | |
speed: Union[int, float, None] = None, | |
denoise: float = 0, | |
speaker_id: int = 0, | |
batch_size: int = 8, | |
vowelizer: Optional[_VOWELIZER_TYPE] = None, | |
postprocess_mel: bool = True, | |
return_mel: bool = False | |
): | |
""" | |
Args: | |
text_buckw (str|List[str]): Input text. | |
speed (float): Speaking speed. | |
denoise (float): Hifi-GAN Denoiser strength. | |
speaker_id (int): Speaker Id. | |
batch_size (int): bacch size for inferrence. | |
vowelizer (None|str): options [None, `'shakkala'`, `'shakkelha'`]. | |
postprocess_mel (bool): Whether to postprocess. | |
return_mel (bool): Whether to return the mel spectrogram(s). | |
""" | |
# input: string | |
if isinstance(text_buckw, str): | |
return self.tts_single(text_buckw, speaker_id=speaker_id, | |
speed=speed, denoise=denoise, | |
vowelizer=vowelizer, | |
postprocess_mel=postprocess_mel, | |
return_mel=return_mel) | |
# input: list | |
assert isinstance(text_buckw, list) | |
batch = text_buckw | |
wav_list = [] | |
if batch_size == 1: | |
for sample in batch: | |
wav = self.tts_single(sample, speaker_id=speaker_id, | |
speed=speed, denoise=denoise, | |
vowelizer=vowelizer, | |
postprocess_mel=postprocess_mel, | |
return_mel=return_mel) | |
wav_list.append(wav) | |
return wav_list | |
# infer one batch | |
if len(batch) <= batch_size: | |
return self.tts_batch(batch, speaker_id=speaker_id, | |
speed=speed, denoise=denoise, | |
vowelizer=vowelizer, | |
postprocess_mel=postprocess_mel, | |
return_mel=return_mel) | |
# batched inference | |
batches = [batch[k:k+batch_size] | |
for k in range(0, len(batch), batch_size)] | |
for batch in batches: | |
wavs = self.tts_batch(batch, speaker_id=speaker_id, | |
speed=speed, denoise=denoise, | |
vowelizer=vowelizer, | |
postprocess_mel=postprocess_mel, | |
return_mel=return_mel) | |
wav_list += wavs | |
return wav_list | |