Spaces:
Build error
Build error
import sys | |
import subprocess | |
def install_base_requirements(): | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements_base.txt"]) | |
install_base_requirements() | |
pip install --upgrade pip | |
!pip install torch | |
import streamlit as st | |
import cv2 | |
import numpy as np | |
import torch | |
from PIL import Image | |
from detectron2.config import get_cfg | |
from detectron2.projects.deeplab import add_deeplab_config | |
from detectron2.engine import DefaultPredictor | |
from mask2former import add_maskformer2_config | |
from detectron2.utils.colormap import random_color | |
import os | |
def setup_config(weights_path): | |
cfg = get_cfg() | |
cfg.set_new_allowed(True) | |
add_deeplab_config(cfg) | |
add_maskformer2_config(cfg) | |
cfg.merge_from_file("configs/maskformer2_R50_bs16_50ep.yaml") | |
cfg.MODEL.WEIGHTS = weights_path | |
cfg.MODEL.DEVICE = "cpu" # Use CPU for inference | |
cfg.freeze() | |
return cfg | |
def area(mask): | |
if mask.size == 0: return 0 | |
return np.count_nonzero(mask) / mask.size | |
def vis_mask(input, mask, mask_color): | |
fg = mask > 0.5 | |
rgb = np.copy(input) | |
rgb[fg] = (rgb[fg] * 0.5 + np.array(mask_color) * 0.5).astype(np.uint8) | |
return Image.fromarray(rgb) | |
def show_image(I, pool): | |
already_painted = np.zeros(np.array(I).shape[:2]) | |
input = I.copy() | |
for mask in pool: | |
already_painted += mask.astype(np.uint8) | |
overlap = (already_painted == 2) | |
if np.sum(overlap) != 0: | |
input = Image.fromarray(overlap[:, :, np.newaxis] * np.copy(I) + np.logical_not(overlap)[:, :, np.newaxis] * np.copy(input)) | |
already_painted -= overlap | |
input = vis_mask(input, mask, random_color(rgb=True)) | |
return input | |
import gdown | |
gdown.download("https://drive.google.com/uc?id=1sCZM5j2pQr34-scSEkgG7VmUaHJc8n4d", "unsam_plus_sa1b_1perc_ckpt_50k.pth", quiet=False) | |
gdown.download("https://drive.google.com/uc?id=1qUdZ2ELU_5SNTsmx3Q0wSA87u4SebiO4", "unsam_sa1b_4perc_ckpt_200k.pth", quiet=False) | |
def process_image(image, model_type): | |
if model_type == "UNSAM+": | |
weights_path = "unsam_plus_sa1b_1perc_ckpt_50k.pth" | |
else: # UNSAM | |
weights_path = "unsam_sa1b_4perc_ckpt_200k.pth" | |
cfg = setup_config(weights_path) | |
predictor = DefaultPredictor(cfg) | |
inputs = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
outputs = predictor(inputs)['instances'] | |
masks = [] | |
for score, mask in zip(outputs.scores, outputs.pred_masks): | |
if score < 0.5: continue | |
masks.append(mask.cpu().numpy()) | |
sorted_masks = sorted(masks, key=lambda m: area(m), reverse=True) | |
result_image = show_image(np.array(image), sorted_masks) | |
return result_image | |
st.title("UNSAM and UNSAM+ Image Segmentation") | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.header("Original Image") | |
st.image(image, use_column_width=True) | |
with col2: | |
st.header("UNSAM+ Output") | |
unsam_plus_output = process_image(image, "UNSAM+") | |
st.image(unsam_plus_output, use_column_width=True) | |
with col3: | |
st.header("UNSAM Output") | |
unsam_output = process_image(image, "UNSAM") | |
st.image(unsam_output, use_column_width=True) | |
else: | |
st.write("Please upload an image to see the segmentation results.") |