|
from PIL import Image |
|
import numpy as np |
|
import timm |
|
import einops |
|
import torch |
|
from torch import nn |
|
from toolkit.dtransform import create_transforms_inference, create_transforms_inference1,\ |
|
create_transforms_inference2,\ |
|
create_transforms_inference3,\ |
|
create_transforms_inference4,\ |
|
create_transforms_inference5 |
|
from toolkit.chelper import load_model |
|
import torch.nn.functional as F |
|
|
|
|
|
def extract_model_from_pth(params_path, net_model): |
|
checkpoint = torch.load(params_path) |
|
state_dict = checkpoint['state_dict'] |
|
|
|
net_model.load_state_dict(state_dict, strict=True) |
|
|
|
return net_model |
|
|
|
|
|
class SRMConv2d_simple(nn.Module): |
|
def __init__(self, inc=3): |
|
super(SRMConv2d_simple, self).__init__() |
|
self.truc = nn.Hardtanh(-3, 3) |
|
self.kernel = torch.from_numpy(self._build_kernel(inc)).float() |
|
|
|
def forward(self, x): |
|
out = F.conv2d(x, self.kernel, stride=1, padding=2) |
|
out = self.truc(out) |
|
|
|
return out |
|
|
|
def _build_kernel(self, inc): |
|
|
|
filter1 = [[0, 0, 0, 0, 0], |
|
[0, -1, 2, -1, 0], |
|
[0, 2, -4, 2, 0], |
|
[0, -1, 2, -1, 0], |
|
[0, 0, 0, 0, 0]] |
|
|
|
filter2 = [[-1, 2, -2, 2, -1], |
|
[2, -6, 8, -6, 2], |
|
[-2, 8, -12, 8, -2], |
|
[2, -6, 8, -6, 2], |
|
[-1, 2, -2, 2, -1]] |
|
|
|
filter3 = [[0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0], |
|
[0, 1, -2, 1, 0], |
|
[0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0]] |
|
|
|
filter1 = np.asarray(filter1, dtype=float) / 4. |
|
filter2 = np.asarray(filter2, dtype=float) / 12. |
|
filter3 = np.asarray(filter3, dtype=float) / 2. |
|
|
|
filters = [[filter1], |
|
[filter2], |
|
[filter3]] |
|
filters = np.array(filters) |
|
filters = np.repeat(filters, inc, axis=1) |
|
return filters |
|
|
|
|
|
class INFER_API: |
|
|
|
_instance = None |
|
|
|
def __new__(cls): |
|
if cls._instance is None: |
|
cls._instance = super(INFER_API, cls).__new__(cls) |
|
cls._instance.initialize() |
|
return cls._instance |
|
|
|
def initialize(self): |
|
self.transformer_ = [create_transforms_inference(h=512, w=512), |
|
create_transforms_inference1(h=512, w=512), |
|
create_transforms_inference2(h=512, w=512), |
|
create_transforms_inference3(h=512, w=512), |
|
create_transforms_inference4(h=512, w=512), |
|
create_transforms_inference5(h=512, w=512)] |
|
self.srm = SRMConv2d_simple() |
|
|
|
|
|
self.model = load_model('all', 2) |
|
model_path = './final_model_csv/final_model.pth' |
|
self.model = extract_model_from_pth(model_path, self.model) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model = self.model.to(device) |
|
|
|
self.model.eval() |
|
|
|
def _add_new_channels_worker(self, image): |
|
new_channels = [] |
|
|
|
image = einops.rearrange(image, "h w c -> c h w") |
|
image = (image - torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(-1, 1, 1)) / torch.as_tensor( |
|
timm.data.constants.IMAGENET_DEFAULT_STD).view(-1, 1, 1) |
|
srm = self.srm(image.unsqueeze(0)).squeeze(0) |
|
new_channels.append(einops.rearrange(srm, "c h w -> h w c").numpy()) |
|
|
|
new_channels = np.concatenate(new_channels, axis=2) |
|
return torch.from_numpy(new_channels).float() |
|
|
|
def add_new_channels(self, images): |
|
images_copied = einops.rearrange(images, "c h w -> h w c") |
|
new_channels = self._add_new_channels_worker(images_copied) |
|
images_copied = torch.concatenate([images_copied, new_channels], dim=-1) |
|
images_copied = einops.rearrange(images_copied, "h w c -> c h w") |
|
|
|
return images_copied |
|
|
|
def test(self, img_path): |
|
|
|
img_data = Image.open(img_path).convert('RGB') |
|
|
|
|
|
all_data = [] |
|
for transform in self.transformer_: |
|
current_data = transform(img_data) |
|
current_data = self.add_new_channels(current_data) |
|
all_data.append(current_data) |
|
img_tensor = torch.stack(all_data, dim=0).unsqueeze(0).cuda() |
|
|
|
preds = self.model(img_tensor) |
|
|
|
return round(float(preds), 20) |
|
|
|
|
|
def main(): |
|
img = '51aa9b8d0da890cd1d0c5029e3d89e3c.jpg' |
|
infer_api = INFER_API() |
|
print(infer_api.test(img)) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |