ESD / app.py
Satyajithchary's picture
Update app.py
7365db2 verified
raw
history blame
3.11 kB
import streamlit as st
import cv2
import numpy as np
from PIL import Image
import torch
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.utils.colormap import random_color
from mask2former import add_maskformer2_config
from tqdm import tqdm
def setup_predictor(config_file, weights_path, device='cpu'):
cfg = get_cfg()
cfg.set_new_allowed(True)
add_deeplab_config(cfg)
add_maskformer2_config(cfg)
cfg.merge_from_file(config_file)
cfg.MODEL.WEIGHTS = weights_path
cfg.MODEL.DEVICE = device
predictor = DefaultPredictor(cfg)
return predictor
def area(mask):
if mask.size == 0:
return 0
return np.count_nonzero(mask) / mask.size
def vis_mask(input, mask, mask_color):
fg = mask > 0.5
rgb = np.copy(input)
rgb[fg] = (rgb[fg] * 0.5 + np.array(mask_color) * 0.5).astype(np.uint8)
return Image.fromarray(rgb)
def show_image(I, pool):
already_painted = np.zeros(np.array(I).shape[:2])
input = I.copy()
for mask in tqdm(pool):
already_painted += mask.astype(np.uint8)
overlap = (already_painted == 2)
if np.sum(overlap) != 0:
input = Image.fromarray(overlap[:, :, np.newaxis] * np.copy(I) + np.logical_not(overlap)[:, :, np.newaxis] * np.copy(input))
already_painted -= overlap
input = vis_mask(input, mask, random_color(rgb=True))
return input
# Load UnSAM and UnSAM+ predictors
unsam_predictor = setup_predictor(
"/kaggle/working/UnSAM/whole_image_segmentation/configs/maskformer2_R50_bs16_50ep.yaml",
"/kaggle/working/Mask2Former/unsam_sa1b_4perc_ckpt_200k.pth"
)
unsam_plus_predictor = setup_predictor(
"/kaggle/working/UnSAM/whole_image_segmentation/configs/maskformer2_R50_bs16_50ep.yaml",
"/kaggle/working/Mask2Former/unsam_plus_sa1b_1perc_ckpt_50k.pth"
)
st.title("Image Segmentation with UnSAM and UnSAM+")
# Upload image
uploaded_file = st.file_uploader("Choose an image...", type="png")
if uploaded_file is not None:
# Read the image
image = np.array(Image.open(uploaded_file))
# Display the original image
st.image(image, caption='Original Image', use_column_width=True)
# Run predictions for UnSAM+
unsam_plus_outputs = unsam_plus_predictor(image)['instances']
unsam_plus_masks = [mask.cpu().numpy() for mask in unsam_plus_outputs.pred_masks]
sorted_unsam_plus_masks = sorted(unsam_plus_masks, key=lambda m: area(m), reverse=True)
unsam_plus_image = show_image(image, sorted_unsam_plus_masks)
# Run predictions for UnSAM
unsam_outputs = unsam_predictor(image)['instances']
unsam_masks = [mask.cpu().numpy() for mask in unsam_outputs.pred_masks]
sorted_unsam_masks = sorted(unsam_masks, key=lambda m: area(m), reverse=True)
unsam_image = show_image(image, sorted_unsam_masks)
# Display the images side by side
st.image([image, unsam_plus_image, unsam_image], caption=['Original Image', 'UnSAM+ Output', 'UnSAM Output'], use_column_width=True)