File size: 6,964 Bytes
29a525e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import tensorflow as tf
def Downsampling(x, filters, kernel_size = (5,5), padding = 'same', stride = 2, multires = False):
'''
Downsampling Block
Arguments:
x : input layer (tf.keras.layer)
filters : number of filters (int)
kernel_size : kernel dimensions (tuple or int), default (5,5)
padding : padding type for convolution (string), default same
stride : stride for convolution (tuple or int), default 2
Returns:
output : output layer (tf.keras.layer)
'''
if multires == False:
conv = tf.keras.layers.Conv2D(kernel_size = kernel_size, filters = filters, strides = stride, padding = padding,data_format = "channels_last")(x)
elif multires == True:
conv = tf.keras.layers.Conv2D(kernel_size = kernel_size, filters = filters//2, strides = stride, padding = padding,data_format = "channels_last")(x)
conv3 = tf.keras.layers.Conv2D(kernel_size = (3,3), filters = filters//4, strides = stride, padding = padding,data_format = "channels_last")(x)
conv7 = tf.keras.layers.Conv2D(kernel_size = (7,7), filters = filters//4, strides = stride, padding = padding,data_format = "channels_last")(x)
conv = tf.keras.layers.Concatenate()([conv, conv3, conv7])
bn = tf.keras.layers.BatchNormalization()(conv)
output = tf.keras.layers.LeakyReLU(0.2)(bn)
return output
def Upsampling(x , y, filters, res_filts, kernel_size = (5,5), padding = 'same', stride = 2, dropout = 'False', resblock = True, se_block = False):
'''
Upsampling Block
Arguments:
x : input layer (tf.keras.layer)
y : residual connection layer (tf.keras.layer)
filters : number of filters (int)
kernel_size : kernel dimensions (tuple or int), default (5,5)
padding : padding type for convolution (string), default same
stride : stride for convolution (tuple or int), default 2
dropout : dropout (boolean), default False
Returns:
output : output layer (tf.keras.layer)
'''
conv = tf.keras.layers.Conv2DTranspose(kernel_size = kernel_size, filters = filters, strides = stride, padding = padding, data_format = "channels_last")(x)
act = tf.keras.layers.ReLU()(conv)
output = tf.keras.layers.BatchNormalization()(act)
if dropout == 'True':
output = tf.keras.layers.Dropout(0.5)(output)
if y is not None:
if resblock is True:
y = ResBlock(y, depth = 2, filters = res_filts)
output = tf.keras.layers.Concatenate()([y, output])
if se_block is True:
output = SE_Block(output, r = 16)
return output
def ResBlock(x, filters, depth = 2, kernel_size = (5,5), padding = 'same', method = 'concat', se_block = False):
'''
ResNet Block
Arguments:
x : input layer (tf.keras.layer)
depth : number of layers in ResBlock
filters : number of filters (int)
kernel_size : kernel dimensions (tuple or int), default (5,5)
padding : padding type for convolution (string), default same
dropout : dropout (boolean), default False
Returns:
output : output layer (tf.keras.layer)
'''
conv = tf.keras.layers.Conv2D(kernel_size = kernel_size, filters = filters, padding = padding, data_format = "channels_last")(x)
conv = tf.keras.layers.ReLU()(conv)
conv = tf.keras.layers.BatchNormalization()(conv)
for i in range(0,depth-1):
conv = tf.keras.layers.Conv2D(kernel_size = kernel_size, filters = filters, padding = padding, data_format = "channels_last")(conv)
conv = tf.keras.layers.ReLU()(conv)
conv = tf.keras.layers.BatchNormalization()(conv)
if method == 'add':
output = tf.keras.layers.Add()([x, conv])
elif method == 'concat':
output = tf.keras.layers.Concatenate()([x, conv])
output = tf.keras.layers.ReLU()(output)
if se_block is True:
output = SE_Block(output, r = 16)
return output
def SE_Block(x, r = 16):
'''
Squeeze and Excitation Block
Assumes channel_last format
Arguments:
x : input layer (tf.keras.layer)
r : reduction ratio for first FC layer
Returns:
output : output layer (tf.keras.layer)
'''
filters = x.shape[-1]
pool = tf.keras.layers.GlobalAveragePooling2D(data_format='channels_last')(x)
fc1 = tf.keras.layers.Dense(int(filters/r))(pool)
fc1 = tf.keras.layers.ReLU()(fc1)
fc2 = tf.keras.layers.Dense(filters)(fc1)
fc2 = tf.keras.layers.Activation('sigmoid')(fc2)
output = tf.keras.layers.Reshape([1,1,filters])(fc2)
output = tf.keras.layers.Multiply()([x,output])
return output
def Steminator(input_shape = (256,128,1), kernel_size = (5,5), feature_maps = 8, multires = True, resblock = True, se_block = True):
'''
MultiResUnet Network Builder - Steminator
Arguments:
input_shape : input shape (tuple)
depth : number of layers in ResBlock
feature_maps : number of initial filters (int)
kernel_size : kernel dimensions (tuple or int), default (5,5)
multires : use multi-res Unet (boolean), default True
resblock : use resblock residual connections (boolean), default True
Returns:
model : tf.keras Neural net model (tf.keras.Model)
'''
cqt_input = tf.keras.Input(shape=input_shape)
ds_0 = Downsampling(cqt_input, filters = feature_maps*2, multires = multires)
ds_1 = Downsampling(ds_0, filters = feature_maps*4, multires = multires)
ds_2 = Downsampling(ds_1, filters = feature_maps*8, multires = multires)
ds_3 = Downsampling(ds_2, filters = feature_maps*16, multires = multires)
ds_4 = Downsampling(ds_3, filters = feature_maps*32, multires = multires)
ds_5 = Downsampling(ds_4, filters = feature_maps*64, multires = multires)
us_0 = Upsampling(ds_5,ds_4,filters = feature_maps*32, res_filts = feature_maps, dropout = 'True', resblock = resblock)
us_1 = Upsampling(us_0,ds_3,filters = feature_maps*16, res_filts = feature_maps*2, dropout = 'True', resblock = resblock)
us_2 = Upsampling(us_1,ds_2,filters = feature_maps*8, res_filts = feature_maps*4, dropout = 'True', resblock = resblock)
us_3 = Upsampling(us_2,ds_1,filters = feature_maps*4, res_filts = feature_maps*8, resblock = resblock)
us_4 = Upsampling(us_3,ds_0,filters = feature_maps*2, res_filts = feature_maps*16, resblock = resblock, se_block = False)
us_5 = Upsampling(us_4,None,filters = feature_maps, res_filts = feature_maps*32, resblock = resblock, se_block = False)
mask = tf.keras.layers.Conv2D(kernel_size = (1,1), filters = 1,activation='relu', padding = 'same',data_format="channels_last")(us_5) #original network kernel_size = (1,1)
outputs = tf.keras.layers.Multiply()([cqt_input,mask])
model = tf.keras.Model(inputs = cqt_input, outputs = outputs, name='Steminator')
#model.summary()
return model |