HivisionIDPhotos / hivisionai /hycv /matting_tools.py
TheEeeeLin's picture
update files
d5d20be verified
raw
history blame
1.27 kB
import numpy as np
from PIL import Image
import cv2
import onnxruntime
from .tensor2numpy import NNormalize, NTo_Tensor, NUnsqueeze
from .vision import image2bgr
def read_modnet_image(input_image, ref_size=512):
im = Image.fromarray(np.uint8(input_image))
width, length = im.size[0], im.size[1]
im = np.asarray(im)
im = image2bgr(im)
im = cv2.resize(im, (ref_size, ref_size), interpolation=cv2.INTER_AREA)
im = NNormalize(im, mean=np.array([0.5, 0.5, 0.5]), std=np.array([0.5, 0.5, 0.5]))
im = NUnsqueeze(NTo_Tensor(im))
return im, width, length
def get_modnet_matting(input_image, checkpoint_path="./test.onnx", ref_size=512):
print("checkpoint_path:", checkpoint_path)
sess = onnxruntime.InferenceSession(checkpoint_path)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
im, width, length = read_modnet_image(input_image=input_image, ref_size=ref_size)
matte = sess.run([output_name], {input_name: im})
matte = (matte[0] * 255).astype('uint8')
matte = np.squeeze(matte)
mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA)
b, g, r = cv2.split(np.uint8(input_image))
output_image = cv2.merge((b, g, r, mask))
return output_image