jamino30 commited on
Commit
e6f200a
1 Parent(s): cc9f69c

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. app.py +1 -1
  2. u2net/model.py +151 -0
  3. u2net/train.py +1 -0
  4. vgg/vgg16.py +72 -0
  5. vgg/vgg19.py +78 -0
app.py CHANGED
@@ -11,7 +11,7 @@ import gradio as gr
11
  from gradio_imageslider import ImageSlider
12
 
13
  from utils import preprocess_img, preprocess_img_from_path, postprocess_img
14
- from vgg19 import VGG_19
15
  from inference import inference
16
 
17
  if torch.cuda.is_available(): device = 'cuda'
 
11
  from gradio_imageslider import ImageSlider
12
 
13
  from utils import preprocess_img, preprocess_img_from_path, postprocess_img
14
+ from vgg.vgg19 import VGG_19
15
  from inference import inference
16
 
17
  if torch.cuda.is_available(): device = 'cuda'
u2net/model.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def init_weight(layer):
6
+ nn.init.xavier_uniform_(layer.weight)
7
+ if layer.bias is not None:
8
+ nn.init.constant_(layer.bias, 0)
9
+
10
+
11
+ class ConvBlock(nn.Module):
12
+ def __init__(self, in_channel, out_channel, dilation=1):
13
+ super(ConvBlock, self).__init__()
14
+ self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=dilation, dilation=dilation)
15
+ self.bn = nn.BatchNorm2d(out_channel)
16
+ self.relu = nn.ReLU(inplace=True)
17
+ init_weight(self.conv)
18
+
19
+ def forward(self, x):
20
+ x = self.conv(x)
21
+ x = self.bn(x)
22
+ x = self.relu(x)
23
+ return x
24
+
25
+
26
+ class RSU(nn.Module):
27
+ def __init__(self, L, C_in, C_out, M):
28
+ super(RSU, self).__init__()
29
+ self.conv = ConvBlock(C_in, C_out)
30
+
31
+ self.enc = nn.ModuleList([ConvBlock(C_out, M)])
32
+ for i in range(L-2):
33
+ self.enc.append(ConvBlock(M, M))
34
+
35
+ self.mid = ConvBlock(M, M, dilation=2)
36
+
37
+ self.dec = nn.ModuleList([ConvBlock(2*M, M) for _ in range(L-2)])
38
+ self.dec.append(ConvBlock(2*M, C_out))
39
+
40
+ self.downsample = nn.MaxPool2d(2, stride=2)
41
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
42
+
43
+ def forward(self, x):
44
+ x = self.conv(x)
45
+
46
+ out = []
47
+ for i, enc in enumerate(self.enc):
48
+ if i == 0: out.append(enc(x))
49
+ else: out.append(enc(self.downsample(out[i-1])))
50
+
51
+ y = self.mid(out[-1])
52
+
53
+ for i, dec in enumerate(self.dec):
54
+ if i > 0: y = self.upsample(y)
55
+ y = dec(torch.cat((out[len(self.dec)-i-1], y), dim=1))
56
+
57
+ return x + y
58
+
59
+
60
+ class RSU4F(nn.Module):
61
+ def __init__(self, C_in, C_out, M):
62
+ super(RSU4F, self).__init__()
63
+ self.conv = ConvBlock(C_in, C_out)
64
+
65
+ self.enc = nn.ModuleList([
66
+ ConvBlock(C_out, M),
67
+ ConvBlock(M, M, dilation=2),
68
+ ConvBlock(M, M, dilation=4)
69
+ ])
70
+
71
+ self.mid = ConvBlock(M, M, dilation=8)
72
+
73
+ self.dec = nn.ModuleList([
74
+ ConvBlock(2*M, M, dilation=4),
75
+ ConvBlock(2*M, M, dilation=2),
76
+ ConvBlock(2*M, C_out)
77
+ ])
78
+
79
+ def forward(self, x):
80
+ x = self.conv(x)
81
+
82
+ out = []
83
+ for i, enc in enumerate(self.enc):
84
+ if i == 0: out.append(enc(x))
85
+ else: out.append(enc(out[i-1]))
86
+
87
+ y = self.mid(out[-1])
88
+
89
+ for i, dec in enumerate(self.dec):
90
+ y = dec(torch.cat((out[len(self.dec)-i-1], y), dim=1))
91
+
92
+ return x + y
93
+
94
+
95
+ class U2Net(nn.Module):
96
+ def __init__(self):
97
+ super(U2Net, self).__init__()
98
+ self.enc = nn.ModuleList([
99
+ RSU(L=7, C_in=3, C_out=64, M=32),
100
+ RSU(L=6, C_in=64, C_out=128, M=32),
101
+ RSU(L=5, C_in=128, C_out=256, M=64),
102
+ RSU(L=4, C_in=256, C_out=512, M=128),
103
+ RSU4F(C_in=512, C_out=512, M=256),
104
+ RSU4F(C_in=512, C_out=512, M=256)
105
+ ])
106
+
107
+ self.dec = nn.ModuleList([
108
+ RSU4F(C_in=1024, C_out=512, M=256),
109
+ RSU(L=4, C_in=1024, C_out=256, M=128),
110
+ RSU(L=5, C_in=512, C_out=128, M=64),
111
+ RSU(L=6, C_in=256, C_out=64, M=32),
112
+ RSU(L=7, C_in=128, C_out=64, M=16)
113
+ ])
114
+
115
+ self.convs = nn.ModuleList([
116
+ nn.Conv2d(64, 1, 3, padding=1),
117
+ nn.Conv2d(64, 1, 3, padding=1),
118
+ nn.Conv2d(128, 1, 3, padding=1),
119
+ nn.Conv2d(256, 1, 3, padding=1),
120
+ nn.Conv2d(512, 1, 3, padding=1),
121
+ nn.Conv2d(512, 1, 3, padding=1)
122
+ ])
123
+
124
+ self.lastconv = nn.Conv2d(6, 1, 1)
125
+ self.downsample = nn.MaxPool2d(2, stride=2)
126
+
127
+ init_weight(self.lastconv)
128
+ for conv in self.convs:
129
+ init_weight(conv)
130
+
131
+ def upsample(self, x, target):
132
+ return F.interpolate(x, size=target.shape[2:], mode='bilinear')
133
+
134
+ def forward(self, x):
135
+ enc_out = []
136
+ for i, enc in enumerate(self.enc):
137
+ if i == 0: enc_out.append(enc(x))
138
+ else: enc_out.append(enc(self.downsample(enc_out[i-1])))
139
+
140
+ dec_out = [enc_out[-1]]
141
+ for i, dec in enumerate(self.dec):
142
+ dec_out.append(dec(torch.cat((self.upsample(dec_out[i], enc_out[4-i]), enc_out[4-i]), dim=1)))
143
+
144
+ side_out = []
145
+ for i, conv in enumerate(self.convs):
146
+ if i == 0: side_out.append(conv(dec_out[5]))
147
+ else: side_out.append(self.upsample(conv(dec_out[5-i]), side_out[0]))
148
+
149
+ side_out.append(self.lastconv(torch.cat(side_out, dim=1)))
150
+
151
+ return [torch.sigmoid(s.squeeze(1)) for s in side_out]
u2net/train.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # for training u2net
vgg/vgg16.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torchvision.models as models
3
+
4
+ """ VGG_16 Architecture
5
+ VGG(
6
+ (features): Sequential(
7
+ (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
8
+ (1): ReLU(inplace=True)
9
+ (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
10
+ (3): ReLU(inplace=True)
11
+ (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
12
+ (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
13
+ (6): ReLU(inplace=True)
14
+ (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
15
+ (8): ReLU(inplace=True)
16
+ (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
17
+ (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
18
+ (11): ReLU(inplace=True)
19
+ (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
20
+ (13): ReLU(inplace=True)
21
+ (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
22
+ (15): ReLU(inplace=True)
23
+ (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
24
+ (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
25
+ (18): ReLU(inplace=True)
26
+ (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
27
+ (20): ReLU(inplace=True)
28
+ (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
29
+ (22): ReLU(inplace=True)
30
+ (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
31
+ (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
32
+ (25): ReLU(inplace=True)
33
+ (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
34
+ (27): ReLU(inplace=True)
35
+ (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
36
+ (29): ReLU(inplace=True)
37
+ (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
38
+ )
39
+ (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
40
+ (classifier): Sequential(
41
+ (0): Linear(in_features=25088, out_features=4096, bias=True)
42
+ (1): ReLU(inplace=True)
43
+ (2): Dropout(p=0.5, inplace=False)
44
+ (3): Linear(in_features=4096, out_features=4096, bias=True)
45
+ (4): ReLU(inplace=True)
46
+ (5): Dropout(p=0.5, inplace=False)
47
+ (6): Linear(in_features=4096, out_features=1000, bias=True)
48
+ )
49
+ )
50
+ """
51
+
52
+ class VGG_16(nn.Module):
53
+ def __init__(self):
54
+ super(VGG_16, self).__init__()
55
+ self.model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features[:30]
56
+
57
+ for i, _ in enumerate(self.model):
58
+ if i in [4, 9, 16, 23]:
59
+ self.model[i] = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
60
+
61
+ def forward(self, x):
62
+ features = []
63
+ for i, layer in enumerate(self.model):
64
+ x = layer(x)
65
+ if i in [0, 5, 10, 17, 24]:
66
+ features.append(x)
67
+ return features
68
+
69
+
70
+ if __name__ == '__main__':
71
+ model = VGG_16()
72
+ print(model)
vgg/vgg19.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torchvision.models as models
3
+
4
+ """ VGG_19 Architecture
5
+ VGG(
6
+ (features): Sequential(
7
+ (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
8
+ (1): ReLU(inplace=True)
9
+ (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
10
+ (3): ReLU(inplace=True)
11
+ (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
12
+ (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
13
+ (6): ReLU(inplace=True)
14
+ (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
15
+ (8): ReLU(inplace=True)
16
+ (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
17
+ (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
18
+ (11): ReLU(inplace=True)
19
+ (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
20
+ (13): ReLU(inplace=True)
21
+ (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
22
+ (15): ReLU(inplace=True)
23
+ (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
24
+ (17): ReLU(inplace=True)
25
+ (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
26
+ (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
27
+ (20): ReLU(inplace=True)
28
+ (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
29
+ (22): ReLU(inplace=True)
30
+ (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
31
+ (24): ReLU(inplace=True)
32
+ (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
33
+ (26): ReLU(inplace=True)
34
+ (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
35
+ (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
36
+ (29): ReLU(inplace=True)
37
+ (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
38
+ (31): ReLU(inplace=True)
39
+ (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
40
+ (33): ReLU(inplace=True)
41
+ (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
42
+ (35): ReLU(inplace=True)
43
+ (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
44
+ )
45
+ (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
46
+ (classifier): Sequential(
47
+ (0): Linear(in_features=25088, out_features=4096, bias=True)
48
+ (1): ReLU(inplace=True)
49
+ (2): Dropout(p=0.5, inplace=False)
50
+ (3): Linear(in_features=4096, out_features=4096, bias=True)
51
+ (4): ReLU(inplace=True)
52
+ (5): Dropout(p=0.5, inplace=False)
53
+ (6): Linear(in_features=4096, out_features=1000, bias=True)
54
+ )
55
+ )
56
+ """
57
+
58
+ class VGG_19(nn.Module):
59
+ def __init__(self):
60
+ super(VGG_19, self).__init__()
61
+ self.model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features[:30]
62
+
63
+ for i, _ in enumerate(self.model):
64
+ if i in [4, 9, 18, 27]:
65
+ self.model[i] = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
66
+
67
+ def forward(self, x):
68
+ features = []
69
+ for i, layer in enumerate(self.model):
70
+ x = layer(x)
71
+ if i in [0, 5, 10, 19, 28]:
72
+ features.append(x)
73
+ return features
74
+
75
+
76
+ if __name__ == '__main__':
77
+ model = VGG_19()
78
+ print(model)