minchul commited on
Commit
92065fc
·
verified ·
1 Parent(s): f83ed69

Upload directory

Browse files
Files changed (1) hide show
  1. models/iresnet_insightface/model.py +184 -0
models/iresnet_insightface/model.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
8
+ """3x3 convolution with padding"""
9
+ return nn.Conv2d(in_planes,
10
+ out_planes,
11
+ kernel_size=3,
12
+ stride=stride,
13
+ padding=dilation,
14
+ groups=groups,
15
+ bias=False,
16
+ dilation=dilation)
17
+
18
+
19
+ def conv1x1(in_planes, out_planes, stride=1):
20
+ """1x1 convolution"""
21
+ return nn.Conv2d(in_planes,
22
+ out_planes,
23
+ kernel_size=1,
24
+ stride=stride,
25
+ bias=False)
26
+
27
+
28
+ class IBasicBlock(nn.Module):
29
+ expansion = 1
30
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
31
+ groups=1, base_width=64, dilation=1):
32
+ super(IBasicBlock, self).__init__()
33
+ if groups != 1 or base_width != 64:
34
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
35
+ if dilation > 1:
36
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
37
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
38
+ self.conv1 = conv3x3(inplanes, planes)
39
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
40
+ self.prelu = nn.PReLU(planes)
41
+ self.conv2 = conv3x3(planes, planes, stride)
42
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
43
+ self.downsample = downsample
44
+ self.stride = stride
45
+
46
+ def forward(self, x):
47
+ identity = x
48
+ out = self.bn1(x)
49
+ out = self.conv1(out)
50
+ out = self.bn2(out)
51
+ out = self.prelu(out)
52
+ out = self.conv2(out)
53
+ out = self.bn3(out)
54
+ if self.downsample is not None:
55
+ identity = self.downsample(x)
56
+ out += identity
57
+ return out
58
+
59
+
60
+ class IResNet(nn.Module):
61
+ fc_scale = 7 * 7
62
+ def __init__(self,
63
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
64
+ groups=1, width_per_group=64, replace_stride_with_dilation=None):
65
+ super(IResNet, self).__init__()
66
+ self.inplanes = 64
67
+ self.dilation = 1
68
+ if replace_stride_with_dilation is None:
69
+ replace_stride_with_dilation = [False, False, False]
70
+ if len(replace_stride_with_dilation) != 3:
71
+ raise ValueError("replace_stride_with_dilation should be None "
72
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
73
+ self.groups = groups
74
+ self.base_width = width_per_group
75
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
76
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
77
+ self.prelu = nn.PReLU(self.inplanes)
78
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
79
+ self.layer2 = self._make_layer(block,
80
+ 128,
81
+ layers[1],
82
+ stride=2,
83
+ dilate=replace_stride_with_dilation[0])
84
+ self.layer3 = self._make_layer(block,
85
+ 256,
86
+ layers[2],
87
+ stride=2,
88
+ dilate=replace_stride_with_dilation[1])
89
+ self.layer4 = self._make_layer(block,
90
+ 512,
91
+ layers[3],
92
+ stride=2,
93
+ dilate=replace_stride_with_dilation[2])
94
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
95
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
96
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
97
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
98
+ nn.init.constant_(self.features.weight, 1.0)
99
+ self.features.weight.requires_grad = False
100
+
101
+ for m in self.modules():
102
+ if isinstance(m, nn.Conv2d):
103
+ nn.init.normal_(m.weight, 0, 0.1)
104
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
105
+ nn.init.constant_(m.weight, 1)
106
+ nn.init.constant_(m.bias, 0)
107
+
108
+ if zero_init_residual:
109
+ for m in self.modules():
110
+ if isinstance(m, IBasicBlock):
111
+ nn.init.constant_(m.bn2.weight, 0)
112
+
113
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
114
+ downsample = None
115
+ previous_dilation = self.dilation
116
+ if dilate:
117
+ self.dilation *= stride
118
+ stride = 1
119
+ if stride != 1 or self.inplanes != planes * block.expansion:
120
+ downsample = nn.Sequential(
121
+ conv1x1(self.inplanes, planes * block.expansion, stride),
122
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
123
+ )
124
+ layers = []
125
+ layers.append(
126
+ block(self.inplanes, planes, stride, downsample, self.groups,
127
+ self.base_width, previous_dilation))
128
+ self.inplanes = planes * block.expansion
129
+ for _ in range(1, blocks):
130
+ layers.append(
131
+ block(self.inplanes,
132
+ planes,
133
+ groups=self.groups,
134
+ base_width=self.base_width,
135
+ dilation=self.dilation))
136
+
137
+ return nn.Sequential(*layers)
138
+
139
+ def forward(self, x):
140
+ x = self.conv1(x)
141
+ x = self.bn1(x)
142
+ x = self.prelu(x)
143
+ x = self.layer1(x)
144
+ x = self.layer2(x)
145
+ x = self.layer3(x)
146
+ x = self.layer4(x)
147
+ x = self.bn2(x)
148
+ x = torch.flatten(x, 1)
149
+ x = self.dropout(x)
150
+ x = self.fc(x)
151
+ x = self.features(x)
152
+ return x
153
+
154
+
155
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
156
+ model = IResNet(block, layers, **kwargs)
157
+ if pretrained:
158
+ raise ValueError()
159
+ return model
160
+
161
+
162
+ def iresnet18(pretrained=False, progress=True, **kwargs):
163
+ return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
164
+ progress, **kwargs)
165
+
166
+
167
+ def iresnet34(pretrained=False, progress=True, **kwargs):
168
+ return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
169
+ progress, **kwargs)
170
+
171
+
172
+ def iresnet50(pretrained=False, progress=True, **kwargs):
173
+ return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
174
+ progress, **kwargs)
175
+
176
+
177
+ def iresnet100(pretrained=False, progress=True, **kwargs):
178
+ return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
179
+ progress, **kwargs)
180
+
181
+
182
+ def iresnet200(pretrained=False, progress=True, **kwargs):
183
+ return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
184
+ progress, **kwargs)