Spaces:
Runtime error
Runtime error
File size: 4,458 Bytes
a8c39f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import math
import torch
from typing import List, Optional
def init_weights(m, mean=0.0, std=0.01):
"""
Initialize the weights of a module.
Args:
m: The module to initialize.
mean: The mean of the normal distribution.
std: The standard deviation of the normal distribution.
"""
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
"""
Calculate the padding needed for a convolution.
Args:
kernel_size: The size of the kernel.
dilation: The dilation of the convolution.
"""
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
"""
Convert the pad shape to a list of integers.
Args:
pad_shape: The pad shape..
"""
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def slice_segments(
x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4, dim: int = 2
):
"""
Slice segments from a tensor, handling tensors with different numbers of dimensions.
Args:
x (torch.Tensor): The tensor to slice.
ids_str (torch.Tensor): The starting indices of the segments.
segment_size (int, optional): The size of each segment. Defaults to 4.
dim (int, optional): The dimension to slice across (2D or 3D tensors). Defaults to 2.
"""
if dim == 2:
ret = torch.zeros_like(x[:, :segment_size])
elif dim == 3:
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i].item()
idx_end = idx_str + segment_size
if dim == 2:
ret[i] = x[i, idx_str:idx_end]
else:
ret[i] = x[i, :, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
"""
Randomly slice segments from a tensor.
Args:
x: The tensor to slice.
x_lengths: The lengths of the sequences.
segment_size: The size of each segment.
"""
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size, dim=3)
return ret, ids_str
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
"""
Fused add tanh sigmoid multiply operation.
Args:
input_a: The first input tensor.
input_b: The second input tensor.
n_channels: The number of channels.
"""
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape: List[List[int]]) -> List[int]:
"""
Convert the pad shape to a list of integers.
Args:
pad_shape: The pad shape.
"""
return torch.tensor(pad_shape).flip(0).reshape(-1).int().tolist()
def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None):
"""
Generate a sequence mask.
Args:
length: The lengths of the sequences.
max_length: The maximum length of the sequences.
"""
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def clip_grad_value(parameters, clip_value, norm_type=2):
"""
Clip the gradients of a list of parameters.
Args:
parameters: The list of parameters to clip.
clip_value: The maximum value of the gradients.
norm_type: The type of norm to use for clipping.
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm
|