File size: 1,907 Bytes
2768cff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import streamlit as st
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation

# Set up the Streamlit app
st.title("Optic Disc and Cup Segmentation")
st.write("Upload an image to segment the optic disc and cup:")

# Create a file uploader
uploaded_file = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg"])

# Load the processor and model
processor = AutoImageProcessor.from_pretrained("pamixsun/segformer_for_optic_disc_cup_segmentation")
model = SegformerForSemanticSegmentation.from_pretrained("pamixsun/segformer_for_optic_disc_cup_segmentation")

# Define a function to process the image
def process_image(image):
    # Convert the image to RGB
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Process the input image
    inputs = processor(image, return_tensors="pt")

    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits.cpu()

    # Upsample the logits to match the input image size
    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.shape[:2],
        mode="bilinear",
        align_corners=False,
    )

    # Get the predicted segmentation
    pred_disc_cup = upsampled_logits.argmax(dim=1)[0].numpy().astype(np.uint8)

    # Display the input image and the segmented output
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    axes[0].imshow(image)
    axes[0].set_title('Input Image')
    axes[0].axis('off')
    axes[1].imshow(pred_disc_cup, cmap='gray')
    axes[1].set_title('Segmented Output')
    axes[1].axis('off')
    plt.tight_layout()
    return fig

# Display the output
if uploaded_file:
    image = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), cv2.IMREAD_COLOR)
    output_fig = process_image(image)
    st.pyplot(output_fig)