Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
10.6 kB
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from ganime.configs.model_configs import GPTConfig, ModelConfig
from ganime.model.vqgan_clean.vqgan import VQGAN
from ganime.trainer.warmup.cosine import WarmUpCosine
from tensorflow import keras
from tensorflow.keras import Model, layers
from transformers import TFGPT2Model, GPT2Config
from tensorflow.keras import mixed_precision
class Net2Net(Model):
def __init__(
self,
transformer_config: GPTConfig,
first_stage_config: ModelConfig,
trainer_config,
**kwargs,
):
super().__init__(**kwargs)
self.first_stage_model = VQGAN(**first_stage_config)
# configuration = GPT2Config(**transformer_config)
# self.transformer = TFGPT2Model(configuration)#.from_pretrained("gpt2", **self.transformer_config)
# configuration = GPT2Config(**transformer_config)
self.transformer = TFGPT2Model.from_pretrained(
"gpt2-medium"
) # , **transformer_config)
if "checkpoint_path" in transformer_config:
print(f"Restoring weights from {transformer_config['checkpoint_path']}")
self.load_weights(transformer_config["checkpoint_path"])
self.loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
self.loss_tracker = keras.metrics.Mean(name="loss")
self.scheduled_lrs = self.create_warmup_scheduler(trainer_config)
optimizer = tfa.optimizers.AdamW(
learning_rate=self.scheduled_lrs, weight_decay=1e-4
)
self.compile(
optimizer=optimizer,
loss=self.loss_fn,
# run_eagerly=True,
)
# Gradient accumulation
self.gradient_accumulation = [
tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False)
for v in self.transformer.trainable_variables
]
def create_warmup_scheduler(self, trainer_config):
len_x_train = trainer_config["len_x_train"]
batch_size = trainer_config["batch_size"]
n_epochs = trainer_config["n_epochs"]
total_steps = int(len_x_train / batch_size * n_epochs)
warmup_epoch_percentage = trainer_config["warmup_epoch_percentage"]
warmup_steps = int(total_steps * warmup_epoch_percentage)
scheduled_lrs = WarmUpCosine(
lr_start=trainer_config["lr_start"],
lr_max=trainer_config["lr_max"],
warmup_steps=warmup_steps,
total_steps=total_steps,
)
return scheduled_lrs
def apply_accu_gradients(self):
# apply accumulated gradients
self.optimizer.apply_gradients(
zip(self.gradient_accumulation, self.transformer.trainable_variables)
)
# reset
for i in range(len(self.gradient_accumulation)):
self.gradient_accumulation[i].assign(
tf.zeros_like(self.transformer.trainable_variables[i], dtype=tf.float32)
)
@property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
# If you don't implement this property, you have to call
# `reset_states()` yourself at the time of your choosing.
return [
self.loss_tracker,
]
@tf.function(
# input_signature=[
# tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.float32),
# ]
)
def encode_to_z(self, x):
quant_z, indices, quantized_loss = self.first_stage_model.encode(x)
batch_size = tf.shape(quant_z)[0]
indices = tf.reshape(indices, shape=(batch_size, -1))
return quant_z, indices
def call(self, inputs, training=None, mask=None):
first_frame = inputs["first_frame"]
last_frame = inputs["last_frame"]
n_frames = inputs["n_frames"]
return self.generate_video(first_frame, last_frame, n_frames)
@tf.function()
def predict_next_indices(self, inputs, example_indices):
logits = self.transformer(inputs)
logits = logits.last_hidden_state
logits = tf.cast(logits, dtype=tf.float32)
# Remove the conditioned part
logits = logits[
:, tf.shape(example_indices)[1] - 1 :
] # -1 here 'cause -1 above
# logits = tf.reshape(logits, shape=(-1, tf.shape(logits)[-1]))
return logits
@tf.function()
def body(self, total_loss, frames, index, last_frame_indices):
previous_frame_indices = self.encode_to_z(frames[:, index - 1, ...])[1]
cz_indices = tf.concat((last_frame_indices, previous_frame_indices), axis=1)
target_indices = self.encode_to_z(frames[:, index, ...])[1]
# target_indices = tf.reshape(target_indices, shape=(-1,))
with tf.GradientTape() as tape:
logits = self.predict_next_indices(
cz_indices[:, :-1], last_frame_indices
) # don't know why -1
frame_loss = tf.cast(
tf.reduce_mean(self.loss_fn(target_indices, logits)),
dtype=tf.float32,
)
# Calculate batch gradients
gradients = tape.gradient(frame_loss, self.transformer.trainable_variables)
# Accumulate batch gradients
for i in range(len(self.gradient_accumulation)):
self.gradient_accumulation[i].assign_add(tf.cast(gradients[i], tf.float32))
index = tf.add(index, 1)
total_loss = tf.add(total_loss, frame_loss)
return total_loss, frames, index, last_frame_indices
def cond(self, total_loss, frames, index, last_frame_indices):
return tf.less(index, tf.shape(frames)[1])
def train_step(self, data):
first_frame = data["first_frame"]
last_frame = data["last_frame"]
frames = data["y"]
n_frames = data["n_frames"]
last_frame_indices = self.encode_to_z(last_frame)[1]
total_loss = 0.0
total_loss, _, _, _ = tf.while_loop(
cond=self.cond,
body=self.body,
loop_vars=(tf.constant(0.0), frames, tf.constant(1), last_frame_indices),
)
self.apply_accu_gradients()
self.loss_tracker.update_state(total_loss)
return {m.name: m.result() for m in self.metrics}
def cond_test_step(self, total_loss, frames, index, last_frame_indices):
return tf.less(index, tf.shape(frames)[1])
@tf.function()
def body_test_step(self, total_loss, frames, index, predicted_logits):
target_indices = self.encode_to_z(frames[:, index, ...])[1]
# target_indices = tf.reshape(target_indices, shape=(-1,))
logits = predicted_logits[index]
frame_loss = tf.cast(
tf.reduce_mean(self.loss_fn(target_indices, logits)),
dtype=tf.float32,
)
index = tf.add(index, 1)
total_loss = tf.add(total_loss, frame_loss)
return total_loss, frames, index, predicted_logits
def test_step(self, data):
first_frame = data["first_frame"]
last_frame = data["last_frame"]
frames = data["y"]
n_frames = data["n_frames"]
predicted_logits, _, _ = self.predict_logits(first_frame, last_frame, n_frames)
total_loss, _, _, _ = tf.while_loop(
cond=self.cond_test_step,
body=self.body_test_step,
loop_vars=(tf.constant(0.0), frames, tf.constant(1), predicted_logits),
)
self.loss_tracker.update_state(total_loss)
return {m.name: m.result() for m in self.metrics}
@tf.function()
def convert_logits_to_indices(self, logits, shape):
probs = tf.keras.activations.softmax(logits)
_, generated_indices = tf.math.top_k(probs)
generated_indices = tf.reshape(
generated_indices,
shape, # , self.first_stage_model.quantize.num_embeddings)
)
return generated_indices
# quant = self.first_stage_model.quantize.get_codebook_entry(
# generated_indices, shape=shape
# )
# return self.first_stage_model.decode(quant)
@tf.function()
def predict_logits(self, first_frame, last_frame, n_frames):
quant_first, indices_first = self.encode_to_z(first_frame)
quant_last, indices_last = self.encode_to_z(last_frame)
indices_previous = indices_first
predicted_logits = tf.TensorArray(
tf.float32, size=0, dynamic_size=True, clear_after_read=False
)
index = tf.constant(1)
while tf.less(index, tf.reduce_max(n_frames)):
tf.autograph.experimental.set_loop_options(
shape_invariants=[(indices_previous, tf.TensorShape([None, None]))]
)
cz_indices = tf.concat((indices_last, indices_previous), axis=1)
logits = self.predict_next_indices(cz_indices[:, :-1], indices_last)
# generated_indices = self.convert_logits_to_indices(
# logits, tf.shape(indices_last)
# )
predicted_logits = predicted_logits.write(index, logits)
indices_previous = self.convert_logits_to_indices(
logits, tf.shape(indices_first)
)
index = tf.add(index, 1)
return predicted_logits.stack(), tf.shape(quant_first), tf.shape(indices_first)
@tf.function()
def generate_video(self, first_frame, last_frame, n_frames):
predicted_logits, quant_shape, indices_shape = self.predict_logits(
first_frame, last_frame, n_frames
)
generated_images = tf.TensorArray(
tf.float32, size=0, dynamic_size=True, clear_after_read=False
)
generated_images = generated_images.write(0, first_frame)
index = tf.constant(1)
while tf.less(index, tf.reduce_max(n_frames)):
indices = self.convert_logits_to_indices(predicted_logits[index], indices_shape)
quant = self.first_stage_model.quantize.get_codebook_entry(
indices,
shape=quant_shape,
)
decoded = self.first_stage_model.decode(quant)
generated_images = generated_images.write(index, decoded)
index = tf.add(index, 1)
stacked_images = generated_images.stack()
videos = tf.transpose(stacked_images, (1, 0, 2, 3, 4))
return videos