File size: 10,667 Bytes
3be620b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 |
import math
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from ganime.configs.model_configs import GPTConfig, ModelConfig
from ganime.model.vqgan_clean.transformer.mingpt import GPT
from ganime.model.vqgan_clean.vqgan import VQGAN
from tensorflow import keras
from tensorflow.keras import Model, layers
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
"""A LearningRateSchedule that uses a warmup cosine decay schedule."""
def __init__(self, lr_start, lr_max, warmup_steps, total_steps):
"""
Args:
lr_start: The initial learning rate
lr_max: The maximum learning rate to which lr should increase to in
the warmup steps
warmup_steps: The number of steps for which the model warms up
total_steps: The total number of steps for the model training
"""
super().__init__()
self.lr_start = lr_start
self.lr_max = lr_max
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.pi = tf.constant(np.pi)
def __call__(self, step):
# Check whether the total number of steps is larger than the warmup
# steps. If not, then throw a value error.
if self.total_steps < self.warmup_steps:
raise ValueError(
f"Total number of steps {self.total_steps} must be"
+ f"larger or equal to warmup steps {self.warmup_steps}."
)
# `cos_annealed_lr` is a graph that increases to 1 from the initial
# step to the warmup step. After that this graph decays to -1 at the
# final step mark.
cos_annealed_lr = tf.cos(
self.pi
* (tf.cast(step, tf.float32) - self.warmup_steps)
/ tf.cast(self.total_steps - self.warmup_steps, tf.float32)
)
# Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes
# from 0 to 2. Normalize the graph with 0.5 so that now it goes from 0
# to 1. With the normalized graph we scale it with `lr_max` such that
# it goes from 0 to `lr_max`
learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr)
# Check whether warmup_steps is more than 0.
if self.warmup_steps > 0:
# Check whether lr_max is larger that lr_start. If not, throw a value
# error.
if self.lr_max < self.lr_start:
raise ValueError(
f"lr_start {self.lr_start} must be smaller or"
+ f"equal to lr_max {self.lr_max}."
)
# Calculate the slope with which the learning rate should increase
# in the warumup schedule. The formula for slope is m = ((b-a)/steps)
slope = (self.lr_max - self.lr_start) / self.warmup_steps
# With the formula for a straight line (y = mx+c) build the warmup
# schedule
warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start
# When the current step is lesser that warmup steps, get the line
# graph. When the current step is greater than the warmup steps, get
# the scaled cos graph.
learning_rate = tf.where(
step < self.warmup_steps, warmup_rate, learning_rate
)
# When the current step is more that the total steps, return 0 else return
# the calculated graph.
return tf.where(
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
)
LEN_X_TRAIN = 8000
BATCH_SIZE = 16
N_EPOCHS = 500
TOTAL_STEPS = int(LEN_X_TRAIN / BATCH_SIZE * N_EPOCHS)
WARMUP_EPOCH_PERCENTAGE = 0.15
WARMUP_STEPS = int(TOTAL_STEPS * WARMUP_EPOCH_PERCENTAGE)
class Net2Net(Model):
def __init__(
self,
transformer_config: GPTConfig,
first_stage_config: ModelConfig,
cond_stage_config: ModelConfig,
):
super().__init__()
self.transformer = GPT(**transformer_config)
self.first_stage_model = VQGAN(**first_stage_config)
self.cond_stage_model = self.first_stage_model # VQGAN(**cond_stage_config)
self.loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
self.loss_tracker = keras.metrics.Mean(name="loss")
# self.compile(
# "adam",
# loss=self.loss_fn,
# )
# Calculate the number of steps for warmup.
# Initialize the warmupcosine schedule.
self.scheduled_lrs = WarmUpCosine(
lr_start=1e-5,
lr_max=2.5e-4,
warmup_steps=WARMUP_STEPS,
total_steps=TOTAL_STEPS,
)
self.compile(
optimizer=tfa.optimizers.AdamW(
learning_rate=self.scheduled_lrs, weight_decay=1e-4
),
loss=[self.loss_fn, None],
)
@property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
# If you don't implement this property, you have to call
# `reset_states()` yourself at the time of your choosing.
return [
self.loss_tracker,
]
def encode_to_z(self, x):
quant_z, indices, quantized_loss = self.first_stage_model.encode(x)
batch_size = tf.shape(quant_z)[0]
indices = tf.reshape(indices, shape=(batch_size, -1))
return quant_z, indices
def encode_to_c(self, c):
quant_c, indices, quantized_loss = self.cond_stage_model.encode(c)
batch_size = tf.shape(quant_c)[0]
indices = tf.reshape(indices, shape=(batch_size, -1))
return quant_c, indices
# def build(self, input_shape):
# self.first_stage_model.build(input_shape)
# self.cond_stage_model.build(input_shape)
# return super().build(input_shape)
def call(self, inputs, training=None, mask=None):
# x, c = inputs
# # one step to produce the logits
# _, z_indices = self.encode_to_z(x)
# _, c_indices = self.encode_to_c(c)
# cz_indices = tf.concat((c_indices, z_indices), axis=1)
# target = z_indices
# logits = self.transformer(
# cz_indices[:, :-1] # , training=training
# ) # don't know why -1
# logits = logits[:, tf.shape(c_indices)[1] - 1 :] # -1 here 'cause -1 above
# logits = tf.reshape(logits, shape=(-1, logits.shape[-1]))
# target = tf.reshape(target, shape=(-1,))
# return logits, target
if isinstance(inputs, tuple) and len(inputs) == 2:
first_last_frame, y = inputs
else:
first_last_frame, y = inputs, None
return self.process_video(first_last_frame, y)
@tf.function()
def process_image(self, x, c, target_image=None):
frame_loss = 0
# one step to produce the logits
quant_z, z_indices = self.encode_to_z(x)
_, c_indices = self.encode_to_c(c)
cz_indices = tf.concat((c_indices, z_indices), axis=1)
logits = self.transformer(
cz_indices[:, :-1] # , training=training
) # don't know why -1
# Remove the conditioned part
logits = logits[:, tf.shape(c_indices)[1] - 1 :] # -1 here 'cause -1 above
logits = tf.reshape(logits, shape=(-1, logits.shape[-1]))
if target_image is not None:
_, target_indices = self.encode_to_z(target_image)
target_indices = tf.reshape(target_indices, shape=(-1,))
frame_loss = tf.reduce_mean(
self.loss_fn(y_true=target_indices, y_pred=logits)
)
image = self.get_image(logits, tf.shape(quant_z))
return image, frame_loss
# @tf.function()
def process_video(self, first_last_frame, target_video=None):
first_frame = first_last_frame[:, 0]
last_frame = first_last_frame[:, -1]
x = first_frame
c = last_frame
total_loss = 0
generated_video = [x]
for i in range(19): # TODO change 19 to the number of frame in the video
if target_video is not None:
with tf.GradientTape() as tape:
target = target_video[:, i, ...] if target_video is not None else None
generated_image, frame_loss = self.process_image(x, c, target_image=target)
x = generated_image
generated_video.append(generated_image)
grads = tape.gradient(
frame_loss,
self.transformer.trainable_variables,
)
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
total_loss += frame_loss
else:
target = target_video[:, i, ...] if target_video is not None else None
generated_image, frame_loss = self.process_image(x, c, target_image=target)
x = generated_image
generated_video.append(generated_image)
if target_video is not None:
return tf.stack(generated_video, axis=1), total_loss
else:
return tf.stack(generated_video, axis=1)
def train_step(self, data):
first_last_frame, y = data
generated_video, loss = self.process_video(first_last_frame, y)
self.loss_tracker.update_state(loss)
# Log results.
return {m.name: m.result() for m in self.metrics}
def get_image(self, logits, shape):
probs = tf.keras.activations.softmax(logits)
_, generated_indices = tf.math.top_k(probs)
generated_indices = tf.reshape(
generated_indices,
(-1,), # , self.first_stage_model.quantize.num_embeddings)
)
quant = self.first_stage_model.quantize.get_codebook_entry(
generated_indices, shape=shape
)
return self.first_stage_model.decode(quant)
def test_step(self, data):
first_last_frame, y = data
generated_video, loss = self.process_video(first_last_frame, y)
self.loss_tracker.update_state(loss)
# Log results.
return {m.name: m.result() for m in self.metrics}
def decode_to_img(self, index, zshape):
quant_z = self.first_stage_model.quantize.get_codebook_entry(
tf.reshape(index, -1), shape=zshape
)
x = self.first_stage_model.decode(quant_z)
return x
|