File size: 2,591 Bytes
3be620b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from typing import List
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model, Sequential
from tensorflow.keras import layers
from tensorflow.keras.initializers import RandomNormal


class NLayerDiscriminator(Model):
    """Defines a PatchGAN discriminator as in Pix2Pix
    --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
    """

    def __init__(self, filters: int = 64, n_layers: int = 3, **kwargs):
        super().__init__(**kwargs)

        init = RandomNormal(stddev=0.02)
        self.filters = filters
        self.n_layers = n_layers

        kernel_size = 4

        inp = tf.keras.layers.Input(shape=[256, 512, 3], name="input_image")
        tar = tf.keras.layers.Input(shape=[256, 512, 3], name="target_image")

        x = tf.keras.layers.concatenate([inp, tar])

        x = layers.Conv2D(
            filters,
            kernel_size=kernel_size,
            strides=2,
            # strides=1,
            padding="same",
            kernel_initializer=init,
        )(x)
        x = layers.LeakyReLU(alpha=0.2)(x)

        filters_mult = 1
        for n in range(1, n_layers):
            filters_mult = min(2**n, 8)

            x = layers.Conv2D(
                filters * filters_mult,
                kernel_size=kernel_size,
                # strides=1,  # 2,
                strides=2,
                padding="same",
                use_bias=False,
                kernel_initializer=init,
            )(x)
            x = layers.BatchNormalization()(x)
            x = layers.LeakyReLU(alpha=0.2)(x)

        filters_mult = min(2**n_layers, 8)
        x = layers.Conv2D(
            filters * filters_mult,
            kernel_size=kernel_size,
            strides=1,
            padding="same",
            use_bias=False,
            kernel_initializer=init,
        )(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(alpha=0.2)(x)

        x = layers.Conv2D(
            1,
            kernel_size=kernel_size,
            strides=1,
            padding="same",
            # activation="sigmoid",
            kernel_initializer=init,
        )(x)
        self.model = tf.keras.Model(inputs=[inp, tar], outputs=x)

    def call(self, inputs, training=True, mask=None):
        return self.model(inputs)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "filters": self.filters,
                "n_layers": self.n_layers,
            }
        )
        return config