|
import tensorflow as tf |
|
from tensorflow.keras import layers |
|
|
|
|
|
@tf.keras.utils.register_keras_serializable() |
|
class VectorQuantizer(layers.Layer): |
|
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs): |
|
super().__init__(**kwargs) |
|
self.embedding_dim = embedding_dim |
|
self.num_embeddings = num_embeddings |
|
|
|
self.beta = beta |
|
|
|
|
|
w_init = tf.random_uniform_initializer() |
|
self.embeddings = tf.Variable( |
|
initial_value=w_init( |
|
shape=(self.embedding_dim, self.num_embeddings), dtype="float32" |
|
), |
|
trainable=True, |
|
name="embeddings_vqvae", |
|
) |
|
|
|
def get_config(self): |
|
config = super().get_config() |
|
config.update( |
|
{ |
|
"embedding_dim": self.embedding_dim, |
|
"num_embeddings": self.num_embeddings, |
|
"beta": self.beta, |
|
} |
|
) |
|
return config |
|
|
|
def call(self, x): |
|
|
|
|
|
input_shape = tf.shape(x) |
|
flattened = tf.reshape(x, [-1, self.embedding_dim]) |
|
|
|
|
|
encoding_indices = self.get_code_indices(flattened) |
|
encodings = tf.one_hot(encoding_indices, self.num_embeddings) |
|
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True) |
|
quantized = tf.reshape(quantized, input_shape) |
|
|
|
|
|
|
|
|
|
|
|
commitment_loss = self.beta * tf.reduce_mean( |
|
(tf.stop_gradient(quantized) - x) ** 2 |
|
) |
|
codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2) |
|
loss = commitment_loss + codebook_loss |
|
|
|
|
|
|
|
quantized = x + tf.stop_gradient(quantized - x) |
|
return quantized, encoding_indices, loss |
|
|
|
def get_code_indices(self, flattened_inputs): |
|
|
|
similarity = tf.matmul(flattened_inputs, self.embeddings) |
|
distances = ( |
|
tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True) |
|
+ tf.reduce_sum(self.embeddings**2, axis=0) |
|
- 2 * similarity |
|
) |
|
|
|
|
|
encoding_indices = tf.argmin(distances, axis=1) |
|
return encoding_indices |
|
|
|
def get_codebook_entry(self, indices, shape): |
|
encodings = tf.one_hot(indices, self.num_embeddings) |
|
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True) |
|
quantized = tf.reshape(quantized, shape) |
|
return quantized |