|
import numpy as np |
|
import tensorflow as tf |
|
from tensorflow import keras |
|
|
|
|
|
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule): |
|
"""A LearningRateSchedule that uses a warmup cosine decay schedule.""" |
|
|
|
def __init__(self, lr_start, lr_max, warmup_steps, total_steps): |
|
""" |
|
Args: |
|
lr_start: The initial learning rate |
|
lr_max: The maximum learning rate to which lr should increase to in |
|
the warmup steps |
|
warmup_steps: The number of steps for which the model warms up |
|
total_steps: The total number of steps for the model training |
|
""" |
|
super().__init__() |
|
self.lr_start = lr_start |
|
self.lr_max = lr_max |
|
self.warmup_steps = warmup_steps |
|
self.total_steps = total_steps |
|
self.pi = tf.constant(np.pi) |
|
|
|
def __call__(self, step): |
|
|
|
|
|
if self.total_steps < self.warmup_steps: |
|
raise ValueError( |
|
f"Total number of steps {self.total_steps} must be" |
|
+ f"larger or equal to warmup steps {self.warmup_steps}." |
|
) |
|
|
|
|
|
|
|
|
|
cos_annealed_lr = tf.cos( |
|
self.pi |
|
* (tf.cast(step, tf.float32) - self.warmup_steps) |
|
/ tf.cast(self.total_steps - self.warmup_steps, tf.float32) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr) |
|
|
|
|
|
if self.warmup_steps > 0: |
|
|
|
|
|
if self.lr_max < self.lr_start: |
|
raise ValueError( |
|
f"lr_start {self.lr_start} must be smaller or" |
|
+ f"equal to lr_max {self.lr_max}." |
|
) |
|
|
|
|
|
|
|
slope = (self.lr_max - self.lr_start) / self.warmup_steps |
|
|
|
|
|
|
|
warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start |
|
|
|
|
|
|
|
|
|
learning_rate = tf.where( |
|
step < self.warmup_steps, warmup_rate, learning_rate |
|
) |
|
|
|
|
|
|
|
return tf.where( |
|
step > self.total_steps, 0.0, learning_rate, name="learning_rate" |
|
) |
|
|