gartajackhats1985's picture
Upload 171 files
c37b2dd verified
# coding: utf-8
"""
Warping field estimator(W) defined in the paper, which generates a warping field using the implicit
keypoint representations x_s and x_d, and employs this flow field to warp the source feature volume f_s.
"""
from torch import nn
import torch.nn.functional as F
from .util import SameBlock2d
from .dense_motion import DenseMotionNetwork
class WarpingNetwork(nn.Module):
def __init__(
self,
num_kp,
block_expansion,
max_features,
num_down_blocks,
reshape_channel,
estimate_occlusion_map=False,
dense_motion_params=None,
**kwargs
):
super(WarpingNetwork, self).__init__()
self.upscale = kwargs.get('upscale', 1)
self.flag_use_occlusion_map = kwargs.get('flag_use_occlusion_map', True)
if dense_motion_params is not None:
self.dense_motion_network = DenseMotionNetwork(
num_kp=num_kp,
feature_channel=reshape_channel,
estimate_occlusion_map=estimate_occlusion_map,
**dense_motion_params
)
else:
self.dense_motion_network = None
self.third = SameBlock2d(max_features, block_expansion * (2 ** num_down_blocks), kernel_size=(3, 3), padding=(1, 1), lrelu=True)
self.fourth = nn.Conv2d(in_channels=block_expansion * (2 ** num_down_blocks), out_channels=block_expansion * (2 ** num_down_blocks), kernel_size=1, stride=1)
self.estimate_occlusion_map = estimate_occlusion_map
def deform_input(self, inp, deformation):
return F.grid_sample(inp, deformation, align_corners=False)
def forward(self, feature_3d, kp_driving, kp_source):
if self.dense_motion_network is not None:
# Feature warper, Transforming feature representation according to deformation and occlusion
dense_motion = self.dense_motion_network(
feature=feature_3d, kp_driving=kp_driving, kp_source=kp_source
)
if 'occlusion_map' in dense_motion:
occlusion_map = dense_motion['occlusion_map'] # Bx1x64x64
else:
occlusion_map = None
deformation = dense_motion['deformation'] # Bx16x64x64x3
out = self.deform_input(feature_3d, deformation) # Bx32x16x64x64
bs, c, d, h, w = out.shape # Bx32x16x64x64
out = out.view(bs, c * d, h, w) # -> Bx512x64x64
out = self.third(out) # -> Bx256x64x64
out = self.fourth(out) # -> Bx256x64x64
if self.flag_use_occlusion_map and (occlusion_map is not None):
out = out * occlusion_map
ret_dct = {
'occlusion_map': occlusion_map,
'deformation': deformation,
'out': out,
}
return ret_dct