File size: 4,040 Bytes
438403f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66fb450
438403f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import streamlit as st
import numpy as np
import pandas as pd
from PIL import Image
from pathlib import Path
import joblib

import numpy as np
import cv2
import onnxruntime as ort
import imutils
# import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px



def onnx_segment_membrane(input_image, threshold):
    ort_session = ort.InferenceSession('onnx_models/membrane_segmentor.onnx')
    img = Image.fromarray(np.uint8(input_image))
    resized = img.resize((256, 256), Image.NEAREST)
    img_unsqueeze = expand_dims_twice(resized)
    onnx_outputs = ort_session.run(None, {'input': img_unsqueeze.astype('float32')}) 
    binarized = 1.0 * (onnx_outputs[0][0][0] > threshold)

    resized_ret = Image.fromarray(binarized.astype(np.uint8) ).resize((356, 256), Image.NEAREST)#.convert("L")
    centroid_img = generate_centroid_image(np.array(onnx_outputs[0][0][0])) *255
    resized_centroid_img = Image.fromarray(centroid_img.astype(np.uint8)).resize((356, 256), Image.NEAREST)
    return(resized_ret, resized_centroid_img)


def generate_centroid_image(thresh):
    thresh = cv2.blur(thresh, (5,5))
    thresh = thresh.astype(np.uint8)
    centroid_image = np.zeros(thresh.shape)
    cnts = cv2.findContours(thresh, cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
    cnts = imutils.grab_contours(cnts)
    centroids = []
    for c in cnts:
        try:
            # compute the center of the contour
            M = cv2.moments(c)
            cX = int(M["m10"] / M["m00"])
            cY = int(M["m01"] / M["m00"])
            # draw the contour and center of the shape on the image
            # cv2.drawContours(centroid_image, [c], -1, (255, 255, 255), 2)
            cv2.circle(centroid_image, (cX, cY), 2, (1, 1, 1), -1)
            centroids.append((cX, cY))
        except:
            pass
    return(centroid_image)



def expand_dims_twice(arr):
    norm=(arr-np.min(arr))/(np.max(arr)-np.min(arr))
    ret = np.expand_dims(np.expand_dims(norm, axis=0), axis=0)
    return(ret)



def cell_membrane_segmentation():
    selected_box2 = st.sidebar.selectbox(
    'Choose Example Input',
    ('Example_1.png','Example_2.png')
    )

    st.title('Cell Membrane Segmentation')
    instructions = """
        Segment Cell Membrane from C. elegans embryo imaging data \n
        Either upload your own image or select from the sidebar to get a preconfigured image. 
        The image you select or upload will be fed through the Deep Neural Network in real-time 
        and the output will be displayed to the screen.
        """
    st.text(instructions)
    file = st.file_uploader('Upload an image or choose an example')
    example_image = Image.open('./images/cell_membrane_segmentation_examples/'+selected_box2)
    threshold = st.sidebar.slider("Select Threshold (Applied on model output)", 0.0, 1.0, 0.1)
    col1, col2, col3 = st.columns(3)

    if file:
        input = Image.open(file)
        fig1 = px.imshow(input, binary_string=True, labels=dict(x="Input Image"))
        fig1.update(layout_coloraxis_showscale=False)
        fig1.update_layout(margin=dict(l=0, r=0, b=0, t=0))
        col1.plotly_chart(fig1, use_container_width=True)

    else:
        input = example_image
        fig1 = px.imshow(input, binary_string=True, labels=dict(x="Input Image"))
        fig1.update(layout_coloraxis_showscale=False)
        fig1.update_layout(margin=dict(l=0, r=0, b=0, t=0))
        col1.plotly_chart(fig1, use_container_width=True)

    pressed = st.button('Run')
    if pressed:
        st.empty()
        model_output = onnx_segment_membrane(np.array(input), threshold)

        fig2 = px.imshow(model_output[0], binary_string=True, labels=dict(x="Segmentation Map"))
        fig2.update_layout(margin=dict(l=0, r=0, b=0, t=0))
        col2.plotly_chart(fig2, use_container_width=True)

        fig3 = px.imshow(model_output[1], binary_string=True, labels=dict(x="Centroid Map"))
        fig3.update_layout(margin=dict(l=0, r=0, b=0, t=0))
        col3.plotly_chart(fig3, use_container_width=True)