reeteshmukul commited on
Commit
ab44973
·
1 Parent(s): 2ef33f0

removing unnecessary files for modelcard

Browse files
README.md CHANGED
@@ -1,10 +1,4 @@
1
  ---
2
  title: Saliency Estimation
3
- emoji: 🌖
4
- colorFrom: indigo
5
- colorTo: green
6
- sdk: gradio
7
- app_file: app.py
8
- pinned: false
9
  ---
10
 
 
1
  ---
2
  title: Saliency Estimation
 
 
 
 
 
 
3
  ---
4
 
app.py DELETED
@@ -1,30 +0,0 @@
1
- from u2net.u2net_inference import get_u2net_model, get_saliency_mask
2
-
3
- import numpy as np
4
- from PIL import Image
5
- import matplotlib.pyplot as plt
6
-
7
- from pathlib import Path
8
- import matplotlib.pyplot as plt
9
- import numpy as np
10
- import gradio as gr
11
-
12
- print('Loading model...')
13
- model = get_u2net_model()
14
- print('Successfully loaded model...')
15
- examples = ['examples/1.jpg', 'examples/6.jpg']
16
-
17
-
18
- def infer(image):
19
- image_out = get_saliency_mask(model, image)
20
- return image_out
21
-
22
-
23
- iface = gr.Interface(
24
- fn=infer,
25
- title="U^2Net Based Saliency Estimatiion",
26
- description = "U^2Net Saliency Estimation",
27
- inputs=[gr.Image(label="image", type="numpy", shape=(640, 480))],
28
- outputs="image",
29
- cache_examples=True,
30
- examples=examples).launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1 +0,0 @@
1
- torch
 
 
samples/1.jpg DELETED
Binary file (62.7 kB)
 
samples/6.jpg DELETED
Binary file (105 kB)
 
