maliahson commited on
Commit
2768cff
·
verified ·
1 Parent(s): 080e626

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from torch import nn
7
+ from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
8
+
9
+ # Set up the Streamlit app
10
+ st.title("Optic Disc and Cup Segmentation")
11
+ st.write("Upload an image to segment the optic disc and cup:")
12
+
13
+ # Create a file uploader
14
+ uploaded_file = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg"])
15
+
16
+ # Load the processor and model
17
+ processor = AutoImageProcessor.from_pretrained("pamixsun/segformer_for_optic_disc_cup_segmentation")
18
+ model = SegformerForSemanticSegmentation.from_pretrained("pamixsun/segformer_for_optic_disc_cup_segmentation")
19
+
20
+ # Define a function to process the image
21
+ def process_image(image):
22
+ # Convert the image to RGB
23
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
24
+
25
+ # Process the input image
26
+ inputs = processor(image, return_tensors="pt")
27
+
28
+ # Perform inference
29
+ with torch.no_grad():
30
+ outputs = model(**inputs)
31
+ logits = outputs.logits.cpu()
32
+
33
+ # Upsample the logits to match the input image size
34
+ upsampled_logits = nn.functional.interpolate(
35
+ logits,
36
+ size=image.shape[:2],
37
+ mode="bilinear",
38
+ align_corners=False,
39
+ )
40
+
41
+ # Get the predicted segmentation
42
+ pred_disc_cup = upsampled_logits.argmax(dim=1)[0].numpy().astype(np.uint8)
43
+
44
+ # Display the input image and the segmented output
45
+ fig, axes = plt.subplots(1, 2, figsize=(12, 6))
46
+ axes[0].imshow(image)
47
+ axes[0].set_title('Input Image')
48
+ axes[0].axis('off')
49
+ axes[1].imshow(pred_disc_cup, cmap='gray')
50
+ axes[1].set_title('Segmented Output')
51
+ axes[1].axis('off')
52
+ plt.tight_layout()
53
+ return fig
54
+
55
+ # Display the output
56
+ if uploaded_file:
57
+ image = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), cv2.IMREAD_COLOR)
58
+ output_fig = process_image(image)
59
+ st.pyplot(output_fig)