moshel commited on
Commit
e18dcea
·
1 Parent(s): 2fc24db
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
 
3
 
4
  import torch
 
5
 
6
  model = torch.load('v4-epoch=19-val_loss=0.6964-val_accuracy=0.8964.ckpt', map_location=torch.device('cpu'))
7
 
@@ -11,9 +12,10 @@ from torchvision import transforms
11
 
12
  # Download human-readable labels for ImageNet.
13
  labels = ['good', 'ill']
 
14
 
15
  def predict(inp):
16
- img = transforms.ToTensor()(inp)
17
  img = torchvision.transforms.Resize((800, 800))(img)
18
  img = torchvision.transforms.CenterCrop(CROP)(img)
19
  img = img.unsqueeze(0)
@@ -26,7 +28,7 @@ import gradio as gr
26
 
27
  gr.Interface(fn=predict,
28
  inputs=gr.Image(type="pil"),
29
- outputs=gr.Label(num_top_classes=3),
30
  ).launch()
31
 
32
 
 
2
 
3
 
4
  import torch
5
+ import torchvision
6
 
7
  model = torch.load('v4-epoch=19-val_loss=0.6964-val_accuracy=0.8964.ckpt', map_location=torch.device('cpu'))
8
 
 
12
 
13
  # Download human-readable labels for ImageNet.
14
  labels = ['good', 'ill']
15
+ CROP=384
16
 
17
  def predict(inp):
18
+ img = torchvision.transforms.ToTensor()(inp)
19
  img = torchvision.transforms.Resize((800, 800))(img)
20
  img = torchvision.transforms.CenterCrop(CROP)(img)
21
  img = img.unsqueeze(0)
 
28
 
29
  gr.Interface(fn=predict,
30
  inputs=gr.Image(type="pil"),
31
+ outputs=gr.Label(num_top_classes=1),
32
  ).launch()
33
 
34