Spaces:
Runtime error
Runtime error
File size: 6,188 Bytes
a8c39f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import math
import torch
from typing import List, Optional
def init_weights(m, mean=0.0, std=0.01):
"""
Initialize the weights of a module.
Args:
m: The module to initialize.
mean: The mean of the normal distribution.
std: The standard deviation of the normal distribution.
"""
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
"""
Calculate the padding needed for a convolution.
Args:
kernel_size: The size of the kernel.
dilation: The dilation of the convolution.
"""
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
"""
Convert the pad shape to a list of integers.
Args:
pad_shape: The pad shape..
"""
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""
Calculate the KL divergence between two distributions.
Args:
m_p: The mean of the first distribution.
logs_p: The log of the standard deviation of the first distribution.
m_q: The mean of the second distribution.
logs_q: The log of the standard deviation of the second distribution.
"""
kl = (logs_q - logs_p) - 0.5
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl
def slice_segments(
x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4, dim: int = 2
):
"""
Slice segments from a tensor, handling tensors with different numbers of dimensions.
Args:
x (torch.Tensor): The tensor to slice.
ids_str (torch.Tensor): The starting indices of the segments.
segment_size (int, optional): The size of each segment. Defaults to 4.
dim (int, optional): The dimension to slice across (2D or 3D tensors). Defaults to 2.
"""
if dim == 2:
ret = torch.zeros_like(x[:, :segment_size])
elif dim == 3:
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i].item()
idx_end = idx_str + segment_size
if dim == 2:
ret[i] = x[i, idx_str:idx_end]
else:
ret[i] = x[i, :, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
"""
Randomly slice segments from a tensor.
Args:
x: The tensor to slice.
x_lengths: The lengths of the sequences.
segment_size: The size of each segment.
"""
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size, dim=3)
return ret, ids_str
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
"""
Generate a 1D timing signal.
Args:
length: The length of the signal.
channels: The number of channels of the signal.
min_timescale: The minimum timescale.
max_timescale: The maximum timescale.
"""
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = torch.nn.functional.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def subsequent_mask(length):
"""
Generate a subsequent mask.
Args:
length: The length of the sequence.
"""
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
"""
Fused add tanh sigmoid multiply operation.
Args:
input_a: The first input tensor.
input_b: The second input tensor.
n_channels: The number of channels.
"""
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape: List[List[int]]) -> List[int]:
"""
Convert the pad shape to a list of integers.
Args:
pad_shape: The pad shape.
"""
return torch.tensor(pad_shape).flip(0).reshape(-1).int().tolist()
def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None):
"""
Generate a sequence mask.
Args:
length: The lengths of the sequences.
max_length: The maximum length of the sequences.
"""
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def clip_grad_value(parameters, clip_value, norm_type=2):
"""
Clip the gradients of a list of parameters.
Args:
parameters: The list of parameters to clip.
clip_value: The maximum value of the gradients.
norm_type: The type of norm to use for clipping.
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm
|