import streamlit as st import cv2 import numpy as np import torch from PIL import Image from detectron2.config import get_cfg from detectron2.projects.deeplab import add_deeplab_config from detectron2.engine import DefaultPredictor from mask2former import add_maskformer2_config from detectron2.utils.colormap import random_color import os @st.cache_resource def setup_config(weights_path): cfg = get_cfg() cfg.set_new_allowed(True) add_deeplab_config(cfg) add_maskformer2_config(cfg) cfg.merge_from_file("configs/maskformer2_R50_bs16_50ep.yaml") cfg.MODEL.WEIGHTS = weights_path cfg.MODEL.DEVICE = "cpu" # Use CPU for inference cfg.freeze() return cfg def area(mask): if mask.size == 0: return 0 return np.count_nonzero(mask) / mask.size def vis_mask(input, mask, mask_color): fg = mask > 0.5 rgb = np.copy(input) rgb[fg] = (rgb[fg] * 0.5 + np.array(mask_color) * 0.5).astype(np.uint8) return Image.fromarray(rgb) def show_image(I, pool): already_painted = np.zeros(np.array(I).shape[:2]) input = I.copy() for mask in pool: already_painted += mask.astype(np.uint8) overlap = (already_painted == 2) if np.sum(overlap) != 0: input = Image.fromarray(overlap[:, :, np.newaxis] * np.copy(I) + np.logical_not(overlap)[:, :, np.newaxis] * np.copy(input)) already_painted -= overlap input = vis_mask(input, mask, random_color(rgb=True)) return input import gdown gdown.download("https://drive.google.com/uc?id=1sCZM5j2pQr34-scSEkgG7VmUaHJc8n4d", "unsam_plus_sa1b_1perc_ckpt_50k.pth", quiet=False) gdown.download("https://drive.google.com/uc?id=1qUdZ2ELU_5SNTsmx3Q0wSA87u4SebiO4", "unsam_sa1b_4perc_ckpt_200k.pth", quiet=False) @st.cache_data def process_image(image, model_type): if model_type == "UNSAM+": weights_path = "unsam_plus_sa1b_1perc_ckpt_50k.pth" else: # UNSAM weights_path = "unsam_sa1b_4perc_ckpt_200k.pth" cfg = setup_config(weights_path) predictor = DefaultPredictor(cfg) inputs = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) outputs = predictor(inputs)['instances'] masks = [] for score, mask in zip(outputs.scores, outputs.pred_masks): if score < 0.5: continue masks.append(mask.cpu().numpy()) sorted_masks = sorted(masks, key=lambda m: area(m), reverse=True) result_image = show_image(np.array(image), sorted_masks) return result_image st.title("UNSAM and UNSAM+ Image Segmentation") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) col1, col2, col3 = st.columns(3) with col1: st.header("Original Image") st.image(image, use_column_width=True) with col2: st.header("UNSAM+ Output") unsam_plus_output = process_image(image, "UNSAM+") st.image(unsam_plus_output, use_column_width=True) with col3: st.header("UNSAM Output") unsam_output = process_image(image, "UNSAM") st.image(unsam_output, use_column_width=True) else: st.write("Please upload an image to see the segmentation results.")