# TODO import functools import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from videoretalking.utils import flow_util from videoretalking.models.base_blocks import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder # DNet class DNet(nn.Module): def __init__(self): super(DNet, self).__init__() self.mapping_net = MappingNet() self.warpping_net = WarpingNet() self.editing_net = EditingNet() def forward(self, input_image, driving_source, stage=None): if stage == 'warp': descriptor = self.mapping_net(driving_source) output = self.warpping_net(input_image, descriptor) else: descriptor = self.mapping_net(driving_source) output = self.warpping_net(input_image, descriptor) output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor) return output class MappingNet(nn.Module): def __init__(self, coeff_nc=73, descriptor_nc=256, layer=3): super( MappingNet, self).__init__() self.layer = layer nonlinearity = nn.LeakyReLU(0.1) self.first = nn.Sequential( torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True)) for i in range(layer): net = nn.Sequential(nonlinearity, torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3)) setattr(self, 'encoder' + str(i), net) self.pooling = nn.AdaptiveAvgPool1d(1) self.output_nc = descriptor_nc def forward(self, input_3dmm): out = self.first(input_3dmm) for i in range(self.layer): model = getattr(self, 'encoder' + str(i)) out = model(out) + out[:,:,3:-3] out = self.pooling(out) return out class WarpingNet(nn.Module): def __init__( self, image_nc=3, descriptor_nc=256, base_nc=32, max_nc=256, encoder_layer=5, decoder_layer=3, use_spect=False ): super( WarpingNet, self).__init__() nonlinearity = nn.LeakyReLU(0.1) norm_layer = functools.partial(LayerNorm2d, affine=True) kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect} self.descriptor_nc = descriptor_nc self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc, max_nc, encoder_layer, decoder_layer, **kwargs) self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc), nonlinearity, nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3)) self.pool = nn.AdaptiveAvgPool2d(1) def forward(self, input_image, descriptor): final_output={} output = self.hourglass(input_image, descriptor) final_output['flow_field'] = self.flow_out(output) deformation = flow_util.convert_flow_to_deformation(final_output['flow_field']) final_output['warp_image'] = flow_util.warp_image(input_image, deformation) return final_output class EditingNet(nn.Module): def __init__( self, image_nc=3, descriptor_nc=256, layer=3, base_nc=64, max_nc=256, num_res_blocks=2, use_spect=False): super(EditingNet, self).__init__() nonlinearity = nn.LeakyReLU(0.1) norm_layer = functools.partial(LayerNorm2d, affine=True) kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect} self.descriptor_nc = descriptor_nc # encoder part self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs) self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs) def forward(self, input_image, warp_image, descriptor): x = torch.cat([input_image, warp_image], 1) x = self.encoder(x) gen_image = self.decoder(x, descriptor) return gen_image