File size: 37,953 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 |
import torch
import math
from torch.nn.utils import clip_grad_norm_, clip_grad_value_
from typing import Union, Iterable, Tuple, Callable, List
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pdb
import numpy as np
import copy
import random
inf = math.inf
def calculate_grad_norm(model: torch.nn.Module, norm_type=2) -> float:
"""
Overview:
calculate grad norm of the parameters whose grad norms are not None in the model.
Arguments:
- model: torch.nn.Module
- norm_type (:obj:`int` or `inf`)
"""
parameters = list(filter(lambda p: p.grad is not None, model.parameters()))
if parameters == []:
parameters = 0
return 0
if norm_type == 'inf':
total_norm = max(p.grad.data.abs().max() for p in parameters)
return float(total_norm)
else:
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
return float(total_norm)
def calculate_grad_norm_without_bias_two_norm(model: torch.nn.Module) -> float:
"""
Overview:
calculate grad norm of the parameters whose grad norms are not None in the model.
Arguments:
- model: torch.nn.Module
"""
_list = []
for name, param in model.named_parameters():
if 'bias' not in name and param.requires_grad:
if param.grad is None:
return 0
_list.append(param.grad.data.norm(2).item() ** 2)
return float(sum(_list) ** (1. / 2))
def grad_ignore_norm(parameters, max_norm, norm_type=2):
"""
Overview:
Clip the gradient norm of an iterable of parameters.
Arguments:
- parameters (:obj:`Iterable`): an iterable of torch.Tensor
- max_norm (:obj:`float`): the max norm of the gradients
- norm_type (:obj:`float`): 2.0 means use norm2 to clip
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
max_norm = float(max_norm)
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.zero_()
return total_norm
def grad_ignore_value(parameters, clip_value):
"""
Overview:
Clip the gradient value of an iterable of parameters.
Arguments:
- parameters (:obj:`Iterable`): an iterable of torch.Tensor
- clip_value (:obj:`float`): the value to start clipping
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
clip_value = float(clip_value)
flag = False
for p in filter(lambda p: p.grad is not None, parameters):
val = p.grad.data.abs().max()
if val >= clip_value:
flag = True
break
if flag:
for p in filter(lambda p: p.grad is not None, parameters):
p.grad.data.zero_()
class Adam(torch.optim.Adam):
"""
Overview:
Rewrited Adam optimizer to support more features.
Interfaces:
``__init__``, ``step``, ``_state_init``, ``get_grad``
"""
def __init__(
self,
params: Iterable,
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
amsgrad: bool = False,
optim_type: str = 'adam',
grad_clip_type: str = None,
clip_value: Union[float, None] = None,
clip_coef: float = 5,
clip_norm_type: float = 2.0,
clip_momentum_timestep: int = 100,
grad_norm_type: str = None,
grad_ignore_type: str = None,
ignore_value: Union[float, None] = None,
ignore_coef: float = 5,
ignore_norm_type: float = 2.0,
ignore_momentum_timestep: int = 100,
):
"""
Overview:
init method of refactored Adam class
Arguments:
- params (:obj:`iterable`): – an iterable of torch.Tensor s or dict s. \
Specifies what Tensors should be optimized
- lr (:obj:`float`): learning rate, default set to 1e-3
- betas (:obj:`Tuple[float, float]`): coefficients used for computing running averages of gradient and its\
square, default set to (0.9, 0.999))
- eps (:obj:`float`): term added to the denominator to improve numerical stability, default set to 1e-8
- weight_decay (:obj:`float`): weight decay coefficient, deault set to 0
- amsgrad (:obj:`bool`): whether to use the AMSGrad variant of this algorithm from the paper\
On the Convergence of Adam and Beyond <https://arxiv.org/abs/1904.09237>
- optim_type (:obj:str): support ["adam", "adamw"]
- grad_clip_type (:obj:`str`): support [None, 'clip_momentum', 'clip_value', 'clip_norm', \
'clip_momentum_norm']
- clip_value (:obj:`float`): the value to start clipping
- clip_coef (:obj:`float`): the cliping coefficient
- clip_norm_type (:obj:`float`): 2.0 means use norm2 to clip
- clip_momentum_timestep (:obj:`int`): after how many step should we start the momentum clipping
- grad_ignore_type (:obj:`str`): support [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', \
'ignore_momentum_norm']
- ignore_value (:obj:`float`): the value to start ignoring
- ignore_coef (:obj:`float`): the ignoreing coefficient
- ignore_norm_type (:obj:`float`): 2.0 means use norm2 to ignore
- ignore_momentum_timestep (:obj:`int`): after how many step should we start the momentum ignoring
"""
self._support_type = {
'optim': ['adam', 'adamw'],
'grad_clip': [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'],
'grad_norm': [None],
'grad_ignore': [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'],
}
assert optim_type in self._support_type['optim']
assert grad_clip_type in self._support_type['grad_clip']
assert grad_norm_type in self._support_type['grad_norm']
assert grad_ignore_type in self._support_type['grad_ignore']
if grad_clip_type:
assert clip_value is not None
if grad_ignore_type:
assert ignore_value is not None
self._optim_type = optim_type
self._grad_clip_type = grad_clip_type
self._grad_norm_type = grad_norm_type
self._grad_ignore_type = grad_ignore_type
self._clip_value = clip_value
self._clip_norm_type = clip_norm_type
self._clip_coef = clip_coef
self._ignore_value = ignore_value
self._ignore_norm_type = ignore_norm_type
self._ignore_coef = ignore_coef
self._clip_momentum_timestep = clip_momentum_timestep
self._ignore_momentum_timestep = ignore_momentum_timestep
if self._optim_type == 'adamw':
self._weight_decay = weight_decay
super(Adam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=0, amsgrad=amsgrad)
elif self._optim_type == 'adam':
super(Adam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
else:
raise NotImplementedError(
"optimizer type {} is not implemented, support type is {}".format(
self._optim_type, self._support_type['optim']
)
)
def _state_init(self, p, amsgrad):
"""
Overview:
Initialize the state of the optimizer
Arguments:
- p (:obj:`torch.Tensor`): the parameter to be optimized
- amsgrad (:obj:`bool`): whether to use the AMSGrad variant of this algorithm from the paper\
On the Convergence of Adam and Beyond <https://arxiv.org/abs/1904.09237>
"""
state = self.state[p]
state['thre_exp_avg_sq'] = torch.zeros_like(p.data, device=p.data.device)
# others
if torch.__version__ < "1.12.0":
state['step'] = 0
# TODO
# wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0
else:
state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
if self.defaults['capturable'] else torch.tensor(0.)
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
def step(self, closure: Union[Callable, None] = None):
"""
Overview:
Performs a single optimization step
Arguments:
- closure (:obj:`callable`): A closure that reevaluates the model and returns the loss, default set to None
"""
# clipping
new_params = [
t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None
]
if self._grad_clip_type == 'clip_value':
clip_grad_value_(new_params, self._clip_value)
elif self._grad_clip_type == 'clip_norm':
clip_grad_norm_(new_params, self._clip_value, self._clip_norm_type)
elif self._grad_clip_type == 'clip_momentum':
'''
This is the implimentation mimic the clip used in OPENAI, quote:
'Gradients are additionally clipped per parameter to be within between ±5√v
where v is the running estimate of the second moment of the (unclipped) gradient'
'''
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
self._state_init(p, group['amsgrad'])
grad = p.grad.data
# should we use same beta group?
beta1, beta2 = group['betas']
bias_correction2 = 1 - beta2 ** state['step']
state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
if state['step'] >= self._clip_momentum_timestep: # initial value is inaccurate
flag = grad.abs(
) > (state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * self._clip_coef
grad.mul_(~flag).add_(
((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) *
self._clip_coef).mul_(flag)
)
elif self._grad_clip_type == 'clip_momentum_norm':
# might have multi param_group, we should calculate each group differently.
for group in self.param_groups:
total_norm = 0
total_momentum_norm = 0
step = inf
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
self._state_init(p, group['amsgrad'])
grad = p.grad.data
# should we use same beta group?
beta1, beta2 = group['betas']
bias_correction2 = 1 - beta2 ** state['step']
state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
# sum total_norm
param_norm = grad.norm(self._clip_norm_type)
total_norm += param_norm.item() ** self._clip_norm_type
# sum momentum_norm
momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) *
self._clip_coef).norm(self._clip_norm_type)
total_momentum_norm += momentum.item() ** self._clip_norm_type
step = min(step, state['step'])
if step > self._clip_momentum_timestep:
total_norm = total_norm ** (1. / self._clip_norm_type)
total_momentum_norm = total_momentum_norm ** (1. / self._clip_norm_type)
clip_coef = total_momentum_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in group['params']:
p.grad.data.mul_(clip_coef)
if self._grad_ignore_type == 'ignore_value':
grad_ignore_value(new_params, self._ignore_value)
elif self._grad_ignore_type == 'ignore_norm':
grad_ignore_norm(new_params, self._ignore_value, self._ignore_norm_type)
elif self._grad_ignore_type == 'ignore_momentum':
flag = False
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
self._state_init(p, group['amsgrad'])
grad = p.grad.data
# should we use same beta group?
beta1, beta2 = group['betas']
bias_correction2 = 1 - beta2 ** state['step']
state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
if state['step'] >= self._ignore_momentum_timestep: # initial value is inaccurate
if grad.abs() > (state['thre_exp_avg_sq'].sqrt() /
math.sqrt(bias_correction2)) * self._ignore_coef:
flag = True
break
else:
continue
break
if flag:
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
p.grad.zero_()
elif self._grad_ignore_type == 'ignore_momentum_norm':
# might have multi param_group, we should calculate each group differently.
step = inf
for group in self.param_groups:
total_norm = 0
total_momentum_norm = 0
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
self._state_init(p, group['amsgrad'])
grad = p.grad.data
# should we use same beta group?
beta1, beta2 = group['betas']
bias_correction2 = 1 - beta2 ** state['step']
state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
# sum total_norm
param_norm = grad.norm(self._ignore_norm_type)
total_norm += param_norm.item() ** self._ignore_norm_type
# sum momentum_norm
momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) *
self._ignore_coef).norm(self._ignore_norm_type)
total_momentum_norm += momentum.item() ** self._ignore_norm_type
step = min(step, state['step'])
if step > self._ignore_momentum_timestep:
total_norm = total_norm ** (1. / self._ignore_norm_type)
total_momentum_norm = total_momentum_norm ** (1. / self._ignore_norm_type)
ignore_coef = total_momentum_norm / (total_norm + 1e-6)
if ignore_coef < 1:
for p in group['params']:
p.grad.zero_()
# Adam optim type
if self._optim_type == 'adamw':
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
p.data = p.data.add(-self._weight_decay * group['lr'], p.data)
return super().step(closure=closure)
elif self._optim_type == 'adam':
return super().step(closure=closure)
def get_grad(self) -> float:
total_norm = 0.
params = [t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None]
for p in params:
param_norm = p.grad.data.norm(self._clip_norm_type)
total_norm += param_norm.item() ** self._clip_norm_type
return total_norm
class RMSprop(torch.optim.RMSprop):
r"""
Overview:
Rewrited RMSprop optimizer to support more features.
Interfaces:
``__init__``, ``step``, ``_state_init``, ``get_grad``
"""
def __init__(
self,
params: Iterable,
lr: float = 1e-2,
alpha: float = 0.99,
eps: float = 1e-8,
weight_decay: float = 0,
momentum: float = 0,
centered: bool = False,
grad_clip_type: str = None,
clip_value: Union[float, None] = None,
clip_coef: float = 5,
clip_norm_type: float = 2.0,
clip_momentum_timestep: int = 100,
grad_norm_type: str = None,
grad_ignore_type: str = None,
ignore_value: Union[float, None] = None,
ignore_coef: float = 5,
ignore_norm_type: float = 2.0,
ignore_momentum_timestep: int = 100,
):
"""
Overview:
init method of refactored Adam class
Arguments:
- params (:obj:`iterable`): – an iterable of torch.Tensor s or dict s. \
Specifies what Tensors should be optimized
- lr (:obj:`float`): learning rate, default set to 1e-3
- alpha (:obj:`float`): smoothing constant, default set to 0.99
- eps (:obj:`float`): term added to the denominator to improve numerical stability, default set to 1e-8
- weight_decay (:obj:`float`): weight decay coefficient, deault set to 0
- centred (:obj:`bool`): if True, compute the centered RMSprop, \
the gradient is normalized by an estimation of its variance
- grad_clip_type (:obj:`str`): support [None, 'clip_momentum', 'clip_value', 'clip_norm', \
'clip_momentum_norm']
- clip_value (:obj:`float`): the value to start clipping
- clip_coef (:obj:`float`): the cliping coefficient
- clip_norm_type (:obj:`float`): 2.0 means use norm2 to clip
- clip_momentum_timestep (:obj:`int`): after how many step should we start the momentum clipping
- grad_ignore_type (:obj:`str`): support [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', \
'ignore_momentum_norm']
- ignore_value (:obj:`float`): the value to start ignoring
- ignore_coef (:obj:`float`): the ignoreing coefficient
- ignore_norm_type (:obj:`float`): 2.0 means use norm2 to ignore
- ignore_momentum_timestep (:obj:`int`): after how many step should we start the momentum ignoring
"""
self._support_type = {
'grad_clip': [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'],
'grad_norm': [None],
'grad_ignore': [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'],
}
assert grad_clip_type in self._support_type['grad_clip']
assert grad_norm_type in self._support_type['grad_norm']
assert grad_ignore_type in self._support_type['grad_ignore']
if grad_clip_type:
assert clip_value is not None
if grad_ignore_type:
assert ignore_value is not None
self._grad_clip_type = grad_clip_type
self._grad_norm_type = grad_norm_type
self._grad_ignore_type = grad_ignore_type
self._clip_value = clip_value
self._clip_norm_type = clip_norm_type
self._clip_coef = clip_coef
self._ignore_value = ignore_value
self._ignore_norm_type = ignore_norm_type
self._ignore_coef = ignore_coef
self._clip_momentum_timestep = clip_momentum_timestep
self._ignore_momentum_timestep = ignore_momentum_timestep
super(RMSprop, self).__init__(
params, lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay, momentum=momentum, centered=centered
)
def _state_init(self, p, momentum, centered):
"""
Overview:
Initialize the state of the optimizer
Arguments:
- p (:obj:`torch.Tensor`): the parameter to be optimized
- momentum (:obj:`float`): the momentum coefficient
- centered (:obj:`bool`): if True, compute the centered RMSprop, \
the gradient is normalized by an estimation of its variance
"""
state = self.state[p]
state['step'] = 0
state['thre_square_avg'] = torch.zeros_like(p.data, device=p.data.device)
state['square_avg'] = torch.zeros_like(p.data, device=p.data.device)
if momentum:
state['momentum_buffer'] = torch.zeros_like(p.data, device=p.data.device)
if centered:
state['grad_avg'] = torch.zeros_like(p.data, device=p.data.device)
def step(self, closure: Union[Callable, None] = None):
"""
Overview:
Performs a single optimization step
Arguments:
- closure (:obj:`callable`): A closure that reevaluates the model and returns the loss, default set to None
"""
# clipping
new_params = [
t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None
]
if self._grad_clip_type == 'clip_value':
clip_grad_value_(new_params, self._clip_value)
elif self._grad_clip_type == 'clip_norm':
clip_grad_norm_(new_params, self._clip_value, self._clip_norm_type)
elif self._grad_clip_type == 'clip_momentum':
'''
This implementation mimics the clip used in OPENAI, quote:
'Gradients are additionally clipped per parameter to be within between ±5√v
where v is the running estimate of the second moment of the (unclipped) gradient'
'''
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
self._state_init(p, group['momentum'], group['centered'])
grad = p.grad.data
# beta1, beta2 = group['betas']
alpha = group['alpha']
state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad)
if state['step'] >= self._clip_momentum_timestep: # initial value is inaccurate
flag = grad.abs() > state['thre_square_avg'].sqrt() * self._clip_coef
grad.mul_(~flag).add_((state['thre_square_avg'].sqrt() * self._clip_coef).mul_(flag))
elif self._grad_clip_type == 'clip_momentum_norm':
# might have multi param_group, we should calculate each group differently.
for group in self.param_groups:
total_norm = 0
total_momentum_norm = 0
step = inf
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
self._state_init(p, group['momentum'], group['centered'])
grad = p.grad.data
alpha = group['alpha']
state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad)
# sum total_norm
param_norm = grad.norm(self._clip_norm_type)
total_norm += param_norm.item() ** self._clip_norm_type
# sum momentum_norm
momentum = (state['thre_square_avg'].sqrt() * self._clip_coef).norm(self._clip_norm_type)
total_momentum_norm += momentum.item() ** self._clip_norm_type
step = min(step, state['step'])
if step > self._clip_momentum_timestep:
total_norm = total_norm ** (1. / self._clip_norm_type)
total_momentum_norm = total_momentum_norm ** (1. / self._clip_norm_type)
clip_coef = total_momentum_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in group['params']:
p.grad.data.mul_(clip_coef)
if self._grad_ignore_type == 'ignore_value':
grad_ignore_value(new_params, self._ignore_value)
elif self._grad_ignore_type == 'ignore_norm':
grad_ignore_norm(new_params, self._ignore_value, self._ignore_norm_type)
elif self._grad_ignore_type == 'ignore_momentum':
flag = False
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
self._state_init(p, group['momentum'], group['centered'])
grad = p.grad.data
alpha = group['alpha']
state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad)
if state['step'] >= self._ignore_momentum_timestep: # initial value is inaccurate
if grad.abs() > state['thre_square_avg'].sqrt() * self._ignore_coef:
flag = True
break
else:
continue
break
if flag:
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
p.grad.zero_()
elif self._grad_ignore_type == 'ignore_momentum_norm':
# might have multi param_group, we should calculate each group differently.
step = inf
for group in self.param_groups:
total_norm = 0
total_momentum_norm = 0
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
self._state_init(p, group['momentum'], group['centered'])
grad = p.grad.data
alpha = group['alpha']
state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad)
# sum total_norm
param_norm = grad.norm(self._ignore_norm_type)
total_norm += param_norm.item() ** self._ignore_norm_type
# sum momentum_norm
momentum = (state['thre_square_avg'].sqrt() * self._ignore_coef).norm(self._ignore_norm_type)
total_momentum_norm += momentum.item() ** self._ignore_norm_type
step = min(step, state['step'])
if step > self._ignore_momentum_timestep:
total_norm = total_norm ** (1. / self._ignore_norm_type)
total_momentum_norm = total_momentum_norm ** (1. / self._ignore_norm_type)
ignore_coef = total_momentum_norm / (total_norm + 1e-6)
if ignore_coef < 1:
for p in group['params']:
p.grad.zero_()
return super().step(closure=closure)
def get_grad(self) -> float:
"""
Overview:
calculate grad norm of the parameters whose grad norms are not None in the model.
"""
total_norm = 0.
params = [t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None]
for p in params:
param_norm = p.grad.data.norm(self._clip_norm_type)
total_norm += param_norm.item() ** self._clip_norm_type
return total_norm
class PCGrad():
"""
Overview:
PCGrad optimizer to support multi-task.
you can view the paper in the following link https://arxiv.org/pdf/2001.06782.pdf
Interfaces:
``__init__``, ``zero_grad``, ``step``, ``pc_backward``
Properties:
- optimizer (:obj:`torch.optim`): the optimizer to be used
"""
def __init__(self, optimizer, reduction='mean'):
"""
Overview:
Initialization of PCGrad optimizer
Arguments:
- optimizer (:obj:`torch.optim`): the optimizer to be used
- reduction (:obj:`str`): the reduction method, support ['mean', 'sum']
"""
self._optim, self._reduction = optimizer, reduction
@property
def optimizer(self):
"""
Overview:
get the optimizer
"""
return self._optim
def zero_grad(self):
"""
Overview:
clear the gradient of the parameters
"""
return self._optim.zero_grad(set_to_none=True)
def step(self):
"""
Overview:
update the parameters with the gradient
"""
return self._optim.step()
def pc_backward(self, objectives):
"""
Overview:
calculate the gradient of the parameters
Arguments:
- objectives: a list of objectives
"""
grads, shapes, has_grads = self._pack_grad(objectives)
pc_grad = self._project_conflicting(grads, has_grads)
pc_grad = self._unflatten_grad(pc_grad, shapes[0])
self._set_grad(pc_grad)
return
def _project_conflicting(self, grads, has_grads, shapes=None):
"""
Overview:
project the conflicting gradient to the orthogonal space
Arguments:
- grads (:obj:`list`): a list of the gradient of the parameters
- has_grads (:obj:`list`): a list of mask represent whether the parameter has gradient
- shapes (:obj:`list`): a list of the shape of the parameters
"""
shared = torch.stack(has_grads).prod(0).bool()
pc_grad, num_task = copy.deepcopy(grads), len(grads)
for g_i in pc_grad:
random.shuffle(grads)
for g_j in grads:
g_i_g_j = torch.dot(g_i, g_j)
if g_i_g_j < 0:
g_i -= (g_i_g_j) * g_j / (g_j.norm() ** 2)
merged_grad = torch.zeros_like(grads[0]).to(grads[0].device)
if self._reduction:
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).mean(dim=0)
elif self._reduction == 'sum':
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).sum(dim=0)
else:
raise KeyError("invalid reduction method")
merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)
return merged_grad
def _set_grad(self, grads):
"""
Overview:
set the modified gradients to the network
Arguments:
- grads (:obj:`list`): a list of the gradient of the parameters
"""
idx = 0
for group in self._optim.param_groups:
for p in group['params']:
# if p.grad is None: continue
p.grad = grads[idx]
idx += 1
return
def _pack_grad(self, objectives):
"""
Overview:
pack the gradient of the parameters of the network for each objective
Arguments:
- objectives: a list of objectives
Returns:
- grad: a list of the gradient of the parameters
- shape: a list of the shape of the parameters
- has_grad: a list of mask represent whether the parameter has gradient
"""
grads, shapes, has_grads = [], [], []
for obj in objectives:
self._optim.zero_grad(set_to_none=True)
obj.backward(retain_graph=True)
grad, shape, has_grad = self._retrieve_grad()
grads.append(self._flatten_grad(grad, shape))
has_grads.append(self._flatten_grad(has_grad, shape))
shapes.append(shape)
return grads, shapes, has_grads
def _unflatten_grad(self, grads, shapes):
"""
Overview:
unflatten the gradient of the parameters of the network
Arguments:
- grads (:obj:`list`): a list of the gradient of the parameters
- shapes (:obj:`list`): a list of the shape of the parameters
"""
unflatten_grad, idx = [], 0
for shape in shapes:
length = np.prod(shape)
unflatten_grad.append(grads[idx:idx + length].view(shape).clone())
idx += length
return unflatten_grad
def _flatten_grad(self, grads, shapes):
"""
Overview:
flatten the gradient of the parameters of the network
Arguments:
- grads (:obj:`list`): a list of the gradient of the parameters
- shapes (:obj:`list`): a list of the shape of the parameters
"""
flatten_grad = torch.cat([g.flatten() for g in grads])
return flatten_grad
def _retrieve_grad(self):
"""
Overview:
get the gradient of the parameters of the network with specific objective
Returns:
- grad: a list of the gradient of the parameters
- shape: a list of the shape of the parameters
- has_grad: a list of mask represent whether the parameter has gradient
"""
grad, shape, has_grad = [], [], []
for group in self._optim.param_groups:
for p in group['params']:
# if p.grad is None: continue
# tackle the multi-head scenario
if p.grad is None:
shape.append(p.shape)
grad.append(torch.zeros_like(p).to(p.device))
has_grad.append(torch.zeros_like(p).to(p.device))
continue
shape.append(p.grad.shape)
grad.append(p.grad.clone())
has_grad.append(torch.ones_like(p).to(p.device))
return grad, shape, has_grad
def configure_weight_decay(model: nn.Module, weight_decay: float) -> List:
"""
Overview:
Separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layer-norm or embedding weights).
Arguments:
- model (:obj:`nn.Module`): the given PyTorch model.
- weight_decay (:obj:`float`): weight decay value for optimizer.
Returns:
- optim groups (:obj:`List`): the parameter groups to be set in the latter optimizer.
"""
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
# Because named_modules and named_parameters are recursive
# we will see the same tensors p many times. But doing it this way
# allows us to know which parent module any tensor p belongs to.
if pn.endswith('bias'):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
else:
decay.add(fpn)
decay = decay - no_decay
# validate that we considered every parameter
param_dict = {pn: p for pn, p in model.named_parameters()}
union_params = decay | no_decay
assert len(
param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params),)
optim_groups = [
{
"params": [param_dict[pn] for pn in sorted(list(decay))],
"weight_decay": weight_decay
},
{
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
"weight_decay": 0.0
},
]
return optim_groups
|