Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
10.7 kB
import math
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from ganime.configs.model_configs import GPTConfig, ModelConfig
from ganime.model.vqgan_clean.transformer.mingpt import GPT
from ganime.model.vqgan_clean.vqgan import VQGAN
from tensorflow import keras
from tensorflow.keras import Model, layers
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
"""A LearningRateSchedule that uses a warmup cosine decay schedule."""
def __init__(self, lr_start, lr_max, warmup_steps, total_steps):
"""
Args:
lr_start: The initial learning rate
lr_max: The maximum learning rate to which lr should increase to in
the warmup steps
warmup_steps: The number of steps for which the model warms up
total_steps: The total number of steps for the model training
"""
super().__init__()
self.lr_start = lr_start
self.lr_max = lr_max
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.pi = tf.constant(np.pi)
def __call__(self, step):
# Check whether the total number of steps is larger than the warmup
# steps. If not, then throw a value error.
if self.total_steps < self.warmup_steps:
raise ValueError(
f"Total number of steps {self.total_steps} must be"
+ f"larger or equal to warmup steps {self.warmup_steps}."
)
# `cos_annealed_lr` is a graph that increases to 1 from the initial
# step to the warmup step. After that this graph decays to -1 at the
# final step mark.
cos_annealed_lr = tf.cos(
self.pi
* (tf.cast(step, tf.float32) - self.warmup_steps)
/ tf.cast(self.total_steps - self.warmup_steps, tf.float32)
)
# Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes
# from 0 to 2. Normalize the graph with 0.5 so that now it goes from 0
# to 1. With the normalized graph we scale it with `lr_max` such that
# it goes from 0 to `lr_max`
learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr)
# Check whether warmup_steps is more than 0.
if self.warmup_steps > 0:
# Check whether lr_max is larger that lr_start. If not, throw a value
# error.
if self.lr_max < self.lr_start:
raise ValueError(
f"lr_start {self.lr_start} must be smaller or"
+ f"equal to lr_max {self.lr_max}."
)
# Calculate the slope with which the learning rate should increase
# in the warumup schedule. The formula for slope is m = ((b-a)/steps)
slope = (self.lr_max - self.lr_start) / self.warmup_steps
# With the formula for a straight line (y = mx+c) build the warmup
# schedule
warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start
# When the current step is lesser that warmup steps, get the line
# graph. When the current step is greater than the warmup steps, get
# the scaled cos graph.
learning_rate = tf.where(
step < self.warmup_steps, warmup_rate, learning_rate
)
# When the current step is more that the total steps, return 0 else return
# the calculated graph.
return tf.where(
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
)
LEN_X_TRAIN = 8000
BATCH_SIZE = 16
N_EPOCHS = 500
TOTAL_STEPS = int(LEN_X_TRAIN / BATCH_SIZE * N_EPOCHS)
WARMUP_EPOCH_PERCENTAGE = 0.15
WARMUP_STEPS = int(TOTAL_STEPS * WARMUP_EPOCH_PERCENTAGE)
class Net2Net(Model):
def __init__(
self,
transformer_config: GPTConfig,
first_stage_config: ModelConfig,
cond_stage_config: ModelConfig,
):
super().__init__()
self.transformer = GPT(**transformer_config)
self.first_stage_model = VQGAN(**first_stage_config)
self.cond_stage_model = self.first_stage_model # VQGAN(**cond_stage_config)
self.loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
self.loss_tracker = keras.metrics.Mean(name="loss")
# self.compile(
# "adam",
# loss=self.loss_fn,
# )
# Calculate the number of steps for warmup.
# Initialize the warmupcosine schedule.
self.scheduled_lrs = WarmUpCosine(
lr_start=1e-5,
lr_max=2.5e-4,
warmup_steps=WARMUP_STEPS,
total_steps=TOTAL_STEPS,
)
self.compile(
optimizer=tfa.optimizers.AdamW(
learning_rate=self.scheduled_lrs, weight_decay=1e-4
),
loss=[self.loss_fn, None],
)
@property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
# If you don't implement this property, you have to call
# `reset_states()` yourself at the time of your choosing.
return [
self.loss_tracker,
]
def encode_to_z(self, x):
quant_z, indices, quantized_loss = self.first_stage_model.encode(x)
batch_size = tf.shape(quant_z)[0]
indices = tf.reshape(indices, shape=(batch_size, -1))
return quant_z, indices
def encode_to_c(self, c):
quant_c, indices, quantized_loss = self.cond_stage_model.encode(c)
batch_size = tf.shape(quant_c)[0]
indices = tf.reshape(indices, shape=(batch_size, -1))
return quant_c, indices
# def build(self, input_shape):
# self.first_stage_model.build(input_shape)
# self.cond_stage_model.build(input_shape)
# return super().build(input_shape)
def call(self, inputs, training=None, mask=None):
# x, c = inputs
# # one step to produce the logits
# _, z_indices = self.encode_to_z(x)
# _, c_indices = self.encode_to_c(c)
# cz_indices = tf.concat((c_indices, z_indices), axis=1)
# target = z_indices
# logits = self.transformer(
# cz_indices[:, :-1] # , training=training
# ) # don't know why -1
# logits = logits[:, tf.shape(c_indices)[1] - 1 :] # -1 here 'cause -1 above
# logits = tf.reshape(logits, shape=(-1, logits.shape[-1]))
# target = tf.reshape(target, shape=(-1,))
# return logits, target
if isinstance(inputs, tuple) and len(inputs) == 2:
first_last_frame, y = inputs
else:
first_last_frame, y = inputs, None
return self.process_video(first_last_frame, y)
@tf.function()
def process_image(self, x, c, target_image=None):
frame_loss = 0
# one step to produce the logits
quant_z, z_indices = self.encode_to_z(x)
_, c_indices = self.encode_to_c(c)
cz_indices = tf.concat((c_indices, z_indices), axis=1)
logits = self.transformer(
cz_indices[:, :-1] # , training=training
) # don't know why -1
# Remove the conditioned part
logits = logits[:, tf.shape(c_indices)[1] - 1 :] # -1 here 'cause -1 above
logits = tf.reshape(logits, shape=(-1, logits.shape[-1]))
if target_image is not None:
_, target_indices = self.encode_to_z(target_image)
target_indices = tf.reshape(target_indices, shape=(-1,))
frame_loss = tf.reduce_mean(
self.loss_fn(y_true=target_indices, y_pred=logits)
)
image = self.get_image(logits, tf.shape(quant_z))
return image, frame_loss
# @tf.function()
def process_video(self, first_last_frame, target_video=None):
first_frame = first_last_frame[:, 0]
last_frame = first_last_frame[:, -1]
x = first_frame
c = last_frame
total_loss = 0
generated_video = [x]
for i in range(19): # TODO change 19 to the number of frame in the video
if target_video is not None:
with tf.GradientTape() as tape:
target = target_video[:, i, ...] if target_video is not None else None
generated_image, frame_loss = self.process_image(x, c, target_image=target)
x = generated_image
generated_video.append(generated_image)
grads = tape.gradient(
frame_loss,
self.transformer.trainable_variables,
)
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
total_loss += frame_loss
else:
target = target_video[:, i, ...] if target_video is not None else None
generated_image, frame_loss = self.process_image(x, c, target_image=target)
x = generated_image
generated_video.append(generated_image)
if target_video is not None:
return tf.stack(generated_video, axis=1), total_loss
else:
return tf.stack(generated_video, axis=1)
def train_step(self, data):
first_last_frame, y = data
generated_video, loss = self.process_video(first_last_frame, y)
self.loss_tracker.update_state(loss)
# Log results.
return {m.name: m.result() for m in self.metrics}
def get_image(self, logits, shape):
probs = tf.keras.activations.softmax(logits)
_, generated_indices = tf.math.top_k(probs)
generated_indices = tf.reshape(
generated_indices,
(-1,), # , self.first_stage_model.quantize.num_embeddings)
)
quant = self.first_stage_model.quantize.get_codebook_entry(
generated_indices, shape=shape
)
return self.first_stage_model.decode(quant)
def test_step(self, data):
first_last_frame, y = data
generated_video, loss = self.process_video(first_last_frame, y)
self.loss_tracker.update_state(loss)
# Log results.
return {m.name: m.result() for m in self.metrics}
def decode_to_img(self, index, zshape):
quant_z = self.first_stage_model.quantize.get_codebook_entry(
tf.reshape(index, -1), shape=zshape
)
x = self.first_stage_model.decode(quant_z)
return x