Spaces:
Paused
Paused
File size: 6,144 Bytes
784616c 5b1ae50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import functools
import torch
import torch.nn as nn
from videoretalking.models.transformer import RETURNX, Transformer
from videoretalking.models.base_blocks import Conv2d, LayerNorm2d, FirstBlock2d, DownBlock2d, UpBlock2d, \
FFCADAINResBlocks, Jump, FinalBlock2d
class Visual_Encoder(nn.Module):
def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
super(Visual_Encoder, self).__init__()
self.layers = layers
self.first_inp = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
self.first_ref = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
for i in range(layers):
in_channels = min(ngf*(2**i), img_f)
out_channels = min(ngf*(2**(i+1)), img_f)
model_ref = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
model_inp = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
if i < 2:
ca_layer = RETURNX()
else:
ca_layer = Transformer(2**(i+1) * ngf,2,4,ngf,ngf*4)
setattr(self, 'ca' + str(i), ca_layer)
setattr(self, 'ref_down' + str(i), model_ref)
setattr(self, 'inp_down' + str(i), model_inp)
self.output_nc = out_channels * 2
def forward(self, maskGT, ref):
x_maskGT, x_ref = self.first_inp(maskGT), self.first_ref(ref)
out=[x_maskGT]
for i in range(self.layers):
model_ref = getattr(self, 'ref_down'+str(i))
model_inp = getattr(self, 'inp_down'+str(i))
ca_layer = getattr(self, 'ca'+str(i))
x_maskGT, x_ref = model_inp(x_maskGT), model_ref(x_ref)
x_maskGT = ca_layer(x_maskGT, x_ref)
if i < self.layers - 1:
out.append(x_maskGT)
else:
out.append(torch.cat([x_maskGT, x_ref], dim=1)) # concat ref features !
return out
class Decoder(nn.Module):
def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
super(Decoder, self).__init__()
self.layers = layers
for i in range(layers)[::-1]:
if i == layers-1:
in_channels = ngf*(2**(i+1)) * 2
else:
in_channels = min(ngf*(2**(i+1)), img_f)
out_channels = min(ngf*(2**i), img_f)
up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
res = FFCADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
setattr(self, 'up' + str(i), up)
setattr(self, 'res' + str(i), res)
setattr(self, 'jump' + str(i), jump)
self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'sigmoid')
self.output_nc = out_channels
def forward(self, x, z):
out = x.pop()
for i in range(self.layers)[::-1]:
res_model = getattr(self, 'res' + str(i))
up_model = getattr(self, 'up' + str(i))
jump_model = getattr(self, 'jump' + str(i))
out = res_model(out, z)
out = up_model(out)
out = jump_model(x.pop()) + out
out_image = self.final(out)
return out_image
class LNet(nn.Module):
def __init__(
self,
image_nc=3,
descriptor_nc=512,
layer=3,
base_nc=64,
max_nc=512,
num_res_blocks=9,
use_spect=True,
encoder=Visual_Encoder,
decoder=Decoder
):
super(LNet, self).__init__()
nonlinearity = nn.LeakyReLU(0.1)
norm_layer = functools.partial(LayerNorm2d, affine=True)
kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
self.descriptor_nc = descriptor_nc
self.encoder = encoder(image_nc, base_nc, max_nc, layer, **kwargs)
self.decoder = decoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, descriptor_nc, kernel_size=1, stride=1, padding=0),
)
def forward(self, audio_sequences, face_sequences):
B = audio_sequences.size(0)
input_dim_size = len(face_sequences.size())
if input_dim_size > 4:
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
cropped, ref = torch.split(face_sequences, 3, dim=1)
vis_feat = self.encoder(cropped, ref)
audio_feat = self.audio_encoder(audio_sequences)
_outputs = self.decoder(vis_feat, audio_feat)
if input_dim_size > 4:
_outputs = torch.split(_outputs, B, dim=0)
outputs = torch.stack(_outputs, dim=2)
else:
outputs = _outputs
return outputs |