File size: 14,479 Bytes
5019d3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 |
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for BERT model."""
import math
import torch
from torch.optim import Optimizer
from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_
import logging
import abc
import sys
logger = logging.getLogger(__name__)
if sys.version_info >= (3, 4):
ABC = abc.ABC
else:
ABC = abc.ABCMeta('ABC', (), {})
class _LRSchedule(ABC):
""" Parent of all LRSchedules here. """
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
def __init__(self, warmup=0.002, t_total=-1, **kw):
"""
:param warmup: what fraction of t_total steps will be used for linear warmup
:param t_total: how many training steps (updates) are planned
:param kw:
"""
super(_LRSchedule, self).__init__(**kw)
if t_total < 0:
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
warmup = max(warmup, 0.)
self.warmup, self.t_total = float(warmup), float(t_total)
self.warned_for_t_total_at_progress = -1
def get_lr(self, step, nowarn=False):
"""
:param step: which of t_total steps we're on
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
:return: learning rate multiplier for current update
"""
if self.t_total < 0:
return 1.
progress = float(step) / self.t_total
ret = self.get_lr_(progress)
# warning for exceeding t_total (only active with warmup_linear
if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress:
logger.warning("Training beyond specified 't_total'. Learning rate multiplier set to {}. Please "
"set 't_total' of {} correctly.".format(ret, self.__class__.__name__))
self.warned_for_t_total_at_progress = progress
# end warning
return ret
@abc.abstractmethod
def get_lr_(self, progress):
"""
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
:return: learning rate multiplier for current update
"""
return 1.
class ConstantLR(_LRSchedule):
def get_lr_(self, progress):
return 1.
class WarmupCosineSchedule(_LRSchedule):
"""
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve.
If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
"""
warn_t_total = True
def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
"""
:param warmup: see LRSchedule
:param t_total: see LRSchedule
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1.
at progress==warmup and 0 at progress==1.
:param kw:
"""
super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
self.cycles = cycles
def get_lr_(self, progress):
if progress < self.warmup:
return progress / self.warmup
else:
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
"""
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
learning rate (with hard restarts).
"""
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
assert(cycles >= 1.)
def get_lr_(self, progress):
if progress < self.warmup:
return progress / self.warmup
else:
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1)))
return ret
class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
"""
All training progress is divided in `cycles` (default=1.) parts of equal length.
Every part follows a schedule with the first `warmup` fraction of training steps linearly increasing from 0. to 1.,
followed by a learning rate decreasing from 1. to 0. following a cosine curve.
"""
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
assert(warmup * cycles < 1.)
warmup = warmup * cycles if warmup >= 0 else warmup
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles,
**kw)
def get_lr_(self, progress):
progress = progress * self.cycles % 1.
if progress < self.warmup:
return progress / self.warmup
else:
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
ret = 0.5 * (1. + math.cos(math.pi * progress))
return ret
class WarmupConstantSchedule(_LRSchedule):
"""
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
Keeps learning rate equal to 1. after warmup.
"""
def get_lr_(self, progress):
if progress < self.warmup:
return progress / self.warmup
return 1.
class WarmupLinearSchedule(_LRSchedule):
"""
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps.
"""
warn_t_total = True
def get_lr_(self, progress):
if progress < self.warmup:
return progress / self.warmup
return max((progress - 1.) / (self.warmup - 1.), 0.)
SCHEDULES = {
None: ConstantLR,
"none": ConstantLR,
"warmup_cosine": WarmupCosineSchedule,
"warmup_constant": WarmupConstantSchedule,
"warmup_linear": WarmupLinearSchedule
}
class EMA(object):
""" Exponential Moving Average for model parameters.
references:
[1] https://github.com/BangLiu/QANet-PyTorch/blob/master/model/modules/ema.py
[2] https://github.com/hengruo/QANet-pytorch/blob/e2de07cd2c711d525f5ffee35c3764335d4b501d/main.py"""
def __init__(self, decay):
self.decay = decay
self.shadow = {}
self.original = {}
def register(self, name, val):
self.shadow[name] = val.clone()
def __call__(self, model, step):
decay = min(self.decay, (1 + step) / (10.0 + step))
for name, param in model.named_parameters():
if param.requires_grad:
assert name in self.shadow
new_average = \
(1.0 - decay) * param.data + decay * self.shadow[name]
self.shadow[name] = new_average.clone()
def assign(self, model):
for name, param in model.named_parameters():
if param.requires_grad:
assert name in self.shadow
self.original[name] = param.data.clone()
param.data = self.shadow[name]
def resume(self, model):
for name, param in model.named_parameters():
if param.requires_grad:
assert name in self.shadow
param.data = self.original[name]
class BertAdam(Optimizer):
"""Implements BERT version of Adam algorithm with weight decay fix.
Params:
lr: learning rate
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
t_total: total number of training steps for the learning
rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
schedule: schedule to use for the warmup (see above).
Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object
(see below).
If `None` or `'none'`, learning rate is always kept constant.
Default : `'warmup_linear'`
b1: Adams b1. Default: 0.9
b2: Adams b2. Default: 0.999
e: Adams epsilon. Default: 1e-6
weight_decay: Weight decay. Default: 0.01
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
"""
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
if not 0.0 <= b2 < 1.0:
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
if not e >= 0.0:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
# initialize schedule object
if not isinstance(schedule, _LRSchedule):
schedule_type = SCHEDULES[schedule]
schedule = schedule_type(warmup=warmup, t_total=t_total)
else:
if warmup != -1 or t_total != -1:
logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is "
"provided as schedule. Please specify custom warmup and t_total in _LRSchedule object.")
defaults = dict(lr=lr, schedule=schedule,
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(BertAdam, self).__init__(params, defaults)
def get_lr(self):
lr = []
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
if len(state) == 0:
return [0]
lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step'])
lr.append(lr_scheduled)
return lr
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['next_m'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['next_v'] = torch.zeros_like(p.data)
next_m, next_v = state['next_m'], state['next_v']
beta1, beta2 = group['b1'], group['b2']
# Add grad clipping
if group['max_grad_norm'] > 0:
clip_grad_norm_(p, group['max_grad_norm'])
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
next_m.mul_(beta1).add_(grad, alpha=1 - beta1)
next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
update = next_m / (next_v.sqrt() + group['e'])
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data
lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step'])
update_with_lr = lr_scheduled * update
p.data.add_(-update_with_lr)
state['step'] += 1
# step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
# No bias correction
# bias_correction1 = 1 - beta1 ** state['step']
# bias_correction2 = 1 - beta2 ** state['step']
return loss
|