Spaces:
Runtime error
Runtime error
from typing import Optional, List | |
import torch | |
from matplotlib import pyplot | |
from torch import Tensor | |
from torch.nn import Module, Sequential, Tanh, Sigmoid | |
from tha3.nn.image_processing_util import GridChangeApplier, apply_color_change | |
from tha3.nn.common.resize_conv_unet import ResizeConvUNet, ResizeConvUNetArgs | |
from tha3.util import numpy_linear_to_srgb | |
from tha3.module.module_factory import ModuleFactory | |
from tha3.nn.conv import create_conv3_from_block_args, create_conv3 | |
from tha3.nn.nonlinearity_factory import ReLUFactory | |
from tha3.nn.normalization import InstanceNorm2dFactory | |
from tha3.nn.util import BlockArgs | |
class Editor07Args: | |
def __init__(self, | |
image_size: int = 512, | |
image_channels: int = 4, | |
num_pose_params: int = 6, | |
start_channels: int = 32, | |
bottleneck_image_size=32, | |
num_bottleneck_blocks=6, | |
max_channels: int = 512, | |
upsampling_mode: str = 'nearest', | |
block_args: Optional[BlockArgs] = None, | |
use_separable_convolution: bool = False): | |
if block_args is None: | |
block_args = BlockArgs( | |
normalization_layer_factory=InstanceNorm2dFactory(), | |
nonlinearity_factory=ReLUFactory(inplace=False)) | |
self.block_args = block_args | |
self.upsampling_mode = upsampling_mode | |
self.max_channels = max_channels | |
self.num_bottleneck_blocks = num_bottleneck_blocks | |
self.bottleneck_image_size = bottleneck_image_size | |
self.start_channels = start_channels | |
self.num_pose_params = num_pose_params | |
self.image_channels = image_channels | |
self.image_size = image_size | |
self.use_separable_convolution = use_separable_convolution | |
class Editor07(Module): | |
def __init__(self, args: Editor07Args): | |
super().__init__() | |
self.args = args | |
self.body = ResizeConvUNet(ResizeConvUNetArgs( | |
image_size=args.image_size, | |
input_channels=2 * args.image_channels + args.num_pose_params + 2, | |
start_channels=args.start_channels, | |
bottleneck_image_size=args.bottleneck_image_size, | |
num_bottleneck_blocks=args.num_bottleneck_blocks, | |
max_channels=args.max_channels, | |
upsample_mode=args.upsampling_mode, | |
block_args=args.block_args, | |
use_separable_convolution=args.use_separable_convolution)) | |
self.color_change_creator = Sequential( | |
create_conv3_from_block_args( | |
in_channels=self.args.start_channels, | |
out_channels=self.args.image_channels, | |
bias=True, | |
block_args=self.args.block_args), | |
Tanh()) | |
self.alpha_creator = Sequential( | |
create_conv3_from_block_args( | |
in_channels=self.args.start_channels, | |
out_channels=self.args.image_channels, | |
bias=True, | |
block_args=self.args.block_args), | |
Sigmoid()) | |
self.grid_change_creator = create_conv3( | |
in_channels=self.args.start_channels, | |
out_channels=2, | |
bias=False, | |
initialization_method='zero', | |
use_spectral_norm=False) | |
self.grid_change_applier = GridChangeApplier() | |
def forward(self, | |
input_original_image: Tensor, | |
input_warped_image: Tensor, | |
input_grid_change: Tensor, | |
pose: Tensor, | |
*args) -> List[Tensor]: | |
n, c = pose.shape | |
pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.image_size, self.args.image_size) | |
feature = torch.cat([input_original_image, input_warped_image, input_grid_change, pose], dim=1) | |
feature = self.body.forward(feature)[-1] | |
output_grid_change = input_grid_change + self.grid_change_creator(feature) | |
output_color_change = self.color_change_creator(feature) | |
output_color_change_alpha = self.alpha_creator(feature) | |
output_warped_image = self.grid_change_applier.apply(output_grid_change, input_original_image) | |
output_color_changed = apply_color_change(output_color_change_alpha, output_color_change, output_warped_image) | |
return [ | |
output_color_changed, | |
output_color_change_alpha, | |
output_color_change, | |
output_warped_image, | |
output_grid_change, | |
] | |
COLOR_CHANGED_IMAGE_INDEX = 0 | |
COLOR_CHANGE_ALPHA_INDEX = 1 | |
COLOR_CHANGE_IMAGE_INDEX = 2 | |
WARPED_IMAGE_INDEX = 3 | |
GRID_CHANGE_INDEX = 4 | |
OUTPUT_LENGTH = 5 | |
class Editor07Factory(ModuleFactory): | |
def __init__(self, args: Editor07Args): | |
super().__init__() | |
self.args = args | |
def create(self) -> Module: | |
return Editor07(self.args) | |
def show_image(pytorch_image): | |
numpy_image = ((pytorch_image + 1.0) / 2.0).squeeze(0).numpy() | |
numpy_image[0:3, :, :] = numpy_linear_to_srgb(numpy_image[0:3, :, :]) | |
c, h, w = numpy_image.shape | |
numpy_image = numpy_image.reshape((c, h * w)).transpose().reshape((h, w, c)) | |
pyplot.imshow(numpy_image) | |
pyplot.show() | |
if __name__ == "__main__": | |
cuda = torch.device('cuda') | |
image_size = 512 | |
image_channels = 4 | |
num_pose_params = 6 | |
args = Editor07Args( | |
image_size=512, | |
image_channels=4, | |
start_channels=32, | |
num_pose_params=6, | |
bottleneck_image_size=32, | |
num_bottleneck_blocks=6, | |
max_channels=512, | |
upsampling_mode='nearest', | |
block_args=BlockArgs( | |
initialization_method='he', | |
use_spectral_norm=False, | |
normalization_layer_factory=InstanceNorm2dFactory(), | |
nonlinearity_factory=ReLUFactory(inplace=False))) | |
module = Editor07(args).to(cuda) | |
image_count = 1 | |
input_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda) | |
direct_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda) | |
warped_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda) | |
grid_change = torch.zeros(image_count, 2, image_size, image_size, device=cuda) | |
pose = torch.zeros(image_count, num_pose_params, device=cuda) | |
repeat = 100 | |
acc = 0.0 | |
for i in range(repeat + 2): | |
start = torch.cuda.Event(enable_timing=True) | |
end = torch.cuda.Event(enable_timing=True) | |
start.record() | |
module.forward(input_image, warped_image, grid_change, pose) | |
end.record() | |
torch.cuda.synchronize() | |
if i >= 2: | |
elapsed_time = start.elapsed_time(end) | |
print("%d:" % i, elapsed_time) | |
acc = acc + elapsed_time | |
print("average:", acc / repeat) | |