cucs / app.py
moshel's picture
dd
858ba7e
raw
history blame
797 Bytes
import gradio as gr
import torch
model = torch.load('v4-epoch=19-val_loss=0.6964-val_accuracy=0.8964.ckpt')
import requests
from PIL import Image
from torchvision import transforms
# Download human-readable labels for ImageNet.
labels = ['good', 'ill']
def predict(inp):
img = transforms.ToTensor()(inp)
img = torchvision.transforms.Resize((800, 800))(img)
img = torchvision.transforms.CenterCrop(CROP)(img)
img = img.unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(img)[0], dim=0)
confidences = {labels[i]: float(prediction[i]) for i in range(2)}
return confidences
import gradio as gr
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
).launch()