File size: 2,410 Bytes
a2c4316
 
 
be366e7
e18dcea
6839f51
2c7b881
be366e7
6545c1c
 
13648f6
8855229
 
 
 
6839f51
 
 
6545c1c
6839f51
 
 
 
 
808099f
 
be366e7
 
 
 
cc885f9
5ba279b
cc885f9
 
 
 
 
 
 
 
 
 
 
 
be366e7
 
61278d5
be366e7
 
cc885f9
 
858ba7e
be366e7
163a85b
ec5990a
3577913
be366e7
415bd3b
 
be366e7
cc885f9
3577913
7d40a69
be366e7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr


import torch
import torchvision
import timm
import cv2

print(timm.__version__)

checkpoint = torch.load('v6-epoch=18-val_loss=0.0313-val_accuracy=0.9618.ckpt', map_location=torch.device('cpu'))
state_dict = checkpoint["state_dict"]                                                                                                                                                                                                         
model_weights = state_dict                                                                                                                                                                                                                    
for key in list(model_weights):                                                                                                                                                                                                               
    model_weights[key.replace("backbone.", "")] = model_weights.pop(key)                                                                                                                                                                      


def get_model():
    model = timm.create_model('convnext_base.fb_in22k_ft_in1k', pretrained=False, num_classes=2)

    return model

model = get_model()

model.load_state_dict(model_weights)
model.eval()

import requests
from PIL import Image
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

CROP = 224
SIZE = CROP + CROP//8

ho_trans_center = A.Compose([
    A.Resize(SIZE,SIZE, interpolation=cv2.INTER_AREA),
    A.CenterCrop(height=CROP, width=CROP, always_apply=True),                                        
    ])
topt = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
    ])

# Download human-readable labels for ImageNet.
labels = ['good', 'ill']

def predict(inp):
  img = ho_trans_center(image = inp)['image']
  img = topt(image = img)['image']
  img = img.unsqueeze(0)
  with torch.no_grad():
    prediction = model(img).softmax(1).numpy()
    confidences = {labels[i]: float(prediction[0][i]) for i in range(2)}    
  return confidences, img

import gradio as gr

gr.Interface(fn=predict, 
             inputs=gr.Image(),
             outputs=[gr.Label(num_top_classes=1), "image"],
             ).launch()