|
import numpy as np |
|
import tensorflow as tf |
|
import tensorflow_addons as tfa |
|
from ganime.configs.model_configs import GPTConfig, ModelConfig |
|
from ganime.model.vqgan_clean.experimental.transformer import Transformer |
|
from ganime.model.vqgan_clean.vqgan import VQGAN |
|
from ganime.trainer.warmup.cosine import WarmUpCosine |
|
from tensorflow import keras |
|
from tensorflow.keras import Model, layers |
|
from ganime.model.vqgan_clean.losses.losses import Losses |
|
from ganime.trainer.warmup.base import create_warmup_scheduler |
|
from ganime.visualization.images import unnormalize_if_necessary |
|
|
|
|
|
class Net2Net(Model): |
|
def __init__( |
|
self, |
|
transformer_config: GPTConfig, |
|
first_stage_config: ModelConfig, |
|
trainer_config, |
|
num_replicas: int = 1, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.first_stage_model = VQGAN(**first_stage_config) |
|
self.transformer = Transformer(transformer_config) |
|
|
|
if "checkpoint_path" in transformer_config: |
|
print(f"Restoring weights from {transformer_config['checkpoint_path']}") |
|
self.load_weights(transformer_config["checkpoint_path"]) |
|
|
|
self.scheduled_lrs = create_warmup_scheduler( |
|
trainer_config, num_devices=num_replicas |
|
) |
|
|
|
optimizer = tfa.optimizers.AdamW( |
|
learning_rate=self.scheduled_lrs, weight_decay=1e-4 |
|
) |
|
self.compile( |
|
optimizer=optimizer, |
|
|
|
|
|
) |
|
|
|
self.n_frames_before = trainer_config["n_frames_before"] |
|
|
|
|
|
self.gradient_accumulation = [ |
|
tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False) |
|
for v in self.transformer.trainable_variables |
|
] |
|
self.accumulation_size = trainer_config["accumulation_size"] |
|
|
|
|
|
self.perceptual_loss_weight = trainer_config["perceptual_loss_weight"] |
|
losses = Losses(num_replicas=num_replicas) |
|
self.scce_loss = losses.scce_loss |
|
self.perceptual_loss = losses.perceptual_loss |
|
|
|
self.total_loss_tracker = keras.metrics.Mean(name="total_loss") |
|
self.scce_loss_tracker = keras.metrics.Mean(name="scce_loss") |
|
self.perceptual_loss_tracker = keras.metrics.Mean(name="perceptual_loss") |
|
|
|
self.epoch = 0 |
|
self.stop_ground_truth_after_epoch = trainer_config[ |
|
"stop_ground_truth_after_epoch" |
|
] |
|
|
|
def apply_accu_gradients(self): |
|
|
|
self.optimizer.apply_gradients( |
|
zip(self.gradient_accumulation, self.transformer.trainable_variables) |
|
) |
|
|
|
|
|
for i in range(len(self.gradient_accumulation)): |
|
self.gradient_accumulation[i].assign( |
|
tf.zeros_like(self.transformer.trainable_variables[i], dtype=tf.float32) |
|
) |
|
|
|
@property |
|
def metrics(self): |
|
|
|
|
|
|
|
|
|
|
|
return [ |
|
self.total_loss_tracker, |
|
self.scce_loss_tracker, |
|
self.perceptual_loss_tracker, |
|
] |
|
|
|
@tf.function() |
|
def encode_to_z(self, x): |
|
quant_z, indices, quantized_loss = self.first_stage_model.encode(x) |
|
|
|
batch_size = tf.shape(quant_z)[0] |
|
|
|
indices = tf.reshape(indices, shape=(batch_size, -1)) |
|
return quant_z, indices |
|
|
|
def call(self, inputs, training=False, mask=None, return_losses=False): |
|
|
|
return self.predict_video(inputs, training, return_losses) |
|
|
|
def predict(self, data, sample=False, temperature=1.0): |
|
video = self.predict_video( |
|
data, |
|
training=False, |
|
return_losses=False, |
|
sample=sample, |
|
temperature=temperature, |
|
) |
|
video = unnormalize_if_necessary(video) |
|
return video |
|
|
|
def get_remaining_frames(self, inputs): |
|
if "remaining_frames" in inputs: |
|
remaining_frames = inputs["remaining_frames"] |
|
else: |
|
raise NotImplementedError |
|
remaining_frames = tf.cast(remaining_frames, tf.int64) |
|
return remaining_frames |
|
|
|
|
|
def predict_video( |
|
self, inputs, training=False, return_losses=False, sample=False, temperature=1.0 |
|
): |
|
first_frame = inputs["first_frame"] |
|
last_frame = inputs["last_frame"] |
|
n_frames = tf.reduce_max(inputs["n_frames"]) |
|
remaining_frames = self.get_remaining_frames(inputs) |
|
|
|
try: |
|
ground_truth = inputs["y"] |
|
except AttributeError: |
|
ground_truth = None |
|
|
|
previous_frames = tf.expand_dims(first_frame, axis=1) |
|
|
|
predictions = tf.TensorArray( |
|
tf.float32, size=0, dynamic_size=True, clear_after_read=False |
|
) |
|
|
|
quant_last, indices_last = self.encode_to_z(last_frame) |
|
|
|
total_loss = tf.constant(0.0) |
|
scce_loss = tf.constant(0.0) |
|
perceptual_loss = tf.constant(0.0) |
|
|
|
current_frame_index = tf.constant(1) |
|
while tf.less(current_frame_index, n_frames): |
|
tf.autograph.experimental.set_loop_options( |
|
shape_invariants=[ |
|
(previous_frames, tf.TensorShape([None, None, None, None, 3])) |
|
], |
|
) |
|
|
|
if ground_truth is not None: |
|
target_frame = ground_truth[:, current_frame_index] |
|
else: |
|
target_frame = None |
|
|
|
y_pred, losses = self.predict_next_frame( |
|
remaining_frames[:, current_frame_index], |
|
previous_frames, |
|
last_frame, |
|
indices_last, |
|
quant_last, |
|
target_frame=target_frame, |
|
training=training, |
|
sample=sample, |
|
temperature=temperature, |
|
) |
|
predictions = predictions.write(current_frame_index, y_pred) |
|
|
|
if training and self.epoch < self.stop_ground_truth_after_epoch: |
|
start_index = tf.math.maximum( |
|
0, current_frame_index - self.n_frames_before |
|
) |
|
previous_frames = ground_truth[ |
|
:, start_index + 1 : current_frame_index + 1 |
|
] |
|
else: |
|
previous_frames = predictions.stack() |
|
previous_frames = tf.transpose(previous_frames, (1, 0, 2, 3, 4)) |
|
previous_frames = previous_frames[:, -self.n_frames_before :] |
|
|
|
current_frame_index = tf.add(current_frame_index, 1) |
|
total_loss = tf.add(total_loss, losses[0]) |
|
scce_loss = tf.add(scce_loss, losses[1]) |
|
perceptual_loss = tf.add(perceptual_loss, losses[2]) |
|
|
|
predictions = predictions.stack() |
|
predictions = tf.transpose(predictions, (1, 0, 2, 3, 4)) |
|
|
|
total_loss = tf.divide(total_loss, tf.cast(n_frames, tf.float32)) |
|
scce_loss = tf.divide(scce_loss, tf.cast(n_frames, tf.float32)) |
|
perceptual_loss = tf.divide(perceptual_loss, tf.cast(n_frames, tf.float32)) |
|
|
|
if return_losses: |
|
return predictions, total_loss, scce_loss, perceptual_loss |
|
else: |
|
return predictions |
|
|
|
def predict_next_frame( |
|
self, |
|
remaining_frames, |
|
previous_frames, |
|
last_frame, |
|
indices_last, |
|
quant_last, |
|
target_frame=None, |
|
training=False, |
|
sample=False, |
|
temperature=1.0, |
|
): |
|
|
|
previous_frames = tf.transpose(previous_frames, (1, 0, 2, 3, 4)) |
|
|
|
|
|
indices_previous = tf.map_fn( |
|
lambda x: self.encode_to_z(x)[1], |
|
previous_frames, |
|
fn_output_signature=tf.int64, |
|
) |
|
|
|
|
|
indices_previous = tf.transpose(indices_previous, (1, 0, 2)) |
|
|
|
batch_size, n_frames, n_z = ( |
|
tf.shape(indices_previous)[0], |
|
tf.shape(indices_previous)[1], |
|
tf.shape(indices_previous)[2], |
|
) |
|
indices_previous = tf.reshape( |
|
indices_previous, shape=(batch_size, n_frames * n_z) |
|
) |
|
|
|
if target_frame is not None: |
|
_, target_indices = self.encode_to_z(target_frame) |
|
else: |
|
target_indices = None |
|
|
|
if training: |
|
next_frame, losses = self.train_predict_next_frame( |
|
remaining_frames, |
|
indices_last, |
|
indices_previous, |
|
target_indices=target_indices, |
|
target_frame=target_frame, |
|
quant_shape=tf.shape(quant_last), |
|
indices_shape=tf.shape(indices_last), |
|
) |
|
else: |
|
next_frame, losses = self.predict_next_frame_body( |
|
remaining_frames, |
|
indices_last, |
|
indices_previous, |
|
target_indices=target_indices, |
|
target_frame=target_frame, |
|
quant_shape=tf.shape(quant_last), |
|
indices_shape=tf.shape(indices_last), |
|
sample=sample, |
|
temperature=temperature, |
|
) |
|
|
|
return next_frame, losses |
|
|
|
def predict_next_frame_body( |
|
self, |
|
remaining_frames, |
|
last_frame_indices, |
|
previous_frame_indices, |
|
quant_shape, |
|
indices_shape, |
|
target_indices=None, |
|
target_frame=None, |
|
sample=False, |
|
temperature=1.0, |
|
): |
|
logits = self.transformer( |
|
(remaining_frames, last_frame_indices, previous_frame_indices) |
|
) |
|
next_frame = self.convert_logits_to_image( |
|
logits, |
|
quant_shape=quant_shape, |
|
indices_shape=indices_shape, |
|
sample=sample, |
|
temperature=temperature, |
|
) |
|
if target_indices is not None: |
|
scce_loss = self.scce_loss(target_indices, logits) |
|
else: |
|
scce_loss = 0.0 |
|
|
|
if target_frame is not None: |
|
perceptual_loss = 1.0 * self.perceptual_loss(target_frame, next_frame) |
|
else: |
|
perceptual_loss = 0.0 |
|
|
|
frame_loss = scce_loss + perceptual_loss |
|
|
|
|
|
|
|
|
|
|
|
return next_frame, (frame_loss, scce_loss, perceptual_loss) |
|
|
|
def train_predict_next_frame( |
|
self, |
|
remaining_frames, |
|
last_frame_indices, |
|
previous_frame_indices, |
|
quant_shape, |
|
indices_shape, |
|
target_indices, |
|
target_frame, |
|
): |
|
with tf.GradientTape() as tape: |
|
next_frame, losses = self.predict_next_frame_body( |
|
remaining_frames=remaining_frames, |
|
last_frame_indices=last_frame_indices, |
|
previous_frame_indices=previous_frame_indices, |
|
target_indices=target_indices, |
|
quant_shape=quant_shape, |
|
indices_shape=indices_shape, |
|
target_frame=target_frame, |
|
sample=False, |
|
) |
|
frame_loss = losses[0] |
|
|
|
gradients = tape.gradient(frame_loss, self.transformer.trainable_variables) |
|
|
|
|
|
for i in range(len(self.gradient_accumulation)): |
|
self.gradient_accumulation[i].assign_add(tf.cast(gradients[i], tf.float32)) |
|
|
|
return next_frame, losses |
|
|
|
def convert_logits_to_image( |
|
self, logits, quant_shape, indices_shape, sample=False, temperature=1.0 |
|
): |
|
if sample: |
|
array = [] |
|
for i in range(logits.shape[1]): |
|
sub_logits = logits[:, i] |
|
sub_logits = sub_logits / temperature |
|
|
|
probs = tf.keras.activations.softmax(sub_logits) |
|
probs, probs_index = tf.math.top_k(probs, k=50) |
|
selection_index = tf.random.categorical( |
|
tf.math.log(probs), num_samples=1 |
|
) |
|
ix = tf.gather_nd(probs_index, selection_index, batch_dims=1) |
|
ix = tf.reshape(ix, (-1, 1)) |
|
array.append(ix) |
|
generated_indices = tf.concat(array, axis=-1) |
|
else: |
|
probs = tf.keras.activations.softmax(logits) |
|
_, generated_indices = tf.math.top_k(probs) |
|
|
|
generated_indices = tf.reshape( |
|
generated_indices, |
|
indices_shape, |
|
) |
|
quant = self.first_stage_model.quantize.get_codebook_entry( |
|
generated_indices, shape=quant_shape |
|
) |
|
|
|
return self.first_stage_model.decode(quant) |
|
|
|
def train_step(self, data): |
|
|
|
batch_total_loss, batch_scce_loss, batch_perceptual_loss = 0.0, 0.0, 0.0 |
|
for i in range(self.accumulation_size): |
|
sub_data = { |
|
key: value[ |
|
self.accumulation_size * i : self.accumulation_size * (i + 1) |
|
] |
|
for key, value in data.items() |
|
} |
|
_, total_loss, scce_loss, perceptual_loss = self( |
|
sub_data, training=True, return_losses=True |
|
) |
|
batch_total_loss += total_loss |
|
batch_scce_loss += scce_loss |
|
batch_perceptual_loss += perceptual_loss |
|
|
|
self.apply_accu_gradients() |
|
self.total_loss_tracker.update_state(batch_total_loss) |
|
self.scce_loss_tracker.update_state(batch_scce_loss) |
|
self.perceptual_loss_tracker.update_state(batch_perceptual_loss) |
|
self.epoch += 1 |
|
return {m.name: m.result() for m in self.metrics} |
|
|
|
def test_step(self, data): |
|
_, total_loss, scce_loss, perceptual_loss = self( |
|
data, training=False, return_losses=True |
|
) |
|
|
|
self.total_loss_tracker.update_state(total_loss) |
|
self.scce_loss_tracker.update_state(scce_loss) |
|
self.perceptual_loss_tracker.update_state(perceptual_loss) |
|
return {m.name: m.result() for m in self.metrics} |
|
|