|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import numpy as np |
|
|
|
import matplotlib.pyplot as plt |
|
import streamlit as st |
|
|
|
from PIL import Image |
|
from glaucoma import GlaucomaModel |
|
|
|
run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def main(): |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
st.title("Glaucoma Screening from Retinal Fundus Images") |
|
|
|
st.write('\n') |
|
|
|
st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io') |
|
|
|
st.write('\n') |
|
|
|
st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*") |
|
|
|
cols = st.beta_columns((1, 1)) |
|
cols[0].subheader("Input image") |
|
cols[1].subheader("Class activation map") |
|
|
|
|
|
fig, ax = plt.subplots() |
|
|
|
|
|
|
|
st.sidebar.title("Image selection") |
|
|
|
st.set_option('deprecation.showfileUploaderEncoding', False) |
|
|
|
uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg']) |
|
if uploaded_file is not None: |
|
|
|
image = Image.open(uploaded_file).convert('RGB') |
|
image = np.array(image).astype(np.uint8) |
|
|
|
ax.imshow(image) |
|
ax.axis('off') |
|
cols[0].pyplot(fig) |
|
|
|
|
|
st.sidebar.write('\n') |
|
|
|
|
|
if st.sidebar.button("Analyze image"): |
|
|
|
if uploaded_file is None: |
|
st.sidebar.write("Please upload an image") |
|
|
|
else: |
|
with st.spinner('Loading model...'): |
|
|
|
model = GlaucomaModel(device=run_device) |
|
|
|
with st.spinner('Analyzing...'): |
|
|
|
disease_idx, cam = model.process(image) |
|
|
|
|
|
|
|
|
|
|
|
ax.imshow(cam) |
|
ax.axis('off') |
|
cols[1].pyplot(fig) |
|
|
|
|
|
st.subheader(" Screening results:") |
|
st.write('\n') |
|
st.markdown(f"{model.id2label[disease_idx]}") |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |