File size: 6,964 Bytes
29a525e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import tensorflow as tf

def Downsampling(x, filters, kernel_size = (5,5), padding = 'same', stride = 2, multires = False):
    '''
    Downsampling Block
    
    Arguments:
        x : input layer (tf.keras.layer)
        filters : number of filters (int)
        kernel_size : kernel dimensions (tuple or int), default (5,5)
        padding : padding type for convolution (string), default same
        stride : stride for convolution (tuple or int), default 2
    
    Returns:
        output : output layer (tf.keras.layer)
    '''
    if multires == False:
      conv = tf.keras.layers.Conv2D(kernel_size = kernel_size, filters = filters, strides = stride, padding = padding,data_format = "channels_last")(x)
    elif multires == True:
      conv = tf.keras.layers.Conv2D(kernel_size = kernel_size, filters = filters//2, strides = stride, padding = padding,data_format = "channels_last")(x)
      conv3 = tf.keras.layers.Conv2D(kernel_size = (3,3), filters = filters//4, strides = stride, padding = padding,data_format = "channels_last")(x)
      conv7 = tf.keras.layers.Conv2D(kernel_size = (7,7), filters = filters//4, strides = stride, padding = padding,data_format = "channels_last")(x)
      conv = tf.keras.layers.Concatenate()([conv, conv3, conv7])
    bn = tf.keras.layers.BatchNormalization()(conv)
    output = tf.keras.layers.LeakyReLU(0.2)(bn)

    return output

def Upsampling(x , y, filters, res_filts, kernel_size = (5,5), padding = 'same', stride = 2, dropout = 'False', resblock = True, se_block = False):
    '''
    Upsampling Block
    
    Arguments:
        x : input layer (tf.keras.layer)
        y : residual connection layer (tf.keras.layer)
        filters : number of filters (int)
        kernel_size : kernel dimensions (tuple or int), default (5,5)
        padding : padding type for convolution (string), default same
        stride : stride for convolution (tuple or int), default 2
        dropout : dropout (boolean), default False
    
    Returns:
        output : output layer (tf.keras.layer)
    '''
    
    conv = tf.keras.layers.Conv2DTranspose(kernel_size = kernel_size, filters = filters, strides = stride, padding = padding, data_format = "channels_last")(x)
    act = tf.keras.layers.ReLU()(conv)
    output = tf.keras.layers.BatchNormalization()(act)
    if dropout == 'True':
      output = tf.keras.layers.Dropout(0.5)(output)
    if y is not None:
      if resblock is True:
        y = ResBlock(y, depth = 2, filters = res_filts)
      output = tf.keras.layers.Concatenate()([y, output])
    if se_block is True:
      output = SE_Block(output, r = 16)
    return output

def ResBlock(x, filters, depth = 2, kernel_size = (5,5), padding = 'same', method = 'concat', se_block = False):
      '''
    ResNet Block
    
    Arguments:
        x : input layer (tf.keras.layer)
        depth : number of layers in ResBlock
        filters : number of filters (int)
        kernel_size : kernel dimensions (tuple or int), default (5,5)
        padding : padding type for convolution (string), default same
        dropout : dropout (boolean), default False
    
    Returns:
        output : output layer (tf.keras.layer)
    '''

      conv = tf.keras.layers.Conv2D(kernel_size = kernel_size, filters = filters, padding = padding, data_format = "channels_last")(x)
      conv = tf.keras.layers.ReLU()(conv)
      conv = tf.keras.layers.BatchNormalization()(conv)
      for i in range(0,depth-1):
        conv = tf.keras.layers.Conv2D(kernel_size = kernel_size, filters = filters, padding = padding, data_format = "channels_last")(conv)
        conv = tf.keras.layers.ReLU()(conv)
        conv = tf.keras.layers.BatchNormalization()(conv)
      if method == 'add':
        output = tf.keras.layers.Add()([x, conv])
      elif method == 'concat':
        output = tf.keras.layers.Concatenate()([x, conv])
      
      output = tf.keras.layers.ReLU()(output)

      if se_block is True:
       output = SE_Block(output, r = 16)

      return output

def SE_Block(x, r = 16):

  '''
    Squeeze and Excitation Block
    Assumes channel_last format
    
    Arguments:
        x : input layer (tf.keras.layer)
        r : reduction ratio for first FC layer
    
    Returns:
        output : output layer (tf.keras.layer)
  '''
  filters = x.shape[-1]
  pool = tf.keras.layers.GlobalAveragePooling2D(data_format='channels_last')(x)
  fc1 = tf.keras.layers.Dense(int(filters/r))(pool)
  fc1 = tf.keras.layers.ReLU()(fc1)
  fc2 = tf.keras.layers.Dense(filters)(fc1)
  fc2 = tf.keras.layers.Activation('sigmoid')(fc2)
  output = tf.keras.layers.Reshape([1,1,filters])(fc2)

  output = tf.keras.layers.Multiply()([x,output])

  return output

def Steminator(input_shape = (256,128,1), kernel_size = (5,5), feature_maps = 8, multires = True, resblock = True, se_block = True):

    '''
    MultiResUnet Network Builder - Steminator
    
    Arguments:
        input_shape : input shape (tuple)
        depth : number of layers in ResBlock
        feature_maps : number of initial filters (int)
        kernel_size : kernel dimensions (tuple or int), default (5,5)
        multires : use multi-res Unet (boolean), default True
        resblock : use resblock residual connections (boolean), default True
    
    Returns:
        model : tf.keras Neural net model (tf.keras.Model)
    '''

    cqt_input = tf.keras.Input(shape=input_shape)
    
    ds_0 = Downsampling(cqt_input, filters = feature_maps*2, multires = multires)
    ds_1 = Downsampling(ds_0, filters = feature_maps*4, multires = multires)
    ds_2 = Downsampling(ds_1, filters = feature_maps*8, multires = multires)
    ds_3 = Downsampling(ds_2, filters = feature_maps*16, multires = multires)
    ds_4 = Downsampling(ds_3, filters = feature_maps*32, multires = multires)
    ds_5 = Downsampling(ds_4, filters = feature_maps*64, multires = multires)
    
    us_0 = Upsampling(ds_5,ds_4,filters = feature_maps*32, res_filts = feature_maps, dropout = 'True', resblock = resblock)
    us_1 = Upsampling(us_0,ds_3,filters = feature_maps*16, res_filts = feature_maps*2, dropout = 'True', resblock = resblock)
    us_2 = Upsampling(us_1,ds_2,filters = feature_maps*8, res_filts = feature_maps*4, dropout = 'True', resblock = resblock)
    us_3 = Upsampling(us_2,ds_1,filters = feature_maps*4, res_filts = feature_maps*8, resblock = resblock)
    us_4 = Upsampling(us_3,ds_0,filters = feature_maps*2, res_filts = feature_maps*16, resblock = resblock, se_block = False)
    us_5 = Upsampling(us_4,None,filters = feature_maps, res_filts = feature_maps*32, resblock = resblock, se_block = False)

    
    mask = tf.keras.layers.Conv2D(kernel_size = (1,1), filters = 1,activation='relu', padding = 'same',data_format="channels_last")(us_5) #original network kernel_size = (1,1)

    outputs = tf.keras.layers.Multiply()([cqt_input,mask])
  
    model = tf.keras.Model(inputs = cqt_input, outputs = outputs, name='Steminator')

    #model.summary()

    return model