File size: 3,158 Bytes
3be620b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import tensorflow as tf
from tensorflow.keras import layers
@tf.keras.utils.register_keras_serializable()
class VectorQuantizer(layers.Layer):
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
# This parameter is best kept between [0.25, 2] as per the paper.
self.beta = beta
# Initialize the embeddings which we will quantize.
w_init = tf.random_uniform_initializer()
self.embeddings = tf.Variable(
initial_value=w_init(
shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
),
trainable=True,
name="embeddings_vqvae",
)
def get_config(self):
config = super().get_config()
config.update(
{
"embedding_dim": self.embedding_dim,
"num_embeddings": self.num_embeddings,
"beta": self.beta,
}
)
return config
def call(self, x):
# Calculate the input shape of the inputs and
# then flatten the inputs keeping `embedding_dim` intact.
input_shape = tf.shape(x)
flattened = tf.reshape(x, [-1, self.embedding_dim])
# Quantization.
encoding_indices = self.get_code_indices(flattened)
encodings = tf.one_hot(encoding_indices, self.num_embeddings)
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
quantized = tf.reshape(quantized, input_shape)
# Calculate vector quantization loss and add that to the layer. You can learn more
# about adding losses to different layers here:
# https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
# the original paper to get a handle on the formulation of the loss function.
commitment_loss = self.beta * tf.reduce_mean(
(tf.stop_gradient(quantized) - x) ** 2
)
codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
loss = commitment_loss + codebook_loss
# self.add_loss(commitment_loss + codebook_loss)
# Straight-through estimator.
quantized = x + tf.stop_gradient(quantized - x)
return quantized, encoding_indices, loss
def get_code_indices(self, flattened_inputs):
# Calculate L2-normalized distance between the inputs and the codes.
similarity = tf.matmul(flattened_inputs, self.embeddings)
distances = (
tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True)
+ tf.reduce_sum(self.embeddings**2, axis=0)
- 2 * similarity
)
# Derive the indices for minimum distances.
encoding_indices = tf.argmin(distances, axis=1)
return encoding_indices
def get_codebook_entry(self, indices, shape):
encodings = tf.one_hot(indices, self.num_embeddings)
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
quantized = tf.reshape(quantized, shape)
return quantized |