|
from typing import List |
|
import numpy as np |
|
import tensorflow as tf |
|
from tensorflow import keras |
|
from tensorflow.keras import Model, Sequential |
|
from tensorflow.keras import layers |
|
|
|
|
|
class NLayerDiscriminator(Model): |
|
"""Defines a PatchGAN discriminator as in Pix2Pix |
|
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py |
|
""" |
|
|
|
def __init__(self, filters: int = 64, n_layers: int = 3, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.filters = filters |
|
self.n_layers = n_layers |
|
|
|
kernel_size = 4 |
|
self.sequence = [ |
|
layers.Conv2D(filters, kernel_size=kernel_size, strides=1, padding="same"), |
|
layers.LeakyReLU(alpha=0.2), |
|
] |
|
|
|
filters_mult = 1 |
|
for n in range(1, n_layers): |
|
filters_mult = min(2**n, 8) |
|
|
|
self.sequence += [ |
|
layers.AveragePooling2D(pool_size=2), |
|
layers.Conv2D( |
|
filters * filters_mult, |
|
kernel_size=kernel_size, |
|
strides=1, |
|
|
|
padding="same", |
|
use_bias=False, |
|
), |
|
layers.BatchNormalization(), |
|
layers.LeakyReLU(alpha=0.2), |
|
] |
|
|
|
filters_mult = min(2**n_layers, 8) |
|
self.sequence += [ |
|
layers.AveragePooling2D(pool_size=2), |
|
layers.Conv2D( |
|
filters * filters_mult, |
|
kernel_size=kernel_size, |
|
strides=1, |
|
padding="same", |
|
use_bias=False, |
|
), |
|
layers.BatchNormalization(), |
|
layers.LeakyReLU(alpha=0.2), |
|
] |
|
|
|
self.sequence += [ |
|
layers.Conv2D(1, kernel_size=kernel_size, strides=1, padding="same") |
|
] |
|
|
|
def call(self, inputs, training=True, mask=None): |
|
h = inputs |
|
for seq in self.sequence: |
|
h = seq(h) |
|
return h |
|
|
|
def get_config(self): |
|
config = super().get_config() |
|
config.update( |
|
{ |
|
"filters": self.filters, |
|
"n_layers": self.n_layers, |
|
} |
|
) |
|
return config |
|
|