pytholic commited on
Commit
f21206e
β€’
1 Parent(s): 599e603

pushing app script

Browse files
Files changed (3) hide show
  1. README.md +0 -13
  2. app/app.py +132 -0
  3. requirements.txt +5 -0
README.md DELETED
@@ -1,13 +0,0 @@
1
- ---
2
- title: Streamlit Image Classification Demo
3
- emoji: πŸƒ
4
- colorFrom: red
5
- colorTo: yellow
6
- sdk: streamlit
7
- sdk_version: 1.17.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ current = os.path.dirname(os.path.realpath(__file__))
5
+
6
+ parent = os.path.dirname(current)
7
+
8
+ sys.path.append(parent)
9
+
10
+ import albumentations as A
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ import streamlit as st
14
+ import torch
15
+ from albumentations.pytorch import ToTensorV2
16
+ from PIL import Image
17
+
18
+ from model import Classifier
19
+
20
+ # Load the model
21
+ model = Classifier.load_from_checkpoint("./models/checkpoint_old.ckpt")
22
+ model.eval()
23
+
24
+ # Define labels
25
+ labels = [
26
+ "dog",
27
+ "horse",
28
+ "elephant",
29
+ "butterfly",
30
+ "chicken",
31
+ "cat",
32
+ "cow",
33
+ "sheep",
34
+ "spider",
35
+ "squirrel",
36
+ ]
37
+
38
+ # Preprocess function
39
+ def preprocess(image):
40
+ image = np.array(image)
41
+ resize = A.Resize(224, 224)
42
+ normalize = A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
43
+ to_tensor = ToTensorV2()
44
+ transform = A.Compose([resize, normalize, to_tensor])
45
+ image = transform(image=image)["image"]
46
+ return image
47
+
48
+
49
+ # Define the sample images
50
+ sample_images = {
51
+ "butterfly": "./test_images/butterfly.jpg",
52
+ "cat": "./test_images/cat.jpg",
53
+ "dog": "./test_images/dog.jpeg",
54
+ "squirrel": "./test_images/squirrel.jpeg",
55
+ "horse": "./test_images/horse.jpeg",
56
+ }
57
+
58
+ # Define the function to make predictions on an image
59
+ def predict(image):
60
+ try:
61
+ image = preprocess(image).unsqueeze(0)
62
+
63
+ # Prediction
64
+ # Make a prediction on the image
65
+ with torch.no_grad():
66
+ output = model(image)
67
+ # convert to probabilities
68
+ probabilities = torch.nn.functional.softmax(output[0])
69
+
70
+ topk_prob, topk_label = torch.topk(probabilities, 3)
71
+
72
+ # convert the predictions to a list
73
+ predictions = []
74
+ for i in range(topk_prob.size(0)):
75
+ prob = topk_prob[i].item()
76
+ label = topk_label[i].item()
77
+ predictions.append((prob, label))
78
+
79
+ return predictions
80
+ except Exception as e:
81
+ print(f"Error predicting image: {e}")
82
+ return []
83
+
84
+
85
+ # Define the Streamlit app
86
+ def app():
87
+ st.title("Animal-10 Image Classification")
88
+
89
+ # Add a file uploader
90
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
91
+
92
+ # # Add a selectbox to choose from sample images
93
+ sample = st.selectbox("Or choose from sample images:", list(sample_images.keys()))
94
+
95
+ # If an image is uploaded, make a prediction on it
96
+ if uploaded_file is not None:
97
+ image = Image.open(uploaded_file)
98
+ st.image(image, caption="Uploaded Image.", use_column_width=True)
99
+ predictions = predict(image)
100
+
101
+ # If a sample image is chosen, make a prediction on it
102
+ elif sample:
103
+ image = Image.open(sample_images[sample])
104
+ st.image(image, caption=sample.capitalize() + " Image.", use_column_width=True)
105
+ predictions = predict(image)
106
+
107
+ # Show the top 3 predictions with their probabilities
108
+ if predictions:
109
+ st.write("Top 3 predictions:")
110
+ for i, (prob, label) in enumerate(predictions):
111
+ st.write(f"{i+1}. {labels[label]} ({prob*100:.2f}%)")
112
+
113
+ # Show progress bar with probabilities
114
+ st.markdown(
115
+ """
116
+ <style>
117
+ .stProgress .st-b8 {
118
+ background-color: orange;
119
+ }
120
+ </style>
121
+ """,
122
+ unsafe_allow_html=True,
123
+ )
124
+ st.progress(prob)
125
+
126
+ else:
127
+ st.write("No predictions.")
128
+
129
+
130
+ # Run the app
131
+ if __name__ == "__main__":
132
+ app()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ pytorch
2
+ pytorch-lightning
3
+ simple-parsing
4
+ albumentations
5
+ matplotlib