from collections import namedtuple from torch.nn import Dropout from torch.nn import MaxPool2d from torch.nn import Sequential import torch import torch.nn as nn from torch.nn import Conv2d, Linear from torch.nn import BatchNorm1d, BatchNorm2d from torch.nn import ReLU, Sigmoid from torch.nn import Module from torch.nn import PReLU from fvcore.nn import flop_count import numpy as np def initialize_weights(modules): for m in modules: if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: m.bias.data.zero_() class Flatten(Module): def forward(self, input): return input.view(input.size(0), -1) class LinearBlock(Module): def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): super(LinearBlock, self).__init__() self.conv = Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False) self.bn = BatchNorm2d(out_c) def forward(self, x): x = self.conv(x) x = self.bn(x) return x class SEModule(Module): def __init__(self, channels, reduction): super(SEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) nn.init.xavier_uniform_(self.fc1.weight.data) self.relu = ReLU(inplace=True) self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) self.sigmoid = Sigmoid() def forward(self, x): module_input = x x = self.avg_pool(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) x = self.sigmoid(x) return module_input * x class BasicBlockIR(Module): def __init__(self, in_channel, depth, stride): super(BasicBlockIR, self).__init__() if in_channel == depth: self.shortcut_layer = MaxPool2d(1, stride) else: self.shortcut_layer = Sequential( Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth)) self.res_layer = Sequential( BatchNorm2d(in_channel), Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), BatchNorm2d(depth), PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)) def forward(self, x): shortcut = self.shortcut_layer(x) res = self.res_layer(x) return res + shortcut class BottleneckIR(Module): def __init__(self, in_channel, depth, stride): super(BottleneckIR, self).__init__() reduction_channel = depth // 4 if in_channel == depth: self.shortcut_layer = MaxPool2d(1, stride) else: self.shortcut_layer = Sequential( Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth)) self.res_layer = Sequential( BatchNorm2d(in_channel), Conv2d(in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False), BatchNorm2d(reduction_channel), PReLU(reduction_channel), Conv2d(reduction_channel, reduction_channel, (3, 3), (1, 1), 1, bias=False), BatchNorm2d(reduction_channel), PReLU(reduction_channel), Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False), BatchNorm2d(depth)) def forward(self, x): shortcut = self.shortcut_layer(x) res = self.res_layer(x) return res + shortcut class BasicBlockIRSE(BasicBlockIR): def __init__(self, in_channel, depth, stride): super(BasicBlockIRSE, self).__init__(in_channel, depth, stride) self.res_layer.add_module("se_block", SEModule(depth, 16)) class BottleneckIRSE(BottleneckIR): def __init__(self, in_channel, depth, stride): super(BottleneckIRSE, self).__init__(in_channel, depth, stride) self.res_layer.add_module("se_block", SEModule(depth, 16)) class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): pass def get_block(in_channel, depth, num_units, stride=2): return [Bottleneck(in_channel, depth, stride)] + \ [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] def get_blocks(num_layers): if num_layers == 18: blocks = [ get_block(in_channel=64, depth=64, num_units=2), get_block(in_channel=64, depth=128, num_units=2), get_block(in_channel=128, depth=256, num_units=2), get_block(in_channel=256, depth=512, num_units=2) ] elif num_layers == 34: blocks = [ get_block(in_channel=64, depth=64, num_units=3), get_block(in_channel=64, depth=128, num_units=4), get_block(in_channel=128, depth=256, num_units=6), get_block(in_channel=256, depth=512, num_units=3) ] elif num_layers == 50: blocks = [ get_block(in_channel=64, depth=64, num_units=3), get_block(in_channel=64, depth=128, num_units=4), get_block(in_channel=128, depth=256, num_units=14), get_block(in_channel=256, depth=512, num_units=3) ] elif num_layers == 100: blocks = [ get_block(in_channel=64, depth=64, num_units=3), get_block(in_channel=64, depth=128, num_units=13), get_block(in_channel=128, depth=256, num_units=30), get_block(in_channel=256, depth=512, num_units=3) ] elif num_layers == 152: blocks = [ get_block(in_channel=64, depth=256, num_units=3), get_block(in_channel=256, depth=512, num_units=8), get_block(in_channel=512, depth=1024, num_units=36), get_block(in_channel=1024, depth=2048, num_units=3) ] elif num_layers == 200: blocks = [ get_block(in_channel=64, depth=256, num_units=3), get_block(in_channel=256, depth=512, num_units=24), get_block(in_channel=512, depth=1024, num_units=36), get_block(in_channel=1024, depth=2048, num_units=3) ] return blocks class Backbone(Module): def __init__(self, input_size, num_layers, mode='ir', flip=False, output_dim=512): super(Backbone, self).__init__() assert input_size[0] in [112, 224], \ "input_size should be [112, 112] or [224, 224]" assert num_layers in [18, 34, 50, 100, 152, 200], \ "num_layers should be 18, 34, 50, 100 or 152" assert mode in ['ir', 'ir_se'], \ "mode should be ir or ir_se" self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)) blocks = get_blocks(num_layers) if num_layers <= 100: if mode == 'ir': unit_module = BasicBlockIR elif mode == 'ir_se': unit_module = BasicBlockIRSE output_channel = 512 else: if mode == 'ir': unit_module = BottleneckIR elif mode == 'ir_se': unit_module = BottleneckIRSE output_channel = 2048 if input_size[0] == 112: self.output_layer = Sequential(BatchNorm2d(output_channel), Dropout(0.4), Flatten(), Linear(output_channel * 7 * 7, output_dim), BatchNorm1d(output_dim, affine=False)) else: self.output_layer = Sequential( BatchNorm2d(output_channel), Dropout(0.4), Flatten(), Linear(output_channel * 14 * 14, output_dim), BatchNorm1d(output_dim, affine=False)) modules = [] for block in blocks: for bottleneck in block: modules.append( unit_module(bottleneck.in_channel, bottleneck.depth, bottleneck.stride)) self.body = Sequential(*modules) initialize_weights(self.modules()) self.flip = flip def forward(self, x): if self.flip: x = x.flip(1) # color channel flip x = self.input_layer(x) for idx, module in enumerate(self.body): x = module(x) x = self.output_layer(x) return x def IR_18(input_size, output_dim=512): model = Backbone(input_size, 18, 'ir', output_dim=output_dim) return model def IR_34(input_size, output_dim=512): model = Backbone(input_size, 34, 'ir', output_dim=output_dim) return model def IR_50(input_size, output_dim=512): model = Backbone(input_size, 50, 'ir', output_dim=output_dim) return model def IR_101(input_size, output_dim=512): model = Backbone(input_size, 100, 'ir', output_dim=output_dim) return model def IR_101_FLIP(input_size, output_dim=512): model = Backbone(input_size, 100, 'ir', flip=True, output_dim=output_dim) return model def IR_152(input_size, output_dim=512): model = Backbone(input_size, 152, 'ir', output_dim=output_dim) return model def IR_200(input_size, output_dim=512): model = Backbone(input_size, 200, 'ir', output_dim=output_dim) return model def IR_SE_50(input_size, output_dim=512): model = Backbone(input_size, 50, 'ir_se', output_dim=output_dim) return model def IR_SE_101(input_size, output_dim=512): model = Backbone(input_size, 100, 'ir_se', output_dim=output_dim) return model def IR_SE_152(input_size, output_dim=512): model = Backbone(input_size, 152, 'ir_se', output_dim=output_dim) return model def IR_SE_200(input_size, output_dim=512): model = Backbone(input_size, 200, 'ir_se', output_dim=output_dim) return model if __name__ == '__main__': inputs_shape = (1, 3, 112, 112) model = IR_50(input_size=(112,112)) model.eval() res = flop_count(model, inputs=torch.randn(inputs_shape), supported_ops={}) fvcore_flop = np.array(list(res[0].values())).sum() print('FLOPs: ', fvcore_flop / 1e9, 'G') print('Num Params: ', sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6, 'M')