File size: 4,821 Bytes
e8b0040 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from PIL import Image
import numpy as np
import timm
import einops
import torch
from torch import nn
from toolkit.dtransform import create_transforms_inference, create_transforms_inference1,\
create_transforms_inference2,\
create_transforms_inference3,\
create_transforms_inference4,\
create_transforms_inference5
from toolkit.chelper import load_model
import torch.nn.functional as F
def extract_model_from_pth(params_path, net_model):
checkpoint = torch.load(params_path)
state_dict = checkpoint['state_dict']
net_model.load_state_dict(state_dict, strict=True)
return net_model
class SRMConv2d_simple(nn.Module):
def __init__(self, inc=3):
super(SRMConv2d_simple, self).__init__()
self.truc = nn.Hardtanh(-3, 3)
self.kernel = torch.from_numpy(self._build_kernel(inc)).float()
def forward(self, x):
out = F.conv2d(x, self.kernel, stride=1, padding=2)
out = self.truc(out)
return out
def _build_kernel(self, inc):
# filter1: KB
filter1 = [[0, 0, 0, 0, 0],
[0, -1, 2, -1, 0],
[0, 2, -4, 2, 0],
[0, -1, 2, -1, 0],
[0, 0, 0, 0, 0]]
# filter2:KV
filter2 = [[-1, 2, -2, 2, -1],
[2, -6, 8, -6, 2],
[-2, 8, -12, 8, -2],
[2, -6, 8, -6, 2],
[-1, 2, -2, 2, -1]]
# filter3:hor 2rd
filter3 = [[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 1, -2, 1, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]
filter1 = np.asarray(filter1, dtype=float) / 4.
filter2 = np.asarray(filter2, dtype=float) / 12.
filter3 = np.asarray(filter3, dtype=float) / 2.
# statck the filters
filters = [[filter1], # , filter1, filter1],
[filter2], # , filter2, filter2],
[filter3]] # , filter3, filter3]]
filters = np.array(filters)
filters = np.repeat(filters, inc, axis=1)
return filters
class INFER_API:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(INFER_API, cls).__new__(cls)
cls._instance.initialize()
return cls._instance
def initialize(self):
self.transformer_ = [create_transforms_inference(h=512, w=512),
create_transforms_inference1(h=512, w=512),
create_transforms_inference2(h=512, w=512),
create_transforms_inference3(h=512, w=512),
create_transforms_inference4(h=512, w=512),
create_transforms_inference5(h=512, w=512)]
self.srm = SRMConv2d_simple()
# model init
self.model = load_model('all', 2)
model_path = './final_model_csv/final_model.pth'
self.model = extract_model_from_pth(model_path, self.model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(device)
self.model.eval()
def _add_new_channels_worker(self, image):
new_channels = []
image = einops.rearrange(image, "h w c -> c h w")
image = (image - torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(-1, 1, 1)) / torch.as_tensor(
timm.data.constants.IMAGENET_DEFAULT_STD).view(-1, 1, 1)
srm = self.srm(image.unsqueeze(0)).squeeze(0)
new_channels.append(einops.rearrange(srm, "c h w -> h w c").numpy())
new_channels = np.concatenate(new_channels, axis=2)
return torch.from_numpy(new_channels).float()
def add_new_channels(self, images):
images_copied = einops.rearrange(images, "c h w -> h w c")
new_channels = self._add_new_channels_worker(images_copied)
images_copied = torch.concatenate([images_copied, new_channels], dim=-1)
images_copied = einops.rearrange(images_copied, "h w c -> c h w")
return images_copied
def test(self, img_path):
# img load
img_data = Image.open(img_path).convert('RGB')
# transform
all_data = []
for transform in self.transformer_:
current_data = transform(img_data)
current_data = self.add_new_channels(current_data)
all_data.append(current_data)
img_tensor = torch.stack(all_data, dim=0).unsqueeze(0).cuda()
preds = self.model(img_tensor)
return round(float(preds), 20)
def main():
img = '51aa9b8d0da890cd1d0c5029e3d89e3c.jpg'
infer_api = INFER_API()
print(infer_api.test(img))
if __name__ == '__main__':
main() |