Segmentation / app.py
maliahson's picture
Create app.py
2768cff verified
raw
history blame
1.91 kB
import streamlit as st
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
# Set up the Streamlit app
st.title("Optic Disc and Cup Segmentation")
st.write("Upload an image to segment the optic disc and cup:")
# Create a file uploader
uploaded_file = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg"])
# Load the processor and model
processor = AutoImageProcessor.from_pretrained("pamixsun/segformer_for_optic_disc_cup_segmentation")
model = SegformerForSemanticSegmentation.from_pretrained("pamixsun/segformer_for_optic_disc_cup_segmentation")
# Define a function to process the image
def process_image(image):
# Convert the image to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Process the input image
inputs = processor(image, return_tensors="pt")
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.cpu()
# Upsample the logits to match the input image size
upsampled_logits = nn.functional.interpolate(
logits,
size=image.shape[:2],
mode="bilinear",
align_corners=False,
)
# Get the predicted segmentation
pred_disc_cup = upsampled_logits.argmax(dim=1)[0].numpy().astype(np.uint8)
# Display the input image and the segmented output
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(image)
axes[0].set_title('Input Image')
axes[0].axis('off')
axes[1].imshow(pred_disc_cup, cmap='gray')
axes[1].set_title('Segmented Output')
axes[1].axis('off')
plt.tight_layout()
return fig
# Display the output
if uploaded_file:
image = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), cv2.IMREAD_COLOR)
output_fig = process_image(image)
st.pyplot(output_fig)