File size: 3,483 Bytes
3be620b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from typing import List

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model, layers
from tensorflow_addons.layers import GroupNormalization

from .layers import AttentionBlock, ResnetBlock, Upsample


# @tf.keras.utils.register_keras_serializable()
class Decoder(layers.Layer):
    def __init__(
        self,
        *,
        channels: int,
        output_channels: int = 3,
        channels_multiplier: List[int],
        num_res_blocks: int,
        attention_resolution: List[int],
        resolution: int,
        z_channels: int,
        dropout: float,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.channels = channels
        self.output_channels = output_channels
        self.channels_multiplier = channels_multiplier
        self.num_resolutions = len(channels_multiplier)
        self.num_res_blocks = num_res_blocks
        self.attention_resolution = attention_resolution
        self.resolution = resolution
        self.z_channels = z_channels
        self.dropout = dropout

        block_in = channels * channels_multiplier[-1]
        current_resolution = resolution // 2 ** (self.num_resolutions - 1)
        self.z_shape = (1, z_channels, current_resolution, current_resolution)

        print(
            "Working with z of shape {} = {} dimensions.".format(
                self.z_shape, np.prod(self.z_shape)
            )
        )

        self.conv_in = layers.Conv2D(block_in, kernel_size=3, strides=1, padding="same")

        # middle
        self.mid = {}
        self.mid["block_1"] = ResnetBlock(
            in_channels=block_in,
            out_channels=block_in,
            dropout=dropout,
        )
        self.mid["attn_1"] = AttentionBlock(block_in)
        self.mid["block_2"] = ResnetBlock(
            in_channels=block_in,
            out_channels=block_in,
            dropout=dropout,
        )

        # upsampling

        self.upsampling_list = []

        for i_level in reversed(range(self.num_resolutions)):
            block_out = channels * channels_multiplier[i_level]
            for i_block in range(self.num_res_blocks + 1):
                self.upsampling_list.append(
                    ResnetBlock(
                        in_channels=block_in,
                        out_channels=block_out,
                        dropout=dropout,
                    )
                )
                block_in = block_out

                if current_resolution in attention_resolution:
                    # attentions.append(layers.Attention())
                    self.upsampling_list.append(AttentionBlock(block_in))

            if i_level != 0:
                self.upsampling_list.append(Upsample(block_in))
                current_resolution *= 2

        # end
        self.norm_out = GroupNormalization(groups=32, epsilon=1e-6)
        self.conv_out = layers.Conv2D(
            output_channels,
            kernel_size=3,
            strides=1,
            activation="tanh",
            padding="same",
        )

    def call(self, inputs, training=True, mask=None):

        h = self.conv_in(inputs)

        # middle
        h = self.mid["block_1"](h)
        h = self.mid["attn_1"](h)
        h = self.mid["block_2"](h)

        for upsampling in self.upsampling_list:
            h = upsampling(h)

        # end
        h = self.norm_out(h)
        h = keras.activations.swish(h)
        h = self.conv_out(h)
        return h