Create new file
Browse files
blocks.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import math
|
5 |
+
from timm.models.vision_transformer import VisionTransformer, _cfg
|
6 |
+
from timm.models.registry import register_model
|
7 |
+
from timm.models.layers import trunc_normal_, DropPath, to_2tuple
|
8 |
+
|
9 |
+
|
10 |
+
# ResMLP's normalization
|
11 |
+
class Aff(nn.Module):
|
12 |
+
def __init__(self, dim):
|
13 |
+
super().__init__()
|
14 |
+
# learnable
|
15 |
+
self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
|
16 |
+
self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
x = x * self.alpha + self.beta
|
20 |
+
return x
|
21 |
+
|
22 |
+
# Color Normalization
|
23 |
+
class Aff_channel(nn.Module):
|
24 |
+
def __init__(self, dim, channel_first = True):
|
25 |
+
super().__init__()
|
26 |
+
# learnable
|
27 |
+
self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
|
28 |
+
self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
|
29 |
+
self.color = nn.Parameter(torch.eye(dim))
|
30 |
+
self.channel_first = channel_first
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
if self.channel_first:
|
34 |
+
x1 = torch.tensordot(x, self.color, dims=[[-1], [-1]])
|
35 |
+
x2 = x1 * self.alpha + self.beta
|
36 |
+
else:
|
37 |
+
x1 = x * self.alpha + self.beta
|
38 |
+
x2 = torch.tensordot(x1, self.color, dims=[[-1], [-1]])
|
39 |
+
return x2
|
40 |
+
|
41 |
+
class Mlp(nn.Module):
|
42 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
43 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
44 |
+
super().__init__()
|
45 |
+
out_features = out_features or in_features
|
46 |
+
hidden_features = hidden_features or in_features
|
47 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
48 |
+
self.act = act_layer()
|
49 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
50 |
+
self.drop = nn.Dropout(drop)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x = self.fc1(x)
|
54 |
+
x = self.act(x)
|
55 |
+
x = self.drop(x)
|
56 |
+
x = self.fc2(x)
|
57 |
+
x = self.drop(x)
|
58 |
+
return x
|
59 |
+
|
60 |
+
class CMlp(nn.Module):
|
61 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
62 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
63 |
+
super().__init__()
|
64 |
+
out_features = out_features or in_features
|
65 |
+
hidden_features = hidden_features or in_features
|
66 |
+
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
|
67 |
+
self.act = act_layer()
|
68 |
+
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
|
69 |
+
self.drop = nn.Dropout(drop)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x = self.fc1(x)
|
73 |
+
x = self.act(x)
|
74 |
+
x = self.drop(x)
|
75 |
+
x = self.fc2(x)
|
76 |
+
x = self.drop(x)
|
77 |
+
return x
|
78 |
+
|
79 |
+
class CBlock_ln(nn.Module):
|
80 |
+
def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
81 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=Aff_channel, init_values=1e-4):
|
82 |
+
super().__init__()
|
83 |
+
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
|
84 |
+
#self.norm1 = Aff_channel(dim)
|
85 |
+
self.norm1 = norm_layer(dim)
|
86 |
+
self.conv1 = nn.Conv2d(dim, dim, 1)
|
87 |
+
self.conv2 = nn.Conv2d(dim, dim, 1)
|
88 |
+
self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
|
89 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
90 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
91 |
+
self.norm2 = norm_layer(dim)
|
92 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
93 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True)
|
94 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True)
|
95 |
+
self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
x = x + self.pos_embed(x)
|
99 |
+
B, C, H, W = x.shape
|
100 |
+
#print(x.shape)
|
101 |
+
norm_x = x.flatten(2).transpose(1, 2)
|
102 |
+
#print(norm_x.shape)
|
103 |
+
norm_x = self.norm1(norm_x)
|
104 |
+
norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2)
|
105 |
+
|
106 |
+
|
107 |
+
x = x + self.drop_path(self.gamma_1*self.conv2(self.attn(self.conv1(norm_x))))
|
108 |
+
norm_x = x.flatten(2).transpose(1, 2)
|
109 |
+
norm_x = self.norm2(norm_x)
|
110 |
+
norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2)
|
111 |
+
x = x + self.drop_path(self.gamma_2*self.mlp(norm_x))
|
112 |
+
return x
|
113 |
+
|
114 |
+
|