import torch import torch.nn as nn import torch.nn.functional as F import math class ImageDecoder(nn.Module): def __init__(self, img_size, input_nc, output_nc, ngf=16, norm_layer=nn.LayerNorm): super(ImageDecoder, self).__init__() n_upsampling = int(math.log(img_size, 2)) ks_list = [3] * (n_upsampling // 3) + [5] * (n_upsampling - n_upsampling // 3) stride_list = [2] * n_upsampling decoder = [] chn_mult = [] for i in range(n_upsampling): chn_mult.append(2 ** (n_upsampling - i - 1)) decoder += [nn.ConvTranspose2d(input_nc, chn_mult[0] * ngf, kernel_size=ks_list[0], stride=stride_list[0], padding=ks_list[0] // 2, output_padding=stride_list[0]-1), norm_layer([chn_mult[0] * ngf, 2, 2]), nn.ReLU(True)] for i in range(1, n_upsampling): # add upsampling layers chn_prev = chn_mult[i - 1] * ngf chn_next = chn_mult[i] * ngf decoder += [nn.ConvTranspose2d(chn_prev, chn_next, kernel_size=ks_list[i], stride=stride_list[i], padding=ks_list[i] // 2, output_padding=stride_list[i]-1), norm_layer([chn_next, 2 ** (i+1) , 2 ** (i+1)]), nn.ReLU(True)] decoder += [nn.Conv2d(chn_mult[-1] * ngf, output_nc, kernel_size=7, padding=7 // 2)] decoder += [nn.Sigmoid()] self.decode = nn.Sequential(*decoder) def forward(self, latent_feat, trg_char, trg_img=None): """Standard forward""" dec_input = torch.cat((latent_feat, trg_char),-1) dec_input = dec_input.view(dec_input.size(0), dec_input.size(1), 1, 1) dec_out = self.decode(dec_input) output = {} output['gen_imgs'] = dec_out if trg_img is not None: output['img_l1loss'] = F.l1_loss(dec_out, trg_img) return output