|
|
|
|
|
|
|
|
|
|
|
|
|
class DynamicLossScaler(object): |
|
def __init__( |
|
self, |
|
init_scale=2.0**15, |
|
scale_factor=2.0, |
|
scale_window=2000, |
|
tolerance=0.0, |
|
threshold=None, |
|
min_loss_scale=1e-4, |
|
): |
|
self.loss_scale = init_scale |
|
self.scale_factor = scale_factor |
|
self.scale_window = scale_window |
|
self.tolerance = tolerance |
|
self.threshold = threshold |
|
self._iter = 0 |
|
self._last_overflow_iter = -1 |
|
self._last_rescale_iter = -1 |
|
self._overflows_since_rescale = 0 |
|
self.min_loss_scale = min_loss_scale |
|
|
|
def scale(self, outputs): |
|
return self.loss_scale * outputs |
|
|
|
def update(self): |
|
if (self._iter - self._last_overflow_iter) % self.scale_window == 0: |
|
self.loss_scale *= self.scale_factor |
|
self._last_rescale_iter = self._iter |
|
self._iter += 1 |
|
|
|
def _decrease_loss_scale(self): |
|
self.loss_scale /= self.scale_factor |
|
if self.threshold is not None: |
|
self.loss_scale = max(self.loss_scale, self.threshold) |
|
|
|
def check_overflow(self, grad_norm): |
|
|
|
if grad_norm == float("inf") or grad_norm != grad_norm: |
|
|
|
prev_scale = self.loss_scale |
|
iter_since_rescale = self._iter - self._last_rescale_iter |
|
|
|
self._last_overflow_iter = self._iter |
|
self._overflows_since_rescale += 1 |
|
pct_overflow = self._overflows_since_rescale / float(iter_since_rescale) |
|
if pct_overflow >= self.tolerance: |
|
self._decrease_loss_scale() |
|
self._last_rescale_iter = self._iter |
|
self._overflows_since_rescale = 0 |
|
|
|
if self.loss_scale <= self.min_loss_scale: |
|
|
|
|
|
self.loss_scale = prev_scale |
|
raise FloatingPointError( |
|
( |
|
"Minimum loss scale reached ({}). Your loss is probably exploding. " |
|
"Try lowering the learning rate, using gradient clipping or " |
|
"increasing the batch size." |
|
).format(self.min_loss_scale) |
|
) |
|
|
|
self._iter += 1 |
|
raise OverflowError("setting loss scale to: " + str(self.loss_scale)) |
|
|