import math, os, re, sys from pathlib import Path import numpy as np import pandas as pd from multiprocessing import Pool from scipy.io import wavfile import tensorflow as tf from tensorflow.keras.utils import Sequence, OrderedEnqueuer from tensorflow.keras import layers from tensorflow.keras.preprocessing.sequence import pad_sequences sys.path.append(os.path.dirname(__file__)) from g2p.g2p_en.g2p import G2p import warnings warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning) def dataloader(fs = 16000,keyword='',wav_path_or_object=None,g2p=None, features='both' # phoneme, g2p_embed, both ... ): phonemes = ["", ] + ['AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0', 'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW0', 'OW1', 'OW2', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH', ' '] p2idx = {p: idx for idx, p in enumerate(phonemes)} idx2p = {idx: p for idx, p in enumerate(phonemes)} fs = fs wav_path_or_object = wav_path_or_object keyword = keyword features = features # g2p = G2p() data = pd.DataFrame(columns=['wav','wav_label', 'text', 'duration', 'label']) target_dict = {} idx = 0 wav = wav_path_or_object keyword = keyword if isinstance(wav_path_or_object, str): duration = float(wavfile.read(wav)[1].shape[-1])/fs else: duration = float(wav_path_or_object.shape[-1])/fs label = 1 anchor_text = wav.split('/')[-2].lower() target_dict[idx] = { 'wav': wav, 'wav_label': anchor_text, 'text': keyword, 'duration': duration, 'label': label } data = data.append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True) # g2p & p2idx by g2p_en package # print(">> Convert word to phoneme") data['phoneme'] = data['text'].apply(lambda x: g2p(re.sub(r"[^a-zA-Z0-9]+", ' ', x))) # print(">> Convert phoneme to index") data['pIndex'] = data['phoneme'].apply(lambda x: [p2idx[t] for t in x]) # print(">> Compute phoneme embedding") data['g2p_embed'] = data['text'].apply(lambda x: g2p.embedding(x)) data['wav_phoneme'] = data['wav_label'].apply(lambda x: g2p(re.sub(r"[^a-zA-Z0-9]+", ' ', x))) data['wav_pIndex'] = data['wav_phoneme'].apply(lambda x: [p2idx[t] for t in x]) # print(data['phoneme']) # Get longest data data = data.sort_values(by='duration').reset_index(drop=True) wav_list = data['wav'].values idx_list = data['pIndex'].values emb_list = data['g2p_embed'].values lab_list = data['label'].values sIdx_list = data['wav_pIndex'].values # Set dataloader params. # len = len(data) maxlen_t = int((int(data['text'].apply(lambda x: len(x)).max() / 10) + 1) * 10) maxlen_a = int((int(data['duration'].values[-1] / 0.5) + 1 ) * fs / 2) maxlen_l = int((int(data['wav_label'].apply(lambda x: len(x)).max() / 10) + 1) * 10) indices = [0] # load inputs if isinstance(wav_path_or_object, str): batch_x = [np.array(wavfile.read(wav_list[i])[1]).astype(np.float32) / 32768.0 for i in indices] else: batch_x = [wav_list[i] / 32768.0 for i in indices] if features == 'both': batch_p = [np.array(idx_list[i]).astype(np.int32) for i in indices] batch_e = [np.array(emb_list[i]).astype(np.float32) for i in indices] else: if features == 'phoneme': batch_y = [np.array(idx_list[i]).astype(np.int32) for i in indices] elif features == 'g2p_embed': batch_y = [np.array(emb_list[i]).astype(np.float32) for i in indices] # load outputs batch_z = [np.array([lab_list[i]]).astype(np.float32) for i in indices] batch_l = [np.array(sIdx_list[i]).astype(np.int32) for i in indices] # padding and masking pad_batch_x = pad_sequences(np.array(batch_x), maxlen=maxlen_a, value=0.0, padding='post', dtype=batch_x[0].dtype) if features == 'both': pad_batch_p = pad_sequences(np.array(batch_p), maxlen=maxlen_t, value=0.0, padding='post', dtype=batch_p[0].dtype) pad_batch_e = pad_sequences(np.array(batch_e), maxlen=maxlen_t, value=0.0, padding='post', dtype=batch_e[0].dtype) else: pad_batch_y = pad_sequences(np.array(batch_y), maxlen=maxlen_t, value=0.0, padding='post', dtype=batch_y[0].dtype) pad_batch_z = pad_sequences(np.array(batch_z), value=0.0, padding='post', dtype=batch_z[0].dtype) pad_batch_l = pad_sequences(np.array(batch_l), maxlen=maxlen_l, value=0.0, padding='post', dtype=batch_l[0].dtype) if features == 'both': return pad_batch_x, pad_batch_p, pad_batch_e, pad_batch_z,batch_l else: return pad_batch_x, pad_batch_y, pad_batch_z,batch_l # def _load_wav(self, wav): # return np.array(wavfile.read(wav)[1]).astype(np.float32) / 32768.0 def convert_sequence_to_dataset(dataloader, wav, text, features): fs = 16000 features=features duration = float(wavfile.read(wav)[1].shape[-1])/fs maxlen_t = int((int(len(text) / 10) + 1) * 10) maxlen_a = int((int(duration / 0.5) + 1 ) * fs / 2) wav_label = wav.split('/')[-2].lower() def data_generator(): if features == 'both': pad_batch_x, pad_batch_p, pad_batch_e, pad_batch_z, pad_batch_l = dataloader yield pad_batch_x, pad_batch_p, pad_batch_e, pad_batch_z, pad_batch_l else: pad_batch_x, pad_batch_y, pad_batch_z, pad_batch_l = dataloader yield pad_batch_x, pad_batch_y, pad_batch_z, pad_batch_l if features == 'both': data_dataset = tf.data.Dataset.from_generator(data_generator, output_signature=( tf.TensorSpec(shape=(None, maxlen_a), dtype=tf.float32), tf.TensorSpec(shape=(None, maxlen_t), dtype=tf.int32), tf.TensorSpec(shape=(None, maxlen_t, 256), dtype=tf.float32), tf.TensorSpec(shape=(None, 1), dtype=tf.float32), tf.TensorSpec(shape=(None, None), dtype=tf.int32),) ) else: data_dataset = tf.data.Dataset.from_generator(data_generator, output_signature=( tf.TensorSpec(shape=(None, maxlen_a), dtype=tf.float32), tf.TensorSpec(shape=(None, maxlen_t) if features == 'phoneme' else (None, maxlen_t, 256), dtype=tf.int32 if features == 'phoneme' else tf.float32), tf.TensorSpec(shape=(None, 1), dtype=tf.float32), tf.TensorSpec(shape=(None, None), dtype=tf.int32),) ) # data_dataset = data_dataset.cache() data_dataset = data_dataset.prefetch(1) return data_dataset