ZitengCui commited on
Commit
726b27f
1 Parent(s): 81283de

Create new file

Browse files
Files changed (1) hide show
  1. blocks.py +114 -0
blocks.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+ from timm.models.vision_transformer import VisionTransformer, _cfg
6
+ from timm.models.registry import register_model
7
+ from timm.models.layers import trunc_normal_, DropPath, to_2tuple
8
+
9
+
10
+ # ResMLP's normalization
11
+ class Aff(nn.Module):
12
+ def __init__(self, dim):
13
+ super().__init__()
14
+ # learnable
15
+ self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
16
+ self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
17
+
18
+ def forward(self, x):
19
+ x = x * self.alpha + self.beta
20
+ return x
21
+
22
+ # Color Normalization
23
+ class Aff_channel(nn.Module):
24
+ def __init__(self, dim, channel_first = True):
25
+ super().__init__()
26
+ # learnable
27
+ self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
28
+ self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
29
+ self.color = nn.Parameter(torch.eye(dim))
30
+ self.channel_first = channel_first
31
+
32
+ def forward(self, x):
33
+ if self.channel_first:
34
+ x1 = torch.tensordot(x, self.color, dims=[[-1], [-1]])
35
+ x2 = x1 * self.alpha + self.beta
36
+ else:
37
+ x1 = x * self.alpha + self.beta
38
+ x2 = torch.tensordot(x1, self.color, dims=[[-1], [-1]])
39
+ return x2
40
+
41
+ class Mlp(nn.Module):
42
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
43
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
44
+ super().__init__()
45
+ out_features = out_features or in_features
46
+ hidden_features = hidden_features or in_features
47
+ self.fc1 = nn.Linear(in_features, hidden_features)
48
+ self.act = act_layer()
49
+ self.fc2 = nn.Linear(hidden_features, out_features)
50
+ self.drop = nn.Dropout(drop)
51
+
52
+ def forward(self, x):
53
+ x = self.fc1(x)
54
+ x = self.act(x)
55
+ x = self.drop(x)
56
+ x = self.fc2(x)
57
+ x = self.drop(x)
58
+ return x
59
+
60
+ class CMlp(nn.Module):
61
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
62
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
63
+ super().__init__()
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
67
+ self.act = act_layer()
68
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
69
+ self.drop = nn.Dropout(drop)
70
+
71
+ def forward(self, x):
72
+ x = self.fc1(x)
73
+ x = self.act(x)
74
+ x = self.drop(x)
75
+ x = self.fc2(x)
76
+ x = self.drop(x)
77
+ return x
78
+
79
+ class CBlock_ln(nn.Module):
80
+ def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
81
+ drop_path=0., act_layer=nn.GELU, norm_layer=Aff_channel, init_values=1e-4):
82
+ super().__init__()
83
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
84
+ #self.norm1 = Aff_channel(dim)
85
+ self.norm1 = norm_layer(dim)
86
+ self.conv1 = nn.Conv2d(dim, dim, 1)
87
+ self.conv2 = nn.Conv2d(dim, dim, 1)
88
+ self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
89
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
90
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
91
+ self.norm2 = norm_layer(dim)
92
+ mlp_hidden_dim = int(dim * mlp_ratio)
93
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True)
94
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True)
95
+ self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
96
+
97
+ def forward(self, x):
98
+ x = x + self.pos_embed(x)
99
+ B, C, H, W = x.shape
100
+ #print(x.shape)
101
+ norm_x = x.flatten(2).transpose(1, 2)
102
+ #print(norm_x.shape)
103
+ norm_x = self.norm1(norm_x)
104
+ norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2)
105
+
106
+
107
+ x = x + self.drop_path(self.gamma_1*self.conv2(self.attn(self.conv1(norm_x))))
108
+ norm_x = x.flatten(2).transpose(1, 2)
109
+ norm_x = self.norm2(norm_x)
110
+ norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2)
111
+ x = x + self.drop_path(self.gamma_2*self.mlp(norm_x))
112
+ return x
113
+
114
+