Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
14.6 kB
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from ganime.configs.model_configs import GPTConfig, ModelConfig
from ganime.model.vqgan_clean.experimental.transformer import Transformer
from ganime.model.vqgan_clean.vqgan import VQGAN
from ganime.trainer.warmup.cosine import WarmUpCosine
from tensorflow import keras
from tensorflow.keras import Model, layers
from ganime.model.vqgan_clean.losses.losses import Losses
from ganime.trainer.warmup.base import create_warmup_scheduler
from ganime.visualization.images import unnormalize_if_necessary
class Net2Net(Model):
def __init__(
self,
transformer_config: GPTConfig,
first_stage_config: ModelConfig,
trainer_config,
num_replicas: int = 1,
**kwargs,
):
super().__init__(**kwargs)
self.first_stage_model = VQGAN(**first_stage_config)
self.transformer = Transformer(transformer_config)
if "checkpoint_path" in transformer_config:
print(f"Restoring weights from {transformer_config['checkpoint_path']}")
self.load_weights(transformer_config["checkpoint_path"])
self.scheduled_lrs = create_warmup_scheduler(
trainer_config, num_devices=num_replicas
)
optimizer = tfa.optimizers.AdamW(
learning_rate=self.scheduled_lrs, weight_decay=1e-4
)
self.compile(
optimizer=optimizer,
# loss=self.loss_fn,
# run_eagerly=True,
)
self.n_frames_before = trainer_config["n_frames_before"]
# Gradient accumulation
self.gradient_accumulation = [
tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False)
for v in self.transformer.trainable_variables
]
self.accumulation_size = trainer_config["accumulation_size"]
# Losses
self.perceptual_loss_weight = trainer_config["perceptual_loss_weight"]
losses = Losses(num_replicas=num_replicas)
self.scce_loss = losses.scce_loss
self.perceptual_loss = losses.perceptual_loss
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.scce_loss_tracker = keras.metrics.Mean(name="scce_loss")
self.perceptual_loss_tracker = keras.metrics.Mean(name="perceptual_loss")
self.epoch = 0
self.stop_ground_truth_after_epoch = trainer_config[
"stop_ground_truth_after_epoch"
]
def apply_accu_gradients(self):
# apply accumulated gradients
self.optimizer.apply_gradients(
zip(self.gradient_accumulation, self.transformer.trainable_variables)
)
# reset
for i in range(len(self.gradient_accumulation)):
self.gradient_accumulation[i].assign(
tf.zeros_like(self.transformer.trainable_variables[i], dtype=tf.float32)
)
@property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
# If you don't implement this property, you have to call
# `reset_states()` yourself at the time of your choosing.
return [
self.total_loss_tracker,
self.scce_loss_tracker,
self.perceptual_loss_tracker,
]
@tf.function()
def encode_to_z(self, x):
quant_z, indices, quantized_loss = self.first_stage_model.encode(x)
batch_size = tf.shape(quant_z)[0]
indices = tf.reshape(indices, shape=(batch_size, -1))
return quant_z, indices
def call(self, inputs, training=False, mask=None, return_losses=False):
return self.predict_video(inputs, training, return_losses)
def predict(self, data, sample=False, temperature=1.0):
video = self.predict_video(
data,
training=False,
return_losses=False,
sample=sample,
temperature=temperature,
)
video = unnormalize_if_necessary(video)
return video
def get_remaining_frames(self, inputs):
if "remaining_frames" in inputs:
remaining_frames = inputs["remaining_frames"]
else:
raise NotImplementedError
remaining_frames = tf.cast(remaining_frames, tf.int64)
return remaining_frames
# @tf.function()
def predict_video(
self, inputs, training=False, return_losses=False, sample=False, temperature=1.0
):
first_frame = inputs["first_frame"]
last_frame = inputs["last_frame"]
n_frames = tf.reduce_max(inputs["n_frames"])
remaining_frames = self.get_remaining_frames(inputs)
try:
ground_truth = inputs["y"]
except AttributeError:
ground_truth = None
previous_frames = tf.expand_dims(first_frame, axis=1)
predictions = tf.TensorArray(
tf.float32, size=0, dynamic_size=True, clear_after_read=False
)
quant_last, indices_last = self.encode_to_z(last_frame)
total_loss = tf.constant(0.0)
scce_loss = tf.constant(0.0)
perceptual_loss = tf.constant(0.0)
current_frame_index = tf.constant(1)
while tf.less(current_frame_index, n_frames):
tf.autograph.experimental.set_loop_options(
shape_invariants=[
(previous_frames, tf.TensorShape([None, None, None, None, 3]))
],
)
if ground_truth is not None:
target_frame = ground_truth[:, current_frame_index]
else:
target_frame = None
y_pred, losses = self.predict_next_frame(
remaining_frames[:, current_frame_index],
previous_frames,
last_frame,
indices_last,
quant_last,
target_frame=target_frame,
training=training,
sample=sample,
temperature=temperature,
)
predictions = predictions.write(current_frame_index, y_pred)
if training and self.epoch < self.stop_ground_truth_after_epoch:
start_index = tf.math.maximum(
0, current_frame_index - self.n_frames_before
)
previous_frames = ground_truth[
:, start_index + 1 : current_frame_index + 1
]
else:
previous_frames = predictions.stack()
previous_frames = tf.transpose(previous_frames, (1, 0, 2, 3, 4))
previous_frames = previous_frames[:, -self.n_frames_before :]
current_frame_index = tf.add(current_frame_index, 1)
total_loss = tf.add(total_loss, losses[0])
scce_loss = tf.add(scce_loss, losses[1])
perceptual_loss = tf.add(perceptual_loss, losses[2])
predictions = predictions.stack()
predictions = tf.transpose(predictions, (1, 0, 2, 3, 4))
total_loss = tf.divide(total_loss, tf.cast(n_frames, tf.float32))
scce_loss = tf.divide(scce_loss, tf.cast(n_frames, tf.float32))
perceptual_loss = tf.divide(perceptual_loss, tf.cast(n_frames, tf.float32))
if return_losses:
return predictions, total_loss, scce_loss, perceptual_loss
else:
return predictions
def predict_next_frame(
self,
remaining_frames,
previous_frames,
last_frame,
indices_last,
quant_last,
target_frame=None,
training=False,
sample=False,
temperature=1.0,
):
# previous frames is of shape (batch_size, n_frames, height, width, 3)
previous_frames = tf.transpose(previous_frames, (1, 0, 2, 3, 4))
# previous frames is now of shape (n_frames, batch_size, height, width, 3)
indices_previous = tf.map_fn(
lambda x: self.encode_to_z(x)[1],
previous_frames,
fn_output_signature=tf.int64,
)
# indices is of shape (n_frames, batch_size, n_z)
indices_previous = tf.transpose(indices_previous, (1, 0, 2))
# indices is now of shape (batch_size, n_frames, n_z)
batch_size, n_frames, n_z = (
tf.shape(indices_previous)[0],
tf.shape(indices_previous)[1],
tf.shape(indices_previous)[2],
)
indices_previous = tf.reshape(
indices_previous, shape=(batch_size, n_frames * n_z)
)
if target_frame is not None:
_, target_indices = self.encode_to_z(target_frame)
else:
target_indices = None
if training:
next_frame, losses = self.train_predict_next_frame(
remaining_frames,
indices_last,
indices_previous,
target_indices=target_indices,
target_frame=target_frame,
quant_shape=tf.shape(quant_last),
indices_shape=tf.shape(indices_last),
)
else:
next_frame, losses = self.predict_next_frame_body(
remaining_frames,
indices_last,
indices_previous,
target_indices=target_indices,
target_frame=target_frame,
quant_shape=tf.shape(quant_last),
indices_shape=tf.shape(indices_last),
sample=sample,
temperature=temperature,
)
return next_frame, losses
def predict_next_frame_body(
self,
remaining_frames,
last_frame_indices,
previous_frame_indices,
quant_shape,
indices_shape,
target_indices=None,
target_frame=None,
sample=False,
temperature=1.0,
):
logits = self.transformer(
(remaining_frames, last_frame_indices, previous_frame_indices)
)
next_frame = self.convert_logits_to_image(
logits,
quant_shape=quant_shape,
indices_shape=indices_shape,
sample=sample,
temperature=temperature,
)
if target_indices is not None:
scce_loss = self.scce_loss(target_indices, logits)
else:
scce_loss = 0.0
if target_frame is not None:
perceptual_loss = 1.0 * self.perceptual_loss(target_frame, next_frame)
else:
perceptual_loss = 0.0
frame_loss = scce_loss + perceptual_loss
# self.total_loss_tracker.update_state(frame_loss)
# self.scce_loss_tracker.update_state(scce_loss)
# self.perceptual_loss_tracker.update_state(perceptual_loss)
return next_frame, (frame_loss, scce_loss, perceptual_loss)
def train_predict_next_frame(
self,
remaining_frames,
last_frame_indices,
previous_frame_indices,
quant_shape,
indices_shape,
target_indices,
target_frame,
):
with tf.GradientTape() as tape:
next_frame, losses = self.predict_next_frame_body(
remaining_frames=remaining_frames,
last_frame_indices=last_frame_indices,
previous_frame_indices=previous_frame_indices,
target_indices=target_indices,
quant_shape=quant_shape,
indices_shape=indices_shape,
target_frame=target_frame,
sample=False,
)
frame_loss = losses[0]
# Calculate batch gradients
gradients = tape.gradient(frame_loss, self.transformer.trainable_variables)
# Accumulate batch gradients
for i in range(len(self.gradient_accumulation)):
self.gradient_accumulation[i].assign_add(tf.cast(gradients[i], tf.float32))
return next_frame, losses
def convert_logits_to_image(
self, logits, quant_shape, indices_shape, sample=False, temperature=1.0
):
if sample:
array = []
for i in range(logits.shape[1]):
sub_logits = logits[:, i]
sub_logits = sub_logits / temperature
# sub_logits, _ = tf.math.top_k(sub_logits, k=1)
probs = tf.keras.activations.softmax(sub_logits)
probs, probs_index = tf.math.top_k(probs, k=50)
selection_index = tf.random.categorical(
tf.math.log(probs), num_samples=1
)
ix = tf.gather_nd(probs_index, selection_index, batch_dims=1)
ix = tf.reshape(ix, (-1, 1))
array.append(ix)
generated_indices = tf.concat(array, axis=-1)
else:
probs = tf.keras.activations.softmax(logits)
_, generated_indices = tf.math.top_k(probs)
generated_indices = tf.reshape(
generated_indices,
indices_shape,
)
quant = self.first_stage_model.quantize.get_codebook_entry(
generated_indices, shape=quant_shape
)
return self.first_stage_model.decode(quant)
def train_step(self, data):
batch_total_loss, batch_scce_loss, batch_perceptual_loss = 0.0, 0.0, 0.0
for i in range(self.accumulation_size):
sub_data = {
key: value[
self.accumulation_size * i : self.accumulation_size * (i + 1)
]
for key, value in data.items()
}
_, total_loss, scce_loss, perceptual_loss = self(
sub_data, training=True, return_losses=True
)
batch_total_loss += total_loss
batch_scce_loss += scce_loss
batch_perceptual_loss += perceptual_loss
self.apply_accu_gradients()
self.total_loss_tracker.update_state(batch_total_loss)
self.scce_loss_tracker.update_state(batch_scce_loss)
self.perceptual_loss_tracker.update_state(batch_perceptual_loss)
self.epoch += 1
return {m.name: m.result() for m in self.metrics}
def test_step(self, data):
_, total_loss, scce_loss, perceptual_loss = self(
data, training=False, return_losses=True
)
self.total_loss_tracker.update_state(total_loss)
self.scce_loss_tracker.update_state(scce_loss)
self.perceptual_loss_tracker.update_state(perceptual_loss)
return {m.name: m.result() for m in self.metrics}