File size: 6,795 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from typing import Optional, List

import torch
from matplotlib import pyplot
from torch import Tensor
from torch.nn import Module, Sequential, Tanh, Sigmoid

from tha3.nn.image_processing_util import GridChangeApplier, apply_color_change
from tha3.nn.common.resize_conv_unet import ResizeConvUNet, ResizeConvUNetArgs
from tha3.util import numpy_linear_to_srgb
from tha3.module.module_factory import ModuleFactory
from tha3.nn.conv import create_conv3_from_block_args, create_conv3
from tha3.nn.nonlinearity_factory import ReLUFactory
from tha3.nn.normalization import InstanceNorm2dFactory
from tha3.nn.util import BlockArgs


class Editor07Args:
    def __init__(self,
                 image_size: int = 512,
                 image_channels: int = 4,
                 num_pose_params: int = 6,
                 start_channels: int = 32,
                 bottleneck_image_size=32,
                 num_bottleneck_blocks=6,
                 max_channels: int = 512,
                 upsampling_mode: str = 'nearest',
                 block_args: Optional[BlockArgs] = None,
                 use_separable_convolution: bool = False):
        if block_args is None:
            block_args = BlockArgs(
                normalization_layer_factory=InstanceNorm2dFactory(),
                nonlinearity_factory=ReLUFactory(inplace=False))

        self.block_args = block_args
        self.upsampling_mode = upsampling_mode
        self.max_channels = max_channels
        self.num_bottleneck_blocks = num_bottleneck_blocks
        self.bottleneck_image_size = bottleneck_image_size
        self.start_channels = start_channels
        self.num_pose_params = num_pose_params
        self.image_channels = image_channels
        self.image_size = image_size
        self.use_separable_convolution = use_separable_convolution


class Editor07(Module):
    def __init__(self, args: Editor07Args):
        super().__init__()
        self.args = args

        self.body = ResizeConvUNet(ResizeConvUNetArgs(
            image_size=args.image_size,
            input_channels=2 * args.image_channels + args.num_pose_params + 2,
            start_channels=args.start_channels,
            bottleneck_image_size=args.bottleneck_image_size,
            num_bottleneck_blocks=args.num_bottleneck_blocks,
            max_channels=args.max_channels,
            upsample_mode=args.upsampling_mode,
            block_args=args.block_args,
            use_separable_convolution=args.use_separable_convolution))
        self.color_change_creator = Sequential(
            create_conv3_from_block_args(
                in_channels=self.args.start_channels,
                out_channels=self.args.image_channels,
                bias=True,
                block_args=self.args.block_args),
            Tanh())
        self.alpha_creator = Sequential(
            create_conv3_from_block_args(
                in_channels=self.args.start_channels,
                out_channels=self.args.image_channels,
                bias=True,
                block_args=self.args.block_args),
            Sigmoid())
        self.grid_change_creator = create_conv3(
            in_channels=self.args.start_channels,
            out_channels=2,
            bias=False,
            initialization_method='zero',
            use_spectral_norm=False)
        self.grid_change_applier = GridChangeApplier()

    def forward(self,
                input_original_image: Tensor,
                input_warped_image: Tensor,
                input_grid_change: Tensor,
                pose: Tensor,
                *args) -> List[Tensor]:
        n, c = pose.shape
        pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.image_size, self.args.image_size)
        feature = torch.cat([input_original_image, input_warped_image, input_grid_change, pose], dim=1)

        feature = self.body.forward(feature)[-1]
        output_grid_change = input_grid_change + self.grid_change_creator(feature)

        output_color_change = self.color_change_creator(feature)
        output_color_change_alpha = self.alpha_creator(feature)
        output_warped_image = self.grid_change_applier.apply(output_grid_change, input_original_image)
        output_color_changed = apply_color_change(output_color_change_alpha, output_color_change, output_warped_image)

        return [
            output_color_changed,
            output_color_change_alpha,
            output_color_change,
            output_warped_image,
            output_grid_change,
        ]

    COLOR_CHANGED_IMAGE_INDEX = 0
    COLOR_CHANGE_ALPHA_INDEX = 1
    COLOR_CHANGE_IMAGE_INDEX = 2
    WARPED_IMAGE_INDEX = 3
    GRID_CHANGE_INDEX = 4
    OUTPUT_LENGTH = 5


class Editor07Factory(ModuleFactory):
    def __init__(self, args: Editor07Args):
        super().__init__()
        self.args = args

    def create(self) -> Module:
        return Editor07(self.args)


def show_image(pytorch_image):
    numpy_image = ((pytorch_image + 1.0) / 2.0).squeeze(0).numpy()
    numpy_image[0:3, :, :] = numpy_linear_to_srgb(numpy_image[0:3, :, :])
    c, h, w = numpy_image.shape
    numpy_image = numpy_image.reshape((c, h * w)).transpose().reshape((h, w, c))
    pyplot.imshow(numpy_image)
    pyplot.show()


if __name__ == "__main__":
    cuda = torch.device('cuda')

    image_size = 512
    image_channels = 4
    num_pose_params = 6
    args = Editor07Args(
        image_size=512,
        image_channels=4,
        start_channels=32,
        num_pose_params=6,
        bottleneck_image_size=32,
        num_bottleneck_blocks=6,
        max_channels=512,
        upsampling_mode='nearest',
        block_args=BlockArgs(
            initialization_method='he',
            use_spectral_norm=False,
            normalization_layer_factory=InstanceNorm2dFactory(),
            nonlinearity_factory=ReLUFactory(inplace=False)))
    module = Editor07(args).to(cuda)

    image_count = 1
    input_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda)
    direct_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda)
    warped_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda)
    grid_change = torch.zeros(image_count, 2, image_size, image_size, device=cuda)
    pose = torch.zeros(image_count, num_pose_params, device=cuda)

    repeat = 100
    acc = 0.0
    for i in range(repeat + 2):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)

        start.record()
        module.forward(input_image, warped_image, grid_change, pose)
        end.record()
        torch.cuda.synchronize()
        if i >= 2:
            elapsed_time = start.elapsed_time(end)
            print("%d:" % i, elapsed_time)
            acc = acc + elapsed_time

    print("average:", acc / repeat)