import cv2 import argparse import numpy as np import torch from PIL import Image class BackgroundRemoval: def __init__(self, device='cuda'): from carvekit.api.high import HiInterface self.interface = HiInterface( object_type="object", # Can be "object" or "hairs-like". batch_size_seg=5, batch_size_matting=1, device=device, seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net matting_mask_size=2048, trimap_prob_threshold=231, trimap_dilation=30, trimap_erosion_iters=5, fp16=True, ) @torch.no_grad() def __call__(self, image): # image: [H, W, 3] array in [0, 255]. image = Image.fromarray(image) image = self.interface([image])[0] image = np.array(image) return image def process(image_path, mask_path): mask_predictor = BackgroundRemoval() image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) if image.shape[-1] == 4: image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) else: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) rgba = mask_predictor(image) # [H, W, 4] cv2.imwrite(mask_path, cv2.cvtColor(rgba, cv2.COLOR_RGBA2BGRA)) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--input', required=True, type=str) parser.add_argument('--output', required=True, type=str) opt = parser.parse_args() process(opt.input, opt.output)