Spaces:
Sleeping
Sleeping
File size: 1,907 Bytes
2768cff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import streamlit as st
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
# Set up the Streamlit app
st.title("Optic Disc and Cup Segmentation")
st.write("Upload an image to segment the optic disc and cup:")
# Create a file uploader
uploaded_file = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg"])
# Load the processor and model
processor = AutoImageProcessor.from_pretrained("pamixsun/segformer_for_optic_disc_cup_segmentation")
model = SegformerForSemanticSegmentation.from_pretrained("pamixsun/segformer_for_optic_disc_cup_segmentation")
# Define a function to process the image
def process_image(image):
# Convert the image to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Process the input image
inputs = processor(image, return_tensors="pt")
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.cpu()
# Upsample the logits to match the input image size
upsampled_logits = nn.functional.interpolate(
logits,
size=image.shape[:2],
mode="bilinear",
align_corners=False,
)
# Get the predicted segmentation
pred_disc_cup = upsampled_logits.argmax(dim=1)[0].numpy().astype(np.uint8)
# Display the input image and the segmented output
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(image)
axes[0].set_title('Input Image')
axes[0].axis('off')
axes[1].imshow(pred_disc_cup, cmap='gray')
axes[1].set_title('Segmented Output')
axes[1].axis('off')
plt.tight_layout()
return fig
# Display the output
if uploaded_file:
image = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), cv2.IMREAD_COLOR)
output_fig = process_image(image)
st.pyplot(output_fig) |