import streamlit as st import numpy as np import cv2 import tensorflow as tf from PIL import Image from sklearn.preprocessing import LabelEncoder # Load your pre-trained model (Make sure this matches the version used during training) model = tf.keras.models.load_model('dementia_cnn_model.h5') # Example class labels (update this list with your actual class labels) class_labels = ['Non Demented', 'Very mild Dementia', 'Mild Dementia', 'Moderate Dementia'] label_encoder = LabelEncoder() label_encoder.fit(class_labels) # Fit the label encoder with your class labels # Function to load and preprocess the uploaded image def load_and_preprocess_image(uploaded_file): img = Image.open(uploaded_file) img = img.convert("RGB") # Convert to RGB if it's in another format img = np.array(img) # Convert to NumPy array img = cv2.resize(img, (224, 224)) # Resize the image to 224x224 img = img / 255.0 # Normalize pixel values img = np.reshape(img, (1, 224, 224, 3)) # Reshape for prediction return img # Function to predict the image class def predict_image(img): predictions = model.predict(img) # Make a prediction predicted_class_index = np.argmax(predictions[0]) # Get the predicted class index return predicted_class_index # Function to get class label def get_class_label(predicted_class_index): return label_encoder.inverse_transform([predicted_class_index])[0] # Get class label # Streamlit App UI st.title("Alzheimer Detection using CNN 🧠") st.write("Upload a brain scan (JPG format), and the model will predict its class.") # File uploader for user to upload images uploaded_file = st.file_uploader("Choose a JPG image...", type="jpg") if uploaded_file is not None: # Display the uploaded image on the left side col1, col2 = st.columns([2, 1]) # Create two columns with col1: st.image(uploaded_file, caption="Uploaded Image", use_column_width=True) with col2: # Button to trigger prediction if st.button("Detect"): st.write("Detecting...") # Load and preprocess the image processed_image = load_and_preprocess_image(uploaded_file) # Make prediction predicted_class_index = predict_image(processed_image) # Get predicted class label predicted_class_label = get_class_label(predicted_class_index) # Center display for the prediction result st.markdown(f"

The Prediction is : {predicted_class_label}

", unsafe_allow_html=True)