Spaces:
Runtime error
Runtime error
import cv2 | |
import argparse | |
import numpy as np | |
import torch | |
from PIL import Image | |
class BackgroundRemoval: | |
def __init__(self, device='cuda'): | |
from carvekit.api.high import HiInterface | |
self.interface = HiInterface( | |
object_type="object", # Can be "object" or "hairs-like". | |
batch_size_seg=5, | |
batch_size_matting=1, | |
device=device, | |
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net | |
matting_mask_size=2048, | |
trimap_prob_threshold=231, | |
trimap_dilation=30, | |
trimap_erosion_iters=5, | |
fp16=True, | |
) | |
def __call__(self, image): | |
# image: [H, W, 3] array in [0, 255]. | |
image = Image.fromarray(image) | |
image = self.interface([image])[0] | |
image = np.array(image) | |
return image | |
def process(image_path, mask_path): | |
mask_predictor = BackgroundRemoval() | |
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) | |
if image.shape[-1] == 4: | |
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) | |
else: | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
rgba = mask_predictor(image) # [H, W, 4] | |
cv2.imwrite(mask_path, cv2.cvtColor(rgba, cv2.COLOR_RGBA2BGRA)) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input', required=True, type=str) | |
parser.add_argument('--output', required=True, type=str) | |
opt = parser.parse_args() | |
process(opt.input, opt.output) |