Spaces:
Build error
Build error
init
Browse files- README.md +1 -0
- U-2-Net/utils/face_seg.py +44 -0
- U-2-Net/utils/seg_model_384.pb +3 -0
- app.py +17 -13
- requirements.txt +1 -2
README.md
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
---
|
|
|
2 |
title: U2net_portrait
|
3 |
emoji: 🦀
|
4 |
colorFrom: indigo
|
|
|
1 |
---
|
2 |
+
python_version: 3.7
|
3 |
title: U2net_portrait
|
4 |
emoji: 🦀
|
5 |
colorFrom: indigo
|
U-2-Net/utils/face_seg.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import tensorflow as tf
|
5 |
+
from tensorflow.python.platform import gfile
|
6 |
+
|
7 |
+
|
8 |
+
curPath = os.path.abspath(os.path.dirname(__file__))
|
9 |
+
|
10 |
+
|
11 |
+
class FaceSeg:
|
12 |
+
def __init__(self, model_path=os.path.join(curPath, 'seg_model_384.pb')):
|
13 |
+
config = tf.compat.v1.ConfigProto()
|
14 |
+
config.gpu_options.allow_growth = True
|
15 |
+
self._graph = tf.Graph()
|
16 |
+
self._sess = tf.compat.v1.Session(config=config, graph=self._graph)
|
17 |
+
|
18 |
+
self.pb_file_path = model_path
|
19 |
+
self._restore_from_pb()
|
20 |
+
self.input_op = self._sess.graph.get_tensor_by_name('input_1:0')
|
21 |
+
self.output_op = self._sess.graph.get_tensor_by_name('sigmoid/Sigmoid:0')
|
22 |
+
|
23 |
+
def _restore_from_pb(self):
|
24 |
+
with self._sess.as_default():
|
25 |
+
with self._graph.as_default():
|
26 |
+
with gfile.FastGFile(self.pb_file_path, 'rb') as f:
|
27 |
+
graph_def = tf.compat.v1.GraphDef()
|
28 |
+
graph_def.ParseFromString(f.read())
|
29 |
+
tf.import_graph_def(graph_def, name='')
|
30 |
+
|
31 |
+
def input_transform(self, image):
|
32 |
+
image = cv2.resize(image, (384, 384), interpolation=cv2.INTER_AREA)
|
33 |
+
image_input = (image / 255.)[np.newaxis, :, :, :]
|
34 |
+
return image_input
|
35 |
+
|
36 |
+
def output_transform(self, output, shape):
|
37 |
+
output = cv2.resize(output, (shape[1], shape[0]))
|
38 |
+
image_output = (output * 255).astype(np.uint8)
|
39 |
+
return image_output
|
40 |
+
|
41 |
+
def get_mask(self, image):
|
42 |
+
image_input = self.input_transform(image)
|
43 |
+
output = self._sess.run(self.output_op, feed_dict={self.input_op: image_input})[0]
|
44 |
+
return self.output_transform(output, shape=image.shape[:2])
|
U-2-Net/utils/seg_model_384.pb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:66a04bc2032b54013d2ae994b34d22518144276f1cbdd2d8cbb1a4a28f50285f
|
3 |
+
size 32477258
|
app.py
CHANGED
@@ -27,8 +27,9 @@ from data_loader import SalObjDataset
|
|
27 |
from model import U2NET # full size version 173.6 MB
|
28 |
from model import U2NETP # small version u2net 4.7 MB
|
29 |
|
30 |
-
from modnet import ModNet
|
31 |
import huggingface_hub
|
|
|
|
|
32 |
|
33 |
# normalize the predicted SOD probability map
|
34 |
def normPRED(d):
|
@@ -59,17 +60,11 @@ def save_output(image_name,pred,d_dir):
|
|
59 |
imo.save(d_dir+'/'+imidx+'.png')
|
60 |
return d_dir+'/'+imidx+'.png'
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
modnet_path = huggingface_hub.hf_hub_download('hylee/apdrawing_model',
|
65 |
-
'modnet.onnx',
|
66 |
-
force_filename='modnet.onnx')
|
67 |
-
modnet = ModNet(modnet_path)
|
68 |
|
69 |
# --------- 1. get image path and name ---------
|
70 |
model_name='u2net_portrait'#u2netp
|
71 |
|
72 |
-
|
73 |
image_dir = 'portrait_im'
|
74 |
prediction_dir = 'portrait_results'
|
75 |
if(not os.path.exists(prediction_dir)):
|
@@ -90,9 +85,18 @@ net.eval()
|
|
90 |
|
91 |
|
92 |
def process(im):
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
img_name_list = [im.name]
|
98 |
print("Number of images: ", len(img_name_list))
|
@@ -135,7 +139,7 @@ def process(im):
|
|
135 |
|
136 |
print(results)
|
137 |
|
138 |
-
return Image.open(results[0])
|
139 |
|
140 |
title = "U-2-Net"
|
141 |
description = "Gradio demo for U-2-Net, https://github.com/xuebinqin/U-2-Net"
|
@@ -145,7 +149,7 @@ gr.Interface(
|
|
145 |
process,
|
146 |
[gr.inputs.Image(type="file", label="Input")
|
147 |
],
|
148 |
-
[gr.outputs.Image(type="pil", label="Output")],
|
149 |
title=title,
|
150 |
description=description,
|
151 |
article=article,
|
|
|
27 |
from model import U2NET # full size version 173.6 MB
|
28 |
from model import U2NETP # small version u2net 4.7 MB
|
29 |
|
|
|
30 |
import huggingface_hub
|
31 |
+
from utils.face_seg import FaceSeg
|
32 |
+
import cv2
|
33 |
|
34 |
# normalize the predicted SOD probability map
|
35 |
def normPRED(d):
|
|
|
60 |
imo.save(d_dir+'/'+imidx+'.png')
|
61 |
return d_dir+'/'+imidx+'.png'
|
62 |
|
63 |
+
segment = FaceSeg()
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
# --------- 1. get image path and name ---------
|
66 |
model_name='u2net_portrait'#u2netp
|
67 |
|
|
|
68 |
image_dir = 'portrait_im'
|
69 |
prediction_dir = 'portrait_results'
|
70 |
if(not os.path.exists(prediction_dir)):
|
|
|
85 |
|
86 |
|
87 |
def process(im):
|
88 |
+
image = cv2.imread(im.name)
|
89 |
+
matte = self.segment.get_mask(face)
|
90 |
+
|
91 |
+
if len(image.shape) == 2:
|
92 |
+
image = image[:, :, None]
|
93 |
+
if image.shape[2] == 1:
|
94 |
+
image = np.repeat(image, 3, axis=2)
|
95 |
+
elif image.shape[2] == 4:
|
96 |
+
image = image[:, :, 0:3]
|
97 |
+
matte = np.repeat(np.asarray(matte)[:, :, None], 3, axis=2) / 255
|
98 |
+
foreground = image * matte + np.full(image.shape, 255) * (1 - matte)
|
99 |
+
cv2.imwrite(im.name, foreground)
|
100 |
|
101 |
img_name_list = [im.name]
|
102 |
print("Number of images: ", len(img_name_list))
|
|
|
139 |
|
140 |
print(results)
|
141 |
|
142 |
+
return Image.open(results[0]), Image.open(im.name)
|
143 |
|
144 |
title = "U-2-Net"
|
145 |
description = "Gradio demo for U-2-Net, https://github.com/xuebinqin/U-2-Net"
|
|
|
149 |
process,
|
150 |
[gr.inputs.Image(type="file", label="Input")
|
151 |
],
|
152 |
+
[gr.outputs.Image(type="pil", label="Output"), gr.outputs.Image(type="pil", label="Output")],
|
153 |
title=title,
|
154 |
description=description,
|
155 |
article=article,
|
requirements.txt
CHANGED
@@ -4,5 +4,4 @@ torch
|
|
4 |
torchvision
|
5 |
pillow
|
6 |
opencv-python-headless
|
7 |
-
|
8 |
-
onnxruntime==1.6.0
|
|
|
4 |
torchvision
|
5 |
pillow
|
6 |
opencv-python-headless
|
7 |
+
tensorflow-gpu==1.14.0
|
|