Spaces:
Build error
Build error
File size: 3,253 Bytes
d61e631 8b0d3c7 97d2179 7365db2 97d2179 7365db2 97d2179 d61e631 97d2179 7365db2 97d2179 7365db2 97d2179 d61e631 7365db2 97d2179 7365db2 d61e631 7365db2 d61e631 7365db2 97d2179 7365db2 97d2179 d61e631 97d2179 7365db2 97d2179 7365db2 97d2179 7365db2 97d2179 7365db2 97d2179 7365db2 97d2179 d61e631 97d2179 d61e631 97d2179 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import streamlit as st
import cv2
import numpy as np
import torch
from PIL import Image
from detectron2.config import get_cfg
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.engine import DefaultPredictor
from mask2former import add_maskformer2_config
from detectron2.utils.colormap import random_color
import os
@st.cache_resource
def setup_config(weights_path):
cfg = get_cfg()
cfg.set_new_allowed(True)
add_deeplab_config(cfg)
add_maskformer2_config(cfg)
cfg.merge_from_file("configs/maskformer2_R50_bs16_50ep.yaml")
cfg.MODEL.WEIGHTS = weights_path
cfg.MODEL.DEVICE = "cpu" # Use CPU for inference
cfg.freeze()
return cfg
def area(mask):
if mask.size == 0: return 0
return np.count_nonzero(mask) / mask.size
def vis_mask(input, mask, mask_color):
fg = mask > 0.5
rgb = np.copy(input)
rgb[fg] = (rgb[fg] * 0.5 + np.array(mask_color) * 0.5).astype(np.uint8)
return Image.fromarray(rgb)
def show_image(I, pool):
already_painted = np.zeros(np.array(I).shape[:2])
input = I.copy()
for mask in pool:
already_painted += mask.astype(np.uint8)
overlap = (already_painted == 2)
if np.sum(overlap) != 0:
input = Image.fromarray(overlap[:, :, np.newaxis] * np.copy(I) + np.logical_not(overlap)[:, :, np.newaxis] * np.copy(input))
already_painted -= overlap
input = vis_mask(input, mask, random_color(rgb=True))
return input
import gdown
gdown.download("https://drive.google.com/uc?id=1sCZM5j2pQr34-scSEkgG7VmUaHJc8n4d", "unsam_plus_sa1b_1perc_ckpt_50k.pth", quiet=False)
gdown.download("https://drive.google.com/uc?id=1qUdZ2ELU_5SNTsmx3Q0wSA87u4SebiO4", "unsam_sa1b_4perc_ckpt_200k.pth", quiet=False)
@st.cache_data
def process_image(image, model_type):
if model_type == "UNSAM+":
weights_path = "unsam_plus_sa1b_1perc_ckpt_50k.pth"
else: # UNSAM
weights_path = "unsam_sa1b_4perc_ckpt_200k.pth"
cfg = setup_config(weights_path)
predictor = DefaultPredictor(cfg)
inputs = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
outputs = predictor(inputs)['instances']
masks = []
for score, mask in zip(outputs.scores, outputs.pred_masks):
if score < 0.5: continue
masks.append(mask.cpu().numpy())
sorted_masks = sorted(masks, key=lambda m: area(m), reverse=True)
result_image = show_image(np.array(image), sorted_masks)
return result_image
st.title("UNSAM and UNSAM+ Image Segmentation")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
col1, col2, col3 = st.columns(3)
with col1:
st.header("Original Image")
st.image(image, use_column_width=True)
with col2:
st.header("UNSAM+ Output")
unsam_plus_output = process_image(image, "UNSAM+")
st.image(unsam_plus_output, use_column_width=True)
with col3:
st.header("UNSAM Output")
unsam_output = process_image(image, "UNSAM")
st.image(unsam_output, use_column_width=True)
else:
st.write("Please upload an image to see the segmentation results.") |