|
from tensorflow.keras.models import Model |
|
from tensorflow.keras.layers import Input, Add, Activation, Dropout, Flatten, Dense |
|
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, AveragePooling2D |
|
from tensorflow.keras.layers import BatchNormalization |
|
from tensorflow.keras.regularizers import l2 |
|
from tensorflow.keras import backend as K |
|
from tensorflow.keras.optimizers import SGD |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
class WideResidualNetwork(object): |
|
def __init__( |
|
self, |
|
input_dim, |
|
weight_decay, |
|
momentum, |
|
nb_classes=100, |
|
N=2, |
|
k=1, |
|
dropout=0.0, |
|
verbose=1, |
|
): |
|
"""[Assign the initial parameters of the wide residual network] |
|
|
|
Args: |
|
weight_decay ([float]): [description] |
|
input_dim ([tuple]): [input dimension] |
|
nb_classes (int, optional): [output class]. Defaults to 100. |
|
N (int, optional): [the number of blocks]. Defaults to 2. |
|
k (int, optional): [network width]. Defaults to 1. |
|
dropout (float, optional): [dropout value to prevent overfitting]. Defaults to 0.0. |
|
verbose (int, optional): [description]. Defaults to 1. |
|
|
|
Returns: |
|
[Model]: [wideresnet] |
|
""" |
|
self.weight_decay = weight_decay |
|
self.input_dim = input_dim |
|
self.nb_classes = nb_classes |
|
self.N = N |
|
self.k = k |
|
self.dropout = dropout |
|
self.verbose = verbose |
|
|
|
def initial_conv(self, input): |
|
"""[summary] |
|
|
|
Args: |
|
input ([type]): [description] |
|
|
|
Returns: |
|
[type]: [description] |
|
""" |
|
x = Convolution2D( |
|
16, |
|
(3, 3), |
|
padding="same", |
|
kernel_initializer="he_normal", |
|
kernel_regularizer=l2(self.weight_decay), |
|
use_bias=False, |
|
)(input) |
|
|
|
channel_axis = 1 if K.image_data_format() == "channels_first" else -1 |
|
|
|
x = BatchNormalization( |
|
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" |
|
)(x) |
|
x = Activation("relu")(x) |
|
return x |
|
|
|
def expand_conv(self, init, base, k, strides=(1, 1)): |
|
"""[summary] |
|
|
|
Args: |
|
init ([type]): [description] |
|
base ([type]): [description] |
|
k ([type]): [description] |
|
strides (tuple, optional): [description]. Defaults to (1, 1). |
|
|
|
Returns: |
|
[type]: [description] |
|
""" |
|
x = Convolution2D( |
|
base * k, |
|
(3, 3), |
|
padding="same", |
|
strides=strides, |
|
kernel_initializer="he_normal", |
|
kernel_regularizer=l2(self.weight_decay), |
|
use_bias=False, |
|
)(init) |
|
|
|
channel_axis = 1 if K.image_data_format() == "channels_first" else -1 |
|
|
|
x = BatchNormalization( |
|
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" |
|
)(x) |
|
x = Activation("relu")(x) |
|
|
|
x = Convolution2D( |
|
base * k, |
|
(3, 3), |
|
padding="same", |
|
kernel_initializer="he_normal", |
|
kernel_regularizer=l2(self.weight_decay), |
|
use_bias=False, |
|
)(x) |
|
|
|
skip = Convolution2D( |
|
base * k, |
|
(1, 1), |
|
padding="same", |
|
strides=strides, |
|
kernel_initializer="he_normal", |
|
kernel_regularizer=l2(self.weight_decay), |
|
use_bias=False, |
|
)(init) |
|
|
|
m = Add()([x, skip]) |
|
|
|
return m |
|
|
|
def conv1_block(self, input, k=1, dropout=0.0): |
|
"""[summary] |
|
|
|
Args: |
|
input ([type]): [description] |
|
k (int, optional): [description]. Defaults to 1. |
|
dropout (float, optional): [description]. Defaults to 0.0. |
|
|
|
Returns: |
|
[type]: [description] |
|
""" |
|
init = input |
|
|
|
channel_axis = 1 if K.image_data_format() == "channels_first" else -1 |
|
|
|
x = BatchNormalization( |
|
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" |
|
)(input) |
|
x = Activation("relu")(x) |
|
x = Convolution2D( |
|
16 * k, |
|
(3, 3), |
|
padding="same", |
|
kernel_initializer="he_normal", |
|
kernel_regularizer=l2(self.weight_decay), |
|
use_bias=False, |
|
)(x) |
|
|
|
if dropout > 0.0: |
|
x = Dropout(dropout)(x) |
|
|
|
x = BatchNormalization( |
|
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" |
|
)(x) |
|
x = Activation("relu")(x) |
|
x = Convolution2D( |
|
16 * k, |
|
(3, 3), |
|
padding="same", |
|
kernel_initializer="he_normal", |
|
kernel_regularizer=l2(self.weight_decay), |
|
use_bias=False, |
|
)(x) |
|
|
|
m = Add()([init, x]) |
|
return m |
|
|
|
def conv2_block(self, input, k=1, dropout=0.0): |
|
"""[summary] |
|
|
|
Args: |
|
input ([type]): [description] |
|
k (int, optional): [description]. Defaults to 1. |
|
dropout (float, optional): [description]. Defaults to 0.0. |
|
|
|
Returns: |
|
[type]: [description] |
|
""" |
|
init = input |
|
|
|
channel_axis = 1 if K.image_data_format() == "channels_first" else -1 |
|
print("conv2:channel: {}".format(channel_axis)) |
|
x = BatchNormalization( |
|
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" |
|
)(input) |
|
x = Activation("relu")(x) |
|
x = Convolution2D( |
|
32 * k, |
|
(3, 3), |
|
padding="same", |
|
kernel_initializer="he_normal", |
|
kernel_regularizer=l2(self.weight_decay), |
|
use_bias=False, |
|
)(x) |
|
|
|
if dropout > 0.0: |
|
x = Dropout(dropout)(x) |
|
|
|
x = BatchNormalization( |
|
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" |
|
)(x) |
|
x = Activation("relu")(x) |
|
x = Convolution2D( |
|
32 * k, |
|
(3, 3), |
|
padding="same", |
|
kernel_initializer="he_normal", |
|
kernel_regularizer=l2(self.weight_decay), |
|
use_bias=False, |
|
)(x) |
|
|
|
m = Add()([init, x]) |
|
return m |
|
|
|
def conv3_block(self, input, k=1, dropout=0.0): |
|
"""[summary] |
|
|
|
Args: |
|
input ([type]): [description] |
|
k (int, optional): [description]. Defaults to 1. |
|
dropout (float, optional): [description]. Defaults to 0.0. |
|
|
|
Returns: |
|
[type]: [description] |
|
""" |
|
init = input |
|
|
|
channel_axis = 1 if K.image_data_format() == "channels_first" else -1 |
|
|
|
x = BatchNormalization( |
|
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" |
|
)(input) |
|
x = Activation("relu")(x) |
|
x = Convolution2D( |
|
64 * k, |
|
(3, 3), |
|
padding="same", |
|
kernel_initializer="he_normal", |
|
kernel_regularizer=l2(self.weight_decay), |
|
use_bias=False, |
|
)(x) |
|
|
|
if dropout > 0.0: |
|
x = Dropout(dropout)(x) |
|
|
|
x = BatchNormalization( |
|
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" |
|
)(x) |
|
x = Activation("relu")(x) |
|
x = Convolution2D( |
|
64 * k, |
|
(3, 3), |
|
padding="same", |
|
kernel_initializer="he_normal", |
|
kernel_regularizer=l2(self.weight_decay), |
|
use_bias=False, |
|
)(x) |
|
|
|
m = Add()([init, x]) |
|
return m |
|
|
|
def create_wide_residual_network(self): |
|
"""create a wide residual network model |
|
|
|
|
|
Returns: |
|
[Model]: [wide residual network] |
|
""" |
|
channel_axis = 1 if K.image_data_format() == "channels_first" else -1 |
|
|
|
ip = Input(shape=self.input_dim) |
|
|
|
x = self.initial_conv(ip) |
|
nb_conv = 4 |
|
|
|
x = self.expand_conv(x, 16, self.k) |
|
nb_conv += 2 |
|
|
|
for i in range(self.N - 1): |
|
x = self.conv1_block(x, self.k, self.dropout) |
|
nb_conv += 2 |
|
|
|
x = BatchNormalization( |
|
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" |
|
)(x) |
|
x = Activation("relu")(x) |
|
|
|
x = self.expand_conv(x, 32, self.k, strides=(2, 2)) |
|
nb_conv += 2 |
|
|
|
for i in range(self.N - 1): |
|
x = self.conv2_block(x, self.k, self.dropout) |
|
nb_conv += 2 |
|
|
|
x = BatchNormalization( |
|
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" |
|
)(x) |
|
x = Activation("relu")(x) |
|
|
|
x = self.expand_conv(x, 64, self.k, strides=(2, 2)) |
|
nb_conv += 2 |
|
|
|
for i in range(self.N - 1): |
|
x = self.conv3_block(x, self.k, self.dropout) |
|
nb_conv += 2 |
|
|
|
x = BatchNormalization( |
|
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" |
|
)(x) |
|
x = Activation("relu")(x) |
|
|
|
x = AveragePooling2D((8, 8))(x) |
|
x = Flatten()(x) |
|
|
|
x = Dense( |
|
self.nb_classes, |
|
kernel_regularizer=l2(self.weight_decay), |
|
activation="softmax", |
|
)(x) |
|
|
|
model = Model(ip, x) |
|
|
|
if self.verbose: |
|
print("Wide Residual Network-%d-%d created." % (nb_conv, self.k)) |
|
return model |
|
|