cucs / app.py
moshel's picture
dd
808099f
raw
history blame
2.32 kB
import gradio as gr
import torch
import torchvision
import timm
checkpoint = torch.load('v4-epoch=19-val_loss=0.6964-val_accuracy=0.8964.ckpt', map_location=torch.device('cpu'))
state_dict = checkpoint["state_dict"]
model_weights = state_dict
for key in list(model_weights):
model_weights[key.replace("backbone.", "")] = model_weights.pop(key)
def get_model():
model = timm.create_model('tf_efficientnet_b1', pretrained=True, num_classes=2, global_pool='catavgmax')
num_in_features = model.get_classifier().in_features
from torch import nn
model.fc = nn.Sequential(
nn.Linear(in_features=num_in_features, out_features=1024, bias=False),
nn.ReLU(),
nn.Linear(in_features=1024, out_features=2, bias=False),
)
return model
model = get_model()
model.load_state_dict(model_weights)
model.eval()
import requests
from PIL import Image
from torchvision import transforms
# Download human-readable labels for ImageNet.
labels = ['good', 'ill']
CROP=384
def predict(inp):
img = torchvision.transforms.ToTensor()(inp)
img = torchvision.transforms.Resize((800, 800))(img)
img = torchvision.transforms.CenterCrop(CROP)(img)
img = img.unsqueeze(0)
with torch.no_grad():
prediction = model(img).softmax(1).numpy()
confidences = {labels[i]: float(prediction[i]) for i in range(2)}
return confidences
import gradio as gr
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=1),
).launch()