Spaces:
Sleeping
Sleeping
File size: 5,218 Bytes
2d9a728 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
""" Optimizer Factory w/ Custom Weight Decay
Hacked together by / Copyright 2020 Ross Wightman
"""
import re
import torch
from torch import optim as optim
from utils.distributed import is_main_process
import logging
logger = logging.getLogger(__name__)
try:
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
has_apex = True
except ImportError:
has_apex = False
def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True):
named_param_tuples = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")):
named_param_tuples.append([name, param, 0])
elif name in no_decay_list:
named_param_tuples.append([name, param, 0])
else:
named_param_tuples.append([name, param, weight_decay])
return named_param_tuples
def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr):
"""use lr=diff_lr for modules named found in diff_lr_names,
otherwise use lr=default_lr
Args:
named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module
diff_lr_names: List(str)
diff_lr: float
default_lr: float
Returns:
named_param_tuples_with_lr: List([name, param, weight_decay, lr])
"""
named_param_tuples_with_lr = []
logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}")
for name, p, wd in named_param_tuples_or_model:
use_diff_lr = False
for diff_name in diff_lr_names:
# if diff_name in name:
if re.search(diff_name, name) is not None:
logger.info(f"param {name} use different_lr: {diff_lr}")
use_diff_lr = True
break
named_param_tuples_with_lr.append(
[name, p, wd, diff_lr if use_diff_lr else default_lr]
)
if is_main_process():
for name, _, wd, diff_lr in named_param_tuples_with_lr:
logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}")
return named_param_tuples_with_lr
def create_optimizer_params_group(named_param_tuples_with_lr):
"""named_param_tuples_with_lr: List([name, param, weight_decay, lr])"""
group = {}
for name, p, wd, lr in named_param_tuples_with_lr:
if wd not in group:
group[wd] = {}
if lr not in group[wd]:
group[wd][lr] = []
group[wd][lr].append(p)
optimizer_params_group = []
for wd, lr_groups in group.items():
for lr, p in lr_groups.items():
optimizer_params_group.append(dict(
params=p,
weight_decay=wd,
lr=lr
))
logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}")
return optimizer_params_group
def create_optimizer(args, model, filter_bias_and_bn=True, return_group=False):
opt_lower = args.opt.lower()
weight_decay = args.weight_decay
# check for modules that requires different lr
if hasattr(args, "different_lr") and args.different_lr.enable:
diff_lr_module_names = args.different_lr.module_names
diff_lr = args.different_lr.lr
else:
diff_lr_module_names = []
diff_lr = None
no_decay = {}
if hasattr(model, 'no_weight_decay'):
no_decay = model.no_weight_decay()
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
if hasattr(model.module, 'no_weight_decay'):
no_decay = model.module.no_weight_decay()
no_decay = {"module." + k for k in no_decay}
named_param_tuples = add_weight_decay(
model, weight_decay, no_decay, filter_bias_and_bn)
named_param_tuples = add_different_lr(
named_param_tuples, diff_lr_module_names, diff_lr, args.lr)
parameters = create_optimizer_params_group(named_param_tuples)
if return_group:
return parameters
if 'fused' in opt_lower:
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
opt_args['eps'] = args.opt_eps
if hasattr(args, 'opt_betas') and args.opt_betas is not None:
opt_args['betas'] = args.opt_betas
if hasattr(args, 'opt_args') and args.opt_args is not None:
opt_args.update(args.opt_args)
opt_split = opt_lower.split('_')
opt_lower = opt_split[-1]
if opt_lower == 'sgd' or opt_lower == 'nesterov':
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
elif opt_lower == 'momentum':
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
elif opt_lower == 'adam':
optimizer = optim.Adam(parameters, **opt_args)
elif opt_lower == 'adamw':
optimizer = optim.AdamW(parameters, **opt_args)
else:
assert False and "Invalid optimizer"
raise ValueError
return optimizer
|