ALeLacheur's picture
uploading audio diffusion attacks
5a9b731
"""
utils.py
Desc: A file for miscellaneous util functions
"""
import numpy as np
import torch
# MonoTransform, this does not exist in PyTorch anymore since it is a simple mean calculation. We provide an implementation here
class MonoTransform(object):
"""
Convert audio sample to mono channel
Args for __call__:
audio_sample with shape (C, T) or (B, C, T), where C is the number of channels.
TODO: IMPLEMENT __call__
"""
def __init__(self):
pass
def __call__(self, sample):
pass
"""
Below: Helper functions for Grad-TTS
"""
## Duration Loss
## Desc: A function for computing the duration loss for the duration predictor
def duration_loss(logw, logw_, lengths):
loss = torch.sum((logw - logw_)**2) / torch.sum(lengths)
return loss
def intersperse(lst, item):
# Adds blank symbol
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
while True:
if length % (2**num_downsamplings_in_unet) == 0:
return length
length += 1
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def generate_path(duration, mask):
device = duration.device
b, t_x, t_y = mask.shape
cum_duration = torch.cumsum(duration, 1)
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0],
[1, 0], [0, 0]]))[:, :-1]
path = path * mask
return path