# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, check out LICENSE.md import copy from functools import partial import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from imaginaire.layers import (Conv2dBlock, HyperConv2dBlock, HyperRes2dBlock, LinearBlock, Res2dBlock) from imaginaire.model_utils.fs_vid2vid import (extract_valid_pose_labels, pick_image, resample) from imaginaire.utils.data import (get_paired_input_image_channel_number, get_paired_input_label_channel_number) from imaginaire.utils.distributed import master_only_print as print from imaginaire.utils.init_weight import weights_init from imaginaire.utils.misc import get_and_setattr, get_nested_attr class Generator(nn.Module): r"""Few-shot vid2vid generator constructor. Args: gen_cfg (obj): Generator definition part of the yaml config file. data_cfg (obj): Data definition part of the yaml config file. """ def __init__(self, gen_cfg, data_cfg): super().__init__() self.gen_cfg = gen_cfg self.data_cfg = data_cfg self.num_frames_G = data_cfg.num_frames_G self.flow_cfg = flow_cfg = gen_cfg.flow # For pose dataset. self.is_pose_data = hasattr(data_cfg, 'for_pose_dataset') if self.is_pose_data: pose_cfg = data_cfg.for_pose_dataset self.pose_type = getattr(pose_cfg, 'pose_type', 'both') self.remove_face_labels = getattr(pose_cfg, 'remove_face_labels', False) num_img_channels = get_paired_input_image_channel_number(data_cfg) self.num_downsamples = num_downsamples = \ get_and_setattr(gen_cfg, 'num_downsamples', 5) conv_kernel_size = get_and_setattr(gen_cfg, 'kernel_size', 3) num_filters = get_and_setattr(gen_cfg, 'num_filters', 32) max_num_filters = getattr(gen_cfg, 'max_num_filters', 1024) self.max_num_filters = gen_cfg.max_num_filters = \ min(max_num_filters, num_filters * (2 ** num_downsamples)) # Get number of filters at each layer in the main branch. num_filters_each_layer = [min(self.max_num_filters, num_filters * (2 ** i)) for i in range(num_downsamples + 2)] # Hyper normalization / convolution. hyper_cfg = gen_cfg.hyper # Use adaptive weight generation for SPADE. self.use_hyper_spade = hyper_cfg.is_hyper_spade # Use adaptive for convolutional layers in the main branch. self.use_hyper_conv = hyper_cfg.is_hyper_conv # Number of hyper layers. self.num_hyper_layers = getattr(hyper_cfg, 'num_hyper_layers', 4) if self.num_hyper_layers == -1: self.num_hyper_layers = num_downsamples gen_cfg.hyper.num_hyper_layers = self.num_hyper_layers # Network weight generator. self.weight_generator = WeightGenerator(gen_cfg, data_cfg) # Number of layers to perform multi-spade combine. self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine, 'num_layers', 3) # Whether to generate raw output for additional losses. self.generate_raw_output = getattr(flow_cfg, 'generate_raw_output', False) # Main branch image generation. padding = conv_kernel_size // 2 activation_norm_type = get_and_setattr(gen_cfg, 'activation_norm_type', 'sync_batch') weight_norm_type = get_and_setattr(gen_cfg, 'weight_norm_type', 'spectral') activation_norm_params = get_and_setattr(gen_cfg, 'activation_norm_params', None) spade_in_channels = [] # Input channel size in SPADE module. for i in range(num_downsamples + 1): spade_in_channels += [[num_filters_each_layer[i]]] \ if i >= self.num_multi_spade_layers \ else [[num_filters_each_layer[i]] * 3] order = getattr(gen_cfg.hyper, 'hyper_block_order', 'NAC') for i in reversed(range(num_downsamples + 1)): activation_norm_params.cond_dims = spade_in_channels[i] is_hyper_conv = self.use_hyper_conv and i < self.num_hyper_layers is_hyper_norm = self.use_hyper_spade and i < self.num_hyper_layers setattr(self, 'up_%d' % i, HyperRes2dBlock( num_filters_each_layer[i + 1], num_filters_each_layer[i], conv_kernel_size, padding=padding, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, activation_norm_params=activation_norm_params, order=order * 2, is_hyper_conv=is_hyper_conv, is_hyper_norm=is_hyper_norm)) self.conv_img = Conv2dBlock(num_filters, num_img_channels, conv_kernel_size, padding=padding, nonlinearity='leakyrelu', order='AC') self.upsample = partial(F.interpolate, scale_factor=2) # Flow estimation module. # Whether to warp reference image and combine with the synthesized. self.warp_ref = getattr(flow_cfg, 'warp_ref', True) if self.warp_ref: self.flow_network_ref = FlowGenerator(flow_cfg, data_cfg, 2) self.ref_image_embedding = \ LabelEmbedder(flow_cfg.multi_spade_combine.embed, num_img_channels + 1) # At beginning of training, only train an image generator. self.temporal_initialized = False if getattr(gen_cfg, 'init_temporal', True): self.init_temporal_network() def forward(self, data): r"""few-shot vid2vid generator forward. Args: data (dict) : Dictionary of input data. Returns: output (dict) : Dictionary of output data. """ label = data['label'] ref_labels, ref_images = data['ref_labels'], data['ref_images'] prev_labels, prev_images = data['prev_labels'], data['prev_images'] is_first_frame = prev_labels is None if self.is_pose_data: label, prev_labels = extract_valid_pose_labels( [label, prev_labels], self.pose_type, self.remove_face_labels) ref_labels = extract_valid_pose_labels( ref_labels, self.pose_type, self.remove_face_labels, do_remove=False) # Weight generation. x, encoded_label, conv_weights, norm_weights, atn, atn_vis, ref_idx = \ self.weight_generator(ref_images, ref_labels, label, is_first_frame) # Flow estimation. flow, flow_mask, img_warp, cond_inputs = \ self.flow_generation(label, ref_labels, ref_images, prev_labels, prev_images, ref_idx) for i in range(len(encoded_label)): encoded_label[i] = [encoded_label[i]] if self.generate_raw_output: encoded_label_raw = [encoded_label[i] for i in range(self.num_multi_spade_layers)] x_raw = None encoded_label = self.SPADE_combine(encoded_label, cond_inputs) # Main branch image generation. for i in range(self.num_downsamples, -1, -1): conv_weight = norm_weight = [None] * 3 if self.use_hyper_conv and i < self.num_hyper_layers: conv_weight = conv_weights[i] if self.use_hyper_spade and i < self.num_hyper_layers: norm_weight = norm_weights[i] # Main branch residual blocks. x = self.one_up_conv_layer(x, encoded_label, conv_weight, norm_weight, i) # For raw output generation. if self.generate_raw_output and i < self.num_multi_spade_layers: x_raw = self.one_up_conv_layer(x_raw, encoded_label_raw, conv_weight, norm_weight, i) else: x_raw = x # Final conv layer. if self.generate_raw_output: img_raw = torch.tanh(self.conv_img(x_raw)) else: img_raw = None img_final = torch.tanh(self.conv_img(x)) output = dict() output['fake_images'] = img_final output['fake_flow_maps'] = flow output['fake_occlusion_masks'] = flow_mask output['fake_raw_images'] = img_raw output['warped_images'] = img_warp output['attention_visualization'] = atn_vis output['ref_idx'] = ref_idx return output def one_up_conv_layer(self, x, encoded_label, conv_weight, norm_weight, i): r"""One residual block layer in the main branch. Args: x (4D tensor) : Current feature map. encoded_label (list of tensors) : Encoded input label maps. conv_weight (list of tensors) : Hyper conv weights. norm_weight (list of tensors) : Hyper norm weights. i (int) : Layer index. Returns: x (4D tensor) : Output feature map. """ layer = getattr(self, 'up_' + str(i)) x = layer(x, *encoded_label[i], conv_weights=conv_weight, norm_weights=norm_weight) if i != 0: x = self.upsample(x) return x def init_temporal_network(self, cfg_init=None): r"""When starting training multiple frames, initialize the flow network. Args: cfg_init (dict) : Weight initialization config. """ flow_cfg = self.flow_cfg emb_cfg = flow_cfg.multi_spade_combine.embed num_frames_G = self.num_frames_G self.temporal_initialized = True self.sep_prev_flownet = flow_cfg.sep_prev_flow or (num_frames_G != 2) \ or not flow_cfg.warp_ref if self.sep_prev_flownet: self.flow_network_temp = FlowGenerator(flow_cfg, self.data_cfg, num_frames_G) if cfg_init is not None: self.flow_network_temp.apply(weights_init(cfg_init.type, cfg_init.gain)) else: self.flow_network_temp = self.flow_network_ref self.sep_prev_embedding = emb_cfg.sep_warp_embed or \ not flow_cfg.warp_ref if self.sep_prev_embedding: num_img_channels = get_paired_input_image_channel_number( self.data_cfg) self.prev_image_embedding = \ LabelEmbedder(emb_cfg, num_img_channels + 1) if cfg_init is not None: self.prev_image_embedding.apply( weights_init(cfg_init.type, cfg_init.gain)) else: self.prev_image_embedding = self.ref_image_embedding if self.warp_ref: if self.sep_prev_flownet: self.init_network_weights(self.flow_network_ref, self.flow_network_temp) print('Initialized temporal flow network with the reference ' 'one.') if self.sep_prev_embedding: self.init_network_weights(self.ref_image_embedding, self.prev_image_embedding) print('Initialized temporal embedding network with the ' 'reference one.') self.flow_temp_is_initalized = True def init_network_weights(self, net_src, net_dst): r"""Initialize weights in net_dst with those in net_src.""" source_weights = net_src.state_dict() target_weights = net_dst.state_dict() for k, v in source_weights.items(): if k in target_weights and target_weights[k].size() == v.size(): target_weights[k] = v net_dst.load_state_dict(target_weights) def load_pretrained_network(self, pretrained_dict, prefix='module.'): r"""Load the pretrained network into self network. Args: pretrained_dict (dict): Pretrained network weights. prefix (str): Prefix to the network weights name. """ # print(pretrained_dict.keys()) model_dict = self.state_dict() print('Pretrained network has fewer layers; The following are ' 'not initialized:') not_initialized = set() for k, v in model_dict.items(): kp = prefix + k if kp in pretrained_dict and v.size() == pretrained_dict[kp].size(): model_dict[k] = pretrained_dict[kp] else: not_initialized.add('.'.join(k.split('.')[:2])) print(sorted(not_initialized)) self.load_state_dict(model_dict) def reset(self): r"""Reset the network at the beginning of a sequence.""" self.weight_generator.reset() def flow_generation(self, label, ref_labels, ref_images, prev_labels, prev_images, ref_idx): r"""Generates flows and masks for warping reference / previous images. Args: label (NxCxHxW tensor): Target label map. ref_labels (NxKxCxHxW tensor): Reference label maps. ref_images (NxKx3xHxW tensor): Reference images. prev_labels (NxTxCxHxW tensor): Previous label maps. prev_images (NxTx3xHxW tensor): Previous images. ref_idx (Nx1 tensor): Index for which image to use from the reference images. Returns: (tuple): - flow (list of Nx2xHxW tensor): Optical flows. - occ_mask (list of Nx1xHxW tensor): Occlusion masks. - img_warp (list of Nx3xHxW tensor): Warped reference / previous images. - cond_inputs (list of Nx4xHxW tensor): Conditional inputs for SPADE combination. """ # Pick an image in the reference images using ref_idx. ref_label, ref_image = pick_image([ref_labels, ref_images], ref_idx) # Only start using prev frames when enough prev frames are generated. has_prev = prev_labels is not None and \ prev_labels.shape[1] == (self.num_frames_G - 1) flow, occ_mask, img_warp, cond_inputs = [None] * 2, [None] * 2, \ [None] * 2, [None] * 2 if self.warp_ref: # Generate flows/masks for warping the reference image. flow_ref, occ_mask_ref = \ self.flow_network_ref(label, ref_label, ref_image) ref_image_warp = resample(ref_image, flow_ref) flow[0], occ_mask[0], img_warp[0] = \ flow_ref, occ_mask_ref, ref_image_warp[:, :3] # Concat warped image and occlusion mask to form the conditional # input. cond_inputs[0] = torch.cat([img_warp[0], occ_mask[0]], dim=1) if self.temporal_initialized and has_prev: # Generate flows/masks for warping the previous image. b, t, c, h, w = prev_labels.shape prev_labels_concat = prev_labels.view(b, -1, h, w) prev_images_concat = prev_images.view(b, -1, h, w) flow_prev, occ_mask_prev = \ self.flow_network_temp(label, prev_labels_concat, prev_images_concat) img_prev_warp = resample(prev_images[:, -1], flow_prev) flow[1], occ_mask[1], img_warp[1] = \ flow_prev, occ_mask_prev, img_prev_warp cond_inputs[1] = torch.cat([img_warp[1], occ_mask[1]], dim=1) return flow, occ_mask, img_warp, cond_inputs def SPADE_combine(self, encoded_label, cond_inputs): r"""Using Multi-SPADE to combine raw synthesized image with warped images. Args: encoded_label (list of tensors): Original label map embeddings. cond_inputs (list of tensors): New SPADE conditional inputs from the warped images. Returns: encoded_label (list of tensors): Combined conditional inputs. """ # Generate the conditional embeddings from inputs. embedded_img_feat = [None, None] if cond_inputs[0] is not None: embedded_img_feat[0] = self.ref_image_embedding(cond_inputs[0]) if cond_inputs[1] is not None: embedded_img_feat[1] = self.prev_image_embedding(cond_inputs[1]) # Combine the original encoded label maps with new conditional # embeddings. for i in range(self.num_multi_spade_layers): encoded_label[i] += [w[i] if w is not None else None for w in embedded_img_feat] return encoded_label def custom_init(self): r"""This function is for dealing with the numerical issue that might occur when doing mixed precision training. """ print('Use custom initialization for the generator.') for k, m in self.named_modules(): if 'weight_generator.ref_label_' in k and 'norm' in k: m.eps = 1e-1 class WeightGenerator(nn.Module): r"""Weight generator constructor. Args: gen_cfg (obj): Generator definition part of the yaml config file. data_cfg (obj): Data definition part of the yaml config file """ def __init__(self, gen_cfg, data_cfg): super().__init__() self.data_cfg = data_cfg self.embed_cfg = embed_cfg = gen_cfg.embed self.embed_arch = embed_cfg.arch num_filters = gen_cfg.num_filters self.max_num_filters = gen_cfg.max_num_filters self.num_downsamples = num_downsamples = gen_cfg.num_downsamples self.num_filters_each_layer = num_filters_each_layer = \ [min(self.max_num_filters, num_filters * (2 ** i)) for i in range(num_downsamples + 2)] if getattr(embed_cfg, 'num_filters', 32) != num_filters: raise ValueError('Embedding network must have the same number of ' 'filters as generator.') # Normalization params. hyper_cfg = gen_cfg.hyper kernel_size = getattr(hyper_cfg, 'kernel_size', 3) activation_norm_type = getattr(hyper_cfg, 'activation_norm_type', 'sync_batch') weight_norm_type = getattr(hyper_cfg, 'weight_norm_type', 'spectral') # Conv kernel size in main branch. self.conv_kernel_size = conv_kernel_size = gen_cfg.kernel_size # Conv kernel size in embedding network. self.embed_kernel_size = embed_kernel_size = \ getattr(gen_cfg.embed, 'kernel_size', 3) # Conv kernel size in SPADE. self.kernel_size = kernel_size = \ getattr(gen_cfg.activation_norm_params, 'kernel_size', 1) # Input channel size in SPADE module. self.spade_in_channels = [] for i in range(num_downsamples + 1): self.spade_in_channels += [num_filters_each_layer[i]] # Hyper normalization / convolution. # Use adaptive weight generation for SPADE. self.use_hyper_spade = hyper_cfg.is_hyper_spade # Use adaptive for the label embedding network. self.use_hyper_embed = hyper_cfg.is_hyper_embed # Use adaptive for convolutional layers in the main branch. self.use_hyper_conv = hyper_cfg.is_hyper_conv # Number of hyper layers. self.num_hyper_layers = hyper_cfg.num_hyper_layers # Order of operations in the conv block. order = getattr(gen_cfg.hyper, 'hyper_block_order', 'NAC') self.conv_before_norm = order.find('C') < order.find('N') # For reference image encoding. # How to utilize the reference label map: concat | mul. self.concat_ref_label = 'concat' in hyper_cfg.method_to_use_ref_labels self.mul_ref_label = 'mul' in hyper_cfg.method_to_use_ref_labels # Output spatial size for adaptive pooling layer. self.sh_fix = self.sw_fix = 32 # Number of fc layers in weight generation. self.num_fc_layers = getattr(hyper_cfg, 'num_fc_layers', 2) # Reference image encoding network. num_input_channels = get_paired_input_label_channel_number(data_cfg) if num_input_channels == 0: num_input_channels = getattr(data_cfg, 'label_channels', 1) elif get_nested_attr(data_cfg, 'for_pose_dataset.pose_type', 'both') == 'open': num_input_channels -= 3 data_cfg.num_input_channels = num_input_channels num_img_channels = get_paired_input_image_channel_number(data_cfg) num_ref_channels = num_img_channels + (num_input_channels if self.concat_ref_label else 0) conv_2d_block = partial( Conv2dBlock, kernel_size=kernel_size, padding=(kernel_size // 2), weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, nonlinearity='leakyrelu') self.ref_img_first = conv_2d_block(num_ref_channels, num_filters) if self.mul_ref_label: self.ref_label_first = conv_2d_block(num_input_channels, num_filters) for i in range(num_downsamples): in_ch, out_ch = num_filters_each_layer[i], \ num_filters_each_layer[i + 1] setattr(self, 'ref_img_down_%d' % i, conv_2d_block(in_ch, out_ch, stride=2)) setattr(self, 'ref_img_up_%d' % i, conv_2d_block(out_ch, in_ch)) if self.mul_ref_label: setattr(self, 'ref_label_down_%d' % i, conv_2d_block(in_ch, out_ch, stride=2)) setattr(self, 'ref_label_up_%d' % i, conv_2d_block(out_ch, in_ch)) # Normalization / main branch conv weight generation. if self.use_hyper_spade or self.use_hyper_conv: for i in range(self.num_hyper_layers): ch_in, ch_out = num_filters_each_layer[i], \ num_filters_each_layer[i + 1] conv_ks2 = conv_kernel_size ** 2 embed_ks2 = embed_kernel_size ** 2 spade_ks2 = kernel_size ** 2 spade_in_ch = self.spade_in_channels[i] fc_names, fc_ins, fc_outs = [], [], [] if self.use_hyper_spade: fc0_out = fcs_out = (spade_in_ch * spade_ks2 + 1) * ( 1 if self.conv_before_norm else 2) fc1_out = (spade_in_ch * spade_ks2 + 1) * ( 1 if ch_in != ch_out else 2) fc_names += ['fc_spade_0', 'fc_spade_1', 'fc_spade_s'] fc_ins += [ch_out] * 3 fc_outs += [fc0_out, fc1_out, fcs_out] if self.use_hyper_embed: fc_names += ['fc_spade_e'] fc_ins += [ch_out] fc_outs += [ch_in * embed_ks2 + 1] if self.use_hyper_conv: fc0_out = ch_out * conv_ks2 + 1 fc1_out = ch_in * conv_ks2 + 1 fcs_out = ch_out + 1 fc_names += ['fc_conv_0', 'fc_conv_1', 'fc_conv_s'] fc_ins += [ch_in] * 3 fc_outs += [fc0_out, fc1_out, fcs_out] linear_block = partial(LinearBlock, weight_norm_type='spectral', nonlinearity='leakyrelu') for n, l in enumerate(fc_names): fc_in = fc_ins[n] if self.mul_ref_label \ else self.sh_fix * self.sw_fix fc_layer = [linear_block(fc_in, ch_out)] for k in range(1, self.num_fc_layers): fc_layer += [linear_block(ch_out, ch_out)] fc_layer += [LinearBlock(ch_out, fc_outs[n], weight_norm_type='spectral')] setattr(self, '%s_%d' % (l, i), nn.Sequential(*fc_layer)) # Label embedding network. num_hyper_layers = self.num_hyper_layers if self.use_hyper_embed else 0 self.label_embedding = LabelEmbedder(self.embed_cfg, num_input_channels, num_hyper_layers=num_hyper_layers) # For multiple reference images. if hasattr(hyper_cfg, 'attention'): self.num_downsample_atn = get_and_setattr(hyper_cfg.attention, 'num_downsamples', 2) if data_cfg.initial_few_shot_K > 1: self.attention_module = AttentionModule(hyper_cfg, data_cfg, conv_2d_block, num_filters_each_layer) else: self.num_downsample_atn = 0 def forward(self, ref_image, ref_label, label, is_first_frame): r"""Generate network weights based on the reference images. Args: ref_image (NxKx3xHxW tensor): Reference images. ref_label (NxKxCxHxW tensor): Reference labels. label (NxCxHxW tensor): Target label. is_first_frame (bool): Whether the current frame is the first frame. Returns: (tuple): - x (NxC2xH2xW2 tensor): Encoded features from reference images for the main branch (as input to the decoder). - encoded_label (list of tensors): Encoded target label map for SPADE. - conv_weights (list of tensors): Network weights for conv layers in the main network. - norm_weights (list of tensors): Network weights for SPADE layers in the main network. - attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps. - atn_vis (1x1xH1xW1 tensor): Visualization for attention scores. - ref_idx (Nx1 tensor): Index for which image to use from the reference images. """ b, k, c, h, w = ref_image.size() ref_image = ref_image.view(b * k, -1, h, w) if ref_label is not None: ref_label = ref_label.view(b * k, -1, h, w) # Encode the reference images to get the features. x, encoded_ref, atn, atn_vis, ref_idx = \ self.encode_reference(ref_image, ref_label, label, k) # If the reference image has changed, recompute the network weights. if self.training or is_first_frame or k > 1: embedding_weights, norm_weights, conv_weights = [], [], [] for i in range(self.num_hyper_layers): if self.use_hyper_spade: feat = encoded_ref[min(len(encoded_ref) - 1, i + 1)] embedding_weight, norm_weight = \ self.get_norm_weights(feat, i) embedding_weights.append(embedding_weight) norm_weights.append(norm_weight) if self.use_hyper_conv: feat = encoded_ref[min(len(encoded_ref) - 1, i)] conv_weights.append(self.get_conv_weights(feat, i)) if not self.training: self.embedding_weights, self.conv_weights, self.norm_weights \ = embedding_weights, conv_weights, norm_weights else: # print('Reusing network weights.') embedding_weights, conv_weights, norm_weights \ = self.embedding_weights, self.conv_weights, self.norm_weights # Encode the target label to get the encoded features. encoded_label = self.label_embedding(label, weights=( embedding_weights if self.use_hyper_embed else None)) return x, encoded_label, conv_weights, norm_weights, \ atn, atn_vis, ref_idx def encode_reference(self, ref_image, ref_label, label, k): r"""Encode the reference image to get features for weight generation. Args: ref_image ((NxK)x3xHxW tensor): Reference images. ref_label ((NxK)xCxHxW tensor): Reference labels. label (NxCxHxW tensor): Target label. k (int): Number of reference images. Returns: (tuple): - x (NxC2xH2xW2 tensor): Encoded features from reference images for the main branch (as input to the decoder). - encoded_ref (list of tensors): Encoded features from reference images for the weight generation branch. - attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps. - atn_vis (1x1xH1xW1 tensor): Visualization for attention scores. - ref_idx (Nx1 tensor): Index for which image to use from the reference images. """ if self.concat_ref_label: # Concat reference label map and image together for encoding. concat_ref = torch.cat([ref_image, ref_label], dim=1) x = self.ref_img_first(concat_ref) elif self.mul_ref_label: # Apply conv to both reference label and image, then multiply them # together for encoding. x = self.ref_img_first(ref_image) x_label = self.ref_label_first(ref_label) else: x = self.ref_img_first(ref_image) # Attention map and the index of the most similar reference image. atn = atn_vis = ref_idx = None for i in range(self.num_downsamples): x = getattr(self, 'ref_img_down_' + str(i))(x) if self.mul_ref_label: x_label = getattr(self, 'ref_label_down_' + str(i))(x_label) # Combine different reference images at a particular layer. if k > 1 and i == self.num_downsample_atn - 1: x, atn, atn_vis = self.attention_module(x, label, ref_label) if self.mul_ref_label: x_label, _, _ = self.attention_module(x_label, None, None, atn) atn_sum = atn.view(label.shape[0], k, -1).sum(2) ref_idx = torch.argmax(atn_sum, dim=1) # Get all corresponding layers in the encoder output for generating # weights in corresponding layers. encoded_image_ref = [x] if self.mul_ref_label: encoded_ref_label = [x_label] for i in reversed(range(self.num_downsamples)): conv = getattr(self, 'ref_img_up_' + str(i))( encoded_image_ref[-1]) encoded_image_ref.append(conv) if self.mul_ref_label: conv_label = getattr(self, 'ref_label_up_' + str(i))( encoded_ref_label[-1]) encoded_ref_label.append(conv_label) if self.mul_ref_label: encoded_ref = [] for i in range(len(encoded_image_ref)): conv, conv_label \ = encoded_image_ref[i], encoded_ref_label[i] b, c, h, w = conv.size() conv_label = nn.Softmax(dim=1)(conv_label) conv_prod = (conv.view(b, c, 1, h * w) * conv_label.view(b, 1, c, h * w)).sum(3, keepdim=True) encoded_ref.append(conv_prod) else: encoded_ref = encoded_image_ref encoded_ref = encoded_ref[::-1] return x, encoded_ref, atn, atn_vis, ref_idx def get_norm_weights(self, x, i): r"""Adaptively generate weights for SPADE in layer i of generator. Args: x (NxCxHxW tensor): Input features. i (int): Layer index. Returns: (tuple): - embedding_weights (list of tensors): Weights for the label embedding network. - norm_weights (list of tensors): Weights for the SPADE layers. """ if not self.mul_ref_label: # Get fixed output size for fc layers. x = nn.AdaptiveAvgPool2d((self.sh_fix, self.sw_fix))(x) in_ch = self.num_filters_each_layer[i] out_ch = self.num_filters_each_layer[i + 1] spade_ch = self.spade_in_channels[i] eks, sks = self.embed_kernel_size, self.kernel_size b = x.size(0) weight_reshaper = WeightReshaper() x = weight_reshaper.reshape_embed_input(x) # Weights for the label embedding network. embedding_weights = None if self.use_hyper_embed: fc_e = getattr(self, 'fc_spade_e_' + str(i))(x).view(b, -1) if 'decoder' in self.embed_arch: weight_shape = [in_ch, out_ch, eks, eks] fc_e = fc_e[:, :-in_ch] else: weight_shape = [out_ch, in_ch, eks, eks] embedding_weights = weight_reshaper.reshape_weight(fc_e, weight_shape) # Weights for the 3 layers in SPADE module: conv_0, conv_1, # and shortcut. fc_0 = getattr(self, 'fc_spade_0_' + str(i))(x).view(b, -1) fc_1 = getattr(self, 'fc_spade_1_' + str(i))(x).view(b, -1) fc_s = getattr(self, 'fc_spade_s_' + str(i))(x).view(b, -1) if self.conv_before_norm: out_ch = in_ch weight_0 = weight_reshaper.reshape_weight(fc_0, [out_ch * 2, spade_ch, sks, sks]) weight_1 = weight_reshaper.reshape_weight(fc_1, [in_ch * 2, spade_ch, sks, sks]) weight_s = weight_reshaper.reshape_weight(fc_s, [out_ch * 2, spade_ch, sks, sks]) norm_weights = [weight_0, weight_1, weight_s] return embedding_weights, norm_weights def get_conv_weights(self, x, i): r"""Adaptively generate weights for layer i in main branch convolutions. Args: x (NxCxHxW tensor): Input features. i (int): Layer index. Returns: (tuple): - conv_weights (list of tensors): Weights for the conv layers in the main branch. """ if not self.mul_ref_label: x = nn.AdaptiveAvgPool2d((self.sh_fix, self.sw_fix))(x) in_ch = self.num_filters_each_layer[i] out_ch = self.num_filters_each_layer[i + 1] cks = self.conv_kernel_size b = x.size()[0] weight_reshaper = WeightReshaper() x = weight_reshaper.reshape_embed_input(x) fc_0 = getattr(self, 'fc_conv_0_' + str(i))(x).view(b, -1) fc_1 = getattr(self, 'fc_conv_1_' + str(i))(x).view(b, -1) fc_s = getattr(self, 'fc_conv_s_' + str(i))(x).view(b, -1) weight_0 = weight_reshaper.reshape_weight(fc_0, [in_ch, out_ch, cks, cks]) weight_1 = weight_reshaper.reshape_weight(fc_1, [in_ch, in_ch, cks, cks]) weight_s = weight_reshaper.reshape_weight(fc_s, [in_ch, out_ch, 1, 1]) return [weight_0, weight_1, weight_s] def reset(self): r"""Reset the network at the beginning of a sequence.""" self.embedding_weights = self.conv_weights = self.norm_weights = None class WeightReshaper(): r"""Handles all weight reshape related tasks.""" def reshape_weight(self, x, weight_shape): r"""Reshape input x to the desired weight shape. Args: x (tensor or list of tensors): Input features. weight_shape (list of int): Desired shape of the weight. Returns: (tuple): - weight (tensor): Network weights - bias (tensor): Network bias. """ # If desired shape is a list, first divide x into the target list of # features. if type(weight_shape[0]) == list and type(x) != list: x = self.split_weights(x, self.sum_mul(weight_shape)) if type(x) == list: return [self.reshape_weight(xi, wi) for xi, wi in zip(x, weight_shape)] # Get output shape, and divide x into either weight + bias or # just weight. weight_shape = [x.size(0)] + weight_shape bias_size = weight_shape[1] try: weight = x[:, :-bias_size].view(weight_shape) bias = x[:, -bias_size:] except Exception: weight = x.view(weight_shape) bias = None return [weight, bias] def split_weights(self, weight, sizes): r"""When the desired shape is a list, first divide the input to each corresponding weight shape in the list. Args: weight (tensor): Input weight. sizes (int or list of int): Target sizes. Returns: weight (list of tensors): Divided weights. """ if isinstance(sizes, list): weights = [] cur_size = 0 for i in range(len(sizes)): # For each target size in sizes, get the number of elements # needed. next_size = cur_size + self.sum(sizes[i]) # Recursively divide the weights. weights.append(self.split_weights( weight[:, cur_size:next_size], sizes[i])) cur_size = next_size assert (next_size == weight.size(1)) return weights return weight def reshape_embed_input(self, x): r"""Reshape input to be (B x C) X H X W. Args: x (tensor or list of tensors): Input features. Returns: x (tensor or list of tensors): Reshaped features. """ if isinstance(x, list): return [self.reshape_embed_input(xi) for xi in zip(x)] b, c, _, _ = x.size() x = x.view(b * c, -1) return x def sum(self, x): r"""Sum all elements recursively in a nested list. Args: x (nested list of int): Input list of elements. Returns: out (int): Sum of all elements. """ if type(x) != list: return x return sum([self.sum(xi) for xi in x]) def sum_mul(self, x): r"""Given a weight shape, compute the number of elements needed for weight + bias. If input is a list of shapes, sum all the elements. Args: x (list of int): Input list of elements. Returns: out (int or list of int): Summed number of elements. """ assert (type(x) == list) if type(x[0]) != list: return np.prod(x) + x[0] # x[0] accounts for bias. return [self.sum_mul(xi) for xi in x] class AttentionModule(nn.Module): r"""Attention module constructor. Args: atn_cfg (obj): Generator definition part of the yaml config file. data_cfg (obj): Data definition part of the yaml config file conv_2d_block: Conv2DBlock constructor. num_filters_each_layer (int): The number of filters in each layer. """ def __init__(self, atn_cfg, data_cfg, conv_2d_block, num_filters_each_layer): super().__init__() self.initial_few_shot_K = data_cfg.initial_few_shot_K num_input_channels = data_cfg.num_input_channels num_filters = getattr(atn_cfg, 'num_filters', 32) self.num_downsample_atn = getattr(atn_cfg, 'num_downsamples', 2) self.atn_query_first = conv_2d_block(num_input_channels, num_filters) self.atn_key_first = conv_2d_block(num_input_channels, num_filters) for i in range(self.num_downsamples_atn): f_in, f_out = num_filters_each_layer[i], \ num_filters_each_layer[i + 1] setattr(self, 'atn_key_%d' % i, conv_2d_block(f_in, f_out, stride=2)) setattr(self, 'atn_query_%d' % i, conv_2d_block(f_in, f_out, stride=2)) def forward(self, in_features, label, ref_label, attention=None): r"""Get the attention map to combine multiple image features in the case of multiple reference images. Args: in_features ((NxK)xC1xH1xW1 tensor): Input feaures. label (NxC2xH2xW2 tensor): Target label. ref_label (NxC2xH2xW2 tensor): Reference label. attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps. Returns: (tuple): - out_features (NxC1xH1xW1 tensor): Attention-combined features. - attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps. - atn_vis (1x1xH1xW1 tensor): Visualization for attention scores. """ b, c, h, w = in_features.size() k = self.initial_few_shot_K b = b // k if attention is None: # Compute the attention map by encoding ref_label and label as # key and query. The map represents how much energy for the k-th # map at location (h_i, w_j) can contribute to the final map at # location (h_i2, w_j2). atn_key = self.attention_encode(ref_label, 'atn_key') atn_query = self.attention_encode(label, 'atn_query') atn_key = atn_key.view(b, k, c, -1).permute( 0, 1, 3, 2).contiguous().view(b, -1, c) # B X KHW X C atn_query = atn_query.view(b, c, -1) # B X C X HW energy = torch.bmm(atn_key, atn_query) # B X KHW X HW attention = nn.Softmax(dim=1)(energy) # Combine the K features from different ref images into one by using # the attention map. in_features = in_features.view(b, k, c, h * w).permute( 0, 2, 1, 3).contiguous().view(b, c, -1) # B X C X KHW out_features = torch.bmm(in_features, attention).view(b, c, h, w) # Get a slice of the attention map for visualization. atn_vis = attention.view(b, k, h * w, h * w).sum(2).view(b, k, h, w) return out_features, attention, atn_vis[-1:, 0:1] def attention_encode(self, img, net_name): r"""Encode the input image to get the attention map. Args: img (NxCxHxW tensor): Input image. net_name (str): Name for attention network. Returns: x (NxC2xH2xW2 tensor): Encoded feature. """ x = getattr(self, net_name + '_first')(img) for i in range(self.num_downsample_atn): x = getattr(self, net_name + '_' + str(i))(x) return x class FlowGenerator(nn.Module): r"""flow generator constructor. Args: flow_cfg (obj): Flow definition part of the yaml config file. data_cfg (obj): Data definition part of the yaml config file. num_frames (int): Number of input frames. """ def __init__(self, flow_cfg, data_cfg, num_frames): super().__init__() num_input_channels = data_cfg.num_input_channels if num_input_channels == 0: num_input_channels = 1 num_prev_img_channels = get_paired_input_image_channel_number(data_cfg) num_downsamples = getattr(flow_cfg, 'num_downsamples', 3) kernel_size = getattr(flow_cfg, 'kernel_size', 3) padding = kernel_size // 2 num_blocks = getattr(flow_cfg, 'num_blocks', 6) num_filters = getattr(flow_cfg, 'num_filters', 32) max_num_filters = getattr(flow_cfg, 'max_num_filters', 1024) num_filters_each_layer = [min(max_num_filters, num_filters * (2 ** i)) for i in range(num_downsamples + 1)] self.flow_output_multiplier = getattr(flow_cfg, 'flow_output_multiplier', 20) self.sep_up_mask = getattr(flow_cfg, 'sep_up_mask', False) activation_norm_type = getattr(flow_cfg, 'activation_norm_type', 'sync_batch') weight_norm_type = getattr(flow_cfg, 'weight_norm_type', 'spectral') base_conv_block = partial(Conv2dBlock, kernel_size=kernel_size, padding=padding, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, nonlinearity='leakyrelu') num_input_channels = num_input_channels * num_frames + \ num_prev_img_channels * (num_frames - 1) # First layer. down_flow = [base_conv_block(num_input_channels, num_filters)] # Downsamples. for i in range(num_downsamples): down_flow += [base_conv_block(num_filters_each_layer[i], num_filters_each_layer[i + 1], stride=2)] # Resnet blocks. res_flow = [] ch = num_filters_each_layer[num_downsamples] for i in range(num_blocks): res_flow += [ Res2dBlock(ch, ch, kernel_size, padding=padding, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, order='NACNAC')] # Upsamples. up_flow = [] for i in reversed(range(num_downsamples)): up_flow += [nn.Upsample(scale_factor=2), base_conv_block(num_filters_each_layer[i + 1], num_filters_each_layer[i])] conv_flow = [Conv2dBlock(num_filters, 2, kernel_size, padding=padding)] conv_mask = [Conv2dBlock(num_filters, 1, kernel_size, padding=padding, nonlinearity='sigmoid')] self.down_flow = nn.Sequential(*down_flow) self.res_flow = nn.Sequential(*res_flow) self.up_flow = nn.Sequential(*up_flow) if self.sep_up_mask: self.up_mask = nn.Sequential(*copy.deepcopy(up_flow)) self.conv_flow = nn.Sequential(*conv_flow) self.conv_mask = nn.Sequential(*conv_mask) def forward(self, label, ref_label, ref_image): r"""Flow generator forward. Args: label (4D tensor) : Input label tensor. ref_label (4D tensor) : Reference label tensors. ref_image (4D tensor) : Reference image tensors. Returns: (tuple): - flow (4D tensor) : Generated flow map. - mask (4D tensor) : Generated occlusion mask. """ label_concat = torch.cat([label, ref_label, ref_image], dim=1) downsample = self.down_flow(label_concat) res = self.res_flow(downsample) flow_feat = self.up_flow(res) flow = self.conv_flow(flow_feat) * self.flow_output_multiplier mask_feat = self.up_mask(res) if self.sep_up_mask else flow_feat mask = self.conv_mask(mask_feat) return flow, mask class LabelEmbedder(nn.Module): r"""Embed the input label map to get embedded features. Args: emb_cfg (obj): Embed network configuration. num_input_channels (int): Number of input channels. num_hyper_layers (int): Number of hyper layers. """ def __init__(self, emb_cfg, num_input_channels, num_hyper_layers=0): super().__init__() num_filters = getattr(emb_cfg, 'num_filters', 32) max_num_filters = getattr(emb_cfg, 'max_num_filters', 1024) self.arch = getattr(emb_cfg, 'arch', 'encoderdecoder') self.num_downsamples = num_downsamples = \ getattr(emb_cfg, 'num_downsamples', 5) kernel_size = getattr(emb_cfg, 'kernel_size', 3) weight_norm_type = getattr(emb_cfg, 'weight_norm_type', 'spectral') activation_norm_type = getattr(emb_cfg, 'activation_norm_type', 'none') self.unet = 'unet' in self.arch self.has_decoder = 'decoder' in self.arch or self.unet self.num_hyper_layers = num_hyper_layers \ if num_hyper_layers != -1 else num_downsamples base_conv_block = partial(HyperConv2dBlock, kernel_size=kernel_size, padding=(kernel_size // 2), weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, nonlinearity='leakyrelu') ch = [min(max_num_filters, num_filters * (2 ** i)) for i in range(num_downsamples + 1)] self.conv_first = base_conv_block(num_input_channels, num_filters, activation_norm_type='none') # Downsample. for i in range(num_downsamples): is_hyper_conv = (i < num_hyper_layers) and not self.has_decoder setattr(self, 'down_%d' % i, base_conv_block(ch[i], ch[i + 1], stride=2, is_hyper_conv=is_hyper_conv)) # Upsample. if self.has_decoder: self.upsample = nn.Upsample(scale_factor=2) for i in reversed(range(num_downsamples)): ch_i = ch[i + 1] * ( 2 if self.unet and i != num_downsamples - 1 else 1) setattr(self, 'up_%d' % i, base_conv_block(ch_i, ch[i], is_hyper_conv=(i < num_hyper_layers))) def forward(self, input, weights=None): r"""Embedding network forward. Args: input (NxCxHxW tensor): Network input. weights (list of tensors): Conv weights if using hyper network. Returns: output (list of tensors): Network outputs at different layers. """ if input is None: return None output = [self.conv_first(input)] for i in range(self.num_downsamples): layer = getattr(self, 'down_%d' % i) # For hyper networks, the hyper layers are at the last few layers # of decoder (if the network has a decoder). Otherwise, the hyper # layers will be at the first few layers of the network. if i >= self.num_hyper_layers or self.has_decoder: conv = layer(output[-1]) else: conv = layer(output[-1], conv_weights=weights[i]) # We will use outputs from different layers as input to different # SPADE layers in the main branch. output.append(conv) if not self.has_decoder: return output # If the network has a decoder, will use outputs from the decoder # layers instead of the encoding layers. if not self.unet: output = [output[-1]] for i in reversed(range(self.num_downsamples)): input_i = output[-1] if self.unet and i != self.num_downsamples - 1: input_i = torch.cat([input_i, output[i + 1]], dim=1) input_i = self.upsample(input_i) layer = getattr(self, 'up_%d' % i) # The last few layers will be hyper layers if necessary. if i >= self.num_hyper_layers: conv = layer(input_i) else: conv = layer(input_i, conv_weights=weights[i]) output.append(conv) if self.unet: output = output[self.num_downsamples:] return output[::-1]