DawnC commited on
Commit
9e437f8
1 Parent(s): 3ca6446

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -7,14 +7,14 @@ import os
7
  # Use CPU
8
  device = torch.device('cpu')
9
 
10
- # define ResNet-50 architecture
11
  model = models.resnet50(weights=None)
12
 
13
- # change the output to 37 (num_classes)
14
  model.fc = torch.nn.Linear(2048, 37)
15
 
16
- # Load model weights
17
- model.load_state_dict(torch.load('/content/Oxford_Pet_classifier/resnet50_model_weights.pth', map_location=device))
18
 
19
  model.eval()
20
 
@@ -37,25 +37,25 @@ class_names = ['Abyssinian (阿比西尼亞貓)', 'American Bulldog (美國鬥
37
 
38
  # predict function
39
  def classify_image(image):
40
- image = transform(image).unsqueeze(0).to(device) # make sure on the cpu
41
  with torch.no_grad():
42
  outputs = model(image)
43
  probabilities, indices = torch.topk(outputs, k=3) # top 3 predictions
44
- probabilities = torch.nn.functional.softmax(probabilities, dim=1) # turn into probabilities
45
  predictions = [(class_names[idx], prob.item()) for idx, prob in zip(indices[0], probabilities[0])]
46
  return {class_name: f"{prob * 100:.2f}%" for class_name, prob in predictions}
47
 
48
- # define examples_path
49
- examples_path = '/content/Oxford_Pet_classifier/examples'
50
 
51
- # make sure examples is exists
 
52
  if os.path.exists(examples_path):
53
- examples = [[examples_path + "/" + img] for img in os.listdir(examples_path)]
54
  else:
55
  print(f"[ERROR] Examples folder not found at {examples_path}")
56
- examples = []
57
 
58
  # Gradio Interface
 
 
59
  demo = gr.Interface(fn=classify_image,
60
  inputs=gr.Image(type="pil"),
61
  outputs=[gr.Label(num_top_classes=3, label="Top 3 Predictions")],
@@ -64,4 +64,4 @@ demo = gr.Interface(fn=classify_image,
64
  description='A ResNet50-based model for classifying 37 different pet breeds.',
65
  article='https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/The%20Oxford-IIIT%20Pet%20Project')
66
 
67
- demo.launch()
 
7
  # Use CPU
8
  device = torch.device('cpu')
9
 
10
+ # Define ResNet-50 Architecture
11
  model = models.resnet50(weights=None)
12
 
13
+ # revised full connected layer to 37 (num_classes)
14
  model.fc = torch.nn.Linear(2048, 37)
15
 
16
+ # Load Model weights
17
+ model.load_state_dict(torch.load('./resnet50_model_weights.pth', map_location=device))
18
 
19
  model.eval()
20
 
 
37
 
38
  # predict function
39
  def classify_image(image):
40
+ image = transform(image).unsqueeze(0).to(device) # make sure prediction on cpu
41
  with torch.no_grad():
42
  outputs = model(image)
43
  probabilities, indices = torch.topk(outputs, k=3) # top 3 predictions
44
+ probabilities = torch.nn.functional.softmax(probabilities, dim=1)
45
  predictions = [(class_names[idx], prob.item()) for idx, prob in zip(indices[0], probabilities[0])]
46
  return {class_name: f"{prob * 100:.2f}%" for class_name, prob in predictions}
47
 
 
 
48
 
49
+ examples_path = './examples'
50
+
51
  if os.path.exists(examples_path):
52
+ print(f"[INFO] Found examples folder at {examples_path}")
53
  else:
54
  print(f"[ERROR] Examples folder not found at {examples_path}")
 
55
 
56
  # Gradio Interface
57
+ examples = [[examples_path + "/" + img] for img in os.listdir(examples_path)]
58
+
59
  demo = gr.Interface(fn=classify_image,
60
  inputs=gr.Image(type="pil"),
61
  outputs=[gr.Label(num_top_classes=3, label="Top 3 Predictions")],
 
64
  description='A ResNet50-based model for classifying 37 different pet breeds.',
65
  article='https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/The%20Oxford-IIIT%20Pet%20Project')
66
 
67
+ demo.launch()