File size: 2,594 Bytes
f2c28c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
# Copyright (C) 2023, Xu Sun.
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.
import torch
import numpy as np
import matplotlib.pyplot as plt
import streamlit as st
from PIL import Image
from lib.glaucoma import GlaucomaModel
run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def main():
# Wide mode
st.set_page_config(layout="wide")
# Designing the interface
st.title("Glaucoma Screening from Retinal Fundus Images")
# For newline
st.write('\n')
# Author info
st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')
# For newline
st.write('\n')
# Instructions
st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
# Set the columns
cols = st.beta_columns((1, 1))
cols[0].subheader("Input image")
cols[1].subheader("Class activation map")
# set the visualization figure
fig, ax = plt.subplots()
# Sidebar
# File selection
st.sidebar.title("Image selection")
# Disabling warning
st.set_option('deprecation.showfileUploaderEncoding', False)
# Choose your own image
uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg'])
if uploaded_file is not None:
# read the upload image
image = Image.open(uploaded_file).convert('RGB')
image = np.array(image).astype(np.uint8)
# page_idx = 0
ax.imshow(image)
ax.axis('off')
cols[0].pyplot(fig)
# For newline
st.sidebar.write('\n')
# actions
if st.sidebar.button("Analyze image"):
if uploaded_file is None:
st.sidebar.write("Please upload an image")
else:
with st.spinner('Loading model...'):
# load model
model = GlaucomaModel(device=run_device)
with st.spinner('Analyzing...'):
# Forward the image to the model and get results
disease_idx, cam = model.process(image)
# visualize results
# fig, ax = plt.subplots()
# plot the stitched image
ax.imshow(cam)
ax.axis('off')
cols[1].pyplot(fig)
# Display JSON
st.subheader(" Screening results:")
st.write('\n')
st.markdown(f"{model.id2label[disease_idx]}")
if __name__ == '__main__':
main() |