Datasculptor's picture
Duplicate from AIGC-Audio/AudioGPT
98f685a
import matplotlib
matplotlib.use('Agg')
from tasks.base_task import data_loader
from tasks.tts.fs2 import FastSpeech2Task
from tasks.tts.dataset_utils import FastSpeechDataset, BaseTTSDataset
import glob
import importlib
from utils.pitch_utils import norm_interp_f0, denorm_f0, f0_to_coarse
from inference.base_tts_infer import load_data_preprocessor
from data_gen.tts.emotion import inference as EmotionEncoder
from data_gen.tts.emotion.inference import embed_utterance as Embed_utterance
from data_gen.tts.emotion.inference import preprocess_wav
from tqdm import tqdm
from utils.hparams import hparams
from data_gen.tts.data_gen_utils import build_phone_encoder, build_word_encoder
import random
import torch
import torch.optim
import torch.nn.functional as F
import torch.utils.data
from utils.indexed_datasets import IndexedDataset
from resemblyzer import VoiceEncoder
import torch.distributions
import numpy as np
import utils
import os
class GenerSpeech_dataset(BaseTTSDataset):
def __init__(self, prefix, shuffle=False, test_items=None, test_sizes=None, data_dir=None):
super().__init__(prefix, shuffle, test_items, test_sizes, data_dir)
self.f0_mean, self.f0_std = hparams.get('f0_mean', None), hparams.get('f0_std', None)
if prefix == 'valid':
indexed_ds = IndexedDataset(f'{self.data_dir}/train')
sizes = np.load(f'{self.data_dir}/train_lengths.npy')
index = [i for i in range(len(indexed_ds))]
random.shuffle(index)
index = index[:300]
self.sizes = sizes[index]
self.indexed_ds = []
for i in index:
self.indexed_ds.append(indexed_ds[i])
self.avail_idxs = list(range(len(self.sizes)))
if hparams['min_frames'] > 0:
self.avail_idxs = [x for x in self.avail_idxs if self.sizes[x] >= hparams['min_frames']]
self.sizes = [self.sizes[i] for i in self.avail_idxs]
if prefix == 'test' and hparams['test_input_dir'] != '':
self.preprocessor, self.preprocess_args = load_data_preprocessor()
self.indexed_ds, self.sizes = self.load_test_inputs(hparams['test_input_dir'])
self.avail_idxs = [i for i, _ in enumerate(self.sizes)]
def load_test_inputs(self, test_input_dir):
inp_wav_paths = sorted(glob.glob(f'{test_input_dir}/*.wav') + glob.glob(f'{test_input_dir}/*.mp3'))
binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizerr.BaseBinarizer')
pkg = ".".join(binarizer_cls.split(".")[:-1])
cls_name = binarizer_cls.split(".")[-1]
binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
phone_encoder = build_phone_encoder(hparams['binary_data_dir'])
word_encoder = build_word_encoder(hparams['binary_data_dir'])
voice_encoder = VoiceEncoder().cuda()
encoder = [phone_encoder, word_encoder]
sizes = []
items = []
EmotionEncoder.load_model(hparams['emotion_encoder_path'])
preprocessor, preprocess_args = self.preprocessor, self.preprocess_args
for wav_fn in tqdm(inp_wav_paths):
item_name = wav_fn[len(test_input_dir) + 1:].replace("/", "_")
spk_id = emotion = 0
item2tgfn = wav_fn.replace('.wav', '.TextGrid') # prepare textgrid alignment
txtpath = wav_fn.replace('.wav', '.txt') # prepare text
with open(txtpath, 'r') as f:
text_raw = f.readlines()
f.close()
ph, txt = preprocessor.txt_to_ph(preprocessor.txt_processor, text_raw[0], preprocess_args)
item = binarizer_cls.process_item(item_name, ph, txt, item2tgfn, wav_fn, spk_id, emotion, encoder, hparams['binarization_args'])
item['emo_embed'] = Embed_utterance(preprocess_wav(item['wav_fn']))
item['spk_embed'] = voice_encoder.embed_utterance(item['wav'])
items.append(item)
sizes.append(item['len'])
return items, sizes
def _get_item(self, index):
if hasattr(self, 'avail_idxs') and self.avail_idxs is not None:
index = self.avail_idxs[index]
if self.indexed_ds is None:
self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
return self.indexed_ds[index]
def __getitem__(self, index):
hparams = self.hparams
item = self._get_item(index)
assert len(item['mel']) == self.sizes[index], (len(item['mel']), self.sizes[index])
max_frames = hparams['max_frames']
spec = torch.Tensor(item['mel'])[:max_frames]
max_frames = spec.shape[0] // hparams['frames_multiple'] * hparams['frames_multiple']
spec = spec[:max_frames]
phone = torch.LongTensor(item['phone'][:hparams['max_input_tokens']])
sample = {
"id": index,
"item_name": item['item_name'],
"text": item['txt'],
"txt_token": phone,
"mel": spec,
"mel_nonpadding": spec.abs().sum(-1) > 0,
}
spec = sample['mel']
T = spec.shape[0]
sample['mel2ph'] = mel2ph = torch.LongTensor(item['mel2ph'])[:T] if 'mel2ph' in item else None
if hparams['use_pitch_embed']:
assert 'f0' in item
if hparams.get('normalize_pitch', False):
f0 = item["f0"]
if len(f0 > 0) > 0 and f0[f0 > 0].std() > 0:
f0[f0 > 0] = (f0[f0 > 0] - f0[f0 > 0].mean()) / f0[f0 > 0].std() * hparams['f0_std'] + \
hparams['f0_mean']
f0[f0 > 0] = f0[f0 > 0].clip(min=60, max=500)
pitch = f0_to_coarse(f0)
pitch = torch.LongTensor(pitch[:max_frames])
else:
pitch = torch.LongTensor(item.get("pitch"))[:max_frames] if "pitch" in item else None
f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
uv = torch.FloatTensor(uv)
f0 = torch.FloatTensor(f0)
else:
f0 = uv = torch.zeros_like(mel2ph)
pitch = None
sample["f0"], sample["uv"], sample["pitch"] = f0, uv, pitch
sample["spk_embed"] = torch.Tensor(item['spk_embed'])
sample["emotion"] = item['emotion']
sample["emo_embed"] = torch.Tensor(item['emo_embed'])
if hparams.get('use_word', False):
sample["ph_words"] = item["ph_words"]
sample["word_tokens"] = torch.LongTensor(item["word_tokens"])
sample["mel2word"] = torch.LongTensor(item.get("mel2word"))[:max_frames]
sample["ph2word"] = torch.LongTensor(item['ph2word'][:hparams['max_input_tokens']])
return sample
def collater(self, samples):
if len(samples) == 0:
return {}
hparams = self.hparams
id = torch.LongTensor([s['id'] for s in samples])
item_names = [s['item_name'] for s in samples]
text = [s['text'] for s in samples]
txt_tokens = utils.collate_1d([s['txt_token'] for s in samples], 0)
mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
txt_lengths = torch.LongTensor([s['txt_token'].numel() for s in samples])
mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
batch = {
'id': id,
'item_name': item_names,
'nsamples': len(samples),
'text': text,
'txt_tokens': txt_tokens,
'txt_lengths': txt_lengths,
'mels': mels,
'mel_lengths': mel_lengths,
}
f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
pitch = utils.collate_1d([s['pitch'] for s in samples]) if samples[0]['pitch'] is not None else None
uv = utils.collate_1d([s['uv'] for s in samples])
mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) if samples[0]['mel2ph'] is not None else None
batch.update({
'mel2ph': mel2ph,
'pitch': pitch,
'f0': f0,
'uv': uv,
})
spk_embed = torch.stack([s['spk_embed'] for s in samples])
batch['spk_embed'] = spk_embed
emo_embed = torch.stack([s['emo_embed'] for s in samples])
batch['emo_embed'] = emo_embed
if hparams.get('use_word', False):
ph_words = [s['ph_words'] for s in samples]
batch['ph_words'] = ph_words
word_tokens = utils.collate_1d([s['word_tokens'] for s in samples], 0)
batch['word_tokens'] = word_tokens
mel2word = utils.collate_1d([s['mel2word'] for s in samples], 0)
batch['mel2word'] = mel2word
ph2word = utils.collate_1d([s['ph2word'] for s in samples], 0)
batch['ph2word'] = ph2word
return batch