import os import sys current = os.path.dirname(os.path.realpath(__file__)) parent = os.path.dirname(current) sys.path.append(parent) import albumentations as A import matplotlib.pyplot as plt import numpy as np import streamlit as st import torch from albumentations.pytorch import ToTensorV2 from PIL import Image from model import Classifier # Load the model model = Classifier.load_from_checkpoint("./models/checkpoint_old.ckpt") model.eval() # Define labels labels = [ "dog", "horse", "elephant", "butterfly", "chicken", "cat", "cow", "sheep", "spider", "squirrel", ] # Preprocess function def preprocess(image): image = np.array(image) resize = A.Resize(224, 224) normalize = A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) to_tensor = ToTensorV2() transform = A.Compose([resize, normalize, to_tensor]) image = transform(image=image)["image"] return image # Define the sample images sample_images = { "butterfly": "./test_images/butterfly.jpg", "cat": "./test_images/cat.jpg", "dog": "./test_images/dog.jpeg", "squirrel": "./test_images/squirrel.jpeg", "horse": "./test_images/horse.jpeg", } # Define the function to make predictions on an image def predict(image): try: image = preprocess(image).unsqueeze(0) # Prediction # Make a prediction on the image with torch.no_grad(): output = model(image) # convert to probabilities probabilities = torch.nn.functional.softmax(output[0]) topk_prob, topk_label = torch.topk(probabilities, 3) # convert the predictions to a list predictions = [] for i in range(topk_prob.size(0)): prob = topk_prob[i].item() label = topk_label[i].item() predictions.append((prob, label)) return predictions except Exception as e: print(f"Error predicting image: {e}") return [] # Define the Streamlit app def app(): st.title("Animal-10 Image Classification") # Add a file uploader uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) # # Add a selectbox to choose from sample images sample = st.selectbox("Or choose from sample images:", list(sample_images.keys())) # If an image is uploaded, make a prediction on it if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image.", use_column_width=True) predictions = predict(image) # If a sample image is chosen, make a prediction on it elif sample: image = Image.open(sample_images[sample]) st.image(image, caption=sample.capitalize() + " Image.", use_column_width=True) predictions = predict(image) # Show the top 3 predictions with their probabilities if predictions: st.write("Top 3 predictions:") for i, (prob, label) in enumerate(predictions): st.write(f"{i+1}. {labels[label]} ({prob*100:.2f}%)") # Show progress bar with probabilities st.markdown( """ """, unsafe_allow_html=True, ) st.progress(prob) else: st.write("No predictions.") # Run the app if __name__ == "__main__": app()