from typing import List, Literal |
import numpy as np |
import tensorflow as tf |
from .discriminator.model import NLayerDiscriminator |
from .losses.vqperceptual import VQLPIPSWithDiscriminator |
from tensorflow import keras |
from tensorflow.keras import Model, layers, Sequential |
from tensorflow.keras.optimizers import Optimizer |
from tensorflow_addons.layers import GroupNormalization |
INPUT_SHAPE = (64, 128, 3) |
ENCODER_OUTPUT_SHAPE = (8, 8, 128) |
@tf.function |
def hinge_d_loss(logits_real, logits_fake): |
loss_real = tf.reduce_mean(keras.activations.relu(1.0 - logits_real)) |
loss_fake = tf.reduce_mean(keras.activations.relu(1.0 + logits_fake)) |
d_loss = 0.5 * (loss_real + loss_fake) |
return d_loss |
@tf.function |
def vanilla_d_loss(logits_real, logits_fake): |
d_loss = 0.5 * ( |
tf.reduce_mean(keras.activations.softplus(-logits_real)) |
+ tf.reduce_mean(keras.activations.softplus(logits_fake)) |
) |
return d_loss |
class VQGAN(keras.Model): |
def __init__( |
self, |
train_variance: float, |
num_embeddings: int, |
embedding_dim: int, |
beta: float = 0.25, |
z_channels: int = 128, |
codebook_weight: float = 1.0, |
disc_num_layers: int = 3, |
disc_factor: float = 1.0, |
disc_iter_start: int = 0, |
disc_conditional: bool = False, |
disc_in_channels: int = 3, |
disc_weight: float = 0.3, |
disc_filters: int = 64, |
disc_loss: Literal["hinge", "vanilla"] = "hinge", |
**kwargs, |
): |
super().__init__(**kwargs) |
self.train_variance = train_variance |
self.codebook_weight = codebook_weight |
self.encoder = Encoder() |
self.decoder = Decoder() |
self.quantize = VectorQuantizer(num_embeddings, embedding_dim, beta=beta) |
self.quant_conv = layers.Conv2D(embedding_dim, kernel_size=1) |
self.post_quant_conv = layers.Conv2D(z_channels, kernel_size=1) |
self.vqvae = self.get_vqvae() |
self.perceptual_loss = VQLPIPSWithDiscriminator( |
reduction=tf.keras.losses.Reduction.NONE |
) |
self.discriminator = NLayerDiscriminator( |
input_channels=disc_in_channels, |
filters=disc_filters, |
n_layers=disc_num_layers, |
) |
self.discriminator_iter_start = disc_iter_start |
if disc_loss == "hinge": |
self.disc_loss = hinge_d_loss |
elif disc_loss == "vanilla": |
self.disc_loss = vanilla_d_loss |
else: |
raise ValueError(f"Unknown GAN loss '{disc_loss}'.") |
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") |
self.disc_factor = disc_factor |
self.discriminator_weight = disc_weight |
self.disc_conditional = disc_conditional |
self.total_loss_tracker = keras.metrics.Mean(name="total_loss") |
self.reconstruction_loss_tracker = keras.metrics.Mean( |
name="reconstruction_loss" |
) |
self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss") |
self.disc_loss_tracker = keras.metrics.Mean(name="disc_loss") |
self.gen_optimizer: Optimizer = None |
self.disc_optimizer: Optimizer = None |
def get_vqvae(self): |
inputs = keras.Input(shape=INPUT_SHAPE) |
quant = self.encode(inputs) |
reconstructed = self.decode(quant) |
return keras.Model(inputs, reconstructed, name="vq_vae") |
def encode(self, x): |
h = self.encoder(x) |
h = self.quant_conv(h) |
return self.quantize(h) |
def decode(self, quant): |
quant = self.post_quant_conv(quant) |
dec = self.decoder(quant) |
return dec |
def call(self, inputs, training=True, mask=None): |
return self.vqvae(inputs) |
def calculate_adaptive_weight( |
self, nll_loss, g_loss, tape, trainable_vars, discriminator_weight |
): |
nll_grads = tape.gradient(nll_loss, trainable_vars)[0] |
g_grads = tape.gradient(g_loss, trainable_vars)[0] |
d_weight = tf.norm(nll_grads) / (tf.norm(g_grads) + 1e-4) |
d_weight = tf.stop_gradient(tf.clip_by_value(d_weight, 0.0, 1e4)) |
return d_weight * discriminator_weight |
@tf.function |
def adopt_weight(self, weight, global_step, threshold=0, value=0.0): |
if global_step < threshold: |
weight = value |
return weight |
def get_global_step(self, optimizer): |
return optimizer.iterations |
def compile( |
self, |
gen_optimizer, |
disc_optimizer, |
): |
super().compile() |
self.gen_optimizer = gen_optimizer |
self.disc_optimizer = disc_optimizer |
def train_step(self, data): |
x, y = data |
with tf.GradientTape() as tape: |
with tf.GradientTape(persistent=True) as adaptive_tape: |
reconstructions = self(x, training=True) |
logits_fake = self.discriminator(reconstructions, training=False) |
g_loss = -tf.reduce_mean(logits_fake) |
nll_loss = self.perceptual_loss(y, reconstructions) |
d_weight = self.calculate_adaptive_weight( |
nll_loss, |
g_loss, |
adaptive_tape, |
self.decoder.conv_out.trainable_variables, |
self.discriminator_weight, |
) |
del adaptive_tape |
disc_factor = self.adopt_weight( |
weight=self.disc_factor, |
global_step=self.get_global_step(self.gen_optimizer), |
threshold=self.discriminator_iter_start, |
) |
total_loss = ( |
nll_loss |
+ d_weight * disc_factor * g_loss |
+ self.codebook_weight * sum(self.vqvae.losses) |
) |
grads = tape.gradient(total_loss, self.vqvae.trainable_variables) |
self.gen_optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables)) |
with tf.GradientTape() as disc_tape: |
logits_real = self.discriminator(y, training=True) |
logits_fake = self.discriminator(reconstructions, training=True) |
disc_factor = self.adopt_weight( |
weight=self.disc_factor, |
global_step=self.get_global_step(self.disc_optimizer), |
threshold=self.discriminator_iter_start, |
) |
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) |
disc_grads = disc_tape.gradient(d_loss, self.discriminator.trainable_variables) |
self.disc_optimizer.apply_gradients( |
zip(disc_grads, self.discriminator.trainable_variables) |
) |
self.total_loss_tracker.update_state(total_loss) |
self.reconstruction_loss_tracker.update_state(nll_loss) |
self.vq_loss_tracker.update_state(sum(self.vqvae.losses)) |
self.disc_loss_tracker.update_state(d_loss) |
return { |
"loss": self.total_loss_tracker.result(), |
"reconstruction_loss": self.reconstruction_loss_tracker.result(), |
"vqvae_loss": self.vq_loss_tracker.result(), |
"disc_loss": self.disc_loss_tracker.result(), |
} |
class VectorQuantizer(layers.Layer): |
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs): |
super().__init__(**kwargs) |
self.embedding_dim = embedding_dim |
self.num_embeddings = num_embeddings |
self.beta = ( |
beta |
) |
w_init = tf.random_uniform_initializer() |
self.embeddings = tf.Variable( |
initial_value=w_init( |
shape=(self.embedding_dim, self.num_embeddings) |
), |
trainable=True, |
name="embeddings_vqvae", |
) |
def call(self, x): |
input_shape = tf.shape(x) |
flattened = tf.reshape(x, [-1, self.embedding_dim]) |
encoding_indices = self.get_code_indices(flattened) |
encodings = tf.one_hot(encoding_indices, self.num_embeddings) |
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True) |
quantized = tf.reshape(quantized, input_shape) |
commitment_loss = self.beta * tf.reduce_mean( |
(tf.stop_gradient(quantized) - x) ** 2 |
) |
codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2) |
self.add_loss(commitment_loss + codebook_loss) |
quantized = x + tf.stop_gradient(quantized - x) |
return quantized |
def get_code_indices(self, flattened_inputs): |
similarity = tf.matmul(flattened_inputs, self.embeddings) |
distances = ( |
tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True) |
+ tf.reduce_sum(self.embeddings**2, axis=0) |
- 2 * similarity |
) |
encoding_indices = tf.argmin(distances, axis=1) |
return encoding_indices |
class Encoder(Model): |
def __init__( |
self, |
*, |
channels: int = 128, |
output_channels: int = 3, |
channels_multiplier: List[int] = [1, 1, 2, 2], |
num_res_blocks: int = 1, |
attention_resolution: List[int] = [16], |
resolution: int = 64, |
z_channels=128, |
dropout=0.0, |
double_z=False, |
resamp_with_conv=True, |
): |
super().__init__() |
self.channels = channels |
self.timestep_embeddings_channel = 0 |
self.num_resolutions = len(channels_multiplier) |
self.num_res_blocks = num_res_blocks |
self.resolution = resolution |
self.conv_in = layers.Conv2D( |
self.channels, kernel_size=3, strides=1, padding="same" |
) |
current_resolution = resolution |
in_channels_multiplier = (1,) + tuple(channels_multiplier) |
self.downsampling_list = [] |
for i_level in range(self.num_resolutions): |
block_in = channels * in_channels_multiplier[i_level] |
block_out = channels * channels_multiplier[i_level] |
for i_block in range(self.num_res_blocks): |
self.downsampling_list.append( |
ResnetBlock( |
in_channels=block_in, |
out_channels=block_out, |
timestep_embedding_channels=self.timestep_embeddings_channel, |
dropout=dropout, |
) |
) |
block_in = block_out |
if current_resolution in attention_resolution: |
self.downsampling_list.append(AttentionBlock(block_in)) |
if i_level != self.num_resolutions - 1: |
self.downsampling_list.append(Downsample(block_in, resamp_with_conv)) |
self.mid = {} |
self.mid["block_1"] = ResnetBlock( |
in_channels=block_in, |
out_channels=block_in, |
timestep_embedding_channels=self.timestep_embeddings_channel, |
dropout=dropout, |
) |
self.mid["attn_1"] = AttentionBlock(block_in) |
self.mid["block_2"] = ResnetBlock( |
in_channels=block_in, |
out_channels=block_in, |
timestep_embedding_channels=self.timestep_embeddings_channel, |
dropout=dropout, |
) |
self.norm_out = GroupNormalization(groups=32, epsilon=1e-6) |
self.conv_out = layers.Conv2D( |
2 * z_channels if double_z else z_channels, |
kernel_size=3, |
strides=1, |
padding="same", |
) |
def summary(self): |
x = layers.Input(shape=INPUT_SHAPE) |
model = Model(inputs=[x], outputs=self.call(x)) |
return model.summary() |
def call(self, inputs, training=True, mask=None): |
h = self.conv_in(inputs) |
for downsampling in self.downsampling_list: |
h = downsampling(h) |
h = self.mid["block_1"](h) |
h = self.mid["attn_1"](h) |
h = self.mid["block_2"](h) |
h = self.norm_out(h) |
h = keras.activations.swish(h) |
h = self.conv_out(h) |
return h |
class Decoder(Model): |
def __init__( |
self, |
*, |
channels: int = 128, |
output_channels: int = 3, |
channels_multiplier: List[int] = [1, 1, 2, 2], |
num_res_blocks: int = 1, |
attention_resolution: List[int] = [16], |
resolution: int = 64, |
z_channels=128, |
dropout=0.0, |
give_pre_end=False, |
resamp_with_conv=True, |
): |
super().__init__() |
self.channels = channels |
self.timestep_embeddings_channel = 0 |
self.num_resolutions = len(channels_multiplier) |
self.num_res_blocks = num_res_blocks |
self.resolution = resolution |
self.give_pre_end = give_pre_end |
in_channels_multiplier = (1,) + tuple(channels_multiplier) |
block_in = channels * channels_multiplier[-1] |
current_resolution = resolution // 2 ** (self.num_resolutions - 1) |
self.z_shape = (1, z_channels, current_resolution, current_resolution) |
print( |
"Working with z of shape {} = {} dimensions.".format( |
self.z_shape, np.prod(self.z_shape) |
) |
) |
self.conv_in = layers.Conv2D(block_in, kernel_size=3, strides=1, padding="same") |
self.mid = {} |
self.mid["block_1"] = ResnetBlock( |
in_channels=block_in, |
out_channels=block_in, |
timestep_embedding_channels=self.timestep_embeddings_channel, |
dropout=dropout, |
) |
self.mid["attn_1"] = AttentionBlock(block_in) |
self.mid["block_2"] = ResnetBlock( |
in_channels=block_in, |
out_channels=block_in, |
timestep_embedding_channels=self.timestep_embeddings_channel, |
dropout=dropout, |
) |
self.upsampling_list = [] |
for i_level in reversed(range(self.num_resolutions)): |
block_out = channels * channels_multiplier[i_level] |
for i_block in range(self.num_res_blocks + 1): |
self.upsampling_list.append( |
ResnetBlock( |
in_channels=block_in, |
out_channels=block_out, |
timestep_embedding_channels=self.timestep_embeddings_channel, |
dropout=dropout, |
) |
) |
block_in = block_out |
if current_resolution in attention_resolution: |
self.upsampling_list.append(AttentionBlock(block_in)) |
if i_level != 0: |
self.upsampling_list.append(Upsample(block_in, resamp_with_conv)) |
current_resolution *= 2 |
self.norm_out = GroupNormalization(groups=32, epsilon=1e-6) |
self.conv_out = layers.Conv2D( |
output_channels, |
kernel_size=3, |
strides=1, |
activation="sigmoid", |
padding="same", |
) |
def summary(self): |
x = layers.Input(shape=ENCODER_OUTPUT_SHAPE) |
model = Model(inputs=[x], outputs=self.call(x)) |
return model.summary() |
def call(self, inputs, training=True, mask=None): |
h = self.conv_in(inputs) |
h = self.mid["block_1"](h) |
h = self.mid["attn_1"](h) |
h = self.mid["block_2"](h) |
for upsampling in self.upsampling_list: |
h = upsampling(h) |
if self.give_pre_end: |
return h |
h = self.norm_out(h) |
h = keras.activations.swish(h) |
h = self.conv_out(h) |
return h |
class ResnetBlock(layers.Layer): |
def __init__( |
self, |
*, |
in_channels, |
dropout=0.0, |
out_channels=None, |
conv_shortcut=False, |
timestep_embedding_channels=512, |
): |
super().__init__() |
self.in_channels = in_channels |
out_channels = in_channels if out_channels is None else out_channels |
self.out_channels = out_channels |
self.use_conv_shortcut = conv_shortcut |
self.norm1 = GroupNormalization(groups=32, epsilon=1e-6) |
self.conv1 = layers.Conv2D( |
out_channels, kernel_size=3, strides=1, padding="same" |
) |
if timestep_embedding_channels > 0: |
self.timestep_embedding_projection = layers.Dense(out_channels) |
self.norm2 = GroupNormalization(groups=32, epsilon=1e-6) |
self.dropout = layers.Dropout(dropout) |
self.conv2 = layers.Conv2D( |
out_channels, kernel_size=3, strides=1, padding="same" |
) |
if self.in_channels != self.out_channels: |
if self.use_conv_shortcut: |
self.conv_shortcut = layers.Conv2D( |
out_channels, kernel_size=3, strides=1, padding="same" |
) |
else: |
self.nin_shortcut = layers.Conv2D( |
out_channels, kernel_size=1, strides=1, padding="valid" |
) |
def call(self, x): |
h = x |
h = self.norm1(h) |
h = keras.activations.swish(h) |
h = self.conv1(h) |
h = self.norm2(h) |
h = keras.activations.swish(h) |
h = self.dropout(h) |
h = self.conv2(h) |
if self.in_channels != self.out_channels: |
if self.use_conv_shortcut: |
x = self.conv_shortcut(x) |
else: |
x = self.nin_shortcut(x) |
return x + h |
class AttentionBlock(layers.Layer): |
def __init__(self, channels): |
super().__init__() |
self.norm = GroupNormalization(groups=32, epsilon=1e-6) |
self.q = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid") |
self.k = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid") |
self.v = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid") |
self.proj_out = layers.Conv2D( |
channels, kernel_size=1, strides=1, padding="valid" |
) |
def call(self, x): |
h_ = x |
h_ = self.norm(h_) |
q = self.q(h_) |
k = self.k(h_) |
v = self.v(h_) |
( |
b, |
h, |
w, |
c, |
) = q.shape |
if b is None: |
b = -1 |
q = tf.reshape(q, [b, h * w, c]) |
k = tf.reshape(k, [b, h * w, c]) |
w_ = tf.matmul( |
q, k, transpose_b=True |
) |
w_ = w_ * (int(c) ** (-0.5)) |
w_ = keras.activations.softmax(w_) |
v = tf.reshape(v, [b, h * w, c]) |
h_ = tf.matmul( |
v, w_, transpose_a=True |
) |
h_ = tf.reshape(h_, [b, h, w, c]) |
h_ = self.proj_out(h_) |
return x + h_ |
class Downsample(layers.Layer): |
def __init__(self, channels, with_conv=True): |
super().__init__() |
self.with_conv = with_conv |
if self.with_conv: |
self.down_sample = layers.Conv2D( |
channels, kernel_size=3, strides=2, padding="same" |
) |
else: |
self.down_sample = layers.AveragePooling2D(pool_size=2, strides=2) |
def call(self, x): |
x = self.down_sample(x) |
return x |
class Upsample(layers.Layer): |
def __init__(self, channels, with_conv=False): |
super().__init__() |
self.with_conv = with_conv |
if False: |
self.up_sample = layers.Conv2DTranspose( |
channels, kernel_size=3, strides=2, padding="same" |
) |
else: |
self.up_sample = Sequential( |
[ |
layers.UpSampling2D(size=2, interpolation="nearest"), |
layers.Conv2D(channels, kernel_size=3, strides=1, padding="same"), |
] |
) |
def call(self, x): |
x = self.up_sample(x) |
return x |