import torch import torch.nn as nn import torch.nn.functional as F # https://openaccess.thecvf.com/content_cvpr_2018/html/Hu_Squeeze-and-Excitation_Networks_CVPR_2018_paper.html class SEBlock(nn.Module): def __init__(self, input_channels, internal_neurons): super(SEBlock, self).__init__() self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1, bias=True) self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1, bias=True) self.input_channels = input_channels def forward(self, inputs): x = F.avg_pool2d(inputs, kernel_size=inputs.size(3)) x = self.down(x) x = F.relu(x) x = self.up(x) x = torch.sigmoid(x) x = x.view(-1, self.input_channels, 1, 1) return inputs * x