venite's picture
initial
f670afc
raw
history blame
14.8 kB
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from imaginaire.config import Config
from imaginaire.generators.vid2vid import Generator as Vid2VidGenerator
from imaginaire.model_utils.fs_vid2vid import resample
from imaginaire.model_utils.wc_vid2vid.render import SplatRenderer
from imaginaire.utils.trainer import (get_model_optimizer_and_scheduler,
get_trainer)
from imaginaire.utils.visualization import tensor2im
class Generator(Vid2VidGenerator):
r"""world consistent vid2vid generator constructor.
Args:
gen_cfg (obj): Generator definition part of the yaml config file.
data_cfg (obj): Data definition part of the yaml config file
"""
def __init__(self, gen_cfg, data_cfg):
# Guidance options.
self.guidance_cfg = gen_cfg.guidance
self.guidance_only_with_flow = getattr(
self.guidance_cfg, 'only_with_flow', False)
self.guidance_partial_conv = getattr(
self.guidance_cfg, 'partial_conv', False)
# Splatter for guidance.
self.renderer = SplatRenderer()
self.reset_renderer()
# Single image model.
self.single_image_model = None
# Initialize the rest same as vid2vid.
super().__init__(gen_cfg, data_cfg)
def _init_single_image_model(self, load_weights=True):
r"""Load single image model, if any."""
if self.single_image_model is None and \
hasattr(self.gen_cfg, 'single_image_model'):
print('Using single image model...')
single_image_cfg = Config(self.gen_cfg.single_image_model.config)
# Init model.
net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
get_model_optimizer_and_scheduler(single_image_cfg)
# Init trainer and load checkpoint.
trainer = get_trainer(single_image_cfg, net_G, net_D,
opt_G, opt_D,
sch_G, sch_D,
None, None)
if load_weights:
print('Loading single image model checkpoint')
single_image_ckpt = self.gen_cfg.single_image_model.checkpoint
trainer.load_checkpoint(single_image_cfg, single_image_ckpt)
print('Loaded single image model checkpoint')
self.single_image_model = net_G.module
self.single_image_model_z = None
def reset_renderer(self, is_flipped_input=False):
r"""Reset the renderer.
Args:
is_flipped_input (bool): Is the input sequence left-right flipped?
"""
self.renderer.reset()
self.is_flipped_input = is_flipped_input
self.renderer_num_forwards = 0
self.single_image_model_z = None
def renderer_update_point_cloud(self, image, point_info):
r"""Update the renderer's color dictionary."""
if point_info is None or len(point_info) == 0:
return
# print('Updating the renderer.')
_, _, h, w = image.size()
# Renderer expects (h, w, c) [0-255] RGB image.
if isinstance(image, torch.Tensor):
image = tensor2im(image.detach())[0]
# Flip this image to correspond to SfM camera pose.
if self.is_flipped_input:
image = np.fliplr(image).copy()
self.renderer.update_point_cloud(image, point_info)
self.renderer_num_forwards += 1
def get_guidance_images_and_masks(self, unprojection):
r"""Do stuff."""
resolution = 'w1024xh512'
point_info = unprojection[resolution]
w, h = resolution.split('x')
w, h = int(w[1:]), int(h[1:])
# This returns guidance image in [0-255] RGB.
# We will convert it into Tensor repr. below.
guidance_image, guidance_mask = self.renderer.render_image(
point_info, w, h, return_mask=True)
# If mask is None, there is no guidance.
# print(np.sum(guidance_mask), guidance_mask.size)
# if np.sum(guidance_mask) == 0:
# return None, point_info
# Flip guidance image and guidance mask if needed.
if self.is_flipped_input:
guidance_image = np.fliplr(guidance_image).copy()
guidance_mask = np.fliplr(guidance_mask).copy()
# Go from (h, w, c) to (1, c, h, w).
# Convert guidance image to Tensor.
guidance_image = (transforms.ToTensor()(guidance_image) - 0.5) * 2
guidance_mask = transforms.ToTensor()(guidance_mask)
guidance = torch.cat((guidance_image, guidance_mask), dim=0)
guidance = guidance.unsqueeze(0).cuda()
# Save guidance at all resolutions.
guidance_images_and_masks = guidance
return guidance_images_and_masks, point_info
def forward(self, data):
r"""vid2vid generator forward.
Args:
data (dict) : Dictionary of input data.
Returns:
output (dict) : Dictionary of output data.
"""
self._init_single_image_model()
label = data['label']
unprojection = data['unprojection']
label_prev, img_prev = data['prev_labels'], data['prev_images']
is_first_frame = img_prev is None
z = getattr(data, 'z', None)
bs, _, h, w = label.size()
# Whether to warp the previous frame or not.
flow = mask = img_warp = None
warp_prev = self.temporal_initialized and not is_first_frame and \
label_prev.shape[1] == self.num_frames_G - 1
# Get guidance images and masks.
guidance_images_and_masks, point_info = None, None
if unprojection is not None:
guidance_images_and_masks, point_info = \
self.get_guidance_images_and_masks(unprojection)
# Get SPADE conditional maps by embedding current label input.
cond_maps_now = self.get_cond_maps(label, self.label_embedding)
# Use single image model, if flow features are not available.
# Guidance features are used whenever flow features are available.
if self.single_image_model is not None and not warp_prev:
# Get z vector for single image model.
if self.single_image_model_z is None:
bs = data['label'].size(0)
z = torch.randn(bs, self.single_image_model.style_dims,
dtype=torch.float32).cuda()
if data['label'].dtype == torch.float16:
z = z.half()
self.single_image_model_z = z
# Get output image.
data['z'] = self.single_image_model_z
self.single_image_model.eval()
with torch.no_grad():
output = self.single_image_model.spade_generator(data)
img_final = output['fake_images'].detach()
fake_images_source = 'pretrained'
else:
# Input to the generator will either be noise/segmentation map (for
# first frame) or encoded previous frame (for subsequent frames).
if is_first_frame:
# First frame in the sequence, start from scratch.
if self.use_segmap_as_input:
x_img = F.interpolate(label, size=(self.sh, self.sw))
x_img = self.fc(x_img)
else:
if z is None:
z = torch.randn(bs, self.z_dim, dtype=label.dtype,
device=label.get_device()).fill_(0)
x_img = self.fc(z).view(bs, -1, self.sh, self.sw)
# Upsampling layers.
for i in range(self.num_layers, self.num_downsamples_img, -1):
j = min(self.num_downsamples_embed, i)
x_img = getattr(self, 'up_' + str(i)
)(x_img, *cond_maps_now[j])
x_img = self.upsample(x_img)
else:
# Not the first frame, will encode the previous frame and feed
# to the generator.
x_img = self.down_first(img_prev[:, -1])
# Get label embedding for the previous frame.
cond_maps_prev = self.get_cond_maps(label_prev[:, -1],
self.label_embedding)
# Downsampling layers.
for i in range(self.num_downsamples_img + 1):
j = min(self.num_downsamples_embed, i)
x_img = getattr(self, 'down_' + str(i))(x_img,
*cond_maps_prev[j])
if i != self.num_downsamples_img:
x_img = self.downsample(x_img)
# Resnet blocks.
j = min(self.num_downsamples_embed,
self.num_downsamples_img + 1)
for i in range(self.num_res_blocks):
cond_maps = cond_maps_prev[j] if \
i < self.num_res_blocks // 2 else cond_maps_now[j]
x_img = getattr(self, 'res_' + str(i))(x_img, *cond_maps)
# Optical flow warped image features.
if warp_prev:
# Estimate flow & mask.
label_concat = torch.cat([label_prev.view(bs, -1, h, w),
label], dim=1)
img_prev_concat = img_prev.view(bs, -1, h, w)
flow, mask = self.flow_network_temp(
label_concat, img_prev_concat)
img_warp = resample(img_prev[:, -1], flow)
if self.spade_combine:
# if using SPADE combine, integrate the warped image (and
# occlusion mask) into conditional inputs for SPADE.
img_embed = torch.cat([img_warp, mask], dim=1)
cond_maps_img = self.get_cond_maps(img_embed,
self.img_prev_embedding)
x_raw_img = None
# Main image generation branch.
for i in range(self.num_downsamples_img, -1, -1):
# Get SPADE conditional inputs.
j = min(i, self.num_downsamples_embed)
cond_maps = cond_maps_now[j]
# For raw output generation.
if self.generate_raw_output:
if i >= self.num_multi_spade_layers - 1:
x_raw_img = x_img
if i < self.num_multi_spade_layers:
x_raw_img = self.one_up_conv_layer(
x_raw_img, cond_maps, i)
# Add flow and guidance features.
if warp_prev:
if i < self.num_multi_spade_layers:
# Add flow.
cond_maps += cond_maps_img[j]
# Add guidance.
if guidance_images_and_masks is not None:
cond_maps += [guidance_images_and_masks]
elif not self.guidance_only_with_flow:
# Add guidance if it is to be applied to every layer.
if guidance_images_and_masks is not None:
cond_maps += [guidance_images_and_masks]
x_img = self.one_up_conv_layer(x_img, cond_maps, i)
# Final conv layer.
img_final = torch.tanh(self.conv_img(x_img))
fake_images_source = 'in_training'
# Update the point cloud color dict of renderer.
self.renderer_update_point_cloud(img_final, point_info)
output = dict()
output['fake_images'] = img_final
output['fake_flow_maps'] = flow
output['fake_occlusion_masks'] = mask
output['fake_raw_images'] = None
output['warped_images'] = img_warp
output['guidance_images_and_masks'] = guidance_images_and_masks
output['fake_images_source'] = fake_images_source
return output
def get_cond_dims(self, num_downs=0):
r"""Get the dimensions of conditional inputs.
Args:
num_downs (int) : How many downsamples at current layer.
Returns:
ch (list) : List of dimensions.
"""
if not self.use_embed:
ch = [self.num_input_channels]
else:
num_filters = getattr(self.emb_cfg, 'num_filters', 32)
num_downs = min(num_downs, self.num_downsamples_embed)
ch = [min(self.max_num_filters, num_filters * (2 ** num_downs))]
if (num_downs < self.num_multi_spade_layers):
ch = ch * 2
# Also add guidance (RGB + mask = 4 channels, or 3 if partial).
if self.guidance_partial_conv:
ch.append(3)
else:
ch.append(4)
elif not self.guidance_only_with_flow:
if self.guidance_partial_conv:
ch.append(3)
else:
ch.append(4)
return ch
def get_partial(self, num_downs=0):
r"""Get if convs should be partial or not.
Args:
num_downs (int) : How many downsamples at current layer.
Returns:
partial (list) : List of boolean partial or not.
"""
partial = [False]
if (num_downs < self.num_multi_spade_layers):
partial = partial * 2
# Also add guidance (RGB + mask = 4 channels, or 3 if partial).
if self.guidance_partial_conv:
partial.append(True)
else:
partial.append(False)
elif not self.guidance_only_with_flow:
if self.guidance_partial_conv:
partial.append(True)
else:
partial.append(False)
return partial
def get_cond_maps(self, label, embedder):
r"""Get the conditional inputs.
Args:
label (4D tensor) : Input label tensor.
embedder (obj) : Embedding network.
Returns:
cond_maps (list) : List of conditional inputs.
"""
if not self.use_embed:
return [label] * (self.num_layers + 1)
embedded_label = embedder(label)
cond_maps = [embedded_label]
cond_maps = [[m[i] for m in cond_maps] for i in
range(len(cond_maps[0]))]
return cond_maps