SyncDreamer / foreground_segment.py
liuyuan-pal's picture
init
8bb8404
raw
history blame
1.53 kB
import cv2
import argparse
import numpy as np
import torch
from PIL import Image
class BackgroundRemoval:
def __init__(self, device='cuda'):
from carvekit.api.high import HiInterface
self.interface = HiInterface(
object_type="object", # Can be "object" or "hairs-like".
batch_size_seg=5,
batch_size_matting=1,
device=device,
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
fp16=True,
)
@torch.no_grad()
def __call__(self, image):
# image: [H, W, 3] array in [0, 255].
image = Image.fromarray(image)
image = self.interface([image])[0]
image = np.array(image)
return image
def process(image_path, mask_path):
mask_predictor = BackgroundRemoval()
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
if image.shape[-1] == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
else:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
rgba = mask_predictor(image) # [H, W, 4]
cv2.imwrite(mask_path, cv2.cvtColor(rgba, cv2.COLOR_RGBA2BGRA))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', required=True, type=str)
parser.add_argument('--output', required=True, type=str)
opt = parser.parse_args()
process(opt.input, opt.output)