rufimelo commited on
Commit
a6b3107
1 Parent(s): cdc6945

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.DS_Store ADDED
Binary file (6.15 kB). View file
 
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: FashionMNIST
3
- emoji: 🐨
4
  colorFrom: gray
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.1.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
  title: FashionMNIST
3
+ emoji: 📷
4
  colorFrom: gray
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.12.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import cv2
5
+ import gradio as gr
6
+ import joblib
7
+ import numpy as np
8
+ import torch
9
+
10
+ from models.cnn import Classifier
11
+ from models.feed_forward import FeedForwardClassifier
12
+
13
+ simple_classifier = Classifier()
14
+ CNN_PATH = "models/classifier_cnn.pth"
15
+ simple_classifier.load_state_dict(torch.load(CNN_PATH))
16
+
17
+ feed_forward_classifier = FeedForwardClassifier(784)
18
+ FF_PATH = "models/classifier.pth"
19
+ feed_forward_classifier.load_state_dict(torch.load(FF_PATH))
20
+
21
+ # Required for the classifier
22
+ from sklearn.ensemble import RandomForestClassifier
23
+
24
+ RF_PATH = "models/fashion_mnist_rf_model.pkl"
25
+ rf_clf = joblib.load(RF_PATH)
26
+
27
+
28
+ # Required for the classifier
29
+ from sklearn.svm import SVC
30
+
31
+ SVM_PATH = "models/fashion_mnist_svm_model.pkl"
32
+ svm_clf = joblib.load(SVM_PATH)
33
+ SVM_PATH_RBF = "models/fashion_mnist_svm_model_rbf.pkl"
34
+ svm_clf_rbf = joblib.load(SVM_PATH_RBF)
35
+
36
+
37
+ LABELS = [
38
+ "T-shirt/top",
39
+ "Trouser",
40
+ "Pullover",
41
+ "Dress",
42
+ "Coat",
43
+ "Sandal",
44
+ "Shirt",
45
+ "Sneaker",
46
+ "Bag",
47
+ "Ankle boot",
48
+ ]
49
+
50
+
51
+ def classify(img: str):
52
+ # read image
53
+ img = cv2.imread(img, cv2.IMREAD_GRAYSCALE)
54
+ img = cv2.resize(img, (28, 28))
55
+ img = img / 255.0
56
+ img = np.array(img).reshape(-1, 1, 28, 28)
57
+ img = torch.from_numpy(img).float()
58
+
59
+ cnn_output = simple_classifier(img)
60
+ cnn_output = torch.nn.functional.softmax(cnn_output, dim=1)
61
+ cnn_output_pred = torch.argmax(cnn_output)
62
+ cnn_output_label = LABELS[cnn_output_pred.item()]
63
+ cnn_output_str = f"Simple CNN: {cnn_output_label} with {round(torch.max(cnn_output).item()*100, 2)}% confidence"
64
+
65
+ feed_forward_output = feed_forward_classifier(img)
66
+ feed_forward_output = torch.nn.functional.softmax(feed_forward_output, dim=1)
67
+ feed_forward_output_pred = torch.argmax(feed_forward_output)
68
+ feed_forward_output_label = LABELS[feed_forward_output_pred.item()]
69
+ feed_forward_output_str = f"Feed Forward: {feed_forward_output_label} with {round(torch.max(feed_forward_output).item()*100, 2)}% confidence"
70
+
71
+ rf_output = rf_clf.predict(img.reshape(1, -1))
72
+ rf_output_label = LABELS[rf_output[0]]
73
+ rf_output_str = f"Random Forest: {rf_output_label}"
74
+
75
+ svm_output = svm_clf.predict(img.reshape(1, -1))
76
+ svm_output_label = LABELS[svm_output[0]]
77
+ svm_output_str = f"SVM with linear kernel: {svm_output_label}"
78
+
79
+ svm_output_rbf = svm_clf_rbf.predict(img.reshape(1, -1))
80
+ svm_output_label_rbf = LABELS[svm_output_rbf[0]]
81
+ svm_output_str_rbf = f"SVM with RBF kernel: {svm_output_label_rbf}"
82
+
83
+ output = (
84
+ cnn_output_str
85
+ + "\n"
86
+ + feed_forward_output_str
87
+ + "\n"
88
+ + rf_output_str
89
+ + "\n"
90
+ + svm_output_str
91
+ + "\n"
92
+ + svm_output_str_rbf
93
+ )
94
+ return output
95
+
96
+
97
+ folder = "./images"
98
+ examples = []
99
+ for filename in os.listdir(folder):
100
+ img_path = os.path.join(folder, filename)
101
+ examples.append([img_path])
102
+
103
+ random.shuffle(examples)
104
+
105
+
106
+ iface = gr.Interface(
107
+ fn=classify,
108
+ title="Fashion MNIST Classifier - TAECAC @ FEUP",
109
+ description="Use an image to classify using the different Fashion MNIST model.",
110
+ inputs=gr.Image(label="Image", type="filepath"),
111
+ outputs=gr.Textbox(label="Classification output"),
112
+ examples=examples,
113
+ examples_per_page=100,
114
+ theme=gr.themes.Soft(
115
+ primary_hue=gr.themes.colors.indigo,
116
+ secondary_hue=gr.themes.colors.gray,
117
+ neutral_hue=gr.themes.colors.slate,
118
+ font=["avenir"],
119
+ ),
120
+ )
121
+
122
+ iface.launch()
fmt.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ echo "Running fmt"
2
+ echo "Running isort"
3
+ isort . --profile black
4
+ echo "Running black"
5
+ black .
images/0.png ADDED
images/1.png ADDED
images/10.png ADDED
images/11.png ADDED
images/12.png ADDED
images/13.png ADDED
images/14.png ADDED
images/15.png ADDED
images/16.png ADDED
images/17.png ADDED
images/18.png ADDED
images/19.png ADDED
images/2.png ADDED
images/20.png ADDED
images/21.png ADDED
images/22.png ADDED
images/23.png ADDED
images/24.png ADDED
images/25.png ADDED
images/26.png ADDED
images/27.png ADDED
images/28.png ADDED
images/29.png ADDED
images/3.png ADDED
images/30.png ADDED
images/31.png ADDED
images/32.png ADDED
images/33.png ADDED
images/34.png ADDED
images/35.png ADDED
images/36.png ADDED
images/37.png ADDED
images/38.png ADDED
images/39.png ADDED
images/4.png ADDED
images/40.png ADDED
images/41.png ADDED
images/42.png ADDED
images/43.png ADDED
images/44.png ADDED
images/45.png ADDED
images/46.png ADDED
images/47.png ADDED
images/48.png ADDED
images/49.png ADDED
images/5.png ADDED