u2net.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
u2net/__init__.py DELETED
File without changes
u2net/data_loader.py DELETED
@@ -1,266 +0,0 @@
1
- # data loader
2
- from __future__ import print_function, division
3
- import glob
4
- import torch
5
- from skimage import io, transform, color
6
- import numpy as np
7
- import random
8
- import math
9
- import matplotlib.pyplot as plt
10
- from torch.utils.data import Dataset, DataLoader
11
- from torchvision import transforms, utils
12
- from PIL import Image
13
-
14
- #==========================dataset load==========================
15
- class RescaleT(object):
16
-
17
- def __init__(self,output_size):
18
- assert isinstance(output_size,(int,tuple))
19
- self.output_size = output_size
20
-
21
- def __call__(self,sample):
22
- imidx, image, label = sample['imidx'], sample['image'],sample['label']
23
-
24
- h, w = image.shape[:2]
25
-
26
- if isinstance(self.output_size,int):
27
- if h > w:
28
- new_h, new_w = self.output_size*h/w,self.output_size
29
- else:
30
- new_h, new_w = self.output_size,self.output_size*w/h
31
- else:
32
- new_h, new_w = self.output_size
33
-
34
- new_h, new_w = int(new_h), int(new_w)
35
-
36
- # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
37
- # img = transform.resize(image,(new_h,new_w),mode='constant')
38
- # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
39
-
40
- img = transform.resize(image,(self.output_size,self.output_size),mode='constant')
41
- lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)
42
-
43
- return {'imidx':imidx, 'image':img,'label':lbl}
44
-
45
- class Rescale(object):
46
-
47
- def __init__(self,output_size):
48
- assert isinstance(output_size,(int,tuple))
49
- self.output_size = output_size
50
-
51
- def __call__(self,sample):
52
- imidx, image, label = sample['imidx'], sample['image'],sample['label']
53
-
54
- if random.random() >= 0.5:
55
- image = image[::-1]
56
- label = label[::-1]
57
-
58
- h, w = image.shape[:2]
59
-
60
- if isinstance(self.output_size,int):
61
- if h > w:
62
- new_h, new_w = self.output_size*h/w,self.output_size
63
- else:
64
- new_h, new_w = self.output_size,self.output_size*w/h
65
- else:
66
- new_h, new_w = self.output_size
67
-
68
- new_h, new_w = int(new_h), int(new_w)
69
-
70
- # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
71
- img = transform.resize(image,(new_h,new_w),mode='constant')
72
- lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
73
-
74
- return {'imidx':imidx, 'image':img,'label':lbl}
75
-
76
- class RandomCrop(object):
77
-
78
- def __init__(self,output_size):
79
- assert isinstance(output_size, (int, tuple))
80
- if isinstance(output_size, int):
81
- self.output_size = (output_size, output_size)
82
- else:
83
- assert len(output_size) == 2
84
- self.output_size = output_size
85
- def __call__(self,sample):
86
- imidx, image, label = sample['imidx'], sample['image'], sample['label']
87
-
88
- if random.random() >= 0.5:
89
- image = image[::-1]
90
- label = label[::-1]
91
-
92
- h, w = image.shape[:2]
93
- new_h, new_w = self.output_size
94
-
95
- top = np.random.randint(0, h - new_h)
96
- left = np.random.randint(0, w - new_w)
97
-
98
- image = image[top: top + new_h, left: left + new_w]
99
- label = label[top: top + new_h, left: left + new_w]
100
-
101
- return {'imidx':imidx,'image':image, 'label':label}
102
-
103
- class ToTensor(object):
104
- """Convert ndarrays in sample to Tensors."""
105
-
106
- def __call__(self, sample):
107
-
108
- imidx, image, label = sample['imidx'], sample['image'], sample['label']
109
-
110
- tmpImg = np.zeros((image.shape[0],image.shape[1],3))
111
- tmpLbl = np.zeros(label.shape)
112
-
113
- image = image/np.max(image)
114
- if(np.max(label)<1e-6):
115
- label = label
116
- else:
117
- label = label/np.max(label)
118
-
119
- if image.shape[2]==1:
120
- tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
121
- tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
122
- tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
123
- else:
124
- tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
125
- tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
126
- tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
127
-
128
- tmpLbl[:,:,0] = label[:,:,0]
129
-
130
-
131
- tmpImg = tmpImg.transpose((2, 0, 1))
132
- tmpLbl = label.transpose((2, 0, 1))
133
-
134
- return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
135
-
136
- class ToTensorLab(object):
137
- """Convert ndarrays in sample to Tensors."""
138
- def __init__(self,flag=0):
139
- self.flag = flag
140
-
141
- def __call__(self, sample):
142
-
143
- imidx, image, label =sample['imidx'], sample['image'], sample['label']
144
-
145
- tmpLbl = np.zeros(label.shape)
146
-
147
- if(np.max(label)<1e-6):
148
- label = label
149
- else:
150
- label = label/np.max(label)
151
-
152
- # change the color space
153
- if self.flag == 2: # with rgb and Lab colors
154
- tmpImg = np.zeros((image.shape[0],image.shape[1],6))
155
- tmpImgt = np.zeros((image.shape[0],image.shape[1],3))
156
- if image.shape[2]==1:
157
- tmpImgt[:,:,0] = image[:,:,0]
158
- tmpImgt[:,:,1] = image[:,:,0]
159
- tmpImgt[:,:,2] = image[:,:,0]
160
- else:
161
- tmpImgt = image
162
- tmpImgtl = color.rgb2lab(tmpImgt)
163
-
164
- # nomalize image to range [0,1]
165
- tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))
166
- tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))
167
- tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))
168
- tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))
169
- tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))
170
- tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))
171
-
172
- # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
173
-
174
- tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
175
- tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
176
- tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
177
- tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])
178
- tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])
179
- tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])
180
-
181
- elif self.flag == 1: #with Lab color
182
- tmpImg = np.zeros((image.shape[0],image.shape[1],3))
183
-
184
- if image.shape[2]==1:
185
- tmpImg[:,:,0] = image[:,:,0]
186
- tmpImg[:,:,1] = image[:,:,0]
187
- tmpImg[:,:,2] = image[:,:,0]
188
- else:
189
- tmpImg = image
190
-
191
- tmpImg = color.rgb2lab(tmpImg)
192
-
193
- # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
194
-
195
- tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))
196
- tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))
197
- tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))
198
-
199
- tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
200
- tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
201
- tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
202
-
203
- else: # with rgb color
204
- tmpImg = np.zeros((image.shape[0],image.shape[1],3))
205
- image = image/np.max(image)
206
- if image.shape[2]==1:
207
- tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
208
- tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
209
- tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
210
- else:
211
- tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
212
- tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
213
- tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
214
-
215
- tmpLbl[:,:,0] = label[:,:,0]
216
-
217
-
218
- tmpImg = tmpImg.transpose((2, 0, 1))
219
- tmpLbl = label.transpose((2, 0, 1))
220
-
221
- return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
222
-
223
- class SalObjDataset(Dataset):
224
- def __init__(self,img_name_list,lbl_name_list,transform=None):
225
- # self.root_dir = root_dir
226
- # self.image_name_list = glob.glob(image_dir+'*.png')
227
- # self.label_name_list = glob.glob(label_dir+'*.png')
228
- self.image_name_list = img_name_list
229
- self.label_name_list = lbl_name_list
230
- self.transform = transform
231
-
232
- def __len__(self):
233
- return len(self.image_name_list)
234
-
235
- def __getitem__(self,idx):
236
-
237
- # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
238
- # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
239
-
240
- image = io.imread(self.image_name_list[idx])
241
- imname = self.image_name_list[idx]
242
- imidx = np.array([idx])
243
-
244
- if(0==len(self.label_name_list)):
245
- label_3 = np.zeros(image.shape)
246
- else:
247
- label_3 = io.imread(self.label_name_list[idx])
248
-
249
- label = np.zeros(label_3.shape[0:2])
250
- if(3==len(label_3.shape)):
251
- label = label_3[:,:,0]
252
- elif(2==len(label_3.shape)):
253
- label = label_3
254
-
255
- if(3==len(image.shape) and 2==len(label.shape)):
256
- label = label[:,:,np.newaxis]
257
- elif(2==len(image.shape) and 2==len(label.shape)):
258
- image = image[:,:,np.newaxis]
259
- label = label[:,:,np.newaxis]
260
-
261
- sample = {'imidx':imidx, 'image':image, 'label':label}
262
-
263
- if self.transform:
264
- sample = self.transform(sample)
265
-
266
- return sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
u2net/u2net.py DELETED
@@ -1,525 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- class REBNCONV(nn.Module):
6
- def __init__(self,in_ch=3,out_ch=3,dirate=1):
7
- super(REBNCONV,self).__init__()
8
-
9
- self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
10
- self.bn_s1 = nn.BatchNorm2d(out_ch)
11
- self.relu_s1 = nn.ReLU(inplace=True)
12
-
13
- def forward(self,x):
14
-
15
- hx = x
16
- xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
17
-
18
- return xout
19
-
20
- ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
21
- def _upsample_like(src,tar):
22
-
23
- src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
24
-
25
- return src
26
-
27
-
28
- ### RSU-7 ###
29
- class RSU7(nn.Module):#UNet07DRES(nn.Module):
30
-
31
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
32
- super(RSU7,self).__init__()
33
-
34
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
35
-
36
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
37
- self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
38
-
39
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
40
- self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
41
-
42
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
43
- self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
44
-
45
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
46
- self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
47
-
48
- self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
49
- self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
50
-
51
- self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
52
-
53
- self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
54
-
55
- self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
56
- self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
57
- self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
58
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
59
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
60
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
61
-
62
- def forward(self,x):
63
-
64
- hx = x
65
- hxin = self.rebnconvin(hx)
66
-
67
- hx1 = self.rebnconv1(hxin)
68
- hx = self.pool1(hx1)
69
-
70
- hx2 = self.rebnconv2(hx)
71
- hx = self.pool2(hx2)
72
-
73
- hx3 = self.rebnconv3(hx)
74
- hx = self.pool3(hx3)
75
-
76
- hx4 = self.rebnconv4(hx)
77
- hx = self.pool4(hx4)
78
-
79
- hx5 = self.rebnconv5(hx)
80
- hx = self.pool5(hx5)
81
-
82
- hx6 = self.rebnconv6(hx)
83
-
84
- hx7 = self.rebnconv7(hx6)
85
-
86
- hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
87
- hx6dup = _upsample_like(hx6d,hx5)
88
-
89
- hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
90
- hx5dup = _upsample_like(hx5d,hx4)
91
-
92
- hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
93
- hx4dup = _upsample_like(hx4d,hx3)
94
-
95
- hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
96
- hx3dup = _upsample_like(hx3d,hx2)
97
-
98
- hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
99
- hx2dup = _upsample_like(hx2d,hx1)
100
-
101
- hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
102
-
103
- return hx1d + hxin
104
-
105
- ### RSU-6 ###
106
- class RSU6(nn.Module):#UNet06DRES(nn.Module):
107
-
108
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
109
- super(RSU6,self).__init__()
110
-
111
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
112
-
113
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
114
- self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
115
-
116
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
117
- self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
118
-
119
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
120
- self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
121
-
122
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
123
- self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
124
-
125
- self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
126
-
127
- self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
128
-
129
- self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
130
- self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
131
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
132
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
133
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
134
-
135
- def forward(self,x):
136
-
137
- hx = x
138
-
139
- hxin = self.rebnconvin(hx)
140
-
141
- hx1 = self.rebnconv1(hxin)
142
- hx = self.pool1(hx1)
143
-
144
- hx2 = self.rebnconv2(hx)
145
- hx = self.pool2(hx2)
146
-
147
- hx3 = self.rebnconv3(hx)
148
- hx = self.pool3(hx3)
149
-
150
- hx4 = self.rebnconv4(hx)
151
- hx = self.pool4(hx4)
152
-
153
- hx5 = self.rebnconv5(hx)
154
-
155
- hx6 = self.rebnconv6(hx5)
156
-
157
-
158
- hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
159
- hx5dup = _upsample_like(hx5d,hx4)
160
-
161
- hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
162
- hx4dup = _upsample_like(hx4d,hx3)
163
-
164
- hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
165
- hx3dup = _upsample_like(hx3d,hx2)
166
-
167
- hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
168
- hx2dup = _upsample_like(hx2d,hx1)
169
-
170
- hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
171
-
172
- return hx1d + hxin
173
-
174
- ### RSU-5 ###
175
- class RSU5(nn.Module):#UNet05DRES(nn.Module):
176
-
177
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
178
- super(RSU5,self).__init__()
179
-
180
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
181
-
182
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
183
- self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
184
-
185
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
186
- self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
187
-
188
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
189
- self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
190
-
191
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
192
-
193
- self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
194
-
195
- self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
196
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
197
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
198
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
199
-
200
- def forward(self,x):
201
-
202
- hx = x
203
-
204
- hxin = self.rebnconvin(hx)
205
-
206
- hx1 = self.rebnconv1(hxin)
207
- hx = self.pool1(hx1)
208
-
209
- hx2 = self.rebnconv2(hx)
210
- hx = self.pool2(hx2)
211
-
212
- hx3 = self.rebnconv3(hx)
213
- hx = self.pool3(hx3)
214
-
215
- hx4 = self.rebnconv4(hx)
216
-
217
- hx5 = self.rebnconv5(hx4)
218
-
219
- hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
220
- hx4dup = _upsample_like(hx4d,hx3)
221
-
222
- hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
223
- hx3dup = _upsample_like(hx3d,hx2)
224
-
225
- hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
226
- hx2dup = _upsample_like(hx2d,hx1)
227
-
228
- hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
229
-
230
- return hx1d + hxin
231
-
232
- ### RSU-4 ###
233
- class RSU4(nn.Module):#UNet04DRES(nn.Module):
234
-
235
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
236
- super(RSU4,self).__init__()
237
-
238
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
239
-
240
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
241
- self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
242
-
243
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
244
- self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
245
-
246
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
247
-
248
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
249
-
250
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
251
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
252
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
253
-
254
- def forward(self,x):
255
-
256
- hx = x
257
-
258
- hxin = self.rebnconvin(hx)
259
-
260
- hx1 = self.rebnconv1(hxin)
261
- hx = self.pool1(hx1)
262
-
263
- hx2 = self.rebnconv2(hx)
264
- hx = self.pool2(hx2)
265
-
266
- hx3 = self.rebnconv3(hx)
267
-
268
- hx4 = self.rebnconv4(hx3)
269
-
270
- hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
271
- hx3dup = _upsample_like(hx3d,hx2)
272
-
273
- hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
274
- hx2dup = _upsample_like(hx2d,hx1)
275
-
276
- hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
277
-
278
- return hx1d + hxin
279
-
280
- ### RSU-4F ###
281
- class RSU4F(nn.Module):#UNet04FRES(nn.Module):
282
-
283
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
284
- super(RSU4F,self).__init__()
285
-
286
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
287
-
288
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
289
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
290
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
291
-
292
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
293
-
294
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
295
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
296
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
297
-
298
- def forward(self,x):
299
-
300
- hx = x
301
-
302
- hxin = self.rebnconvin(hx)
303
-
304
- hx1 = self.rebnconv1(hxin)
305
- hx2 = self.rebnconv2(hx1)
306
- hx3 = self.rebnconv3(hx2)
307
-
308
- hx4 = self.rebnconv4(hx3)
309
-
310
- hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
311
- hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
312
- hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
313
-
314
- return hx1d + hxin
315
-
316
-
317
- ##### U^2-Net ####
318
- class U2NET(nn.Module):
319
-
320
- def __init__(self,in_ch=3,out_ch=1):
321
- super(U2NET,self).__init__()
322
-
323
- self.stage1 = RSU7(in_ch,32,64)
324
- self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
325
-
326
- self.stage2 = RSU6(64,32,128)
327
- self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
328
-
329
- self.stage3 = RSU5(128,64,256)
330
- self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
331
-
332
- self.stage4 = RSU4(256,128,512)
333
- self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
334
-
335
- self.stage5 = RSU4F(512,256,512)
336
- self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
337
-
338
- self.stage6 = RSU4F(512,256,512)
339
-
340
- # decoder
341
- self.stage5d = RSU4F(1024,256,512)
342
- self.stage4d = RSU4(1024,128,256)
343
- self.stage3d = RSU5(512,64,128)
344
- self.stage2d = RSU6(256,32,64)
345
- self.stage1d = RSU7(128,16,64)
346
-
347
- self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
348
- self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
349
- self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
350
- self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
351
- self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
352
- self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
353
-
354
- self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
355
-
356
- def forward(self,x):
357
-
358
- hx = x
359
-
360
- #stage 1
361
- hx1 = self.stage1(hx)
362
- hx = self.pool12(hx1)
363
-
364
- #stage 2
365
- hx2 = self.stage2(hx)
366
- hx = self.pool23(hx2)
367
-
368
- #stage 3
369
- hx3 = self.stage3(hx)
370
- hx = self.pool34(hx3)
371
-
372
- #stage 4
373
- hx4 = self.stage4(hx)
374
- hx = self.pool45(hx4)
375
-
376
- #stage 5
377
- hx5 = self.stage5(hx)
378
- hx = self.pool56(hx5)
379
-
380
- #stage 6
381
- hx6 = self.stage6(hx)
382
- hx6up = _upsample_like(hx6,hx5)
383
-
384
- #-------------------- decoder --------------------
385
- hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
386
- hx5dup = _upsample_like(hx5d,hx4)
387
-
388
- hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
389
- hx4dup = _upsample_like(hx4d,hx3)
390
-
391
- hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
392
- hx3dup = _upsample_like(hx3d,hx2)
393
-
394
- hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
395
- hx2dup = _upsample_like(hx2d,hx1)
396
-
397
- hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
398
-
399
-
400
- #side output
401
- d1 = self.side1(hx1d)
402
-
403
- d2 = self.side2(hx2d)
404
- d2 = _upsample_like(d2,d1)
405
-
406
- d3 = self.side3(hx3d)
407
- d3 = _upsample_like(d3,d1)
408
-
409
- d4 = self.side4(hx4d)
410
- d4 = _upsample_like(d4,d1)
411
-
412
- d5 = self.side5(hx5d)
413
- d5 = _upsample_like(d5,d1)
414
-
415
- d6 = self.side6(hx6)
416
- d6 = _upsample_like(d6,d1)
417
-
418
- d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
419
-
420
- return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)
421
-
422
- ### U^2-Net small ###
423
- class U2NETP(nn.Module):
424
-
425
- def __init__(self,in_ch=3,out_ch=1):
426
- super(U2NETP,self).__init__()
427
-
428
- self.stage1 = RSU7(in_ch,16,64)
429
- self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
430
-
431
- self.stage2 = RSU6(64,16,64)
432
- self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
433
-
434
- self.stage3 = RSU5(64,16,64)
435
- self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
436
-
437
- self.stage4 = RSU4(64,16,64)
438
- self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
439
-
440
- self.stage5 = RSU4F(64,16,64)
441
- self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
442
-
443
- self.stage6 = RSU4F(64,16,64)
444
-
445
- # decoder
446
- self.stage5d = RSU4F(128,16,64)
447
- self.stage4d = RSU4(128,16,64)
448
- self.stage3d = RSU5(128,16,64)
449
- self.stage2d = RSU6(128,16,64)
450
- self.stage1d = RSU7(128,16,64)
451
-
452
- self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
453
- self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
454
- self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
455
- self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
456
- self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
457
- self.side6 = nn.Conv2d(64,out_ch,3,padding=1)
458
-
459
- self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
460
-
461
- def forward(self,x):
462
-
463
- hx = x
464
-
465
- #stage 1
466
- hx1 = self.stage1(hx)
467
- hx = self.pool12(hx1)
468
-
469
- #stage 2
470
- hx2 = self.stage2(hx)
471
- hx = self.pool23(hx2)
472
-
473
- #stage 3
474
- hx3 = self.stage3(hx)
475
- hx = self.pool34(hx3)
476
-
477
- #stage 4
478
- hx4 = self.stage4(hx)
479
- hx = self.pool45(hx4)
480
-
481
- #stage 5
482
- hx5 = self.stage5(hx)
483
- hx = self.pool56(hx5)
484
-
485
- #stage 6
486
- hx6 = self.stage6(hx)
487
- hx6up = _upsample_like(hx6,hx5)
488
-
489
- #decoder
490
- hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
491
- hx5dup = _upsample_like(hx5d,hx4)
492
-
493
- hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
494
- hx4dup = _upsample_like(hx4d,hx3)
495
-
496
- hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
497
- hx3dup = _upsample_like(hx3d,hx2)
498
-
499
- hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
500
- hx2dup = _upsample_like(hx2d,hx1)
501
-
502
- hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
503
-
504
-
505
- #side output
506
- d1 = self.side1(hx1d)
507
-
508
- d2 = self.side2(hx2d)
509
- d2 = _upsample_like(d2,d1)
510
-
511
- d3 = self.side3(hx3d)
512
- d3 = _upsample_like(d3,d1)
513
-
514
- d4 = self.side4(hx4d)
515
- d4 = _upsample_like(d4,d1)
516
-
517
- d5 = self.side5(hx5d)
518
- d5 = _upsample_like(d5,d1)
519
-
520
- d6 = self.side6(hx6)
521
- d6 = _upsample_like(d6,d1)
522
-
523
- d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
524
-
525
- return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
u2net/u2net_inference.py DELETED
@@ -1,100 +0,0 @@
1
- import os
2
- from typing import Union
3
- from skimage import io, transform
4
- import torch
5
- import torchvision
6
- from torch.autograd import Variable
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from torch.utils.data import Dataset, DataLoader
10
- from torchvision import transforms#, utils
11
- # import torch.optim as optim
12
-
13
- import numpy as np
14
- from PIL import Image
15
- import glob
16
-
17
- from .data_loader import RescaleT
18
- from .data_loader import ToTensor
19
- from .data_loader import ToTensorLab
20
- from .data_loader import SalObjDataset
21
-
22
- from .u2net import U2NET # full size version 173.6 MB
23
- from .u2net import U2NETP # small version u2net 4.7 MB
24
-
25
-
26
- # normalize the predicted SOD probability map
27
- def normPRED(d):
28
- ma = torch.max(d)
29
- mi = torch.min(d)
30
-
31
- dn = (d-mi)/(ma-mi)
32
-
33
- return dn
34
-
35
- def save_output(image_name,pred,d_dir):
36
-
37
- predict = pred
38
- predict = predict.squeeze()
39
- predict_np = predict.cpu().data.numpy()
40
-
41
- im = Image.fromarray(predict_np*255).convert('RGB')
42
- img_name = image_name.split(os.sep)[-1]
43
- image = io.imread(image_name)
44
- imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
45
-
46
- pb_np = np.array(imo)
47
-
48
- aaa = img_name.split(".")
49
- bbb = aaa[0:-1]
50
- imidx = bbb[0]
51
- for i in range(1,len(bbb)):
52
- imidx = imidx + "." + bbb[i]
53
-
54
- imo.save(d_dir+imidx+'.png')
55
-
56
-
57
- def get_u2net_model():
58
- model_pth = "models/u2net.pth"
59
- net = U2NET(3,1)
60
-
61
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
62
- net.load_state_dict(torch.load(model_pth, map_location=device))
63
- net.eval()
64
-
65
- return net
66
-
67
-
68
- def get_saliency_mask(model, image_or_image_path : Union[str, np.array]):
69
-
70
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
71
-
72
- if isinstance(image_or_image_path, str):
73
- image = io.imread(image_or_image_path)
74
- else:
75
- image = image_or_image_path
76
-
77
- transform = transforms.Compose([RescaleT(320), ToTensorLab(flag=0)])
78
- sample = transform({
79
- 'imidx' : np.array([0]),
80
- 'image' : image,
81
- 'label' : np.expand_dims(np.zeros(image.shape[:-1]), -1)
82
- })
83
-
84
- input_test = sample["image"].unsqueeze(0).type(torch.FloatTensor).to(device)
85
-
86
- d1,d2,d3,d4,d5,d6,d7= model(input_test)
87
-
88
- pred = d1[:,0,:,:]
89
- pred = normPRED(pred)
90
-
91
- pred = pred.squeeze()
92
- predict_np = pred.cpu().data.numpy()
93
-
94
- rescaled = predict_np
95
- rescaled = rescaled - np.min(rescaled)
96
- rescaled = rescaled / np.max(rescaled)
97
-
98
- im = Image.fromarray(rescaled * 255).convert("RGB")
99
-
100
- return im