OriLib commited on
Commit
18baaa5
1 Parent(s): 5ed1996

Delete models/isnet.py

Browse files
Files changed (1) hide show
  1. models/isnet.py +0 -611
models/isnet.py DELETED
@@ -1,611 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torchvision import models
4
- import torch.nn.functional as F
5
-
6
-
7
- bce_loss = nn.BCELoss(size_average=True)
8
- def muti_loss_fusion(preds, target):
9
- loss0 = 0.0
10
- loss = 0.0
11
-
12
- for i in range(0,len(preds)):
13
- # print("i: ", i, preds[i].shape)
14
- if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
15
- # tmp_target = _upsample_like(target,preds[i])
16
- tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
17
- loss = loss + bce_loss(preds[i],tmp_target)
18
- else:
19
- loss = loss + bce_loss(preds[i],target)
20
- if(i==0):
21
- loss0 = loss
22
- return loss0, loss
23
-
24
- fea_loss = nn.MSELoss(size_average=True)
25
- kl_loss = nn.KLDivLoss(size_average=True)
26
- l1_loss = nn.L1Loss(size_average=True)
27
- smooth_l1_loss = nn.SmoothL1Loss(size_average=True)
28
- def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'):
29
- loss0 = 0.0
30
- loss = 0.0
31
-
32
- for i in range(0,len(preds)):
33
- # print("i: ", i, preds[i].shape)
34
- if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
35
- # tmp_target = _upsample_like(target,preds[i])
36
- tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
37
- loss = loss + bce_loss(preds[i],tmp_target)
38
- else:
39
- loss = loss + bce_loss(preds[i],target)
40
- if(i==0):
41
- loss0 = loss
42
-
43
- for i in range(0,len(dfs)):
44
- if(mode=='MSE'):
45
- loss = loss + fea_loss(dfs[i],fs[i]) ### add the mse loss of features as additional constraints
46
- # print("fea_loss: ", fea_loss(dfs[i],fs[i]).item())
47
- elif(mode=='KL'):
48
- loss = loss + kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1))
49
- # print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item())
50
- elif(mode=='MAE'):
51
- loss = loss + l1_loss(dfs[i],fs[i])
52
- # print("ls_loss: ", l1_loss(dfs[i],fs[i]))
53
- elif(mode=='SmoothL1'):
54
- loss = loss + smooth_l1_loss(dfs[i],fs[i])
55
- # print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item())
56
-
57
- return loss0, loss
58
-
59
- class REBNCONV(nn.Module):
60
- def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
61
- super(REBNCONV,self).__init__()
62
-
63
- self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
64
- self.bn_s1 = nn.BatchNorm2d(out_ch)
65
- self.relu_s1 = nn.ReLU(inplace=True)
66
-
67
- def forward(self,x):
68
-
69
- hx = x
70
- xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
71
-
72
- return xout
73
-
74
- ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
75
- def _upsample_like(src,tar):
76
-
77
- src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
78
-
79
- return src
80
-
81
-
82
- ### RSU-7 ###
83
- class RSU7(nn.Module):
84
-
85
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
86
- super(RSU7,self).__init__()
87
-
88
- self.in_ch = in_ch
89
- self.mid_ch = mid_ch
90
- self.out_ch = out_ch
91
-
92
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
93
-
94
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
95
- self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
96
-
97
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
98
- self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
99
-
100
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
101
- self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
102
-
103
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
104
- self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
105
-
106
- self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
107
- self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
108
-
109
- self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
110
-
111
- self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
112
-
113
- self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
114
- self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
115
- self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
116
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
117
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
118
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
119
-
120
- def forward(self,x):
121
- b, c, h, w = x.shape
122
-
123
- hx = x
124
- hxin = self.rebnconvin(hx)
125
-
126
- hx1 = self.rebnconv1(hxin)
127
- hx = self.pool1(hx1)
128
-
129
- hx2 = self.rebnconv2(hx)
130
- hx = self.pool2(hx2)
131
-
132
- hx3 = self.rebnconv3(hx)
133
- hx = self.pool3(hx3)
134
-
135
- hx4 = self.rebnconv4(hx)
136
- hx = self.pool4(hx4)
137
-
138
- hx5 = self.rebnconv5(hx)
139
- hx = self.pool5(hx5)
140
-
141
- hx6 = self.rebnconv6(hx)
142
-
143
- hx7 = self.rebnconv7(hx6)
144
-
145
- hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
146
- hx6dup = _upsample_like(hx6d,hx5)
147
-
148
- hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
149
- hx5dup = _upsample_like(hx5d,hx4)
150
-
151
- hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
152
- hx4dup = _upsample_like(hx4d,hx3)
153
-
154
- hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
155
- hx3dup = _upsample_like(hx3d,hx2)
156
-
157
- hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
158
- hx2dup = _upsample_like(hx2d,hx1)
159
-
160
- hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
161
-
162
- return hx1d + hxin
163
-
164
-
165
- ### RSU-6 ###
166
- class RSU6(nn.Module):
167
-
168
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
169
- super(RSU6,self).__init__()
170
-
171
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
172
-
173
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
174
- self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
175
-
176
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
177
- self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
178
-
179
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
180
- self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
181
-
182
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
183
- self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
184
-
185
- self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
186
-
187
- self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
188
-
189
- self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
190
- self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
191
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
192
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
193
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
194
-
195
- def forward(self,x):
196
-
197
- hx = x
198
-
199
- hxin = self.rebnconvin(hx)
200
-
201
- hx1 = self.rebnconv1(hxin)
202
- hx = self.pool1(hx1)
203
-
204
- hx2 = self.rebnconv2(hx)
205
- hx = self.pool2(hx2)
206
-
207
- hx3 = self.rebnconv3(hx)
208
- hx = self.pool3(hx3)
209
-
210
- hx4 = self.rebnconv4(hx)
211
- hx = self.pool4(hx4)
212
-
213
- hx5 = self.rebnconv5(hx)
214
-
215
- hx6 = self.rebnconv6(hx5)
216
-
217
-
218
- hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
219
- hx5dup = _upsample_like(hx5d,hx4)
220
-
221
- hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
222
- hx4dup = _upsample_like(hx4d,hx3)
223
-
224
- hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
225
- hx3dup = _upsample_like(hx3d,hx2)
226
-
227
- hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
228
- hx2dup = _upsample_like(hx2d,hx1)
229
-
230
- hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
231
-
232
- return hx1d + hxin
233
-
234
- ### RSU-5 ###
235
- class RSU5(nn.Module):
236
-
237
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
238
- super(RSU5,self).__init__()
239
-
240
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
241
-
242
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
243
- self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
244
-
245
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
246
- self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
247
-
248
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
249
- self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
250
-
251
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
252
-
253
- self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
254
-
255
- self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
256
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
257
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
258
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
259
-
260
- def forward(self,x):
261
-
262
- hx = x
263
-
264
- hxin = self.rebnconvin(hx)
265
-
266
- hx1 = self.rebnconv1(hxin)
267
- hx = self.pool1(hx1)
268
-
269
- hx2 = self.rebnconv2(hx)
270
- hx = self.pool2(hx2)
271
-
272
- hx3 = self.rebnconv3(hx)
273
- hx = self.pool3(hx3)
274
-
275
- hx4 = self.rebnconv4(hx)
276
-
277
- hx5 = self.rebnconv5(hx4)
278
-
279
- hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
280
- hx4dup = _upsample_like(hx4d,hx3)
281
-
282
- hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
283
- hx3dup = _upsample_like(hx3d,hx2)
284
-
285
- hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
286
- hx2dup = _upsample_like(hx2d,hx1)
287
-
288
- hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
289
-
290
- return hx1d + hxin
291
-
292
- ### RSU-4 ###
293
- class RSU4(nn.Module):
294
-
295
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
296
- super(RSU4,self).__init__()
297
-
298
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
299
-
300
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
301
- self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
302
-
303
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
304
- self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
305
-
306
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
307
-
308
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
309
-
310
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
311
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
312
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
313
-
314
- def forward(self,x):
315
-
316
- hx = x
317
-
318
- hxin = self.rebnconvin(hx)
319
-
320
- hx1 = self.rebnconv1(hxin)
321
- hx = self.pool1(hx1)
322
-
323
- hx2 = self.rebnconv2(hx)
324
- hx = self.pool2(hx2)
325
-
326
- hx3 = self.rebnconv3(hx)
327
-
328
- hx4 = self.rebnconv4(hx3)
329
-
330
- hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
331
- hx3dup = _upsample_like(hx3d,hx2)
332
-
333
- hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
334
- hx2dup = _upsample_like(hx2d,hx1)
335
-
336
- hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
337
-
338
- return hx1d + hxin
339
-
340
- ### RSU-4F ###
341
- class RSU4F(nn.Module):
342
-
343
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
344
- super(RSU4F,self).__init__()
345
-
346
- self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
347
-
348
- self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
349
- self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
350
- self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
351
-
352
- self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
353
-
354
- self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
355
- self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
356
- self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
357
-
358
- def forward(self,x):
359
-
360
- hx = x
361
-
362
- hxin = self.rebnconvin(hx)
363
-
364
- hx1 = self.rebnconv1(hxin)
365
- hx2 = self.rebnconv2(hx1)
366
- hx3 = self.rebnconv3(hx2)
367
-
368
- hx4 = self.rebnconv4(hx3)
369
-
370
- hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
371
- hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
372
- hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
373
-
374
- return hx1d + hxin
375
-
376
-
377
- class myrebnconv(nn.Module):
378
- def __init__(self, in_ch=3,
379
- out_ch=1,
380
- kernel_size=3,
381
- stride=1,
382
- padding=1,
383
- dilation=1,
384
- groups=1):
385
- super(myrebnconv,self).__init__()
386
-
387
- self.conv = nn.Conv2d(in_ch,
388
- out_ch,
389
- kernel_size=kernel_size,
390
- stride=stride,
391
- padding=padding,
392
- dilation=dilation,
393
- groups=groups)
394
- self.bn = nn.BatchNorm2d(out_ch)
395
- self.rl = nn.ReLU(inplace=True)
396
-
397
- def forward(self,x):
398
- return self.rl(self.bn(self.conv(x)))
399
-
400
-
401
- class ISNetGTEncoder(nn.Module):
402
-
403
- def __init__(self,in_ch=1,out_ch=1):
404
- super(ISNetGTEncoder,self).__init__()
405
-
406
- self.conv_in = myrebnconv(in_ch,16,3,stride=2,padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
407
-
408
- self.stage1 = RSU7(16,16,64)
409
- self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
410
-
411
- self.stage2 = RSU6(64,16,64)
412
- self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
413
-
414
- self.stage3 = RSU5(64,32,128)
415
- self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
416
-
417
- self.stage4 = RSU4(128,32,256)
418
- self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
419
-
420
- self.stage5 = RSU4F(256,64,512)
421
- self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
422
-
423
- self.stage6 = RSU4F(512,64,512)
424
-
425
-
426
- self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
427
- self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
428
- self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
429
- self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
430
- self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
431
- self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
432
-
433
- def compute_loss(self, preds, targets):
434
-
435
- return muti_loss_fusion(preds,targets)
436
-
437
- def forward(self,x):
438
-
439
- hx = x
440
-
441
- hxin = self.conv_in(hx)
442
- # hx = self.pool_in(hxin)
443
-
444
- #stage 1
445
- hx1 = self.stage1(hxin)
446
- hx = self.pool12(hx1)
447
-
448
- #stage 2
449
- hx2 = self.stage2(hx)
450
- hx = self.pool23(hx2)
451
-
452
- #stage 3
453
- hx3 = self.stage3(hx)
454
- hx = self.pool34(hx3)
455
-
456
- #stage 4
457
- hx4 = self.stage4(hx)
458
- hx = self.pool45(hx4)
459
-
460
- #stage 5
461
- hx5 = self.stage5(hx)
462
- hx = self.pool56(hx5)
463
-
464
- #stage 6
465
- hx6 = self.stage6(hx)
466
-
467
-
468
- #side output
469
- d1 = self.side1(hx1)
470
- d1 = _upsample_like(d1,x)
471
-
472
- d2 = self.side2(hx2)
473
- d2 = _upsample_like(d2,x)
474
-
475
- d3 = self.side3(hx3)
476
- d3 = _upsample_like(d3,x)
477
-
478
- d4 = self.side4(hx4)
479
- d4 = _upsample_like(d4,x)
480
-
481
- d5 = self.side5(hx5)
482
- d5 = _upsample_like(d5,x)
483
-
484
- d6 = self.side6(hx6)
485
- d6 = _upsample_like(d6,x)
486
-
487
- # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
488
-
489
- return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1,hx2,hx3,hx4,hx5,hx6]
490
-
491
- class ISNetDIS(nn.Module):
492
-
493
- def __init__(self,in_ch=3,out_ch=1):
494
- super(ISNetDIS,self).__init__()
495
-
496
- self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
497
- self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
498
-
499
- self.stage1 = RSU7(64,32,64)
500
- self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
501
-
502
- self.stage2 = RSU6(64,32,128)
503
- self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
504
-
505
- self.stage3 = RSU5(128,64,256)
506
- self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
507
-
508
- self.stage4 = RSU4(256,128,512)
509
- self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
510
-
511
- self.stage5 = RSU4F(512,256,512)
512
- self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
513
-
514
- self.stage6 = RSU4F(512,256,512)
515
-
516
- # decoder
517
- self.stage5d = RSU4F(1024,256,512)
518
- self.stage4d = RSU4(1024,128,256)
519
- self.stage3d = RSU5(512,64,128)
520
- self.stage2d = RSU6(256,32,64)
521
- self.stage1d = RSU7(128,16,64)
522
-
523
- self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
524
- self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
525
- self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
526
- self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
527
- self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
528
- self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
529
-
530
- # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
531
-
532
- def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'):
533
-
534
- # return muti_loss_fusion(preds,targets)
535
- return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
536
-
537
- def compute_loss(self, preds, targets):
538
-
539
- # return muti_loss_fusion(preds,targets)
540
- return muti_loss_fusion(preds, targets)
541
-
542
- def forward(self,x):
543
-
544
- hx = x
545
-
546
- hxin = self.conv_in(hx)
547
- #hx = self.pool_in(hxin)
548
-
549
- #stage 1
550
- hx1 = self.stage1(hxin)
551
- hx = self.pool12(hx1)
552
-
553
- #stage 2
554
- hx2 = self.stage2(hx)
555
- hx = self.pool23(hx2)
556
-
557
- #stage 3
558
- hx3 = self.stage3(hx)
559
- hx = self.pool34(hx3)
560
-
561
- #stage 4
562
- hx4 = self.stage4(hx)
563
- hx = self.pool45(hx4)
564
-
565
- #stage 5
566
- hx5 = self.stage5(hx)
567
- hx = self.pool56(hx5)
568
-
569
- #stage 6
570
- hx6 = self.stage6(hx)
571
- hx6up = _upsample_like(hx6,hx5)
572
-
573
- #-------------------- decoder --------------------
574
- hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
575
- hx5dup = _upsample_like(hx5d,hx4)
576
-
577
- hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
578
- hx4dup = _upsample_like(hx4d,hx3)
579
-
580
- hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
581
- hx3dup = _upsample_like(hx3d,hx2)
582
-
583
- hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
584
- hx2dup = _upsample_like(hx2d,hx1)
585
-
586
- hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
587
-
588
-
589
- #side output
590
- d1 = self.side1(hx1d)
591
- d1 = _upsample_like(d1,x)
592
-
593
- d2 = self.side2(hx2d)
594
- d2 = _upsample_like(d2,x)
595
-
596
- d3 = self.side3(hx3d)
597
- d3 = _upsample_like(d3,x)
598
-
599
- d4 = self.side4(hx4d)
600
- d4 = _upsample_like(d4,x)
601
-
602
- d5 = self.side5(hx5d)
603
- d5 = _upsample_like(d5,x)
604
-
605
- d6 = self.side6(hx6)
606
- d6 = _upsample_like(d6,x)
607
-
608
- # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
609
-
610
- return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
611
- # return F.sigmoid(d1)