|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import math |
|
|
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from .conv import Conv1d as conv_Conv1d |
|
|
|
|
|
def Conv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs): |
|
m = conv_Conv1d(in_channels, out_channels, kernel_size, **kwargs) |
|
nn.init.kaiming_normal_(m.weight, nonlinearity="relu") |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
return nn.utils.weight_norm(m) |
|
|
|
|
|
def Conv1d1x1(in_channels, out_channels, bias=True): |
|
return Conv1d( |
|
in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias |
|
) |
|
|
|
|
|
def _conv1x1_forward(conv, x, is_incremental): |
|
if is_incremental: |
|
x = conv.incremental_forward(x) |
|
else: |
|
x = conv(x) |
|
return x |
|
|
|
|
|
class ResidualConv1dGLU(nn.Module): |
|
"""Residual dilated conv1d + Gated linear unit |
|
|
|
Args: |
|
residual_channels (int): Residual input / output channels |
|
gate_channels (int): Gated activation channels. |
|
kernel_size (int): Kernel size of convolution layers. |
|
skip_out_channels (int): Skip connection channels. If None, set to same |
|
as ``residual_channels``. |
|
cin_channels (int): Local conditioning channels. If negative value is |
|
set, local conditioning is disabled. |
|
dropout (float): Dropout probability. |
|
padding (int): Padding for convolution layers. If None, proper padding |
|
is computed depends on dilation and kernel_size. |
|
dilation (int): Dilation factor. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
residual_channels, |
|
gate_channels, |
|
kernel_size, |
|
skip_out_channels=None, |
|
cin_channels=-1, |
|
dropout=1 - 0.95, |
|
padding=None, |
|
dilation=1, |
|
causal=True, |
|
bias=True, |
|
*args, |
|
**kwargs, |
|
): |
|
super(ResidualConv1dGLU, self).__init__() |
|
self.dropout = dropout |
|
|
|
if skip_out_channels is None: |
|
skip_out_channels = residual_channels |
|
if padding is None: |
|
|
|
if causal: |
|
padding = (kernel_size - 1) * dilation |
|
else: |
|
padding = (kernel_size - 1) // 2 * dilation |
|
self.causal = causal |
|
|
|
self.conv = Conv1d( |
|
residual_channels, |
|
gate_channels, |
|
kernel_size, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias, |
|
*args, |
|
**kwargs, |
|
) |
|
|
|
|
|
self.conv1x1c = Conv1d1x1(cin_channels, gate_channels, bias=False) |
|
|
|
gate_out_channels = gate_channels // 2 |
|
self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias) |
|
self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels, bias=bias) |
|
|
|
def forward(self, x, c=None): |
|
return self._forward(x, c, False) |
|
|
|
def incremental_forward(self, x, c=None): |
|
return self._forward(x, c, True) |
|
|
|
def clear_buffer(self): |
|
for c in [ |
|
self.conv, |
|
self.conv1x1_out, |
|
self.conv1x1_skip, |
|
self.conv1x1c, |
|
]: |
|
if c is not None: |
|
c.clear_buffer() |
|
|
|
def _forward(self, x, c, is_incremental): |
|
"""Forward |
|
|
|
Args: |
|
x (Tensor): B x C x T |
|
c (Tensor): B x C x T, Mel conditioning features |
|
Returns: |
|
Tensor: output |
|
""" |
|
residual = x |
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
if is_incremental: |
|
splitdim = -1 |
|
x = self.conv.incremental_forward(x) |
|
else: |
|
splitdim = 1 |
|
x = self.conv(x) |
|
|
|
x = x[:, :, : residual.size(-1)] if self.causal else x |
|
|
|
a, b = x.split(x.size(splitdim) // 2, dim=splitdim) |
|
|
|
assert self.conv1x1c is not None |
|
c = _conv1x1_forward(self.conv1x1c, c, is_incremental) |
|
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) |
|
a, b = a + ca, b + cb |
|
|
|
x = torch.tanh(a) * torch.sigmoid(b) |
|
|
|
|
|
s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental) |
|
|
|
|
|
x = _conv1x1_forward(self.conv1x1_out, x, is_incremental) |
|
|
|
x = (x + residual) * math.sqrt(0.5) |
|
return x, s |
|
|