File size: 3,108 Bytes
d61e631
 
 
 
 
7365db2
 
 
 
 
 
d61e631
7365db2
 
 
 
 
 
 
 
 
 
d61e631
7365db2
 
 
 
d61e631
7365db2
 
 
 
 
d61e631
7365db2
 
 
 
 
 
 
 
 
 
 
d61e631
7365db2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d61e631
7365db2
 
 
 
 
d61e631
7365db2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import streamlit as st
import cv2
import numpy as np
from PIL import Image
import torch
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.utils.colormap import random_color
from mask2former import add_maskformer2_config
from tqdm import tqdm

def setup_predictor(config_file, weights_path, device='cpu'):
    cfg = get_cfg()
    cfg.set_new_allowed(True)
    add_deeplab_config(cfg)
    add_maskformer2_config(cfg)
    cfg.merge_from_file(config_file)
    cfg.MODEL.WEIGHTS = weights_path
    cfg.MODEL.DEVICE = device
    predictor = DefaultPredictor(cfg)
    return predictor

def area(mask):
    if mask.size == 0: 
        return 0
    return np.count_nonzero(mask) / mask.size

def vis_mask(input, mask, mask_color):
    fg = mask > 0.5
    rgb = np.copy(input)
    rgb[fg] = (rgb[fg] * 0.5 + np.array(mask_color) * 0.5).astype(np.uint8)
    return Image.fromarray(rgb)

def show_image(I, pool):
    already_painted = np.zeros(np.array(I).shape[:2])
    input = I.copy()
    for mask in tqdm(pool):
        already_painted += mask.astype(np.uint8)
        overlap = (already_painted == 2)
        if np.sum(overlap) != 0:
            input = Image.fromarray(overlap[:, :, np.newaxis] * np.copy(I) + np.logical_not(overlap)[:, :, np.newaxis] * np.copy(input))
            already_painted -= overlap
        input = vis_mask(input, mask, random_color(rgb=True))
    return input

# Load UnSAM and UnSAM+ predictors
unsam_predictor = setup_predictor(
    "/kaggle/working/UnSAM/whole_image_segmentation/configs/maskformer2_R50_bs16_50ep.yaml", 
    "/kaggle/working/Mask2Former/unsam_sa1b_4perc_ckpt_200k.pth"
)
unsam_plus_predictor = setup_predictor(
    "/kaggle/working/UnSAM/whole_image_segmentation/configs/maskformer2_R50_bs16_50ep.yaml", 
    "/kaggle/working/Mask2Former/unsam_plus_sa1b_1perc_ckpt_50k.pth"
)

st.title("Image Segmentation with UnSAM and UnSAM+")

# Upload image
uploaded_file = st.file_uploader("Choose an image...", type="png")

if uploaded_file is not None:
    # Read the image
    image = np.array(Image.open(uploaded_file))
    
    # Display the original image
    st.image(image, caption='Original Image', use_column_width=True)
    
    # Run predictions for UnSAM+
    unsam_plus_outputs = unsam_plus_predictor(image)['instances']
    unsam_plus_masks = [mask.cpu().numpy() for mask in unsam_plus_outputs.pred_masks]
    sorted_unsam_plus_masks = sorted(unsam_plus_masks, key=lambda m: area(m), reverse=True)
    unsam_plus_image = show_image(image, sorted_unsam_plus_masks)
    
    # Run predictions for UnSAM
    unsam_outputs = unsam_predictor(image)['instances']
    unsam_masks = [mask.cpu().numpy() for mask in unsam_outputs.pred_masks]
    sorted_unsam_masks = sorted(unsam_masks, key=lambda m: area(m), reverse=True)
    unsam_image = show_image(image, sorted_unsam_masks)
    
    # Display the images side by side
    st.image([image, unsam_plus_image, unsam_image], caption=['Original Image', 'UnSAM+ Output', 'UnSAM Output'], use_column_width=True)