venite's picture
initial
f670afc
raw
history blame
20.8 kB
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from imaginaire.generators.fs_vid2vid import LabelEmbedder
from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock
from imaginaire.model_utils.fs_vid2vid import (extract_valid_pose_labels,
resample)
from imaginaire.utils.data import (get_paired_input_image_channel_number,
get_paired_input_label_channel_number)
from imaginaire.utils.init_weight import weights_init
class BaseNetwork(nn.Module):
r"""vid2vid generator."""
def __init__(self):
super(BaseNetwork, self).__init__()
def get_num_filters(self, num_downsamples):
r"""Get the number of filters at current layer.
Args:
num_downsamples (int) : How many downsamples at current layer.
Returns:
output (int) : Number of filters.
"""
return min(self.max_num_filters,
self.num_filters * (2 ** num_downsamples))
class Generator(BaseNetwork):
r"""vid2vid generator constructor.
Args:
gen_cfg (obj): Generator definition part of the yaml config file.
data_cfg (obj): Data definition part of the yaml config file.
"""
def __init__(self, gen_cfg, data_cfg):
super().__init__()
self.gen_cfg = gen_cfg
self.data_cfg = data_cfg
self.num_frames_G = data_cfg.num_frames_G
# Number of residual blocks in generator.
self.num_layers = num_layers = getattr(gen_cfg, 'num_layers', 7)
# Number of downsamplings for previous frame.
self.num_downsamples_img = getattr(gen_cfg, 'num_downsamples_img', 4)
# Number of filters in the first layer.
self.num_filters = num_filters = getattr(gen_cfg, 'num_filters', 32)
self.max_num_filters = getattr(gen_cfg, 'max_num_filters', 1024)
self.kernel_size = kernel_size = getattr(gen_cfg, 'kernel_size', 3)
padding = kernel_size // 2
# For pose dataset.
self.is_pose_data = hasattr(data_cfg, 'for_pose_dataset')
if self.is_pose_data:
pose_cfg = data_cfg.for_pose_dataset
self.pose_type = getattr(pose_cfg, 'pose_type', 'both')
self.remove_face_labels = getattr(pose_cfg, 'remove_face_labels',
False)
# Input data params.
num_input_channels = get_paired_input_label_channel_number(data_cfg)
num_img_channels = get_paired_input_image_channel_number(data_cfg)
aug_cfg = data_cfg.val.augmentations
if hasattr(aug_cfg, 'center_crop_h_w'):
crop_h_w = aug_cfg.center_crop_h_w
elif hasattr(aug_cfg, 'resize_h_w'):
crop_h_w = aug_cfg.resize_h_w
else:
raise ValueError('Need to specify output size.')
crop_h, crop_w = crop_h_w.split(',')
crop_h, crop_w = int(crop_h), int(crop_w)
# Spatial size at the bottle neck of generator.
self.sh = crop_h // (2 ** num_layers)
self.sw = crop_w // (2 ** num_layers)
# Noise vector dimension.
self.z_dim = getattr(gen_cfg, 'style_dims', 256)
self.use_segmap_as_input = \
getattr(gen_cfg, 'use_segmap_as_input', False)
# Label / image embedding network.
self.emb_cfg = emb_cfg = getattr(gen_cfg, 'embed', None)
self.use_embed = getattr(emb_cfg, 'use_embed', 'True')
self.num_downsamples_embed = getattr(emb_cfg, 'num_downsamples', 5)
if self.use_embed:
self.label_embedding = LabelEmbedder(emb_cfg, num_input_channels)
# Flow network.
self.flow_cfg = flow_cfg = gen_cfg.flow
# Use SPADE to combine warped and hallucinated frames instead of
# linear combination.
self.spade_combine = getattr(flow_cfg, 'multi_spade_combine', True)
# Number of layers to perform multi-spade combine.
self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine,
'num_layers', 3)
# At beginning of training, only train an image generator.
self.temporal_initialized = False
# Whether to output hallucinated frame (when training temporal network)
# for additional loss.
self.generate_raw_output = False
# Image generation network.
weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral')
activation_norm_type = gen_cfg.activation_norm_type
activation_norm_params = gen_cfg.activation_norm_params
if self.use_embed and \
not hasattr(activation_norm_params, 'num_filters'):
activation_norm_params.num_filters = 0
nonlinearity = 'leakyrelu'
self.base_res_block = base_res_block = partial(
Res2dBlock, kernel_size=kernel_size, padding=padding,
weight_norm_type=weight_norm_type,
activation_norm_type=activation_norm_type,
activation_norm_params=activation_norm_params,
nonlinearity=nonlinearity, order='NACNAC')
# Upsampling residual blocks.
for i in range(num_layers, -1, -1):
activation_norm_params.cond_dims = self.get_cond_dims(i)
activation_norm_params.partial = self.get_partial(
i) if hasattr(self, 'get_partial') else False
layer = base_res_block(self.get_num_filters(i + 1),
self.get_num_filters(i))
setattr(self, 'up_%d' % i, layer)
# Final conv layer.
self.conv_img = Conv2dBlock(num_filters, num_img_channels,
kernel_size, padding=padding,
nonlinearity=nonlinearity, order='AC')
num_filters = min(self.max_num_filters,
num_filters * (2 ** (self.num_layers + 1)))
if self.use_segmap_as_input:
self.fc = Conv2dBlock(num_input_channels, num_filters,
kernel_size=3, padding=1)
else:
self.fc = LinearBlock(self.z_dim, num_filters * self.sh * self.sw)
# Misc.
self.downsample = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
self.upsample = partial(F.interpolate, scale_factor=2)
self.init_temporal_network()
def forward(self, data):
r"""vid2vid generator forward.
Args:
data (dict) : Dictionary of input data.
Returns:
output (dict) : Dictionary of output data.
"""
label = data['label']
label_prev, img_prev = data['prev_labels'], data['prev_images']
is_first_frame = img_prev is None
z = getattr(data, 'z', None)
bs, _, h, w = label.size()
if self.is_pose_data:
label, label_prev = extract_valid_pose_labels(
[label, label_prev], self.pose_type, self.remove_face_labels)
# Get SPADE conditional maps by embedding current label input.
cond_maps_now = self.get_cond_maps(label, self.label_embedding)
# Input to the generator will either be noise/segmentation map (for
# first frame) or encoded previous frame (for subsequent frames).
if is_first_frame:
# First frame in the sequence, start from scratch.
if self.use_segmap_as_input:
x_img = F.interpolate(label, size=(self.sh, self.sw))
x_img = self.fc(x_img)
else:
if z is None:
z = torch.randn(bs, self.z_dim, dtype=label.dtype,
device=label.get_device()).fill_(0)
x_img = self.fc(z).view(bs, -1, self.sh, self.sw)
# Upsampling layers.
for i in range(self.num_layers, self.num_downsamples_img, -1):
j = min(self.num_downsamples_embed, i)
x_img = getattr(self, 'up_' + str(i))(x_img, *cond_maps_now[j])
x_img = self.upsample(x_img)
else:
# Not the first frame, will encode the previous frame and feed to
# the generator.
x_img = self.down_first(img_prev[:, -1])
# Get label embedding for the previous frame.
cond_maps_prev = self.get_cond_maps(label_prev[:, -1],
self.label_embedding)
# Downsampling layers.
for i in range(self.num_downsamples_img + 1):
j = min(self.num_downsamples_embed, i)
x_img = getattr(self, 'down_' + str(i))(x_img,
*cond_maps_prev[j])
if i != self.num_downsamples_img:
x_img = self.downsample(x_img)
# Resnet blocks.
j = min(self.num_downsamples_embed, self.num_downsamples_img + 1)
for i in range(self.num_res_blocks):
cond_maps = cond_maps_prev[j] if i < self.num_res_blocks // 2 \
else cond_maps_now[j]
x_img = getattr(self, 'res_' + str(i))(x_img, *cond_maps)
flow = mask = img_warp = None
num_frames_G = self.num_frames_G
# Whether to warp the previous frame or not.
warp_prev = self.temporal_initialized and not is_first_frame and \
label_prev.shape[1] == num_frames_G - 1
if warp_prev:
# Estimate flow & mask.
label_concat = torch.cat([label_prev.view(bs, -1, h, w),
label], dim=1)
img_prev_concat = img_prev.view(bs, -1, h, w)
flow, mask = self.flow_network_temp(label_concat, img_prev_concat)
img_warp = resample(img_prev[:, -1], flow)
if self.spade_combine:
# if using SPADE combine, integrate the warped image (and
# occlusion mask) into conditional inputs for SPADE.
img_embed = torch.cat([img_warp, mask], dim=1)
cond_maps_img = self.get_cond_maps(img_embed,
self.img_prev_embedding)
x_raw_img = None
# Main image generation branch.
for i in range(self.num_downsamples_img, -1, -1):
# Get SPADE conditional inputs.
j = min(i, self.num_downsamples_embed)
cond_maps = cond_maps_now[j]
# For raw output generation.
if self.generate_raw_output:
if i >= self.num_multi_spade_layers - 1:
x_raw_img = x_img
if i < self.num_multi_spade_layers:
x_raw_img = self.one_up_conv_layer(x_raw_img, cond_maps, i)
# For final output.
if warp_prev and i < self.num_multi_spade_layers:
cond_maps += cond_maps_img[j]
x_img = self.one_up_conv_layer(x_img, cond_maps, i)
# Final conv layer.
img_final = torch.tanh(self.conv_img(x_img))
img_raw = None
if self.spade_combine and self.generate_raw_output:
img_raw = torch.tanh(self.conv_img(x_raw_img))
if warp_prev and not self.spade_combine:
img_raw = img_final
img_final = img_final * mask + img_warp * (1 - mask)
output = dict()
output['fake_images'] = img_final
output['fake_flow_maps'] = flow
output['fake_occlusion_masks'] = mask
output['fake_raw_images'] = img_raw
output['warped_images'] = img_warp
return output
def one_up_conv_layer(self, x, encoded_label, i):
r"""One residual block layer in the main branch.
Args:
x (4D tensor) : Current feature map.
encoded_label (list of tensors) : Encoded input label maps.
i (int) : Layer index.
Returns:
x (4D tensor) : Output feature map.
"""
layer = getattr(self, 'up_' + str(i))
x = layer(x, *encoded_label)
if i != 0:
x = self.upsample(x)
return x
def init_temporal_network(self, cfg_init=None):
r"""When starting training multiple frames, initialize the
downsampling network and flow network.
Args:
cfg_init (dict) : Weight initialization config.
"""
# Number of image downsamplings for the previous frame.
num_downsamples_img = self.num_downsamples_img
# Number of residual blocks for the previous frame.
self.num_res_blocks = int(
np.ceil((self.num_layers - num_downsamples_img) / 2.0) * 2)
# First conv layer.
num_img_channels = get_paired_input_image_channel_number(self.data_cfg)
self.down_first = \
Conv2dBlock(num_img_channels,
self.num_filters, self.kernel_size,
padding=self.kernel_size // 2)
if cfg_init is not None:
self.down_first.apply(weights_init(cfg_init.type, cfg_init.gain))
# Downsampling residual blocks.
activation_norm_params = self.gen_cfg.activation_norm_params
for i in range(num_downsamples_img + 1):
activation_norm_params.cond_dims = self.get_cond_dims(i)
layer = self.base_res_block(self.get_num_filters(i),
self.get_num_filters(i + 1))
if cfg_init is not None:
layer.apply(weights_init(cfg_init.type, cfg_init.gain))
setattr(self, 'down_%d' % i, layer)
# Additional residual blocks.
res_ch = self.get_num_filters(num_downsamples_img + 1)
activation_norm_params.cond_dims = \
self.get_cond_dims(num_downsamples_img + 1)
for i in range(self.num_res_blocks):
layer = self.base_res_block(res_ch, res_ch)
if cfg_init is not None:
layer.apply(weights_init(cfg_init.type, cfg_init.gain))
setattr(self, 'res_%d' % i, layer)
# Flow network.
flow_cfg = self.flow_cfg
self.temporal_initialized = True
self.generate_raw_output = getattr(flow_cfg, 'generate_raw_output',
False) and self.spade_combine
self.flow_network_temp = FlowGenerator(flow_cfg, self.data_cfg)
if cfg_init is not None:
self.flow_network_temp.apply(weights_init(cfg_init.type,
cfg_init.gain))
self.spade_combine = getattr(flow_cfg, 'multi_spade_combine', True)
if self.spade_combine:
emb_cfg = flow_cfg.multi_spade_combine.embed
num_img_channels = get_paired_input_image_channel_number(
self.data_cfg)
self.img_prev_embedding = LabelEmbedder(emb_cfg,
num_img_channels + 1)
if cfg_init is not None:
self.img_prev_embedding.apply(weights_init(cfg_init.type,
cfg_init.gain))
def get_cond_dims(self, num_downs=0):
r"""Get the dimensions of conditional inputs.
Args:
num_downs (int) : How many downsamples at current layer.
Returns:
ch (list) : List of dimensions.
"""
if not self.use_embed:
ch = [self.num_input_channels]
else:
num_filters = getattr(self.emb_cfg, 'num_filters', 32)
num_downs = min(num_downs, self.num_downsamples_embed)
ch = [min(self.max_num_filters, num_filters * (2 ** num_downs))]
if (num_downs < self.num_multi_spade_layers):
ch = ch * 2
return ch
def get_cond_maps(self, label, embedder):
r"""Get the conditional inputs.
Args:
label (4D tensor) : Input label tensor.
embedder (obj) : Embedding network.
Returns:
cond_maps (list) : List of conditional inputs.
"""
if not self.use_embed:
return [label] * (self.num_layers + 1)
embedded_label = embedder(label)
cond_maps = [embedded_label]
cond_maps = [[m[i] for m in cond_maps] for i in
range(len(cond_maps[0]))]
return cond_maps
class FlowGenerator(BaseNetwork):
r"""Flow generator constructor.
Args:
flow_cfg (obj): Flow definition part of the yaml config file.
data_cfg (obj): Data definition part of the yaml config file.
"""
def __init__(self, flow_cfg, data_cfg):
super().__init__()
num_input_channels = get_paired_input_label_channel_number(data_cfg)
num_prev_img_channels = get_paired_input_image_channel_number(data_cfg)
num_frames = data_cfg.num_frames_G # Num. of input frames.
self.num_filters = num_filters = getattr(flow_cfg, 'num_filters', 32)
self.max_num_filters = getattr(flow_cfg, 'max_num_filters', 1024)
num_downsamples = getattr(flow_cfg, 'num_downsamples', 5)
kernel_size = getattr(flow_cfg, 'kernel_size', 3)
padding = kernel_size // 2
self.num_res_blocks = getattr(flow_cfg, 'num_res_blocks', 6)
# Multiplier on the flow output.
self.flow_output_multiplier = getattr(flow_cfg,
'flow_output_multiplier', 20)
activation_norm_type = getattr(flow_cfg, 'activation_norm_type',
'sync_batch')
weight_norm_type = getattr(flow_cfg, 'weight_norm_type', 'spectral')
base_conv_block = partial(Conv2dBlock, kernel_size=kernel_size,
padding=padding,
weight_norm_type=weight_norm_type,
activation_norm_type=activation_norm_type,
nonlinearity='leakyrelu')
# Will downsample the labels and prev frames separately, then combine.
down_lbl = [base_conv_block(num_input_channels * num_frames,
num_filters)]
down_img = [base_conv_block(num_prev_img_channels * (num_frames - 1),
num_filters)]
for i in range(num_downsamples):
down_lbl += [base_conv_block(self.get_num_filters(i),
self.get_num_filters(i + 1),
stride=2)]
down_img += [base_conv_block(self.get_num_filters(i),
self.get_num_filters(i + 1),
stride=2)]
# Resnet blocks.
res_flow = []
ch = self.get_num_filters(num_downsamples)
for i in range(self.num_res_blocks):
res_flow += [
Res2dBlock(ch, ch, kernel_size, padding=padding,
weight_norm_type=weight_norm_type,
activation_norm_type=activation_norm_type,
order='CNACN')]
# Upsample.
up_flow = []
for i in reversed(range(num_downsamples)):
up_flow += [nn.Upsample(scale_factor=2),
base_conv_block(self.get_num_filters(i + 1),
self.get_num_filters(i))]
conv_flow = [Conv2dBlock(num_filters, 2, kernel_size, padding=padding)]
conv_mask = [Conv2dBlock(num_filters, 1, kernel_size, padding=padding,
nonlinearity='sigmoid')]
self.down_lbl = nn.Sequential(*down_lbl)
self.down_img = nn.Sequential(*down_img)
self.res_flow = nn.Sequential(*res_flow)
self.up_flow = nn.Sequential(*up_flow)
self.conv_flow = nn.Sequential(*conv_flow)
self.conv_mask = nn.Sequential(*conv_mask)
def forward(self, label, img_prev):
r"""Flow generator forward.
Args:
label (4D tensor) : Input label tensor.
img_prev (4D tensor) : Previously generated image tensors.
Returns:
(tuple):
- flow (4D tensor) : Generated flow map.
- mask (4D tensor) : Generated occlusion mask.
"""
downsample = self.down_lbl(label) + self.down_img(img_prev)
res = self.res_flow(downsample)
flow_feat = self.up_flow(res)
flow = self.conv_flow(flow_feat) * self.flow_output_multiplier
mask = self.conv_mask(flow_feat)
return flow, mask