|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import torch.nn as nn |
|
|
|
from modules.general.utils import Conv1d, zero_module |
|
from .residual_block import ResidualBlock |
|
|
|
|
|
class BiDilConv(nn.Module): |
|
r"""Dilated CNN architecture with residual connections, default diffusion decoder. |
|
|
|
Args: |
|
input_channel: The number of input channels. |
|
base_channel: The number of base channels. |
|
n_res_block: The number of residual blocks. |
|
conv_kernel_size: The kernel size of convolutional layers. |
|
dilation_cycle_length: The cycle length of dilation. |
|
conditioner_size: The size of conditioner. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_channel, |
|
base_channel, |
|
n_res_block, |
|
conv_kernel_size, |
|
dilation_cycle_length, |
|
conditioner_size, |
|
output_channel: int = -1, |
|
): |
|
super().__init__() |
|
|
|
self.input_channel = input_channel |
|
self.base_channel = base_channel |
|
self.n_res_block = n_res_block |
|
self.conv_kernel_size = conv_kernel_size |
|
self.dilation_cycle_length = dilation_cycle_length |
|
self.conditioner_size = conditioner_size |
|
self.output_channel = output_channel if output_channel > 0 else input_channel |
|
|
|
self.input = nn.Sequential( |
|
Conv1d( |
|
input_channel, |
|
base_channel, |
|
1, |
|
), |
|
nn.ReLU(), |
|
) |
|
|
|
self.residual_blocks = nn.ModuleList( |
|
[ |
|
ResidualBlock( |
|
channels=base_channel, |
|
kernel_size=conv_kernel_size, |
|
dilation=2 ** (i % dilation_cycle_length), |
|
d_context=conditioner_size, |
|
) |
|
for i in range(n_res_block) |
|
] |
|
) |
|
|
|
self.out_proj = nn.Sequential( |
|
Conv1d( |
|
base_channel, |
|
base_channel, |
|
1, |
|
), |
|
nn.ReLU(), |
|
zero_module( |
|
Conv1d( |
|
base_channel, |
|
self.output_channel, |
|
1, |
|
), |
|
), |
|
) |
|
|
|
def forward(self, x, y, context=None): |
|
""" |
|
Args: |
|
x: Noisy mel-spectrogram [B x ``n_mel`` x L] |
|
y: FILM embeddings with the shape of (B, ``base_channel``) |
|
context: Context with the shape of [B x ``d_context`` x L], default to None. |
|
""" |
|
|
|
h = self.input(x) |
|
|
|
skip = None |
|
for i in range(self.n_res_block): |
|
h, skip_connection = self.residual_blocks[i](h, y, context) |
|
skip = skip_connection if skip is None else skip_connection + skip |
|
|
|
out = skip / math.sqrt(self.n_res_block) |
|
|
|
out = self.out_proj(out) |
|
|
|
return out |
|
|