Yadvendra's picture
Update app.py
0d7fabd verified
import streamlit as st
import numpy as np
import cv2
import tensorflow as tf
from PIL import Image
from sklearn.preprocessing import LabelEncoder
# Load your pre-trained model (Make sure this matches the version used during training)
model = tf.keras.models.load_model('brain_tumor_model.h5')
# Example class labels (update this list with your actual class labels)
class_labels = ['glioma', 'pituitary', 'meningioma', 'healthy']
label_encoder = LabelEncoder()
label_encoder.fit(class_labels) # Fit the label encoder with your class labels
# Function to load and preprocess the uploaded image
def load_and_preprocess_image(uploaded_file):
img = Image.open(uploaded_file)
img = img.convert("RGB") # Convert to RGB if it's in another format
img = np.array(img) # Convert to NumPy array
img = cv2.resize(img, (224, 224)) # Resize the image to 224x224
img = img / 255.0 # Normalize pixel values
img = np.reshape(img, (1, 224, 224, 3)) # Reshape for prediction
return img
# Function to predict the image class
def predict_image(img):
predictions = model.predict(img) # Make a prediction
predicted_class_index = np.argmax(predictions[0]) # Get the predicted class index
return predicted_class_index
# Function to get class label
def get_class_label(predicted_class_index):
return label_encoder.inverse_transform([predicted_class_index])[0] # Get class label
# Streamlit App UI
st.title("Brain Tumor using CNN 🧠")
st.write("Upload a brain scan (JPG format), and the model will predict its class.")
# File uploader for user to upload images
uploaded_file = st.file_uploader("Choose a JPG image...", type="jpg")
if uploaded_file is not None:
# Display the uploaded image on the left side
col1, col2 = st.columns([2, 1]) # Create two columns
with col1:
st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
with col2:
# Button to trigger prediction
if st.button("Detect"):
st.write("Detecting...")
# Load and preprocess the image
processed_image = load_and_preprocess_image(uploaded_file)
# Make prediction
predicted_class_index = predict_image(processed_image)
# Get predicted class label
predicted_class_label = get_class_label(predicted_class_index)
# Center display for the prediction result
st.markdown(f"<h3 style='color: #4CAF50; text-align: center;'>The Prediction is : <strong>{predicted_class_label}</strong></h3>", unsafe_allow_html=True)