Spaces:
Sleeping
Sleeping
sup
Browse files- .DS_Store +0 -0
- README.md +1 -1
- app.py +63 -7
- models/fashion_mnist_knn_model.pkl +3 -0
- models/fashion_mnist_lr_model.pkl +3 -0
- models/fashion_mnist_xgb_model.pkl +3 -0
- requirements.txt +2 -1
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 📷
|
|
4 |
colorFrom: gray
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
|
|
4 |
colorFrom: gray
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.1.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
app.py
CHANGED
@@ -34,6 +34,26 @@ SVM_PATH_RBF = "models/fashion_mnist_svm_model_rbf.pkl"
|
|
34 |
svm_clf_rbf = joblib.load(SVM_PATH_RBF)
|
35 |
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
LABELS = [
|
38 |
"T-shirt/top",
|
39 |
"Trouser",
|
@@ -56,40 +76,76 @@ def classify(img: str):
|
|
56 |
img = np.array(img).reshape(-1, 1, 28, 28)
|
57 |
img = torch.from_numpy(img).float()
|
58 |
|
|
|
59 |
cnn_output = simple_classifier(img)
|
60 |
cnn_output = torch.nn.functional.softmax(cnn_output, dim=1)
|
61 |
cnn_output_pred = torch.argmax(cnn_output)
|
62 |
cnn_output_label = LABELS[cnn_output_pred.item()]
|
63 |
-
|
|
|
|
|
|
|
64 |
|
|
|
65 |
feed_forward_output = feed_forward_classifier(img)
|
66 |
feed_forward_output = torch.nn.functional.softmax(feed_forward_output, dim=1)
|
67 |
feed_forward_output_pred = torch.argmax(feed_forward_output)
|
68 |
feed_forward_output_label = LABELS[feed_forward_output_pred.item()]
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
|
|
71 |
rf_output = rf_clf.predict(img.reshape(1, -1))
|
72 |
rf_output_label = LABELS[rf_output[0]]
|
73 |
-
rf_output_str = f"Random Forest: {rf_output_label}"
|
74 |
|
|
|
75 |
svm_output = svm_clf.predict(img.reshape(1, -1))
|
76 |
svm_output_label = LABELS[svm_output[0]]
|
77 |
-
svm_output_str = f"SVM with
|
78 |
|
|
|
79 |
svm_output_rbf = svm_clf_rbf.predict(img.reshape(1, -1))
|
80 |
svm_output_label_rbf = LABELS[svm_output_rbf[0]]
|
81 |
-
svm_output_str_rbf = f"SVM with RBF kernel: {svm_output_label_rbf}"
|
|
|
|
|
|
|
|
|
|
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
output = (
|
84 |
cnn_output_str
|
85 |
+ "\n"
|
86 |
+ feed_forward_output_str
|
87 |
+ "\n"
|
|
|
|
|
88 |
+ rf_output_str
|
89 |
+ "\n"
|
90 |
+ svm_output_str
|
91 |
+ "\n"
|
92 |
+ svm_output_str_rbf
|
|
|
|
|
|
|
|
|
93 |
)
|
94 |
return output
|
95 |
|
@@ -106,11 +162,11 @@ random.shuffle(examples)
|
|
106 |
iface = gr.Interface(
|
107 |
fn=classify,
|
108 |
title="Fashion MNIST Classifier - TAECAC @ FEUP",
|
109 |
-
description="
|
110 |
inputs=gr.Image(label="Image", type="filepath"),
|
111 |
outputs=gr.Textbox(label="Classification output"),
|
112 |
examples=examples,
|
113 |
-
examples_per_page=100,
|
114 |
theme=gr.themes.Soft(
|
115 |
primary_hue=gr.themes.colors.indigo,
|
116 |
secondary_hue=gr.themes.colors.gray,
|
|
|
34 |
svm_clf_rbf = joblib.load(SVM_PATH_RBF)
|
35 |
|
36 |
|
37 |
+
# Required for the classifier
|
38 |
+
from sklearn.linear_model import LogisticRegression
|
39 |
+
|
40 |
+
LR_PATH = "models/fashion_mnist_lr_model.pkl"
|
41 |
+
lr_clf = joblib.load(LR_PATH)
|
42 |
+
|
43 |
+
# Required for the classifier
|
44 |
+
|
45 |
+
from sklearn.neighbors import KNeighborsClassifier
|
46 |
+
|
47 |
+
KNN_PATH = "models/fashion_mnist_knn_model.pkl"
|
48 |
+
knn_clf = joblib.load(KNN_PATH)
|
49 |
+
|
50 |
+
# Required for the classifier
|
51 |
+
from xgboost import XGBClassifier
|
52 |
+
|
53 |
+
XGB_PATH = "models/fashion_mnist_xgb_model.pkl"
|
54 |
+
xgb_clf = joblib.load(XGB_PATH)
|
55 |
+
|
56 |
+
|
57 |
LABELS = [
|
58 |
"T-shirt/top",
|
59 |
"Trouser",
|
|
|
76 |
img = np.array(img).reshape(-1, 1, 28, 28)
|
77 |
img = torch.from_numpy(img).float()
|
78 |
|
79 |
+
# CNN Classifier
|
80 |
cnn_output = simple_classifier(img)
|
81 |
cnn_output = torch.nn.functional.softmax(cnn_output, dim=1)
|
82 |
cnn_output_pred = torch.argmax(cnn_output)
|
83 |
cnn_output_label = LABELS[cnn_output_pred.item()]
|
84 |
+
cnn_confidence = round(torch.max(cnn_output).item() * 100, 2)
|
85 |
+
cnn_output_str = (
|
86 |
+
f"{'CNN:':<35} {cnn_output_label:<15} with {cnn_confidence:.2f}% confidence"
|
87 |
+
)
|
88 |
|
89 |
+
# Feed Forward Classifier
|
90 |
feed_forward_output = feed_forward_classifier(img)
|
91 |
feed_forward_output = torch.nn.functional.softmax(feed_forward_output, dim=1)
|
92 |
feed_forward_output_pred = torch.argmax(feed_forward_output)
|
93 |
feed_forward_output_label = LABELS[feed_forward_output_pred.item()]
|
94 |
+
feed_forward_confidence = round(torch.max(feed_forward_output).item() * 100, 2)
|
95 |
+
feed_forward_output_str = f"{'Feed Forward:':<35} {feed_forward_output_label:<15} with {feed_forward_confidence:.2f}% confidence"
|
96 |
+
|
97 |
+
# XGBoost Classifier
|
98 |
+
xgb_output = xgb_clf.predict(img.reshape(1, -1))
|
99 |
+
xgb_output_label = LABELS[xgb_output[0]]
|
100 |
+
xgb_confidence = round(
|
101 |
+
float(np.max(xgb_clf.predict_proba(img.reshape(1, -1))[0])) * 100, 2
|
102 |
+
)
|
103 |
+
xgb_output_str = (
|
104 |
+
f"{'XGBoost:':<35} {xgb_output_label:<15} with {xgb_confidence:.2f}% confidence"
|
105 |
+
)
|
106 |
|
107 |
+
# Random Forest Classifier
|
108 |
rf_output = rf_clf.predict(img.reshape(1, -1))
|
109 |
rf_output_label = LABELS[rf_output[0]]
|
110 |
+
rf_output_str = f"{'Random Forest:':<35} {rf_output_label:<15}"
|
111 |
|
112 |
+
# SVM with Linear Kernel Classifier
|
113 |
svm_output = svm_clf.predict(img.reshape(1, -1))
|
114 |
svm_output_label = LABELS[svm_output[0]]
|
115 |
+
svm_output_str = f"{'SVM with Linear kernel:':<35} {svm_output_label:<15}"
|
116 |
|
117 |
+
# SVM with RBF Kernel Classifier
|
118 |
svm_output_rbf = svm_clf_rbf.predict(img.reshape(1, -1))
|
119 |
svm_output_label_rbf = LABELS[svm_output_rbf[0]]
|
120 |
+
svm_output_str_rbf = f"{'SVM with RBF kernel:':<35} {svm_output_label_rbf:<15}"
|
121 |
+
|
122 |
+
# Logistic Regression Classifier
|
123 |
+
lr_output = lr_clf.predict(img.reshape(1, -1))
|
124 |
+
lr_output_label = LABELS[lr_output[0]]
|
125 |
+
lr_output_str = f"{'Logistic Regression:':<35} {lr_output_label:<15}"
|
126 |
|
127 |
+
# KNN Classifier
|
128 |
+
knn_output = knn_clf.predict(img.reshape(1, -1))
|
129 |
+
knn_output_label = LABELS[knn_output[0]]
|
130 |
+
knn_output_str = f"{'KNN:':<35} {knn_output_label:<15}"
|
131 |
+
|
132 |
+
# Combine output
|
133 |
output = (
|
134 |
cnn_output_str
|
135 |
+ "\n"
|
136 |
+ feed_forward_output_str
|
137 |
+ "\n"
|
138 |
+
+ xgb_output_str
|
139 |
+
+ "\n"
|
140 |
+ rf_output_str
|
141 |
+ "\n"
|
142 |
+ svm_output_str
|
143 |
+ "\n"
|
144 |
+ svm_output_str_rbf
|
145 |
+
+ "\n"
|
146 |
+
+ lr_output_str
|
147 |
+
+ "\n"
|
148 |
+
+ knn_output_str
|
149 |
)
|
150 |
return output
|
151 |
|
|
|
162 |
iface = gr.Interface(
|
163 |
fn=classify,
|
164 |
title="Fashion MNIST Classifier - TAECAC @ FEUP",
|
165 |
+
description="Simple Proof of Concept.",
|
166 |
inputs=gr.Image(label="Image", type="filepath"),
|
167 |
outputs=gr.Textbox(label="Classification output"),
|
168 |
examples=examples,
|
169 |
+
examples_per_page=100,
|
170 |
theme=gr.themes.Soft(
|
171 |
primary_hue=gr.themes.colors.indigo,
|
172 |
secondary_hue=gr.themes.colors.gray,
|
models/fashion_mnist_knn_model.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cb976fbeb5cd6b96fa2e352cb76abc759fdb28bfa6a6e734154776dffda86b1d
|
3 |
+
size 188640868
|
models/fashion_mnist_lr_model.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e39ebb79140b7fad5f4b6aa4c84faa063299b9f99f9c0fb201ddcdb615e81432
|
3 |
+
size 32319
|
models/fashion_mnist_xgb_model.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:949b5f3cdf0ce1b64b609a2f70ce56f39dcdb6a8005a7475ddf96e45308d965c
|
3 |
+
size 9915808
|
requirements.txt
CHANGED
@@ -5,4 +5,5 @@ watchdog
|
|
5 |
scikit-learn
|
6 |
opencv-python
|
7 |
isort==5.13.2
|
8 |
-
black==24.8.0
|
|
|
|
5 |
scikit-learn
|
6 |
opencv-python
|
7 |
isort==5.13.2
|
8 |
+
black==24.8.0
|
9 |
+
xgboost
|