ESD / app.py
Satyajithchary's picture
Update app.py
581e4ed verified
raw
history blame
3.49 kB
import sys
import subprocess
def install_base_requirements():
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements_base.txt"])
install_base_requirements()
pip install --upgrade pip
!pip install torch
import streamlit as st
import cv2
import numpy as np
import torch
from PIL import Image
from detectron2.config import get_cfg
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.engine import DefaultPredictor
from mask2former import add_maskformer2_config
from detectron2.utils.colormap import random_color
import os
@st.cache_resource
def setup_config(weights_path):
cfg = get_cfg()
cfg.set_new_allowed(True)
add_deeplab_config(cfg)
add_maskformer2_config(cfg)
cfg.merge_from_file("configs/maskformer2_R50_bs16_50ep.yaml")
cfg.MODEL.WEIGHTS = weights_path
cfg.MODEL.DEVICE = "cpu" # Use CPU for inference
cfg.freeze()
return cfg
def area(mask):
if mask.size == 0: return 0
return np.count_nonzero(mask) / mask.size
def vis_mask(input, mask, mask_color):
fg = mask > 0.5
rgb = np.copy(input)
rgb[fg] = (rgb[fg] * 0.5 + np.array(mask_color) * 0.5).astype(np.uint8)
return Image.fromarray(rgb)
def show_image(I, pool):
already_painted = np.zeros(np.array(I).shape[:2])
input = I.copy()
for mask in pool:
already_painted += mask.astype(np.uint8)
overlap = (already_painted == 2)
if np.sum(overlap) != 0:
input = Image.fromarray(overlap[:, :, np.newaxis] * np.copy(I) + np.logical_not(overlap)[:, :, np.newaxis] * np.copy(input))
already_painted -= overlap
input = vis_mask(input, mask, random_color(rgb=True))
return input
import gdown
gdown.download("https://drive.google.com/uc?id=1sCZM5j2pQr34-scSEkgG7VmUaHJc8n4d", "unsam_plus_sa1b_1perc_ckpt_50k.pth", quiet=False)
gdown.download("https://drive.google.com/uc?id=1qUdZ2ELU_5SNTsmx3Q0wSA87u4SebiO4", "unsam_sa1b_4perc_ckpt_200k.pth", quiet=False)
@st.cache_data
def process_image(image, model_type):
if model_type == "UNSAM+":
weights_path = "unsam_plus_sa1b_1perc_ckpt_50k.pth"
else: # UNSAM
weights_path = "unsam_sa1b_4perc_ckpt_200k.pth"
cfg = setup_config(weights_path)
predictor = DefaultPredictor(cfg)
inputs = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
outputs = predictor(inputs)['instances']
masks = []
for score, mask in zip(outputs.scores, outputs.pred_masks):
if score < 0.5: continue
masks.append(mask.cpu().numpy())
sorted_masks = sorted(masks, key=lambda m: area(m), reverse=True)
result_image = show_image(np.array(image), sorted_masks)
return result_image
st.title("UNSAM and UNSAM+ Image Segmentation")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
col1, col2, col3 = st.columns(3)
with col1:
st.header("Original Image")
st.image(image, use_column_width=True)
with col2:
st.header("UNSAM+ Output")
unsam_plus_output = process_image(image, "UNSAM+")
st.image(unsam_plus_output, use_column_width=True)
with col3:
st.header("UNSAM Output")
unsam_output = process_image(image, "UNSAM")
st.image(unsam_output, use_column_width=True)
else:
st.write("Please upload an image to see the segmentation results.")