File size: 4,006 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import math
from typing import Optional, List

import torch
from torch import Tensor
from torch.nn import ModuleList, Module

from tha3.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args
from tha3.nn.resnet_block_seperable import ResnetBlockSeparable
from tha3.nn.separable_conv import create_separable_conv3_block, create_separable_downsample_block, \
    create_separable_upsample_block


class PoserEncoderDecoder00Separable(Module):
    def __init__(self, args: PoserEncoderDecoder00Args):
        super().__init__()
        self.args = args

        self.num_levels = int(math.log2(args.image_size // args.bottleneck_image_size)) + 1

        self.downsample_blocks = ModuleList()
        self.downsample_blocks.append(
            create_separable_conv3_block(
                args.input_image_channels,
                args.start_channels,
                args.block_args))
        current_image_size = args.image_size
        current_num_channels = args.start_channels
        while current_image_size > args.bottleneck_image_size:
            next_image_size = current_image_size // 2
            next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
            self.downsample_blocks.append(create_separable_downsample_block(
                in_channels=current_num_channels,
                out_channels=next_num_channels,
                is_output_1x1=False,
                block_args=args.block_args))
            current_image_size = next_image_size
            current_num_channels = next_num_channels
        assert len(self.downsample_blocks) == self.num_levels

        self.bottleneck_blocks = ModuleList()
        self.bottleneck_blocks.append(create_separable_conv3_block(
            in_channels=current_num_channels + args.num_pose_params,
            out_channels=current_num_channels,
            block_args=args.block_args))
        for i in range(1, args.num_bottleneck_blocks):
            self.bottleneck_blocks.append(
                ResnetBlockSeparable.create(
                    num_channels=current_num_channels,
                    is1x1=False,
                    block_args=args.block_args))

        self.upsample_blocks = ModuleList()
        while current_image_size < args.image_size:
            next_image_size = current_image_size * 2
            next_num_channels = self.get_num_output_channels_from_image_size(next_image_size)
            self.upsample_blocks.append(create_separable_upsample_block(
                in_channels=current_num_channels,
                out_channels=next_num_channels,
                block_args=args.block_args))
            current_image_size = next_image_size
            current_num_channels = next_num_channels

    def get_num_output_channels_from_level(self, level: int):
        return self.get_num_output_channels_from_image_size(self.args.image_size // (2 ** level))

    def get_num_output_channels_from_image_size(self, image_size: int):
        return min(self.args.start_channels * (self.args.image_size // image_size), self.args.max_channels)

    def forward(self, image: Tensor, pose: Optional[Tensor] = None) -> List[Tensor]:
        if self.args.num_pose_params != 0:
            assert pose is not None
        else:
            assert pose is None
        outputs = []
        feature = image
        outputs.append(feature)
        for block in self.downsample_blocks:
            feature = block(feature)
            outputs.append(feature)
        if pose is not None:
            n, c = pose.shape
            pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.bottleneck_image_size, self.args.bottleneck_image_size)
            feature = torch.cat([feature, pose], dim=1)
        for block in self.bottleneck_blocks:
            feature = block(feature)
            outputs.append(feature)
        for block in self.upsample_blocks:
            feature = block(feature)
            outputs.append(feature)
        outputs.reverse()
        return outputs