wahaha commited on
Commit
4a57abf
1 Parent(s): dfb75bd
README.md CHANGED
@@ -1,4 +1,5 @@
1
  ---
 
2
  title: U2net_portrait
3
  emoji: 🦀
4
  colorFrom: indigo
 
1
  ---
2
+ python_version: 3.7
3
  title: U2net_portrait
4
  emoji: 🦀
5
  colorFrom: indigo
U-2-Net/utils/face_seg.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from tensorflow.python.platform import gfile
6
+
7
+
8
+ curPath = os.path.abspath(os.path.dirname(__file__))
9
+
10
+
11
+ class FaceSeg:
12
+ def __init__(self, model_path=os.path.join(curPath, 'seg_model_384.pb')):
13
+ config = tf.compat.v1.ConfigProto()
14
+ config.gpu_options.allow_growth = True
15
+ self._graph = tf.Graph()
16
+ self._sess = tf.compat.v1.Session(config=config, graph=self._graph)
17
+
18
+ self.pb_file_path = model_path
19
+ self._restore_from_pb()
20
+ self.input_op = self._sess.graph.get_tensor_by_name('input_1:0')
21
+ self.output_op = self._sess.graph.get_tensor_by_name('sigmoid/Sigmoid:0')
22
+
23
+ def _restore_from_pb(self):
24
+ with self._sess.as_default():
25
+ with self._graph.as_default():
26
+ with gfile.FastGFile(self.pb_file_path, 'rb') as f:
27
+ graph_def = tf.compat.v1.GraphDef()
28
+ graph_def.ParseFromString(f.read())
29
+ tf.import_graph_def(graph_def, name='')
30
+
31
+ def input_transform(self, image):
32
+ image = cv2.resize(image, (384, 384), interpolation=cv2.INTER_AREA)
33
+ image_input = (image / 255.)[np.newaxis, :, :, :]
34
+ return image_input
35
+
36
+ def output_transform(self, output, shape):
37
+ output = cv2.resize(output, (shape[1], shape[0]))
38
+ image_output = (output * 255).astype(np.uint8)
39
+ return image_output
40
+
41
+ def get_mask(self, image):
42
+ image_input = self.input_transform(image)
43
+ output = self._sess.run(self.output_op, feed_dict={self.input_op: image_input})[0]
44
+ return self.output_transform(output, shape=image.shape[:2])
U-2-Net/utils/seg_model_384.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66a04bc2032b54013d2ae994b34d22518144276f1cbdd2d8cbb1a4a28f50285f
3
+ size 32477258
app.py CHANGED
@@ -27,8 +27,9 @@ from data_loader import SalObjDataset
27
  from model import U2NET # full size version 173.6 MB
28
  from model import U2NETP # small version u2net 4.7 MB
29
 
30
- from modnet import ModNet
31
  import huggingface_hub
 
 
32
 
33
  # normalize the predicted SOD probability map
34
  def normPRED(d):
@@ -59,17 +60,11 @@ def save_output(image_name,pred,d_dir):
59
  imo.save(d_dir+'/'+imidx+'.png')
60
  return d_dir+'/'+imidx+'.png'
61
 
62
-
63
-
64
- modnet_path = huggingface_hub.hf_hub_download('hylee/apdrawing_model',
65
- 'modnet.onnx',
66
- force_filename='modnet.onnx')
67
- modnet = ModNet(modnet_path)
68
 
69
  # --------- 1. get image path and name ---------
70
  model_name='u2net_portrait'#u2netp
71
 
72
-
73
  image_dir = 'portrait_im'
74
  prediction_dir = 'portrait_results'
75
  if(not os.path.exists(prediction_dir)):
@@ -90,9 +85,18 @@ net.eval()
90
 
91
 
92
  def process(im):
93
- #image = modnet.segment(im.name)
94
- #im_path = os.path.abspath(os.path.basename(im.name))
95
- #Image.fromarray(np.uint8(image)).save(im_path)
 
 
 
 
 
 
 
 
 
96
 
97
  img_name_list = [im.name]
98
  print("Number of images: ", len(img_name_list))
@@ -135,7 +139,7 @@ def process(im):
135
 
136
  print(results)
137
 
138
- return Image.open(results[0])
139
 
140
  title = "U-2-Net"
141
  description = "Gradio demo for U-2-Net, https://github.com/xuebinqin/U-2-Net"
@@ -145,7 +149,7 @@ gr.Interface(
145
  process,
146
  [gr.inputs.Image(type="file", label="Input")
147
  ],
148
- [gr.outputs.Image(type="pil", label="Output")],
149
  title=title,
150
  description=description,
151
  article=article,
 
27
  from model import U2NET # full size version 173.6 MB
28
  from model import U2NETP # small version u2net 4.7 MB
29
 
 
30
  import huggingface_hub
31
+ from utils.face_seg import FaceSeg
32
+ import cv2
33
 
34
  # normalize the predicted SOD probability map
35
  def normPRED(d):
 
60
  imo.save(d_dir+'/'+imidx+'.png')
61
  return d_dir+'/'+imidx+'.png'
62
 
63
+ segment = FaceSeg()
 
 
 
 
 
64
 
65
  # --------- 1. get image path and name ---------
66
  model_name='u2net_portrait'#u2netp
67
 
 
68
  image_dir = 'portrait_im'
69
  prediction_dir = 'portrait_results'
70
  if(not os.path.exists(prediction_dir)):
 
85
 
86
 
87
  def process(im):
88
+ image = cv2.imread(im.name)
89
+ matte = self.segment.get_mask(face)
90
+
91
+ if len(image.shape) == 2:
92
+ image = image[:, :, None]
93
+ if image.shape[2] == 1:
94
+ image = np.repeat(image, 3, axis=2)
95
+ elif image.shape[2] == 4:
96
+ image = image[:, :, 0:3]
97
+ matte = np.repeat(np.asarray(matte)[:, :, None], 3, axis=2) / 255
98
+ foreground = image * matte + np.full(image.shape, 255) * (1 - matte)
99
+ cv2.imwrite(im.name, foreground)
100
 
101
  img_name_list = [im.name]
102
  print("Number of images: ", len(img_name_list))
 
139
 
140
  print(results)
141
 
142
+ return Image.open(results[0]), Image.open(im.name)
143
 
144
  title = "U-2-Net"
145
  description = "Gradio demo for U-2-Net, https://github.com/xuebinqin/U-2-Net"
 
149
  process,
150
  [gr.inputs.Image(type="file", label="Input")
151
  ],
152
+ [gr.outputs.Image(type="pil", label="Output"), gr.outputs.Image(type="pil", label="Output")],
153
  title=title,
154
  description=description,
155
  article=article,
requirements.txt CHANGED
@@ -4,5 +4,4 @@ torch
4
  torchvision
5
  pillow
6
  opencv-python-headless
7
- onnx==1.8.1
8
- onnxruntime==1.6.0
 
4
  torchvision
5
  pillow
6
  opencv-python-headless
7
+ tensorflow-gpu==1.14.0