wahaha commited on
Commit
398654a
1 Parent(s): 55cb778
Files changed (3) hide show
  1. app.py +155 -0
  2. modnet.py +94 -0
  3. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+
5
+ import sys
6
+ sys.path.insert(0, 'U-2-Net')
7
+
8
+ from skimage import io, transform
9
+ import torch
10
+ import torchvision
11
+ from torch.autograd import Variable
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from torchvision import transforms#, utils
16
+ # import torch.optim as optim
17
+
18
+ import numpy as np
19
+ from PIL import Image
20
+ import glob
21
+
22
+ from data_loader import RescaleT
23
+ from data_loader import ToTensor
24
+ from data_loader import ToTensorLab
25
+ from data_loader import SalObjDataset
26
+
27
+ from model import U2NET # full size version 173.6 MB
28
+ from model import U2NETP # small version u2net 4.7 MB
29
+
30
+ from modnet import ModNet
31
+ import huggingface_hub
32
+
33
+ # normalize the predicted SOD probability map
34
+ def normPRED(d):
35
+ ma = torch.max(d)
36
+ mi = torch.min(d)
37
+
38
+ dn = (d-mi)/(ma-mi)
39
+
40
+ return dn
41
+ def save_output(image_name,pred,d_dir):
42
+ predict = pred
43
+ predict = predict.squeeze()
44
+ predict_np = predict.cpu().data.numpy()
45
+
46
+ im = Image.fromarray(predict_np*255).convert('RGB')
47
+ img_name = image_name.split(os.sep)[-1]
48
+ image = io.imread(image_name)
49
+ imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
50
+
51
+ pb_np = np.array(imo)
52
+
53
+ aaa = img_name.split(".")
54
+ bbb = aaa[0:-1]
55
+ imidx = bbb[0]
56
+ for i in range(1,len(bbb)):
57
+ imidx = imidx + "." + bbb[i]
58
+
59
+ imo.save(d_dir+'/'+imidx+'.png')
60
+ return d_dir+'/'+imidx+'.png'
61
+
62
+
63
+
64
+ modnet_path = huggingface_hub.hf_hub_download('hylee/apdrawing_model',
65
+ 'modnet.onnx',
66
+ force_filename='modnet.onnx')
67
+ modnet = ModNet(modnet_path)
68
+
69
+ # --------- 1. get image path and name ---------
70
+ model_name='u2net_portrait'#u2netp
71
+
72
+
73
+ image_dir = 'portrait_im'
74
+ prediction_dir = 'portrait_results'
75
+ if(not os.path.exists(prediction_dir)):
76
+ os.mkdir(prediction_dir)
77
+
78
+ model_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'U-2-Net/saved_models/u2net_portrait/u2net_portrait.pth')
79
+
80
+
81
+ # --------- 3. model define ---------
82
+
83
+ print("...load U2NET---173.6 MB")
84
+ net = U2NET(3,1)
85
+
86
+ net.load_state_dict(torch.load(model_dir, map_location='cpu'))
87
+ # if torch.cuda.is_available():
88
+ # net.cuda()
89
+ net.eval()
90
+
91
+
92
+ def process(im):
93
+ image = modnet.segment(im.name)
94
+ im_path = os.path.abspath(os.path.basename(im.name))
95
+ Image.fromarray(np.uint8(image)).save(im_path)
96
+
97
+ img_name_list = [im_path]
98
+ print("Number of images: ", len(img_name_list))
99
+ # --------- 2. dataloader ---------
100
+ # 1. dataloader
101
+ test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,
102
+ lbl_name_list=[],
103
+ transform=transforms.Compose([RescaleT(512),
104
+ ToTensorLab(flag=0)])
105
+ )
106
+ test_salobj_dataloader = DataLoader(test_salobj_dataset,
107
+ batch_size=1,
108
+ shuffle=False,
109
+ num_workers=1)
110
+
111
+ results = []
112
+ # --------- 4. inference for each image ---------
113
+ for i_test, data_test in enumerate(test_salobj_dataloader):
114
+
115
+ print("inferencing:", img_name_list[i_test].split(os.sep)[-1])
116
+
117
+ inputs_test = data_test['image']
118
+ inputs_test = inputs_test.type(torch.FloatTensor)
119
+
120
+ # if torch.cuda.is_available():
121
+ # inputs_test = Variable(inputs_test.cuda())
122
+ # else:
123
+ inputs_test = Variable(inputs_test)
124
+
125
+ d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
126
+
127
+ # normalization
128
+ pred = 1.0 - d1[:, 0, :, :]
129
+ pred = normPRED(pred)
130
+
131
+ # save results to test_results folder
132
+ results.append(save_output(img_name_list[i_test], pred, prediction_dir))
133
+
134
+ del d1, d2, d3, d4, d5, d6, d7
135
+
136
+ print(results)
137
+
138
+ return Image.open(results[0])
139
+
140
+ title = "U-2-Net"
141
+ description = "Gradio demo for U-2-Net, https://github.com/xuebinqin/U-2-Net"
142
+ article = ""
143
+
144
+ gr.Interface(
145
+ process,
146
+ [gr.inputs.Image(type="file", label="Input")
147
+ ],
148
+ [gr.outputs.Image(type="pil", label="Output")],
149
+ title=title,
150
+ description=description,
151
+ article=article,
152
+ examples=[],
153
+ allow_flagging=False,
154
+ allow_screenshot=False
155
+ ).launch(enable_queue=True,cache_examples=True)
modnet.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import argparse
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ import onnx
8
+ import onnxruntime
9
+
10
+
11
+ class ModNet:
12
+
13
+ def __init__(self, model_path):
14
+ # Initialize session and get prediction
15
+ self.session = onnxruntime.InferenceSession(model_path, None)
16
+
17
+ # Get x_scale_factor & y_scale_factor to resize image
18
+ def get_scale_factor(self, im_h, im_w, ref_size):
19
+
20
+ if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
21
+ if im_w >= im_h:
22
+ im_rh = ref_size
23
+ im_rw = int(im_w / im_h * ref_size)
24
+ elif im_w < im_h:
25
+ im_rw = ref_size
26
+ im_rh = int(im_h / im_w * ref_size)
27
+ else:
28
+ im_rh = im_h
29
+ im_rw = im_w
30
+
31
+ im_rw = im_rw - im_rw % 32
32
+ im_rh = im_rh - im_rh % 32
33
+
34
+ x_scale_factor = im_rw / im_w
35
+ y_scale_factor = im_rh / im_h
36
+
37
+ return x_scale_factor, y_scale_factor
38
+
39
+ def segment(self, image_path):
40
+ ref_size = 512
41
+ ##############################################
42
+ # Main Inference part
43
+ ##############################################
44
+
45
+ # read image
46
+ im = cv2.imread(image_path)
47
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
48
+
49
+ # unify image channels to 3
50
+ if len(im.shape) == 2:
51
+ im = im[:, :, None]
52
+ if im.shape[2] == 1:
53
+ im = np.repeat(im, 3, axis=2)
54
+ elif im.shape[2] == 4:
55
+ im = im[:, :, 0:3]
56
+
57
+ # normalize values to scale it between -1 to 1
58
+ im = (im - 127.5) / 127.5
59
+
60
+ im_h, im_w, im_c = im.shape
61
+ x, y = self.get_scale_factor(im_h, im_w, ref_size)
62
+
63
+ # resize image
64
+ im = cv2.resize(im, None, fx=x, fy=y, interpolation=cv2.INTER_AREA)
65
+
66
+ # prepare input shape
67
+ im = np.transpose(im)
68
+ im = np.swapaxes(im, 1, 2)
69
+ im = np.expand_dims(im, axis=0).astype('float32')
70
+
71
+ input_name = self.session.get_inputs()[0].name
72
+ output_name = self.session.get_outputs()[0].name
73
+ result = self.session.run([output_name], {input_name: im})
74
+
75
+ # refine matte
76
+ matte = (np.squeeze(result[0]) * 255).astype('uint8')
77
+ matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation=cv2.INTER_AREA)
78
+
79
+ # obtain predicted foreground
80
+ image = cv2.imread(image_path)
81
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
82
+
83
+ if len(image.shape) == 2:
84
+ image = image[:, :, None]
85
+ if image.shape[2] == 1:
86
+ image = np.repeat(image, 3, axis=2)
87
+ elif image.shape[2] == 4:
88
+ image = image[:, :, 0:3]
89
+ matte = np.repeat(np.asarray(matte)[:, :, None], 3, axis=2) / 255
90
+ foreground = image * matte + np.full(image.shape, 255) * (1 - matte)
91
+
92
+ return foreground
93
+
94
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ scikit-image
3
+ torch
4
+ torchvision
5
+ pillow
6
+ opencv-python-headless
7
+ onnx==1.8.1
8
+ onnxruntime==1.6.0