DawnC commited on
Commit
91b16a5
1 Parent(s): d81b8e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -14
app.py CHANGED
@@ -4,29 +4,27 @@ from PIL import Image
4
  import gradio as gr
5
  import os
6
 
7
- # 使用 CPU
8
  device = torch.device('cpu')
9
 
10
- # 定義 ResNet-50 模型架構(不使用預訓練權重)
11
  model = models.resnet50(weights=None)
12
 
13
- # 修改模型的全連接層,輸出 37 個類別
14
  model.fc = torch.nn.Linear(2048, 37)
15
 
16
- # 加載模型權重
17
  model.load_state_dict(torch.load('./resnet50_model_weights.pth', map_location=device))
18
 
19
- # 設置模型為評估模式
20
  model.eval()
21
 
22
- # 定義影像預處理
23
  transform = transforms.Compose([
24
  transforms.Resize((224, 224)),
25
  transforms.ToTensor(),
26
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
27
  ])
28
 
29
- # 定義類別名稱
30
  class_names = ['Abyssinian (阿比西尼亞貓)', 'American Bulldog (美國鬥牛犬)', 'American Pit Bull Terrier (美國比特鬥牛梗)',
31
  'Basset Hound (巴吉度獵犬)', 'Beagle (米格魯)', 'Bengal (孟加拉貓)', 'Birman (緬甸貓)', 'Bombay (孟買貓)',
32
  'Boxer (拳師犬)', 'British Shorthair (英國短毛貓)', 'Chihuahua (吉娃娃)', 'Egyptian Mau (埃及貓)',
@@ -38,7 +36,7 @@ class_names = ['Abyssinian (阿比西尼亞貓)', 'American Bulldog (美國鬥
38
  'Siamese (暹羅貓)', 'Sphynx (無毛貓)', 'Staffordshire Bull Terrier (史塔福郡鬥牛犬)',
39
  'Wheaten Terrier (小麥色梗)', 'Yorkshire Terrier (約克夏犬)']
40
 
41
- # 定義預測函數
42
  def classify_image(image):
43
  image = transform(image).unsqueeze(0).to(device)
44
  with torch.no_grad():
@@ -48,7 +46,6 @@ def classify_image(image):
48
  predictions = [(class_names[idx], prob.item()) for idx, prob in zip(indices[0], probabilities[0])]
49
  return {class_name: f"{prob:.2f}" for class_name, prob in predictions}
50
 
51
- # 設定 examples 路徑
52
  examples_path = './examples'
53
 
54
  if os.path.exists(examples_path):
@@ -56,10 +53,9 @@ if os.path.exists(examples_path):
56
  else:
57
  print(f"[ERROR] Examples folder not found at {examples_path}")
58
 
59
- # 設定範例圖片
60
  examples = [[examples_path + "/" + img] for img in os.listdir(examples_path)]
61
 
62
- # 用戶可參考的品種列表
63
  breed_list_text = """
64
  ### Recognizable Breeds:
65
 
@@ -73,14 +69,14 @@ breed_list_text = """
73
  """
74
 
75
 
76
- # Gradio 介面
77
  demo = gr.Interface(
78
  fn=classify_image,
79
- inputs=gr.Image(type="pil"), # 只需要圖片輸入
80
  outputs=[gr.Label(num_top_classes=3, label="Top 3 Predictions")],
81
  examples=examples,
82
  title='Oxford Pet 🐈🐕',
83
- description=f'A ResNet50-based model for classifying 37 different pet breeds.\n\n{breed_list_text}', # 直接把品種放到描述中
84
  article='[Oxford Project](https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/The%20Oxford-IIIT%20Pet%20Project)'
85
  )
86
 
 
4
  import gradio as gr
5
  import os
6
 
7
+ # Use CPU
8
  device = torch.device('cpu')
9
 
10
+ # Define ResNet-50 Architecture
11
  model = models.resnet50(weights=None)
12
 
13
+ # Chanege model ouputs to fit this data (num_classes=37)
14
  model.fc = torch.nn.Linear(2048, 37)
15
 
16
+ # Load model's weight
17
  model.load_state_dict(torch.load('./resnet50_model_weights.pth', map_location=device))
18
 
 
19
  model.eval()
20
 
 
21
  transform = transforms.Compose([
22
  transforms.Resize((224, 224)),
23
  transforms.ToTensor(),
24
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
25
  ])
26
 
27
+
28
  class_names = ['Abyssinian (阿比西尼亞貓)', 'American Bulldog (美國鬥牛犬)', 'American Pit Bull Terrier (美國比特鬥牛梗)',
29
  'Basset Hound (巴吉度獵犬)', 'Beagle (米格魯)', 'Bengal (孟加拉貓)', 'Birman (緬甸貓)', 'Bombay (孟買貓)',
30
  'Boxer (拳師犬)', 'British Shorthair (英國短毛貓)', 'Chihuahua (吉娃娃)', 'Egyptian Mau (埃及貓)',
 
36
  'Siamese (暹羅貓)', 'Sphynx (無毛貓)', 'Staffordshire Bull Terrier (史塔福郡鬥牛犬)',
37
  'Wheaten Terrier (小麥色梗)', 'Yorkshire Terrier (約克夏犬)']
38
 
39
+ # define predict images function
40
  def classify_image(image):
41
  image = transform(image).unsqueeze(0).to(device)
42
  with torch.no_grad():
 
46
  predictions = [(class_names[idx], prob.item()) for idx, prob in zip(indices[0], probabilities[0])]
47
  return {class_name: f"{prob:.2f}" for class_name, prob in predictions}
48
 
 
49
  examples_path = './examples'
50
 
51
  if os.path.exists(examples_path):
 
53
  else:
54
  print(f"[ERROR] Examples folder not found at {examples_path}")
55
 
 
56
  examples = [[examples_path + "/" + img] for img in os.listdir(examples_path)]
57
 
58
+ # Create the reference list
59
  breed_list_text = """
60
  ### Recognizable Breeds:
61
 
 
69
  """
70
 
71
 
72
+ # Gradio Interface
73
  demo = gr.Interface(
74
  fn=classify_image,
75
+ inputs=gr.Image(type="pil"),
76
  outputs=[gr.Label(num_top_classes=3, label="Top 3 Predictions")],
77
  examples=examples,
78
  title='Oxford Pet 🐈🐕',
79
+ description=f'A ResNet50-based model for classifying 37 different pet breeds.\n\n{breed_list_text}',
80
  article='[Oxford Project](https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/The%20Oxford-IIIT%20Pet%20Project)'
81
  )
82