Matte-Anything / matte_anything.py
Jeney's picture
Upload folder using huggingface_hub
6a89c74
import os
import cv2
import torch
import numpy as np
import gradio as gr
from PIL import Image
from torchvision.ops import box_convert
from detectron2.config import LazyConfig, instantiate
from detectron2.checkpoint import DetectionCheckpointer
from segment_anything import sam_model_registry, SamPredictor
import groundingdino.datasets.transforms as T
from groundingdino.util.inference import load_model as dino_load_model, predict as dino_predict, annotate as dino_annotate
models = {
'vit_h': './pretrained/sam_vit_h_4b8939.pth',
'vit_b': './pretrained/sam_vit_b_01ec64.pth'
}
vitmatte_models = {
'vit_b': './pretrained/ViTMatte_B_DIS.pth',
}
vitmatte_config = {
'vit_b': './configs/matte_anything.py',
}
grounding_dino = {
'config': './GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py',
'weight': './pretrained/groundingdino_swint_ogc.pth'
}
def generate_checkerboard_image(height, width, num_squares):
num_squares_h = num_squares
square_size_h = height // num_squares_h
square_size_w = square_size_h
num_squares_w = width // square_size_w
new_height = num_squares_h * square_size_h
new_width = num_squares_w * square_size_w
image = np.zeros((new_height, new_width), dtype=np.uint8)
for i in range(num_squares_h):
for j in range(num_squares_w):
start_x = j * square_size_w
start_y = i * square_size_h
color = 255 if (i + j) % 2 == 0 else 200
image[start_y:start_y + square_size_h, start_x:start_x + square_size_w] = color
image = cv2.resize(image, (width, height))
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
return image
def init_segment_anything(model_type):
"""
Initialize the segmenting anything with model_type in ['vit_b', 'vit_l', 'vit_h']
"""
sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device)
predictor = SamPredictor(sam)
return predictor
def init_vitmatte(model_type):
"""
Initialize the vitmatte with model_type in ['vit_s', 'vit_b']
"""
cfg = LazyConfig.load(vitmatte_config[model_type])
vitmatte = instantiate(cfg.model)
vitmatte.to(device)
vitmatte.eval()
DetectionCheckpointer(vitmatte).load(vitmatte_models[model_type])
return vitmatte
def generate_trimap(mask, erode_kernel_size=10, dilate_kernel_size=10):
erode_kernel = np.ones((erode_kernel_size, erode_kernel_size), np.uint8)
dilate_kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8)
eroded = cv2.erode(mask, erode_kernel, iterations=5)
dilated = cv2.dilate(mask, dilate_kernel, iterations=5)
trimap = np.zeros_like(mask)
trimap[dilated==255] = 128
trimap[eroded==255] = 255
return trimap
# user click the image to get points, and show the points on the image
def get_point(img, sel_pix, point_type, evt: gr.SelectData):
if point_type == 'foreground_point':
sel_pix.append((evt.index, 1)) # append the foreground_point
elif point_type == 'background_point':
sel_pix.append((evt.index, 0)) # append the background_point
else:
sel_pix.append((evt.index, 1)) # default foreground_point
# draw points
for point, label in sel_pix:
cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
if img[..., 0][0, 0] == img[..., 2][0, 0]: # BGR to RGB
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img if isinstance(img, np.ndarray) else np.array(img)
# undo the selected point
def undo_points(orig_img, sel_pix):
temp = orig_img.copy()
# draw points
if len(sel_pix) != 0:
sel_pix.pop()
for point, label in sel_pix:
cv2.drawMarker(temp, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
if temp[..., 0][0, 0] == temp[..., 2][0, 0]: # BGR to RGB
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
return temp if isinstance(temp, np.ndarray) else np.array(temp)
# once user upload an image, the original image is stored in `original_image`
def store_img(img):
return img, [] # when new image is uploaded, `selected_points` should be empty
def convert_pixels(gray_image, boxes):
converted_image = np.copy(gray_image)
for box in boxes:
x1, y1, x2, y2 = box
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
converted_image[y1:y2, x1:x2][converted_image[y1:y2, x1:x2] == 1] = 0.5
return converted_image
if __name__ == "__main__":
device = 'cuda'
sam_model = 'vit_h'
vitmatte_model = 'vit_b'
colors = [(255, 0, 0), (0, 255, 0)]
markers = [1, 5]
print('Initializing models... Please wait...')
predictor = init_segment_anything(sam_model)
vitmatte = init_vitmatte(vitmatte_model)
grounding_dino = dino_load_model(grounding_dino['config'], grounding_dino['weight'])
def run_inference(input_x, selected_points, erode_kernel_size, dilate_kernel_size):
predictor.set_image(input_x)
if len(selected_points) != 0:
points = torch.Tensor([p for p, _ in selected_points]).to(device).unsqueeze(1)
labels = torch.Tensor([int(l) for _, l in selected_points]).to(device).unsqueeze(1)
transformed_points = predictor.transform.apply_coords_torch(points, input_x.shape[:2])
print(points.size(), transformed_points.size(), labels.size(), input_x.shape, points)
else:
transformed_points, labels = None, None
# predict segmentation according to the boxes
masks, scores, logits = predictor.predict_torch(
point_coords=transformed_points.permute(1, 0, 2),
point_labels=labels.permute(1, 0),
boxes=None,
multimask_output=False,
)
masks = masks.cpu().detach().numpy()
mask_all = np.ones((input_x.shape[0], input_x.shape[1], 3))
for ann in masks:
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
mask_all[ann[0] == True, i] = color_mask[i]
img = input_x / 255 * 0.3 + mask_all * 0.7
# generate alpha matte
torch.cuda.empty_cache()
mask = masks[0][0].astype(np.uint8)*255
trimap = generate_trimap(mask, erode_kernel_size, dilate_kernel_size).astype(np.float32)
trimap[trimap==128] = 0.5
trimap[trimap==255] = 1
dino_transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
image_transformed, _ = dino_transform(Image.fromarray(input_x), None)
boxes, logits, phrases = dino_predict(
model=grounding_dino,
image=image_transformed,
caption="glass, lens, crystal, diamond, bubble, bulb, web, grid",
box_threshold=0.5,
text_threshold=0.25,
)
annotated_frame = dino_annotate(image_source=input_x, boxes=boxes, logits=logits, phrases=phrases)
# 把annotated_frame的改成RGB
annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
if boxes.shape[0] == 0:
# no transparent object detected
pass
else:
h, w, _ = input_x.shape
boxes = boxes * torch.Tensor([w, h, w, h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
trimap = convert_pixels(trimap, xyxy)
input = {
"image": torch.from_numpy(input_x).permute(2, 0, 1).unsqueeze(0)/255,
"trimap": torch.from_numpy(trimap).unsqueeze(0).unsqueeze(0),
}
torch.cuda.empty_cache()
alpha = vitmatte(input)['phas'].flatten(0,2)
alpha = alpha.detach().cpu().numpy()
# get a green background
background = generate_checkerboard_image(input_x.shape[0], input_x.shape[1], 8)
# calculate foreground with alpha blending
foreground_alpha = input_x * np.expand_dims(alpha, axis=2).repeat(3,2)/255 + background * (1 - np.expand_dims(alpha, axis=2).repeat(3,2))/255
# calculate foreground with mask
foreground_mask = input_x * np.expand_dims(mask/255, axis=2).repeat(3,2)/255 + background * (1 - np.expand_dims(mask/255, axis=2).repeat(3,2))/255
foreground_alpha[foreground_alpha>1] = 1
foreground_mask[foreground_mask>1] = 1
# return img, mask_all
trimap[trimap==1] == 0.999
# new background
background_1 = cv2.imread('figs/sea.jpg')
background_2 = cv2.imread('figs/forest.jpg')
background_3 = cv2.imread('figs/sunny.jpg')
background_1 = cv2.resize(background_1, (input_x.shape[1], input_x.shape[0]))
background_2 = cv2.resize(background_2, (input_x.shape[1], input_x.shape[0]))
background_3 = cv2.resize(background_3, (input_x.shape[1], input_x.shape[0]))
# to RGB
background_1 = cv2.cvtColor(background_1, cv2.COLOR_BGR2RGB)
background_2 = cv2.cvtColor(background_2, cv2.COLOR_BGR2RGB)
background_3 = cv2.cvtColor(background_3, cv2.COLOR_BGR2RGB)
# use alpha blending
new_bg_1 = input_x * np.expand_dims(alpha, axis=2).repeat(3,2)/255 + background_1 * (1 - np.expand_dims(alpha, axis=2).repeat(3,2))/255
new_bg_2 = input_x * np.expand_dims(alpha, axis=2).repeat(3,2)/255 + background_2 * (1 - np.expand_dims(alpha, axis=2).repeat(3,2))/255
new_bg_3 = input_x * np.expand_dims(alpha, axis=2).repeat(3,2)/255 + background_3 * (1 - np.expand_dims(alpha, axis=2).repeat(3,2))/255
return mask, alpha, foreground_mask, foreground_alpha, new_bg_1, new_bg_2, new_bg_3
with gr.Blocks() as demo:
gr.Markdown(
"""
# <center>Matte Anything🐒 !
"""
)
with gr.Row().style(equal_height=True):
with gr.Column():
# input image
original_image = gr.State(value=None) # store original image without points, default None
input_image = gr.Image(type="numpy")
# point prompt
with gr.Column():
selected_points = gr.State([]) # store points
with gr.Row():
undo_button = gr.Button('Remove Points')
radio = gr.Radio(['foreground_point', 'background_point'], label='point labels')
# run button
button = gr.Button("Start!")
erode_kernel_size = gr.inputs.Slider(minimum=1, maximum=30, step=1, default=10, label="erode_kernel_size")
dilate_kernel_size = gr.inputs.Slider(minimum=1, maximum=30, step=1, default=10, label="dilate_kernel_size")
# show the image with mask
with gr.Tab(label='SAM Mask'):
mask = gr.Image(type='numpy')
# with gr.Tab(label='Trimap'):
# trimap = gr.Image(type='numpy')
with gr.Tab(label='Alpha Matte'):
alpha = gr.Image(type='numpy')
# show only mask
with gr.Tab(label='Foreground by SAM Mask'):
foreground_by_sam_mask = gr.Image(type='numpy')
with gr.Tab(label='Refined by ViTMatte'):
refined_by_vitmatte = gr.Image(type='numpy')
# with gr.Tab(label='Transparency Detection'):
# transparency = gr.Image(type='numpy')
with gr.Tab(label='New Background 1'):
new_bg_1 = gr.Image(type='numpy')
with gr.Tab(label='New Background 2'):
new_bg_2 = gr.Image(type='numpy')
with gr.Tab(label='New Background 3'):
new_bg_3 = gr.Image(type='numpy')
input_image.upload(
store_img,
[input_image],
[original_image, selected_points]
)
input_image.select(
get_point,
[input_image, selected_points, radio],
[input_image],
)
undo_button.click(
undo_points,
[original_image, selected_points],
[input_image]
)
button.click(run_inference, inputs=[original_image, selected_points, erode_kernel_size, dilate_kernel_size], outputs=[mask, alpha, \
foreground_by_sam_mask, refined_by_vitmatte, new_bg_1, new_bg_2, new_bg_3])
with gr.Row():
with gr.Column():
background_image = gr.State(value=None)
demo.launch(share=True)