File size: 2,070 Bytes
6dff1ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch
from torch import nn
from typing import Any
class BatchNormConv1d(nn.Module):
"""
A nn.Conv1d followed by an optional activation function, and nn.BatchNorm1d
"""
def __init__(
self,
in_dim: int,
out_dim: int,
kernel_size: int,
stride: int,
padding: int,
activation: Any = None,
):
super().__init__()
self.conv1d = nn.Conv1d(
in_dim,
out_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False,
)
self.bn = nn.BatchNorm1d(out_dim)
self.activation = activation
def forward(self, x: Any):
x = self.conv1d(x)
if self.activation is not None:
x = self.activation(x)
return self.bn(x)
class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super().__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, x):
return self.linear_layer(x)
class ConvNorm(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
padding=None, dilation=1, bias=True, w_init_gain='linear'):
super().__init__()
if padding is None:
assert(kernel_size % 2 == 1)
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation,
bias=bias)
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
|