minchul commited on
Commit
4b4ca75
·
verified ·
1 Parent(s): ea689fe

Upload directory

Browse files
Files changed (1) hide show
  1. models/iresnet/model.py +340 -0
models/iresnet/model.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ from torch.nn import Dropout
3
+ from torch.nn import MaxPool2d
4
+ from torch.nn import Sequential
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import Conv2d, Linear
8
+ from torch.nn import BatchNorm1d, BatchNorm2d
9
+ from torch.nn import ReLU, Sigmoid
10
+ from torch.nn import Module
11
+ from torch.nn import PReLU
12
+ from fvcore.nn import flop_count
13
+ import numpy as np
14
+
15
+
16
+ def initialize_weights(modules):
17
+ for m in modules:
18
+ if isinstance(m, nn.Conv2d):
19
+ nn.init.kaiming_normal_(m.weight,
20
+ mode='fan_out',
21
+ nonlinearity='relu')
22
+ if m.bias is not None:
23
+ m.bias.data.zero_()
24
+ elif isinstance(m, nn.BatchNorm2d):
25
+ m.weight.data.fill_(1)
26
+ m.bias.data.zero_()
27
+ elif isinstance(m, nn.Linear):
28
+ nn.init.kaiming_normal_(m.weight,
29
+ mode='fan_out',
30
+ nonlinearity='relu')
31
+ if m.bias is not None:
32
+ m.bias.data.zero_()
33
+
34
+
35
+ class Flatten(Module):
36
+ def forward(self, input):
37
+ return input.view(input.size(0), -1)
38
+
39
+
40
+ class LinearBlock(Module):
41
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
42
+ super(LinearBlock, self).__init__()
43
+ self.conv = Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False)
44
+ self.bn = BatchNorm2d(out_c)
45
+
46
+ def forward(self, x):
47
+ x = self.conv(x)
48
+ x = self.bn(x)
49
+ return x
50
+
51
+ class SEModule(Module):
52
+ def __init__(self, channels, reduction):
53
+ super(SEModule, self).__init__()
54
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
55
+ self.fc1 = Conv2d(channels, channels // reduction,
56
+ kernel_size=1, padding=0, bias=False)
57
+
58
+ nn.init.xavier_uniform_(self.fc1.weight.data)
59
+
60
+ self.relu = ReLU(inplace=True)
61
+ self.fc2 = Conv2d(channels // reduction, channels,
62
+ kernel_size=1, padding=0, bias=False)
63
+
64
+ self.sigmoid = Sigmoid()
65
+
66
+ def forward(self, x):
67
+ module_input = x
68
+ x = self.avg_pool(x)
69
+ x = self.fc1(x)
70
+ x = self.relu(x)
71
+ x = self.fc2(x)
72
+ x = self.sigmoid(x)
73
+
74
+ return module_input * x
75
+
76
+
77
+
78
+ class BasicBlockIR(Module):
79
+ def __init__(self, in_channel, depth, stride):
80
+ super(BasicBlockIR, self).__init__()
81
+ if in_channel == depth:
82
+ self.shortcut_layer = MaxPool2d(1, stride)
83
+ else:
84
+ self.shortcut_layer = Sequential(
85
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
86
+ BatchNorm2d(depth))
87
+ self.res_layer = Sequential(
88
+ BatchNorm2d(in_channel),
89
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
90
+ BatchNorm2d(depth),
91
+ PReLU(depth),
92
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
93
+ BatchNorm2d(depth))
94
+
95
+ def forward(self, x):
96
+ shortcut = self.shortcut_layer(x)
97
+ res = self.res_layer(x)
98
+
99
+ return res + shortcut
100
+
101
+
102
+ class BottleneckIR(Module):
103
+ def __init__(self, in_channel, depth, stride):
104
+ super(BottleneckIR, self).__init__()
105
+ reduction_channel = depth // 4
106
+ if in_channel == depth:
107
+ self.shortcut_layer = MaxPool2d(1, stride)
108
+ else:
109
+ self.shortcut_layer = Sequential(
110
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
111
+ BatchNorm2d(depth))
112
+ self.res_layer = Sequential(
113
+ BatchNorm2d(in_channel),
114
+ Conv2d(in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False),
115
+ BatchNorm2d(reduction_channel),
116
+ PReLU(reduction_channel),
117
+ Conv2d(reduction_channel, reduction_channel, (3, 3), (1, 1), 1, bias=False),
118
+ BatchNorm2d(reduction_channel),
119
+ PReLU(reduction_channel),
120
+ Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False),
121
+ BatchNorm2d(depth))
122
+
123
+ def forward(self, x):
124
+ shortcut = self.shortcut_layer(x)
125
+ res = self.res_layer(x)
126
+
127
+ return res + shortcut
128
+
129
+
130
+ class BasicBlockIRSE(BasicBlockIR):
131
+ def __init__(self, in_channel, depth, stride):
132
+ super(BasicBlockIRSE, self).__init__(in_channel, depth, stride)
133
+ self.res_layer.add_module("se_block", SEModule(depth, 16))
134
+
135
+
136
+ class BottleneckIRSE(BottleneckIR):
137
+ def __init__(self, in_channel, depth, stride):
138
+ super(BottleneckIRSE, self).__init__(in_channel, depth, stride)
139
+ self.res_layer.add_module("se_block", SEModule(depth, 16))
140
+
141
+
142
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
143
+ pass
144
+
145
+
146
+ def get_block(in_channel, depth, num_units, stride=2):
147
+
148
+ return [Bottleneck(in_channel, depth, stride)] + \
149
+ [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
150
+
151
+
152
+ def get_blocks(num_layers):
153
+ if num_layers == 18:
154
+ blocks = [
155
+ get_block(in_channel=64, depth=64, num_units=2),
156
+ get_block(in_channel=64, depth=128, num_units=2),
157
+ get_block(in_channel=128, depth=256, num_units=2),
158
+ get_block(in_channel=256, depth=512, num_units=2)
159
+ ]
160
+ elif num_layers == 34:
161
+ blocks = [
162
+ get_block(in_channel=64, depth=64, num_units=3),
163
+ get_block(in_channel=64, depth=128, num_units=4),
164
+ get_block(in_channel=128, depth=256, num_units=6),
165
+ get_block(in_channel=256, depth=512, num_units=3)
166
+ ]
167
+ elif num_layers == 50:
168
+ blocks = [
169
+ get_block(in_channel=64, depth=64, num_units=3),
170
+ get_block(in_channel=64, depth=128, num_units=4),
171
+ get_block(in_channel=128, depth=256, num_units=14),
172
+ get_block(in_channel=256, depth=512, num_units=3)
173
+ ]
174
+ elif num_layers == 100:
175
+ blocks = [
176
+ get_block(in_channel=64, depth=64, num_units=3),
177
+ get_block(in_channel=64, depth=128, num_units=13),
178
+ get_block(in_channel=128, depth=256, num_units=30),
179
+ get_block(in_channel=256, depth=512, num_units=3)
180
+ ]
181
+ elif num_layers == 152:
182
+ blocks = [
183
+ get_block(in_channel=64, depth=256, num_units=3),
184
+ get_block(in_channel=256, depth=512, num_units=8),
185
+ get_block(in_channel=512, depth=1024, num_units=36),
186
+ get_block(in_channel=1024, depth=2048, num_units=3)
187
+ ]
188
+ elif num_layers == 200:
189
+ blocks = [
190
+ get_block(in_channel=64, depth=256, num_units=3),
191
+ get_block(in_channel=256, depth=512, num_units=24),
192
+ get_block(in_channel=512, depth=1024, num_units=36),
193
+ get_block(in_channel=1024, depth=2048, num_units=3)
194
+ ]
195
+
196
+ return blocks
197
+
198
+
199
+ class Backbone(Module):
200
+
201
+ def __init__(self, input_size, num_layers, mode='ir', flip=False, output_dim=512):
202
+ super(Backbone, self).__init__()
203
+ assert input_size[0] in [112, 224], \
204
+ "input_size should be [112, 112] or [224, 224]"
205
+ assert num_layers in [18, 34, 50, 100, 152, 200], \
206
+ "num_layers should be 18, 34, 50, 100 or 152"
207
+ assert mode in ['ir', 'ir_se'], \
208
+ "mode should be ir or ir_se"
209
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
210
+ BatchNorm2d(64), PReLU(64))
211
+ blocks = get_blocks(num_layers)
212
+ if num_layers <= 100:
213
+ if mode == 'ir':
214
+ unit_module = BasicBlockIR
215
+ elif mode == 'ir_se':
216
+ unit_module = BasicBlockIRSE
217
+ output_channel = 512
218
+ else:
219
+ if mode == 'ir':
220
+ unit_module = BottleneckIR
221
+ elif mode == 'ir_se':
222
+ unit_module = BottleneckIRSE
223
+ output_channel = 2048
224
+
225
+ if input_size[0] == 112:
226
+ self.output_layer = Sequential(BatchNorm2d(output_channel),
227
+ Dropout(0.4), Flatten(),
228
+ Linear(output_channel * 7 * 7, output_dim),
229
+ BatchNorm1d(output_dim, affine=False))
230
+ else:
231
+ self.output_layer = Sequential(
232
+ BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
233
+ Linear(output_channel * 14 * 14, output_dim),
234
+ BatchNorm1d(output_dim, affine=False))
235
+
236
+ modules = []
237
+ for block in blocks:
238
+ for bottleneck in block:
239
+ modules.append(
240
+ unit_module(bottleneck.in_channel, bottleneck.depth,
241
+ bottleneck.stride))
242
+ self.body = Sequential(*modules)
243
+
244
+ initialize_weights(self.modules())
245
+
246
+ self.flip = flip
247
+
248
+
249
+ def forward(self, x):
250
+
251
+ if self.flip:
252
+ x = x.flip(1) # color channel flip
253
+
254
+ x = self.input_layer(x)
255
+ for idx, module in enumerate(self.body):
256
+ x = module(x)
257
+
258
+ x = self.output_layer(x)
259
+ return x
260
+
261
+
262
+
263
+ def IR_18(input_size, output_dim=512):
264
+ model = Backbone(input_size, 18, 'ir', output_dim=output_dim)
265
+
266
+ return model
267
+
268
+
269
+ def IR_34(input_size, output_dim=512):
270
+ model = Backbone(input_size, 34, 'ir', output_dim=output_dim)
271
+
272
+ return model
273
+
274
+
275
+ def IR_50(input_size, output_dim=512):
276
+ model = Backbone(input_size, 50, 'ir', output_dim=output_dim)
277
+
278
+ return model
279
+
280
+
281
+ def IR_101(input_size, output_dim=512):
282
+ model = Backbone(input_size, 100, 'ir', output_dim=output_dim)
283
+
284
+ return model
285
+
286
+
287
+ def IR_101_FLIP(input_size, output_dim=512):
288
+ model = Backbone(input_size, 100, 'ir', flip=True, output_dim=output_dim)
289
+
290
+ return model
291
+
292
+
293
+
294
+ def IR_152(input_size, output_dim=512):
295
+ model = Backbone(input_size, 152, 'ir', output_dim=output_dim)
296
+
297
+ return model
298
+
299
+
300
+ def IR_200(input_size, output_dim=512):
301
+ model = Backbone(input_size, 200, 'ir', output_dim=output_dim)
302
+
303
+ return model
304
+
305
+
306
+ def IR_SE_50(input_size, output_dim=512):
307
+ model = Backbone(input_size, 50, 'ir_se', output_dim=output_dim)
308
+
309
+ return model
310
+
311
+
312
+ def IR_SE_101(input_size, output_dim=512):
313
+ model = Backbone(input_size, 100, 'ir_se', output_dim=output_dim)
314
+
315
+ return model
316
+
317
+
318
+ def IR_SE_152(input_size, output_dim=512):
319
+ model = Backbone(input_size, 152, 'ir_se', output_dim=output_dim)
320
+
321
+ return model
322
+
323
+
324
+ def IR_SE_200(input_size, output_dim=512):
325
+ model = Backbone(input_size, 200, 'ir_se', output_dim=output_dim)
326
+
327
+ return model
328
+
329
+
330
+ if __name__ == '__main__':
331
+
332
+ inputs_shape = (1, 3, 112, 112)
333
+ model = IR_50(input_size=(112,112))
334
+ model.eval()
335
+ res = flop_count(model, inputs=torch.randn(inputs_shape), supported_ops={})
336
+ fvcore_flop = np.array(list(res[0].values())).sum()
337
+ print('FLOPs: ', fvcore_flop / 1e9, 'G')
338
+ print('Num Params: ', sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6, 'M')
339
+
340
+