Spaces:
Runtime error
Runtime error
File size: 5,202 Bytes
99712e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import torch
import torch.nn as nn
##########################################################################
def conv(in_channels, out_channels, kernel_size, bias=False, stride=1):
layer = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias, stride=stride)
return layer
def conv3x3(in_chn, out_chn, bias=True):
layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias)
return layer
def conv_down(in_chn, out_chn, bias=False):
layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias)
return layer
##########################################################################
## Supervised Attention Module (RAM)
class SAM(nn.Module):
def __init__(self, n_feat, kernel_size, bias):
super(SAM, self).__init__()
self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
self.conv2 = conv(n_feat, 3, kernel_size, bias=bias)
self.conv3 = conv(3, n_feat, kernel_size, bias=bias)
def forward(self, x, x_img):
x1 = self.conv1(x)
img = self.conv2(x) + x_img
x2 = torch.sigmoid(self.conv3(img))
x1 = x1 * x2
x1 = x1 + x
return x1, img
##########################################################################
## Spatial Attention
class SALayer(nn.Module):
def __init__(self, kernel_size=7):
super(SALayer, self).__init__()
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
y = torch.cat([avg_out, max_out], dim=1)
y = self.conv1(y)
y = self.sigmoid(y)
return x * y
# Spatial Attention Block (SAB)
class SAB(nn.Module):
def __init__(self, n_feat, kernel_size, reduction, bias, act):
super(SAB, self).__init__()
modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)]
self.body = nn.Sequential(*modules_body)
self.SA = SALayer(kernel_size=7)
def forward(self, x):
res = self.body(x)
res = self.SA(res)
res += x
return res
##########################################################################
## Pixel Attention
class PALayer(nn.Module):
def __init__(self, channel, reduction=16, bias=False):
super(PALayer, self).__init__()
self.pa = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), # channel <-> 1
nn.Sigmoid()
)
def forward(self, x):
y = self.pa(x)
return x * y
## Pixel Attention Block (PAB)
class PAB(nn.Module):
def __init__(self, n_feat, kernel_size, reduction, bias, act):
super(PAB, self).__init__()
modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)]
self.PA = PALayer(n_feat, reduction, bias=bias)
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
res = self.PA(res)
res += x
return res
##########################################################################
## Channel Attention Layer
class CALayer(nn.Module):
def __init__(self, channel, reduction=16, bias=False):
super(CALayer, self).__init__()
# global average pooling: feature --> point
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# feature channel downscale and upscale --> channel weight
self.conv_du = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
## Channel Attention Block (CAB)
class CAB(nn.Module):
def __init__(self, n_feat, kernel_size, reduction, bias, act):
super(CAB, self).__init__()
modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)]
self.CA = CALayer(n_feat, reduction, bias=bias)
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
res = self.CA(res)
res += x
return res
if __name__ == "__main__":
import time
from thop import profile
# layer = CAB(64, 3, 4, False, nn.PReLU())
layer = PAB(64, 3, 4, False, nn.PReLU())
# layer = SAB(64, 3, 4, False, nn.PReLU())
for idx, m in enumerate(layer.modules()):
print(idx, "-", m)
s = time.time()
rgb = torch.ones(1, 64, 256, 256, dtype=torch.float, requires_grad=False)
out = layer(rgb)
flops, params = profile(layer, inputs=(rgb,))
print('parameters:', params)
print('flops', flops)
print('time: {:.4f}ms'.format((time.time()-s)*10)) |