#coding: utf-8 import os import os.path as osp import time import random import numpy as np import random import soundfile as sf import librosa import torch from torch import nn import torch.nn.functional as F import torchaudio from torch.utils.data import DataLoader import logging logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) import pandas as pd _pad = "$" _punctuation = ';:,.!?¡¿—…"«»“” ' _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" # Export all symbols: symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) dicts = {} for i in range(len((symbols))): dicts[symbols[i]] = i class TextCleaner: def __init__(self, dummy=None): self.word_index_dictionary = dicts def __call__(self, text): indexes = [] for char in text: try: indexes.append(self.word_index_dictionary[char]) except KeyError: print(text) return indexes np.random.seed(1) random.seed(1) SPECT_PARAMS = { "n_fft": 2048, "win_length": 1200, "hop_length": 300 } MEL_PARAMS = { "n_mels": 80, } to_mel = torchaudio.transforms.MelSpectrogram( n_mels=80, n_fft=2048, win_length=1200, hop_length=300) mean, std = -4, 4 def preprocess(wave): wave_tensor = torch.from_numpy(wave).float() mel_tensor = to_mel(wave_tensor) mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std return mel_tensor class FilePathDataset(torch.utils.data.Dataset): def __init__(self, data_list, root_path, sr=24000, data_augmentation=False, validation=False, OOD_data="Data/OOD_texts.txt", min_length=50, ): spect_params = SPECT_PARAMS mel_params = MEL_PARAMS _data_list = [l.strip().split('|') for l in data_list] self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list] self.text_cleaner = TextCleaner() self.sr = sr self.df = pd.DataFrame(self.data_list) self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS) self.mean, self.std = -4, 4 self.data_augmentation = data_augmentation and (not validation) self.max_mel_length = 192 self.min_length = min_length with open(OOD_data, 'r', encoding='utf-8') as f: tl = f.readlines() idx = 1 if '.wav' in tl[0].split('|')[0] else 0 self.ptexts = [t.split('|')[idx] for t in tl] self.root_path = root_path def __len__(self): return len(self.data_list) def __getitem__(self, idx): data = self.data_list[idx] path = data[0] wave, text_tensor, speaker_id = self._load_tensor(data) mel_tensor = preprocess(wave).squeeze() acoustic_feature = mel_tensor.squeeze() length_feature = acoustic_feature.size(1) acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)] # get reference sample ref_data = (self.df[self.df[2] == str(speaker_id)]).sample(n=1).iloc[0].tolist() ref_mel_tensor, ref_label = self._load_data(ref_data[:3]) # get OOD text ps = "" while len(ps) < self.min_length: rand_idx = np.random.randint(0, len(self.ptexts) - 1) ps = self.ptexts[rand_idx] text = self.text_cleaner(ps) text.insert(0, 0) text.append(0) ref_text = torch.LongTensor(text) return speaker_id, acoustic_feature, text_tensor, ref_text, ref_mel_tensor, ref_label, path, wave def _load_tensor(self, data): wave_path, text, speaker_id = data speaker_id = int(speaker_id) wave, sr = sf.read(osp.join(self.root_path, wave_path)) if wave.shape[-1] == 2: wave = wave[:, 0].squeeze() if sr != 24000: wave = librosa.resample(wave, orig_sr=sr, target_sr=24000) print(wave_path, sr) wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0) text = self.text_cleaner(text) text.insert(0, 0) text.append(0) text = torch.LongTensor(text) return wave, text, speaker_id def _load_data(self, data): wave, text_tensor, speaker_id = self._load_tensor(data) mel_tensor = preprocess(wave).squeeze() mel_length = mel_tensor.size(1) if mel_length > self.max_mel_length: random_start = np.random.randint(0, mel_length - self.max_mel_length) mel_tensor = mel_tensor[:, random_start:random_start + self.max_mel_length] return mel_tensor, speaker_id class Collater(object): """ Args: adaptive_batch_size (bool): if true, decrease batch size when long data comes. """ def __init__(self, return_wave=False): self.text_pad_index = 0 self.min_mel_length = 192 self.max_mel_length = 192 self.return_wave = return_wave def __call__(self, batch): # batch[0] = wave, mel, text, f0, speakerid batch_size = len(batch) # sort by mel length lengths = [b[1].shape[1] for b in batch] batch_indexes = np.argsort(lengths)[::-1] batch = [batch[bid] for bid in batch_indexes] nmels = batch[0][1].size(0) max_mel_length = max([b[1].shape[1] for b in batch]) max_text_length = max([b[2].shape[0] for b in batch]) max_rtext_length = max([b[3].shape[0] for b in batch]) labels = torch.zeros((batch_size)).long() mels = torch.zeros((batch_size, nmels, max_mel_length)).float() texts = torch.zeros((batch_size, max_text_length)).long() ref_texts = torch.zeros((batch_size, max_rtext_length)).long() input_lengths = torch.zeros(batch_size).long() ref_lengths = torch.zeros(batch_size).long() output_lengths = torch.zeros(batch_size).long() ref_mels = torch.zeros((batch_size, nmels, self.max_mel_length)).float() ref_labels = torch.zeros((batch_size)).long() paths = ['' for _ in range(batch_size)] waves = [None for _ in range(batch_size)] for bid, (label, mel, text, ref_text, ref_mel, ref_label, path, wave) in enumerate(batch): mel_size = mel.size(1) text_size = text.size(0) rtext_size = ref_text.size(0) labels[bid] = label mels[bid, :, :mel_size] = mel texts[bid, :text_size] = text ref_texts[bid, :rtext_size] = ref_text input_lengths[bid] = text_size ref_lengths[bid] = rtext_size output_lengths[bid] = mel_size paths[bid] = path ref_mel_size = ref_mel.size(1) ref_mels[bid, :, :ref_mel_size] = ref_mel ref_labels[bid] = ref_label waves[bid] = wave return waves, texts, input_lengths, ref_texts, ref_lengths, mels, output_lengths, ref_mels def build_dataloader(path_list, root_path, validation=False, OOD_data="Data/OOD_texts.txt", min_length=50, batch_size=4, num_workers=1, device='cpu', collate_config={}, dataset_config={}): dataset = FilePathDataset(path_list, root_path, OOD_data=OOD_data, min_length=min_length, validation=validation, **dataset_config) collate_fn = Collater(**collate_config) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=(not validation), num_workers=num_workers, drop_last=True, collate_fn=collate_fn, pin_memory=(device != 'cpu')) return data_loader