|
import numpy as np |
|
import tensorflow as tf |
|
import tensorflow_addons as tfa |
|
from ganime.configs.model_configs import GPTConfig, ModelConfig |
|
from ganime.model.vqgan_clean.vqgan import VQGAN |
|
from ganime.trainer.warmup.cosine import WarmUpCosine |
|
from tensorflow import keras |
|
from tensorflow.keras import Model, layers |
|
from transformers import TFGPT2Model, GPT2Config |
|
from tensorflow.keras import mixed_precision |
|
|
|
|
|
class Net2Net(Model): |
|
def __init__( |
|
self, |
|
transformer_config: GPTConfig, |
|
first_stage_config: ModelConfig, |
|
trainer_config, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.first_stage_model = VQGAN(**first_stage_config) |
|
|
|
|
|
|
|
|
|
self.transformer = TFGPT2Model.from_pretrained( |
|
"gpt2-medium" |
|
) |
|
if "checkpoint_path" in transformer_config: |
|
print(f"Restoring weights from {transformer_config['checkpoint_path']}") |
|
self.load_weights(transformer_config["checkpoint_path"]) |
|
|
|
self.loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( |
|
from_logits=True, reduction=tf.keras.losses.Reduction.NONE |
|
) |
|
|
|
self.loss_tracker = keras.metrics.Mean(name="loss") |
|
|
|
self.scheduled_lrs = self.create_warmup_scheduler(trainer_config) |
|
|
|
optimizer = tfa.optimizers.AdamW( |
|
learning_rate=self.scheduled_lrs, weight_decay=1e-4 |
|
) |
|
self.compile( |
|
optimizer=optimizer, |
|
loss=self.loss_fn, |
|
|
|
) |
|
|
|
|
|
self.gradient_accumulation = [ |
|
tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False) |
|
for v in self.transformer.trainable_variables |
|
] |
|
|
|
def create_warmup_scheduler(self, trainer_config): |
|
len_x_train = trainer_config["len_x_train"] |
|
batch_size = trainer_config["batch_size"] |
|
n_epochs = trainer_config["n_epochs"] |
|
|
|
total_steps = int(len_x_train / batch_size * n_epochs) |
|
warmup_epoch_percentage = trainer_config["warmup_epoch_percentage"] |
|
warmup_steps = int(total_steps * warmup_epoch_percentage) |
|
|
|
scheduled_lrs = WarmUpCosine( |
|
lr_start=trainer_config["lr_start"], |
|
lr_max=trainer_config["lr_max"], |
|
warmup_steps=warmup_steps, |
|
total_steps=total_steps, |
|
) |
|
|
|
return scheduled_lrs |
|
|
|
def apply_accu_gradients(self): |
|
|
|
self.optimizer.apply_gradients( |
|
zip(self.gradient_accumulation, self.transformer.trainable_variables) |
|
) |
|
|
|
|
|
for i in range(len(self.gradient_accumulation)): |
|
self.gradient_accumulation[i].assign( |
|
tf.zeros_like(self.transformer.trainable_variables[i], dtype=tf.float32) |
|
) |
|
|
|
@property |
|
def metrics(self): |
|
|
|
|
|
|
|
|
|
|
|
return [ |
|
self.loss_tracker, |
|
] |
|
|
|
@tf.function( |
|
|
|
|
|
|
|
) |
|
def encode_to_z(self, x): |
|
quant_z, indices, quantized_loss = self.first_stage_model.encode(x) |
|
|
|
batch_size = tf.shape(quant_z)[0] |
|
|
|
indices = tf.reshape(indices, shape=(batch_size, -1)) |
|
return quant_z, indices |
|
|
|
def call(self, inputs, training=None, mask=None): |
|
|
|
first_frame = inputs["first_frame"] |
|
last_frame = inputs["last_frame"] |
|
n_frames = inputs["n_frames"] |
|
|
|
return self.generate_video(first_frame, last_frame, n_frames) |
|
|
|
|
|
|
|
@tf.function() |
|
def predict_next_indices(self, inputs, example_indices): |
|
logits = self.transformer(inputs) |
|
logits = logits.last_hidden_state |
|
logits = tf.cast(logits, dtype=tf.float32) |
|
|
|
logits = logits[ |
|
:, tf.shape(example_indices)[1] - 1 : |
|
] |
|
|
|
return logits |
|
|
|
@tf.function() |
|
def body(self, total_loss, frames, index, last_frame_indices): |
|
|
|
previous_frame_indices = self.encode_to_z(frames[:, index - 1, ...])[1] |
|
cz_indices = tf.concat((last_frame_indices, previous_frame_indices), axis=1) |
|
target_indices = self.encode_to_z(frames[:, index, ...])[1] |
|
|
|
|
|
with tf.GradientTape() as tape: |
|
logits = self.predict_next_indices( |
|
cz_indices[:, :-1], last_frame_indices |
|
) |
|
|
|
frame_loss = tf.cast( |
|
tf.reduce_mean(self.loss_fn(target_indices, logits)), |
|
dtype=tf.float32, |
|
) |
|
|
|
|
|
gradients = tape.gradient(frame_loss, self.transformer.trainable_variables) |
|
|
|
|
|
for i in range(len(self.gradient_accumulation)): |
|
self.gradient_accumulation[i].assign_add(tf.cast(gradients[i], tf.float32)) |
|
|
|
index = tf.add(index, 1) |
|
total_loss = tf.add(total_loss, frame_loss) |
|
return total_loss, frames, index, last_frame_indices |
|
|
|
def cond(self, total_loss, frames, index, last_frame_indices): |
|
return tf.less(index, tf.shape(frames)[1]) |
|
|
|
def train_step(self, data): |
|
first_frame = data["first_frame"] |
|
last_frame = data["last_frame"] |
|
frames = data["y"] |
|
n_frames = data["n_frames"] |
|
|
|
last_frame_indices = self.encode_to_z(last_frame)[1] |
|
total_loss = 0.0 |
|
|
|
total_loss, _, _, _ = tf.while_loop( |
|
cond=self.cond, |
|
body=self.body, |
|
loop_vars=(tf.constant(0.0), frames, tf.constant(1), last_frame_indices), |
|
) |
|
|
|
self.apply_accu_gradients() |
|
self.loss_tracker.update_state(total_loss) |
|
return {m.name: m.result() for m in self.metrics} |
|
|
|
def cond_test_step(self, total_loss, frames, index, last_frame_indices): |
|
return tf.less(index, tf.shape(frames)[1]) |
|
|
|
@tf.function() |
|
def body_test_step(self, total_loss, frames, index, predicted_logits): |
|
target_indices = self.encode_to_z(frames[:, index, ...])[1] |
|
|
|
logits = predicted_logits[index] |
|
|
|
frame_loss = tf.cast( |
|
tf.reduce_mean(self.loss_fn(target_indices, logits)), |
|
dtype=tf.float32, |
|
) |
|
|
|
index = tf.add(index, 1) |
|
total_loss = tf.add(total_loss, frame_loss) |
|
return total_loss, frames, index, predicted_logits |
|
|
|
def test_step(self, data): |
|
first_frame = data["first_frame"] |
|
last_frame = data["last_frame"] |
|
frames = data["y"] |
|
n_frames = data["n_frames"] |
|
|
|
predicted_logits, _, _ = self.predict_logits(first_frame, last_frame, n_frames) |
|
|
|
total_loss, _, _, _ = tf.while_loop( |
|
cond=self.cond_test_step, |
|
body=self.body_test_step, |
|
loop_vars=(tf.constant(0.0), frames, tf.constant(1), predicted_logits), |
|
) |
|
|
|
|
|
self.loss_tracker.update_state(total_loss) |
|
return {m.name: m.result() for m in self.metrics} |
|
|
|
@tf.function() |
|
def convert_logits_to_indices(self, logits, shape): |
|
probs = tf.keras.activations.softmax(logits) |
|
_, generated_indices = tf.math.top_k(probs) |
|
generated_indices = tf.reshape( |
|
generated_indices, |
|
shape, |
|
) |
|
return generated_indices |
|
|
|
|
|
|
|
|
|
|
|
|
|
@tf.function() |
|
def predict_logits(self, first_frame, last_frame, n_frames): |
|
quant_first, indices_first = self.encode_to_z(first_frame) |
|
quant_last, indices_last = self.encode_to_z(last_frame) |
|
|
|
indices_previous = indices_first |
|
|
|
predicted_logits = tf.TensorArray( |
|
tf.float32, size=0, dynamic_size=True, clear_after_read=False |
|
) |
|
|
|
index = tf.constant(1) |
|
while tf.less(index, tf.reduce_max(n_frames)): |
|
tf.autograph.experimental.set_loop_options( |
|
shape_invariants=[(indices_previous, tf.TensorShape([None, None]))] |
|
) |
|
cz_indices = tf.concat((indices_last, indices_previous), axis=1) |
|
logits = self.predict_next_indices(cz_indices[:, :-1], indices_last) |
|
|
|
|
|
|
|
|
|
predicted_logits = predicted_logits.write(index, logits) |
|
indices_previous = self.convert_logits_to_indices( |
|
logits, tf.shape(indices_first) |
|
) |
|
index = tf.add(index, 1) |
|
|
|
return predicted_logits.stack(), tf.shape(quant_first), tf.shape(indices_first) |
|
|
|
@tf.function() |
|
def generate_video(self, first_frame, last_frame, n_frames): |
|
predicted_logits, quant_shape, indices_shape = self.predict_logits( |
|
first_frame, last_frame, n_frames |
|
) |
|
|
|
generated_images = tf.TensorArray( |
|
tf.float32, size=0, dynamic_size=True, clear_after_read=False |
|
) |
|
generated_images = generated_images.write(0, first_frame) |
|
|
|
index = tf.constant(1) |
|
while tf.less(index, tf.reduce_max(n_frames)): |
|
indices = self.convert_logits_to_indices(predicted_logits[index], indices_shape) |
|
quant = self.first_stage_model.quantize.get_codebook_entry( |
|
indices, |
|
shape=quant_shape, |
|
) |
|
decoded = self.first_stage_model.decode(quant) |
|
generated_images = generated_images.write(index, decoded) |
|
index = tf.add(index, 1) |
|
|
|
stacked_images = generated_images.stack() |
|
videos = tf.transpose(stacked_images, (1, 0, 2, 3, 4)) |
|
return videos |
|
|