File size: 2,594 Bytes
f2c28c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Copyright (C) 2023, Xu Sun.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import torch
import numpy as np

import matplotlib.pyplot as plt
import streamlit as st

from PIL import Image
from lib.glaucoma import GlaucomaModel

run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def main():
    # Wide mode
    st.set_page_config(layout="wide")

    # Designing the interface
    st.title("Glaucoma Screening from Retinal Fundus Images")
    # For newline
    st.write('\n')
    # Author info
    st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')
    # For newline
    st.write('\n')
    # Instructions
    st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
    # Set the columns
    cols = st.beta_columns((1, 1))
    cols[0].subheader("Input image")
    cols[1].subheader("Class activation map")
    
    # set the visualization figure
    fig, ax = plt.subplots()

    # Sidebar
    # File selection
    st.sidebar.title("Image selection")
    # Disabling warning
    st.set_option('deprecation.showfileUploaderEncoding', False)
    # Choose your own image
    uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg'])
    if uploaded_file is not None:
        # read the upload image
        image = Image.open(uploaded_file).convert('RGB')
        image = np.array(image).astype(np.uint8)
        # page_idx = 0
        ax.imshow(image)
        ax.axis('off')
        cols[0].pyplot(fig)

    # For newline
    st.sidebar.write('\n')

    # actions
    if st.sidebar.button("Analyze image"):

        if uploaded_file is None:
            st.sidebar.write("Please upload an image")

        else:
            with st.spinner('Loading model...'):
                # load model
                model = GlaucomaModel(device=run_device)

            with st.spinner('Analyzing...'):
                # Forward the image to the model and get results
                disease_idx, cam = model.process(image)
                
                # visualize results
                # fig, ax = plt.subplots()

                # plot the stitched image
                ax.imshow(cam)
                ax.axis('off')
                cols[1].pyplot(fig)

                # Display JSON
                st.subheader("  Screening results:")
                st.write('\n')
                st.markdown(f"{model.id2label[disease_idx]}")


if __name__ == '__main__':
    main()