File size: 3,904 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from typing import List, Optional

import torch
from torch import Tensor
from torch.nn import Module

from tha3.nn.common.poser_encoder_decoder_00 import PoserEncoderDecoder00Args
from tha3.nn.common.poser_encoder_decoder_00_separable import PoserEncoderDecoder00Separable
from tha3.nn.image_processing_util import apply_color_change
from tha3.module.module_factory import ModuleFactory
from tha3.nn.nonlinearity_factory import ReLUFactory
from tha3.nn.normalization import InstanceNorm2dFactory
from tha3.nn.util import BlockArgs


class EyebrowDecomposer03Args(PoserEncoderDecoder00Args):
    def __init__(self,
                 image_size: int = 128,
                 image_channels: int = 4,
                 start_channels: int = 64,
                 bottleneck_image_size=16,
                 num_bottleneck_blocks=6,
                 max_channels: int = 512,
                 block_args: Optional[BlockArgs] = None):
        super().__init__(
            image_size,
            image_channels,
            image_channels,
            0,
            start_channels,
            bottleneck_image_size,
            num_bottleneck_blocks,
            max_channels,
            block_args)


class EyebrowDecomposer03(Module):
    def __init__(self, args: EyebrowDecomposer03Args):
        super().__init__()
        self.args = args
        self.body = PoserEncoderDecoder00Separable(args)
        self.background_layer_alpha = self.args.create_alpha_block()
        self.background_layer_color_change = self.args.create_color_change_block()
        self.eyebrow_layer_alpha = self.args.create_alpha_block()
        self.eyebrow_layer_color_change = self.args.create_color_change_block()

    def forward(self, image: Tensor, *args) -> List[Tensor]:
        feature = self.body(image)[0]

        background_layer_alpha = self.background_layer_alpha(feature)
        background_layer_color_change = self.background_layer_color_change(feature)
        background_layer_1 = apply_color_change(background_layer_alpha, background_layer_color_change, image)

        eyebrow_layer_alpha = self.eyebrow_layer_alpha(feature)
        eyebrow_layer_color_change = self.eyebrow_layer_color_change(feature)
        eyebrow_layer = apply_color_change(eyebrow_layer_alpha, image, eyebrow_layer_color_change)

        return [
            eyebrow_layer,  # 0
            eyebrow_layer_alpha,  # 1
            eyebrow_layer_color_change,  # 2
            background_layer_1,  # 3
            background_layer_alpha,  # 4
            background_layer_color_change,  # 5
        ]

    EYEBROW_LAYER_INDEX = 0
    EYEBROW_LAYER_ALPHA_INDEX = 1
    EYEBROW_LAYER_COLOR_CHANGE_INDEX = 2
    BACKGROUND_LAYER_INDEX = 3
    BACKGROUND_LAYER_ALPHA_INDEX = 4
    BACKGROUND_LAYER_COLOR_CHANGE_INDEX = 5
    OUTPUT_LENGTH = 6


class EyebrowDecomposer03Factory(ModuleFactory):
    def __init__(self, args: EyebrowDecomposer03Args):
        super().__init__()
        self.args = args

    def create(self) -> Module:
        return EyebrowDecomposer03(self.args)


if __name__ == "__main__":
    cuda = torch.device('cuda')
    args = EyebrowDecomposer03Args(
        image_size=128,
        image_channels=4,
        start_channels=64,
        bottleneck_image_size=16,
        num_bottleneck_blocks=6,
        block_args=BlockArgs(
            initialization_method='xavier',
            use_spectral_norm=False,
            normalization_layer_factory=InstanceNorm2dFactory(),
            nonlinearity_factory=ReLUFactory(inplace=True)))
    face_morpher = EyebrowDecomposer03(args).to(cuda)

    #image = torch.randn(8, 4, 128, 128, device=cuda)
    #outputs = face_morpher.forward(image)
    #for i in range(len(outputs)):
    #    print(i, outputs[i].shape)

    state_dict = face_morpher.state_dict()
    index = 0
    for key in state_dict:
        print(f"[{index}]", key, state_dict[key].shape)
        index += 1