Spaces:
Running
Running
Commit
·
4be6b70
1
Parent(s):
59bff44
fix huggingface
Browse files- app.py +1 -1
- hivision/creator/human_matting.py +30 -6
app.py
CHANGED
|
@@ -618,4 +618,4 @@ if __name__ == "__main__":
|
|
| 618 |
],
|
| 619 |
)
|
| 620 |
|
| 621 |
-
demo.launch(
|
|
|
|
| 618 |
],
|
| 619 |
)
|
| 620 |
|
| 621 |
+
demo.launch()
|
hivision/creator/human_matting.py
CHANGED
|
@@ -15,7 +15,17 @@ from .context import Context
|
|
| 15 |
import cv2
|
| 16 |
import os
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
def extract_human(ctx: Context):
|
|
@@ -24,7 +34,21 @@ def extract_human(ctx: Context):
|
|
| 24 |
:param ctx: 上下文
|
| 25 |
"""
|
| 26 |
# 抠图
|
| 27 |
-
matting_image = get_modnet_matting(ctx.processing_image,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
# 修复抠图
|
| 29 |
ctx.processing_image = hollow_out_fix(matting_image)
|
| 30 |
ctx.matting_image = ctx.processing_image.copy()
|
|
@@ -92,13 +116,13 @@ def read_modnet_image(input_image, ref_size=512):
|
|
| 92 |
return im, width, length
|
| 93 |
|
| 94 |
|
| 95 |
-
sess = None
|
| 96 |
|
| 97 |
|
| 98 |
def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
|
| 99 |
-
global sess
|
| 100 |
-
if sess is None:
|
| 101 |
-
|
| 102 |
|
| 103 |
input_name = sess.get_inputs()[0].name
|
| 104 |
output_name = sess.get_outputs()[0].name
|
|
|
|
| 15 |
import cv2
|
| 16 |
import os
|
| 17 |
|
| 18 |
+
|
| 19 |
+
WEIGHTS = {
|
| 20 |
+
"hivision_modnet": os.path.join(
|
| 21 |
+
os.path.dirname(__file__), "weights", "hivision_modnet.onnx"
|
| 22 |
+
),
|
| 23 |
+
"modnet_photographic_portrait_matting": os.path.join(
|
| 24 |
+
os.path.dirname(__file__),
|
| 25 |
+
"weights",
|
| 26 |
+
"modnet_photographic_portrait_matting.onnx",
|
| 27 |
+
),
|
| 28 |
+
}
|
| 29 |
|
| 30 |
|
| 31 |
def extract_human(ctx: Context):
|
|
|
|
| 34 |
:param ctx: 上下文
|
| 35 |
"""
|
| 36 |
# 抠图
|
| 37 |
+
matting_image = get_modnet_matting(ctx.processing_image, WEIGHTS["hivision_modnet"])
|
| 38 |
+
# 修复抠图
|
| 39 |
+
ctx.processing_image = hollow_out_fix(matting_image)
|
| 40 |
+
ctx.matting_image = ctx.processing_image.copy()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def extract_human_modnet_photographic_portrait_matting(ctx: Context):
|
| 44 |
+
"""
|
| 45 |
+
人像抠图
|
| 46 |
+
:param ctx: 上下文
|
| 47 |
+
"""
|
| 48 |
+
# 抠图
|
| 49 |
+
matting_image = get_modnet_matting(
|
| 50 |
+
ctx.processing_image, WEIGHTS["modnet_photographic_portrait_matting"]
|
| 51 |
+
)
|
| 52 |
# 修复抠图
|
| 53 |
ctx.processing_image = hollow_out_fix(matting_image)
|
| 54 |
ctx.matting_image = ctx.processing_image.copy()
|
|
|
|
| 116 |
return im, width, length
|
| 117 |
|
| 118 |
|
| 119 |
+
# sess = None
|
| 120 |
|
| 121 |
|
| 122 |
def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
|
| 123 |
+
# global sess
|
| 124 |
+
# if sess is None:
|
| 125 |
+
sess = onnxruntime.InferenceSession(checkpoint_path)
|
| 126 |
|
| 127 |
input_name = sess.get_inputs()[0].name
|
| 128 |
output_name = sess.get_outputs()[0].name
|