import streamlit as st import cv2 import torch import numpy as np import matplotlib.pyplot as plt from torch import nn from transformers import AutoImageProcessor, SegformerForSemanticSegmentation # Set up the Streamlit app st.title("Optic Disc and Cup Segmentation") st.write("Upload an image to segment the optic disc and cup:") # Create a file uploader uploaded_file = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg"]) # Load the processor and model processor = AutoImageProcessor.from_pretrained("pamixsun/segformer_for_optic_disc_cup_segmentation") model = SegformerForSemanticSegmentation.from_pretrained("pamixsun/segformer_for_optic_disc_cup_segmentation") # Define a function to process the image def process_image(image): # Convert the image to RGB image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Process the input image inputs = processor(image, return_tensors="pt") # Perform inference with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits.cpu() # Upsample the logits to match the input image size upsampled_logits = nn.functional.interpolate( logits, size=image.shape[:2], mode="bilinear", align_corners=False, ) # Get the predicted segmentation pred_disc_cup = upsampled_logits.argmax(dim=1)[0].numpy().astype(np.uint8) # Display the input image and the segmented output fig, axes = plt.subplots(1, 2, figsize=(12, 6)) axes[0].imshow(image) axes[0].set_title('Input Image') axes[0].axis('off') axes[1].imshow(pred_disc_cup, cmap='gray') axes[1].set_title('Segmented Output') axes[1].axis('off') plt.tight_layout() return fig # Display the output if uploaded_file: image = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), cv2.IMREAD_COLOR) output_fig = process_image(image) st.pyplot(output_fig)