|
from typing import List |
|
import numpy as np |
|
import tensorflow as tf |
|
from tensorflow import keras |
|
from tensorflow.keras import Model, Sequential |
|
from tensorflow.keras import layers |
|
from tensorflow.keras.initializers import RandomNormal |
|
|
|
|
|
class NLayerDiscriminator(Model): |
|
"""Defines a PatchGAN discriminator as in Pix2Pix |
|
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py |
|
""" |
|
|
|
def __init__(self, filters: int = 64, n_layers: int = 3, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
init = RandomNormal(stddev=0.02) |
|
self.filters = filters |
|
self.n_layers = n_layers |
|
|
|
kernel_size = 4 |
|
|
|
inp = tf.keras.layers.Input(shape=[256, 512, 3], name="input_image") |
|
tar = tf.keras.layers.Input(shape=[256, 512, 3], name="target_image") |
|
|
|
x = tf.keras.layers.concatenate([inp, tar]) |
|
|
|
x = layers.Conv2D( |
|
filters, |
|
kernel_size=kernel_size, |
|
strides=2, |
|
|
|
padding="same", |
|
kernel_initializer=init, |
|
)(x) |
|
x = layers.LeakyReLU(alpha=0.2)(x) |
|
|
|
filters_mult = 1 |
|
for n in range(1, n_layers): |
|
filters_mult = min(2**n, 8) |
|
|
|
x = layers.Conv2D( |
|
filters * filters_mult, |
|
kernel_size=kernel_size, |
|
|
|
strides=2, |
|
padding="same", |
|
use_bias=False, |
|
kernel_initializer=init, |
|
)(x) |
|
x = layers.BatchNormalization()(x) |
|
x = layers.LeakyReLU(alpha=0.2)(x) |
|
|
|
filters_mult = min(2**n_layers, 8) |
|
x = layers.Conv2D( |
|
filters * filters_mult, |
|
kernel_size=kernel_size, |
|
strides=1, |
|
padding="same", |
|
use_bias=False, |
|
kernel_initializer=init, |
|
)(x) |
|
x = layers.BatchNormalization()(x) |
|
x = layers.LeakyReLU(alpha=0.2)(x) |
|
|
|
x = layers.Conv2D( |
|
1, |
|
kernel_size=kernel_size, |
|
strides=1, |
|
padding="same", |
|
|
|
kernel_initializer=init, |
|
)(x) |
|
self.model = tf.keras.Model(inputs=[inp, tar], outputs=x) |
|
|
|
def call(self, inputs, training=True, mask=None): |
|
return self.model(inputs) |
|
|
|
def get_config(self): |
|
config = super().get_config() |
|
config.update( |
|
{ |
|
"filters": self.filters, |
|
"n_layers": self.n_layers, |
|
} |
|
) |
|
return config |
|
|