Zheng-MJ commited on
Commit
9436731
·
verified ·
1 Parent(s): b9029bc

Create models/smfanet_arch.py

Browse files
Files changed (1) hide show
  1. models/smfanet_arch.py +102 -0
models/smfanet_arch.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import ops
5
+
6
+ class DMlp(nn.Module):
7
+ def __init__(self, dim, growth_rate=2.0):
8
+ super().__init__()
9
+ hidden_dim = int(dim * growth_rate)
10
+ self.conv_0 = nn.Sequential(
11
+ nn.Conv2d(dim,hidden_dim,3,1,1,groups=dim),
12
+ nn.Conv2d(hidden_dim,hidden_dim,1,1,0)
13
+ )
14
+ self.act =nn.GELU()
15
+ self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)
16
+
17
+ def forward(self, x):
18
+ x = self.conv_0(x)
19
+ x = self.act(x)
20
+ x = self.conv_1(x)
21
+ return x
22
+
23
+ class PCFN(nn.Module):
24
+ def __init__(self, dim, growth_rate=2.0, p_rate=0.25):
25
+ super().__init__()
26
+ hidden_dim = int(dim * growth_rate)
27
+ p_dim = int(hidden_dim * p_rate)
28
+ self.conv_0 = nn.Conv2d(dim,hidden_dim,1,1,0)
29
+ self.conv_1 = nn.Conv2d(p_dim, p_dim ,3,1,1)
30
+
31
+ self.act =nn.GELU()
32
+ self.conv_2 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)
33
+
34
+ self.p_dim = p_dim
35
+ self.hidden_dim = hidden_dim
36
+
37
+ def forward(self, x):
38
+ if self.training:
39
+ x = self.act(self.conv_0(x))
40
+ x1, x2 = torch.split(x,[self.p_dim,self.hidden_dim-self.p_dim],dim=1)
41
+ x1 = self.act(self.conv_1(x1))
42
+ x = self.conv_2(torch.cat([x1,x2], dim=1))
43
+ else:
44
+ x = self.act(self.conv_0(x))
45
+ x[:,:self.p_dim,:,:] = self.act(self.conv_1(x[:,:self.p_dim,:,:]))
46
+ x = self.conv_2(x)
47
+ return x
48
+
49
+ class SMFA(nn.Module):
50
+ def __init__(self, dim=36):
51
+ super(SMFA, self).__init__()
52
+ self.linear_0 = nn.Conv2d(dim,dim*2,1,1,0)
53
+ self.linear_1 = nn.Conv2d(dim,dim,1,1,0)
54
+ self.linear_2 = nn.Conv2d(dim,dim,1,1,0)
55
+
56
+ self.lde = DMlp(dim,2)
57
+
58
+ self.dw_conv = nn.Conv2d(dim,dim,3,1,1,groups=dim)
59
+
60
+ self.gelu = nn.GELU()
61
+ self.down_scale = 8
62
+
63
+ self.alpha = nn.Parameter(torch.ones((1,dim,1,1)))
64
+ self.belt = nn.Parameter(torch.zeros((1,dim,1,1)))
65
+
66
+ def forward(self, f):
67
+ _,_,h,w = f.shape
68
+ y, x = self.linear_0(f).chunk(2, dim=1)
69
+ x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale)))
70
+ x_v = torch.var(x, dim=(-2,-1), keepdim=True)
71
+ x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h,w), mode='nearest')
72
+ y_d = self.lde(y)
73
+ return self.linear_2(x_l + y_d)
74
+
75
+ class FMB(nn.Module):
76
+ def __init__(self, dim, ffn_scale=2.0):
77
+ super().__init__()
78
+
79
+ self.smfa = SMFA(dim)
80
+ self.pcfn = PCFN(dim, ffn_scale)
81
+
82
+ def forward(self, x):
83
+ x = self.smfa(F.normalize(x)) + x
84
+ x = self.pcfn(F.normalize(x)) + x
85
+ return x
86
+
87
+
88
+ class SMFANet(nn.Module):
89
+ def __init__(self, dim=36, n_blocks=8, ffn_scale=2, upscaling_factor=4):
90
+ super().__init__()
91
+ self.scale = upscaling_factor
92
+ self.to_feat = nn.Conv2d(3, dim, 3, 1, 1)
93
+ self.feats = nn.Sequential(*[FMB(dim, ffn_scale) for _ in range(n_blocks)])
94
+ self.to_img = nn.Sequential(
95
+ nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1),
96
+ nn.PixelShuffle(upscaling_factor)
97
+ )
98
+ def forward(self, x):
99
+ x = self.to_feat(x)
100
+ x = self.feats(x) + x
101
+ x = self.to_img(x)
102
+ return x