File size: 3,053 Bytes
a2c4316
 
 
be366e7
e18dcea
6839f51
2c7b881
af18fab
 
be366e7
6545c1c
 
bd12a8b
 
8855229
 
 
 
6839f51
 
 
6545c1c
6839f51
 
 
 
 
808099f
 
be366e7
 
 
 
cc885f9
5ba279b
cc885f9
 
 
 
 
 
 
 
 
 
 
 
be366e7
 
61278d5
be366e7
 
f4ae5e5
 
 
cc885f9
4436e33
cc885f9
858ba7e
be366e7
163a85b
ec5990a
f4ae5e5
 
 
 
 
9a56e8d
 
efad804
c51e43f
d18ae81
be366e7
415bd3b
 
be366e7
cc885f9
3577913
7d40a69
be366e7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr


import torch
import torchvision
import timm
import cv2
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

print(timm.__version__)

# checkpoint = torch.load('v6-epoch=18-val_loss=0.0313-val_accuracy=0.9618.ckpt', map_location=torch.device('cpu'))
checkpoint = torch.load('v7-epoch=39-val_loss=0.0222-val_accuracy=0.9806.ckpt', map_location=torch.device('cpu'))
state_dict = checkpoint["state_dict"]                                                                                                                                                                                                         
model_weights = state_dict                                                                                                                                                                                                                    
for key in list(model_weights):                                                                                                                                                                                                               
    model_weights[key.replace("backbone.", "")] = model_weights.pop(key)                                                                                                                                                                      


def get_model():
    model = timm.create_model('convnext_base.fb_in22k_ft_in1k', pretrained=False, num_classes=2)

    return model

model = get_model()

model.load_state_dict(model_weights)
model.eval()

import requests
from PIL import Image
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

CROP = 224
SIZE = CROP + CROP//8

ho_trans_center = A.Compose([
    A.Resize(SIZE,SIZE, interpolation=cv2.INTER_AREA),
    A.CenterCrop(height=CROP, width=CROP, always_apply=True),                                        
    ])
topt = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
    ])

# Download human-readable labels for ImageNet.
labels = ['good', 'ill']

def predict(inp):
  target_layers = [model.norm_pre]
  cam = GradCAM(model=model,
       target_layers=target_layers)
  img = ho_trans_center(image = inp)['image']
  rgb_img = img.copy()/255.
  img = topt(image = img)['image']
  img = img.unsqueeze(0)
  with torch.no_grad():
    prediction = model(img).softmax(1).numpy()
    confidences = {labels[i]: float(prediction[0][i]) for i in range(2)}    
  grad = cam(input_tensor=img,
              targets=None,
              # eigen_smooth=args.eigen_smooth,
              # aug_smooth=args.aug_smooth
        )
  grad = grad[0, :]
  print(rgb_img.shape, rgb_img.dtype)
  print(grad.shape, grad.dtype)
  cam_image = show_cam_on_image(rgb_img, grad)
  return confidences,cam_image 

import gradio as gr

gr.Interface(fn=predict, 
             inputs=gr.Image(),
             outputs=[gr.Label(num_top_classes=1), "image"],
             ).launch()