import streamlit as st import numpy as np from PIL import Image import torch import torch.nn as nn from transformers import SegformerForSemanticSegmentation from transformers import AutoTokenizer from transformers import AutoImageProcessor from huggingface_hub import hf_hub_url, cached_download from tensorflow.keras.applications import EfficientNetV2B0 from keras.layers import GlobalAveragePooling2D, Dense from keras.models import Model from tensorflow.keras.optimizers import Adam from tensorflow.keras.preprocessing.image import load_img, img_to_array # Load SegFormer model model_id_seg = "nvidia/mit-b0" image_processor = AutoImageProcessor.from_pretrained(model_id_seg, size=(128, 128)) #id2label = {0: "na", 1:"anillo", 2:"nervio"} #label2id = { v:k for k, v in id2label.items()} #model_seg = AutoModelForSemanticSegmentation.from_pretrained(model_id_seg, id2label=id2label, label2id=label2id) # Load SegFormer model with trained weights repo_id_seg = "ferferefer/segformer" #filename_seg = "model.ckpt" #model_file_seg = cached_download(hf_hub_url(repo_id_seg, filename_seg)) model_seg = SegformerForSemanticSegmentation.from_pretrained(repo_id_seg) # Function to preprocess and obtain predictions from SegFormer model def obtener_predicciones(model, sample_batch): processed_batch = image_processor(sample_batch, return_tensors="pt") pixel_values = processed_batch.pixel_values outputs = model(pixel_values=pixel_values) logits = outputs.logits upsampled_logits = nn.functional.interpolate( logits, size=sample_batch[0].size[::-1], mode="bilinear", align_corners=False, ) pred_seg = upsampled_logits.argmax(dim=1) return pred_seg # Function to calculate centroids of the segmented image def calcular_centro_imagen(masks): centroid_list = [] imagenes_transformadas = [] for mask in masks: image = np.transpose(np.argwhere(mask.cpu()==1)) x = [p[0] for p in image] y = [p[1] for p in image] centroid = (sum(x) / len(image), sum(y) / len(image)) centroid_list.append(centroid) imagenes_transformadas.append(image) return centroid_list,imagenes_transformadas # Function to crop the segmented image based on centroids def recortar_imagen(centroids, mascara_final,images): lista_img_recortadas = [] for counter, image in enumerate (images): max_distance = 0 for x, y in mascara_final[counter]: distance = np.sqrt((x - centroids[counter][0]) ** 2 + (y - centroids[counter][1]) ** 2) if distance > max_distance: max_distance = distance centroid_uno = int(centroids[counter][1].item()) centroid_cero = int(centroids[counter][0].item()) max_distance = int(max_distance.item()) image = image.cpu().numpy() #image = np.transpose(image, (1, 2, 0)) #image = np.clip(image, 0, 1, dtype=np.float32) a = centroid_cero - int(max_distance * 2) b = centroid_cero + int(max_distance * 2) c = centroid_uno - int(max_distance * 2) d = centroid_uno + int(max_distance * 2) height, width, _ = image.shape pad_size = max_distance * 2 if a < 0: crop_img = image[ 0:centroid_cero + int(max_distance * 2), centroid_uno - int(max_distance * 2):centroid_uno + int(max_distance * 2)] pad_top = max(0, pad_size - centroid_cero) pad_bottom = max(0, pad_size + centroid_cero - height) pad_left = max(0, pad_size - centroid_uno) pad_right = max(0, pad_size + centroid_uno - width) padded_img = np.pad(crop_img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant', constant_values=0) padded_img = torch.from_numpy(padded_img) lista_img_recortadas.append(padded_img) elif b > height: crop_img = image[ centroid_cero - int(max_distance*2):height, centroid_uno - int(max_distance * 2):centroid_uno + int(max_distance * 2)] pad_top = max(0, pad_size - centroid_cero) pad_bottom = max(0, pad_size + centroid_cero - height) pad_left = max(0, pad_size - centroid_uno) pad_right = max(0, pad_size + centroid_uno - width) padded_img = np.pad(crop_img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant', constant_values=0) padded_img = torch.from_numpy(padded_img) lista_img_recortadas.append(padded_img) elif c < 0: crop_img = image[ centroid_cero-int(max_distance * 2):centroid_cero + int(max_distance * 2), 0:centroid_uno + int(max_distance * 2)] pad_top = max(0, pad_size - centroid_cero) pad_bottom = max(0, pad_size + centroid_cero - height) pad_left = max(0, pad_size - centroid_uno) pad_right = max(0, pad_size + centroid_uno - width) padded_img = np.pad(crop_img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant', constant_values=0) padded_img = torch.from_numpy(padded_img) lista_img_recortadas.append(padded_img) elif d > width: crop_img = image[ centroid_cero - int(max_distance *2):centroid_cero + int(max_distance * 2), centroid_uno - int(max_distance * 2):width] pad_top = max(0, pad_size - centroid_cero) pad_bottom = max(0, pad_size + centroid_cero - height) pad_left = max(0, pad_size - centroid_uno) pad_right = max(0, pad_size + centroid_uno - width) padded_img = np.pad(crop_img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant', constant_values=0) padded_img = torch.from_numpy(padded_img) lista_img_recortadas.append(padded_img) else: crop_img = image[ centroid_cero - int(max_distance * 2):centroid_cero + int(max_distance * 2), centroid_uno - int(max_distance * 2):centroid_uno + int(max_distance * 2) ] crop_img = torch.from_numpy(crop_img) lista_img_recortadas.append(crop_img) return lista_img_recortadas # Load EfficientNetV2 model img_shape = (224, 224, 3) model_efficientnet = EfficientNetV2B0(include_top=False, input_shape=img_shape) flat_1 = GlobalAveragePooling2D()(model_efficientnet.output) capa_3 = Dense(1, activation='sigmoid')(flat_1) model_efficientnet = Model(inputs=model_efficientnet.inputs, outputs=capa_3) model_efficientnet.compile(optimizer=Adam(learning_rate=1e-4), loss="BinaryCrossentropy", metrics=["accuracy"]) # Load weights for EfficientNetV2 model repo_id = "ferferefer/PAPILA" filename = "EfficientNetV2B0_checkpoint.h5" model_file = cached_download(hf_hub_url(repo_id, filename)) model_efficientnet.load_weights(model_file) # Streamlit app st.title('Glaucoma PAPILA Image Classifier') # Main Streamlit app logic uploaded_image = st.file_uploader('Upload image', type=['jpg', 'jpeg', 'png']) if uploaded_image is not None: # Obtain predictions from SegFormer model predictions_papila = obtener_predicciones(model_seg, uploaded_image) centroids,imagenes_transformadas = calcular_centro_imagen(predictions_papila) imagen_final_recortada = recortar_imagen(centroids, imagenes_transformadas,uploaded_image) imagen_final_recortada= Image.fromarray(imagen_final_recortada[0].numpy()) # Display cropped image st.image(imagen_final_recortada[0], use_column_width=True) # Button to trigger prediction if st.button('PREDICT'): predict = load_img(imagen_final_recortada[0], target_size=img_shape) predict_modified = img_to_array(predict) predict_modified = np.expand_dims(predict_modified, axis=0) result = model_efficientnet.predict(predict_modified) if result < 0.5: probability = 1 - result[0][0] st.write(f"Healthy with {probability*100:.2f}%") else: probability = result[0][0] st.write(f"Glaucoma with {probability*100:.2f}%") image1 = img_to_array(imagen_final_recortada[0]) image1 = np.array(imagen_final_recortada[0]) image1 = imagen_final_recortada[0]/255 st.image(imagen_final_recortada[0], caption='Uploaded Image', use_column_width=True, clamp=True)