Spaces:
Runtime error
Runtime error
import os | |
import sys | |
current = os.path.dirname(os.path.realpath(__file__)) | |
parent = os.path.dirname(current) | |
sys.path.append(parent) | |
import albumentations as A | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import streamlit as st | |
import torch | |
from albumentations.pytorch import ToTensorV2 | |
from PIL import Image | |
from model import Classifier | |
# Load the model | |
model = Classifier.load_from_checkpoint("./models/checkpoint_old.ckpt") | |
model.eval() | |
# Define labels | |
labels = [ | |
"dog", | |
"horse", | |
"elephant", | |
"butterfly", | |
"chicken", | |
"cat", | |
"cow", | |
"sheep", | |
"spider", | |
"squirrel", | |
] | |
# Preprocess function | |
def preprocess(image): | |
image = np.array(image) | |
resize = A.Resize(224, 224) | |
normalize = A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
to_tensor = ToTensorV2() | |
transform = A.Compose([resize, normalize, to_tensor]) | |
image = transform(image=image)["image"] | |
return image | |
# Define the sample images | |
sample_images = { | |
"butterfly": "./test_images/butterfly.jpg", | |
"cat": "./test_images/cat.jpg", | |
"dog": "./test_images/dog.jpeg", | |
"squirrel": "./test_images/squirrel.jpeg", | |
"horse": "./test_images/horse.jpeg", | |
} | |
# Define the function to make predictions on an image | |
def predict(image): | |
try: | |
image = preprocess(image).unsqueeze(0) | |
# Prediction | |
# Make a prediction on the image | |
with torch.no_grad(): | |
output = model(image) | |
# convert to probabilities | |
probabilities = torch.nn.functional.softmax(output[0]) | |
topk_prob, topk_label = torch.topk(probabilities, 3) | |
# convert the predictions to a list | |
predictions = [] | |
for i in range(topk_prob.size(0)): | |
prob = topk_prob[i].item() | |
label = topk_label[i].item() | |
predictions.append((prob, label)) | |
return predictions | |
except Exception as e: | |
print(f"Error predicting image: {e}") | |
return [] | |
# Define the Streamlit app | |
def app(): | |
st.title("Animal-10 Image Classification") | |
# Add a file uploader | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
# # Add a selectbox to choose from sample images | |
sample = st.selectbox("Or choose from sample images:", list(sample_images.keys())) | |
# If an image is uploaded, make a prediction on it | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image.", use_column_width=True) | |
predictions = predict(image) | |
# If a sample image is chosen, make a prediction on it | |
elif sample: | |
image = Image.open(sample_images[sample]) | |
st.image(image, caption=sample.capitalize() + " Image.", use_column_width=True) | |
predictions = predict(image) | |
# Show the top 3 predictions with their probabilities | |
if predictions: | |
st.write("Top 3 predictions:") | |
for i, (prob, label) in enumerate(predictions): | |
st.write(f"{i+1}. {labels[label]} ({prob*100:.2f}%)") | |
# Show progress bar with probabilities | |
st.markdown( | |
""" | |
<style> | |
.stProgress .st-b8 { | |
background-color: orange; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.progress(prob) | |
else: | |
st.write("No predictions.") | |
# Run the app | |
if __name__ == "__main__": | |
app() | |