Spaces:
Sleeping
Sleeping
"""Search a good noise schedule for WaveGrad for a given number of inference iterations""" | |
import argparse | |
from itertools import product as cartesian_product | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from TTS.config import load_config | |
from TTS.utils.audio import AudioProcessor | |
from TTS.vocoder.datasets.preprocess import load_wav_data | |
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset | |
from TTS.vocoder.models import setup_model | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") | |
parser.add_argument("--config_path", type=str, help="Path to model config file.") | |
parser.add_argument("--data_path", type=str, help="Path to data directory.") | |
parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.") | |
parser.add_argument( | |
"--num_iter", | |
type=int, | |
help="Number of model inference iterations that you like to optimize noise schedule for.", | |
) | |
parser.add_argument("--use_cuda", action="store_true", help="enable CUDA.") | |
parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.") | |
parser.add_argument( | |
"--search_depth", | |
type=int, | |
default=3, | |
help="Search granularity. Increasing this increases the run-time exponentially.", | |
) | |
# load config | |
args = parser.parse_args() | |
config = load_config(args.config_path) | |
# setup audio processor | |
ap = AudioProcessor(**config.audio) | |
# load dataset | |
_, train_data = load_wav_data(args.data_path, 0) | |
train_data = train_data[: args.num_samples] | |
dataset = WaveGradDataset( | |
ap=ap, | |
items=train_data, | |
seq_len=-1, | |
hop_len=ap.hop_length, | |
pad_short=config.pad_short, | |
conv_pad=config.conv_pad, | |
is_training=True, | |
return_segments=False, | |
use_noise_augment=False, | |
use_cache=False, | |
verbose=True, | |
) | |
loader = DataLoader( | |
dataset, | |
batch_size=1, | |
shuffle=False, | |
collate_fn=dataset.collate_full_clips, | |
drop_last=False, | |
num_workers=config.num_loader_workers, | |
pin_memory=False, | |
) | |
# setup the model | |
model = setup_model(config) | |
if args.use_cuda: | |
model.cuda() | |
# setup optimization parameters | |
base_values = sorted(10 * np.random.uniform(size=args.search_depth)) | |
print(f" > base values: {base_values}") | |
exponents = 10 ** np.linspace(-6, -1, num=args.num_iter) | |
best_error = float("inf") | |
best_schedule = None # pylint: disable=C0103 | |
total_search_iter = len(base_values) ** args.num_iter | |
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter): | |
beta = exponents * base | |
model.compute_noise_level(beta) | |
for data in loader: | |
mel, audio = data | |
y_hat = model.inference(mel.cuda() if args.use_cuda else mel) | |
if args.use_cuda: | |
y_hat = y_hat.cpu() | |
y_hat = y_hat.numpy() | |
mel_hat = [] | |
for i in range(y_hat.shape[0]): | |
m = ap.melspectrogram(y_hat[i, 0])[:, :-1] | |
mel_hat.append(torch.from_numpy(m)) | |
mel_hat = torch.stack(mel_hat) | |
mse = torch.sum((mel - mel_hat) ** 2).mean() | |
if mse.item() < best_error: | |
best_error = mse.item() | |
best_schedule = {"beta": beta} | |
print(f" > Found a better schedule. - MSE: {mse.item()}") | |
np.save(args.output_path, best_schedule) | |