wasmdashai's picture
first commit
7694c84
raw
history blame
9.73 kB
import os
import re
import text
import torch
import torchaudio
import numpy as np
from torch.utils.data import Dataset
from utils import read_lines_from_file, progbar
from utils.audio import MelSpectrogram
def text_mel_collate_fn(batch, pad_value=0):
"""
Args:
batch: List[(text_ids, mel_spec)]
Returns:
text_ids_pad
input_lengths
mel_pad
gate_pad
output_lengths
"""
input_lens_sorted, input_sort_ids = torch.sort(
torch.LongTensor([len(x[0]) for x in batch]),
dim=0, descending=True)
max_input_len = input_lens_sorted[0]
num_mels = batch[0][1].size(0)
max_target_len = max([x[1].size(1) for x in batch])
text_ids_pad = torch.LongTensor(len(batch), max_input_len)
mel_pad = torch.FloatTensor(len(batch), num_mels, max_target_len)
gate_pad = torch.FloatTensor(len(batch), max_target_len)
output_lengths = torch.LongTensor(len(batch))
text_ids_pad.zero_(), mel_pad.fill_(pad_value), gate_pad.zero_()
for i in range(len(input_sort_ids)):
text_ids, mel = batch[input_sort_ids[i]]
text_ids_pad[i, :text_ids.size(0)] = text_ids
mel_pad[i, :, :mel.size(1)] = mel
gate_pad[i, mel.size(1)-1:] = 1
output_lengths[i] = mel.size(1)
return text_ids_pad, input_lens_sorted, \
mel_pad, gate_pad, output_lengths
def normalize_pitch(pitch,
mean: float = 130.05478,
std: float = 22.86267):
zeros = (pitch == 0.0)
pitch -= mean
pitch /= std
pitch[zeros] = 0.0
return pitch
def remove_silence(energy_per_frame: torch.Tensor,
thresh: float = -10.0):
keep = energy_per_frame > thresh
# keep silence at the end
i = keep.size(0)-1
while not keep[i] and i > 0:
keep[i] = True
i -= 1
return keep
def make_dataset_from_subdirs(folder_path):
samples = []
for root, _, fnames in os.walk(folder_path, followlinks=True):
for fname in fnames:
if fname.endswith('.wav'):
samples.append(os.path.join(root, fname))
return samples
def _process_line(label_pattern: str, line: str):
match = re.search(label_pattern, line)
if match is None:
raise Exception(f'no match for line: {line}')
res_dict = match.groupdict()
if 'arabic' in res_dict:
phonemes = text.arabic_to_phonemes(res_dict['arabic'])
elif 'phonemes' in res_dict:
phonemes = res_dict['phonemes']
elif 'buckwalter' in res_dict:
phonemes = text.buckwalter_to_phonemes(res_dict['buckwalter'])
if 'filename' in res_dict:
filename = res_dict['filename']
elif 'filestem' in res_dict:
filename = f"{res_dict['filestem']}.wav"
return phonemes, filename
class ArabDataset(Dataset):
def __init__(self,
txtpath: str = 'tts data sample/text.txt',
wavpath: str = './',
label_pattern: str = '"(?P<filename>.*)" "(?P<phonemes>.*)"',
sr_target: int = 22050
):
super().__init__()
self.mel_fn = MelSpectrogram()
self.wav_path = wavpath
self.label_pattern = label_pattern
self.sr_target = sr_target
self.data = self._process_textfile(txtpath)
def _process_textfile(self, txtpath: str):
lines = read_lines_from_file(txtpath)
phoneme_mel_list = []
for l_idx, line in enumerate(progbar(lines)):
try:
phonemes, filename = _process_line(
self.label_pattern, line)
except:
print(f'invalid line {l_idx}: {line}')
continue
fpath = os.path.join(self.wav_path, filename)
if not os.path.exists(fpath):
print(f"{fpath} does not exist")
continue
try:
tokens = text.phonemes_to_tokens(phonemes)
token_ids = text.tokens_to_ids(tokens)
except:
print(f'invalid phonemes at line {l_idx}: {line}')
continue
phoneme_mel_list.append((torch.LongTensor(token_ids), fpath))
return phoneme_mel_list
def _get_mel_from_fpath(self, fpath):
wave, sr = torchaudio.load(fpath)
if sr != self.sr_target:
wave = torchaudio.functional.resample(wave, sr, self.sr_target, 64)
mel_raw = self.mel_fn(wave)
mel_log = mel_raw.clamp_min(1e-5).log().squeeze()
energy_per_frame = mel_log.mean(0)
mel_log = mel_log[:, remove_silence(energy_per_frame)]
return mel_log
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
phonemes, fpath = self.data[idx]
mel_log = self._get_mel_from_fpath(fpath)
return phonemes, mel_log
class ArabDataset4FastPitch(Dataset):
def __init__(self,
txtpath: str = './data/train_phon.txt',
wavpath: str = 'G:/data/arabic-speech-corpus/wav_new',
label_pattern: str = '"(?P<filename>.*)" "(?P<phonemes>.*)"',
f0_dict_path: str = './data/pitch_dict.pt',
f0_mean: float = 130.05478,
f0_std: float = 22.86267,
sr_target: int = 22050
):
super().__init__()
from models.fastpitch.fastpitch.data_function import BetaBinomialInterpolator
self.mel_fn = MelSpectrogram()
self.wav_path = wavpath
self.label_pattern = label_pattern
self.sr_target = sr_target
self.f0_dict = torch.load(f0_dict_path)
self.f0_mean = f0_mean
self.f0_std = f0_std
self.betabinomial_interpolator = BetaBinomialInterpolator()
self.data = self._process_textfile(txtpath)
def _process_textfile(self, txtpath: str):
lines = read_lines_from_file(txtpath)
phoneme_mel_pitch_list = []
for l_idx, line in enumerate(progbar(lines)):
try:
phonemes, filename = _process_line(
self.label_pattern, line)
except:
print(f'invalid line {l_idx}: {line}')
continue
fpath = os.path.join(self.wav_path, filename)
if not os.path.exists(fpath):
print(f"{fpath} does not exist")
continue
try:
tokens = text.phonemes_to_tokens(phonemes)
token_ids = text.tokens_to_ids(tokens)
except:
print(f'invalid phonemes at line {l_idx}: {line}')
continue
wav_name = os.path.basename(fpath)
pitch_mel = self.f0_dict[wav_name][None]
phoneme_mel_pitch_list.append(
(torch.LongTensor(token_ids), fpath, pitch_mel))
return phoneme_mel_pitch_list
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
phonemes, fpath, pitch_mel = self.data[idx]
wave, sr = torchaudio.load(fpath)
if sr != self.sr_target:
wave = torchaudio.functional.resample(wave, sr, self.sr_target, 64)
mel_raw = self.mel_fn(wave)
mel_log = mel_raw.clamp_min(1e-5).log().squeeze()
keep = remove_silence(mel_log.mean(0))
mel_log = mel_log[:, keep]
pitch_mel = normalize_pitch(pitch_mel[:,keep], self.f0_mean, self.f0_std)
energy = torch.norm(mel_log.float(), dim=0, p=2)
attn_prior = torch.from_numpy(
self.betabinomial_interpolator(mel_log.size(1), len(phonemes)))
speaker = None
return (phonemes, mel_log, len(phonemes), pitch_mel,
energy, speaker, attn_prior,
fpath)
class DynBatchDataset(ArabDataset4FastPitch):
def __init__(self,
txtpath: str = './data/train_phon.txt',
wavpath: str = 'G:/data/arabic-speech-corpus/wav_new',
label_pattern: str = '"(?P<filename>.*)" "(?P<phonemes>.*)"',
f0_dict_path: str = './data/pitch_dict.pt',
f0_mean: float = 130.05478,
f0_std: float = 22.86267,
max_lengths: list[int] = [1000, 1300, 1850, 30000],
batch_sizes: list[int] = [10, 8, 6, 4],
):
super().__init__(txtpath=txtpath, wavpath=wavpath,
label_pattern=label_pattern,
f0_dict_path=f0_dict_path,
f0_mean=f0_mean, f0_std=f0_std)
self.max_lens = [0,] + max_lengths
self.b_sizes = batch_sizes
self.id_batches = []
self.shuffle()
def shuffle(self):
lens = [x[2].size(1) for x in self.data] # x[2]: pitch
ids_per_bs = {b: [] for b in self.b_sizes}
for i, mel_len in enumerate(lens):
b_idx = next(i for i in range(len(self.max_lens)-1)
if self.max_lens[i] <= mel_len < self.max_lens[i+1])
ids_per_bs[self.b_sizes[b_idx]].append(i)
id_batches = []
for bs, ids in ids_per_bs.items():
np.random.shuffle(ids)
ids_chnk = [ids[i:i+bs] for i in range(0, len(ids), bs)]
id_batches += ids_chnk
self.id_batches = id_batches
def __len__(self):
return len(self.id_batches)
def __getitem__(self, idx):
batch = [super(DynBatchDataset, self).__getitem__(idx)
for idx in self.id_batches[idx]]
return batch