hhguo's picture
update
37ced70
raw
history blame
5.25 kB
import typing as tp
import torch
import torch.nn as nn
class ConvLayer(nn.Module):
def __init__(self,
in_channels:int,
out_channels:int,
kernel_size:int,
stride:int,
activation:str="GELU",
dropout_rate:float=0.0,
):
super().__init__()
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size-stride)//2,
)
self.drop = nn.Dropout(dropout_rate)
self.norm = nn.LayerNorm(out_channels)
self.activ = getattr(nn, activation)()
def forward(self, x:torch.Tensor):
"""
Args:
x: (b, t, c)
Return:
x: (b, t, c)
"""
x = x.transpose(2, 1)
x = self.conv(x)
x = x.transpose(2, 1)
x = self.drop(x)
x = self.norm(x)
x = self.activ(x)
return x
class ResidualConvLayer(nn.Module):
def __init__(self,
hidden_channels:int,
n_layers:int=2,
kernel_size:int=5,
activation:str="GELU",
dropout_rate:float=0.0,
):
super().__init__()
layers = [
ConvLayer(hidden_channels, hidden_channels, kernel_size, 1, activation, dropout_rate)
for _ in range(n_layers)
]
self.layers = nn.Sequential(*layers)
def forward(self, x:torch.Tensor):
"""
Args:
x: (b, t, c)
Returns:
x: (b, t, c)
"""
return x + self.layers(x)
class ResidualConvBlock(nn.Module):
def __init__(self,
in_channels:int,
hidden_channels:int,
out_channels:int,
n_layers:int=2,
n_blocks:int=5,
middle_layer:tp.Optional[nn.Module]=None,
kernel_size:int=5,
activation:str="GELU",
dropout_rate:float=0.0,
):
super().__init__()
self.in_proj = nn.Conv1d(
in_channels,
hidden_channels,
kernel_size=kernel_size,
stride=1,
padding=(kernel_size-1)//2,
) if in_channels != hidden_channels else nn.Identity()
self.conv1 = nn.Sequential(*[
ResidualConvLayer(hidden_channels, n_layers, kernel_size, activation, dropout_rate)
for _ in range(n_blocks)
])
if middle_layer is None:
self.middle_layer = nn.Identity()
elif isinstance(middle_layer, nn.Module):
self.middle_layer = middle_layer
else:
raise TypeError("unknown middle layer type:{}".format(type(middle_layer)))
self.conv2 = nn.Sequential(*[
ResidualConvLayer(hidden_channels, n_layers, kernel_size, activation, dropout_rate)
for _ in range(n_blocks)
])
self.out_proj = nn.Conv1d(
hidden_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding=(kernel_size-1)//2,
) if out_channels != hidden_channels else nn.Identity()
def forward(self, x:torch.Tensor, **middle_layer_kwargs):
"""
Args:
x: (b, t1, c)
Return:
x: (b, t2, c)
"""
x = self.in_proj(x.transpose(2, 1)).transpose(2, 1)
x = self.conv1(x)
if isinstance(self.middle_layer, nn.MaxPool1d) or isinstance(self.middle_layer, nn.Conv1d):
x = self.middle_layer(x.transpose(2, 1)).transpose(2, 1)
elif isinstance(self.middle_layer, nn.Identity):
x = self.middle_layer(x)
else:
# incase of phoneme-pooling layer
x = self.middle_layer(x, **middle_layer_kwargs)
x = self.conv2(x)
x = self.out_proj(x.transpose(2, 1)).transpose(2, 1)
return x
class MelReduceEncoder(nn.Module):
def __init__(self,
in_channels:int,
out_channels:int,
hidden_channels:int=384,
reduction_rate:int=4,
n_layers:int=2,
n_blocks:int=5,
kernel_size:int=3,
activation:str="GELU",
dropout:float=0.0,
):
super().__init__()
self.reduction_rate = reduction_rate
middle_conv = nn.Conv1d(
in_channels=hidden_channels,
out_channels=hidden_channels,
kernel_size=reduction_rate,
stride=reduction_rate,
padding=0
)
self.encoder = ResidualConvBlock(
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
n_layers=n_layers,
n_blocks=n_blocks,
middle_layer=middle_conv,
kernel_size=kernel_size,
activation=activation,
dropout_rate=dropout
)
def forward(self, x:torch.Tensor):
return self.encoder(x)