Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
9.12 kB
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from ganime.configs.model_configs import GPTConfig, ModelConfig
from ganime.model.vqgan_clean.experimental.transformer import Transformer
from ganime.model.vqgan_clean.vqgan import VQGAN
from ganime.trainer.warmup.cosine import WarmUpCosine
from tensorflow import keras
from tensorflow.keras import Model, layers
from ganime.model.vqgan_clean.losses.losses import Losses
from ganime.trainer.warmup.base import create_warmup_scheduler
class Net2Net(Model):
def __init__(
self,
transformer_config: GPTConfig,
first_stage_config: ModelConfig,
trainer_config,
num_replicas: int = 1,
**kwargs,
):
super().__init__(**kwargs)
self.first_stage_model = VQGAN(**first_stage_config)
self.transformer = Transformer(transformer_config)
if "checkpoint_path" in transformer_config:
print(f"Restoring weights from {transformer_config['checkpoint_path']}")
self.load_weights(transformer_config["checkpoint_path"])
losses = Losses(num_replicas=num_replicas)
self.scce_loss = losses.scce_loss
self.scheduled_lrs = create_warmup_scheduler(
trainer_config, num_devices=num_replicas
)
optimizer = tfa.optimizers.AdamW(
learning_rate=self.scheduled_lrs, weight_decay=1e-4
)
self.compile(
optimizer=optimizer,
loss=self.loss_fn,
# run_eagerly=True,
)
# Gradient accumulation
self.gradient_accumulation = [
tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False)
for v in self.transformer.trainable_variables
]
self.loss_tracker = keras.metrics.Mean(name="loss")
def loss_fn(self, logits_true, logits_pred):
frame_loss = self.scce_loss(logits_true, logits_pred)
return frame_loss
def apply_accu_gradients(self):
# apply accumulated gradients
self.optimizer.apply_gradients(
zip(self.gradient_accumulation, self.transformer.trainable_variables)
)
# reset
for i in range(len(self.gradient_accumulation)):
self.gradient_accumulation[i].assign(
tf.zeros_like(self.transformer.trainable_variables[i], dtype=tf.float32)
)
@property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
# If you don't implement this property, you have to call
# `reset_states()` yourself at the time of your choosing.
return [
self.loss_tracker,
]
@tf.function()
def encode_to_z(self, x):
quant_z, indices, quantized_loss = self.first_stage_model.encode(x)
batch_size = tf.shape(quant_z)[0]
indices = tf.reshape(indices, shape=(batch_size, -1))
return quant_z, indices
def call(self, inputs, training=None, mask=None):
first_frame = inputs["first_frame"]
last_frame = inputs["last_frame"]
n_frames = inputs["n_frames"]
return self.generate_video(first_frame, last_frame, n_frames)
@tf.function()
def predict_next_indices(self, inputs):
logits = self.transformer(inputs)
return logits
@tf.function()
def body(self, total_loss, frames, index, last_frame_indices):
previous_frame_indices = self.encode_to_z(frames[:, index - 1, ...])[1]
target_indices = self.encode_to_z(frames[:, index, ...])[1]
# target_indices = tf.reshape(target_indices, shape=(-1,))
with tf.GradientTape() as tape:
logits = self.predict_next_indices(
(last_frame_indices, previous_frame_indices)
)
frame_loss = self.loss_fn(target_indices, logits)
# Calculate batch gradients
gradients = tape.gradient(frame_loss, self.transformer.trainable_variables)
# Accumulate batch gradients
for i in range(len(self.gradient_accumulation)):
self.gradient_accumulation[i].assign_add(tf.cast(gradients[i], tf.float32))
index = tf.add(index, 1)
total_loss = tf.add(total_loss, frame_loss)
return total_loss, frames, index, last_frame_indices
def cond(self, total_loss, frames, index, last_frame_indices):
return tf.less(index, tf.shape(frames)[1])
def train_step(self, data):
first_frame = data["first_frame"]
last_frame = data["last_frame"]
frames = data["y"]
n_frames = data["n_frames"]
last_frame_indices = self.encode_to_z(last_frame)[1]
total_loss = 0.0
total_loss, _, _, _ = tf.while_loop(
cond=self.cond,
body=self.body,
loop_vars=(tf.constant(0.0), frames, tf.constant(1), last_frame_indices),
)
self.apply_accu_gradients()
self.loss_tracker.update_state(total_loss)
return {m.name: m.result() for m in self.metrics}
def cond_test_step(self, total_loss, frames, index, last_frame_indices):
return tf.less(index, tf.shape(frames)[1])
@tf.function()
def body_test_step(self, total_loss, frames, index, predicted_logits):
target_indices = self.encode_to_z(frames[:, index, ...])[1]
# target_indices = tf.reshape(target_indices, shape=(-1,))
logits = predicted_logits[index]
frame_loss = self.loss_fn(target_indices, logits)
index = tf.add(index, 1)
total_loss = tf.add(total_loss, frame_loss)
return total_loss, frames, index, predicted_logits
def test_step(self, data):
first_frame = data["first_frame"]
last_frame = data["last_frame"]
frames = data["y"]
n_frames = data["n_frames"]
predicted_logits, _, _ = self.predict_logits(first_frame, last_frame, n_frames)
total_loss, _, _, _ = tf.while_loop(
cond=self.cond_test_step,
body=self.body_test_step,
loop_vars=(tf.constant(0.0), frames, tf.constant(1), predicted_logits),
)
self.loss_tracker.update_state(total_loss)
return {m.name: m.result() for m in self.metrics}
@tf.function()
def convert_logits_to_indices(self, logits, shape):
probs = tf.keras.activations.softmax(logits)
_, generated_indices = tf.math.top_k(probs)
generated_indices = tf.reshape(
generated_indices,
shape, # , self.first_stage_model.quantize.num_embeddings)
)
return generated_indices
# quant = self.first_stage_model.quantize.get_codebook_entry(
# generated_indices, shape=shape
# )
# return self.first_stage_model.decode(quant)
@tf.function()
def predict_logits(self, first_frame, last_frame, n_frames):
quant_first, indices_first = self.encode_to_z(first_frame)
quant_last, indices_last = self.encode_to_z(last_frame)
indices_previous = indices_first
predicted_logits = tf.TensorArray(
tf.float32, size=0, dynamic_size=True, clear_after_read=False
)
index = tf.constant(1)
while tf.less(index, tf.reduce_max(n_frames)):
tf.autograph.experimental.set_loop_options(
shape_invariants=[(indices_previous, tf.TensorShape([None, None]))]
)
logits = self.predict_next_indices((indices_last, indices_previous))
# generated_indices = self.convert_logits_to_indices(
# logits, tf.shape(indices_last)
# )
predicted_logits = predicted_logits.write(index, logits)
indices_previous = self.convert_logits_to_indices(
logits, tf.shape(indices_first)
)
index = tf.add(index, 1)
return predicted_logits.stack(), tf.shape(quant_first), tf.shape(indices_first)
@tf.function()
def generate_video(self, first_frame, last_frame, n_frames):
predicted_logits, quant_shape, indices_shape = self.predict_logits(
first_frame, last_frame, n_frames
)
generated_images = tf.TensorArray(
tf.float32, size=0, dynamic_size=True, clear_after_read=False
)
generated_images = generated_images.write(0, first_frame)
index = tf.constant(1)
while tf.less(index, tf.reduce_max(n_frames)):
indices = self.convert_logits_to_indices(
predicted_logits[index], indices_shape
)
quant = self.first_stage_model.quantize.get_codebook_entry(
indices,
shape=quant_shape,
)
decoded = self.first_stage_model.decode(quant)
generated_images = generated_images.write(index, decoded)
index = tf.add(index, 1)
stacked_images = generated_images.stack()
videos = tf.transpose(stacked_images, (1, 0, 2, 3, 4))
return videos