DawnC commited on
Commit
85265af
·
verified ·
1 Parent(s): be75ea9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -16
app.py CHANGED
@@ -4,26 +4,26 @@ from PIL import Image
4
  import gradio as gr
5
  import os
6
 
7
- # 使用 CPU
8
  device = torch.device('cpu')
9
 
10
- # 定義 ResNet-50 模型架構
11
- model = models.resnet50(weights=None) # 不使用預訓練權重,僅構建模型架構
12
 
13
- # 加載模型權重到模型架構
14
- model.load_state_dict(torch.load('resnet50_model_weights.pth', map_location=device))
 
 
 
15
 
16
- # 設置模型為評估模式
17
  model.eval()
18
 
19
- # 定義影像預處理
20
  transform = transforms.Compose([
21
  transforms.Resize((224, 224)),
22
  transforms.ToTensor(),
23
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
24
  ])
25
 
26
- # 定義類別名稱
27
  class_names = ['Abyssinian (阿比西尼亞貓)', 'American Bulldog (美國鬥牛犬)', 'American Pit Bull Terrier (美國比特鬥牛梗)',
28
  'Basset Hound (巴吉度獵犬)', 'Beagle (米格魯)', 'Bengal (孟加拉貓)', 'Birman (緬甸貓)', 'Bombay (孟買貓)',
29
  'Boxer (拳師犬)', 'British Shorthair (英國短毛貓)', 'Chihuahua (吉娃娃)', 'Egyptian Mau (埃及貓)',
@@ -35,19 +35,27 @@ class_names = ['Abyssinian (阿比西尼亞貓)', 'American Bulldog (美國鬥
35
  'Siamese (暹羅貓)', 'Sphynx (無毛貓)', 'Staffordshire Bull Terrier (史塔福郡鬥牛犬)',
36
  'Wheaten Terrier (小麥色梗)', 'Yorkshire Terrier (約克夏犬)']
37
 
38
- # 定義預測函數
39
  def classify_image(image):
40
- image = transform(image).unsqueeze(0).to(device) # 確保影像資料處理在 CPU
41
  with torch.no_grad():
42
  outputs = model(image)
43
- probabilities, indices = torch.topk(outputs, k=3) # 取得前3個預測
44
- probabilities = torch.nn.functional.softmax(probabilities, dim=1) # 將結果轉換為機率
45
  predictions = [(class_names[idx], prob.item()) for idx, prob in zip(indices[0], probabilities[0])]
46
  return {class_name: f"{prob * 100:.2f}%" for class_name, prob in predictions}
47
 
48
- # Gradio 介面
49
- examples = [["examples/" + img] for img in os.listdir('examples')]
 
 
 
 
 
 
 
50
 
 
51
  demo = gr.Interface(fn=classify_image,
52
  inputs=gr.Image(type="pil"),
53
  outputs=[gr.Label(num_top_classes=3, label="Top 3 Predictions")],
@@ -56,5 +64,4 @@ demo = gr.Interface(fn=classify_image,
56
  description='A ResNet50-based model for classifying 37 different pet breeds.',
57
  article='https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/The%20Oxford-IIIT%20Pet%20Project')
58
 
59
- # 啟動 Gradio demo
60
- demo.launch()
 
4
  import gradio as gr
5
  import os
6
 
7
+ # Use CPU
8
  device = torch.device('cpu')
9
 
10
+ # define ResNet-50 architecture
11
+ model = models.resnet50(weights=None)
12
 
13
+ # change the output to 37 (num_classes)
14
+ model.fc = torch.nn.Linear(2048, 37)
15
+
16
+ # Load model weights
17
+ model.load_state_dict(torch.load('/content/Oxford_Pet_classifier/resnet50_model_weights.pth', map_location=device))
18
 
 
19
  model.eval()
20
 
 
21
  transform = transforms.Compose([
22
  transforms.Resize((224, 224)),
23
  transforms.ToTensor(),
24
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
25
  ])
26
 
 
27
  class_names = ['Abyssinian (阿比西尼亞貓)', 'American Bulldog (美國鬥牛犬)', 'American Pit Bull Terrier (美國比特鬥牛梗)',
28
  'Basset Hound (巴吉度獵犬)', 'Beagle (米格魯)', 'Bengal (孟加拉貓)', 'Birman (緬甸貓)', 'Bombay (孟買貓)',
29
  'Boxer (拳師犬)', 'British Shorthair (英國短毛貓)', 'Chihuahua (吉娃娃)', 'Egyptian Mau (埃及貓)',
 
35
  'Siamese (暹羅貓)', 'Sphynx (無毛貓)', 'Staffordshire Bull Terrier (史塔福郡鬥牛犬)',
36
  'Wheaten Terrier (小麥色梗)', 'Yorkshire Terrier (約克夏犬)']
37
 
38
+ # predict function
39
  def classify_image(image):
40
+ image = transform(image).unsqueeze(0).to(device) # make sure on the cpu
41
  with torch.no_grad():
42
  outputs = model(image)
43
+ probabilities, indices = torch.topk(outputs, k=3) # top 3 predictions
44
+ probabilities = torch.nn.functional.softmax(probabilities, dim=1) # turn into probabilities
45
  predictions = [(class_names[idx], prob.item()) for idx, prob in zip(indices[0], probabilities[0])]
46
  return {class_name: f"{prob * 100:.2f}%" for class_name, prob in predictions}
47
 
48
+ # define examples_path
49
+ examples_path = '/content/Oxford_Pet_classifier/examples'
50
+
51
+ # make sure examples is exists
52
+ if os.path.exists(examples_path):
53
+ examples = [[examples_path + "/" + img] for img in os.listdir(examples_path)]
54
+ else:
55
+ print(f"[ERROR] Examples folder not found at {examples_path}")
56
+ examples = []
57
 
58
+ # Gradio Interface
59
  demo = gr.Interface(fn=classify_image,
60
  inputs=gr.Image(type="pil"),
61
  outputs=[gr.Label(num_top_classes=3, label="Top 3 Predictions")],
 
64
  description='A ResNet50-based model for classifying 37 different pet breeds.',
65
  article='https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/The%20Oxford-IIIT%20Pet%20Project')
66
 
67
+ demo.launch()