minchul commited on
Commit
f83ed69
·
verified ·
1 Parent(s): 46a3b5e

Upload directory

Browse files
models/iresnet_insightface/__init__.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base import BaseModel
2
+ from .model import iresnet100, iresnet50, iresnet18
3
+ from torchvision import transforms
4
+
5
+
6
+ class IResNetModel(BaseModel):
7
+
8
+ """
9
+ A class representing a model for IResNet architectures. It supports creating
10
+ models with specific configurations such as IR_50 and IR_101.
11
+
12
+ Attributes:
13
+ net (torch.nn.Module): The IResNet network (either IR_50 or IR_101).
14
+ config (object): The configuration object with model specifications.
15
+ """
16
+
17
+
18
+ def __init__(self, net, config):
19
+ super(IResNetModel, self).__init__(config)
20
+ self.net = net
21
+ self.config = config
22
+
23
+
24
+ @classmethod
25
+ def from_config(cls, config):
26
+ if config.name == 'ir50':
27
+ net = iresnet50(input_size=(112,112), output_dim=config.output_dim)
28
+ elif config.name == 'ir101':
29
+ net = iresnet100()
30
+ elif config.name == 'ir18':
31
+ net = iresnet18(input_size=(112,112), output_dim=config.output_dim)
32
+ else:
33
+ raise NotImplementedError
34
+
35
+ model = cls(net, config)
36
+ model.eval()
37
+ return model
38
+
39
+ def forward(self, x):
40
+ if self.input_color_flip:
41
+ x = x.flip(1)
42
+ return self.net(x)
43
+
44
+ def make_train_transform(self):
45
+ transform = transforms.Compose([
46
+ transforms.ToTensor(),
47
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
48
+ ])
49
+ return transform
50
+
51
+ def make_test_transform(self):
52
+ transform = transforms.Compose([
53
+ transforms.ToTensor(),
54
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
55
+ ])
56
+ return transform
57
+
58
+ def load_model(model_config):
59
+ model = IResNetModel.from_config(model_config)
60
+ return model