Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
5.9 kB
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
input_shape = (20, 64, 64, 1)
class VectorQuantizer(layers.Layer):
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.beta = (
beta # This parameter is best kept between [0.25, 2] as per the paper.
)
# Initialize the embeddings which we will quantize.
w_init = tf.random_uniform_initializer()
self.embeddings = tf.Variable(
initial_value=w_init(
shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
),
trainable=True,
name="embeddings_vqvae",
)
def call(self, x):
# Calculate the input shape of the inputs and
# then flatten the inputs keeping `embedding_dim` intact.
input_shape = tf.shape(x)
flattened = tf.reshape(x, [-1, self.embedding_dim])
# Quantization.
encoding_indices = self.get_code_indices(flattened)
encodings = tf.one_hot(encoding_indices, self.num_embeddings)
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
quantized = tf.reshape(quantized, input_shape)
# Calculate vector quantization loss and add that to the layer. You can learn more
# about adding losses to different layers here:
# https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
# the original paper to get a handle on the formulation of the loss function.
commitment_loss = self.beta * tf.reduce_mean(
(tf.stop_gradient(quantized) - x) ** 2
)
codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
self.add_loss(commitment_loss + codebook_loss)
# Straight-through estimator.
quantized = x + tf.stop_gradient(quantized - x)
return quantized
def get_code_indices(self, flattened_inputs):
# Calculate L2-normalized distance between the inputs and the codes.
similarity = tf.matmul(flattened_inputs, self.embeddings)
distances = (
tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
+ tf.reduce_sum(self.embeddings ** 2, axis=0)
- 2 * similarity
)
# Derive the indices for minimum distances.
encoding_indices = tf.argmin(distances, axis=1)
return encoding_indices
class VQVAE(keras.Model):
def __init__(self, train_variance:float, latent_dim:int=32, num_embeddings:int=128, **kwargs):
super().__init__(**kwargs)
self.train_variance = train_variance
self.latent_dim = latent_dim
self.num_embeddings = num_embeddings
self.vqvae = self.get_vqvae()
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(
name="reconstruction_loss"
)
self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss")
def get_encoder(self):
encoder_inputs = keras.Input(shape=input_shape)
x = layers.TimeDistributed(layers.Conv2D(32, 3, activation="relu", strides=2, padding="same"))(
encoder_inputs
)
x = layers.TimeDistributed(layers.Conv2D(64, 3, activation="relu", strides=2, padding="same"))(x)
encoder_outputs = layers.TimeDistributed(layers.Conv2D(self.latent_dim, 1, padding="same"))(x)
return keras.Model(encoder_inputs, encoder_outputs, name="encoder")
def get_decoder(self):
latent_inputs = keras.Input(shape=self.get_encoder().output.shape[1:])
x = layers.TimeDistributed(layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same"))(
latent_inputs
)
x = layers.TimeDistributed(layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same"))(x)
decoder_outputs = layers.TimeDistributed(layers.Conv2DTranspose(1, 3, padding="same"))(x)
return keras.Model(latent_inputs, decoder_outputs, name="decoder")
def get_vqvae(self):
self.vq_layer = VectorQuantizer(self.num_embeddings, self.latent_dim, name="vector_quantizer")
self.encoder = self.get_encoder()
self.decoder = self.get_decoder()
inputs = keras.Input(shape=input_shape)
encoder_outputs = self.encoder(inputs)
quantized_latents = self.vq_layer(encoder_outputs)
reconstructions = self.decoder(quantized_latents)
return keras.Model(inputs, reconstructions, name="vq_vae")
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
# Outputs from the VQ-VAE.
reconstructions = self.vqvae(x)
# Calculate the losses.
reconstruction_loss = (
tf.reduce_mean((y - reconstructions) ** 2) / self.train_variance
)
total_loss = reconstruction_loss + sum(self.vqvae.losses)
# Backpropagation.
grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))
# Loss tracking.
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.vq_loss_tracker.update_state(sum(self.vqvae.losses))
# Log results.
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"vqvae_loss": self.vq_loss_tracker.result(),
}
def call(self, inputs, training=False, mask=None):
return self.vqvae(inputs)