import os import cv2 import torch import numpy as np import gradio as gr from PIL import Image from torchvision.ops import box_convert from detectron2.config import LazyConfig, instantiate from detectron2.checkpoint import DetectionCheckpointer from segment_anything import sam_model_registry, SamPredictor import groundingdino.datasets.transforms as T from groundingdino.util.inference import load_model as dino_load_model, predict as dino_predict, annotate as dino_annotate models = { 'vit_h': './pretrained/sam_vit_h_4b8939.pth', 'vit_b': './pretrained/sam_vit_b_01ec64.pth' } vitmatte_models = { 'vit_b': './pretrained/ViTMatte_B_DIS.pth', } vitmatte_config = { 'vit_b': './configs/matte_anything.py', } grounding_dino = { 'config': './GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py', 'weight': './pretrained/groundingdino_swint_ogc.pth' } def generate_checkerboard_image(height, width, num_squares): num_squares_h = num_squares square_size_h = height // num_squares_h square_size_w = square_size_h num_squares_w = width // square_size_w new_height = num_squares_h * square_size_h new_width = num_squares_w * square_size_w image = np.zeros((new_height, new_width), dtype=np.uint8) for i in range(num_squares_h): for j in range(num_squares_w): start_x = j * square_size_w start_y = i * square_size_h color = 255 if (i + j) % 2 == 0 else 200 image[start_y:start_y + square_size_h, start_x:start_x + square_size_w] = color image = cv2.resize(image, (width, height)) image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) return image def init_segment_anything(model_type): """ Initialize the segmenting anything with model_type in ['vit_b', 'vit_l', 'vit_h'] """ sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device) predictor = SamPredictor(sam) return predictor def init_vitmatte(model_type): """ Initialize the vitmatte with model_type in ['vit_s', 'vit_b'] """ cfg = LazyConfig.load(vitmatte_config[model_type]) vitmatte = instantiate(cfg.model) vitmatte.to(device) vitmatte.eval() DetectionCheckpointer(vitmatte).load(vitmatte_models[model_type]) return vitmatte def generate_trimap(mask, erode_kernel_size=10, dilate_kernel_size=10): erode_kernel = np.ones((erode_kernel_size, erode_kernel_size), np.uint8) dilate_kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8) eroded = cv2.erode(mask, erode_kernel, iterations=5) dilated = cv2.dilate(mask, dilate_kernel, iterations=5) trimap = np.zeros_like(mask) trimap[dilated==255] = 128 trimap[eroded==255] = 255 return trimap # user click the image to get points, and show the points on the image def get_point(img, sel_pix, point_type, evt: gr.SelectData): if point_type == 'foreground_point': sel_pix.append((evt.index, 1)) # append the foreground_point elif point_type == 'background_point': sel_pix.append((evt.index, 0)) # append the background_point else: sel_pix.append((evt.index, 1)) # default foreground_point # draw points for point, label in sel_pix: cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) if img[..., 0][0, 0] == img[..., 2][0, 0]: # BGR to RGB img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img if isinstance(img, np.ndarray) else np.array(img) # undo the selected point def undo_points(orig_img, sel_pix): temp = orig_img.copy() # draw points if len(sel_pix) != 0: sel_pix.pop() for point, label in sel_pix: cv2.drawMarker(temp, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) if temp[..., 0][0, 0] == temp[..., 2][0, 0]: # BGR to RGB temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) return temp if isinstance(temp, np.ndarray) else np.array(temp) # once user upload an image, the original image is stored in `original_image` def store_img(img): return img, [] # when new image is uploaded, `selected_points` should be empty def convert_pixels(gray_image, boxes): converted_image = np.copy(gray_image) for box in boxes: x1, y1, x2, y2 = box x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) converted_image[y1:y2, x1:x2][converted_image[y1:y2, x1:x2] == 1] = 0.5 return converted_image if __name__ == "__main__": device = 'cuda' sam_model = 'vit_h' vitmatte_model = 'vit_b' colors = [(255, 0, 0), (0, 255, 0)] markers = [1, 5] print('Initializing models... Please wait...') predictor = init_segment_anything(sam_model) vitmatte = init_vitmatte(vitmatte_model) grounding_dino = dino_load_model(grounding_dino['config'], grounding_dino['weight']) def run_inference(input_x, selected_points, erode_kernel_size, dilate_kernel_size): predictor.set_image(input_x) if len(selected_points) != 0: points = torch.Tensor([p for p, _ in selected_points]).to(device).unsqueeze(1) labels = torch.Tensor([int(l) for _, l in selected_points]).to(device).unsqueeze(1) transformed_points = predictor.transform.apply_coords_torch(points, input_x.shape[:2]) print(points.size(), transformed_points.size(), labels.size(), input_x.shape, points) else: transformed_points, labels = None, None # predict segmentation according to the boxes masks, scores, logits = predictor.predict_torch( point_coords=transformed_points.permute(1, 0, 2), point_labels=labels.permute(1, 0), boxes=None, multimask_output=False, ) masks = masks.cpu().detach().numpy() mask_all = np.ones((input_x.shape[0], input_x.shape[1], 3)) for ann in masks: color_mask = np.random.random((1, 3)).tolist()[0] for i in range(3): mask_all[ann[0] == True, i] = color_mask[i] img = input_x / 255 * 0.3 + mask_all * 0.7 # generate alpha matte torch.cuda.empty_cache() mask = masks[0][0].astype(np.uint8)*255 trimap = generate_trimap(mask, erode_kernel_size, dilate_kernel_size).astype(np.float32) trimap[trimap==128] = 0.5 trimap[trimap==255] = 1 dino_transform = T.Compose( [ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) image_transformed, _ = dino_transform(Image.fromarray(input_x), None) boxes, logits, phrases = dino_predict( model=grounding_dino, image=image_transformed, caption="glass, lens, crystal, diamond, bubble, bulb, web, grid", box_threshold=0.5, text_threshold=0.25, ) annotated_frame = dino_annotate(image_source=input_x, boxes=boxes, logits=logits, phrases=phrases) # 把annotated_frame的改成RGB annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB) if boxes.shape[0] == 0: # no transparent object detected pass else: h, w, _ = input_x.shape boxes = boxes * torch.Tensor([w, h, w, h]) xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() trimap = convert_pixels(trimap, xyxy) input = { "image": torch.from_numpy(input_x).permute(2, 0, 1).unsqueeze(0)/255, "trimap": torch.from_numpy(trimap).unsqueeze(0).unsqueeze(0), } torch.cuda.empty_cache() alpha = vitmatte(input)['phas'].flatten(0,2) alpha = alpha.detach().cpu().numpy() # get a green background background = generate_checkerboard_image(input_x.shape[0], input_x.shape[1], 8) # calculate foreground with alpha blending foreground_alpha = input_x * np.expand_dims(alpha, axis=2).repeat(3,2)/255 + background * (1 - np.expand_dims(alpha, axis=2).repeat(3,2))/255 # calculate foreground with mask foreground_mask = input_x * np.expand_dims(mask/255, axis=2).repeat(3,2)/255 + background * (1 - np.expand_dims(mask/255, axis=2).repeat(3,2))/255 foreground_alpha[foreground_alpha>1] = 1 foreground_mask[foreground_mask>1] = 1 # return img, mask_all trimap[trimap==1] == 0.999 # new background background_1 = cv2.imread('figs/sea.jpg') background_2 = cv2.imread('figs/forest.jpg') background_3 = cv2.imread('figs/sunny.jpg') background_1 = cv2.resize(background_1, (input_x.shape[1], input_x.shape[0])) background_2 = cv2.resize(background_2, (input_x.shape[1], input_x.shape[0])) background_3 = cv2.resize(background_3, (input_x.shape[1], input_x.shape[0])) # to RGB background_1 = cv2.cvtColor(background_1, cv2.COLOR_BGR2RGB) background_2 = cv2.cvtColor(background_2, cv2.COLOR_BGR2RGB) background_3 = cv2.cvtColor(background_3, cv2.COLOR_BGR2RGB) # use alpha blending new_bg_1 = input_x * np.expand_dims(alpha, axis=2).repeat(3,2)/255 + background_1 * (1 - np.expand_dims(alpha, axis=2).repeat(3,2))/255 new_bg_2 = input_x * np.expand_dims(alpha, axis=2).repeat(3,2)/255 + background_2 * (1 - np.expand_dims(alpha, axis=2).repeat(3,2))/255 new_bg_3 = input_x * np.expand_dims(alpha, axis=2).repeat(3,2)/255 + background_3 * (1 - np.expand_dims(alpha, axis=2).repeat(3,2))/255 return mask, alpha, foreground_mask, foreground_alpha, new_bg_1, new_bg_2, new_bg_3 with gr.Blocks() as demo: gr.Markdown( """ #
Matte Anything🐒 ! """ ) with gr.Row().style(equal_height=True): with gr.Column(): # input image original_image = gr.State(value=None) # store original image without points, default None input_image = gr.Image(type="numpy") # point prompt with gr.Column(): selected_points = gr.State([]) # store points with gr.Row(): undo_button = gr.Button('Remove Points') radio = gr.Radio(['foreground_point', 'background_point'], label='point labels') # run button button = gr.Button("Start!") erode_kernel_size = gr.inputs.Slider(minimum=1, maximum=30, step=1, default=10, label="erode_kernel_size") dilate_kernel_size = gr.inputs.Slider(minimum=1, maximum=30, step=1, default=10, label="dilate_kernel_size") # show the image with mask with gr.Tab(label='SAM Mask'): mask = gr.Image(type='numpy') # with gr.Tab(label='Trimap'): # trimap = gr.Image(type='numpy') with gr.Tab(label='Alpha Matte'): alpha = gr.Image(type='numpy') # show only mask with gr.Tab(label='Foreground by SAM Mask'): foreground_by_sam_mask = gr.Image(type='numpy') with gr.Tab(label='Refined by ViTMatte'): refined_by_vitmatte = gr.Image(type='numpy') # with gr.Tab(label='Transparency Detection'): # transparency = gr.Image(type='numpy') with gr.Tab(label='New Background 1'): new_bg_1 = gr.Image(type='numpy') with gr.Tab(label='New Background 2'): new_bg_2 = gr.Image(type='numpy') with gr.Tab(label='New Background 3'): new_bg_3 = gr.Image(type='numpy') input_image.upload( store_img, [input_image], [original_image, selected_points] ) input_image.select( get_point, [input_image, selected_points, radio], [input_image], ) undo_button.click( undo_points, [original_image, selected_points], [input_image] ) button.click(run_inference, inputs=[original_image, selected_points, erode_kernel_size, dilate_kernel_size], outputs=[mask, alpha, \ foreground_by_sam_mask, refined_by_vitmatte, new_bg_1, new_bg_2, new_bg_3]) with gr.Row(): with gr.Column(): background_image = gr.State(value=None) demo.launch(share=True)