|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from .modules import Conv1d1x1, ResidualConv1dGLU |
|
from .upsample import ConvInUpsampleNetwork |
|
|
|
|
|
def receptive_field_size( |
|
total_layers, num_cycles, kernel_size, dilation=lambda x: 2**x |
|
): |
|
"""Compute receptive field size |
|
|
|
Args: |
|
total_layers (int): total layers |
|
num_cycles (int): cycles |
|
kernel_size (int): kernel size |
|
dilation (lambda): lambda to compute dilation factor. ``lambda x : 1`` |
|
to disable dilated convolution. |
|
|
|
Returns: |
|
int: receptive field size in sample |
|
|
|
""" |
|
assert total_layers % num_cycles == 0 |
|
|
|
layers_per_cycle = total_layers // num_cycles |
|
dilations = [dilation(i % layers_per_cycle) for i in range(total_layers)] |
|
return (kernel_size - 1) * sum(dilations) + 1 |
|
|
|
|
|
class WaveNet(nn.Module): |
|
"""The WaveNet model that supports local and global conditioning. |
|
|
|
Args: |
|
out_channels (int): Output channels. If input_type is mu-law quantized |
|
one-hot vecror. this must equal to the quantize channels. Other wise |
|
num_mixtures x 3 (pi, mu, log_scale). |
|
layers (int): Number of total layers |
|
stacks (int): Number of dilation cycles |
|
residual_channels (int): Residual input / output channels |
|
gate_channels (int): Gated activation channels. |
|
skip_out_channels (int): Skip connection channels. |
|
kernel_size (int): Kernel size of convolution layers. |
|
dropout (float): Dropout probability. |
|
input_dim (int): Number of mel-spec dimension. |
|
upsample_scales (list): List of upsample scale. |
|
``np.prod(upsample_scales)`` must equal to hop size. Used only if |
|
upsample_conditional_features is enabled. |
|
freq_axis_kernel_size (int): Freq-axis kernel_size for transposed |
|
convolution layers for upsampling. If you only care about time-axis |
|
upsampling, set this to 1. |
|
scalar_input (Bool): If True, scalar input ([-1, 1]) is expected, otherwise |
|
quantized one-hot vector is expected.. |
|
""" |
|
|
|
def __init__(self, cfg): |
|
super(WaveNet, self).__init__() |
|
self.cfg = cfg |
|
self.scalar_input = self.cfg.VOCODER.SCALAR_INPUT |
|
self.out_channels = self.cfg.VOCODER.OUT_CHANNELS |
|
self.cin_channels = self.cfg.VOCODER.INPUT_DIM |
|
self.residual_channels = self.cfg.VOCODER.RESIDUAL_CHANNELS |
|
self.layers = self.cfg.VOCODER.LAYERS |
|
self.stacks = self.cfg.VOCODER.STACKS |
|
self.gate_channels = self.cfg.VOCODER.GATE_CHANNELS |
|
self.kernel_size = self.cfg.VOCODER.KERNEL_SIZE |
|
self.skip_out_channels = self.cfg.VOCODER.SKIP_OUT_CHANNELS |
|
self.dropout = self.cfg.VOCODER.DROPOUT |
|
self.upsample_scales = self.cfg.VOCODER.UPSAMPLE_SCALES |
|
self.mel_frame_pad = self.cfg.VOCODER.MEL_FRAME_PAD |
|
|
|
assert self.layers % self.stacks == 0 |
|
|
|
layers_per_stack = self.layers // self.stacks |
|
if self.scalar_input: |
|
self.first_conv = Conv1d1x1(1, self.residual_channels) |
|
else: |
|
self.first_conv = Conv1d1x1(self.out_channels, self.residual_channels) |
|
|
|
self.conv_layers = nn.ModuleList() |
|
for layer in range(self.layers): |
|
dilation = 2 ** (layer % layers_per_stack) |
|
conv = ResidualConv1dGLU( |
|
self.residual_channels, |
|
self.gate_channels, |
|
kernel_size=self.kernel_size, |
|
skip_out_channels=self.skip_out_channels, |
|
bias=True, |
|
dilation=dilation, |
|
dropout=self.dropout, |
|
cin_channels=self.cin_channels, |
|
) |
|
self.conv_layers.append(conv) |
|
|
|
self.last_conv_layers = nn.ModuleList( |
|
[ |
|
nn.ReLU(inplace=True), |
|
Conv1d1x1(self.skip_out_channels, self.skip_out_channels), |
|
nn.ReLU(inplace=True), |
|
Conv1d1x1(self.skip_out_channels, self.out_channels), |
|
] |
|
) |
|
|
|
self.upsample_net = ConvInUpsampleNetwork( |
|
upsample_scales=self.upsample_scales, |
|
cin_pad=self.mel_frame_pad, |
|
cin_channels=self.cin_channels, |
|
) |
|
|
|
self.receptive_field = receptive_field_size( |
|
self.layers, self.stacks, self.kernel_size |
|
) |
|
|
|
def forward(self, x, mel, softmax=False): |
|
"""Forward step |
|
|
|
Args: |
|
x (Tensor): One-hot encoded audio signal, shape (B x C x T) |
|
mel (Tensor): Local conditioning features, |
|
shape (B x cin_channels x T) |
|
softmax (bool): Whether applies softmax or not. |
|
|
|
Returns: |
|
Tensor: output, shape B x out_channels x T |
|
""" |
|
B, _, T = x.size() |
|
|
|
mel = self.upsample_net(mel) |
|
assert mel.shape[-1] == x.shape[-1] |
|
|
|
x = self.first_conv(x) |
|
skips = 0 |
|
for f in self.conv_layers: |
|
x, h = f(x, mel) |
|
skips += h |
|
skips *= math.sqrt(1.0 / len(self.conv_layers)) |
|
|
|
x = skips |
|
for f in self.last_conv_layers: |
|
x = f(x) |
|
|
|
x = F.softmax(x, dim=1) if softmax else x |
|
|
|
return x |
|
|
|
def clear_buffer(self): |
|
self.first_conv.clear_buffer() |
|
for f in self.conv_layers: |
|
f.clear_buffer() |
|
for f in self.last_conv_layers: |
|
try: |
|
f.clear_buffer() |
|
except AttributeError: |
|
pass |
|
|
|
def make_generation_fast_(self): |
|
def remove_weight_norm(m): |
|
try: |
|
nn.utils.remove_weight_norm(m) |
|
except ValueError: |
|
return |
|
|
|
self.apply(remove_weight_norm) |
|
|