File size: 5,164 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import math
from typing import Optional, List

import torch
from torch import Tensor
from torch.nn import ModuleList, Module

from tha3.nn.common.poser_args import PoserArgs00
from tha3.nn.conv import create_conv3_block_from_block_args, create_downsample_block_from_block_args, \
    create_upsample_block_from_block_args
from tha3.nn.nonlinearity_factory import ReLUFactory
from tha3.nn.normalization import InstanceNorm2dFactory
from tha3.nn.resnet_block import ResnetBlock
from tha3.nn.util import BlockArgs


class PoserEncoderDecoder00Args(PoserArgs00):
    def __init__(self,
                 image_size: int,
                 input_image_channels: int,
                 output_image_channels: int,
                 num_pose_params: int ,
                 start_channels: int,
                 bottleneck_image_size,
                 num_bottleneck_blocks,
                 max_channels: int,
                 block_args: Optional[BlockArgs] = None):
        super().__init__(
            image_size, input_image_channels, output_image_channels, start_channels, num_pose_params, block_args)
        self.max_channels = max_channels
        self.num_bottleneck_blocks = num_bottleneck_blocks
        self.bottleneck_image_size = bottleneck_image_size
        assert bottleneck_image_size > 1

        if block_args is None:
            self.block_args = BlockArgs(
                normalization_layer_factory=InstanceNorm2dFactory(),
                nonlinearity_factory=ReLUFactory(inplace=True))
        else:
            self.block_args = block_args


class PoserEncoderDecoder00(Module):
    def __init__(self, args: PoserEncoderDecoder00Args):
        super().__init__()
        self.args = args

        self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1

        self.downsample_blocks = ModuleList()
        self.downsample_blocks.append(
            create_conv3_block_from_block_args(
                args.input_image_channels,
                args.start_channels,
                args.block_args))
        current_image_size = args.image_size
        current_num_channels = args.start_channels
        while current_image_size > args.bottleneck_image_size:
            next_image_size = current_image_size // 2
            next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
            self.downsample_blocks.append(create_downsample_block_from_block_args(
                in_channels=current_num_channels,
                out_channels=next_num_channels,
                is_output_1x1=False,
                block_args=args.block_args))
            current_image_size = next_image_size
            current_num_channels = next_num_channels
        assert len(self.downsample_blocks) == self.num_levels

        self.bottleneck_blocks = ModuleList()
        self.bottleneck_blocks.append(create_conv3_block_from_block_args(
            in_channels=current_num_channels + args.num_pose_params,
            out_channels=current_num_channels,
            block_args=args.block_args))
        for i in range(1, args.num_bottleneck_blocks):
            self.bottleneck_blocks.append(
                ResnetBlock.create(
                    num_channels=current_num_channels,
                    is1x1=False,
                    block_args=args.block_args))

        self.upsample_blocks = ModuleList()
        while current_image_size < args.image_size:
            next_image_size = current_image_size * 2
            next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
            self.upsample_blocks.append(create_upsample_block_from_block_args(
                in_channels=current_num_channels,
                out_channels=next_num_channels,
                block_args=args.block_args))
            current_image_size = next_image_size
            current_num_channels = next_num_channels

    def get_num_output_channels_from_level(self, level: int):
        return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level))

    def get_num_output_channels_from_image_size(self, image_size: int):
        return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels)

    def forward(self, image: Tensor, pose: Optional[Tensor] = None) -> List[Tensor]:
        if self.args.num_pose_params != 0:
            assert pose is not None
        else:
            assert pose is None
        outputs = []
        feature = image
        outputs.append(feature)
        for block in self.downsample_blocks:
            feature = block(feature)
            outputs.append(feature)
        if pose is not None:
            n, c = pose.shape
            pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size)
            feature = torch.cat([feature, pose], dim=1)
        for block in self.bottleneck_blocks:
            feature = block(feature)
            outputs.append(feature)
        for block in self.upsample_blocks:
            feature = block(feature)
            outputs.append(feature)
        outputs.reverse()
        return outputs