Upload directory
Browse files
models/iresnet_insightface/__init__.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..base import BaseModel
|
2 |
+
from .model import iresnet100, iresnet50, iresnet18
|
3 |
+
from torchvision import transforms
|
4 |
+
|
5 |
+
|
6 |
+
class IResNetModel(BaseModel):
|
7 |
+
|
8 |
+
"""
|
9 |
+
A class representing a model for IResNet architectures. It supports creating
|
10 |
+
models with specific configurations such as IR_50 and IR_101.
|
11 |
+
|
12 |
+
Attributes:
|
13 |
+
net (torch.nn.Module): The IResNet network (either IR_50 or IR_101).
|
14 |
+
config (object): The configuration object with model specifications.
|
15 |
+
"""
|
16 |
+
|
17 |
+
|
18 |
+
def __init__(self, net, config):
|
19 |
+
super(IResNetModel, self).__init__(config)
|
20 |
+
self.net = net
|
21 |
+
self.config = config
|
22 |
+
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def from_config(cls, config):
|
26 |
+
if config.name == 'ir50':
|
27 |
+
net = iresnet50(input_size=(112,112), output_dim=config.output_dim)
|
28 |
+
elif config.name == 'ir101':
|
29 |
+
net = iresnet100()
|
30 |
+
elif config.name == 'ir18':
|
31 |
+
net = iresnet18(input_size=(112,112), output_dim=config.output_dim)
|
32 |
+
else:
|
33 |
+
raise NotImplementedError
|
34 |
+
|
35 |
+
model = cls(net, config)
|
36 |
+
model.eval()
|
37 |
+
return model
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
if self.input_color_flip:
|
41 |
+
x = x.flip(1)
|
42 |
+
return self.net(x)
|
43 |
+
|
44 |
+
def make_train_transform(self):
|
45 |
+
transform = transforms.Compose([
|
46 |
+
transforms.ToTensor(),
|
47 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
48 |
+
])
|
49 |
+
return transform
|
50 |
+
|
51 |
+
def make_test_transform(self):
|
52 |
+
transform = transforms.Compose([
|
53 |
+
transforms.ToTensor(),
|
54 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
55 |
+
])
|
56 |
+
return transform
|
57 |
+
|
58 |
+
def load_model(model_config):
|
59 |
+
model = IResNetModel.from_config(model_config)
|
60 |
+
return model
|