|
import torch.nn as nn |
|
import torch |
|
from models.util import mydownres2Dblock |
|
import numpy as np |
|
from models.util import AntiAliasInterpolation2d,make_coordinate_grid |
|
from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d |
|
import torch.nn.functional as F |
|
import copy |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
|
def __init__(self, d_hid, n_position=200): |
|
super(PositionalEncoding, self).__init__() |
|
|
|
|
|
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) |
|
|
|
def _get_sinusoid_encoding_table(self, n_position, d_hid): |
|
''' Sinusoid position encoding table ''' |
|
|
|
|
|
def get_position_angle_vec(position): |
|
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] |
|
|
|
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) |
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) |
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) |
|
|
|
return torch.FloatTensor(sinusoid_table).unsqueeze(0) |
|
|
|
def forward(self, winsize): |
|
return self.pos_table[:, :winsize].clone().detach() |
|
|
|
def _get_activation_fn(activation): |
|
"""Return an activation function given a string""" |
|
if activation == "relu": |
|
return F.relu |
|
if activation == "gelu": |
|
return F.gelu |
|
if activation == "glu": |
|
return F.glu |
|
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") |
|
|
|
def _get_clones(module, N): |
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
|
class Transformer(nn.Module): |
|
|
|
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, |
|
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, |
|
activation="relu", normalize_before=False, |
|
return_intermediate_dec=True): |
|
super().__init__() |
|
|
|
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, |
|
dropout, activation, normalize_before) |
|
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None |
|
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) |
|
|
|
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, |
|
dropout, activation, normalize_before) |
|
decoder_norm = nn.LayerNorm(d_model) |
|
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, |
|
return_intermediate=return_intermediate_dec) |
|
|
|
self._reset_parameters() |
|
|
|
self.d_model = d_model |
|
self.nhead = nhead |
|
|
|
def _reset_parameters(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def forward(self,opt, src, query_embed, pos_embed): |
|
|
|
|
|
src = src.permute(1, 0, 2) |
|
pos_embed = pos_embed.permute(1, 0, 2) |
|
query_embed = query_embed.permute(1, 0, 2) |
|
|
|
tgt = torch.zeros_like(query_embed) |
|
memory = self.encoder(src, pos=pos_embed) |
|
|
|
hs = self.decoder(tgt, memory, |
|
pos=pos_embed, query_pos=query_embed) |
|
return hs |
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
|
|
def __init__(self, encoder_layer, num_layers, norm=None): |
|
super().__init__() |
|
self.layers = _get_clones(encoder_layer, num_layers) |
|
self.num_layers = num_layers |
|
self.norm = norm |
|
|
|
def forward(self, src, mask = None, src_key_padding_mask = None, pos = None): |
|
output = src+pos |
|
|
|
for layer in self.layers: |
|
output = layer(output, src_mask=mask, |
|
src_key_padding_mask=src_key_padding_mask, pos=pos) |
|
|
|
if self.norm is not None: |
|
output = self.norm(output) |
|
|
|
return output |
|
|
|
|
|
class TransformerDecoder(nn.Module): |
|
|
|
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): |
|
super().__init__() |
|
self.layers = _get_clones(decoder_layer, num_layers) |
|
self.num_layers = num_layers |
|
self.norm = norm |
|
self.return_intermediate = return_intermediate |
|
|
|
def forward(self, tgt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None, |
|
memory_key_padding_mask = None, |
|
pos = None, |
|
query_pos = None): |
|
output = tgt+pos+query_pos |
|
|
|
intermediate = [] |
|
|
|
for layer in self.layers: |
|
output = layer(output, memory, tgt_mask=tgt_mask, |
|
memory_mask=memory_mask, |
|
tgt_key_padding_mask=tgt_key_padding_mask, |
|
memory_key_padding_mask=memory_key_padding_mask, |
|
pos=pos, query_pos=query_pos) |
|
if self.return_intermediate: |
|
intermediate.append(self.norm(output)) |
|
|
|
if self.norm is not None: |
|
output = self.norm(output) |
|
if self.return_intermediate: |
|
intermediate.pop() |
|
intermediate.append(output) |
|
|
|
if self.return_intermediate: |
|
return torch.stack(intermediate) |
|
|
|
return output.unsqueeze(0) |
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
|
|
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
|
activation="relu", normalize_before=False): |
|
super().__init__() |
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
self.dropout = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
|
|
self.activation = _get_activation_fn(activation) |
|
self.normalize_before = normalize_before |
|
|
|
def with_pos_embed(self, tensor, pos): |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward_post(self, |
|
src, |
|
src_mask = None, |
|
src_key_padding_mask = None, |
|
pos = None): |
|
|
|
src2 = self.self_attn(src, src, value=src, attn_mask=src_mask, |
|
key_padding_mask=src_key_padding_mask)[0] |
|
src = src + self.dropout1(src2) |
|
src = self.norm1(src) |
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) |
|
src = src + self.dropout2(src2) |
|
src = self.norm2(src) |
|
return src |
|
|
|
def forward_pre(self, src, |
|
src_mask = None, |
|
src_key_padding_mask = None, |
|
pos = None): |
|
src2 = self.norm1(src) |
|
|
|
src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask, |
|
key_padding_mask=src_key_padding_mask)[0] |
|
src = src + self.dropout1(src2) |
|
src2 = self.norm2(src) |
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) |
|
src = src + self.dropout2(src2) |
|
return src |
|
|
|
def forward(self, src, |
|
src_mask = None, |
|
src_key_padding_mask = None, |
|
pos = None): |
|
if self.normalize_before: |
|
return self.forward_pre(src, src_mask, src_key_padding_mask, pos) |
|
return self.forward_post(src, src_mask, src_key_padding_mask, pos) |
|
|
|
|
|
class TransformerDecoderLayer(nn.Module): |
|
|
|
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
|
activation="relu", normalize_before=False): |
|
super().__init__() |
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
self.dropout = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
self.norm3 = nn.LayerNorm(d_model) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
self.dropout3 = nn.Dropout(dropout) |
|
|
|
self.activation = _get_activation_fn(activation) |
|
self.normalize_before = normalize_before |
|
|
|
def with_pos_embed(self, tensor, pos): |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward_post(self, tgt, memory, |
|
tgt_mask = None, |
|
memory_mask = None, |
|
tgt_key_padding_mask = None, |
|
memory_key_padding_mask = None, |
|
pos = None, |
|
query_pos = None): |
|
|
|
tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask, |
|
key_padding_mask=tgt_key_padding_mask)[0] |
|
tgt = tgt + self.dropout1(tgt2) |
|
tgt = self.norm1(tgt) |
|
tgt2 = self.multihead_attn(query=tgt, |
|
key=memory, |
|
value=memory, attn_mask=memory_mask, |
|
key_padding_mask=memory_key_padding_mask)[0] |
|
tgt = tgt + self.dropout2(tgt2) |
|
tgt = self.norm2(tgt) |
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) |
|
tgt = tgt + self.dropout3(tgt2) |
|
tgt = self.norm3(tgt) |
|
return tgt |
|
|
|
def forward_pre(self, tgt, memory, |
|
tgt_mask = None, |
|
memory_mask = None, |
|
tgt_key_padding_mask = None, |
|
memory_key_padding_mask = None, |
|
pos = None, |
|
query_pos = None): |
|
tgt2 = self.norm1(tgt) |
|
|
|
tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask, |
|
key_padding_mask=tgt_key_padding_mask)[0] |
|
tgt = tgt + self.dropout1(tgt2) |
|
tgt2 = self.norm2(tgt) |
|
tgt2 = self.multihead_attn(query=tgt2, |
|
key=memory, |
|
value=memory, attn_mask=memory_mask, |
|
key_padding_mask=memory_key_padding_mask)[0] |
|
tgt = tgt + self.dropout2(tgt2) |
|
tgt2 = self.norm3(tgt) |
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) |
|
tgt = tgt + self.dropout3(tgt2) |
|
return tgt |
|
|
|
def forward(self, tgt, memory, |
|
tgt_mask = None, |
|
memory_mask = None, |
|
tgt_key_padding_mask = None, |
|
memory_key_padding_mask = None, |
|
pos = None, |
|
query_pos = None): |
|
if self.normalize_before: |
|
return self.forward_pre(tgt, memory, tgt_mask, memory_mask, |
|
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) |
|
return self.forward_post(tgt, memory, tgt_mask, memory_mask, |
|
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) |
|
|
|
|
|
|
|
class Audio2kpTransformer(nn.Module): |
|
def __init__(self,opt): |
|
super(Audio2kpTransformer, self).__init__() |
|
self.opt = opt |
|
|
|
|
|
self.embedding = nn.Embedding(41, opt.embedding_dim) |
|
self.pos_enc = PositionalEncoding(512,20) |
|
self.down_pose = AntiAliasInterpolation2d(1,0.25) |
|
input_dim = 2 |
|
self.feature_extract = nn.Sequential(mydownres2Dblock(input_dim,32), |
|
mydownres2Dblock(32,64), |
|
mydownres2Dblock(64,128), |
|
mydownres2Dblock(128,256), |
|
mydownres2Dblock(256,512), |
|
nn.AvgPool2d(2)) |
|
|
|
self.decode_dim = 70 |
|
self.audio_embedding = nn.Sequential(nn.ConvTranspose2d(1, 8, (29, 14), stride=(1, 1), padding=(0, 11)), |
|
BatchNorm2d(8), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(8, 35, (13, 13), stride=(1, 1), padding=(6, 6))) |
|
self.decodefeature_extract = nn.Sequential(mydownres2Dblock(self.decode_dim,32), |
|
mydownres2Dblock(32,64), |
|
mydownres2Dblock(64,128), |
|
mydownres2Dblock(128,256), |
|
mydownres2Dblock(256,512), |
|
nn.AvgPool2d(2)) |
|
|
|
self.transformer = Transformer() |
|
self.kp = nn.Linear(512,opt.num_kp*2) |
|
self.jacobian = nn.Linear(512,opt.num_kp*4) |
|
self.jacobian.weight.data.zero_() |
|
self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.opt.num_kp, dtype=torch.float)) |
|
self.criterion = nn.L1Loss() |
|
|
|
def create_sparse_motions(self, source_image, kp_source): |
|
""" |
|
Eq 4. in the paper T_{s<-d}(z) |
|
""" |
|
bs, _, h, w = source_image.shape |
|
identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type()) |
|
identity_grid = identity_grid.view(1, 1, h, w, 2) |
|
coordinate_grid = identity_grid |
|
if 'jacobian' in kp_source: |
|
jacobian = kp_source['jacobian'] |
|
jacobian = jacobian.unsqueeze(-3).unsqueeze(-3) |
|
jacobian = jacobian.repeat(1, 1, h, w, 1, 1) |
|
coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) |
|
coordinate_grid = coordinate_grid.squeeze(-1) |
|
|
|
driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.opt.num_kp, 1, 1, 2) |
|
|
|
|
|
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1) |
|
sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) |
|
|
|
return sparse_motions.permute(0,1,4,2,3).reshape(bs,(self.opt.num_kp+1)*2,64,64) |
|
|
|
|
|
|
|
def forward(self,x, initial_kp = None): |
|
bs,seqlen = x["ph"].shape |
|
ph = x["ph"].reshape(bs*seqlen,1) |
|
pose = x["pose"].reshape(bs*seqlen,1,256,256) |
|
input_feature = self.down_pose(pose) |
|
|
|
phoneme_embedding = self.embedding(ph.long()) |
|
phoneme_embedding = phoneme_embedding.reshape(bs*seqlen, 1, 16, 16) |
|
phoneme_embedding = F.interpolate(phoneme_embedding, scale_factor=4) |
|
input_feature = torch.cat((input_feature, phoneme_embedding), dim=1) |
|
|
|
input_feature = self.feature_extract(input_feature).unsqueeze(-1).reshape(bs,seqlen,512) |
|
|
|
audio = x["audio"].reshape(bs * seqlen, 1, 4, 41) |
|
decoder_feature = self.audio_embedding(audio) |
|
decoder_feature = F.interpolate(decoder_feature, scale_factor=2) |
|
decoder_feature = self.decodefeature_extract(torch.cat( |
|
(decoder_feature, |
|
initial_kp["feature_map"].unsqueeze(1).repeat(1, seqlen, 1, 1, 1).reshape(bs * seqlen, 35, 64, 64)), |
|
dim=1)).unsqueeze(-1).reshape(bs, seqlen, 512) |
|
|
|
posi_em = self.pos_enc(self.opt.num_w*2+1) |
|
|
|
|
|
out = {} |
|
|
|
output_feature = self.transformer(self.opt,input_feature,decoder_feature,posi_em)[-1,self.opt.num_w] |
|
|
|
out["value"] = self.kp(output_feature).reshape(bs,self.opt.num_kp,2) |
|
out["jacobian"] = self.jacobian(output_feature).reshape(bs,self.opt.num_kp,2,2) |
|
|
|
return out |
|
|
|
|
|
|