File size: 5,688 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from typing import Optional, List

import torch
from torch import Tensor
from torch.nn import ModuleList, Module, Upsample

from tha3.nn.common.conv_block_factory import ConvBlockFactory
from tha3.nn.nonlinearity_factory import ReLUFactory
from tha3.nn.normalization import InstanceNorm2dFactory
from tha3.nn.util import BlockArgs


class ResizeConvUNetArgs:
    def __init__(self,
                 image_size: int,
                 input_channels: int,
                 start_channels: int,
                 bottleneck_image_size: int,
                 num_bottleneck_blocks: int,
                 max_channels: int,
                 upsample_mode: str = 'bilinear',
                 block_args: Optional[BlockArgs] = None,
                 use_separable_convolution: bool = False):
        if block_args is None:
            block_args = BlockArgs(
                normalization_layer_factory=InstanceNorm2dFactory(),
                nonlinearity_factory=ReLUFactory(inplace=False))

        self.use_separable_convolution = use_separable_convolution
        self.block_args = block_args
        self.upsample_mode = upsample_mode
        self.max_channels = max_channels
        self.num_bottleneck_blocks = num_bottleneck_blocks
        self.bottleneck_image_size = bottleneck_image_size
        self.input_channels = input_channels
        self.start_channels = start_channels
        self.image_size = image_size


class ResizeConvUNet(Module):
    def __init__(self, args: ResizeConvUNetArgs):
        super().__init__()
        self.args = args
        conv_block_factory = ConvBlockFactory(args.block_args, args.use_separable_convolution)

        self.downsample_blocks = ModuleList()
        self.downsample_blocks.append(conv_block_factory.create_conv3_block(
            self.args.input_channels,
            self.args.start_channels))
        current_channels = self.args.start_channels
        current_size = self.args.image_size

        size_to_channel = {
            current_size: current_channels
        }
        while current_size > self.args.bottleneck_image_size:
            next_size = current_size // 2
            next_channels = min(self.args.max_channels, current_channels * 2)
            self.downsample_blocks.append(conv_block_factory.create_downsample_block(
                current_channels,
                next_channels,
                is_output_1x1=False))
            current_size = next_size
            current_channels = next_channels
            size_to_channel[current_size] = current_channels

        self.bottleneck_blocks = ModuleList()
        for i in range(self.args.num_bottleneck_blocks):
            self.bottleneck_blocks.append(conv_block_factory.create_resnet_block(current_channels, is_1x1=False))

        self.output_image_sizes = [current_size]
        self.output_num_channels = [current_channels]
        self.upsample_blocks = ModuleList()
        while current_size < self.args.image_size:
            next_size = current_size * 2
            next_channels = size_to_channel[next_size]
            self.upsample_blocks.append(conv_block_factory.create_conv3_block(
                current_channels + next_channels,
                next_channels))
            current_size = next_size
            current_channels = next_channels
            self.output_image_sizes.append(current_size)
            self.output_num_channels.append(current_channels)

        if args.upsample_mode == 'nearest':
            align_corners = None
        else:
            align_corners = False
        self.double_resolution = Upsample(scale_factor=2, mode=args.upsample_mode, align_corners=align_corners)

    def forward(self, feature: Tensor) -> List[Tensor]:
        downsampled_features = []
        for block in self.downsample_blocks:
            feature = block(feature)
            downsampled_features.append(feature)

        for block in self.bottleneck_blocks:
            feature = block(feature)

        outputs = [feature]
        for i in range(0, len(self.upsample_blocks)):
            feature = self.double_resolution(feature)
            feature = torch.cat([feature, downsampled_features[-i - 2]], dim=1)
            feature = self.upsample_blocks[i](feature)
            outputs.append(feature)

        return outputs


if __name__ == "__main__":
    device = torch.device('cuda')

    image_size = 512
    image_channels = 4
    num_pose_params = 6
    args = ResizeConvUNetArgs(
        image_size=512,
        input_channels=10,
        start_channels=32,
        bottleneck_image_size=32,
        num_bottleneck_blocks=6,
        max_channels=512,
        upsample_mode='nearest',
        use_separable_convolution=False,
        block_args=BlockArgs(
            initialization_method='he',
            use_spectral_norm=False,
            normalization_layer_factory=InstanceNorm2dFactory(),
            nonlinearity_factory=ReLUFactory(inplace=False)))
    module = ResizeConvUNet(args).to(device)

    image_count = 8
    input = torch.zeros(image_count, 10, 512, 512, device=device)
    outputs = module.forward(input)
    for output in outputs:
        print(output.shape)


    if True:
        repeat = 100
        acc = 0.0
        for i in range(repeat + 2):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            start.record()
            module.forward(input)
            end.record()
            torch.cuda.synchronize()
            if i >= 2:
                elapsed_time = start.elapsed_time(end)
                print("%d:" % i, elapsed_time)
                acc = acc + elapsed_time

        print("average:", acc / repeat)