|
import torch |
|
from torch import nn |
|
from torchvision.models import resnet |
|
from typing import Optional, Callable |
|
|
|
|
|
class ConvBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels, |
|
gate: Optional[Callable[..., nn.Module]] = None, |
|
norm_layer: Optional[Callable[..., nn.Module]] = None): |
|
super().__init__() |
|
if gate is None: |
|
self.gate = nn.ReLU(inplace=True) |
|
else: |
|
self.gate = gate |
|
if norm_layer is None: |
|
norm_layer = nn.BatchNorm2d |
|
self.conv1 = resnet.conv3x3(in_channels, out_channels) |
|
self.bn1 = norm_layer(out_channels) |
|
self.conv2 = resnet.conv3x3(out_channels, out_channels) |
|
self.bn2 = norm_layer(out_channels) |
|
|
|
def forward(self, x): |
|
x = self.gate(self.bn1(self.conv1(x))) |
|
x = self.gate(self.bn2(self.conv2(x))) |
|
return x |
|
|
|
|
|
|
|
class ResBlock(nn.Module): |
|
expansion: int = 1 |
|
|
|
def __init__( |
|
self, |
|
inplanes: int, |
|
planes: int, |
|
stride: int = 1, |
|
downsample: Optional[nn.Module] = None, |
|
groups: int = 1, |
|
base_width: int = 64, |
|
dilation: int = 1, |
|
gate: Optional[Callable[..., nn.Module]] = None, |
|
norm_layer: Optional[Callable[..., nn.Module]] = None |
|
) -> None: |
|
super(ResBlock, self).__init__() |
|
if gate is None: |
|
self.gate = nn.ReLU(inplace=True) |
|
else: |
|
self.gate = gate |
|
if norm_layer is None: |
|
norm_layer = nn.BatchNorm2d |
|
if groups != 1 or base_width != 64: |
|
raise ValueError('ResBlock only supports groups=1 and base_width=64') |
|
if dilation > 1: |
|
raise NotImplementedError("Dilation > 1 not supported in ResBlock") |
|
|
|
self.conv1 = resnet.conv3x3(inplanes, planes, stride) |
|
self.bn1 = norm_layer(planes) |
|
self.conv2 = resnet.conv3x3(planes, planes) |
|
self.bn2 = norm_layer(planes) |
|
self.downsample = downsample |
|
self.stride = stride |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
identity = x |
|
|
|
out = self.conv1(x) |
|
out = self.bn1(out) |
|
out = self.gate(out) |
|
|
|
out = self.conv2(out) |
|
out = self.bn2(out) |
|
|
|
if self.downsample is not None: |
|
identity = self.downsample(x) |
|
|
|
out += identity |
|
out = self.gate(out) |
|
|
|
return out |
|
|
|
|
|
class ALNet(nn.Module): |
|
def __init__(self, c1: int = 32, c2: int = 64, c3: int = 128, c4: int = 128, dim: int = 128, |
|
single_head: bool = True, |
|
): |
|
super().__init__() |
|
|
|
self.gate = nn.ReLU(inplace=True) |
|
|
|
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
self.pool4 = nn.MaxPool2d(kernel_size=4, stride=4) |
|
|
|
self.block1 = ConvBlock(3, c1, self.gate, nn.BatchNorm2d) |
|
|
|
self.block2 = ResBlock(inplanes=c1, planes=c2, stride=1, |
|
downsample=nn.Conv2d(c1, c2, 1), |
|
gate=self.gate, |
|
norm_layer=nn.BatchNorm2d) |
|
self.block3 = ResBlock(inplanes=c2, planes=c3, stride=1, |
|
downsample=nn.Conv2d(c2, c3, 1), |
|
gate=self.gate, |
|
norm_layer=nn.BatchNorm2d) |
|
self.block4 = ResBlock(inplanes=c3, planes=c4, stride=1, |
|
downsample=nn.Conv2d(c3, c4, 1), |
|
gate=self.gate, |
|
norm_layer=nn.BatchNorm2d) |
|
|
|
|
|
self.conv1 = resnet.conv1x1(c1, dim // 4) |
|
self.conv2 = resnet.conv1x1(c2, dim // 4) |
|
self.conv3 = resnet.conv1x1(c3, dim // 4) |
|
self.conv4 = resnet.conv1x1(dim, dim // 4) |
|
self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
|
self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) |
|
self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) |
|
self.upsample32 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True) |
|
|
|
|
|
self.single_head = single_head |
|
if not self.single_head: |
|
self.convhead1 = resnet.conv1x1(dim, dim) |
|
self.convhead2 = resnet.conv1x1(dim, dim + 1) |
|
|
|
def forward(self, image): |
|
|
|
x1 = self.block1(image) |
|
x2 = self.pool2(x1) |
|
x2 = self.block2(x2) |
|
x3 = self.pool4(x2) |
|
x3 = self.block3(x3) |
|
x4 = self.pool4(x3) |
|
x4 = self.block4(x4) |
|
|
|
|
|
x1 = self.gate(self.conv1(x1)) |
|
x2 = self.gate(self.conv2(x2)) |
|
x3 = self.gate(self.conv3(x3)) |
|
x4 = self.gate(self.conv4(x4)) |
|
x2_up = self.upsample2(x2) |
|
x3_up = self.upsample8(x3) |
|
x4_up = self.upsample32(x4) |
|
x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1) |
|
|
|
|
|
if not self.single_head: |
|
x1234 = self.gate(self.convhead1(x1234)) |
|
x = self.convhead2(x1234) |
|
|
|
descriptor_map = x[:, :-1, :, :] |
|
scores_map = torch.sigmoid(x[:, -1, :, :]).unsqueeze(1) |
|
|
|
return scores_map, descriptor_map |
|
|
|
|
|
if __name__ == '__main__': |
|
from thop import profile |
|
|
|
net = ALNet(c1=16, c2=32, c3=64, c4=128, dim=128, single_head=True) |
|
|
|
image = torch.randn(1, 3, 640, 480) |
|
flops, params = profile(net, inputs=(image,), verbose=False) |
|
print('{:<30} {:<8} GFLops'.format('Computational complexity: ', flops / 1e9)) |
|
print('{:<30} {:<8} KB'.format('Number of parameters: ', params / 1e3)) |
|
|