Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
25.1 kB
from typing import List, Literal
import numpy as np
import tensorflow as tf
from .discriminator.model import NLayerDiscriminator
from .losses.vqperceptual import VQLPIPSWithDiscriminator
from tensorflow import keras
from tensorflow.keras import Model, layers, Sequential
from tensorflow.keras.optimizers import Optimizer
from tensorflow_addons.layers import GroupNormalization
INPUT_SHAPE = (64, 128, 3)
ENCODER_OUTPUT_SHAPE = (8, 8, 128)
@tf.function
def hinge_d_loss(logits_real, logits_fake):
loss_real = tf.reduce_mean(keras.activations.relu(1.0 - logits_real))
loss_fake = tf.reduce_mean(keras.activations.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
@tf.function
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
tf.reduce_mean(keras.activations.softplus(-logits_real))
+ tf.reduce_mean(keras.activations.softplus(logits_fake))
)
return d_loss
class VQGAN(keras.Model):
def __init__(
self,
train_variance: float,
num_embeddings: int,
embedding_dim: int,
beta: float = 0.25,
z_channels: int = 128, # 256,
codebook_weight: float = 1.0,
disc_num_layers: int = 3,
disc_factor: float = 1.0,
disc_iter_start: int = 0,
disc_conditional: bool = False,
disc_in_channels: int = 3,
disc_weight: float = 0.3,
disc_filters: int = 64,
disc_loss: Literal["hinge", "vanilla"] = "hinge",
**kwargs,
):
super().__init__(**kwargs)
self.train_variance = train_variance
self.codebook_weight = codebook_weight
self.encoder = Encoder()
self.decoder = Decoder()
self.quantize = VectorQuantizer(num_embeddings, embedding_dim, beta=beta)
self.quant_conv = layers.Conv2D(embedding_dim, kernel_size=1)
self.post_quant_conv = layers.Conv2D(z_channels, kernel_size=1)
self.vqvae = self.get_vqvae()
self.perceptual_loss = VQLPIPSWithDiscriminator(
reduction=tf.keras.losses.Reduction.NONE
)
self.discriminator = NLayerDiscriminator(
input_channels=disc_in_channels,
filters=disc_filters,
n_layers=disc_num_layers,
)
self.discriminator_iter_start = disc_iter_start
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(
name="reconstruction_loss"
)
self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss")
self.disc_loss_tracker = keras.metrics.Mean(name="disc_loss")
self.gen_optimizer: Optimizer = None
self.disc_optimizer: Optimizer = None
def get_vqvae(self):
inputs = keras.Input(shape=INPUT_SHAPE)
quant = self.encode(inputs)
reconstructed = self.decode(quant)
return keras.Model(inputs, reconstructed, name="vq_vae")
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return self.quantize(h)
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def call(self, inputs, training=True, mask=None):
return self.vqvae(inputs)
def calculate_adaptive_weight(
self, nll_loss, g_loss, tape, trainable_vars, discriminator_weight
):
nll_grads = tape.gradient(nll_loss, trainable_vars)[0]
g_grads = tape.gradient(g_loss, trainable_vars)[0]
d_weight = tf.norm(nll_grads) / (tf.norm(g_grads) + 1e-4)
d_weight = tf.stop_gradient(tf.clip_by_value(d_weight, 0.0, 1e4))
return d_weight * discriminator_weight
@tf.function
def adopt_weight(self, weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
def get_global_step(self, optimizer):
return optimizer.iterations
def compile(
self,
gen_optimizer,
disc_optimizer,
):
super().compile()
self.gen_optimizer = gen_optimizer
self.disc_optimizer = disc_optimizer
def train_step(self, data):
x, y = data
# Autoencode
with tf.GradientTape() as tape:
with tf.GradientTape(persistent=True) as adaptive_tape:
reconstructions = self(x, training=True)
# Calculate the losses.
# reconstruction_loss = (
# tf.reduce_mean((y - reconstructions) ** 2) / self.train_variance
# )
logits_fake = self.discriminator(reconstructions, training=False)
g_loss = -tf.reduce_mean(logits_fake)
nll_loss = self.perceptual_loss(y, reconstructions)
d_weight = self.calculate_adaptive_weight(
nll_loss,
g_loss,
adaptive_tape,
self.decoder.conv_out.trainable_variables,
self.discriminator_weight,
)
del adaptive_tape
disc_factor = self.adopt_weight(
weight=self.disc_factor,
global_step=self.get_global_step(self.gen_optimizer),
threshold=self.discriminator_iter_start,
)
# total_loss = reconstruction_loss + sum(self.vqvae.losses)
total_loss = (
nll_loss
+ d_weight * disc_factor * g_loss
# + self.codebook_weight * tf.reduce_mean(self.vqvae.losses)
+ self.codebook_weight * sum(self.vqvae.losses)
)
# Backpropagation.
grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
self.gen_optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))
# Discriminator
with tf.GradientTape() as disc_tape:
logits_real = self.discriminator(y, training=True)
logits_fake = self.discriminator(reconstructions, training=True)
disc_factor = self.adopt_weight(
weight=self.disc_factor,
global_step=self.get_global_step(self.disc_optimizer),
threshold=self.discriminator_iter_start,
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
disc_grads = disc_tape.gradient(d_loss, self.discriminator.trainable_variables)
self.disc_optimizer.apply_gradients(
zip(disc_grads, self.discriminator.trainable_variables)
)
# Loss tracking.
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(nll_loss)
self.vq_loss_tracker.update_state(sum(self.vqvae.losses))
self.disc_loss_tracker.update_state(d_loss)
# Log results.
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"vqvae_loss": self.vq_loss_tracker.result(),
"disc_loss": self.disc_loss_tracker.result(),
}
class VectorQuantizer(layers.Layer):
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.beta = (
beta # This parameter is best kept between [0.25, 2] as per the paper.
)
# Initialize the embeddings which we will quantize.
w_init = tf.random_uniform_initializer()
self.embeddings = tf.Variable(
initial_value=w_init(
shape=(self.embedding_dim, self.num_embeddings) # , dtype="float32"
),
trainable=True,
name="embeddings_vqvae",
)
def call(self, x):
# Calculate the input shape of the inputs and
# then flatten the inputs keeping `embedding_dim` intact.
input_shape = tf.shape(x)
flattened = tf.reshape(x, [-1, self.embedding_dim])
# Quantization.
encoding_indices = self.get_code_indices(flattened)
encodings = tf.one_hot(encoding_indices, self.num_embeddings)
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
quantized = tf.reshape(quantized, input_shape)
# Calculate vector quantization loss and add that to the layer. You can learn more
# about adding losses to different layers here:
# https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
# the original paper to get a handle on the formulation of the loss function.
commitment_loss = self.beta * tf.reduce_mean(
(tf.stop_gradient(quantized) - x) ** 2
)
codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
self.add_loss(commitment_loss + codebook_loss)
# Straight-through estimator.
quantized = x + tf.stop_gradient(quantized - x)
return quantized
def get_code_indices(self, flattened_inputs):
# Calculate L2-normalized distance between the inputs and the codes.
similarity = tf.matmul(flattened_inputs, self.embeddings)
distances = (
tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True)
+ tf.reduce_sum(self.embeddings**2, axis=0)
- 2 * similarity
)
# Derive the indices for minimum distances.
encoding_indices = tf.argmin(distances, axis=1)
return encoding_indices
class Encoder(Model):
def __init__(
self,
*,
channels: int = 128,
output_channels: int = 3,
channels_multiplier: List[int] = [1, 1, 2, 2], # [1, 1, 2, 2, 4],
num_res_blocks: int = 1, # 2,
attention_resolution: List[int] = [16],
resolution: int = 64, # 256,
z_channels=128, # 256,
dropout=0.0,
double_z=False,
resamp_with_conv=True,
):
super().__init__()
self.channels = channels
self.timestep_embeddings_channel = 0
self.num_resolutions = len(channels_multiplier)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.conv_in = layers.Conv2D(
self.channels, kernel_size=3, strides=1, padding="same"
)
current_resolution = resolution
in_channels_multiplier = (1,) + tuple(channels_multiplier)
self.downsampling_list = []
for i_level in range(self.num_resolutions):
block_in = channels * in_channels_multiplier[i_level]
block_out = channels * channels_multiplier[i_level]
for i_block in range(self.num_res_blocks):
self.downsampling_list.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
timestep_embedding_channels=self.timestep_embeddings_channel,
dropout=dropout,
)
)
block_in = block_out
if current_resolution in attention_resolution:
# attentions.append(layers.Attention())
self.downsampling_list.append(AttentionBlock(block_in))
if i_level != self.num_resolutions - 1:
self.downsampling_list.append(Downsample(block_in, resamp_with_conv))
# self.downsampling = []
# for i_level in range(self.num_resolutions):
# block = []
# attentions = []
# block_in = channels * in_channels_multiplier[i_level]
# block_out = channels * channels_multiplier[i_level]
# for i_block in range(self.num_res_blocks):
# block.append(
# ResnetBlock(
# in_channels=block_in,
# out_channels=block_out,
# timestep_embedding_channels=self.timestep_embeddings_channel,
# dropout=dropout,
# )
# )
# block_in = block_out
# if current_resolution in attention_resolution:
# # attentions.append(layers.Attention())
# attentions.append(AttentionBlock(block_in))
# down = {}
# down["block"] = block
# down["attention"] = attentions
# if i_level != self.num_resolutions - 1:
# down["downsample"] = Downsample(block_in, resamp_with_conv)
# self.downsampling.append(down)
# middle
self.mid = {}
self.mid["block_1"] = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
timestep_embedding_channels=self.timestep_embeddings_channel,
dropout=dropout,
)
self.mid["attn_1"] = AttentionBlock(block_in)
self.mid["block_2"] = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
timestep_embedding_channels=self.timestep_embeddings_channel,
dropout=dropout,
)
# end
self.norm_out = GroupNormalization(groups=32, epsilon=1e-6)
self.conv_out = layers.Conv2D(
2 * z_channels if double_z else z_channels,
kernel_size=3,
strides=1,
padding="same",
)
def summary(self):
x = layers.Input(shape=INPUT_SHAPE)
model = Model(inputs=[x], outputs=self.call(x))
return model.summary()
def call(self, inputs, training=True, mask=None):
h = self.conv_in(inputs)
for downsampling in self.downsampling_list:
h = downsampling(h)
# for i_level in range(self.num_resolutions):
# for i_block in range(self.num_res_blocks):
# h = self.downsampling[i_level]["block"][i_block](hs[-1])
# if len(self.downsampling[i_level]["attention"]) > 0:
# h = self.downsampling[i_level]["attention"][i_block](h)
# hs.append(h)
# if i_level != self.num_resolutions - 1:
# hs.append(self.downsampling[i_level]["downsample"](hs[-1]))
# h = hs[-1]
h = self.mid["block_1"](h)
h = self.mid["attn_1"](h)
h = self.mid["block_2"](h)
# end
h = self.norm_out(h)
h = keras.activations.swish(h)
h = self.conv_out(h)
return h
class Decoder(Model):
def __init__(
self,
*,
channels: int = 128,
output_channels: int = 3,
channels_multiplier: List[int] = [1, 1, 2, 2], # [1, 1, 2, 2, 4],
num_res_blocks: int = 1, # 2,
attention_resolution: List[int] = [16],
resolution: int = 64, # 256,
z_channels=128, # 256,
dropout=0.0,
give_pre_end=False,
resamp_with_conv=True,
):
super().__init__()
self.channels = channels
self.timestep_embeddings_channel = 0
self.num_resolutions = len(channels_multiplier)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.give_pre_end = give_pre_end
in_channels_multiplier = (1,) + tuple(channels_multiplier)
block_in = channels * channels_multiplier[-1]
current_resolution = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, current_resolution, current_resolution)
print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
self.conv_in = layers.Conv2D(block_in, kernel_size=3, strides=1, padding="same")
# middle
self.mid = {}
self.mid["block_1"] = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
timestep_embedding_channels=self.timestep_embeddings_channel,
dropout=dropout,
)
self.mid["attn_1"] = AttentionBlock(block_in)
self.mid["block_2"] = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
timestep_embedding_channels=self.timestep_embeddings_channel,
dropout=dropout,
)
# upsampling
self.upsampling_list = []
for i_level in reversed(range(self.num_resolutions)):
block_out = channels * channels_multiplier[i_level]
for i_block in range(self.num_res_blocks + 1):
self.upsampling_list.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
timestep_embedding_channels=self.timestep_embeddings_channel,
dropout=dropout,
)
)
block_in = block_out
if current_resolution in attention_resolution:
# attentions.append(layers.Attention())
self.upsampling_list.append(AttentionBlock(block_in))
if i_level != 0:
self.upsampling_list.append(Upsample(block_in, resamp_with_conv))
current_resolution *= 2
# self.upsampling.insert(0, upsampling)
# self.upsampling = []
# for i_level in reversed(range(self.num_resolutions)):
# block = []
# attentions = []
# block_out = channels * channels_multiplier[i_level]
# for i_block in range(self.num_res_blocks + 1):
# block.append(
# ResnetBlock(
# in_channels=block_in,
# out_channels=block_out,
# timestep_embedding_channels=self.timestep_embeddings_channel,
# dropout=dropout,
# )
# )
# block_in = block_out
# if current_resolution in attention_resolution:
# # attentions.append(layers.Attention())
# attentions.append(AttentionBlock(block_in))
# upsampling = {}
# upsampling["block"] = block
# upsampling["attention"] = attentions
# if i_level != 0:
# upsampling["upsample"] = Upsample(block_in, resamp_with_conv)
# current_resolution *= 2
# self.upsampling.insert(0, upsampling)
# end
self.norm_out = GroupNormalization(groups=32, epsilon=1e-6)
self.conv_out = layers.Conv2D(
output_channels,
kernel_size=3,
strides=1,
activation="sigmoid",
padding="same",
)
def summary(self):
x = layers.Input(shape=ENCODER_OUTPUT_SHAPE)
model = Model(inputs=[x], outputs=self.call(x))
return model.summary()
def call(self, inputs, training=True, mask=None):
h = self.conv_in(inputs)
# middle
h = self.mid["block_1"](h)
h = self.mid["attn_1"](h)
h = self.mid["block_2"](h)
for upsampling in self.upsampling_list:
h = upsampling(h)
# for i_level in reversed(range(self.num_resolutions)):
# for i_block in range(self.num_res_blocks + 1):
# h = self.upsampling[i_level]["block"][i_block](h)
# if len(self.upsampling[i_level]["attention"]) > 0:
# h = self.upsampling[i_level]["attention"][i_block](h)
# if i_level != 0:
# h = self.upsampling[i_level]["upsample"](h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = keras.activations.swish(h)
h = self.conv_out(h)
return h
class ResnetBlock(layers.Layer):
def __init__(
self,
*,
in_channels,
dropout=0.0,
out_channels=None,
conv_shortcut=False,
timestep_embedding_channels=512,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = GroupNormalization(groups=32, epsilon=1e-6)
self.conv1 = layers.Conv2D(
out_channels, kernel_size=3, strides=1, padding="same"
)
if timestep_embedding_channels > 0:
self.timestep_embedding_projection = layers.Dense(out_channels)
self.norm2 = GroupNormalization(groups=32, epsilon=1e-6)
self.dropout = layers.Dropout(dropout)
self.conv2 = layers.Conv2D(
out_channels, kernel_size=3, strides=1, padding="same"
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = layers.Conv2D(
out_channels, kernel_size=3, strides=1, padding="same"
)
else:
self.nin_shortcut = layers.Conv2D(
out_channels, kernel_size=1, strides=1, padding="valid"
)
def call(self, x):
h = x
h = self.norm1(h)
h = keras.activations.swish(h)
h = self.conv1(h)
# if timestamp_embedding is not None:
# h = h + self.timestep_embedding_projection(keras.activations.swish(timestamp_embedding))
h = self.norm2(h)
h = keras.activations.swish(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttentionBlock(layers.Layer):
def __init__(self, channels):
super().__init__()
self.norm = GroupNormalization(groups=32, epsilon=1e-6)
self.q = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid")
self.k = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid")
self.v = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid")
self.proj_out = layers.Conv2D(
channels, kernel_size=1, strides=1, padding="valid"
)
def call(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
(
b,
h,
w,
c,
) = q.shape
if b is None:
b = -1
q = tf.reshape(q, [b, h * w, c])
k = tf.reshape(k, [b, h * w, c])
w_ = tf.matmul(
q, k, transpose_b=True
) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = keras.activations.softmax(w_)
# attend to values
v = tf.reshape(v, [b, h * w, c])
# w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = tf.matmul(
v, w_, transpose_a=True
) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
# h_ = h_.reshape(b, c, h, w)
h_ = tf.reshape(h_, [b, h, w, c])
h_ = self.proj_out(h_)
return x + h_
class Downsample(layers.Layer):
def __init__(self, channels, with_conv=True):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.down_sample = layers.Conv2D(
channels, kernel_size=3, strides=2, padding="same"
)
else:
self.down_sample = layers.AveragePooling2D(pool_size=2, strides=2)
def call(self, x):
x = self.down_sample(x)
return x
class Upsample(layers.Layer):
def __init__(self, channels, with_conv=False):
super().__init__()
self.with_conv = with_conv
if False: # self.with_conv:
self.up_sample = layers.Conv2DTranspose(
channels, kernel_size=3, strides=2, padding="same"
)
else:
self.up_sample = Sequential(
[
layers.UpSampling2D(size=2, interpolation="nearest"),
layers.Conv2D(channels, kernel_size=3, strides=1, padding="same"),
]
)
def call(self, x):
x = self.up_sample(x)
return x