Spaces:
Running
Running
File size: 4,164 Bytes
ca46a75 4be6b70 ca46a75 4be6b70 ca46a75 4be6b70 ca46a75 4be6b70 ca46a75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024/9/5 21:21
@File: human_matting.py
@IDE: pycharm
@Description:
人像抠图
"""
import numpy as np
from PIL import Image
import onnxruntime
from .tensor2numpy import NNormalize, NTo_Tensor, NUnsqueeze
from .context import Context
import cv2
import os
WEIGHTS = {
"hivision_modnet": os.path.join(
os.path.dirname(__file__), "weights", "hivision_modnet.onnx"
),
"modnet_photographic_portrait_matting": os.path.join(
os.path.dirname(__file__),
"weights",
"modnet_photographic_portrait_matting.onnx",
),
}
def extract_human(ctx: Context):
"""
人像抠图
:param ctx: 上下文
"""
# 抠图
matting_image = get_modnet_matting(ctx.processing_image, WEIGHTS["hivision_modnet"])
# 修复抠图
ctx.processing_image = hollow_out_fix(matting_image)
ctx.matting_image = ctx.processing_image.copy()
def extract_human_modnet_photographic_portrait_matting(ctx: Context):
"""
人像抠图
:param ctx: 上下文
"""
# 抠图
matting_image = get_modnet_matting(
ctx.processing_image, WEIGHTS["modnet_photographic_portrait_matting"]
)
# 修复抠图
ctx.processing_image = hollow_out_fix(matting_image)
ctx.matting_image = ctx.processing_image.copy()
def hollow_out_fix(src: np.ndarray) -> np.ndarray:
"""
修补抠图区域,作为抠图模型精度不够的补充
:param src:
:return:
"""
b, g, r, a = cv2.split(src)
src_bgr = cv2.merge((b, g, r))
# -----------padding---------- #
add_area = np.zeros((10, a.shape[1]), np.uint8)
a = np.vstack((add_area, a, add_area))
add_area = np.zeros((a.shape[0], 10), np.uint8)
a = np.hstack((add_area, a, add_area))
# -------------end------------ #
_, a_threshold = cv2.threshold(a, 127, 255, 0)
a_erode = cv2.erode(
a_threshold,
kernel=cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)),
iterations=3,
)
contours, hierarchy = cv2.findContours(
a_erode, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
)
contours = [x for x in contours]
# contours = np.squeeze(contours)
contours.sort(key=lambda c: cv2.contourArea(c), reverse=True)
a_contour = cv2.drawContours(np.zeros(a.shape, np.uint8), contours[0], -1, 255, 2)
# a_base = a_contour[1:-1, 1:-1]
h, w = a.shape[:2]
mask = np.zeros(
[h + 2, w + 2], np.uint8
) # mask 必须行和列都加 2,且必须为 uint8 单通道阵列
cv2.floodFill(a_contour, mask=mask, seedPoint=(0, 0), newVal=255)
a = cv2.add(a, 255 - a_contour)
return cv2.merge((src_bgr, a[10:-10, 10:-10]))
def image2bgr(input_image):
if len(input_image.shape) == 2:
input_image = input_image[:, :, None]
if input_image.shape[2] == 1:
result_image = np.repeat(input_image, 3, axis=2)
elif input_image.shape[2] == 4:
result_image = input_image[:, :, 0:3]
else:
result_image = input_image
return result_image
def read_modnet_image(input_image, ref_size=512):
im = Image.fromarray(np.uint8(input_image))
width, length = im.size[0], im.size[1]
im = np.asarray(im)
im = image2bgr(im)
im = cv2.resize(im, (ref_size, ref_size), interpolation=cv2.INTER_AREA)
im = NNormalize(im, mean=np.array([0.5, 0.5, 0.5]), std=np.array([0.5, 0.5, 0.5]))
im = NUnsqueeze(NTo_Tensor(im))
return im, width, length
# sess = None
def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
# global sess
# if sess is None:
sess = onnxruntime.InferenceSession(checkpoint_path)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
im, width, length = read_modnet_image(input_image=input_image, ref_size=ref_size)
matte = sess.run([output_name], {input_name: im})
matte = (matte[0] * 255).astype("uint8")
matte = np.squeeze(matte)
mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA)
b, g, r = cv2.split(np.uint8(input_image))
output_image = cv2.merge((b, g, r, mask))
return output_image
|