File size: 4,821 Bytes
e8b0040
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from PIL import Image
import numpy as np
import timm
import einops
import torch
from torch import nn
from toolkit.dtransform import create_transforms_inference, create_transforms_inference1,\
                    create_transforms_inference2,\
                    create_transforms_inference3,\
                    create_transforms_inference4,\
                    create_transforms_inference5
from toolkit.chelper import load_model
import torch.nn.functional as F


def extract_model_from_pth(params_path, net_model):
    checkpoint = torch.load(params_path)
    state_dict = checkpoint['state_dict']

    net_model.load_state_dict(state_dict, strict=True)

    return net_model


class SRMConv2d_simple(nn.Module):
    def __init__(self, inc=3):
        super(SRMConv2d_simple, self).__init__()
        self.truc = nn.Hardtanh(-3, 3)
        self.kernel = torch.from_numpy(self._build_kernel(inc)).float()

    def forward(self, x):
        out = F.conv2d(x, self.kernel, stride=1, padding=2)
        out = self.truc(out)

        return out

    def _build_kernel(self, inc):
        # filter1: KB
        filter1 = [[0, 0, 0, 0, 0],
                   [0, -1, 2, -1, 0],
                   [0, 2, -4, 2, 0],
                   [0, -1, 2, -1, 0],
                   [0, 0, 0, 0, 0]]
        # filter2:KV
        filter2 = [[-1, 2, -2, 2, -1],
                   [2, -6, 8, -6, 2],
                   [-2, 8, -12, 8, -2],
                   [2, -6, 8, -6, 2],
                   [-1, 2, -2, 2, -1]]
        # filter3:hor 2rd
        filter3 = [[0, 0, 0, 0, 0],
                   [0, 0, 0, 0, 0],
                   [0, 1, -2, 1, 0],
                   [0, 0, 0, 0, 0],
                   [0, 0, 0, 0, 0]]

        filter1 = np.asarray(filter1, dtype=float) / 4.
        filter2 = np.asarray(filter2, dtype=float) / 12.
        filter3 = np.asarray(filter3, dtype=float) / 2.
        # statck the filters
        filters = [[filter1],  # , filter1, filter1],
                   [filter2],  # , filter2, filter2],
                   [filter3]]  # , filter3, filter3]]
        filters = np.array(filters)
        filters = np.repeat(filters, inc, axis=1)
        return filters


class INFER_API:

    _instance = None
        
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(INFER_API, cls).__new__(cls)
            cls._instance.initialize()
        return cls._instance
    
    def initialize(self):
        self.transformer_ = [create_transforms_inference(h=512, w=512),
                        create_transforms_inference1(h=512, w=512),
                        create_transforms_inference2(h=512, w=512),
                        create_transforms_inference3(h=512, w=512),
                        create_transforms_inference4(h=512, w=512),
                        create_transforms_inference5(h=512, w=512)]
        self.srm = SRMConv2d_simple()

        # model init
        self.model = load_model('all', 2)
        model_path = './final_model_csv/final_model.pth'
        self.model = extract_model_from_pth(model_path, self.model)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(device)

        self.model.eval()

    def _add_new_channels_worker(self, image):
        new_channels = []

        image = einops.rearrange(image, "h w c -> c h w")
        image = (image - torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(-1, 1, 1)) / torch.as_tensor(
            timm.data.constants.IMAGENET_DEFAULT_STD).view(-1, 1, 1)
        srm = self.srm(image.unsqueeze(0)).squeeze(0)
        new_channels.append(einops.rearrange(srm, "c h w -> h w c").numpy())

        new_channels = np.concatenate(new_channels, axis=2)
        return torch.from_numpy(new_channels).float()

    def add_new_channels(self, images):
        images_copied = einops.rearrange(images, "c h w -> h w c")
        new_channels = self._add_new_channels_worker(images_copied)
        images_copied = torch.concatenate([images_copied, new_channels], dim=-1)
        images_copied = einops.rearrange(images_copied, "h w c -> c h w")

        return images_copied

    def test(self, img_path):
        # img load
        img_data = Image.open(img_path).convert('RGB')

        # transform
        all_data = []
        for transform in self.transformer_:
            current_data = transform(img_data)
            current_data = self.add_new_channels(current_data)
            all_data.append(current_data)
        img_tensor = torch.stack(all_data, dim=0).unsqueeze(0).cuda()

        preds = self.model(img_tensor)

        return round(float(preds), 20)


def main():
    img = '51aa9b8d0da890cd1d0c5029e3d89e3c.jpg'
    infer_api = INFER_API()
    print(infer_api.test(img))


if __name__ == '__main__':
    main()