Sefika's picture
Upload 9 files
9ba9ac1
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Add, Activation, Dropout, Flatten, Dense
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.regularizers import l2
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import SGD
import warnings
from constraint import tight_frame
from convexity_constraint import convex_add
warnings.filterwarnings("ignore")
class ParsevalNetwork(Model):
def __init__(
self,
input_dim,
weight_decay,
momentum,
nb_classes=4,
N=2,
k=1,
dropout=0.0,
verbose=1,
):
"""[Assign the initial parameters of the wide residual network]
Args:
weight_decay ([float]): [description]
input_dim ([tuple]): [input dimension]
nb_classes (int, optional): [output class]. Defaults to 4.
N (int, optional): [the number of blocks]. Defaults to 2.
k (int, optional): [network width]. Defaults to 1.
dropout (float, optional): [dropout value to prevent overfitting]. Defaults to 0.0.
verbose (int, optional): [description]. Defaults to 1.
Returns:
[Model]: [parsevalnetwork]
"""
self.weight_decay = weight_decay
self.input_dim = input_dim
self.nb_classes = nb_classes
self.N = N
self.k = k
self.dropout = dropout
self.verbose = verbose
def initial_conv(self, input):
"""[summary]
Args:
input ([type]): [description]
Returns:
[type]: [description]
"""
x = Convolution2D(
16,
(3, 3),
padding="same",
kernel_initializer="orthogonal",
kernel_regularizer=l2(self.weight_decay),
kernel_constraint=tight_frame(0.001),
use_bias=False,
)(input)
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
x = BatchNormalization(
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
)(x)
x = Activation("relu")(x)
return x
def expand_conv(self, init, base, k, strides=(1, 1)):
"""[summary]
Args:
init ([type]): [description]
base ([type]): [description]
k ([type]): [description]
strides (tuple, optional): [description]. Defaults to (1, 1).
Returns:
[type]: [description]
"""
x = Convolution2D(
base * k,
(3, 3),
padding="same",
strides=strides,
kernel_initializer="Orthogonal",
kernel_regularizer=l2(self.weight_decay),
kernel_constraint=tight_frame(0.001),
use_bias=False,
)(init)
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
x = BatchNormalization(
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
)(x)
x = Activation("relu")(x)
x = Convolution2D(
base * k,
(3, 3),
padding="same",
kernel_initializer="Orthogonal",
kernel_regularizer=l2(self.weight_decay),
kernel_constraint=tight_frame(0.001),
use_bias=False,
)(x)
skip = Convolution2D(
base * k,
(1, 1),
padding="same",
strides=strides,
kernel_initializer="Orthogonal",
kernel_regularizer=l2(self.weight_decay),
kernel_constraint=tight_frame(0.001),
use_bias=False,
)(init)
m = Add()([x, skip])
return m
def conv1_block(self, input, k=1, dropout=0.0):
"""[summary]
Args:
input ([type]): [description]
k (int, optional): [description]. Defaults to 1.
dropout (float, optional): [description]. Defaults to 0.0.
Returns:
[type]: [description]
"""
init = input
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
x = BatchNormalization(
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
)(input)
x = Activation("relu")(x)
x = Convolution2D(
16 * k,
(3, 3),
padding="same",
kernel_initializer="Orthogonal",
kernel_regularizer=l2(self.weight_decay),
kernel_constraint=tight_frame(0.001),
use_bias=False,
)(x)
if dropout > 0.0:
x = Dropout(dropout)(x)
x = BatchNormalization(
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
)(x)
x = Activation("relu")(x)
x = Convolution2D(
16 * k,
(3, 3),
padding="same",
kernel_initializer="Orthogonal",
kernel_regularizer=l2(self.weight_decay),
kernel_constraint=tight_frame(0.001),
use_bias=False,
)(x)
m = convex_add(init, x, initial_convex_par=0.5, trainable=True)
return m
def conv2_block(self, input, k=1, dropout=0.0):
"""[summary]
Args:
input ([type]): [description]
k (int, optional): [description]. Defaults to 1.
dropout (float, optional): [description]. Defaults to 0.0.
Returns:
[type]: [description]
"""
init = input
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
x = BatchNormalization(
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
)(input)
x = Activation("relu")(x)
x = Convolution2D(
32 * k,
(3, 3),
padding="same",
kernel_initializer="Orthogonal",
kernel_regularizer=l2(self.weight_decay),
kernel_constraint=tight_frame(0.001),
use_bias=False,
)(x)
if dropout > 0.0:
x = Dropout(dropout)(x)
x = BatchNormalization(
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
)(x)
x = Activation("relu")(x)
x = Convolution2D(
32 * k,
(3, 3),
padding="same",
kernel_initializer="Orthogonal",
kernel_regularizer=l2(self.weight_decay),
kernel_constraint=tight_frame(0.001),
use_bias=False,
)(x)
m = convex_add(init, x, initial_convex_par=0.5, trainable=True)
return m
def conv3_block(self, input, k=1, dropout=0.0):
init = input
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
x = BatchNormalization(
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
)(input)
x = Activation("relu")(x)
x = Convolution2D(
64 * k,
(3, 3),
padding="same",
kernel_initializer="Orthogonal",
kernel_constraint=tight_frame(0.001),
kernel_regularizer=l2(self.weight_decay),
use_bias=False,
)(x)
if dropout > 0.0:
x = Dropout(dropout)(x)
x = BatchNormalization(
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
)(x)
x = Activation("relu")(x)
x = Convolution2D(
64 * k,
(3, 3),
padding="same",
kernel_initializer="Orthogonal",
kernel_constraint=tight_frame(0.001),
kernel_regularizer=l2(self.weight_decay),
use_bias=False,
)(x)
m = convex_add(init, x, initial_convex_par=0.5, trainable=True)
return m
def create_wide_residual_network(self):
"""create a wide residual network model
Returns:
[Model]: [wide residual network]
"""
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
ip = Input(shape=self.input_dim)
x = self.initial_conv(ip)
nb_conv = 4
x = self.expand_conv(x, 16, self.k)
nb_conv += 2
for i in range(self.N - 1):
x = self.conv1_block(x, self.k, self.dropout)
nb_conv += 2
x = BatchNormalization(
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
)(x)
x = Activation("relu")(x)
x = self.expand_conv(x, 32, self.k, strides=(2, 2))
nb_conv += 2
for i in range(self.N - 1):
x = self.conv2_block(x, self.k, self.dropout)
nb_conv += 2
x = BatchNormalization(
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
)(x)
x = Activation("relu")(x)
x = self.expand_conv(x, 64, self.k, strides=(2, 2))
nb_conv += 2
for i in range(self.N - 1):
x = self.conv3_block(x, self.k, self.dropout)
nb_conv += 2
x = BatchNormalization(
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
)(x)
x = Activation("relu")(x)
x = AveragePooling2D((8, 8))(x)
x = Flatten()(x)
x = Dense(
self.nb_classes,
kernel_regularizer=l2(self.weight_decay),
activation="softmax",
)(x)
model = Model(ip, x)
if self.verbose:
print("Parseval Network-%d-%d created." % (nb_conv, self.k))
return model