import os import re import text import torch import torchaudio import numpy as np from torch.utils.data import Dataset from utils import read_lines_from_file, progbar from utils.audio import MelSpectrogram def text_mel_collate_fn(batch, pad_value=0): """ Args: batch: List[(text_ids, mel_spec)] Returns: text_ids_pad input_lengths mel_pad gate_pad output_lengths """ input_lens_sorted, input_sort_ids = torch.sort( torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True) max_input_len = input_lens_sorted[0] num_mels = batch[0][1].size(0) max_target_len = max([x[1].size(1) for x in batch]) text_ids_pad = torch.LongTensor(len(batch), max_input_len) mel_pad = torch.FloatTensor(len(batch), num_mels, max_target_len) gate_pad = torch.FloatTensor(len(batch), max_target_len) output_lengths = torch.LongTensor(len(batch)) text_ids_pad.zero_(), mel_pad.fill_(pad_value), gate_pad.zero_() for i in range(len(input_sort_ids)): text_ids, mel = batch[input_sort_ids[i]] text_ids_pad[i, :text_ids.size(0)] = text_ids mel_pad[i, :, :mel.size(1)] = mel gate_pad[i, mel.size(1)-1:] = 1 output_lengths[i] = mel.size(1) return text_ids_pad, input_lens_sorted, \ mel_pad, gate_pad, output_lengths def normalize_pitch(pitch, mean: float = 130.05478, std: float = 22.86267): zeros = (pitch == 0.0) pitch -= mean pitch /= std pitch[zeros] = 0.0 return pitch def remove_silence(energy_per_frame: torch.Tensor, thresh: float = -10.0): keep = energy_per_frame > thresh # keep silence at the end i = keep.size(0)-1 while not keep[i] and i > 0: keep[i] = True i -= 1 return keep def make_dataset_from_subdirs(folder_path): samples = [] for root, _, fnames in os.walk(folder_path, followlinks=True): for fname in fnames: if fname.endswith('.wav'): samples.append(os.path.join(root, fname)) return samples def _process_line(label_pattern: str, line: str): match = re.search(label_pattern, line) if match is None: raise Exception(f'no match for line: {line}') res_dict = match.groupdict() if 'arabic' in res_dict: phonemes = text.arabic_to_phonemes(res_dict['arabic']) elif 'phonemes' in res_dict: phonemes = res_dict['phonemes'] elif 'buckwalter' in res_dict: phonemes = text.buckwalter_to_phonemes(res_dict['buckwalter']) if 'filename' in res_dict: filename = res_dict['filename'] elif 'filestem' in res_dict: filename = f"{res_dict['filestem']}.wav" return phonemes, filename class ArabDataset(Dataset): def __init__(self, txtpath: str = 'tts data sample/text.txt', wavpath: str = './', label_pattern: str = '"(?P.*)" "(?P.*)"', sr_target: int = 22050 ): super().__init__() self.mel_fn = MelSpectrogram() self.wav_path = wavpath self.label_pattern = label_pattern self.sr_target = sr_target self.data = self._process_textfile(txtpath) def _process_textfile(self, txtpath: str): lines = read_lines_from_file(txtpath) phoneme_mel_list = [] for l_idx, line in enumerate(progbar(lines)): try: phonemes, filename = _process_line( self.label_pattern, line) except: print(f'invalid line {l_idx}: {line}') continue fpath = os.path.join(self.wav_path, filename) if not os.path.exists(fpath): print(f"{fpath} does not exist") continue try: tokens = text.phonemes_to_tokens(phonemes) token_ids = text.tokens_to_ids(tokens) except: print(f'invalid phonemes at line {l_idx}: {line}') continue phoneme_mel_list.append((torch.LongTensor(token_ids), fpath)) return phoneme_mel_list def _get_mel_from_fpath(self, fpath): wave, sr = torchaudio.load(fpath) if sr != self.sr_target: wave = torchaudio.functional.resample(wave, sr, self.sr_target, 64) mel_raw = self.mel_fn(wave) mel_log = mel_raw.clamp_min(1e-5).log().squeeze() energy_per_frame = mel_log.mean(0) mel_log = mel_log[:, remove_silence(energy_per_frame)] return mel_log def __len__(self): return len(self.data) def __getitem__(self, idx): phonemes, fpath = self.data[idx] mel_log = self._get_mel_from_fpath(fpath) return phonemes, mel_log class ArabDataset4FastPitch(Dataset): def __init__(self, txtpath: str = './data/train_phon.txt', wavpath: str = 'G:/data/arabic-speech-corpus/wav_new', label_pattern: str = '"(?P.*)" "(?P.*)"', f0_dict_path: str = './data/pitch_dict.pt', f0_mean: float = 130.05478, f0_std: float = 22.86267, sr_target: int = 22050 ): super().__init__() from models.fastpitch.fastpitch.data_function import BetaBinomialInterpolator self.mel_fn = MelSpectrogram() self.wav_path = wavpath self.label_pattern = label_pattern self.sr_target = sr_target self.f0_dict = torch.load(f0_dict_path) self.f0_mean = f0_mean self.f0_std = f0_std self.betabinomial_interpolator = BetaBinomialInterpolator() self.data = self._process_textfile(txtpath) def _process_textfile(self, txtpath: str): lines = read_lines_from_file(txtpath) phoneme_mel_pitch_list = [] for l_idx, line in enumerate(progbar(lines)): try: phonemes, filename = _process_line( self.label_pattern, line) except: print(f'invalid line {l_idx}: {line}') continue fpath = os.path.join(self.wav_path, filename) if not os.path.exists(fpath): print(f"{fpath} does not exist") continue try: tokens = text.phonemes_to_tokens(phonemes) token_ids = text.tokens_to_ids(tokens) except: print(f'invalid phonemes at line {l_idx}: {line}') continue wav_name = os.path.basename(fpath) pitch_mel = self.f0_dict[wav_name][None] phoneme_mel_pitch_list.append( (torch.LongTensor(token_ids), fpath, pitch_mel)) return phoneme_mel_pitch_list def __len__(self): return len(self.data) def __getitem__(self, idx): phonemes, fpath, pitch_mel = self.data[idx] wave, sr = torchaudio.load(fpath) if sr != self.sr_target: wave = torchaudio.functional.resample(wave, sr, self.sr_target, 64) mel_raw = self.mel_fn(wave) mel_log = mel_raw.clamp_min(1e-5).log().squeeze() keep = remove_silence(mel_log.mean(0)) mel_log = mel_log[:, keep] pitch_mel = normalize_pitch(pitch_mel[:,keep], self.f0_mean, self.f0_std) energy = torch.norm(mel_log.float(), dim=0, p=2) attn_prior = torch.from_numpy( self.betabinomial_interpolator(mel_log.size(1), len(phonemes))) speaker = None return (phonemes, mel_log, len(phonemes), pitch_mel, energy, speaker, attn_prior, fpath) class DynBatchDataset(ArabDataset4FastPitch): def __init__(self, txtpath: str = './data/train_phon.txt', wavpath: str = 'G:/data/arabic-speech-corpus/wav_new', label_pattern: str = '"(?P.*)" "(?P.*)"', f0_dict_path: str = './data/pitch_dict.pt', f0_mean: float = 130.05478, f0_std: float = 22.86267, max_lengths: list[int] = [1000, 1300, 1850, 30000], batch_sizes: list[int] = [10, 8, 6, 4], ): super().__init__(txtpath=txtpath, wavpath=wavpath, label_pattern=label_pattern, f0_dict_path=f0_dict_path, f0_mean=f0_mean, f0_std=f0_std) self.max_lens = [0,] + max_lengths self.b_sizes = batch_sizes self.id_batches = [] self.shuffle() def shuffle(self): lens = [x[2].size(1) for x in self.data] # x[2]: pitch ids_per_bs = {b: [] for b in self.b_sizes} for i, mel_len in enumerate(lens): b_idx = next(i for i in range(len(self.max_lens)-1) if self.max_lens[i] <= mel_len < self.max_lens[i+1]) ids_per_bs[self.b_sizes[b_idx]].append(i) id_batches = [] for bs, ids in ids_per_bs.items(): np.random.shuffle(ids) ids_chnk = [ids[i:i+bs] for i in range(0, len(ids), bs)] id_batches += ids_chnk self.id_batches = id_batches def __len__(self): return len(self.id_batches) def __getitem__(self, idx): batch = [super(DynBatchDataset, self).__getitem__(idx) for idx in self.id_batches[idx]] return batch