|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
from tensorflow import keras |
|
from tensorflow.keras import layers |
|
import tensorflow as tf |
|
|
|
input_shape = (20, 64, 64, 1) |
|
|
|
class VectorQuantizer(layers.Layer): |
|
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs): |
|
super().__init__(**kwargs) |
|
self.embedding_dim = embedding_dim |
|
self.num_embeddings = num_embeddings |
|
self.beta = ( |
|
beta |
|
) |
|
|
|
|
|
w_init = tf.random_uniform_initializer() |
|
self.embeddings = tf.Variable( |
|
initial_value=w_init( |
|
shape=(self.embedding_dim, self.num_embeddings), dtype="float32" |
|
), |
|
trainable=True, |
|
name="embeddings_vqvae", |
|
) |
|
|
|
def call(self, x): |
|
|
|
|
|
input_shape = tf.shape(x) |
|
flattened = tf.reshape(x, [-1, self.embedding_dim]) |
|
|
|
|
|
encoding_indices = self.get_code_indices(flattened) |
|
encodings = tf.one_hot(encoding_indices, self.num_embeddings) |
|
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True) |
|
quantized = tf.reshape(quantized, input_shape) |
|
|
|
|
|
|
|
|
|
|
|
commitment_loss = self.beta * tf.reduce_mean( |
|
(tf.stop_gradient(quantized) - x) ** 2 |
|
) |
|
codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2) |
|
self.add_loss(commitment_loss + codebook_loss) |
|
|
|
|
|
quantized = x + tf.stop_gradient(quantized - x) |
|
return quantized |
|
|
|
def get_code_indices(self, flattened_inputs): |
|
|
|
similarity = tf.matmul(flattened_inputs, self.embeddings) |
|
distances = ( |
|
tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True) |
|
+ tf.reduce_sum(self.embeddings ** 2, axis=0) |
|
- 2 * similarity |
|
) |
|
|
|
|
|
encoding_indices = tf.argmin(distances, axis=1) |
|
return encoding_indices |
|
|
|
|
|
class VQVAE(keras.Model): |
|
def __init__(self, train_variance:float, latent_dim:int=32, num_embeddings:int=128, **kwargs): |
|
super().__init__(**kwargs) |
|
self.train_variance = train_variance |
|
self.latent_dim = latent_dim |
|
self.num_embeddings = num_embeddings |
|
|
|
self.vqvae = self.get_vqvae() |
|
|
|
self.total_loss_tracker = keras.metrics.Mean(name="total_loss") |
|
self.reconstruction_loss_tracker = keras.metrics.Mean( |
|
name="reconstruction_loss" |
|
) |
|
self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss") |
|
|
|
|
|
def get_encoder(self): |
|
encoder_inputs = keras.Input(shape=input_shape) |
|
x = layers.TimeDistributed(layers.Conv2D(32, 3, activation="relu", strides=2, padding="same"))( |
|
encoder_inputs |
|
) |
|
x = layers.TimeDistributed(layers.Conv2D(64, 3, activation="relu", strides=2, padding="same"))(x) |
|
encoder_outputs = layers.TimeDistributed(layers.Conv2D(self.latent_dim, 1, padding="same"))(x) |
|
return keras.Model(encoder_inputs, encoder_outputs, name="encoder") |
|
|
|
|
|
def get_decoder(self): |
|
latent_inputs = keras.Input(shape=self.get_encoder().output.shape[1:]) |
|
x = layers.TimeDistributed(layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same"))( |
|
latent_inputs |
|
) |
|
x = layers.TimeDistributed(layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same"))(x) |
|
decoder_outputs = layers.TimeDistributed(layers.Conv2DTranspose(1, 3, padding="same"))(x) |
|
return keras.Model(latent_inputs, decoder_outputs, name="decoder") |
|
|
|
def get_vqvae(self): |
|
self.vq_layer = VectorQuantizer(self.num_embeddings, self.latent_dim, name="vector_quantizer") |
|
self.encoder = self.get_encoder() |
|
self.decoder = self.get_decoder() |
|
inputs = keras.Input(shape=input_shape) |
|
encoder_outputs = self.encoder(inputs) |
|
quantized_latents = self.vq_layer(encoder_outputs) |
|
reconstructions = self.decoder(quantized_latents) |
|
return keras.Model(inputs, reconstructions, name="vq_vae") |
|
|
|
def train_step(self, data): |
|
x, y = data |
|
with tf.GradientTape() as tape: |
|
|
|
reconstructions = self.vqvae(x) |
|
|
|
|
|
reconstruction_loss = ( |
|
tf.reduce_mean((y - reconstructions) ** 2) / self.train_variance |
|
) |
|
total_loss = reconstruction_loss + sum(self.vqvae.losses) |
|
|
|
|
|
grads = tape.gradient(total_loss, self.vqvae.trainable_variables) |
|
self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables)) |
|
|
|
|
|
self.total_loss_tracker.update_state(total_loss) |
|
self.reconstruction_loss_tracker.update_state(reconstruction_loss) |
|
self.vq_loss_tracker.update_state(sum(self.vqvae.losses)) |
|
|
|
|
|
return { |
|
"loss": self.total_loss_tracker.result(), |
|
"reconstruction_loss": self.reconstruction_loss_tracker.result(), |
|
"vqvae_loss": self.vq_loss_tracker.result(), |
|
} |
|
|
|
def call(self, inputs, training=False, mask=None): |
|
return self.vqvae(inputs) |
|
|