DawnC commited on
Commit
75c78ca
1 Parent(s): 37f6bf3

fixed bugs

Browse files
Files changed (1) hide show
  1. app.py +37 -34
app.py CHANGED
@@ -1,57 +1,60 @@
1
  import torch
2
- from torchvision import transforms
3
- from torchvision import models
4
  from PIL import Image
5
  import gradio as gr
6
  import os
7
 
8
- # Use CPU
9
  device = torch.device('cpu')
10
 
11
- # Load the model ResNet-50 model architecture
12
- model = models.resnet50(pretrained=False)
13
 
14
- # Load model's weight to CPU
15
- model = torch.load('resnet50_model_weights.pth', map_location=device)
 
 
16
  model.eval()
17
 
18
- # Define the image preprocessing
19
  transform = transforms.Compose([
20
  transforms.Resize((224, 224)),
21
  transforms.ToTensor(),
22
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
23
  ])
24
 
25
- # Define the class names
26
- class_names = ['Abyssinian', 'American Bulldog', 'American Pit Bull Terrier', 'Basset Hound', 'Beagle', 'Bengal', 'Birman', 'Bombay',
27
- 'Boxer', 'British Shorthair', 'Chihuahua', 'Egyptian Mau', 'English Cocker Spaniel', 'English Setter', 'German Shorthaired',
28
- 'Great Pyrenees', 'Havanese', 'Japanese Chin', 'Keeshond', 'Leonberger', 'Maine Coon', 'Miniature Pinscher', 'Newfoundland',
29
- 'Persian', 'Pomeranian', 'Pug', 'Ragdoll', 'Russian Blue', 'Saint Bernard', 'Samoyed', 'Scottish Terrier', 'Shiba Inu',
30
- 'Siamese', 'Sphynx', 'Staffordshire Bull Terrier', 'Wheaten Terrier', 'Yorkshire Terrier']
31
-
32
- # Define the predict function
 
 
 
 
 
33
  def classify_image(image):
34
- image = transform(image).unsqueeze(0).to(device) # Ensure image data is processed on CPU
35
  with torch.no_grad():
36
  outputs = model(image)
37
- _, predicted = torch.max(outputs, 1)
38
- return class_names[predicted.item()]
39
-
40
- # Custom Gradio interface title, description, and article
41
- title = 'Oxford Pet 🐈🐕'
42
- description = 'A ResNet50-based computer vision model for classifying images of pets from the Oxford-IIIT Pet Dataset. The model can recognize 37 different pet breeds, including cats and dogs.'
43
- article = 'https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/The%20Oxford-IIIT%20Pet%20Project'
44
 
45
- # Gradio interface
46
  examples = [["examples/" + img] for img in os.listdir('examples')]
47
 
48
- demo = gr.Interface(fn=classify_image, # Map input to output function
49
- inputs=gr.Image(type="pil"), # Image input
50
- outputs=[gr.Label(num_top_classes=1, label="Predictions")], # Predicted label
51
- examples=examples, # Example images
52
- title=title,
53
- description=description,
54
- article=article)
55
 
56
- # Launch the demo
57
- demo.launch()
 
1
  import torch
2
+ from torchvision import transforms, models
 
3
  from PIL import Image
4
  import gradio as gr
5
  import os
6
 
7
+ # 使用 CPU
8
  device = torch.device('cpu')
9
 
10
+ # 定義 ResNet-50 模型架構
11
+ model = models.resnet50(weights=None) # 不使用預訓練權重,僅構建模型架構
12
 
13
+ # 加載模型權重到模型架構
14
+ model.load_state_dict(torch.load('resnet50_model_weights.pth', map_location=device))
15
+
16
+ # 設置模型為評估模式
17
  model.eval()
18
 
19
+ # 定義影像預處理
20
  transform = transforms.Compose([
21
  transforms.Resize((224, 224)),
22
  transforms.ToTensor(),
23
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
24
  ])
25
 
26
+ # 定義類別名稱
27
+ class_names = ['Abyssinian (阿比西尼亞貓)', 'American Bulldog (美國鬥牛犬)', 'American Pit Bull Terrier (美國比特鬥牛梗)',
28
+ 'Basset Hound (巴吉度獵犬)', 'Beagle (米格魯)', 'Bengal (孟加拉貓)', 'Birman (緬甸貓)', 'Bombay (孟買貓)',
29
+ 'Boxer (拳師犬)', 'British Shorthair (英國短毛貓)', 'Chihuahua (吉娃娃)', 'Egyptian Mau (埃及貓)',
30
+ 'English Cocker Spaniel (英國可卡犬)', 'English Setter (英國設得蘭犬)', 'German Shorthaired (德國短毛犬)',
31
+ 'Great Pyrenees (大白熊犬)', 'Havanese (哈瓦那犬)', 'Japanese Chin (日本狆)', 'Keeshond (荷蘭毛獅犬)',
32
+ 'Leonberger (萊昂貝格犬)', 'Maine Coon (緬因貓)', 'Miniature Pinscher (迷你品犬)', 'Newfoundland (紐芬蘭犬)',
33
+ 'Persian (波斯貓)', 'Pomeranian (博美犬)', 'Pug (哈巴狗)', 'Ragdoll (布偶貓)', 'Russian Blue (俄羅斯藍貓)',
34
+ 'Saint Bernard (聖伯納犬)', 'Samoyed (薩摩耶)', 'Scottish Terrier (蘇格蘭梗)', 'Shiba Inu (柴犬)',
35
+ 'Siamese (暹羅貓)', 'Sphynx (無毛貓)', 'Staffordshire Bull Terrier (史塔福郡鬥牛犬)',
36
+ 'Wheaten Terrier (小麥色梗)', 'Yorkshire Terrier (約克夏犬)']
37
+
38
+ # 定義預測函數
39
  def classify_image(image):
40
+ image = transform(image).unsqueeze(0).to(device) # 確保影像資料處理在 CPU
41
  with torch.no_grad():
42
  outputs = model(image)
43
+ probabilities, indices = torch.topk(outputs, k=3) # 取得前3個預測
44
+ probabilities = torch.nn.functional.softmax(probabilities, dim=1) # 將結果轉換為機率
45
+ predictions = [(class_names[idx], prob.item()) for idx, prob in zip(indices[0], probabilities[0])]
46
+ return {class_name: f"{prob * 100:.2f}%" for class_name, prob in predictions}
 
 
 
47
 
48
+ # Gradio 介面
49
  examples = [["examples/" + img] for img in os.listdir('examples')]
50
 
51
+ demo = gr.Interface(fn=classify_image,
52
+ inputs=gr.Image(type="pil"),
53
+ outputs=[gr.Label(num_top_classes=3, label="Top 3 Predictions")],
54
+ examples=examples,
55
+ title='Oxford Pet 🐈🐕',
56
+ description='A ResNet50-based model for classifying 37 different pet breeds.',
57
+ article='https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/The%20Oxford-IIIT%20Pet%20Project')
58
 
59
+ # 啟動 Gradio demo
60
+ demo.launch()