File size: 3,491 Bytes
581e4ed
 
 
 
 
 
 
 
 
e7092fe
3cc6267
d61e631
 
 
8b0d3c7
97d2179
7365db2
 
97d2179
7365db2
97d2179
 
d61e631
97d2179
 
7365db2
 
 
 
97d2179
7365db2
97d2179
 
 
d61e631
7365db2
97d2179
7365db2
d61e631
7365db2
 
 
 
 
d61e631
7365db2
 
 
97d2179
7365db2
 
 
 
 
 
 
97d2179
d61e631
97d2179
 
7365db2
97d2179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7365db2
97d2179
7365db2
 
97d2179
7365db2
97d2179
7365db2
97d2179
 
 
d61e631
97d2179
 
 
 
d61e631
97d2179
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import sys
import subprocess

def install_base_requirements():
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements_base.txt"])

install_base_requirements()


pip install --upgrade pip
!pip install torch
import streamlit as st
import cv2
import numpy as np
import torch
from PIL import Image
from detectron2.config import get_cfg
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.engine import DefaultPredictor
from mask2former import add_maskformer2_config
from detectron2.utils.colormap import random_color
import os

@st.cache_resource
def setup_config(weights_path):
    cfg = get_cfg()
    cfg.set_new_allowed(True)
    add_deeplab_config(cfg)
    add_maskformer2_config(cfg)
    cfg.merge_from_file("configs/maskformer2_R50_bs16_50ep.yaml")
    cfg.MODEL.WEIGHTS = weights_path
    cfg.MODEL.DEVICE = "cpu"  # Use CPU for inference
    cfg.freeze()
    return cfg

def area(mask):
    if mask.size == 0: return 0
    return np.count_nonzero(mask) / mask.size

def vis_mask(input, mask, mask_color):
    fg = mask > 0.5
    rgb = np.copy(input)
    rgb[fg] = (rgb[fg] * 0.5 + np.array(mask_color) * 0.5).astype(np.uint8)
    return Image.fromarray(rgb)

def show_image(I, pool):
    already_painted = np.zeros(np.array(I).shape[:2])
    input = I.copy()
    for mask in pool:
        already_painted += mask.astype(np.uint8)
        overlap = (already_painted == 2)
        if np.sum(overlap) != 0:
            input = Image.fromarray(overlap[:, :, np.newaxis] * np.copy(I) + np.logical_not(overlap)[:, :, np.newaxis] * np.copy(input))
            already_painted -= overlap
        input = vis_mask(input, mask, random_color(rgb=True))
    return input
import gdown

gdown.download("https://drive.google.com/uc?id=1sCZM5j2pQr34-scSEkgG7VmUaHJc8n4d", "unsam_plus_sa1b_1perc_ckpt_50k.pth", quiet=False)
gdown.download("https://drive.google.com/uc?id=1qUdZ2ELU_5SNTsmx3Q0wSA87u4SebiO4", "unsam_sa1b_4perc_ckpt_200k.pth", quiet=False)

@st.cache_data
def process_image(image, model_type):
    if model_type == "UNSAM+":
        weights_path = "unsam_plus_sa1b_1perc_ckpt_50k.pth"
    else:  # UNSAM
        weights_path = "unsam_sa1b_4perc_ckpt_200k.pth"
    
    cfg = setup_config(weights_path)
    predictor = DefaultPredictor(cfg)
    
    inputs = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    outputs = predictor(inputs)['instances']
    
    masks = []
    for score, mask in zip(outputs.scores, outputs.pred_masks):
        if score < 0.5: continue
        masks.append(mask.cpu().numpy())
    
    sorted_masks = sorted(masks, key=lambda m: area(m), reverse=True)
    result_image = show_image(np.array(image), sorted_masks)
    
    return result_image

st.title("UNSAM and UNSAM+ Image Segmentation")

uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    
    col1, col2, col3 = st.columns(3)
    
    with col1:
        st.header("Original Image")
        st.image(image, use_column_width=True)
    
    with col2:
        st.header("UNSAM+ Output")
        unsam_plus_output = process_image(image, "UNSAM+")
        st.image(unsam_plus_output, use_column_width=True)
    
    with col3:
        st.header("UNSAM Output")
        unsam_output = process_image(image, "UNSAM")
        st.image(unsam_output, use_column_width=True)

else:
    st.write("Please upload an image to see the segmentation results.")