ITO-Master / networks /network_utils.py
jhtonyKoo's picture
revise
6fc042a
"""
Utility File
containing functions for neural networks
"""
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch
import torchaudio
# 2-dimensional convolutional layer
# in the order of conv -> norm -> activation
class Conv2d_layer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, \
stride=1, \
padding="SAME", dilation=(1,1), bias=True, \
norm="batch", activation="relu", \
mode="conv"):
super(Conv2d_layer, self).__init__()
self.conv2d = nn.Sequential()
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
if isinstance(stride, int):
stride = [stride, stride]
if isinstance(dilation, int):
dilation = [dilation, dilation]
''' padding '''
if mode=="deconv":
padding = tuple(int((current_kernel - 1)/2) for current_kernel in kernel_size)
out_padding = tuple(0 if current_stride == 1 else 1 for current_stride in stride)
elif mode=="conv":
if padding == "SAME":
f_pad = int((kernel_size[0]-1) * dilation[0])
t_pad = int((kernel_size[1]-1) * dilation[1])
t_l_pad = int(t_pad//2)
t_r_pad = t_pad - t_l_pad
f_l_pad = int(f_pad//2)
f_r_pad = f_pad - f_l_pad
padding_area = (t_l_pad, t_r_pad, f_l_pad, f_r_pad)
elif padding == "VALID":
padding = 0
else:
pass
''' convolutional layer '''
if mode=="deconv":
self.conv2d.add_module("deconv2d", nn.ConvTranspose2d(in_channels, out_channels, \
(kernel_size[0], kernel_size[1]), \
stride=stride, \
padding=padding, output_padding=out_padding, \
dilation=dilation, \
bias=bias))
elif mode=="conv":
self.conv2d.add_module(f"{mode}2d_pad", nn.ReflectionPad2d(padding_area))
self.conv2d.add_module(f"{mode}2d", nn.Conv2d(in_channels, out_channels, \
(kernel_size[0], kernel_size[1]), \
stride=stride, \
padding=0, \
dilation=dilation, \
bias=bias))
''' normalization '''
if norm=="batch":
self.conv2d.add_module("batch_norm", nn.BatchNorm2d(out_channels))
''' activation '''
if activation=="relu":
self.conv2d.add_module("relu", nn.ReLU())
elif activation=="lrelu":
self.conv2d.add_module("lrelu", nn.LeakyReLU())
def forward(self, input):
# input shape should be : batch x channel x height x width
output = self.conv2d(input)
return output
# 1-dimensional convolutional layer
# in the order of conv -> norm -> activation
class Conv1d_layer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, \
stride=1, \
padding="SAME", dilation=1, bias=True, \
norm="batch", activation="relu", \
mode="conv"):
super(Conv1d_layer, self).__init__()
self.conv1d = nn.Sequential()
''' padding '''
if mode=="deconv":
padding = int(dilation * (kernel_size-1) / 2)
out_padding = 0 if stride==1 else 1
elif mode=="conv" or "alias_free" in mode:
if padding == "SAME":
pad = int((kernel_size-1) * dilation)
l_pad = int(pad//2)
r_pad = pad - l_pad
padding_area = (l_pad, r_pad)
elif padding == "VALID":
padding_area = (0, 0)
else:
pass
''' convolutional layer '''
if mode=="deconv":
self.conv1d.add_module("deconv1d", nn.ConvTranspose1d(in_channels, out_channels, kernel_size, \
stride=stride, padding=padding, output_padding=out_padding, \
dilation=dilation, \
bias=bias))
elif mode=="conv":
self.conv1d.add_module(f"{mode}1d_pad", nn.ReflectionPad1d(padding_area))
self.conv1d.add_module(f"{mode}1d", nn.Conv1d(in_channels, out_channels, kernel_size, \
stride=stride, padding=0, \
dilation=dilation, \
bias=bias))
elif "alias_free" in mode:
if "up" in mode:
up_factor = stride * 2
down_factor = 2
elif "down" in mode:
up_factor = 2
down_factor = stride * 2
else:
raise ValueError("choose alias-free method : 'up' or 'down'")
# procedure : conv -> upsample -> lrelu -> low-pass filter -> downsample
# the torchaudio.transforms.Resample's default resampling_method is 'sinc_interpolation' which performs low-pass filter during the process
# details at https://pytorch.org/audio/stable/transforms.html
self.conv1d.add_module(f"{mode}1d_pad", nn.ReflectionPad1d(padding_area))
self.conv1d.add_module(f"{mode}1d", nn.Conv1d(in_channels, out_channels, kernel_size, \
stride=1, padding=0, \
dilation=dilation, \
bias=bias))
self.conv1d.add_module(f"{mode}upsample", torchaudio.transforms.Resample(orig_freq=1, new_freq=up_factor))
self.conv1d.add_module(f"{mode}lrelu", nn.LeakyReLU())
self.conv1d.add_module(f"{mode}downsample", torchaudio.transforms.Resample(orig_freq=down_factor, new_freq=1))
''' normalization '''
if norm=="batch":
self.conv1d.add_module("batch_norm", nn.BatchNorm1d(out_channels))
# self.conv1d.add_module("batch_norm", nn.SyncBatchNorm(out_channels))
''' activation '''
if 'alias_free' not in mode:
if activation=="relu":
self.conv1d.add_module("relu", nn.ReLU())
elif activation=="lrelu":
self.conv1d.add_module("lrelu", nn.LeakyReLU())
def forward(self, input):
# input shape should be : batch x channel x height x width
output = self.conv1d(input)
return output
# Residual Block
# the input is added after the first convolutional layer, retaining its original channel size
# therefore, the second convolutional layer's output channel may differ
class Res_ConvBlock(nn.Module):
def __init__(self, dimension, \
in_channels, out_channels, \
kernel_size, \
stride=1, padding="SAME", \
dilation=1, \
bias=True, \
norm="batch", \
activation="relu", last_activation="relu", \
mode="conv"):
super(Res_ConvBlock, self).__init__()
if dimension==1:
self.conv1 = Conv1d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation)
self.conv2 = Conv1d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode)
elif dimension==2:
self.conv1 = Conv2d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation)
self.conv2 = Conv2d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode)
def forward(self, input):
c1_out = self.conv1(input) + input
c2_out = self.conv2(c1_out)
return c2_out
# Convoluaionl Block
# consists of multiple (number of layer_num) convolutional layers
# only the final convoluational layer outputs the desired 'out_channels'
class ConvBlock(nn.Module):
def __init__(self, dimension, layer_num, \
in_channels, out_channels, \
kernel_size, \
stride=1, padding="SAME", \
dilation=1, \
bias=True, \
norm="batch", \
activation="relu", last_activation="relu", \
mode="conv"):
super(ConvBlock, self).__init__()
conv_block = []
if dimension==1:
for i in range(layer_num-1):
conv_block.append(Conv1d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation))
conv_block.append(Conv1d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode))
elif dimension==2:
for i in range(layer_num-1):
conv_block.append(Conv2d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation))
conv_block.append(Conv2d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode))
self.conv_block = nn.Sequential(*conv_block)
def forward(self, input):
return self.conv_block(input)
# Feature-wise Linear Modulation
class FiLM(nn.Module):
def __init__(self, condition_len=2048, feature_len=1024):
super(FiLM, self).__init__()
self.film_fc = nn.Linear(condition_len, feature_len*2)
self.feat_len = feature_len
def forward(self, feature, condition, sefa=None):
# SeFA
if sefa:
weight = self.film_fc.weight.T
weight = weight / torch.linalg.norm((weight+1e-07), dim=0, keepdims=True)
eigen_values, eigen_vectors = torch.eig(torch.matmul(weight, weight.T), eigenvectors=True)
####### custom parameters #######
chosen_eig_idx = sefa[0]
alpha = eigen_values[chosen_eig_idx][0] * sefa[1]
#################################
An = eigen_vectors[chosen_eig_idx].repeat(condition.shape[0], 1)
alpha_An = alpha * An
condition += alpha_An
film_factor = self.film_fc(condition).unsqueeze(-1)
r, b = torch.split(film_factor, self.feat_len, dim=1)
return r*feature + b