ferferefer commited on
Commit
6a88d6c
1 Parent(s): 4708345

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -0
app.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import SegformerForSemanticSegmentation
7
+ from transformers import AutoTokenizer
8
+ from transformers import AutoImageProcessor
9
+ from huggingface_hub import hf_hub_url, cached_download
10
+ from tensorflow.keras.applications import EfficientNetV2B0
11
+ from keras.layers import GlobalAveragePooling2D, Dense
12
+ from keras.models import Model
13
+ from tensorflow.keras.optimizers import Adam
14
+ from tensorflow.keras.preprocessing.image import load_img, img_to_array
15
+
16
+ # Load SegFormer model
17
+ model_id_seg = "nvidia/mit-b0"
18
+ image_processor = AutoImageProcessor.from_pretrained(model_id_seg, size=(128, 128))
19
+ #id2label = {0: "na", 1:"anillo", 2:"nervio"}
20
+ #label2id = { v:k for k, v in id2label.items()}
21
+ #model_seg = AutoModelForSemanticSegmentation.from_pretrained(model_id_seg, id2label=id2label, label2id=label2id)
22
+
23
+ # Load SegFormer model with trained weights
24
+ repo_id_seg = "ferferefer/segformer"
25
+ #filename_seg = "model.ckpt"
26
+ #model_file_seg = cached_download(hf_hub_url(repo_id_seg, filename_seg))
27
+
28
+ model_seg = SegformerForSemanticSegmentation.from_pretrained(repo_id_seg)
29
+
30
+ # Function to preprocess and obtain predictions from SegFormer model
31
+ def obtener_predicciones(model, sample_batch):
32
+ processed_batch = image_processor(sample_batch, return_tensors="pt")
33
+ pixel_values = processed_batch.pixel_values
34
+ outputs = model(pixel_values=pixel_values)
35
+ logits = outputs.logits
36
+
37
+ upsampled_logits = nn.functional.interpolate(
38
+ logits,
39
+ size=sample_batch[0].size[::-1],
40
+ mode="bilinear",
41
+ align_corners=False,
42
+ )
43
+ pred_seg = upsampled_logits.argmax(dim=1)
44
+ return pred_seg
45
+
46
+ # Function to calculate centroids of the segmented image
47
+ def calcular_centro_imagen(masks):
48
+ centroid_list = []
49
+ imagenes_transformadas = []
50
+
51
+ for mask in masks:
52
+
53
+ image = np.transpose(np.argwhere(mask.cpu()==1))
54
+
55
+
56
+ x = [p[0] for p in image]
57
+ y = [p[1] for p in image]
58
+ centroid = (sum(x) / len(image), sum(y) / len(image))
59
+ centroid_list.append(centroid)
60
+ imagenes_transformadas.append(image)
61
+
62
+ return centroid_list,imagenes_transformadas
63
+
64
+ # Function to crop the segmented image based on centroids
65
+ def recortar_imagen(centroids, mascara_final,images):
66
+ lista_img_recortadas = []
67
+
68
+ for counter, image in enumerate (images):
69
+ max_distance = 0
70
+
71
+ for x, y in mascara_final[counter]:
72
+
73
+ distance = np.sqrt((x - centroids[counter][0]) ** 2 + (y - centroids[counter][1]) ** 2)
74
+
75
+ if distance > max_distance:
76
+
77
+ max_distance = distance
78
+
79
+ centroid_uno = int(centroids[counter][1].item())
80
+
81
+ centroid_cero = int(centroids[counter][0].item())
82
+
83
+ max_distance = int(max_distance.item())
84
+
85
+ image = image.cpu().numpy()
86
+
87
+ #image = np.transpose(image, (1, 2, 0))
88
+
89
+
90
+ #image = np.clip(image, 0, 1, dtype=np.float32)
91
+
92
+ a = centroid_cero - int(max_distance * 2)
93
+
94
+ b = centroid_cero + int(max_distance * 2)
95
+
96
+ c = centroid_uno - int(max_distance * 2)
97
+
98
+ d = centroid_uno + int(max_distance * 2)
99
+
100
+ height, width, _ = image.shape
101
+ pad_size = max_distance * 2
102
+
103
+ if a < 0:
104
+ crop_img = image[
105
+ 0:centroid_cero + int(max_distance * 2),
106
+ centroid_uno - int(max_distance * 2):centroid_uno + int(max_distance * 2)]
107
+
108
+ pad_top = max(0, pad_size - centroid_cero)
109
+ pad_bottom = max(0, pad_size + centroid_cero - height)
110
+ pad_left = max(0, pad_size - centroid_uno)
111
+ pad_right = max(0, pad_size + centroid_uno - width)
112
+
113
+ padded_img = np.pad(crop_img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant', constant_values=0)
114
+ padded_img = torch.from_numpy(padded_img)
115
+ lista_img_recortadas.append(padded_img)
116
+ elif b > height:
117
+ crop_img = image[
118
+ centroid_cero - int(max_distance*2):height,
119
+ centroid_uno - int(max_distance * 2):centroid_uno + int(max_distance * 2)]
120
+
121
+ pad_top = max(0, pad_size - centroid_cero)
122
+ pad_bottom = max(0, pad_size + centroid_cero - height)
123
+ pad_left = max(0, pad_size - centroid_uno)
124
+ pad_right = max(0, pad_size + centroid_uno - width)
125
+
126
+ padded_img = np.pad(crop_img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant', constant_values=0)
127
+ padded_img = torch.from_numpy(padded_img)
128
+ lista_img_recortadas.append(padded_img)
129
+ elif c < 0:
130
+ crop_img = image[
131
+ centroid_cero-int(max_distance * 2):centroid_cero + int(max_distance * 2),
132
+ 0:centroid_uno + int(max_distance * 2)]
133
+
134
+ pad_top = max(0, pad_size - centroid_cero)
135
+ pad_bottom = max(0, pad_size + centroid_cero - height)
136
+ pad_left = max(0, pad_size - centroid_uno)
137
+ pad_right = max(0, pad_size + centroid_uno - width)
138
+
139
+ padded_img = np.pad(crop_img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant', constant_values=0)
140
+ padded_img = torch.from_numpy(padded_img)
141
+
142
+ lista_img_recortadas.append(padded_img)
143
+ elif d > width:
144
+ crop_img = image[
145
+ centroid_cero - int(max_distance *2):centroid_cero + int(max_distance * 2),
146
+ centroid_uno - int(max_distance * 2):width]
147
+
148
+ pad_top = max(0, pad_size - centroid_cero)
149
+ pad_bottom = max(0, pad_size + centroid_cero - height)
150
+ pad_left = max(0, pad_size - centroid_uno)
151
+ pad_right = max(0, pad_size + centroid_uno - width)
152
+
153
+ padded_img = np.pad(crop_img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant', constant_values=0)
154
+ padded_img = torch.from_numpy(padded_img)
155
+ lista_img_recortadas.append(padded_img)
156
+
157
+ else:
158
+ crop_img = image[
159
+ centroid_cero - int(max_distance * 2):centroid_cero + int(max_distance * 2),
160
+ centroid_uno - int(max_distance * 2):centroid_uno + int(max_distance * 2)
161
+ ]
162
+
163
+ crop_img = torch.from_numpy(crop_img)
164
+
165
+
166
+ lista_img_recortadas.append(crop_img)
167
+
168
+ return lista_img_recortadas
169
+
170
+ # Load EfficientNetV2 model
171
+ img_shape = (224, 224, 3)
172
+ model_efficientnet = EfficientNetV2B0(include_top=False, input_shape=img_shape)
173
+ flat_1 = GlobalAveragePooling2D()(model_efficientnet.output)
174
+ capa_3 = Dense(1, activation='sigmoid')(flat_1)
175
+ model_efficientnet = Model(inputs=model_efficientnet.inputs, outputs=capa_3)
176
+ model_efficientnet.compile(optimizer=Adam(learning_rate=1e-4), loss="BinaryCrossentropy", metrics=["accuracy"])
177
+
178
+ # Load weights for EfficientNetV2 model
179
+ repo_id = "ferferefer/PAPILA"
180
+ filename = "EfficientNetV2B0_checkpoint.h5"
181
+ model_file = cached_download(hf_hub_url(repo_id, filename))
182
+ model_efficientnet.load_weights(model_file)
183
+
184
+ # Streamlit app
185
+ st.title('Glaucoma PAPILA Image Classifier')
186
+
187
+
188
+ # Main Streamlit app logic
189
+ uploaded_image = st.file_uploader('Upload image', type=['jpg', 'jpeg', 'png'])
190
+
191
+ if uploaded_image is not None:
192
+
193
+
194
+ # Obtain predictions from SegFormer model
195
+ predictions_papila = obtener_predicciones(model_seg, uploaded_image)
196
+
197
+ centroids,imagenes_transformadas = calcular_centro_imagen(predictions_papila)
198
+ imagen_final_recortada = recortar_imagen(centroids, imagenes_transformadas,uploaded_image)
199
+ imagen_final_recortada= Image.fromarray(imagen_final_recortada[0].numpy())
200
+
201
+ # Display cropped image
202
+ st.image(imagen_final_recortada[0], use_column_width=True)
203
+
204
+ # Button to trigger prediction
205
+ if st.button('PREDICT'):
206
+ predict = load_img(imagen_final_recortada[0], target_size=img_shape)
207
+ predict_modified = img_to_array(predict)
208
+ predict_modified = np.expand_dims(predict_modified, axis=0)
209
+ result = model_efficientnet.predict(predict_modified)
210
+ if result < 0.5:
211
+ probability = 1 - result[0][0]
212
+ st.write(f"Healthy with {probability*100:.2f}%")
213
+ else:
214
+ probability = result[0][0]
215
+ st.write(f"Glaucoma with {probability*100:.2f}%")
216
+
217
+ image1 = img_to_array(imagen_final_recortada[0])
218
+ image1 = np.array(imagen_final_recortada[0])
219
+ image1 = imagen_final_recortada[0]/255
220
+
221
+
222
+
223
+ st.image(imagen_final_recortada[0], caption='Uploaded Image', use_column_width=True, clamp=True)
224
+