#!/usr/bin/env python # -*- coding: utf-8 -*- r""" @DATE: 2024/9/5 21:21 @File: human_matting.py @IDE: pycharm @Description: 人像抠图 """ import numpy as np from PIL import Image import onnxruntime from .tensor2numpy import NNormalize, NTo_Tensor, NUnsqueeze from .context import Context import cv2 import os WEIGHTS = { "hivision_modnet": os.path.join( os.path.dirname(__file__), "weights", "hivision_modnet.onnx" ), "modnet_photographic_portrait_matting": os.path.join( os.path.dirname(__file__), "weights", "modnet_photographic_portrait_matting.onnx", ), } def extract_human(ctx: Context): """ 人像抠图 :param ctx: 上下文 """ # 抠图 matting_image = get_modnet_matting(ctx.processing_image, WEIGHTS["hivision_modnet"]) # 修复抠图 ctx.processing_image = hollow_out_fix(matting_image) ctx.matting_image = ctx.processing_image.copy() def extract_human_modnet_photographic_portrait_matting(ctx: Context): """ 人像抠图 :param ctx: 上下文 """ # 抠图 matting_image = get_modnet_matting( ctx.processing_image, WEIGHTS["modnet_photographic_portrait_matting"] ) # 修复抠图 ctx.processing_image = hollow_out_fix(matting_image) ctx.matting_image = ctx.processing_image.copy() def hollow_out_fix(src: np.ndarray) -> np.ndarray: """ 修补抠图区域,作为抠图模型精度不够的补充 :param src: :return: """ b, g, r, a = cv2.split(src) src_bgr = cv2.merge((b, g, r)) # -----------padding---------- # add_area = np.zeros((10, a.shape[1]), np.uint8) a = np.vstack((add_area, a, add_area)) add_area = np.zeros((a.shape[0], 10), np.uint8) a = np.hstack((add_area, a, add_area)) # -------------end------------ # _, a_threshold = cv2.threshold(a, 127, 255, 0) a_erode = cv2.erode( a_threshold, kernel=cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)), iterations=3, ) contours, hierarchy = cv2.findContours( a_erode, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE ) contours = [x for x in contours] # contours = np.squeeze(contours) contours.sort(key=lambda c: cv2.contourArea(c), reverse=True) a_contour = cv2.drawContours(np.zeros(a.shape, np.uint8), contours[0], -1, 255, 2) # a_base = a_contour[1:-1, 1:-1] h, w = a.shape[:2] mask = np.zeros( [h + 2, w + 2], np.uint8 ) # mask 必须行和列都加 2,且必须为 uint8 单通道阵列 cv2.floodFill(a_contour, mask=mask, seedPoint=(0, 0), newVal=255) a = cv2.add(a, 255 - a_contour) return cv2.merge((src_bgr, a[10:-10, 10:-10])) def image2bgr(input_image): if len(input_image.shape) == 2: input_image = input_image[:, :, None] if input_image.shape[2] == 1: result_image = np.repeat(input_image, 3, axis=2) elif input_image.shape[2] == 4: result_image = input_image[:, :, 0:3] else: result_image = input_image return result_image def read_modnet_image(input_image, ref_size=512): im = Image.fromarray(np.uint8(input_image)) width, length = im.size[0], im.size[1] im = np.asarray(im) im = image2bgr(im) im = cv2.resize(im, (ref_size, ref_size), interpolation=cv2.INTER_AREA) im = NNormalize(im, mean=np.array([0.5, 0.5, 0.5]), std=np.array([0.5, 0.5, 0.5])) im = NUnsqueeze(NTo_Tensor(im)) return im, width, length # sess = None def get_modnet_matting(input_image, checkpoint_path, ref_size=512): # global sess # if sess is None: sess = onnxruntime.InferenceSession(checkpoint_path) input_name = sess.get_inputs()[0].name output_name = sess.get_outputs()[0].name im, width, length = read_modnet_image(input_image=input_image, ref_size=ref_size) matte = sess.run([output_name], {input_name: im}) matte = (matte[0] * 255).astype("uint8") matte = np.squeeze(matte) mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA) b, g, r = cv2.split(np.uint8(input_image)) output_image = cv2.merge((b, g, r, mask)) return output_image