import numpy as np import tensorflow as tf import tensorflow_addons as tfa from ganime.configs.model_configs import GPTConfig, ModelConfig from ganime.model.vqgan_clean.experimental.transformer import Transformer from ganime.model.vqgan_clean.vqgan import VQGAN from ganime.trainer.warmup.cosine import WarmUpCosine from tensorflow import keras from tensorflow.keras import Model, layers from ganime.model.vqgan_clean.losses.losses import Losses from ganime.trainer.warmup.base import create_warmup_scheduler from ganime.visualization.images import unnormalize_if_necessary class Net2Net(Model): def __init__( self, transformer_config: GPTConfig, first_stage_config: ModelConfig, trainer_config, num_replicas: int = 1, **kwargs, ): super().__init__(**kwargs) self.first_stage_model = VQGAN(**first_stage_config) self.transformer = Transformer(transformer_config) if "checkpoint_path" in transformer_config: print(f"Restoring weights from {transformer_config['checkpoint_path']}") self.load_weights(transformer_config["checkpoint_path"]) self.scheduled_lrs = create_warmup_scheduler( trainer_config, num_devices=num_replicas ) optimizer = tfa.optimizers.AdamW( learning_rate=self.scheduled_lrs, weight_decay=1e-4 ) self.compile( optimizer=optimizer, # loss=self.loss_fn, # run_eagerly=True, ) self.n_frames_before = trainer_config["n_frames_before"] # Gradient accumulation self.gradient_accumulation = [ tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False) for v in self.transformer.trainable_variables ] self.accumulation_size = trainer_config["accumulation_size"] # Losses self.perceptual_loss_weight = trainer_config["perceptual_loss_weight"] losses = Losses(num_replicas=num_replicas) self.scce_loss = losses.scce_loss self.perceptual_loss = losses.perceptual_loss self.total_loss_tracker = keras.metrics.Mean(name="total_loss") self.scce_loss_tracker = keras.metrics.Mean(name="scce_loss") self.perceptual_loss_tracker = keras.metrics.Mean(name="perceptual_loss") self.epoch = 0 self.stop_ground_truth_after_epoch = trainer_config[ "stop_ground_truth_after_epoch" ] def apply_accu_gradients(self): # apply accumulated gradients self.optimizer.apply_gradients( zip(self.gradient_accumulation, self.transformer.trainable_variables) ) # reset for i in range(len(self.gradient_accumulation)): self.gradient_accumulation[i].assign( tf.zeros_like(self.transformer.trainable_variables[i], dtype=tf.float32) ) @property def metrics(self): # We list our `Metric` objects here so that `reset_states()` can be # called automatically at the start of each epoch # or at the start of `evaluate()`. # If you don't implement this property, you have to call # `reset_states()` yourself at the time of your choosing. return [ self.total_loss_tracker, self.scce_loss_tracker, self.perceptual_loss_tracker, ] @tf.function() def encode_to_z(self, x): quant_z, indices, quantized_loss = self.first_stage_model.encode(x) batch_size = tf.shape(quant_z)[0] indices = tf.reshape(indices, shape=(batch_size, -1)) return quant_z, indices def call(self, inputs, training=False, mask=None, return_losses=False): return self.predict_video(inputs, training, return_losses) def predict(self, data, sample=False, temperature=1.0): video = self.predict_video( data, training=False, return_losses=False, sample=sample, temperature=temperature, ) video = unnormalize_if_necessary(video) return video def get_remaining_frames(self, inputs): if "remaining_frames" in inputs: remaining_frames = inputs["remaining_frames"] else: raise NotImplementedError remaining_frames = tf.cast(remaining_frames, tf.int64) return remaining_frames # @tf.function() def predict_video( self, inputs, training=False, return_losses=False, sample=False, temperature=1.0 ): first_frame = inputs["first_frame"] last_frame = inputs["last_frame"] n_frames = tf.reduce_max(inputs["n_frames"]) remaining_frames = self.get_remaining_frames(inputs) try: ground_truth = inputs["y"] except AttributeError: ground_truth = None previous_frames = tf.expand_dims(first_frame, axis=1) predictions = tf.TensorArray( tf.float32, size=0, dynamic_size=True, clear_after_read=False ) quant_last, indices_last = self.encode_to_z(last_frame) total_loss = tf.constant(0.0) scce_loss = tf.constant(0.0) perceptual_loss = tf.constant(0.0) current_frame_index = tf.constant(1) while tf.less(current_frame_index, n_frames): tf.autograph.experimental.set_loop_options( shape_invariants=[ (previous_frames, tf.TensorShape([None, None, None, None, 3])) ], ) if ground_truth is not None: target_frame = ground_truth[:, current_frame_index] else: target_frame = None y_pred, losses = self.predict_next_frame( remaining_frames[:, current_frame_index], previous_frames, last_frame, indices_last, quant_last, target_frame=target_frame, training=training, sample=sample, temperature=temperature, ) predictions = predictions.write(current_frame_index, y_pred) if training and self.epoch < self.stop_ground_truth_after_epoch: start_index = tf.math.maximum( 0, current_frame_index - self.n_frames_before ) previous_frames = ground_truth[ :, start_index + 1 : current_frame_index + 1 ] else: previous_frames = predictions.stack() previous_frames = tf.transpose(previous_frames, (1, 0, 2, 3, 4)) previous_frames = previous_frames[:, -self.n_frames_before :] current_frame_index = tf.add(current_frame_index, 1) total_loss = tf.add(total_loss, losses[0]) scce_loss = tf.add(scce_loss, losses[1]) perceptual_loss = tf.add(perceptual_loss, losses[2]) predictions = predictions.stack() predictions = tf.transpose(predictions, (1, 0, 2, 3, 4)) total_loss = tf.divide(total_loss, tf.cast(n_frames, tf.float32)) scce_loss = tf.divide(scce_loss, tf.cast(n_frames, tf.float32)) perceptual_loss = tf.divide(perceptual_loss, tf.cast(n_frames, tf.float32)) if return_losses: return predictions, total_loss, scce_loss, perceptual_loss else: return predictions def predict_next_frame( self, remaining_frames, previous_frames, last_frame, indices_last, quant_last, target_frame=None, training=False, sample=False, temperature=1.0, ): # previous frames is of shape (batch_size, n_frames, height, width, 3) previous_frames = tf.transpose(previous_frames, (1, 0, 2, 3, 4)) # previous frames is now of shape (n_frames, batch_size, height, width, 3) indices_previous = tf.map_fn( lambda x: self.encode_to_z(x)[1], previous_frames, fn_output_signature=tf.int64, ) # indices is of shape (n_frames, batch_size, n_z) indices_previous = tf.transpose(indices_previous, (1, 0, 2)) # indices is now of shape (batch_size, n_frames, n_z) batch_size, n_frames, n_z = ( tf.shape(indices_previous)[0], tf.shape(indices_previous)[1], tf.shape(indices_previous)[2], ) indices_previous = tf.reshape( indices_previous, shape=(batch_size, n_frames * n_z) ) if target_frame is not None: _, target_indices = self.encode_to_z(target_frame) else: target_indices = None if training: next_frame, losses = self.train_predict_next_frame( remaining_frames, indices_last, indices_previous, target_indices=target_indices, target_frame=target_frame, quant_shape=tf.shape(quant_last), indices_shape=tf.shape(indices_last), ) else: next_frame, losses = self.predict_next_frame_body( remaining_frames, indices_last, indices_previous, target_indices=target_indices, target_frame=target_frame, quant_shape=tf.shape(quant_last), indices_shape=tf.shape(indices_last), sample=sample, temperature=temperature, ) return next_frame, losses def predict_next_frame_body( self, remaining_frames, last_frame_indices, previous_frame_indices, quant_shape, indices_shape, target_indices=None, target_frame=None, sample=False, temperature=1.0, ): logits = self.transformer( (remaining_frames, last_frame_indices, previous_frame_indices) ) next_frame = self.convert_logits_to_image( logits, quant_shape=quant_shape, indices_shape=indices_shape, sample=sample, temperature=temperature, ) if target_indices is not None: scce_loss = self.scce_loss(target_indices, logits) else: scce_loss = 0.0 if target_frame is not None: perceptual_loss = 1.0 * self.perceptual_loss(target_frame, next_frame) else: perceptual_loss = 0.0 frame_loss = scce_loss + perceptual_loss # self.total_loss_tracker.update_state(frame_loss) # self.scce_loss_tracker.update_state(scce_loss) # self.perceptual_loss_tracker.update_state(perceptual_loss) return next_frame, (frame_loss, scce_loss, perceptual_loss) def train_predict_next_frame( self, remaining_frames, last_frame_indices, previous_frame_indices, quant_shape, indices_shape, target_indices, target_frame, ): with tf.GradientTape() as tape: next_frame, losses = self.predict_next_frame_body( remaining_frames=remaining_frames, last_frame_indices=last_frame_indices, previous_frame_indices=previous_frame_indices, target_indices=target_indices, quant_shape=quant_shape, indices_shape=indices_shape, target_frame=target_frame, sample=False, ) frame_loss = losses[0] # Calculate batch gradients gradients = tape.gradient(frame_loss, self.transformer.trainable_variables) # Accumulate batch gradients for i in range(len(self.gradient_accumulation)): self.gradient_accumulation[i].assign_add(tf.cast(gradients[i], tf.float32)) return next_frame, losses def convert_logits_to_image( self, logits, quant_shape, indices_shape, sample=False, temperature=1.0 ): if sample: array = [] for i in range(logits.shape[1]): sub_logits = logits[:, i] sub_logits = sub_logits / temperature # sub_logits, _ = tf.math.top_k(sub_logits, k=1) probs = tf.keras.activations.softmax(sub_logits) probs, probs_index = tf.math.top_k(probs, k=50) selection_index = tf.random.categorical( tf.math.log(probs), num_samples=1 ) ix = tf.gather_nd(probs_index, selection_index, batch_dims=1) ix = tf.reshape(ix, (-1, 1)) array.append(ix) generated_indices = tf.concat(array, axis=-1) else: probs = tf.keras.activations.softmax(logits) _, generated_indices = tf.math.top_k(probs) generated_indices = tf.reshape( generated_indices, indices_shape, ) quant = self.first_stage_model.quantize.get_codebook_entry( generated_indices, shape=quant_shape ) return self.first_stage_model.decode(quant) def train_step(self, data): batch_total_loss, batch_scce_loss, batch_perceptual_loss = 0.0, 0.0, 0.0 for i in range(self.accumulation_size): sub_data = { key: value[ self.accumulation_size * i : self.accumulation_size * (i + 1) ] for key, value in data.items() } _, total_loss, scce_loss, perceptual_loss = self( sub_data, training=True, return_losses=True ) batch_total_loss += total_loss batch_scce_loss += scce_loss batch_perceptual_loss += perceptual_loss self.apply_accu_gradients() self.total_loss_tracker.update_state(batch_total_loss) self.scce_loss_tracker.update_state(batch_scce_loss) self.perceptual_loss_tracker.update_state(batch_perceptual_loss) self.epoch += 1 return {m.name: m.result() for m in self.metrics} def test_step(self, data): _, total_loss, scce_loss, perceptual_loss = self( data, training=False, return_losses=True ) self.total_loss_tracker.update_state(total_loss) self.scce_loss_tracker.update_state(scce_loss) self.perceptual_loss_tracker.update_state(perceptual_loss) return {m.name: m.result() for m in self.metrics}