Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
from options import get_parser_main_model | |
opts = get_parser_main_model().parse_args() | |
def init_weights(m): | |
for name, param in m.named_parameters(): | |
nn.init.uniform_(param.data, -0.08, 0.08) | |
class ModalityFusion(nn.Module): | |
def __init__(self, img_size=64, ref_nshot=4, bottleneck_bits=512, ngf=32, seq_latent_dim=512, mode='train'): | |
super().__init__() | |
self.mode = mode | |
self.bottleneck_bits = bottleneck_bits | |
self.ref_nshot = ref_nshot | |
self.mode = mode | |
self.fc_merge = nn.Linear(seq_latent_dim * opts.ref_nshot, 512) | |
n_downsampling = int(math.log(img_size, 2)) | |
mult_max = 2 ** (n_downsampling) | |
self.fc_fusion = nn.Linear(ngf * mult_max + seq_latent_dim, opts.bottleneck_bits * 2, bias=True) # the max multiplier for img feat channels is | |
def forward(self, seq_feat, img_feat, ref_pad_mask=None): | |
cls_one_pad = torch.ones((1,1,1)).to(seq_feat.device).repeat(seq_feat.size(0),1,1) | |
ref_pad_mask = torch.cat([cls_one_pad,ref_pad_mask],dim=-1) | |
seq_feat = seq_feat * (ref_pad_mask.transpose(1, 2)) | |
seq_feat_ = seq_feat.view(seq_feat.size(0) // self.ref_nshot, self.ref_nshot,seq_feat.size(-2) , seq_feat.size(-1)) | |
seq_feat_ = seq_feat_.transpose(1, 2) | |
seq_feat_ = seq_feat_.contiguous().view(seq_feat_.size(0), seq_feat_.size(1), seq_feat_.size(2) * seq_feat_.size(3)) | |
seq_feat_ = self.fc_merge(seq_feat_) | |
seq_feat_cls = seq_feat_[:, 0] | |
feat_cat = torch.cat((img_feat, seq_feat_cls),-1) | |
dist_param = self.fc_fusion(feat_cat) | |
output = {} | |
mu = dist_param[..., :self.bottleneck_bits] | |
log_sigma = dist_param[..., self.bottleneck_bits:] | |
if self.mode == 'train': | |
# calculate the kl loss and reparamerize latent code | |
epsilon = torch.randn(*mu.size(), device=mu.device) | |
z = mu + torch.exp(log_sigma / 2) * epsilon | |
kl = 0.5 * torch.mean(torch.exp(log_sigma) + torch.square(mu) - 1. - log_sigma) | |
output['latent'] = z | |
output['kl_loss'] = kl | |
seq_feat_[:, 0] = z | |
latent_feat_seq = seq_feat_ | |
else: | |
output['latent'] = mu | |
output['kl_loss'] = 0.0 | |
seq_feat_[:, 0] = mu | |
latent_feat_seq = seq_feat_ | |
return output, latent_feat_seq | |