dd
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
|
|
2 |
|
3 |
|
4 |
import torch
|
|
|
5 |
|
6 |
model = torch.load('v4-epoch=19-val_loss=0.6964-val_accuracy=0.8964.ckpt', map_location=torch.device('cpu'))
|
7 |
|
@@ -11,9 +12,10 @@ from torchvision import transforms
|
|
11 |
|
12 |
# Download human-readable labels for ImageNet.
|
13 |
labels = ['good', 'ill']
|
|
|
14 |
|
15 |
def predict(inp):
|
16 |
-
img = transforms.ToTensor()(inp)
|
17 |
img = torchvision.transforms.Resize((800, 800))(img)
|
18 |
img = torchvision.transforms.CenterCrop(CROP)(img)
|
19 |
img = img.unsqueeze(0)
|
@@ -26,7 +28,7 @@ import gradio as gr
|
|
26 |
|
27 |
gr.Interface(fn=predict,
|
28 |
inputs=gr.Image(type="pil"),
|
29 |
-
outputs=gr.Label(num_top_classes=
|
30 |
).launch()
|
31 |
|
32 |
|
|
|
2 |
|
3 |
|
4 |
import torch
|
5 |
+
import torchvision
|
6 |
|
7 |
model = torch.load('v4-epoch=19-val_loss=0.6964-val_accuracy=0.8964.ckpt', map_location=torch.device('cpu'))
|
8 |
|
|
|
12 |
|
13 |
# Download human-readable labels for ImageNet.
|
14 |
labels = ['good', 'ill']
|
15 |
+
CROP=384
|
16 |
|
17 |
def predict(inp):
|
18 |
+
img = torchvision.transforms.ToTensor()(inp)
|
19 |
img = torchvision.transforms.Resize((800, 800))(img)
|
20 |
img = torchvision.transforms.CenterCrop(CROP)(img)
|
21 |
img = img.unsqueeze(0)
|
|
|
28 |
|
29 |
gr.Interface(fn=predict,
|
30 |
inputs=gr.Image(type="pil"),
|
31 |
+
outputs=gr.Label(num_top_classes=1),
|
32 |
).launch()
|
33 |
|
34 |
|