File size: 1,768 Bytes
f83ed69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from ..base import BaseModel
from .model import iresnet100, iresnet50, iresnet18
from torchvision import transforms


class IResNetModel(BaseModel):

    """
    A class representing a model for IResNet architectures. It supports creating
    models with specific configurations such as IR_50 and IR_101.

    Attributes:
        net (torch.nn.Module): The IResNet network (either IR_50 or IR_101).
        config (object): The configuration object with model specifications.
    """


    def __init__(self, net, config):
        super(IResNetModel, self).__init__(config)
        self.net = net
        self.config = config


    @classmethod
    def from_config(cls, config):
        if config.name == 'ir50':
            net = iresnet50(input_size=(112,112), output_dim=config.output_dim)
        elif config.name == 'ir101':
            net = iresnet100()
        elif config.name == 'ir18':
            net = iresnet18(input_size=(112,112), output_dim=config.output_dim)
        else:
            raise NotImplementedError

        model = cls(net, config)
        model.eval()
        return model

    def forward(self, x):
        if self.input_color_flip:
            x = x.flip(1)
        return self.net(x)

    def make_train_transform(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])
        return transform

    def make_test_transform(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])
        return transform

def load_model(model_config):
    model = IResNetModel.from_config(model_config)
    return model