import torch.nn as nn import torch from models.util import mydownres2Dblock import numpy as np from models.util import AntiAliasInterpolation2d,make_coordinate_grid from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d import torch.nn.functional as F import copy class PositionalEncoding(nn.Module): def __init__(self, d_hid, n_position=200): super(PositionalEncoding, self).__init__() # Not a parameter self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) def _get_sinusoid_encoding_table(self, n_position, d_hid): ''' Sinusoid position encoding table ''' # TODO: make it with torch instead of numpy def get_position_angle_vec(position): return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 return torch.FloatTensor(sinusoid_table).unsqueeze(0) def forward(self, winsize): return self.pos_table[:, :winsize].clone().detach() def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(F"activation should be relu/gelu, not {activation}.") def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) class Transformer(nn.Module): def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, return_intermediate_dec=True): super().__init__() encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before) encoder_norm = nn.LayerNorm(d_model) if normalize_before else None self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before) decoder_norm = nn.LayerNorm(d_model) self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec) self._reset_parameters() self.d_model = d_model self.nhead = nhead def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self,opt, src, query_embed, pos_embed): # flatten NxCxHxW to HWxNxC src = src.permute(1, 0, 2) pos_embed = pos_embed.permute(1, 0, 2) query_embed = query_embed.permute(1, 0, 2) tgt = torch.zeros_like(query_embed) memory = self.encoder(src, pos=pos_embed) hs = self.decoder(tgt, memory, pos=pos_embed, query_pos=query_embed) return hs class TransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers, norm=None): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward(self, src, mask = None, src_key_padding_mask = None, pos = None): output = src+pos for layer in self.layers: output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos) if self.norm is not None: output = self.norm(output) return output class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): super().__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate def forward(self, tgt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None, memory_key_padding_mask = None, pos = None, query_pos = None): output = tgt+pos+query_pos intermediate = [] for layer in self.layers: output = layer(output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos) if self.return_intermediate: intermediate.append(self.norm(output)) if self.norm is not None: output = self.norm(output) if self.return_intermediate: intermediate.pop() intermediate.append(output) if self.return_intermediate: return torch.stack(intermediate) return output.unsqueeze(0) class TransformerEncoderLayer(nn.Module): def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos): return tensor if pos is None else tensor + pos def forward_post(self, src, src_mask = None, src_key_padding_mask = None, pos = None): # q = k = self.with_pos_embed(src, pos) src2 = self.self_attn(src, src, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(src2) src = self.norm2(src) return src def forward_pre(self, src, src_mask = None, src_key_padding_mask = None, pos = None): src2 = self.norm1(src) # q = k = self.with_pos_embed(src2, pos) src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] src = src + self.dropout1(src2) src2 = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) src = src + self.dropout2(src2) return src def forward(self, src, src_mask = None, src_key_padding_mask = None, pos = None): if self.normalize_before: return self.forward_pre(src, src_mask, src_key_padding_mask, pos) return self.forward_post(src, src_mask, src_key_padding_mask, pos) class TransformerDecoderLayer(nn.Module): def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos): return tensor if pos is None else tensor + pos def forward_post(self, tgt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None, memory_key_padding_mask = None, pos = None, query_pos = None): # q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) tgt2 = self.multihead_attn(query=tgt, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) return tgt def forward_pre(self, tgt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None, memory_key_padding_mask = None, pos = None, query_pos = None): tgt2 = self.norm1(tgt) # q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout1(tgt2) tgt2 = self.norm2(tgt) tgt2 = self.multihead_attn(query=tgt2, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] tgt = tgt + self.dropout2(tgt2) tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout3(tgt2) return tgt def forward(self, tgt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None, memory_key_padding_mask = None, pos = None, query_pos = None): if self.normalize_before: return self.forward_pre(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) return self.forward_post(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) class Audio2kpTransformer(nn.Module): def __init__(self,opt): super(Audio2kpTransformer, self).__init__() self.opt = opt self.embedding = nn.Embedding(41, opt.embedding_dim) self.pos_enc = PositionalEncoding(512,20) self.down_pose = AntiAliasInterpolation2d(1,0.25) input_dim = 2 self.feature_extract = nn.Sequential(mydownres2Dblock(input_dim,32), mydownres2Dblock(32,64), mydownres2Dblock(64,128), mydownres2Dblock(128,256), mydownres2Dblock(256,512), nn.AvgPool2d(2)) self.decode_dim = 70 self.audio_embedding = nn.Sequential(nn.ConvTranspose2d(1, 8, (29, 14), stride=(1, 1), padding=(0, 11)), BatchNorm2d(8), nn.ReLU(inplace=True), nn.Conv2d(8, 35, (13, 13), stride=(1, 1), padding=(6, 6))) self.decodefeature_extract = nn.Sequential(mydownres2Dblock(self.decode_dim,32), mydownres2Dblock(32,64), mydownres2Dblock(64,128), mydownres2Dblock(128,256), mydownres2Dblock(256,512), nn.AvgPool2d(2)) self.transformer = Transformer() self.kp = nn.Linear(512,opt.num_kp*2) self.jacobian = nn.Linear(512,opt.num_kp*4) self.jacobian.weight.data.zero_() self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.opt.num_kp, dtype=torch.float)) self.criterion = nn.L1Loss() def create_sparse_motions(self, source_image, kp_source): """ Eq 4. in the paper T_{s<-d}(z) """ bs, _, h, w = source_image.shape identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type()) identity_grid = identity_grid.view(1, 1, h, w, 2) coordinate_grid = identity_grid if 'jacobian' in kp_source: jacobian = kp_source['jacobian'] jacobian = jacobian.unsqueeze(-3).unsqueeze(-3) jacobian = jacobian.repeat(1, 1, h, w, 1, 1) coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) coordinate_grid = coordinate_grid.squeeze(-1) driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.opt.num_kp, 1, 1, 2) #adding background feature identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1) sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) return sparse_motions.permute(0,1,4,2,3).reshape(bs,(self.opt.num_kp+1)*2,64,64) def forward(self,x, initial_kp = None): bs,seqlen = x["ph"].shape ph = x["ph"].reshape(bs*seqlen,1) pose = x["pose"].reshape(bs*seqlen,1,256,256) input_feature = self.down_pose(pose) phoneme_embedding = self.embedding(ph.long()) phoneme_embedding = phoneme_embedding.reshape(bs*seqlen, 1, 16, 16) phoneme_embedding = F.interpolate(phoneme_embedding, scale_factor=4) input_feature = torch.cat((input_feature, phoneme_embedding), dim=1) input_feature = self.feature_extract(input_feature).unsqueeze(-1).reshape(bs,seqlen,512) audio = x["audio"].reshape(bs * seqlen, 1, 4, 41) decoder_feature = self.audio_embedding(audio) decoder_feature = F.interpolate(decoder_feature, scale_factor=2) decoder_feature = self.decodefeature_extract(torch.cat( (decoder_feature, initial_kp["feature_map"].unsqueeze(1).repeat(1, seqlen, 1, 1, 1).reshape(bs * seqlen, 35, 64, 64)), dim=1)).unsqueeze(-1).reshape(bs, seqlen, 512) posi_em = self.pos_enc(self.opt.num_w*2+1) out = {} output_feature = self.transformer(self.opt,input_feature,decoder_feature,posi_em)[-1,self.opt.num_w] out["value"] = self.kp(output_feature).reshape(bs,self.opt.num_kp,2) out["jacobian"] = self.jacobian(output_feature).reshape(bs,self.opt.num_kp,2,2) return out