File size: 1,962 Bytes
b762e56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ImageDecoder(nn.Module):
    def __init__(self, img_size, input_nc, output_nc, ngf=16, norm_layer=nn.LayerNorm):

        super(ImageDecoder, self).__init__()
        n_upsampling = int(math.log(img_size, 2))
        ks_list = [3] * (n_upsampling // 3) + [5] * (n_upsampling - n_upsampling // 3)
        stride_list = [2] * n_upsampling
        decoder = []

        chn_mult = []
        for i in range(n_upsampling):
            chn_mult.append(2 ** (n_upsampling - i - 1))
        
        decoder += [nn.ConvTranspose2d(input_nc, chn_mult[0] * ngf,
                       kernel_size=ks_list[0], stride=stride_list[0],
                       padding=ks_list[0] // 2, output_padding=stride_list[0]-1),
                       norm_layer([chn_mult[0] * ngf, 2, 2]),
                       nn.ReLU(True)]

        for i in range(1, n_upsampling):  # add upsampling layers
            chn_prev = chn_mult[i - 1] * ngf
            chn_next = chn_mult[i] * ngf
            decoder += [nn.ConvTranspose2d(chn_prev, chn_next, kernel_size=ks_list[i], stride=stride_list[i], padding=ks_list[i] // 2, output_padding=stride_list[i]-1),
                            norm_layer([chn_next, 2 ** (i+1) , 2 ** (i+1)]),
                            nn.ReLU(True)]
        
        decoder += [nn.Conv2d(chn_mult[-1] * ngf, output_nc, kernel_size=7, padding=7 // 2)]
        decoder += [nn.Sigmoid()]
        self.decode = nn.Sequential(*decoder)

    def forward(self, latent_feat, trg_char, trg_img=None):
        """Standard forward"""
        dec_input = torch.cat((latent_feat, trg_char),-1)
        dec_input = dec_input.view(dec_input.size(0), dec_input.size(1), 1, 1)
        dec_out = self.decode(dec_input)
        output = {}
        output['gen_imgs'] = dec_out
        if trg_img is not None:
            output['img_l1loss'] = F.l1_loss(dec_out, trg_img)

        return output