import torch import os import math import torch.nn as nn from torch.nn import init import functools from torch.autograd import Variable import torch.nn.functional as F import numpy as np # from torch.utils.serialization import load_lua from lib.nn import SynchronizedBatchNorm2d as SynBN2d ############################################################################### # Functions ############################################################################### def pad_tensor(input): height_org, width_org = input.shape[2], input.shape[3] divide = 16 if width_org % divide != 0 or height_org % divide != 0: width_res = width_org % divide height_res = height_org % divide if width_res != 0: width_div = divide - width_res pad_left = int(width_div / 2) pad_right = int(width_div - pad_left) else: pad_left = 0 pad_right = 0 if height_res != 0: height_div = divide - height_res pad_top = int(height_div / 2) pad_bottom = int(height_div - pad_top) else: pad_top = 0 pad_bottom = 0 padding = nn.ReflectionPad2d((pad_left, pad_right, pad_top, pad_bottom)) input = padding(input) else: pad_left = 0 pad_right = 0 pad_top = 0 pad_bottom = 0 height, width = input.data.shape[2], input.data.shape[3] assert width % divide == 0, 'width cant divided by stride' assert height % divide == 0, 'height cant divided by stride' return input, pad_left, pad_right, pad_top, pad_bottom def pad_tensor_back(input, pad_left, pad_right, pad_top, pad_bottom): height, width = input.shape[2], input.shape[3] return input[:,:, pad_top: height - pad_bottom, pad_left: width - pad_right] def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm2d') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) def get_norm_layer(norm_type='instance'): if norm_type == 'batch': norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == 'instance': norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) elif norm_type == 'synBN': norm_layer = functools.partial(SynBN2d, affine=True) else: raise NotImplementedError('normalization layer [%s] is not found' % norm) return norm_layer def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[], skip=False, opt=None): netG = None use_gpu = len(gpu_ids) > 0 norm_layer = get_norm_layer(norm_type=norm) if use_gpu: assert(torch.cuda.is_available()) if which_model_netG == 'resnet_9blocks': netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) elif which_model_netG == 'resnet_6blocks': netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids) elif which_model_netG == 'unet_128': netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) elif which_model_netG == 'unet_256': netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids, skip=skip, opt=opt) elif which_model_netG == 'unet_512': netG = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids, skip=skip, opt=opt) elif which_model_netG == 'sid_unet': netG = Unet(opt, skip) elif which_model_netG == 'sid_unet_shuffle': netG = Unet_pixelshuffle(opt, skip) elif which_model_netG == 'sid_unet_resize': netG = Unet_resize_conv(opt, skip) elif which_model_netG == 'DnCNN': netG = DnCNN(opt, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3) else: raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) >= 0: netG.cuda(device=gpu_ids[0]) netG = torch.nn.DataParallel(netG, gpu_ids) netG.apply(weights_init) return netG def define_D(input_nc, ndf, which_model_netD, n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[], patch=False): netD = None use_gpu = len(gpu_ids) > 0 norm_layer = get_norm_layer(norm_type=norm) if use_gpu: assert(torch.cuda.is_available()) if which_model_netD == 'basic': netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) elif which_model_netD == 'n_layers': netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) elif which_model_netD == 'no_norm': netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) elif which_model_netD == 'no_norm_4': netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) elif which_model_netD == 'no_patchgan': netD = FCDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids, patch=patch) else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) if use_gpu: netD.cuda(device=gpu_ids[0]) netD = torch.nn.DataParallel(netD, gpu_ids) netD.apply(weights_init) return netD def print_network(net): num_params = 0 for param in net.parameters(): num_params += param.numel() print(net) print('Total number of parameters: %d' % num_params) ############################################################################## # Classes ############################################################################## # Defines the GAN loss which uses either LSGAN or the regular GAN. # When LSGAN is used, it is basically same as MSELoss, # but it abstracts away the need to create the target label tensor # that has the same size as the input class GANLoss(nn.Module): def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor): super(GANLoss, self).__init__() self.real_label = target_real_label self.fake_label = target_fake_label self.real_label_var = None self.fake_label_var = None self.Tensor = tensor if use_lsgan: self.loss = nn.MSELoss() else: self.loss = nn.BCELoss() def get_target_tensor(self, input, target_is_real): target_tensor = None if target_is_real: create_label = ((self.real_label_var is None) or (self.real_label_var.numel() != input.numel())) if create_label: real_tensor = self.Tensor(input.size()).fill_(self.real_label) self.real_label_var = Variable(real_tensor, requires_grad=False) target_tensor = self.real_label_var else: create_label = ((self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel())) if create_label: fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) self.fake_label_var = Variable(fake_tensor, requires_grad=False) target_tensor = self.fake_label_var return target_tensor def __call__(self, input, target_is_real): target_tensor = self.get_target_tensor(input, target_is_real) return self.loss(input, target_tensor) class DiscLossWGANGP(): def __init__(self): self.LAMBDA = 10 def name(self): return 'DiscLossWGAN-GP' def initialize(self, opt, tensor): # DiscLossLS.initialize(self, opt, tensor) self.LAMBDA = 10 # def get_g_loss(self, net, realA, fakeB): # # First, G(A) should fake the discriminator # self.D_fake = net.forward(fakeB) # return -self.D_fake.mean() def calc_gradient_penalty(self, netD, real_data, fake_data): alpha = torch.rand(1, 1) alpha = alpha.expand(real_data.size()) alpha = alpha.cuda() interpolates = alpha * real_data + ((1 - alpha) * fake_data) interpolates = interpolates.cuda() interpolates = Variable(interpolates, requires_grad=True) disc_interpolates = netD.forward(interpolates) gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0] gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA return gradient_penalty # Defines the generator that consists of Resnet blocks between a few # downsampling/upsampling operations. # Code and idea originally from Justin Johnson's architecture. # https://github.com/jcjohnson/fast-neural-style/ class ResnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], padding_type='reflect'): assert(n_blocks >= 0) super(ResnetGenerator, self).__init__() self.input_nc = input_nc self.output_nc = output_nc self.ngf = ngf self.gpu_ids = gpu_ids model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), nn.ReLU(True)] n_downsampling = 2 for i in range(n_downsampling): mult = 2**i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), norm_layer(ngf * mult * 2), nn.ReLU(True)] mult = 2**n_downsampling for i in range(n_blocks): model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout)] for i in range(n_downsampling): mult = 2**(n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] model += [nn.ReflectionPad2d(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, input): if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) # Define a resnet block class ResnetBlock(nn.Module): def __init__(self, dim, padding_type, norm_layer, use_dropout): super(ResnetBlock, self).__init__() self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout) def build_conv_block(self, dim, padding_type, norm_layer, use_dropout): conv_block = [] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), norm_layer(dim), nn.ReLU(True)] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), norm_layer(dim)] return nn.Sequential(*conv_block) def forward(self, x): out = x + self.conv_block(x) return out # Defines the Unet generator. # |num_downs|: number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 # at the bottleneck class UnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[], skip=False, opt=None): super(UnetGenerator, self).__init__() self.gpu_ids = gpu_ids self.opt = opt # currently support only input_nc == output_nc assert(input_nc == output_nc) # construct unet structure unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True, opt=opt) for i in range(num_downs - 5): unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout, opt=opt) unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer, opt=opt) unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer, opt=opt) unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer, opt=opt) unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer, opt=opt) if skip == True: skipmodule = SkipModule(unet_block, opt) self.model = skipmodule else: self.model = unet_block def forward(self, input): if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) class SkipModule(nn.Module): def __init__(self, submodule, opt): super(SkipModule, self).__init__() self.submodule = submodule self.opt = opt def forward(self, x): latent = self.submodule(x) return self.opt.skip*x + latent, latent # Defines the submodule with skip connection. # X -------------------identity---------------------- X # |-- downsampling -- |submodule| -- upsampling --| class UnetSkipConnectionBlock(nn.Module): def __init__(self, outer_nc, inner_nc, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, opt=None): super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, stride=2, padding=1) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) upnorm = norm_layer(outer_nc) if opt.use_norm == 0: if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downconv] up = [uprelu, upconv, nn.Tanh()] model = down + [submodule] + up elif innermost: upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1) down = [downrelu, downconv] up = [uprelu, upconv] model = down + up else: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downrelu, downconv] up = [uprelu, upconv] if use_dropout: model = down + [submodule] + up + [nn.Dropout(0.5)] else: model = down + [submodule] + up else: if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downconv] up = [uprelu, upconv, nn.Tanh()] model = down + [submodule] + up elif innermost: upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] if use_dropout: model = down + [submodule] + up + [nn.Dropout(0.5)] else: model = down + [submodule] + up self.model = nn.Sequential(*model) def forward(self, x): if self.outermost: return self.model(x) else: return torch.cat([self.model(x), x], 1) # Defines the PatchGAN discriminator with the specified arguments. class NLayerDiscriminator(nn.Module): def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]): super(NLayerDiscriminator, self).__init__() self.gpu_ids = gpu_ids kw = 4 padw = int(np.ceil((kw-1)/2)) sequence = [ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True) ] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): nf_mult_prev = nf_mult nf_mult = min(2**n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2**n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] if use_sigmoid: sequence += [nn.Sigmoid()] self.model = nn.Sequential(*sequence) def forward(self, input): # if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): # return nn.parallel.data_parallel(self.model, input, self.gpu_ids) # else: return self.model(input) class NoNormDiscriminator(nn.Module): def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[]): super(NoNormDiscriminator, self).__init__() self.gpu_ids = gpu_ids kw = 4 padw = int(np.ceil((kw-1)/2)) sequence = [ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True) ] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): nf_mult_prev = nf_mult nf_mult = min(2**n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2**n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True) ] sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] if use_sigmoid: sequence += [nn.Sigmoid()] self.model = nn.Sequential(*sequence) def forward(self, input): # if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): # return nn.parallel.data_parallel(self.model, input, self.gpu_ids) # else: return self.model(input) class FCDiscriminator(nn.Module): def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[], patch=False): super(FCDiscriminator, self).__init__() self.gpu_ids = gpu_ids self.use_sigmoid = use_sigmoid kw = 4 padw = int(np.ceil((kw-1)/2)) sequence = [ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True) ] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): nf_mult_prev = nf_mult nf_mult = min(2**n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2**n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True) ] sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] if patch: self.linear = nn.Linear(7*7,1) else: self.linear = nn.Linear(13*13,1) if use_sigmoid: self.sigmoid = nn.Sigmoid() self.model = nn.Sequential(*sequence) def forward(self, input): batchsize = input.size()[0] output = self.model(input) output = output.view(batchsize,-1) # print(output.size()) output = self.linear(output) if self.use_sigmoid: print("sigmoid") output = self.sigmoid(output) return output class Unet_resize_conv(nn.Module): def __init__(self, opt, skip): super(Unet_resize_conv, self).__init__() self.opt = opt self.skip = skip p = 1 # self.conv1_1 = nn.Conv2d(4, 32, 3, padding=p) if opt.self_attention: self.conv1_1 = nn.Conv2d(4, 32, 3, padding=p) # self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p) self.downsample_1 = nn.MaxPool2d(2) self.downsample_2 = nn.MaxPool2d(2) self.downsample_3 = nn.MaxPool2d(2) self.downsample_4 = nn.MaxPool2d(2) else: self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p) self.LReLU1_1 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn1_1 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32) self.conv1_2 = nn.Conv2d(32, 32, 3, padding=p) self.LReLU1_2 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn1_2 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32) self.max_pool1 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2) self.conv2_1 = nn.Conv2d(32, 64, 3, padding=p) self.LReLU2_1 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn2_1 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64) self.conv2_2 = nn.Conv2d(64, 64, 3, padding=p) self.LReLU2_2 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn2_2 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64) self.max_pool2 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2) self.conv3_1 = nn.Conv2d(64, 128, 3, padding=p) self.LReLU3_1 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn3_1 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128) self.conv3_2 = nn.Conv2d(128, 128, 3, padding=p) self.LReLU3_2 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn3_2 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128) self.max_pool3 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2) self.conv4_1 = nn.Conv2d(128, 256, 3, padding=p) self.LReLU4_1 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn4_1 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256) self.conv4_2 = nn.Conv2d(256, 256, 3, padding=p) self.LReLU4_2 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn4_2 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256) self.max_pool4 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2) self.conv5_1 = nn.Conv2d(256, 512, 3, padding=p) self.LReLU5_1 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn5_1 = SynBN2d(512) if self.opt.syn_norm else nn.BatchNorm2d(512) self.conv5_2 = nn.Conv2d(512, 512, 3, padding=p) self.LReLU5_2 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn5_2 = SynBN2d(512) if self.opt.syn_norm else nn.BatchNorm2d(512) # self.deconv5 = nn.ConvTranspose2d(512, 256, 2, stride=2) self.deconv5 = nn.Conv2d(512, 256, 3, padding=p) self.conv6_1 = nn.Conv2d(512, 256, 3, padding=p) self.LReLU6_1 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn6_1 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256) self.conv6_2 = nn.Conv2d(256, 256, 3, padding=p) self.LReLU6_2 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn6_2 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256) # self.deconv6 = nn.ConvTranspose2d(256, 128, 2, stride=2) self.deconv6 = nn.Conv2d(256, 128, 3, padding=p) self.conv7_1 = nn.Conv2d(256, 128, 3, padding=p) self.LReLU7_1 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn7_1 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128) self.conv7_2 = nn.Conv2d(128, 128, 3, padding=p) self.LReLU7_2 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn7_2 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128) # self.deconv7 = nn.ConvTranspose2d(128, 64, 2, stride=2) self.deconv7 = nn.Conv2d(128, 64, 3, padding=p) self.conv8_1 = nn.Conv2d(128, 64, 3, padding=p) self.LReLU8_1 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn8_1 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64) self.conv8_2 = nn.Conv2d(64, 64, 3, padding=p) self.LReLU8_2 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn8_2 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64) # self.deconv8 = nn.ConvTranspose2d(64, 32, 2, stride=2) self.deconv8 = nn.Conv2d(64, 32, 3, padding=p) self.conv9_1 = nn.Conv2d(64, 32, 3, padding=p) self.LReLU9_1 = nn.LeakyReLU(0.2, inplace=True) if self.opt.use_norm == 1: self.bn9_1 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32) self.conv9_2 = nn.Conv2d(32, 32, 3, padding=p) self.LReLU9_2 = nn.LeakyReLU(0.2, inplace=True) self.conv10 = nn.Conv2d(32, 3, 1) if self.opt.tanh: self.tanh = nn.Tanh() def depth_to_space(self, input, block_size): block_size_sq = block_size*block_size output = input.permute(0, 2, 3, 1) (batch_size, d_height, d_width, d_depth) = output.size() s_depth = int(d_depth / block_size_sq) s_width = int(d_width * block_size) s_height = int(d_height * block_size) t_1 = output.resize(batch_size, d_height, d_width, block_size_sq, s_depth) spl = t_1.split(block_size, 3) stack = [t_t.resize(batch_size, d_height, s_width, s_depth) for t_t in spl] output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).resize(batch_size, s_height, s_width, s_depth) output = output.permute(0, 3, 1, 2) return output def forward(self, input, gray): flag = 0 if input.size()[3] > 2200: avg = nn.AvgPool2d(2) input = avg(input) gray = avg(gray) flag = 1 # pass input, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(input) gray, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(gray) if self.opt.self_attention: gray_2 = self.downsample_1(gray) gray_3 = self.downsample_2(gray_2) gray_4 = self.downsample_3(gray_3) gray_5 = self.downsample_4(gray_4) if self.opt.use_norm == 1: if self.opt.self_attention: x = self.bn1_1(self.LReLU1_1(self.conv1_1(torch.cat((input, gray), 1)))) # x = self.bn1_1(self.LReLU1_1(self.conv1_1(input))) else: x = self.bn1_1(self.LReLU1_1(self.conv1_1(input))) conv1 = self.bn1_2(self.LReLU1_2(self.conv1_2(x))) x = self.max_pool1(conv1) x = self.bn2_1(self.LReLU2_1(self.conv2_1(x))) conv2 = self.bn2_2(self.LReLU2_2(self.conv2_2(x))) x = self.max_pool2(conv2) x = self.bn3_1(self.LReLU3_1(self.conv3_1(x))) conv3 = self.bn3_2(self.LReLU3_2(self.conv3_2(x))) x = self.max_pool3(conv3) x = self.bn4_1(self.LReLU4_1(self.conv4_1(x))) conv4 = self.bn4_2(self.LReLU4_2(self.conv4_2(x))) x = self.max_pool4(conv4) x = self.bn5_1(self.LReLU5_1(self.conv5_1(x))) x = x*gray_5 if self.opt.self_attention else x conv5 = self.bn5_2(self.LReLU5_2(self.conv5_2(x))) conv5 = F.upsample(conv5, scale_factor=2, mode='bilinear') conv4 = conv4*gray_4 if self.opt.self_attention else conv4 up6 = torch.cat([self.deconv5(conv5), conv4], 1) x = self.bn6_1(self.LReLU6_1(self.conv6_1(up6))) conv6 = self.bn6_2(self.LReLU6_2(self.conv6_2(x))) conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear') conv3 = conv3*gray_3 if self.opt.self_attention else conv3 up7 = torch.cat([self.deconv6(conv6), conv3], 1) x = self.bn7_1(self.LReLU7_1(self.conv7_1(up7))) conv7 = self.bn7_2(self.LReLU7_2(self.conv7_2(x))) conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear') conv2 = conv2*gray_2 if self.opt.self_attention else conv2 up8 = torch.cat([self.deconv7(conv7), conv2], 1) x = self.bn8_1(self.LReLU8_1(self.conv8_1(up8))) conv8 = self.bn8_2(self.LReLU8_2(self.conv8_2(x))) conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear') conv1 = conv1*gray if self.opt.self_attention else conv1 up9 = torch.cat([self.deconv8(conv8), conv1], 1) x = self.bn9_1(self.LReLU9_1(self.conv9_1(up9))) conv9 = self.LReLU9_2(self.conv9_2(x)) latent = self.conv10(conv9) if self.opt.times_residual: latent = latent*gray # output = self.depth_to_space(conv10, 2) if self.opt.tanh: latent = self.tanh(latent) if self.skip: if self.opt.linear_add: if self.opt.latent_threshold: latent = F.relu(latent) elif self.opt.latent_norm: latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent)) input = (input - torch.min(input))/(torch.max(input) - torch.min(input)) output = latent + input*self.opt.skip output = output*2 - 1 else: if self.opt.latent_threshold: latent = F.relu(latent) elif self.opt.latent_norm: latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent)) output = latent + input*self.opt.skip else: output = latent if self.opt.linear: output = output/torch.max(torch.abs(output)) elif self.opt.use_norm == 0: if self.opt.self_attention: x = self.LReLU1_1(self.conv1_1(torch.cat((input, gray), 1))) else: x = self.LReLU1_1(self.conv1_1(input)) conv1 = self.LReLU1_2(self.conv1_2(x)) x = self.max_pool1(conv1) x = self.LReLU2_1(self.conv2_1(x)) conv2 = self.LReLU2_2(self.conv2_2(x)) x = self.max_pool2(conv2) x = self.LReLU3_1(self.conv3_1(x)) conv3 = self.LReLU3_2(self.conv3_2(x)) x = self.max_pool3(conv3) x = self.LReLU4_1(self.conv4_1(x)) conv4 = self.LReLU4_2(self.conv4_2(x)) x = self.max_pool4(conv4) x = self.LReLU5_1(self.conv5_1(x)) x = x*gray_5 if self.opt.self_attention else x conv5 = self.LReLU5_2(self.conv5_2(x)) conv5 = F.upsample(conv5, scale_factor=2, mode='bilinear') conv4 = conv4*gray_4 if self.opt.self_attention else conv4 up6 = torch.cat([self.deconv5(conv5), conv4], 1) x = self.LReLU6_1(self.conv6_1(up6)) conv6 = self.LReLU6_2(self.conv6_2(x)) conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear') conv3 = conv3*gray_3 if self.opt.self_attention else conv3 up7 = torch.cat([self.deconv6(conv6), conv3], 1) x = self.LReLU7_1(self.conv7_1(up7)) conv7 = self.LReLU7_2(self.conv7_2(x)) conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear') conv2 = conv2*gray_2 if self.opt.self_attention else conv2 up8 = torch.cat([self.deconv7(conv7), conv2], 1) x = self.LReLU8_1(self.conv8_1(up8)) conv8 = self.LReLU8_2(self.conv8_2(x)) conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear') conv1 = conv1*gray if self.opt.self_attention else conv1 up9 = torch.cat([self.deconv8(conv8), conv1], 1) x = self.LReLU9_1(self.conv9_1(up9)) conv9 = self.LReLU9_2(self.conv9_2(x)) latent = self.conv10(conv9) if self.opt.times_residual: latent = latent*gray if self.opt.tanh: latent = self.tanh(latent) if self.skip: if self.opt.linear_add: if self.opt.latent_threshold: latent = F.relu(latent) elif self.opt.latent_norm: latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent)) input = (input - torch.min(input))/(torch.max(input) - torch.min(input)) output = latent + input*self.opt.skip output = output*2 - 1 else: if self.opt.latent_threshold: latent = F.relu(latent) elif self.opt.latent_norm: latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent)) output = latent + input*self.opt.skip else: output = latent if self.opt.linear: output = output/torch.max(torch.abs(output)) output = pad_tensor_back(output, pad_left, pad_right, pad_top, pad_bottom) latent = pad_tensor_back(latent, pad_left, pad_right, pad_top, pad_bottom) gray = pad_tensor_back(gray, pad_left, pad_right, pad_top, pad_bottom) if flag == 1: output = F.upsample(output, scale_factor=2, mode='bilinear') gray = F.upsample(gray, scale_factor=2, mode='bilinear') if self.skip: return output, latent else: return output class DnCNN(nn.Module): def __init__(self, opt=None, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3): super(DnCNN, self).__init__() kernel_size = 3 padding = 1 layers = [] layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True)) layers.append(nn.ReLU(inplace=True)) for _ in range(depth-2): layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False)) layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95)) layers.append(nn.ReLU(inplace=True)) layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False)) self.dncnn = nn.Sequential(*layers) self._initialize_weights() def forward(self, x): y = x out = self.dncnn(x) return y+out def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): init.orthogonal_(m.weight) print('init weight') if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias, 0) class Vgg16(nn.Module): def __init__(self): super(Vgg16, self).__init__() self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) def forward(self, X, opt): h = F.relu(self.conv1_1(X), inplace=True) h = F.relu(self.conv1_2(h), inplace=True) # relu1_2 = h h = F.max_pool2d(h, kernel_size=2, stride=2) h = F.relu(self.conv2_1(h), inplace=True) h = F.relu(self.conv2_2(h), inplace=True) # relu2_2 = h h = F.max_pool2d(h, kernel_size=2, stride=2) h = F.relu(self.conv3_1(h), inplace=True) h = F.relu(self.conv3_2(h), inplace=True) h = F.relu(self.conv3_3(h), inplace=True) # relu3_3 = h if opt.vgg_choose != "no_maxpool": h = F.max_pool2d(h, kernel_size=2, stride=2) h = F.relu(self.conv4_1(h), inplace=True) relu4_1 = h h = F.relu(self.conv4_2(h), inplace=True) relu4_2 = h conv4_3 = self.conv4_3(h) h = F.relu(conv4_3, inplace=True) relu4_3 = h if opt.vgg_choose != "no_maxpool": if opt.vgg_maxpooling: h = F.max_pool2d(h, kernel_size=2, stride=2) relu5_1 = F.relu(self.conv5_1(h), inplace=True) relu5_2 = F.relu(self.conv5_2(relu5_1), inplace=True) conv5_3 = self.conv5_3(relu5_2) h = F.relu(conv5_3, inplace=True) relu5_3 = h if opt.vgg_choose == "conv4_3": return conv4_3 elif opt.vgg_choose == "relu4_2": return relu4_2 elif opt.vgg_choose == "relu4_1": return relu4_1 elif opt.vgg_choose == "relu4_3": return relu4_3 elif opt.vgg_choose == "conv5_3": return conv5_3 elif opt.vgg_choose == "relu5_1": return relu5_1 elif opt.vgg_choose == "relu5_2": return relu5_2 elif opt.vgg_choose == "relu5_3" or "maxpool": return relu5_3 def vgg_preprocess(batch, opt): tensortype = type(batch.data) (r, g, b) = torch.chunk(batch, 3, dim = 1) batch = torch.cat((b, g, r), dim = 1) # convert RGB to BGR batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255] if opt.vgg_mean: mean = tensortype(batch.data.size()) mean[:, 0, :, :] = 103.939 mean[:, 1, :, :] = 116.779 mean[:, 2, :, :] = 123.680 batch = batch.sub(Variable(mean)) # subtract mean return batch class PerceptualLoss(nn.Module): def __init__(self, opt): super(PerceptualLoss, self).__init__() self.opt = opt self.instancenorm = nn.InstanceNorm2d(512, affine=False) def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img, self.opt) target_vgg = vgg_preprocess(target, self.opt) img_fea = vgg(img_vgg, self.opt) target_fea = vgg(target_vgg, self.opt) if self.opt.no_vgg_instance: return torch.mean((img_fea - target_fea) ** 2) else: return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) def load_vgg16(model_dir, gpu_ids): """ Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """ if not os.path.exists(model_dir): os.mkdir(model_dir) # if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')): # if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')): # os.system('wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_dir, 'vgg16.t7')) # vgglua = load_lua(os.path.join(model_dir, 'vgg16.t7')) # vgg = Vgg16() # for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()): # dst.data[:] = src # torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight')) vgg = Vgg16() # vgg.cuda() vgg.cuda(device=gpu_ids[0]) vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight'))) vgg = torch.nn.DataParallel(vgg, gpu_ids) return vgg class FCN32s(nn.Module): def __init__(self, n_class=21): super(FCN32s, self).__init__() # conv1 self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) self.relu1_1 = nn.ReLU(inplace=True) self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) self.relu1_2 = nn.ReLU(inplace=True) self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 # conv2 self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) self.relu2_1 = nn.ReLU(inplace=True) self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) self.relu2_2 = nn.ReLU(inplace=True) self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 # conv3 self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) self.relu3_1 = nn.ReLU(inplace=True) self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) self.relu3_2 = nn.ReLU(inplace=True) self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) self.relu3_3 = nn.ReLU(inplace=True) self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 # conv4 self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) self.relu4_1 = nn.ReLU(inplace=True) self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) self.relu4_2 = nn.ReLU(inplace=True) self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) self.relu4_3 = nn.ReLU(inplace=True) self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 # conv5 self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) self.relu5_1 = nn.ReLU(inplace=True) self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) self.relu5_2 = nn.ReLU(inplace=True) self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) self.relu5_3 = nn.ReLU(inplace=True) self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 # fc6 self.fc6 = nn.Conv2d(512, 4096, 7) self.relu6 = nn.ReLU(inplace=True) self.drop6 = nn.Dropout2d() # fc7 self.fc7 = nn.Conv2d(4096, 4096, 1) self.relu7 = nn.ReLU(inplace=True) self.drop7 = nn.Dropout2d() self.score_fr = nn.Conv2d(4096, n_class, 1) self.upscore = nn.ConvTranspose2d(n_class, n_class, 64, stride=32, bias=False) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): m.weight.data.zero_() if m.bias is not None: m.bias.data.zero_() if isinstance(m, nn.ConvTranspose2d): assert m.kernel_size[0] == m.kernel_size[1] initial_weight = get_upsampling_weight( m.in_channels, m.out_channels, m.kernel_size[0]) m.weight.data.copy_(initial_weight) def forward(self, x): h = x h = self.relu1_1(self.conv1_1(h)) h = self.relu1_2(self.conv1_2(h)) h = self.pool1(h) h = self.relu2_1(self.conv2_1(h)) h = self.relu2_2(self.conv2_2(h)) h = self.pool2(h) h = self.relu3_1(self.conv3_1(h)) h = self.relu3_2(self.conv3_2(h)) h = self.relu3_3(self.conv3_3(h)) h = self.pool3(h) h = self.relu4_1(self.conv4_1(h)) h = self.relu4_2(self.conv4_2(h)) h = self.relu4_3(self.conv4_3(h)) h = self.pool4(h) h = self.relu5_1(self.conv5_1(h)) h = self.relu5_2(self.conv5_2(h)) h = self.relu5_3(self.conv5_3(h)) h = self.pool5(h) h = self.relu6(self.fc6(h)) h = self.drop6(h) h = self.relu7(self.fc7(h)) h = self.drop7(h) h = self.score_fr(h) h = self.upscore(h) h = h[:, :, 19:19 + x.size()[2], 19:19 + x.size()[3]].contiguous() return h def load_fcn(model_dir): fcn = FCN32s() fcn.load_state_dict(torch.load(os.path.join(model_dir, 'fcn32s_from_caffe.pth'))) fcn.cuda() return fcn class SemanticLoss(nn.Module): def __init__(self, opt): super(SemanticLoss, self).__init__() self.opt = opt self.instancenorm = nn.InstanceNorm2d(21, affine=False) def compute_fcn_loss(self, fcn, img, target): img_fcn = vgg_preprocess(img, self.opt) target_fcn = vgg_preprocess(target, self.opt) img_fea = fcn(img_fcn) target_fea = fcn(target_fcn) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)