|
from typing import List, Literal |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
from .discriminator.model import NLayerDiscriminator |
|
from .losses.vqperceptual import VQLPIPSWithDiscriminator |
|
from tensorflow import keras |
|
from tensorflow.keras import Model, layers, Sequential |
|
from tensorflow.keras.optimizers import Optimizer |
|
from tensorflow_addons.layers import GroupNormalization |
|
|
|
INPUT_SHAPE = (64, 128, 3) |
|
ENCODER_OUTPUT_SHAPE = (8, 8, 128) |
|
|
|
|
|
@tf.function |
|
def hinge_d_loss(logits_real, logits_fake): |
|
loss_real = tf.reduce_mean(keras.activations.relu(1.0 - logits_real)) |
|
loss_fake = tf.reduce_mean(keras.activations.relu(1.0 + logits_fake)) |
|
d_loss = 0.5 * (loss_real + loss_fake) |
|
return d_loss |
|
|
|
|
|
@tf.function |
|
def vanilla_d_loss(logits_real, logits_fake): |
|
d_loss = 0.5 * ( |
|
tf.reduce_mean(keras.activations.softplus(-logits_real)) |
|
+ tf.reduce_mean(keras.activations.softplus(logits_fake)) |
|
) |
|
return d_loss |
|
|
|
|
|
class VQGAN(keras.Model): |
|
def __init__( |
|
self, |
|
train_variance: float, |
|
num_embeddings: int, |
|
embedding_dim: int, |
|
beta: float = 0.25, |
|
z_channels: int = 128, |
|
codebook_weight: float = 1.0, |
|
disc_num_layers: int = 3, |
|
disc_factor: float = 1.0, |
|
disc_iter_start: int = 0, |
|
disc_conditional: bool = False, |
|
disc_in_channels: int = 3, |
|
disc_weight: float = 0.3, |
|
disc_filters: int = 64, |
|
disc_loss: Literal["hinge", "vanilla"] = "hinge", |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.train_variance = train_variance |
|
self.codebook_weight = codebook_weight |
|
|
|
self.encoder = Encoder() |
|
self.decoder = Decoder() |
|
self.quantize = VectorQuantizer(num_embeddings, embedding_dim, beta=beta) |
|
|
|
self.quant_conv = layers.Conv2D(embedding_dim, kernel_size=1) |
|
self.post_quant_conv = layers.Conv2D(z_channels, kernel_size=1) |
|
|
|
self.vqvae = self.get_vqvae() |
|
|
|
self.perceptual_loss = VQLPIPSWithDiscriminator( |
|
reduction=tf.keras.losses.Reduction.NONE |
|
) |
|
|
|
self.discriminator = NLayerDiscriminator( |
|
input_channels=disc_in_channels, |
|
filters=disc_filters, |
|
n_layers=disc_num_layers, |
|
) |
|
self.discriminator_iter_start = disc_iter_start |
|
|
|
if disc_loss == "hinge": |
|
self.disc_loss = hinge_d_loss |
|
elif disc_loss == "vanilla": |
|
self.disc_loss = vanilla_d_loss |
|
else: |
|
raise ValueError(f"Unknown GAN loss '{disc_loss}'.") |
|
|
|
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") |
|
self.disc_factor = disc_factor |
|
self.discriminator_weight = disc_weight |
|
self.disc_conditional = disc_conditional |
|
|
|
self.total_loss_tracker = keras.metrics.Mean(name="total_loss") |
|
self.reconstruction_loss_tracker = keras.metrics.Mean( |
|
name="reconstruction_loss" |
|
) |
|
self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss") |
|
self.disc_loss_tracker = keras.metrics.Mean(name="disc_loss") |
|
|
|
self.gen_optimizer: Optimizer = None |
|
self.disc_optimizer: Optimizer = None |
|
|
|
def get_vqvae(self): |
|
inputs = keras.Input(shape=INPUT_SHAPE) |
|
quant = self.encode(inputs) |
|
reconstructed = self.decode(quant) |
|
return keras.Model(inputs, reconstructed, name="vq_vae") |
|
|
|
def encode(self, x): |
|
h = self.encoder(x) |
|
h = self.quant_conv(h) |
|
return self.quantize(h) |
|
|
|
def decode(self, quant): |
|
quant = self.post_quant_conv(quant) |
|
dec = self.decoder(quant) |
|
return dec |
|
|
|
def call(self, inputs, training=True, mask=None): |
|
return self.vqvae(inputs) |
|
|
|
def calculate_adaptive_weight( |
|
self, nll_loss, g_loss, tape, trainable_vars, discriminator_weight |
|
): |
|
nll_grads = tape.gradient(nll_loss, trainable_vars)[0] |
|
g_grads = tape.gradient(g_loss, trainable_vars)[0] |
|
|
|
d_weight = tf.norm(nll_grads) / (tf.norm(g_grads) + 1e-4) |
|
d_weight = tf.stop_gradient(tf.clip_by_value(d_weight, 0.0, 1e4)) |
|
return d_weight * discriminator_weight |
|
|
|
@tf.function |
|
def adopt_weight(self, weight, global_step, threshold=0, value=0.0): |
|
if global_step < threshold: |
|
weight = value |
|
return weight |
|
|
|
def get_global_step(self, optimizer): |
|
return optimizer.iterations |
|
|
|
def compile( |
|
self, |
|
gen_optimizer, |
|
disc_optimizer, |
|
): |
|
super().compile() |
|
self.gen_optimizer = gen_optimizer |
|
self.disc_optimizer = disc_optimizer |
|
|
|
def train_step(self, data): |
|
x, y = data |
|
|
|
|
|
with tf.GradientTape() as tape: |
|
with tf.GradientTape(persistent=True) as adaptive_tape: |
|
reconstructions = self(x, training=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
logits_fake = self.discriminator(reconstructions, training=False) |
|
|
|
g_loss = -tf.reduce_mean(logits_fake) |
|
nll_loss = self.perceptual_loss(y, reconstructions) |
|
|
|
d_weight = self.calculate_adaptive_weight( |
|
nll_loss, |
|
g_loss, |
|
adaptive_tape, |
|
self.decoder.conv_out.trainable_variables, |
|
self.discriminator_weight, |
|
) |
|
del adaptive_tape |
|
|
|
disc_factor = self.adopt_weight( |
|
weight=self.disc_factor, |
|
global_step=self.get_global_step(self.gen_optimizer), |
|
threshold=self.discriminator_iter_start, |
|
) |
|
|
|
|
|
total_loss = ( |
|
nll_loss |
|
+ d_weight * disc_factor * g_loss |
|
|
|
+ self.codebook_weight * sum(self.vqvae.losses) |
|
) |
|
|
|
|
|
grads = tape.gradient(total_loss, self.vqvae.trainable_variables) |
|
self.gen_optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables)) |
|
|
|
|
|
with tf.GradientTape() as disc_tape: |
|
logits_real = self.discriminator(y, training=True) |
|
logits_fake = self.discriminator(reconstructions, training=True) |
|
|
|
disc_factor = self.adopt_weight( |
|
weight=self.disc_factor, |
|
global_step=self.get_global_step(self.disc_optimizer), |
|
threshold=self.discriminator_iter_start, |
|
) |
|
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) |
|
|
|
disc_grads = disc_tape.gradient(d_loss, self.discriminator.trainable_variables) |
|
self.disc_optimizer.apply_gradients( |
|
zip(disc_grads, self.discriminator.trainable_variables) |
|
) |
|
|
|
|
|
self.total_loss_tracker.update_state(total_loss) |
|
self.reconstruction_loss_tracker.update_state(nll_loss) |
|
self.vq_loss_tracker.update_state(sum(self.vqvae.losses)) |
|
self.disc_loss_tracker.update_state(d_loss) |
|
|
|
|
|
return { |
|
"loss": self.total_loss_tracker.result(), |
|
"reconstruction_loss": self.reconstruction_loss_tracker.result(), |
|
"vqvae_loss": self.vq_loss_tracker.result(), |
|
"disc_loss": self.disc_loss_tracker.result(), |
|
} |
|
|
|
|
|
class VectorQuantizer(layers.Layer): |
|
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs): |
|
super().__init__(**kwargs) |
|
self.embedding_dim = embedding_dim |
|
self.num_embeddings = num_embeddings |
|
self.beta = ( |
|
beta |
|
) |
|
|
|
|
|
w_init = tf.random_uniform_initializer() |
|
self.embeddings = tf.Variable( |
|
initial_value=w_init( |
|
shape=(self.embedding_dim, self.num_embeddings) |
|
), |
|
trainable=True, |
|
name="embeddings_vqvae", |
|
) |
|
|
|
def call(self, x): |
|
|
|
|
|
input_shape = tf.shape(x) |
|
flattened = tf.reshape(x, [-1, self.embedding_dim]) |
|
|
|
|
|
encoding_indices = self.get_code_indices(flattened) |
|
encodings = tf.one_hot(encoding_indices, self.num_embeddings) |
|
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True) |
|
quantized = tf.reshape(quantized, input_shape) |
|
|
|
|
|
|
|
|
|
|
|
commitment_loss = self.beta * tf.reduce_mean( |
|
(tf.stop_gradient(quantized) - x) ** 2 |
|
) |
|
codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2) |
|
self.add_loss(commitment_loss + codebook_loss) |
|
|
|
|
|
quantized = x + tf.stop_gradient(quantized - x) |
|
return quantized |
|
|
|
def get_code_indices(self, flattened_inputs): |
|
|
|
similarity = tf.matmul(flattened_inputs, self.embeddings) |
|
distances = ( |
|
tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True) |
|
+ tf.reduce_sum(self.embeddings**2, axis=0) |
|
- 2 * similarity |
|
) |
|
|
|
|
|
encoding_indices = tf.argmin(distances, axis=1) |
|
return encoding_indices |
|
|
|
|
|
class Encoder(Model): |
|
def __init__( |
|
self, |
|
*, |
|
channels: int = 128, |
|
output_channels: int = 3, |
|
channels_multiplier: List[int] = [1, 1, 2, 2], |
|
num_res_blocks: int = 1, |
|
attention_resolution: List[int] = [16], |
|
resolution: int = 64, |
|
z_channels=128, |
|
dropout=0.0, |
|
double_z=False, |
|
resamp_with_conv=True, |
|
): |
|
super().__init__() |
|
|
|
self.channels = channels |
|
self.timestep_embeddings_channel = 0 |
|
self.num_resolutions = len(channels_multiplier) |
|
self.num_res_blocks = num_res_blocks |
|
self.resolution = resolution |
|
|
|
self.conv_in = layers.Conv2D( |
|
self.channels, kernel_size=3, strides=1, padding="same" |
|
) |
|
|
|
current_resolution = resolution |
|
|
|
in_channels_multiplier = (1,) + tuple(channels_multiplier) |
|
|
|
self.downsampling_list = [] |
|
|
|
for i_level in range(self.num_resolutions): |
|
block_in = channels * in_channels_multiplier[i_level] |
|
block_out = channels * channels_multiplier[i_level] |
|
for i_block in range(self.num_res_blocks): |
|
self.downsampling_list.append( |
|
ResnetBlock( |
|
in_channels=block_in, |
|
out_channels=block_out, |
|
timestep_embedding_channels=self.timestep_embeddings_channel, |
|
dropout=dropout, |
|
) |
|
) |
|
block_in = block_out |
|
|
|
if current_resolution in attention_resolution: |
|
|
|
self.downsampling_list.append(AttentionBlock(block_in)) |
|
|
|
if i_level != self.num_resolutions - 1: |
|
self.downsampling_list.append(Downsample(block_in, resamp_with_conv)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.mid = {} |
|
self.mid["block_1"] = ResnetBlock( |
|
in_channels=block_in, |
|
out_channels=block_in, |
|
timestep_embedding_channels=self.timestep_embeddings_channel, |
|
dropout=dropout, |
|
) |
|
self.mid["attn_1"] = AttentionBlock(block_in) |
|
self.mid["block_2"] = ResnetBlock( |
|
in_channels=block_in, |
|
out_channels=block_in, |
|
timestep_embedding_channels=self.timestep_embeddings_channel, |
|
dropout=dropout, |
|
) |
|
|
|
|
|
self.norm_out = GroupNormalization(groups=32, epsilon=1e-6) |
|
self.conv_out = layers.Conv2D( |
|
2 * z_channels if double_z else z_channels, |
|
kernel_size=3, |
|
strides=1, |
|
padding="same", |
|
) |
|
|
|
def summary(self): |
|
x = layers.Input(shape=INPUT_SHAPE) |
|
model = Model(inputs=[x], outputs=self.call(x)) |
|
return model.summary() |
|
|
|
def call(self, inputs, training=True, mask=None): |
|
h = self.conv_in(inputs) |
|
for downsampling in self.downsampling_list: |
|
h = downsampling(h) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
h = self.mid["block_1"](h) |
|
h = self.mid["attn_1"](h) |
|
h = self.mid["block_2"](h) |
|
|
|
|
|
h = self.norm_out(h) |
|
h = keras.activations.swish(h) |
|
h = self.conv_out(h) |
|
return h |
|
|
|
|
|
class Decoder(Model): |
|
def __init__( |
|
self, |
|
*, |
|
channels: int = 128, |
|
output_channels: int = 3, |
|
channels_multiplier: List[int] = [1, 1, 2, 2], |
|
num_res_blocks: int = 1, |
|
attention_resolution: List[int] = [16], |
|
resolution: int = 64, |
|
z_channels=128, |
|
dropout=0.0, |
|
give_pre_end=False, |
|
resamp_with_conv=True, |
|
): |
|
super().__init__() |
|
|
|
self.channels = channels |
|
self.timestep_embeddings_channel = 0 |
|
self.num_resolutions = len(channels_multiplier) |
|
self.num_res_blocks = num_res_blocks |
|
self.resolution = resolution |
|
self.give_pre_end = give_pre_end |
|
|
|
in_channels_multiplier = (1,) + tuple(channels_multiplier) |
|
block_in = channels * channels_multiplier[-1] |
|
current_resolution = resolution // 2 ** (self.num_resolutions - 1) |
|
self.z_shape = (1, z_channels, current_resolution, current_resolution) |
|
|
|
print( |
|
"Working with z of shape {} = {} dimensions.".format( |
|
self.z_shape, np.prod(self.z_shape) |
|
) |
|
) |
|
|
|
self.conv_in = layers.Conv2D(block_in, kernel_size=3, strides=1, padding="same") |
|
|
|
|
|
self.mid = {} |
|
self.mid["block_1"] = ResnetBlock( |
|
in_channels=block_in, |
|
out_channels=block_in, |
|
timestep_embedding_channels=self.timestep_embeddings_channel, |
|
dropout=dropout, |
|
) |
|
self.mid["attn_1"] = AttentionBlock(block_in) |
|
self.mid["block_2"] = ResnetBlock( |
|
in_channels=block_in, |
|
out_channels=block_in, |
|
timestep_embedding_channels=self.timestep_embeddings_channel, |
|
dropout=dropout, |
|
) |
|
|
|
|
|
|
|
self.upsampling_list = [] |
|
|
|
for i_level in reversed(range(self.num_resolutions)): |
|
block_out = channels * channels_multiplier[i_level] |
|
for i_block in range(self.num_res_blocks + 1): |
|
self.upsampling_list.append( |
|
ResnetBlock( |
|
in_channels=block_in, |
|
out_channels=block_out, |
|
timestep_embedding_channels=self.timestep_embeddings_channel, |
|
dropout=dropout, |
|
) |
|
) |
|
block_in = block_out |
|
|
|
if current_resolution in attention_resolution: |
|
|
|
self.upsampling_list.append(AttentionBlock(block_in)) |
|
|
|
if i_level != 0: |
|
self.upsampling_list.append(Upsample(block_in, resamp_with_conv)) |
|
current_resolution *= 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.norm_out = GroupNormalization(groups=32, epsilon=1e-6) |
|
self.conv_out = layers.Conv2D( |
|
output_channels, |
|
kernel_size=3, |
|
strides=1, |
|
activation="sigmoid", |
|
padding="same", |
|
) |
|
|
|
def summary(self): |
|
x = layers.Input(shape=ENCODER_OUTPUT_SHAPE) |
|
model = Model(inputs=[x], outputs=self.call(x)) |
|
return model.summary() |
|
|
|
def call(self, inputs, training=True, mask=None): |
|
|
|
h = self.conv_in(inputs) |
|
|
|
|
|
h = self.mid["block_1"](h) |
|
h = self.mid["attn_1"](h) |
|
h = self.mid["block_2"](h) |
|
|
|
for upsampling in self.upsampling_list: |
|
h = upsampling(h) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.give_pre_end: |
|
return h |
|
|
|
h = self.norm_out(h) |
|
h = keras.activations.swish(h) |
|
h = self.conv_out(h) |
|
return h |
|
|
|
|
|
class ResnetBlock(layers.Layer): |
|
def __init__( |
|
self, |
|
*, |
|
in_channels, |
|
dropout=0.0, |
|
out_channels=None, |
|
conv_shortcut=False, |
|
timestep_embedding_channels=512, |
|
): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
out_channels = in_channels if out_channels is None else out_channels |
|
self.out_channels = out_channels |
|
self.use_conv_shortcut = conv_shortcut |
|
|
|
self.norm1 = GroupNormalization(groups=32, epsilon=1e-6) |
|
|
|
self.conv1 = layers.Conv2D( |
|
out_channels, kernel_size=3, strides=1, padding="same" |
|
) |
|
|
|
if timestep_embedding_channels > 0: |
|
self.timestep_embedding_projection = layers.Dense(out_channels) |
|
|
|
self.norm2 = GroupNormalization(groups=32, epsilon=1e-6) |
|
self.dropout = layers.Dropout(dropout) |
|
|
|
self.conv2 = layers.Conv2D( |
|
out_channels, kernel_size=3, strides=1, padding="same" |
|
) |
|
|
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
self.conv_shortcut = layers.Conv2D( |
|
out_channels, kernel_size=3, strides=1, padding="same" |
|
) |
|
else: |
|
self.nin_shortcut = layers.Conv2D( |
|
out_channels, kernel_size=1, strides=1, padding="valid" |
|
) |
|
|
|
def call(self, x): |
|
h = x |
|
h = self.norm1(h) |
|
h = keras.activations.swish(h) |
|
h = self.conv1(h) |
|
|
|
|
|
|
|
|
|
h = self.norm2(h) |
|
h = keras.activations.swish(h) |
|
h = self.dropout(h) |
|
h = self.conv2(h) |
|
|
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
x = self.conv_shortcut(x) |
|
else: |
|
x = self.nin_shortcut(x) |
|
|
|
return x + h |
|
|
|
|
|
class AttentionBlock(layers.Layer): |
|
def __init__(self, channels): |
|
super().__init__() |
|
|
|
self.norm = GroupNormalization(groups=32, epsilon=1e-6) |
|
self.q = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid") |
|
self.k = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid") |
|
self.v = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid") |
|
self.proj_out = layers.Conv2D( |
|
channels, kernel_size=1, strides=1, padding="valid" |
|
) |
|
|
|
def call(self, x): |
|
h_ = x |
|
h_ = self.norm(h_) |
|
q = self.q(h_) |
|
k = self.k(h_) |
|
v = self.v(h_) |
|
|
|
|
|
( |
|
b, |
|
h, |
|
w, |
|
c, |
|
) = q.shape |
|
if b is None: |
|
b = -1 |
|
q = tf.reshape(q, [b, h * w, c]) |
|
k = tf.reshape(k, [b, h * w, c]) |
|
w_ = tf.matmul( |
|
q, k, transpose_b=True |
|
) |
|
w_ = w_ * (int(c) ** (-0.5)) |
|
w_ = keras.activations.softmax(w_) |
|
|
|
|
|
v = tf.reshape(v, [b, h * w, c]) |
|
|
|
h_ = tf.matmul( |
|
v, w_, transpose_a=True |
|
) |
|
|
|
h_ = tf.reshape(h_, [b, h, w, c]) |
|
|
|
h_ = self.proj_out(h_) |
|
|
|
return x + h_ |
|
|
|
|
|
class Downsample(layers.Layer): |
|
def __init__(self, channels, with_conv=True): |
|
super().__init__() |
|
self.with_conv = with_conv |
|
if self.with_conv: |
|
|
|
self.down_sample = layers.Conv2D( |
|
channels, kernel_size=3, strides=2, padding="same" |
|
) |
|
else: |
|
self.down_sample = layers.AveragePooling2D(pool_size=2, strides=2) |
|
|
|
def call(self, x): |
|
x = self.down_sample(x) |
|
return x |
|
|
|
|
|
class Upsample(layers.Layer): |
|
def __init__(self, channels, with_conv=False): |
|
super().__init__() |
|
self.with_conv = with_conv |
|
if False: |
|
self.up_sample = layers.Conv2DTranspose( |
|
channels, kernel_size=3, strides=2, padding="same" |
|
) |
|
else: |
|
self.up_sample = Sequential( |
|
[ |
|
layers.UpSampling2D(size=2, interpolation="nearest"), |
|
layers.Conv2D(channels, kernel_size=3, strides=1, padding="same"), |
|
] |
|
) |
|
|
|
def call(self, x): |
|
x = self.up_sample(x) |
|
return x |
|
|