Spaces:
Runtime error
Runtime error
File size: 3,548 Bytes
f21206e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
import sys
current = os.path.dirname(os.path.realpath(__file__))
parent = os.path.dirname(current)
sys.path.append(parent)
import albumentations as A
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
import torch
from albumentations.pytorch import ToTensorV2
from PIL import Image
from model import Classifier
# Load the model
model = Classifier.load_from_checkpoint("./models/checkpoint_old.ckpt")
model.eval()
# Define labels
labels = [
"dog",
"horse",
"elephant",
"butterfly",
"chicken",
"cat",
"cow",
"sheep",
"spider",
"squirrel",
]
# Preprocess function
def preprocess(image):
image = np.array(image)
resize = A.Resize(224, 224)
normalize = A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
to_tensor = ToTensorV2()
transform = A.Compose([resize, normalize, to_tensor])
image = transform(image=image)["image"]
return image
# Define the sample images
sample_images = {
"butterfly": "./test_images/butterfly.jpg",
"cat": "./test_images/cat.jpg",
"dog": "./test_images/dog.jpeg",
"squirrel": "./test_images/squirrel.jpeg",
"horse": "./test_images/horse.jpeg",
}
# Define the function to make predictions on an image
def predict(image):
try:
image = preprocess(image).unsqueeze(0)
# Prediction
# Make a prediction on the image
with torch.no_grad():
output = model(image)
# convert to probabilities
probabilities = torch.nn.functional.softmax(output[0])
topk_prob, topk_label = torch.topk(probabilities, 3)
# convert the predictions to a list
predictions = []
for i in range(topk_prob.size(0)):
prob = topk_prob[i].item()
label = topk_label[i].item()
predictions.append((prob, label))
return predictions
except Exception as e:
print(f"Error predicting image: {e}")
return []
# Define the Streamlit app
def app():
st.title("Animal-10 Image Classification")
# Add a file uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
# # Add a selectbox to choose from sample images
sample = st.selectbox("Or choose from sample images:", list(sample_images.keys()))
# If an image is uploaded, make a prediction on it
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image.", use_column_width=True)
predictions = predict(image)
# If a sample image is chosen, make a prediction on it
elif sample:
image = Image.open(sample_images[sample])
st.image(image, caption=sample.capitalize() + " Image.", use_column_width=True)
predictions = predict(image)
# Show the top 3 predictions with their probabilities
if predictions:
st.write("Top 3 predictions:")
for i, (prob, label) in enumerate(predictions):
st.write(f"{i+1}. {labels[label]} ({prob*100:.2f}%)")
# Show progress bar with probabilities
st.markdown(
"""
<style>
.stProgress .st-b8 {
background-color: orange;
}
</style>
""",
unsafe_allow_html=True,
)
st.progress(prob)
else:
st.write("No predictions.")
# Run the app
if __name__ == "__main__":
app()
|