|
from typing import List, Optional, Tuple |
|
try: |
|
from typing import Literal |
|
except ImportError: |
|
from typing_extensions import Literal |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
from ganime.model.vqgan_clean.losses.losses import Losses |
|
from .discriminator.model import NLayerDiscriminator |
|
from .losses.vqperceptual import PerceptualLoss |
|
from .vqvae.quantize import VectorQuantizer |
|
from .diffusion.encoder import Encoder |
|
from .diffusion.decoder import Decoder |
|
from tensorflow import keras |
|
from tensorflow.keras import layers |
|
from tensorflow.keras.optimizers import Optimizer |
|
from ganime.configs.model_configs import ( |
|
VQVAEConfig, |
|
AutoencoderConfig, |
|
DiscriminatorConfig, |
|
LossConfig, |
|
) |
|
|
|
|
|
@tf.function |
|
def hinge_d_loss(logits_real, logits_fake): |
|
loss_real = tf.reduce_mean(keras.activations.relu(1.0 - logits_real)) |
|
loss_fake = tf.reduce_mean(keras.activations.relu(1.0 + logits_fake)) |
|
d_loss = 0.5 * (loss_real + loss_fake) |
|
return d_loss |
|
|
|
|
|
@tf.function |
|
def vanilla_d_loss(logits_real, logits_fake): |
|
d_loss = 0.5 * ( |
|
tf.reduce_mean(keras.activations.softplus(-logits_real)) |
|
+ tf.reduce_mean(keras.activations.softplus(logits_fake)) |
|
) |
|
return d_loss |
|
|
|
|
|
class VQGAN(keras.Model): |
|
def __init__( |
|
self, |
|
vqvae_config: VQVAEConfig, |
|
autoencoder_config: AutoencoderConfig, |
|
discriminator_config: DiscriminatorConfig, |
|
loss_config: LossConfig, |
|
checkpoint_path: Optional[str] = None, |
|
num_replicas: int = 1, |
|
**kwargs, |
|
): |
|
"""Create a VQ-GAN model. |
|
Args: |
|
vqvae (VQVAEConfig): The configuration of the VQ-VAE |
|
autoencoder (AutoencoderConfig): The configuration of the autoencoder |
|
discriminator (DiscriminatorConfig): The configuration of the discriminator |
|
loss_config (LossConfig): The configuration of the loss |
|
Raises: |
|
ValueError: The specified loss type is not supported. |
|
""" |
|
super().__init__(**kwargs) |
|
self.perceptual_weight = loss_config.vqvae.perceptual_weight |
|
self.codebook_weight = loss_config.vqvae.codebook_weight |
|
self.vqvae_config = vqvae_config |
|
self.autoencoder_config = autoencoder_config |
|
self.discriminator_config = discriminator_config |
|
self.loss_config = loss_config |
|
self.num_replicas = num_replicas |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.encoder = Encoder(**autoencoder_config) |
|
|
|
self.quant_conv = layers.Conv2D( |
|
vqvae_config.embedding_dim, kernel_size=1, name="pre_quant_conv" |
|
) |
|
|
|
self.quantize = VectorQuantizer( |
|
vqvae_config.num_embeddings, |
|
vqvae_config.embedding_dim, |
|
beta=vqvae_config.beta, |
|
) |
|
|
|
self.post_quant_conv = layers.Conv2D( |
|
autoencoder_config.z_channels, kernel_size=1, name="post_quant_conv" |
|
) |
|
|
|
self.decoder = Decoder(**autoencoder_config) |
|
|
|
self.perceptual_loss = self.get_perceptual_loss( |
|
loss_config.perceptual_loss |
|
) |
|
|
|
|
|
self.discriminator = NLayerDiscriminator( |
|
filters=discriminator_config.filters, |
|
n_layers=discriminator_config.num_layers, |
|
) |
|
self.discriminator_iter_start = loss_config.discriminator.iter_start |
|
self.disc_loss = self._get_discriminator_loss(loss_config.discriminator.loss) |
|
self.disc_factor = loss_config.discriminator.factor |
|
self.discriminator_weight = loss_config.discriminator.weight |
|
|
|
|
|
|
|
self.total_loss_tracker = keras.metrics.Mean(name="total_loss") |
|
self.reconstruction_loss_tracker = keras.metrics.Mean( |
|
name="reconstruction_loss" |
|
) |
|
self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss") |
|
self.disc_loss_tracker = keras.metrics.Mean(name="disc_loss") |
|
|
|
|
|
self.gen_optimizer: Optimizer = None |
|
self.disc_optimizer: Optimizer = None |
|
|
|
self.checkpoint_path = checkpoint_path |
|
|
|
self.cross_entropy = Losses(self.num_replicas).bce_loss |
|
self.reconstruction_loss = self.get_reconstruction_loss("mae") |
|
|
|
def get_perceptual_loss(self, loss_type: str): |
|
if loss_type == "vgg16": |
|
return PerceptualLoss(reduction=tf.keras.losses.Reduction.NONE) |
|
elif loss_type == "vgg19": |
|
return Losses(self.num_replicas).vgg_loss |
|
elif loss_type == "style": |
|
return Losses(self.num_replicas).style_loss |
|
else: |
|
raise ValueError(f"Unknown loss type: {loss_type}") |
|
|
|
def get_reconstruction_loss(self, loss_type: str): |
|
if loss_type == "mse": |
|
return Losses(self.num_replicas).mse_loss |
|
elif loss_type == "mae": |
|
return Losses(self.num_replicas).mae_loss |
|
else: |
|
raise ValueError(f"Unknown loss type: {loss_type}") |
|
|
|
def load_from_checkpoint(self, path): |
|
self.load_weights(path) |
|
|
|
@property |
|
def metrics(self): |
|
|
|
|
|
|
|
|
|
|
|
return [ |
|
self.total_loss_tracker, |
|
self.reconstruction_loss_tracker, |
|
self.vq_loss_tracker, |
|
self.disc_loss_tracker, |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_discriminator_loss(self, disc_loss): |
|
if disc_loss == "hinge": |
|
loss = hinge_d_loss |
|
elif disc_loss == "vanilla": |
|
loss = vanilla_d_loss |
|
else: |
|
raise ValueError(f"Unknown GAN loss '{disc_loss}'.") |
|
|
|
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") |
|
return loss |
|
|
|
def build(self, input_shape): |
|
|
|
|
|
|
|
super().build(input_shape) |
|
self.built = True |
|
if self.checkpoint_path is not None: |
|
self.load_from_checkpoint(self.checkpoint_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def encode(self, x): |
|
h = self.encoder(x) |
|
h = self.quant_conv(h) |
|
return self.quantize(h) |
|
|
|
def decode(self, quant): |
|
quant = self.post_quant_conv(quant) |
|
dec = self.decoder(quant) |
|
return dec |
|
|
|
def call(self, inputs, training=True, mask=None): |
|
quantized, encoding_indices, loss = self.encode(inputs) |
|
reconstructed = self.decode(quantized) |
|
return reconstructed, loss |
|
|
|
def predict(self, inputs): |
|
output, loss = self(inputs) |
|
output = (output + 1.0) * 127.5 / 255 |
|
return output |
|
|
|
def calculate_adaptive_weight( |
|
self, |
|
nll_loss: tf.Tensor, |
|
g_loss: tf.Tensor, |
|
tape: tf.GradientTape, |
|
trainable_vars: list, |
|
discriminator_weight: float, |
|
) -> tf.Tensor: |
|
"""Calculate the adaptive weight for the discriminator which prevents mode collapse (https://arxiv.org/abs/2012.03149). |
|
Args: |
|
nll_loss (tf.Tensor): Negative log likelihood loss (the reconstruction loss). |
|
g_loss (tf.Tensor): Generator loss (compared to the discriminator). |
|
tape (tf.GradientTape): Gradient tape used to compute the nll_loss and g_loss |
|
trainable_vars (list): List of trainable vars of the last layer (conv_out of the decoder) |
|
discriminator_weight (float): Weight of the discriminator |
|
Returns: |
|
tf.Tensor: Discriminator weights used for the discriminator loss to benefits best the generator or discriminator and avoiding mode collapse. |
|
""" |
|
nll_grads = tape.gradient(nll_loss, trainable_vars)[0] |
|
g_grads = tape.gradient(g_loss, trainable_vars)[0] |
|
|
|
d_weight = tf.norm(nll_grads) / (tf.norm(g_grads) + 1e-4) |
|
d_weight = tf.stop_gradient(tf.clip_by_value(d_weight, 0.0, 1e4)) |
|
return d_weight * discriminator_weight |
|
|
|
@tf.function |
|
def adapt_weight( |
|
self, weight: float, global_step: int, threshold: int = 0, value: float = 0.0 |
|
) -> float: |
|
"""Adapt the weight depending on the global step. If the global_step is lower than the threshold, the weight is set to value. Used to reduce the weight of the discriminator during the first iterations. |
|
Args: |
|
weight (float): The weight to adapt. |
|
global_step (int): The global step of the optimizer |
|
threshold (int, optional): The threshold under which the weight will be set to `value`. Defaults to 0. |
|
value (float, optional): The value of the weight. Defaults to 0.0. |
|
Returns: |
|
float: The adapted weight |
|
""" |
|
if global_step < threshold: |
|
weight = value |
|
return weight |
|
|
|
def _get_global_step(self, optimizer: Optimizer): |
|
"""Get the global step of the optimizer.""" |
|
return optimizer.iterations |
|
|
|
def compile( |
|
self, |
|
gen_optimizer, |
|
disc_optimizer, |
|
): |
|
super().compile() |
|
self.gen_optimizer = gen_optimizer |
|
self.disc_optimizer = disc_optimizer |
|
|
|
def get_vqvae_trainable_vars(self): |
|
return ( |
|
self.encoder.trainable_variables |
|
+ self.quant_conv.trainable_variables |
|
+ self.quantize.trainable_variables |
|
+ self.post_quant_conv.trainable_variables |
|
+ self.decoder.trainable_variables |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generator_loss(self, fake_output): |
|
return self.cross_entropy(tf.ones_like(fake_output), fake_output) |
|
|
|
def discriminator_loss(self, disc_real_output, disc_generated_output): |
|
real_loss = self.cross_entropy(tf.ones_like(disc_real_output), disc_real_output) |
|
|
|
generated_loss = self.cross_entropy( |
|
tf.zeros_like(disc_generated_output), disc_generated_output |
|
) |
|
|
|
total_disc_loss = real_loss + generated_loss |
|
|
|
return total_disc_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_step(self, data: Tuple[tf.Tensor, tf.Tensor]): |
|
x, y = data |
|
|
|
|
|
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: |
|
with tf.GradientTape( |
|
persistent=True |
|
) as adaptive_tape: |
|
reconstructions, quantized_loss = self(x, training=True) |
|
|
|
disc_real_input = tf.image.resize( |
|
x, [256, 512], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR |
|
) |
|
disc_gen_input = tf.image.resize( |
|
reconstructions, |
|
[256, 512], |
|
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, |
|
) |
|
logits_real = self.discriminator( |
|
(disc_real_input, disc_real_input), |
|
training=True, |
|
) |
|
logits_fake = self.discriminator( |
|
(disc_real_input, disc_gen_input), |
|
training=True, |
|
) |
|
|
|
reconstruction_loss = self.reconstruction_loss(y, reconstructions) |
|
if self.perceptual_weight > 0.0: |
|
perceptual_loss = self.perceptual_weight * self.perceptual_loss( |
|
y, reconstructions |
|
) |
|
else: |
|
perceptual_loss = 0.0 |
|
|
|
nll_loss = reconstruction_loss + perceptual_loss |
|
|
|
g_loss = -tf.reduce_mean(logits_fake) |
|
|
|
|
|
d_weight = self.calculate_adaptive_weight( |
|
nll_loss, |
|
g_loss, |
|
adaptive_tape, |
|
self.decoder.conv_out.trainable_variables, |
|
self.discriminator_weight, |
|
) |
|
del adaptive_tape |
|
|
|
|
|
disc_factor = self.adapt_weight( |
|
weight=self.disc_factor, |
|
global_step=self._get_global_step(self.gen_optimizer), |
|
threshold=self.discriminator_iter_start, |
|
) |
|
|
|
total_loss = ( |
|
nll_loss |
|
+ d_weight * disc_factor * g_loss |
|
+ self.codebook_weight * quantized_loss |
|
) |
|
|
|
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
grads = gen_tape.gradient(total_loss, self.get_vqvae_trainable_vars()) |
|
self.gen_optimizer.apply_gradients(zip(grads, self.get_vqvae_trainable_vars())) |
|
|
|
|
|
disc_grads = disc_tape.gradient(d_loss, self.discriminator.trainable_variables) |
|
self.disc_optimizer.apply_gradients( |
|
zip(disc_grads, self.discriminator.trainable_variables) |
|
) |
|
|
|
|
|
self.total_loss_tracker.update_state(total_loss) |
|
self.reconstruction_loss_tracker.update_state(nll_loss) |
|
self.vq_loss_tracker.update_state(quantized_loss) |
|
self.disc_loss_tracker.update_state(d_loss) |
|
|
|
|
|
return {m.name: m.result() for m in self.metrics} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_step(self, data: Tuple[tf.Tensor, tf.Tensor]): |
|
x, y = data |
|
|
|
with tf.GradientTape( |
|
persistent=True |
|
) as adaptive_tape: |
|
reconstructions, quantized_loss = self(x, training=False) |
|
|
|
disc_real_input = tf.image.resize( |
|
x, [256, 512], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR |
|
) |
|
disc_gen_input = tf.image.resize( |
|
reconstructions, |
|
[256, 512], |
|
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, |
|
) |
|
logits_real = self.discriminator( |
|
(disc_real_input, disc_real_input), |
|
training=False, |
|
) |
|
logits_fake = self.discriminator( |
|
(disc_real_input, disc_gen_input), |
|
training=False, |
|
) |
|
|
|
reconstruction_loss = self.reconstruction_loss(y, reconstructions) |
|
if self.perceptual_weight > 0.0: |
|
perceptual_loss = self.perceptual_weight * self.perceptual_loss( |
|
y, reconstructions |
|
) |
|
else: |
|
perceptual_loss = 0.0 |
|
|
|
nll_loss = reconstruction_loss + perceptual_loss |
|
g_loss = -tf.reduce_mean(logits_fake) |
|
|
|
|
|
d_weight = self.calculate_adaptive_weight( |
|
nll_loss, |
|
g_loss, |
|
adaptive_tape, |
|
self.decoder.conv_out.trainable_variables, |
|
self.discriminator_weight, |
|
) |
|
del adaptive_tape |
|
|
|
|
|
disc_factor = self.adapt_weight( |
|
weight=self.disc_factor, |
|
global_step=self._get_global_step(self.gen_optimizer), |
|
threshold=self.discriminator_iter_start, |
|
) |
|
|
|
total_loss = ( |
|
nll_loss |
|
+ d_weight * disc_factor * g_loss |
|
+ self.codebook_weight * quantized_loss |
|
) |
|
|
|
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.total_loss_tracker.update_state(total_loss) |
|
self.reconstruction_loss_tracker.update_state(nll_loss) |
|
self.vq_loss_tracker.update_state(quantized_loss) |
|
self.disc_loss_tracker.update_state(d_loss) |
|
|
|
|
|
return {m.name: m.result() for m in self.metrics} |
|
|