pamixsun's picture
Update app.py
8dcbbbd
raw
history blame
2.49 kB
# Copyright (C) 2023, Xu Sun.
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.
import torch
import numpy as np
import matplotlib.pyplot as plt
import streamlit as st
from PIL import Image
from glaucoma import GlaucomaModel
run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def main():
# Wide mode
st.set_page_config(layout="wide")
# Designing the interface
st.title("Glaucoma Screening from Retinal Fundus Images")
# For newline
st.write('\n')
# Author info
st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')
# For newline
st.write('\n')
# Instructions
st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
# Set the columns
cols = st.beta_columns((1, 1))
cols[0].subheader("Input image")
cols[1].subheader("Class activation map")
# Set the visualization figure
fig, ax = plt.subplots()
# Sidebar
# File selection
st.sidebar.title("Image selection")
# Disabling warning
st.set_option('deprecation.showfileUploaderEncoding', False)
# Choose your own image
uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg'])
if uploaded_file is not None:
# Read the upload image
image = Image.open(uploaded_file).convert('RGB')
image = np.array(image).astype(np.uint8)
ax.imshow(image)
ax.axis('off')
cols[0].pyplot(fig)
# For newline
st.sidebar.write('\n')
# Actions
if st.sidebar.button("Analyze image"):
if uploaded_file is None:
st.sidebar.write("Please upload an image")
else:
with st.spinner('Loading model...'):
# Load model
model = GlaucomaModel(device=run_device)
with st.spinner('Analyzing...'):
# Forward the image to the model and get results
disease_idx, cam = model.process(image)
# Plot the class activation map
ax.imshow(cam)
ax.axis('off')
cols[1].pyplot(fig)
# Display screening results
st.subheader(" Screening results:")
st.write('\n')
st.markdown(f"{model.id2label[disease_idx]}")
if __name__ == '__main__':
main()