Spaces:
Sleeping
Sleeping
from typing import List, Union, Optional, Literal | |
import text | |
import torch | |
import torch.nn as nn | |
from utils import get_basic_config | |
from vocoder import load_hifigan | |
from vocoder.hifigan.denoiser import Denoiser | |
from .fastpitch.model import FastPitch as _FastPitch | |
from ..diacritizers import load_vowelizer | |
_VOWELIZER_TYPE = Literal['shakkala', 'shakkelha'] | |
def text_collate_fn(batch: List[torch.Tensor]): | |
""" | |
Args: | |
batch: List[text_ids] | |
Returns: | |
text_ids_pad | |
input_lens_sorted | |
reverse_ids | |
""" | |
input_lens_sorted, input_sort_ids = torch.sort( | |
torch.LongTensor([len(x) for x in batch]), descending=True) | |
max_input_len = input_lens_sorted[0] | |
text_ids_pad = torch.LongTensor(len(batch), max_input_len) | |
text_ids_pad.zero_() | |
for i in range(len(input_sort_ids)): | |
text_ids = batch[input_sort_ids[i]] | |
text_ids_pad[i, :text_ids.size(0)] = text_ids | |
return text_ids_pad, input_lens_sorted, input_sort_ids.argsort() | |
class FastPitch(_FastPitch): | |
def __init__(self, | |
checkpoint: str, | |
arabic_in: bool = True, | |
vowelizer: Optional[_VOWELIZER_TYPE] = None, | |
**kwargs): | |
from models.fastpitch import net_config | |
state_dicts = torch.load(checkpoint, map_location='cpu') | |
if 'config' in state_dicts: | |
net_config = state_dicts['config'] | |
super().__init__(**net_config) | |
#self.n_eos = len(EOS_TOKENS) | |
self.arabic_in = arabic_in | |
#if checkpoint is not None: | |
self.load_state_dict(state_dicts['model']) | |
self.config = get_basic_config() | |
self.vowelizers = {} | |
if vowelizer is not None: | |
self.vowelizers[vowelizer] = load_vowelizer(vowelizer, self.config) | |
self.default_vowelizer = vowelizer | |
self.phon_to_id = None | |
if 'symbols' in state_dicts: | |
self.phon_to_id = {phon: i for i, phon in enumerate(state_dicts['symbols'])} | |
self.eval() | |
def device(self): | |
return next(self.parameters()).device | |
def _vowelize(self, utterance: str, vowelizer: Optional[_VOWELIZER_TYPE] = None): | |
vowelizer = self.default_vowelizer if vowelizer is None else vowelizer | |
if vowelizer is not None: | |
if not vowelizer in self.vowelizers: | |
self.vowelizers[vowelizer] = load_vowelizer(vowelizer, self.config) | |
# print(f"loaded: {vowelizer}") | |
utterance_ar = text.buckwalter_to_arabic(utterance) | |
utterance = self.vowelizers[vowelizer].predict(utterance_ar) | |
return utterance | |
def _tokenize(self, utterance: str, vowelizer: Optional[_VOWELIZER_TYPE] = None): | |
utterance = self._vowelize(utterance=utterance, vowelizer=vowelizer) | |
if self.arabic_in: | |
return text.arabic_to_tokens(utterance, append_space=False) | |
return text.buckwalter_to_tokens(utterance, append_space=False) | |
def ttmel_single(self, | |
utterance: str, | |
speed: float = 1, | |
speaker_id: int = 0, | |
vowelizer: Optional[_VOWELIZER_TYPE] = None | |
): | |
tokens = self._tokenize(utterance, vowelizer=vowelizer) | |
token_ids = text.tokens_to_ids(tokens, self.phon_to_id) | |
ids_batch = torch.LongTensor(token_ids).unsqueeze(0).to(self.device) | |
sid = torch.LongTensor([speaker_id]).to(self.device) | |
# Infer spectrogram | |
mel_spec, *_ = self.infer(ids_batch, | |
pace=speed, | |
speaker=speaker_id) | |
mel_spec = mel_spec[0] | |
return mel_spec # [F, T] | |
def ttmel_batch(self, | |
batch: List[str], | |
speed: float = 1, | |
speaker_id: int = 0, | |
vowelizer: Optional[_VOWELIZER_TYPE] = None | |
): | |
batch_tokens = [ | |
self._tokenize(line, vowelizer=vowelizer) | |
for line in batch | |
] | |
batch_ids = [torch.LongTensor( | |
text.tokens_to_ids(tokens, self.phon_to_id) | |
) for tokens in batch_tokens] | |
batch = text_collate_fn(batch_ids) | |
( | |
batch_ids_padded, batch_lens_sorted, | |
reverse_sort_ids | |
) = batch | |
batch_ids_padded = batch_ids_padded.to(self.device) | |
batch_lens_sorted = batch_lens_sorted.to(self.device) | |
batch_sids = batch_lens_sorted*0 + speaker_id | |
y_pred = self.infer(batch_ids_padded, pace=speed, speaker=speaker_id) | |
mel_outputs, mel_specgram_lengths, *_ = y_pred | |
mel_list = [] | |
for i, id in enumerate(reverse_sort_ids): | |
mel = mel_outputs[id, :, :mel_specgram_lengths[id]] | |
mel_list.append(mel) | |
return mel_list | |
def ttmel(self, | |
text_input: Union[str, List[str]], | |
speed: float = 1, | |
speaker_id: int = 0, | |
batch_size: int = 1, | |
vowelizer: Optional[_VOWELIZER_TYPE] = None | |
): | |
# input: string | |
if isinstance(text_input, str): | |
return self.ttmel_single(text_input, speed=speed, | |
speaker_id=speaker_id, | |
vowelizer=vowelizer) | |
# input: list | |
assert isinstance(text_input, list) | |
batch = text_input | |
mel_list = [] | |
if batch_size == 1: | |
for sample in batch: | |
mel = self.ttmel_single(sample, speed=speed, | |
speaker_id=speaker_id, | |
vowelizer=vowelizer) | |
mel_list.append(mel) | |
return mel_list | |
# infer one batch | |
if len(batch) <= batch_size: | |
return self.ttmel_batch(batch, speed=speed, | |
speaker_id=speaker_id, | |
vowelizer=vowelizer) | |
# batched inference | |
batches = [batch[k:k+batch_size] | |
for k in range(0, len(batch), batch_size)] | |
for batch in batches: | |
mels = self.ttmel_batch(batch, speed=speed, | |
speaker_id=speaker_id, | |
vowelizer=vowelizer) | |
mel_list += mels | |
return mel_list | |
class FastPitch2Wave(nn.Module): | |
def __init__(self, | |
model_sd_path: str, | |
vocoder_sd: Optional[str] = None, | |
vocoder_config: Optional[str] = None, | |
vowelizer: Optional[_VOWELIZER_TYPE] = None, | |
arabic_in: bool = True, | |
): | |
super().__init__() | |
# from models.fastpitch import net_config | |
state_dicts = torch.load(model_sd_path, map_location='cpu') | |
# if 'config' in state_dicts: | |
# net_config = state_dicts['config'] | |
model = FastPitch(model_sd_path, | |
arabic_in=arabic_in, | |
vowelizer=vowelizer) | |
model.load_state_dict(state_dicts['model'], strict=False) | |
self.model = model | |
if vocoder_sd is None or vocoder_config is None: | |
config = get_basic_config() | |
vocoder_sd = config.vocoder_state_path | |
vocoder_config = config.vocoder_config_path | |
vocoder = load_hifigan(vocoder_sd, vocoder_config) | |
self.vocoder = vocoder | |
self.denoiser = Denoiser(vocoder) | |
self.eval() | |
def device(self): | |
return next(self.parameters()).device | |
def forward(self, x): | |
return x | |
def tts_single(self, | |
text_buckw: str, | |
speed: float = 1, | |
speaker_id: int = 0, | |
denoise: float = 0, | |
vowelizer: Optional[_VOWELIZER_TYPE] = None, | |
return_mel=False): | |
mel_spec = self.model.ttmel_single(text_buckw, speed, | |
speaker_id, vowelizer) | |
wave = self.vocoder(mel_spec) | |
if denoise > 0: | |
wave = self.denoiser(wave, denoise) | |
if return_mel: | |
return wave[0].cpu(), mel_spec | |
return wave[0].cpu() | |
def tts_batch(self, | |
batch: List[str], | |
speed: float = 1, | |
speaker_id: int = 0, | |
denoise: float = 0, | |
vowelizer: Optional[_VOWELIZER_TYPE] = None, | |
return_mel: bool = False | |
): | |
mel_list = self.model.ttmel_batch(batch, speed, | |
speaker_id, vowelizer) | |
wav_list = [] | |
for mel in mel_list: | |
wav_inferred = self.vocoder(mel) | |
if denoise > 0: | |
wav_inferred = self.denoiser(wav_inferred, denoise) | |
wav_list.append(wav_inferred[0].cpu()) | |
if return_mel: | |
wav_list, mel_list | |
return wav_list | |
def tts(self, | |
text_input: Union[str, List[str]], | |
speed: float = 1., | |
denoise: float = 0, | |
speaker_id: int = 0, | |
batch_size: int = 2, | |
vowelizer: Optional[_VOWELIZER_TYPE] = None, | |
return_mel: bool = False): | |
""" | |
Args: | |
text_input (str|List[str]): Input text. | |
speed (float): Speaking speed. | |
denoise (float): Hifi-GAN Denoiser strength. | |
speaker_id (int): Speaker Id. | |
batch_size (int): bacch size for inferrence. | |
vowelizer (None|str): options [None, `'shakkala'`, `'shakkelha'`]. | |
return_mel (bool): Whether to return the mel spectrogram(s). | |
""" | |
# input: string | |
if isinstance(text_input, str): | |
return self.tts_single(text_input, speaker_id=speaker_id, | |
speed=speed, denoise=denoise, | |
vowelizer=vowelizer, | |
return_mel=return_mel) | |
# input: list | |
assert isinstance(text_input, list) | |
batch = text_input | |
wav_list = [] | |
if batch_size == 1: | |
for sample in batch: | |
wav = self.tts_single(sample, speaker_id=speaker_id, | |
speed=speed, denoise=denoise, | |
vowelizer=vowelizer, | |
return_mel=return_mel) | |
wav_list.append(wav) | |
return wav_list | |
# infer one batch | |
if len(batch) <= batch_size: | |
return self.tts_batch(batch, speaker_id=speaker_id, | |
speed=speed, denoise=denoise, | |
vowelizer=vowelizer, | |
return_mel=return_mel) | |
# batched inference | |
batches = [batch[k:k+batch_size] | |
for k in range(0, len(batch), batch_size)] | |
for batch in batches: | |
wavs = self.tts_batch(batch, speaker_id=speaker_id, | |
speed=speed, denoise=denoise, | |
vowelizer=vowelizer, | |
return_mel=return_mel) | |
wav_list += wavs | |
return wav_list | |