datasciencedojo's picture
Update app.py
7bedfbc
raw
history blame
2.49 kB
from PIL import Image, ImageOps
import numpy as np
from collections import OrderedDict
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from keras.models import load_model
import gradio as gr
def create_plot(data):
sns.set_theme(style="whitegrid")
f, ax = plt.subplots(figsize=(5, 5))
sns.set_color_codes("pastel")
sns.barplot(x="Total", y="Labels", data=data,label="Total", color="b")
sns.set_color_codes("muted")
sns.barplot(x="Confidence Score", y="Labels", data=data,label="Conficence Score", color="b")
ax.legend(ncol=2, loc="lower right", frameon=True)
sns.despine(left=True, bottom=True)
return f
def predict_pneumonia(img):
np.set_printoptions(suppress=True)
model = load_model('keras_model.h5', compile=False)
class_names = open('labels.txt', 'r').readlines()
data = np.ndarray(shape=(1, 224, 224, 3), dtype=np.float32)
# image = Image.open(img).convert('RGB')
image = img
size = (224, 224)
image_PIL = Image.fromarray(image)
image = ImageOps.fit(image_PIL, size, Image.LANCZOS)
image_array = np.asarray(image)
normalized_image_array = (image_array.astype(np.float32) / 127.0) - 1
data[0] = normalized_image_array
prediction = model.predict(data)
index = np.argmax(prediction)
class_name = class_names[index]
confidence_score = prediction[0][index]
c_name = (class_name[2:])[:-1]
if c_name == "Normal":
pneumonia_prediction = "Chest XRay is normal no signs of pneumonia"
other_class = "Pneumonia"
else:
other_class = "Normal"
pneumonia_prediction = "Chest XRay shows signs of pneumonia"
res = {"Labels":[c_name,other_class], "Confidence Score":[(confidence_score*100),(1-confidence_score)*100],"Total":100}
data_for_plot = pd.DataFrame.from_dict(res)
pneumonia_conf_plt = create_plot(data_for_plot)
return pneumonia_prediction,pneumonia_conf_plt
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=4):
with gr.Row():
imgInput = gr.Image()
with gr.Column(scale=1):
pneumonia = gr.Textbox(label='Presence of pneumonia')
plot = gr.Plot(label="Plot")
submit_button = gr.Button(value="Submit")
submit_button.click(fn=predict_pneumonia, inputs=[imgInput], outputs=[pneumonia,plot])
gr.Examples(
examples=["normal_Sample.jpg","pneumonia_sample.jpg"],
inputs=imgInput,
outputs=[pneumonia,plot],
fn=predict_pneumonia,
cache_examples=True,
)
demo.launch()