Spaces:
Build error
Build error
import streamlit as st | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import torch | |
from detectron2.engine import DefaultPredictor | |
from detectron2.config import get_cfg | |
from detectron2.projects.deeplab import add_deeplab_config | |
from detectron2.utils.colormap import random_color | |
from mask2former import add_maskformer2_config | |
from tqdm import tqdm | |
def setup_predictor(config_file, weights_path, device='cpu'): | |
cfg = get_cfg() | |
cfg.set_new_allowed(True) | |
add_deeplab_config(cfg) | |
add_maskformer2_config(cfg) | |
cfg.merge_from_file(config_file) | |
cfg.MODEL.WEIGHTS = weights_path | |
cfg.MODEL.DEVICE = device | |
predictor = DefaultPredictor(cfg) | |
return predictor | |
def area(mask): | |
if mask.size == 0: | |
return 0 | |
return np.count_nonzero(mask) / mask.size | |
def vis_mask(input, mask, mask_color): | |
fg = mask > 0.5 | |
rgb = np.copy(input) | |
rgb[fg] = (rgb[fg] * 0.5 + np.array(mask_color) * 0.5).astype(np.uint8) | |
return Image.fromarray(rgb) | |
def show_image(I, pool): | |
already_painted = np.zeros(np.array(I).shape[:2]) | |
input = I.copy() | |
for mask in tqdm(pool): | |
already_painted += mask.astype(np.uint8) | |
overlap = (already_painted == 2) | |
if np.sum(overlap) != 0: | |
input = Image.fromarray(overlap[:, :, np.newaxis] * np.copy(I) + np.logical_not(overlap)[:, :, np.newaxis] * np.copy(input)) | |
already_painted -= overlap | |
input = vis_mask(input, mask, random_color(rgb=True)) | |
return input | |
# Load UnSAM and UnSAM+ predictors | |
unsam_predictor = setup_predictor( | |
"/kaggle/working/UnSAM/whole_image_segmentation/configs/maskformer2_R50_bs16_50ep.yaml", | |
"/kaggle/working/Mask2Former/unsam_sa1b_4perc_ckpt_200k.pth" | |
) | |
unsam_plus_predictor = setup_predictor( | |
"/kaggle/working/UnSAM/whole_image_segmentation/configs/maskformer2_R50_bs16_50ep.yaml", | |
"/kaggle/working/Mask2Former/unsam_plus_sa1b_1perc_ckpt_50k.pth" | |
) | |
st.title("Image Segmentation with UnSAM and UnSAM+") | |
# Upload image | |
uploaded_file = st.file_uploader("Choose an image...", type="png") | |
if uploaded_file is not None: | |
# Read the image | |
image = np.array(Image.open(uploaded_file)) | |
# Display the original image | |
st.image(image, caption='Original Image', use_column_width=True) | |
# Run predictions for UnSAM+ | |
unsam_plus_outputs = unsam_plus_predictor(image)['instances'] | |
unsam_plus_masks = [mask.cpu().numpy() for mask in unsam_plus_outputs.pred_masks] | |
sorted_unsam_plus_masks = sorted(unsam_plus_masks, key=lambda m: area(m), reverse=True) | |
unsam_plus_image = show_image(image, sorted_unsam_plus_masks) | |
# Run predictions for UnSAM | |
unsam_outputs = unsam_predictor(image)['instances'] | |
unsam_masks = [mask.cpu().numpy() for mask in unsam_outputs.pred_masks] | |
sorted_unsam_masks = sorted(unsam_masks, key=lambda m: area(m), reverse=True) | |
unsam_image = show_image(image, sorted_unsam_masks) | |
# Display the images side by side | |
st.image([image, unsam_plus_image, unsam_image], caption=['Original Image', 'UnSAM+ Output', 'UnSAM Output'], use_column_width=True) | |