import streamlit as st import cv2 import numpy as np from PIL import Image import torch from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg from detectron2.projects.deeplab import add_deeplab_config from detectron2.utils.colormap import random_color from mask2former import add_maskformer2_config from tqdm import tqdm def setup_predictor(config_file, weights_path, device='cpu'): cfg = get_cfg() cfg.set_new_allowed(True) add_deeplab_config(cfg) add_maskformer2_config(cfg) cfg.merge_from_file(config_file) cfg.MODEL.WEIGHTS = weights_path cfg.MODEL.DEVICE = device predictor = DefaultPredictor(cfg) return predictor def area(mask): if mask.size == 0: return 0 return np.count_nonzero(mask) / mask.size def vis_mask(input, mask, mask_color): fg = mask > 0.5 rgb = np.copy(input) rgb[fg] = (rgb[fg] * 0.5 + np.array(mask_color) * 0.5).astype(np.uint8) return Image.fromarray(rgb) def show_image(I, pool): already_painted = np.zeros(np.array(I).shape[:2]) input = I.copy() for mask in tqdm(pool): already_painted += mask.astype(np.uint8) overlap = (already_painted == 2) if np.sum(overlap) != 0: input = Image.fromarray(overlap[:, :, np.newaxis] * np.copy(I) + np.logical_not(overlap)[:, :, np.newaxis] * np.copy(input)) already_painted -= overlap input = vis_mask(input, mask, random_color(rgb=True)) return input # Load UnSAM and UnSAM+ predictors unsam_predictor = setup_predictor( "/kaggle/working/UnSAM/whole_image_segmentation/configs/maskformer2_R50_bs16_50ep.yaml", "/kaggle/working/Mask2Former/unsam_sa1b_4perc_ckpt_200k.pth" ) unsam_plus_predictor = setup_predictor( "/kaggle/working/UnSAM/whole_image_segmentation/configs/maskformer2_R50_bs16_50ep.yaml", "/kaggle/working/Mask2Former/unsam_plus_sa1b_1perc_ckpt_50k.pth" ) st.title("Image Segmentation with UnSAM and UnSAM+") # Upload image uploaded_file = st.file_uploader("Choose an image...", type="png") if uploaded_file is not None: # Read the image image = np.array(Image.open(uploaded_file)) # Display the original image st.image(image, caption='Original Image', use_column_width=True) # Run predictions for UnSAM+ unsam_plus_outputs = unsam_plus_predictor(image)['instances'] unsam_plus_masks = [mask.cpu().numpy() for mask in unsam_plus_outputs.pred_masks] sorted_unsam_plus_masks = sorted(unsam_plus_masks, key=lambda m: area(m), reverse=True) unsam_plus_image = show_image(image, sorted_unsam_plus_masks) # Run predictions for UnSAM unsam_outputs = unsam_predictor(image)['instances'] unsam_masks = [mask.cpu().numpy() for mask in unsam_outputs.pred_masks] sorted_unsam_masks = sorted(unsam_masks, key=lambda m: area(m), reverse=True) unsam_image = show_image(image, sorted_unsam_masks) # Display the images side by side st.image([image, unsam_plus_image, unsam_image], caption=['Original Image', 'UnSAM+ Output', 'UnSAM Output'], use_column_width=True)