Spaces:
Sleeping
Sleeping
File size: 12,442 Bytes
b762e56 86e64e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
from models.image_encoder import ImageEncoder
from models.image_decoder import ImageDecoder
from models.modality_fusion import ModalityFusion
from models.vgg_perceptual_loss import VGGPerceptualLoss
from models.transformers import *
from torch.autograd import Variable
class ModelMain(nn.Module):
def __init__(self, opts, mode='train'):
super().__init__()
self.opts = opts
self.img_encoder = ImageEncoder(img_size=opts.img_size, input_nc=opts.ref_nshot, ngf=opts.ngf, norm_layer=nn.LayerNorm)
self.img_decoder = ImageDecoder(img_size=opts.img_size, input_nc=opts.bottleneck_bits + opts.char_num, output_nc=1, ngf=opts.ngf, norm_layer=nn.LayerNorm)
self.vggptlossfunc = VGGPerceptualLoss()
self.modality_fusion = ModalityFusion(img_size=opts.img_size, ref_nshot=opts.ref_nshot, bottleneck_bits=opts.bottleneck_bits, ngf=opts.ngf, mode=opts.mode)
self.transformer_main = Transformer(
input_channels = 1,
input_axis = 2, # number of axis for input data (2 for images, 3 for video)
num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1)
max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is
depth = 6, # depth of net. The shape of the final attention mechanism will be:
# depth * (cross attention -> self_per_cross_attn * self attention)
num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names
latent_dim = opts.dim_seq_latent, # latent dimension
cross_heads = 1, # number of heads for cross attention. paper said 1
latent_heads = 8, # number of heads for latent self attention, 8
cross_dim_head = 64, # number of dimensions per cross attention head
latent_dim_head = 64, # number of dimensions per latent self attention head
num_classes = 1000, # output number of classes
attn_dropout = 0.,
ff_dropout = 0.,
weight_tie_layers = False, # whether to weight tie layers (optional, as indicated in the diagram)
fourier_encode_data = True, # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself
self_per_cross_attn = 2 # number of self attention blocks per cross attention
)
self.transformer_seqdec = Transformer_decoder()
def forward(self, data, mode='train'):
imgs, seqs, scalars = self.fetch_data(data, mode)
ref_img, trg_img = imgs
ref_seq, ref_seq_cat, ref_pad_mask, trg_seq, trg_seq_gt, trg_seq_shifted, trg_pts_aux = seqs
trg_char_onehot, trg_cls, trg_seqlen = scalars
# image encoding
img_encoder_out = self.img_encoder(ref_img)
img_feat = img_encoder_out['img_feat'] # bs, ngf * (2 ** 6)
# seq encoding
ref_img_ = ref_img.view(ref_img.size(0) * ref_img.size(1), ref_img.size(2), ref_img.size(3)).unsqueeze(-1) # [max_seq_len, n_bs * n_ref, 9]
seq_feat, _ = self.transformer_main(ref_img_, ref_seq_cat, mask=ref_pad_mask) # [n_bs * n_ref, max_seq_len + 1, 9]
# modality funsion
mf_output, latent_feat_seq = self.modality_fusion(seq_feat, img_feat, ref_pad_mask=ref_pad_mask)
latent_feat_seq = self.transformer_main.att_residual(latent_feat_seq) # [n_bs, max_seq_len + 1, bottleneck_bits]
z = mf_output['latent']
kl_loss = mf_output['kl_loss']
# image decoding
img_decoder_out = self.img_decoder(z, trg_char_onehot, trg_img)
ret_dict = {}
loss_dict = {}
ret_dict['img'] = {}
ret_dict['img']['out'] = img_decoder_out['gen_imgs']
ret_dict['img']['ref'] = ref_img
ret_dict['img']['trg'] = trg_img
if mode in {'train', 'val'}:
# seq decoding (training or val mode)
tgt_mask = Variable(subsequent_mask(self.opts.max_seq_len).type_as(ref_pad_mask.data)).unsqueeze(0).expand(z.size(0), -1, -1, -1).cuda().float()
command_logits, args_logits, attn = self.transformer_seqdec(x=trg_seq_shifted, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(command_logits, args_logits, memory=latent_feat_seq.detach(), trg_char=trg_cls)
total_loss = self.transformer_main.loss(command_logits, args_logits,trg_seq, trg_seqlen, trg_pts_aux)
total_loss_parallel = self.transformer_main.loss(command_logits_2, args_logits_2, trg_seq, trg_seqlen, trg_pts_aux)
vggpt_loss = self.vggptlossfunc(img_decoder_out['gen_imgs'], trg_img)
# loss and output
loss_svg_items = ['total', 'cmd', 'args', 'smt', 'aux']
# for image
loss_dict['img'] = {}
loss_dict['img']['l1'] = img_decoder_out['img_l1loss']
loss_dict['img']['vggpt'] = vggpt_loss['pt_c_loss']
# for latent
loss_dict['kl'] = kl_loss
# for svg
loss_dict['svg'] = {}
loss_dict['svg_para'] = {}
for item in loss_svg_items:
loss_dict['svg'][item] = total_loss[f'loss_{item}']
loss_dict['svg_para'][item] = total_loss_parallel[f'loss_{item}']
else: # testing (inference)
trg_len = trg_seq_shifted.size(0)
sampled_svg = torch.zeros(1, trg_seq.size(1), self.opts.dim_seq_short).cuda()
for t in range(0, trg_len):
tgt_mask = Variable(subsequent_mask(sampled_svg.size(0)).type_as(ref_seq_cat.data)).unsqueeze(0).expand(sampled_svg.size(1), -1, -1, -1).cuda().float()
command_logits, args_logits, attn = self.transformer_seqdec(x=sampled_svg, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask)
prob_comand = F.softmax(command_logits[:, -1, :], -1)
prob_args = F.softmax(args_logits[:, -1, :], -1)
next_command = torch.argmax(prob_comand, -1).unsqueeze(-1)
next_args = torch.argmax(prob_args, -1)
predict_tmp = torch.cat((next_command, next_args),-1).unsqueeze(1).transpose(0,1)
sampled_svg = torch.cat((sampled_svg, predict_tmp), dim=0)
sampled_svg = sampled_svg[1:]
cmd2 = sampled_svg[:,:,0].unsqueeze(-1)
arg2 = sampled_svg[:,:,1:]
command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(cmd_logits=cmd2, args_logits=arg2, memory=latent_feat_seq, trg_char=trg_cls)
prob_comand = F.softmax(command_logits_2,-1)
prob_args = F.softmax(args_logits_2,-1)
update_command = torch.argmax(prob_comand,-1).unsqueeze(-1)
update_args = torch.argmax(prob_args,-1)
sampled_svg_parralel = torch.cat((update_command, update_args),-1).transpose(0,1)
commands1 = F.one_hot(sampled_svg[:,:,:1].long(), 4).squeeze().transpose(0, 1)
args1 = denumericalize(sampled_svg[:,:,1:]).transpose(0,1)
sampled_svg_1 = torch.cat([commands1.cpu().detach(),args1[:, :, 2:].cpu().detach()],dim =-1)
commands2 = F.one_hot(sampled_svg_parralel[:, :, :1].long(), 4).squeeze().transpose(0, 1)
args2 = denumericalize(sampled_svg_parralel[:, :, 1:]).transpose(0,1)
sampled_svg_2 = torch.cat([commands2.cpu().detach(),args2[:, :, 2:].cpu().detach()], dim =-1)
ret_dict['svg'] = {}
ret_dict['svg']['sampled_1'] = sampled_svg_1
ret_dict['svg']['sampled_2'] = sampled_svg_2
ret_dict['svg']['trg'] = trg_seq_gt
return ret_dict, loss_dict
def fetch_data(self, data, mode):
input_image = data['rendered'] # [bs, opts.char_num, opts.img_size, opts.img_size]
input_sequence = data['sequence'] # [bs, opts.char_num, opts.max_seq_len]
input_seqlen = data['seq_len']
input_seqlen = input_seqlen + 1
input_pts_aux = data['pts_aux']
arg_quant = numericalize(input_sequence[:, :, :, 4:])
cmd_cls = torch.argmax(input_sequence[:, :, :, :4], dim=-1).unsqueeze(-1)
input_sequence = torch.cat([cmd_cls, arg_quant], dim=-1) # 1 + 8 = 9 dimension
# choose reference classes and target classes
if mode == 'train':
ref_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), self.opts.ref_nshot)).cuda()
if opts.ref_nshot == 52: # For ENG to TH
ref_cls_upper = torch.randint(0, 26, (input_image.size(0), self.opts.ref_nshot // 2)).cuda()
ref_cls_lower = torch.randint(26, 52, (input_image.size(0), self.opts.ref_nshot // 2)).cuda()
ref_cls = torch.cat((ref_cls_upper, ref_cls_lower), -1)
elif mode == 'val':
ref_cls = torch.arange(0, self.opts.ref_nshot, 1).cuda().unsqueeze(0).expand(input_image.size(0), -1)
else:
ref_ids = self.opts.ref_char_ids.split(',')
ref_ids = list(map(int, ref_ids))
assert len(ref_ids) == self.opts.ref_nshot
ref_cls = torch.tensor(ref_ids).cuda().unsqueeze(0).expand(self.opts.char_num, -1)
if mode in {'train', 'val'}:
trg_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), 1)).cuda()
if opts.ref_nshot == 52:
trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).cuda()
else:
trg_cls = torch.arange(0, self.opts.char_num).cuda()
if opts.ref_nshot == 52:
trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).cuda()
trg_cls = trg_cls.view(self.opts.char_num, 1)
input_image = input_image.expand(self.opts.char_num, -1, -1, -1)
input_sequence = input_sequence.expand(self.opts.char_num, -1, -1, -1)
input_pts_aux = input_pts_aux.expand(self.opts.char_num, -1, -1, -1)
input_seqlen = input_seqlen.expand(self.opts.char_num, -1, -1)
ref_img = util_funcs.select_imgs(input_image, ref_cls, self.opts)
# select a target glyph image
trg_img = util_funcs.select_imgs(input_image, trg_cls, self.opts)
# randomly select ref vector glyphs
ref_seq = util_funcs.select_seqs(input_sequence, ref_cls, self.opts, self.opts.dim_seq_short) # [opts.batch_size, opts.ref_nshot, opts.max_seq_len, opts.dim_seq_nmr]
# randomly select a target vector glyph
trg_seq = util_funcs.select_seqs(input_sequence, trg_cls, self.opts, self.opts.dim_seq_short)
trg_seq = trg_seq.squeeze(1)
trg_pts_aux = util_funcs.select_seqs(input_pts_aux, trg_cls, self.opts, opts.n_aux_pts)
trg_pts_aux = trg_pts_aux.squeeze(1)
# the one-hot target char class
trg_char_onehot = util_funcs.trgcls_to_onehot(trg_cls, self.opts)
# shift target sequence
trg_seq_gt = trg_seq.clone().detach()
trg_seq_gt = torch.cat((trg_seq_gt[:, :, :1], trg_seq_gt[:, :, 3:]), -1)
trg_seq = trg_seq.transpose(0, 1)
trg_seq_shifted = util_funcs.shift_right(trg_seq)
ref_seq_cat = ref_seq.view(ref_seq.size(0) * ref_seq.size(1), ref_seq.size(2), ref_seq.size(3))
ref_seq_cat = ref_seq_cat.transpose(0,1)
ref_seqlen = util_funcs.select_seqlens(input_seqlen, ref_cls, self.opts)
ref_seqlen_cat = ref_seqlen.view(ref_seqlen.size(0) * ref_seqlen.size(1), ref_seqlen.size(2))
ref_pad_mask = torch.zeros(ref_seqlen_cat.size(0), self.opts.max_seq_len) # value = 1 means pos to be masked
for i in range(ref_seqlen_cat.size(0)):
ref_pad_mask[i,:ref_seqlen_cat[i]] = 1
ref_pad_mask = ref_pad_mask.cuda().float().unsqueeze(1)
trg_seqlen = util_funcs.select_seqlens(input_seqlen, trg_cls, self.opts)
trg_seqlen = trg_seqlen.squeeze()
return [ref_img, trg_img], [ref_seq, ref_seq_cat, ref_pad_mask, trg_seq, trg_seq_gt, trg_seq_shifted, trg_pts_aux], [trg_char_onehot, trg_cls, trg_seqlen] |