Jfink09 commited on
Commit
1569e58
·
1 Parent(s): ef9f2d9

Upload 6 files

Browse files
Files changed (2) hide show
  1. app.py +3 -2
  2. model.py +1 -1
app.py CHANGED
@@ -9,13 +9,14 @@ from typing import Tuple, Dict
9
 
10
  # Setup class names
11
  class_names = ['CRVO',
 
12
  'Diabetic Retinopathy',
13
  'Laser Spots',
14
  'Macular Degeneration',
15
  'Myelinated Nerve Fiber',
16
  'Normal',
17
  'Pathological Mypoia',
18
- 'Retinitis Pigmentosa']
19
 
20
  ### 2. Model and transforms preparation ###
21
 
@@ -72,7 +73,7 @@ example_list = [["examples/" + example] for example in os.listdir("examples")]
72
  # Create the Gradio demo
73
  demo = gr.Interface(fn=predict, # mapping function from input to output
74
  inputs=gr.Image(type="pil"), # what are the inputs?
75
- outputs=[gr.Label(num_top_classes=8, label="Predictions"), # what are the outputs?
76
  gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
77
  # Create examples list from "examples/" directory
78
  examples=example_list,
 
9
 
10
  # Setup class names
11
  class_names = ['CRVO',
12
+ 'Choroidal Nevus',
13
  'Diabetic Retinopathy',
14
  'Laser Spots',
15
  'Macular Degeneration',
16
  'Myelinated Nerve Fiber',
17
  'Normal',
18
  'Pathological Mypoia',
19
+ 'Retinitis Pigmentosa'])
20
 
21
  ### 2. Model and transforms preparation ###
22
 
 
73
  # Create the Gradio demo
74
  demo = gr.Interface(fn=predict, # mapping function from input to output
75
  inputs=gr.Image(type="pil"), # what are the inputs?
76
+ outputs=[gr.Label(num_top_classes=9, label="Predictions"), # what are the outputs?
77
  gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
78
  # Create examples list from "examples/" directory
79
  examples=example_list,
model.py CHANGED
@@ -3,7 +3,7 @@ import torchvision
3
 
4
  from torch import nn
5
 
6
- def create_resnet50_model(num_classes:int=8, # 4
7
  seed:int=42):
8
  """Creates an ResNet50 feature extractor model and transforms.
9
 
 
3
 
4
  from torch import nn
5
 
6
+ def create_resnet50_model(num_classes:int=9, # 4
7
  seed:int=42):
8
  """Creates an ResNet50 feature extractor model and transforms.
9