from typing import Optional, List import torch from matplotlib import pyplot from torch import Tensor from torch.nn import Module, Sequential, Tanh, Sigmoid from tha3.nn.image_processing_util import GridChangeApplier, apply_color_change from tha3.nn.common.resize_conv_unet import ResizeConvUNet, ResizeConvUNetArgs from tha3.util import numpy_linear_to_srgb from tha3.module.module_factory import ModuleFactory from tha3.nn.conv import create_conv3_from_block_args, create_conv3 from tha3.nn.nonlinearity_factory import ReLUFactory from tha3.nn.normalization import InstanceNorm2dFactory from tha3.nn.util import BlockArgs class Editor07Args: def __init__(self, image_size: int = 512, image_channels: int = 4, num_pose_params: int = 6, start_channels: int = 32, bottleneck_image_size=32, num_bottleneck_blocks=6, max_channels: int = 512, upsampling_mode: str = 'nearest', block_args: Optional[BlockArgs] = None, use_separable_convolution: bool = False): if block_args is None: block_args = BlockArgs( normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=ReLUFactory(inplace=False)) self.block_args = block_args self.upsampling_mode = upsampling_mode self.max_channels = max_channels self.num_bottleneck_blocks = num_bottleneck_blocks self.bottleneck_image_size = bottleneck_image_size self.start_channels = start_channels self.num_pose_params = num_pose_params self.image_channels = image_channels self.image_size = image_size self.use_separable_convolution = use_separable_convolution class Editor07(Module): def __init__(self, args: Editor07Args): super().__init__() self.args = args self.body = ResizeConvUNet(ResizeConvUNetArgs( image_size=args.image_size, input_channels=2 * args.image_channels + args.num_pose_params + 2, start_channels=args.start_channels, bottleneck_image_size=args.bottleneck_image_size, num_bottleneck_blocks=args.num_bottleneck_blocks, max_channels=args.max_channels, upsample_mode=args.upsampling_mode, block_args=args.block_args, use_separable_convolution=args.use_separable_convolution)) self.color_change_creator = Sequential( create_conv3_from_block_args( in_channels=self.args.start_channels, out_channels=self.args.image_channels, bias=True, block_args=self.args.block_args), Tanh()) self.alpha_creator = Sequential( create_conv3_from_block_args( in_channels=self.args.start_channels, out_channels=self.args.image_channels, bias=True, block_args=self.args.block_args), Sigmoid()) self.grid_change_creator = create_conv3( in_channels=self.args.start_channels, out_channels=2, bias=False, initialization_method='zero', use_spectral_norm=False) self.grid_change_applier = GridChangeApplier() def forward(self, input_original_image: Tensor, input_warped_image: Tensor, input_grid_change: Tensor, pose: Tensor, *args) -> List[Tensor]: n, c = pose.shape pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.image_size, self.args.image_size) feature = torch.cat([input_original_image, input_warped_image, input_grid_change, pose], dim=1) feature = self.body.forward(feature)[-1] output_grid_change = input_grid_change + self.grid_change_creator(feature) output_color_change = self.color_change_creator(feature) output_color_change_alpha = self.alpha_creator(feature) output_warped_image = self.grid_change_applier.apply(output_grid_change, input_original_image) output_color_changed = apply_color_change(output_color_change_alpha, output_color_change, output_warped_image) return [ output_color_changed, output_color_change_alpha, output_color_change, output_warped_image, output_grid_change, ] COLOR_CHANGED_IMAGE_INDEX = 0 COLOR_CHANGE_ALPHA_INDEX = 1 COLOR_CHANGE_IMAGE_INDEX = 2 WARPED_IMAGE_INDEX = 3 GRID_CHANGE_INDEX = 4 OUTPUT_LENGTH = 5 class Editor07Factory(ModuleFactory): def __init__(self, args: Editor07Args): super().__init__() self.args = args def create(self) -> Module: return Editor07(self.args) def show_image(pytorch_image): numpy_image = ((pytorch_image + 1.0) / 2.0).squeeze(0).numpy() numpy_image[0:3, :, :] = numpy_linear_to_srgb(numpy_image[0:3, :, :]) c, h, w = numpy_image.shape numpy_image = numpy_image.reshape((c, h * w)).transpose().reshape((h, w, c)) pyplot.imshow(numpy_image) pyplot.show() if __name__ == "__main__": cuda = torch.device('cuda') image_size = 512 image_channels = 4 num_pose_params = 6 args = Editor07Args( image_size=512, image_channels=4, start_channels=32, num_pose_params=6, bottleneck_image_size=32, num_bottleneck_blocks=6, max_channels=512, upsampling_mode='nearest', block_args=BlockArgs( initialization_method='he', use_spectral_norm=False, normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=ReLUFactory(inplace=False))) module = Editor07(args).to(cuda) image_count = 1 input_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda) direct_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda) warped_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda) grid_change = torch.zeros(image_count, 2, image_size, image_size, device=cuda) pose = torch.zeros(image_count, num_pose_params, device=cuda) repeat = 100 acc = 0.0 for i in range(repeat + 2): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() module.forward(input_image, warped_image, grid_change, pose) end.record() torch.cuda.synchronize() if i >= 2: elapsed_time = start.elapsed_time(end) print("%d:" % i, elapsed_time) acc = acc + elapsed_time print("average:", acc / repeat)