Spaces:
Sleeping
Sleeping
from models.image_encoder import ImageEncoder | |
from models.image_decoder import ImageDecoder | |
from models.modality_fusion import ModalityFusion | |
from models.vgg_perceptual_loss import VGGPerceptualLoss | |
from models.transformers import * | |
from torch.autograd import Variable | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class ModelMain(nn.Module): | |
def __init__(self, opts, mode='train'): | |
super().__init__() | |
self.opts = opts | |
self.img_encoder = ImageEncoder(img_size=opts.img_size, input_nc=opts.ref_nshot, ngf=opts.ngf, norm_layer=nn.LayerNorm) | |
self.img_decoder = ImageDecoder(img_size=opts.img_size, input_nc=opts.bottleneck_bits + opts.char_num, output_nc=1, ngf=opts.ngf, norm_layer=nn.LayerNorm) | |
self.vggptlossfunc = VGGPerceptualLoss() | |
self.modality_fusion = ModalityFusion(img_size=opts.img_size, ref_nshot=opts.ref_nshot, bottleneck_bits=opts.bottleneck_bits, ngf=opts.ngf, mode=opts.mode) | |
self.transformer_main = Transformer( | |
input_channels = 1, | |
input_axis = 2, # number of axis for input data (2 for images, 3 for video) | |
num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1) | |
max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is | |
depth = 6, # depth of net. The shape of the final attention mechanism will be: | |
# depth * (cross attention -> self_per_cross_attn * self attention) | |
num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names | |
latent_dim = opts.dim_seq_latent, # latent dimension | |
cross_heads = 1, # number of heads for cross attention. paper said 1 | |
latent_heads = 8, # number of heads for latent self attention, 8 | |
cross_dim_head = 64, # number of dimensions per cross attention head | |
latent_dim_head = 64, # number of dimensions per latent self attention head | |
num_classes = 1000, # output number of classes | |
attn_dropout = 0., | |
ff_dropout = 0., | |
weight_tie_layers = False, # whether to weight tie layers (optional, as indicated in the diagram) | |
fourier_encode_data = True, # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself | |
self_per_cross_attn = 2 # number of self attention blocks per cross attention | |
) | |
self.transformer_seqdec = Transformer_decoder() | |
def forward(self, data, mode='train'): | |
imgs, seqs, scalars = self.fetch_data(data, mode) | |
ref_img, trg_img = imgs | |
ref_seq, ref_seq_cat, ref_pad_mask, trg_seq, trg_seq_gt, trg_seq_shifted, trg_pts_aux = seqs | |
trg_char_onehot, trg_cls, trg_seqlen = scalars | |
# image encoding | |
img_encoder_out = self.img_encoder(ref_img) | |
img_feat = img_encoder_out['img_feat'] # bs, ngf * (2 ** 6) | |
# seq encoding | |
ref_img_ = ref_img.view(ref_img.size(0) * ref_img.size(1), ref_img.size(2), ref_img.size(3)).unsqueeze(-1) # [max_seq_len, n_bs * n_ref, 9] | |
seq_feat, _ = self.transformer_main(ref_img_, ref_seq_cat, mask=ref_pad_mask) # [n_bs * n_ref, max_seq_len + 1, 9] | |
# modality funsion | |
mf_output, latent_feat_seq = self.modality_fusion(seq_feat, img_feat, ref_pad_mask=ref_pad_mask) | |
latent_feat_seq = self.transformer_main.att_residual(latent_feat_seq) # [n_bs, max_seq_len + 1, bottleneck_bits] | |
z = mf_output['latent'] | |
kl_loss = mf_output['kl_loss'] | |
# image decoding | |
img_decoder_out = self.img_decoder(z, trg_char_onehot, trg_img) | |
ret_dict = {} | |
loss_dict = {} | |
ret_dict['img'] = {} | |
ret_dict['img']['out'] = img_decoder_out['gen_imgs'] | |
ret_dict['img']['ref'] = ref_img | |
ret_dict['img']['trg'] = trg_img | |
if mode in {'train', 'val'}: | |
# seq decoding (training or val mode) | |
tgt_mask = Variable(subsequent_mask(self.opts.max_seq_len).type_as(ref_pad_mask.data)).unsqueeze(0).expand(z.size(0), -1, -1, -1).to(device).float() | |
command_logits, args_logits, attn = self.transformer_seqdec(x=trg_seq_shifted, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask) | |
command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(command_logits, args_logits, memory=latent_feat_seq.detach(), trg_char=trg_cls) | |
total_loss = self.transformer_main.loss(command_logits, args_logits,trg_seq, trg_seqlen, trg_pts_aux) | |
total_loss_parallel = self.transformer_main.loss(command_logits_2, args_logits_2, trg_seq, trg_seqlen, trg_pts_aux) | |
vggpt_loss = self.vggptlossfunc(img_decoder_out['gen_imgs'], trg_img) | |
# loss and output | |
loss_svg_items = ['total', 'cmd', 'args', 'smt', 'aux'] | |
# for image | |
loss_dict['img'] = {} | |
loss_dict['img']['l1'] = img_decoder_out['img_l1loss'] | |
loss_dict['img']['vggpt'] = vggpt_loss['pt_c_loss'] | |
# for latent | |
loss_dict['kl'] = kl_loss | |
# for svg | |
loss_dict['svg'] = {} | |
loss_dict['svg_para'] = {} | |
for item in loss_svg_items: | |
loss_dict['svg'][item] = total_loss[f'loss_{item}'] | |
loss_dict['svg_para'][item] = total_loss_parallel[f'loss_{item}'] | |
else: # testing (inference) | |
trg_len = trg_seq_shifted.size(0) | |
sampled_svg = torch.zeros(1, trg_seq.size(1), self.opts.dim_seq_short).to(device) | |
for t in range(0, trg_len): | |
tgt_mask = Variable(subsequent_mask(sampled_svg.size(0)).type_as(ref_seq_cat.data)).unsqueeze(0).expand(sampled_svg.size(1), -1, -1, -1).to(device).float() | |
command_logits, args_logits, attn = self.transformer_seqdec(x=sampled_svg, memory=latent_feat_seq, trg_char=trg_cls, tgt_mask=tgt_mask) | |
prob_comand = F.softmax(command_logits[:, -1, :], -1) | |
prob_args = F.softmax(args_logits[:, -1, :], -1) | |
next_command = torch.argmax(prob_comand, -1).unsqueeze(-1) | |
next_args = torch.argmax(prob_args, -1) | |
predict_tmp = torch.cat((next_command, next_args),-1).unsqueeze(1).transpose(0,1) | |
sampled_svg = torch.cat((sampled_svg, predict_tmp), dim=0) | |
sampled_svg = sampled_svg[1:] | |
cmd2 = sampled_svg[:,:,0].unsqueeze(-1) | |
arg2 = sampled_svg[:,:,1:] | |
command_logits_2, args_logits_2 = self.transformer_seqdec.parallel_decoder(cmd_logits=cmd2, args_logits=arg2, memory=latent_feat_seq, trg_char=trg_cls) | |
prob_comand = F.softmax(command_logits_2,-1) | |
prob_args = F.softmax(args_logits_2,-1) | |
update_command = torch.argmax(prob_comand,-1).unsqueeze(-1) | |
update_args = torch.argmax(prob_args,-1) | |
sampled_svg_parralel = torch.cat((update_command, update_args),-1).transpose(0,1) | |
commands1 = F.one_hot(sampled_svg[:,:,:1].long(), 4).squeeze().transpose(0, 1) | |
args1 = denumericalize(sampled_svg[:,:,1:]).transpose(0,1) | |
sampled_svg_1 = torch.cat([commands1.cpu().detach(),args1[:, :, 2:].cpu().detach()],dim =-1) | |
commands2 = F.one_hot(sampled_svg_parralel[:, :, :1].long(), 4).squeeze().transpose(0, 1) | |
args2 = denumericalize(sampled_svg_parralel[:, :, 1:]).transpose(0,1) | |
sampled_svg_2 = torch.cat([commands2.cpu().detach(),args2[:, :, 2:].cpu().detach()], dim =-1) | |
ret_dict['svg'] = {} | |
ret_dict['svg']['sampled_1'] = sampled_svg_1 | |
ret_dict['svg']['sampled_2'] = sampled_svg_2 | |
ret_dict['svg']['trg'] = trg_seq_gt | |
return ret_dict, loss_dict | |
def fetch_data(self, data, mode): | |
input_image = data['rendered'] # [bs, opts.char_num, opts.img_size, opts.img_size] | |
input_sequence = data['sequence'] # [bs, opts.char_num, opts.max_seq_len] | |
input_seqlen = data['seq_len'] | |
input_seqlen = input_seqlen + 1 | |
input_pts_aux = data['pts_aux'] | |
arg_quant = numericalize(input_sequence[:, :, :, 4:]) | |
cmd_cls = torch.argmax(input_sequence[:, :, :, :4], dim=-1).unsqueeze(-1) | |
input_sequence = torch.cat([cmd_cls, arg_quant], dim=-1) # 1 + 8 = 9 dimension | |
# choose reference classes and target classes | |
if mode == 'train': | |
ref_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), self.opts.ref_nshot)).to(device) | |
if opts.ref_nshot == 52: # For ENG to TH | |
ref_cls_upper = torch.randint(0, 26, (input_image.size(0), self.opts.ref_nshot // 2)).to(device) | |
ref_cls_lower = torch.randint(26, 52, (input_image.size(0), self.opts.ref_nshot // 2)).to(device) | |
ref_cls = torch.cat((ref_cls_upper, ref_cls_lower), -1) | |
elif mode == 'val': | |
ref_cls = torch.arange(0, self.opts.ref_nshot, 1).to(device).unsqueeze(0).expand(input_image.size(0), -1) | |
else: | |
ref_ids = self.opts.ref_char_ids.split(',') | |
ref_ids = list(map(int, ref_ids)) | |
assert len(ref_ids) == self.opts.ref_nshot | |
ref_cls = torch.tensor(ref_ids).to(device).unsqueeze(0).expand(self.opts.char_num, -1) | |
if mode in {'train', 'val'}: | |
trg_cls = torch.randint(0, self.opts.char_num, (input_image.size(0), 1)).to(device) | |
if opts.ref_nshot == 52: | |
trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).to(device) | |
else: | |
trg_cls = torch.arange(0, self.opts.char_num).to(device) | |
if opts.ref_nshot == 52: | |
trg_cls = torch.randint(52, opts.char_num, (input_image.size(0), 1)).to(device) | |
trg_cls = trg_cls.view(self.opts.char_num, 1) | |
input_image = input_image.expand(self.opts.char_num, -1, -1, -1) | |
input_sequence = input_sequence.expand(self.opts.char_num, -1, -1, -1) | |
input_pts_aux = input_pts_aux.expand(self.opts.char_num, -1, -1, -1) | |
input_seqlen = input_seqlen.expand(self.opts.char_num, -1, -1) | |
ref_img = util_funcs.select_imgs(input_image, ref_cls, self.opts) | |
# select a target glyph image | |
trg_img = util_funcs.select_imgs(input_image, trg_cls, self.opts) | |
# randomly select ref vector glyphs | |
ref_seq = util_funcs.select_seqs(input_sequence, ref_cls, self.opts, self.opts.dim_seq_short) # [opts.batch_size, opts.ref_nshot, opts.max_seq_len, opts.dim_seq_nmr] | |
# randomly select a target vector glyph | |
trg_seq = util_funcs.select_seqs(input_sequence, trg_cls, self.opts, self.opts.dim_seq_short) | |
trg_seq = trg_seq.squeeze(1) | |
trg_pts_aux = util_funcs.select_seqs(input_pts_aux, trg_cls, self.opts, opts.n_aux_pts) | |
trg_pts_aux = trg_pts_aux.squeeze(1) | |
# the one-hot target char class | |
trg_char_onehot = util_funcs.trgcls_to_onehot(trg_cls, self.opts) | |
# shift target sequence | |
trg_seq_gt = trg_seq.clone().detach() | |
trg_seq_gt = torch.cat((trg_seq_gt[:, :, :1], trg_seq_gt[:, :, 3:]), -1) | |
trg_seq = trg_seq.transpose(0, 1) | |
trg_seq_shifted = util_funcs.shift_right(trg_seq) | |
ref_seq_cat = ref_seq.view(ref_seq.size(0) * ref_seq.size(1), ref_seq.size(2), ref_seq.size(3)) | |
ref_seq_cat = ref_seq_cat.transpose(0,1) | |
ref_seqlen = util_funcs.select_seqlens(input_seqlen, ref_cls, self.opts) | |
ref_seqlen_cat = ref_seqlen.view(ref_seqlen.size(0) * ref_seqlen.size(1), ref_seqlen.size(2)) | |
ref_pad_mask = torch.zeros(ref_seqlen_cat.size(0), self.opts.max_seq_len) # value = 1 means pos to be masked | |
for i in range(ref_seqlen_cat.size(0)): | |
ref_pad_mask[i,:ref_seqlen_cat[i]] = 1 | |
ref_pad_mask = ref_pad_mask.to(device).float().unsqueeze(1) | |
trg_seqlen = util_funcs.select_seqlens(input_seqlen, trg_cls, self.opts) | |
trg_seqlen = trg_seqlen.squeeze() | |
return [ref_img, trg_img], [ref_seq, ref_seq_cat, ref_pad_mask, trg_seq, trg_seq_gt, trg_seq_shifted, trg_pts_aux], [trg_char_onehot, trg_cls, trg_seqlen] |