File size: 2,410 Bytes
a2c4316 be366e7 e18dcea 6839f51 2c7b881 be366e7 6545c1c 13648f6 8855229 6839f51 6545c1c 6839f51 808099f be366e7 cc885f9 5ba279b cc885f9 be366e7 61278d5 be366e7 cc885f9 858ba7e be366e7 163a85b ec5990a 3577913 be366e7 415bd3b be366e7 cc885f9 3577913 7d40a69 be366e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import gradio as gr
import torch
import torchvision
import timm
import cv2
print(timm.__version__)
checkpoint = torch.load('v6-epoch=18-val_loss=0.0313-val_accuracy=0.9618.ckpt', map_location=torch.device('cpu'))
state_dict = checkpoint["state_dict"]
model_weights = state_dict
for key in list(model_weights):
model_weights[key.replace("backbone.", "")] = model_weights.pop(key)
def get_model():
model = timm.create_model('convnext_base.fb_in22k_ft_in1k', pretrained=False, num_classes=2)
return model
model = get_model()
model.load_state_dict(model_weights)
model.eval()
import requests
from PIL import Image
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
CROP = 224
SIZE = CROP + CROP//8
ho_trans_center = A.Compose([
A.Resize(SIZE,SIZE, interpolation=cv2.INTER_AREA),
A.CenterCrop(height=CROP, width=CROP, always_apply=True),
])
topt = A.Compose([
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
# Download human-readable labels for ImageNet.
labels = ['good', 'ill']
def predict(inp):
img = ho_trans_center(image = inp)['image']
img = topt(image = img)['image']
img = img.unsqueeze(0)
with torch.no_grad():
prediction = model(img).softmax(1).numpy()
confidences = {labels[i]: float(prediction[0][i]) for i in range(2)}
return confidences, img
import gradio as gr
gr.Interface(fn=predict,
inputs=gr.Image(),
outputs=[gr.Label(num_top_classes=1), "image"],
).launch()
|