Spaces:
Build error
Build error
import os | |
import cv2 | |
import torch | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
from torchvision.ops import box_convert | |
from detectron2.config import LazyConfig, instantiate | |
from detectron2.checkpoint import DetectionCheckpointer | |
from segment_anything import sam_model_registry, SamPredictor | |
import groundingdino.datasets.transforms as T | |
from groundingdino.util.inference import load_model as dino_load_model, predict as dino_predict, annotate as dino_annotate | |
models = { | |
'vit_h': './pretrained/sam_vit_h_4b8939.pth', | |
'vit_b': './pretrained/sam_vit_b_01ec64.pth' | |
} | |
vitmatte_models = { | |
'vit_b': './pretrained/ViTMatte_B_DIS.pth', | |
} | |
vitmatte_config = { | |
'vit_b': './configs/matte_anything.py', | |
} | |
grounding_dino = { | |
'config': './GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py', | |
'weight': './pretrained/groundingdino_swint_ogc.pth' | |
} | |
def generate_checkerboard_image(height, width, num_squares): | |
num_squares_h = num_squares | |
square_size_h = height // num_squares_h | |
square_size_w = square_size_h | |
num_squares_w = width // square_size_w | |
new_height = num_squares_h * square_size_h | |
new_width = num_squares_w * square_size_w | |
image = np.zeros((new_height, new_width), dtype=np.uint8) | |
for i in range(num_squares_h): | |
for j in range(num_squares_w): | |
start_x = j * square_size_w | |
start_y = i * square_size_h | |
color = 255 if (i + j) % 2 == 0 else 200 | |
image[start_y:start_y + square_size_h, start_x:start_x + square_size_w] = color | |
image = cv2.resize(image, (width, height)) | |
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
return image | |
def init_segment_anything(model_type): | |
""" | |
Initialize the segmenting anything with model_type in ['vit_b', 'vit_l', 'vit_h'] | |
""" | |
sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device) | |
predictor = SamPredictor(sam) | |
return predictor | |
def init_vitmatte(model_type): | |
""" | |
Initialize the vitmatte with model_type in ['vit_s', 'vit_b'] | |
""" | |
cfg = LazyConfig.load(vitmatte_config[model_type]) | |
vitmatte = instantiate(cfg.model) | |
vitmatte.to(device) | |
vitmatte.eval() | |
DetectionCheckpointer(vitmatte).load(vitmatte_models[model_type]) | |
return vitmatte | |
def generate_trimap(mask, erode_kernel_size=10, dilate_kernel_size=10): | |
erode_kernel = np.ones((erode_kernel_size, erode_kernel_size), np.uint8) | |
dilate_kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8) | |
eroded = cv2.erode(mask, erode_kernel, iterations=5) | |
dilated = cv2.dilate(mask, dilate_kernel, iterations=5) | |
trimap = np.zeros_like(mask) | |
trimap[dilated==255] = 128 | |
trimap[eroded==255] = 255 | |
return trimap | |
# user click the image to get points, and show the points on the image | |
def get_point(img, sel_pix, point_type, evt: gr.SelectData): | |
if point_type == 'foreground_point': | |
sel_pix.append((evt.index, 1)) # append the foreground_point | |
elif point_type == 'background_point': | |
sel_pix.append((evt.index, 0)) # append the background_point | |
else: | |
sel_pix.append((evt.index, 1)) # default foreground_point | |
# draw points | |
for point, label in sel_pix: | |
cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) | |
if img[..., 0][0, 0] == img[..., 2][0, 0]: # BGR to RGB | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
return img if isinstance(img, np.ndarray) else np.array(img) | |
# undo the selected point | |
def undo_points(orig_img, sel_pix): | |
temp = orig_img.copy() | |
# draw points | |
if len(sel_pix) != 0: | |
sel_pix.pop() | |
for point, label in sel_pix: | |
cv2.drawMarker(temp, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) | |
if temp[..., 0][0, 0] == temp[..., 2][0, 0]: # BGR to RGB | |
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) | |
return temp if isinstance(temp, np.ndarray) else np.array(temp) | |
# once user upload an image, the original image is stored in `original_image` | |
def store_img(img): | |
return img, [] # when new image is uploaded, `selected_points` should be empty | |
def convert_pixels(gray_image, boxes): | |
converted_image = np.copy(gray_image) | |
for box in boxes: | |
x1, y1, x2, y2 = box | |
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
converted_image[y1:y2, x1:x2][converted_image[y1:y2, x1:x2] == 1] = 0.5 | |
return converted_image | |
if __name__ == "__main__": | |
device = 'cuda' | |
sam_model = 'vit_h' | |
vitmatte_model = 'vit_b' | |
colors = [(255, 0, 0), (0, 255, 0)] | |
markers = [1, 5] | |
print('Initializing models... Please wait...') | |
predictor = init_segment_anything(sam_model) | |
vitmatte = init_vitmatte(vitmatte_model) | |
grounding_dino = dino_load_model(grounding_dino['config'], grounding_dino['weight']) | |
def run_inference(input_x, selected_points, erode_kernel_size, dilate_kernel_size): | |
predictor.set_image(input_x) | |
if len(selected_points) != 0: | |
points = torch.Tensor([p for p, _ in selected_points]).to(device).unsqueeze(1) | |
labels = torch.Tensor([int(l) for _, l in selected_points]).to(device).unsqueeze(1) | |
transformed_points = predictor.transform.apply_coords_torch(points, input_x.shape[:2]) | |
print(points.size(), transformed_points.size(), labels.size(), input_x.shape, points) | |
else: | |
transformed_points, labels = None, None | |
# predict segmentation according to the boxes | |
masks, scores, logits = predictor.predict_torch( | |
point_coords=transformed_points.permute(1, 0, 2), | |
point_labels=labels.permute(1, 0), | |
boxes=None, | |
multimask_output=False, | |
) | |
masks = masks.cpu().detach().numpy() | |
mask_all = np.ones((input_x.shape[0], input_x.shape[1], 3)) | |
for ann in masks: | |
color_mask = np.random.random((1, 3)).tolist()[0] | |
for i in range(3): | |
mask_all[ann[0] == True, i] = color_mask[i] | |
img = input_x / 255 * 0.3 + mask_all * 0.7 | |
# generate alpha matte | |
torch.cuda.empty_cache() | |
mask = masks[0][0].astype(np.uint8)*255 | |
trimap = generate_trimap(mask, erode_kernel_size, dilate_kernel_size).astype(np.float32) | |
trimap[trimap==128] = 0.5 | |
trimap[trimap==255] = 1 | |
dino_transform = T.Compose( | |
[ | |
T.RandomResize([800], max_size=1333), | |
T.ToTensor(), | |
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
]) | |
image_transformed, _ = dino_transform(Image.fromarray(input_x), None) | |
boxes, logits, phrases = dino_predict( | |
model=grounding_dino, | |
image=image_transformed, | |
caption="glass, lens, crystal, diamond, bubble, bulb, web, grid", | |
box_threshold=0.5, | |
text_threshold=0.25, | |
) | |
annotated_frame = dino_annotate(image_source=input_x, boxes=boxes, logits=logits, phrases=phrases) | |
# 把annotated_frame的改成RGB | |
annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB) | |
if boxes.shape[0] == 0: | |
# no transparent object detected | |
pass | |
else: | |
h, w, _ = input_x.shape | |
boxes = boxes * torch.Tensor([w, h, w, h]) | |
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() | |
trimap = convert_pixels(trimap, xyxy) | |
input = { | |
"image": torch.from_numpy(input_x).permute(2, 0, 1).unsqueeze(0)/255, | |
"trimap": torch.from_numpy(trimap).unsqueeze(0).unsqueeze(0), | |
} | |
torch.cuda.empty_cache() | |
alpha = vitmatte(input)['phas'].flatten(0,2) | |
alpha = alpha.detach().cpu().numpy() | |
# get a green background | |
background = generate_checkerboard_image(input_x.shape[0], input_x.shape[1], 8) | |
# calculate foreground with alpha blending | |
foreground_alpha = input_x * np.expand_dims(alpha, axis=2).repeat(3,2)/255 + background * (1 - np.expand_dims(alpha, axis=2).repeat(3,2))/255 | |
# calculate foreground with mask | |
foreground_mask = input_x * np.expand_dims(mask/255, axis=2).repeat(3,2)/255 + background * (1 - np.expand_dims(mask/255, axis=2).repeat(3,2))/255 | |
foreground_alpha[foreground_alpha>1] = 1 | |
foreground_mask[foreground_mask>1] = 1 | |
# return img, mask_all | |
trimap[trimap==1] == 0.999 | |
# new background | |
background_1 = cv2.imread('figs/sea.jpg') | |
background_2 = cv2.imread('figs/forest.jpg') | |
background_3 = cv2.imread('figs/sunny.jpg') | |
background_1 = cv2.resize(background_1, (input_x.shape[1], input_x.shape[0])) | |
background_2 = cv2.resize(background_2, (input_x.shape[1], input_x.shape[0])) | |
background_3 = cv2.resize(background_3, (input_x.shape[1], input_x.shape[0])) | |
# to RGB | |
background_1 = cv2.cvtColor(background_1, cv2.COLOR_BGR2RGB) | |
background_2 = cv2.cvtColor(background_2, cv2.COLOR_BGR2RGB) | |
background_3 = cv2.cvtColor(background_3, cv2.COLOR_BGR2RGB) | |
# use alpha blending | |
new_bg_1 = input_x * np.expand_dims(alpha, axis=2).repeat(3,2)/255 + background_1 * (1 - np.expand_dims(alpha, axis=2).repeat(3,2))/255 | |
new_bg_2 = input_x * np.expand_dims(alpha, axis=2).repeat(3,2)/255 + background_2 * (1 - np.expand_dims(alpha, axis=2).repeat(3,2))/255 | |
new_bg_3 = input_x * np.expand_dims(alpha, axis=2).repeat(3,2)/255 + background_3 * (1 - np.expand_dims(alpha, axis=2).repeat(3,2))/255 | |
return mask, alpha, foreground_mask, foreground_alpha, new_bg_1, new_bg_2, new_bg_3 | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# <center>Matte Anything🐒 ! | |
""" | |
) | |
with gr.Row().style(equal_height=True): | |
with gr.Column(): | |
# input image | |
original_image = gr.State(value=None) # store original image without points, default None | |
input_image = gr.Image(type="numpy") | |
# point prompt | |
with gr.Column(): | |
selected_points = gr.State([]) # store points | |
with gr.Row(): | |
undo_button = gr.Button('Remove Points') | |
radio = gr.Radio(['foreground_point', 'background_point'], label='point labels') | |
# run button | |
button = gr.Button("Start!") | |
erode_kernel_size = gr.inputs.Slider(minimum=1, maximum=30, step=1, default=10, label="erode_kernel_size") | |
dilate_kernel_size = gr.inputs.Slider(minimum=1, maximum=30, step=1, default=10, label="dilate_kernel_size") | |
# show the image with mask | |
with gr.Tab(label='SAM Mask'): | |
mask = gr.Image(type='numpy') | |
# with gr.Tab(label='Trimap'): | |
# trimap = gr.Image(type='numpy') | |
with gr.Tab(label='Alpha Matte'): | |
alpha = gr.Image(type='numpy') | |
# show only mask | |
with gr.Tab(label='Foreground by SAM Mask'): | |
foreground_by_sam_mask = gr.Image(type='numpy') | |
with gr.Tab(label='Refined by ViTMatte'): | |
refined_by_vitmatte = gr.Image(type='numpy') | |
# with gr.Tab(label='Transparency Detection'): | |
# transparency = gr.Image(type='numpy') | |
with gr.Tab(label='New Background 1'): | |
new_bg_1 = gr.Image(type='numpy') | |
with gr.Tab(label='New Background 2'): | |
new_bg_2 = gr.Image(type='numpy') | |
with gr.Tab(label='New Background 3'): | |
new_bg_3 = gr.Image(type='numpy') | |
input_image.upload( | |
store_img, | |
[input_image], | |
[original_image, selected_points] | |
) | |
input_image.select( | |
get_point, | |
[input_image, selected_points, radio], | |
[input_image], | |
) | |
undo_button.click( | |
undo_points, | |
[original_image, selected_points], | |
[input_image] | |
) | |
button.click(run_inference, inputs=[original_image, selected_points, erode_kernel_size, dilate_kernel_size], outputs=[mask, alpha, \ | |
foreground_by_sam_mask, refined_by_vitmatte, new_bg_1, new_bg_2, new_bg_3]) | |
with gr.Row(): | |
with gr.Column(): | |
background_image = gr.State(value=None) | |
demo.launch(share=True) |