one-shot-talking-face / models /transformer.py
DmitrMakeev's picture
Upload 7 files
02cacbe
raw
history blame
15.9 kB
import torch.nn as nn
import torch
from models.util import mydownres2Dblock
import numpy as np
from models.util import AntiAliasInterpolation2d,make_coordinate_grid
from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
import torch.nn.functional as F
import copy
class PositionalEncoding(nn.Module):
def __init__(self, d_hid, n_position=200):
super(PositionalEncoding, self).__init__()
# Not a parameter
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
def _get_sinusoid_encoding_table(self, n_position, d_hid):
''' Sinusoid position encoding table '''
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
def forward(self, winsize):
return self.pos_table[:, :winsize].clone().detach()
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class Transformer(nn.Module):
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False,
return_intermediate_dec=True):
super().__init__()
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
return_intermediate=return_intermediate_dec)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self,opt, src, query_embed, pos_embed):
# flatten NxCxHxW to HWxNxC
src = src.permute(1, 0, 2)
pos_embed = pos_embed.permute(1, 0, 2)
query_embed = query_embed.permute(1, 0, 2)
tgt = torch.zeros_like(query_embed)
memory = self.encoder(src, pos=pos_embed)
hs = self.decoder(tgt, memory,
pos=pos_embed, query_pos=query_embed)
return hs
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, src, mask = None, src_key_padding_mask = None, pos = None):
output = src+pos
for layer in self.layers:
output = layer(output, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, pos=pos)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(self, tgt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None,
memory_key_padding_mask = None,
pos = None,
query_pos = None):
output = tgt+pos+query_pos
intermediate = []
for layer in self.layers:
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos, query_pos=query_pos)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output.unsqueeze(0)
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos):
return tensor if pos is None else tensor + pos
def forward_post(self,
src,
src_mask = None,
src_key_padding_mask = None,
pos = None):
# q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(src, src, value=src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(self, src,
src_mask = None,
src_key_padding_mask = None,
pos = None):
src2 = self.norm1(src)
# q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(self, src,
src_mask = None,
src_key_padding_mask = None,
pos = None):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory,
tgt_mask = None,
memory_mask = None,
tgt_key_padding_mask = None,
memory_key_padding_mask = None,
pos = None,
query_pos = None):
# q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(query=tgt,
key=memory,
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(self, tgt, memory,
tgt_mask = None,
memory_mask = None,
tgt_key_padding_mask = None,
memory_key_padding_mask = None,
pos = None,
query_pos = None):
tgt2 = self.norm1(tgt)
# q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(query=tgt2,
key=memory,
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(self, tgt, memory,
tgt_mask = None,
memory_mask = None,
tgt_key_padding_mask = None,
memory_key_padding_mask = None,
pos = None,
query_pos = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
class Audio2kpTransformer(nn.Module):
def __init__(self,opt):
super(Audio2kpTransformer, self).__init__()
self.opt = opt
self.embedding = nn.Embedding(41, opt.embedding_dim)
self.pos_enc = PositionalEncoding(512,20)
self.down_pose = AntiAliasInterpolation2d(1,0.25)
input_dim = 2
self.feature_extract = nn.Sequential(mydownres2Dblock(input_dim,32),
mydownres2Dblock(32,64),
mydownres2Dblock(64,128),
mydownres2Dblock(128,256),
mydownres2Dblock(256,512),
nn.AvgPool2d(2))
self.decode_dim = 70
self.audio_embedding = nn.Sequential(nn.ConvTranspose2d(1, 8, (29, 14), stride=(1, 1), padding=(0, 11)),
BatchNorm2d(8),
nn.ReLU(inplace=True),
nn.Conv2d(8, 35, (13, 13), stride=(1, 1), padding=(6, 6)))
self.decodefeature_extract = nn.Sequential(mydownres2Dblock(self.decode_dim,32),
mydownres2Dblock(32,64),
mydownres2Dblock(64,128),
mydownres2Dblock(128,256),
mydownres2Dblock(256,512),
nn.AvgPool2d(2))
self.transformer = Transformer()
self.kp = nn.Linear(512,opt.num_kp*2)
self.jacobian = nn.Linear(512,opt.num_kp*4)
self.jacobian.weight.data.zero_()
self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.opt.num_kp, dtype=torch.float))
self.criterion = nn.L1Loss()
def create_sparse_motions(self, source_image, kp_source):
"""
Eq 4. in the paper T_{s<-d}(z)
"""
bs, _, h, w = source_image.shape
identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type())
identity_grid = identity_grid.view(1, 1, h, w, 2)
coordinate_grid = identity_grid
if 'jacobian' in kp_source:
jacobian = kp_source['jacobian']
jacobian = jacobian.unsqueeze(-3).unsqueeze(-3)
jacobian = jacobian.repeat(1, 1, h, w, 1, 1)
coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
coordinate_grid = coordinate_grid.squeeze(-1)
driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.opt.num_kp, 1, 1, 2)
#adding background feature
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1)
return sparse_motions.permute(0,1,4,2,3).reshape(bs,(self.opt.num_kp+1)*2,64,64)
def forward(self,x, initial_kp = None):
bs,seqlen = x["ph"].shape
ph = x["ph"].reshape(bs*seqlen,1)
pose = x["pose"].reshape(bs*seqlen,1,256,256)
input_feature = self.down_pose(pose)
phoneme_embedding = self.embedding(ph.long())
phoneme_embedding = phoneme_embedding.reshape(bs*seqlen, 1, 16, 16)
phoneme_embedding = F.interpolate(phoneme_embedding, scale_factor=4)
input_feature = torch.cat((input_feature, phoneme_embedding), dim=1)
input_feature = self.feature_extract(input_feature).unsqueeze(-1).reshape(bs,seqlen,512)
audio = x["audio"].reshape(bs * seqlen, 1, 4, 41)
decoder_feature = self.audio_embedding(audio)
decoder_feature = F.interpolate(decoder_feature, scale_factor=2)
decoder_feature = self.decodefeature_extract(torch.cat(
(decoder_feature,
initial_kp["feature_map"].unsqueeze(1).repeat(1, seqlen, 1, 1, 1).reshape(bs * seqlen, 35, 64, 64)),
dim=1)).unsqueeze(-1).reshape(bs, seqlen, 512)
posi_em = self.pos_enc(self.opt.num_w*2+1)
out = {}
output_feature = self.transformer(self.opt,input_feature,decoder_feature,posi_em)[-1,self.opt.num_w]
out["value"] = self.kp(output_feature).reshape(bs,self.opt.num_kp,2)
out["jacobian"] = self.jacobian(output_feature).reshape(bs,self.opt.num_kp,2,2)
return out