File size: 12,872 Bytes
b213d84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 |
# Copyright (c) Facebook, Inc. and its affiliates.
import copy
import itertools
import logging
from collections import defaultdict
from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
import torch
from fvcore.common.param_scheduler import (
CosineParamScheduler,
MultiStepParamScheduler,
StepWithFixedGammaParamScheduler,
)
from detectron2.config import CfgNode
from detectron2.utils.env import TORCH_VERSION
from .lr_scheduler import LRMultiplier, LRScheduler, WarmupParamScheduler
_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]]
_GradientClipper = Callable[[_GradientClipperInput], None]
class GradientClipType(Enum):
VALUE = "value"
NORM = "norm"
def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper:
"""
Creates gradient clipping closure to clip by value or by norm,
according to the provided config.
"""
cfg = copy.deepcopy(cfg)
def clip_grad_norm(p: _GradientClipperInput):
torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE)
def clip_grad_value(p: _GradientClipperInput):
torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE)
_GRADIENT_CLIP_TYPE_TO_CLIPPER = {
GradientClipType.VALUE: clip_grad_value,
GradientClipType.NORM: clip_grad_norm,
}
return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)]
def _generate_optimizer_class_with_gradient_clipping(
optimizer: Type[torch.optim.Optimizer],
*,
per_param_clipper: Optional[_GradientClipper] = None,
global_clipper: Optional[_GradientClipper] = None,
) -> Type[torch.optim.Optimizer]:
"""
Dynamically creates a new type that inherits the type of a given instance
and overrides the `step` method to add gradient clipping
"""
assert (
per_param_clipper is None or global_clipper is None
), "Not allowed to use both per-parameter clipping and global clipping"
def optimizer_wgc_step(self, closure=None):
if per_param_clipper is not None:
for group in self.param_groups:
for p in group["params"]:
per_param_clipper(p)
else:
# global clipper for future use with detr
# (https://github.com/facebookresearch/detr/pull/287)
all_params = itertools.chain(*[g["params"] for g in self.param_groups])
global_clipper(all_params)
super(type(self), self).step(closure)
OptimizerWithGradientClip = type(
optimizer.__name__ + "WithGradientClip",
(optimizer,),
{"step": optimizer_wgc_step},
)
return OptimizerWithGradientClip
def maybe_add_gradient_clipping(
cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]
) -> Type[torch.optim.Optimizer]:
"""
If gradient clipping is enabled through config options, wraps the existing
optimizer type to become a new dynamically created class OptimizerWithGradientClip
that inherits the given optimizer and overrides the `step` method to
include gradient clipping.
Args:
cfg: CfgNode, configuration options
optimizer: type. A subclass of torch.optim.Optimizer
Return:
type: either the input `optimizer` (if gradient clipping is disabled), or
a subclass of it with gradient clipping included in the `step` method.
"""
if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
return optimizer
if isinstance(optimizer, torch.optim.Optimizer):
optimizer_type = type(optimizer)
else:
assert issubclass(optimizer, torch.optim.Optimizer), optimizer
optimizer_type = optimizer
grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS)
OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping(
optimizer_type, per_param_clipper=grad_clipper
)
if isinstance(optimizer, torch.optim.Optimizer):
optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended
return optimizer
else:
return OptimizerWithGradientClip
def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build an optimizer from config.
"""
params = get_default_optimizer_params(
model,
base_lr=cfg.SOLVER.BASE_LR,
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
)
sgd_args = {
"params": params,
"lr": cfg.SOLVER.BASE_LR,
"momentum": cfg.SOLVER.MOMENTUM,
"nesterov": cfg.SOLVER.NESTEROV,
"weight_decay": cfg.SOLVER.WEIGHT_DECAY,
}
if TORCH_VERSION >= (1, 12):
sgd_args["foreach"] = True
return maybe_add_gradient_clipping(cfg, torch.optim.SGD(**sgd_args))
def get_default_optimizer_params(
model: torch.nn.Module,
base_lr: Optional[float] = None,
weight_decay: Optional[float] = None,
weight_decay_norm: Optional[float] = None,
bias_lr_factor: Optional[float] = 1.0,
weight_decay_bias: Optional[float] = None,
lr_factor_func: Optional[Callable] = None,
overrides: Optional[Dict[str, Dict[str, float]]] = None,
) -> List[Dict[str, Any]]:
"""
Get default param list for optimizer, with support for a few types of
overrides. If no overrides needed, this is equivalent to `model.parameters()`.
Args:
base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
weight_decay: weight decay for every group by default. Can be omitted to use the one
in optimizer.
weight_decay_norm: override weight decay for params in normalization layers
bias_lr_factor: multiplier of lr for bias parameters.
weight_decay_bias: override weight decay for bias parameters.
lr_factor_func: function to calculate lr decay rate by mapping the parameter names to
corresponding lr decay rate. Note that setting this option requires
also setting ``base_lr``.
overrides: if not `None`, provides values for optimizer hyperparameters
(LR, weight decay) for module parameters with a given name; e.g.
``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
weight decay values for all module parameters named `embedding`.
For common detection models, ``weight_decay_norm`` is the only option
needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings
from Detectron1 that are not found useful.
Example:
::
torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0),
lr=0.01, weight_decay=1e-4, momentum=0.9)
"""
if overrides is None:
overrides = {}
defaults = {}
if base_lr is not None:
defaults["lr"] = base_lr
if weight_decay is not None:
defaults["weight_decay"] = weight_decay
bias_overrides = {}
if bias_lr_factor is not None and bias_lr_factor != 1.0:
# NOTE: unlike Detectron v1, we now by default make bias hyperparameters
# exactly the same as regular weights.
if base_lr is None:
raise ValueError("bias_lr_factor requires base_lr")
bias_overrides["lr"] = base_lr * bias_lr_factor
if weight_decay_bias is not None:
bias_overrides["weight_decay"] = weight_decay_bias
if len(bias_overrides):
if "bias" in overrides:
raise ValueError("Conflicting overrides for 'bias'")
overrides["bias"] = bias_overrides
if lr_factor_func is not None:
if base_lr is None:
raise ValueError("lr_factor_func requires base_lr")
norm_module_types = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
)
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
hyperparams = copy.copy(defaults)
if isinstance(module, norm_module_types) and weight_decay_norm is not None:
hyperparams["weight_decay"] = weight_decay_norm
if lr_factor_func is not None:
hyperparams["lr"] *= lr_factor_func(f"{module_name}.{module_param_name}")
hyperparams.update(overrides.get(module_param_name, {}))
params.append({"params": [value], **hyperparams})
return reduce_param_groups(params)
def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# Transform parameter groups into per-parameter structure.
# Later items in `params` can overwrite parameters set in previous items.
ret = defaultdict(dict)
for item in params:
assert "params" in item
cur_params = {x: y for x, y in item.items() if x != "params" and x != "param_names"}
if "param_names" in item:
for param_name, param in zip(item["param_names"], item["params"]):
ret[param].update({"param_names": [param_name], "params": [param], **cur_params})
else:
for param in item["params"]:
ret[param].update({"params": [param], **cur_params})
return list(ret.values())
def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# Reorganize the parameter groups and merge duplicated groups.
# The number of parameter groups needs to be as small as possible in order
# to efficiently use the PyTorch multi-tensor optimizer. Therefore instead
# of using a parameter_group per single parameter, we reorganize the
# parameter groups and merge duplicated groups. This approach speeds
# up multi-tensor optimizer significantly.
params = _expand_param_groups(params)
groups = defaultdict(list) # re-group all parameter groups by their hyperparams
for item in params:
cur_params = tuple((x, y) for x, y in item.items() if x != "params" and x != "param_names")
groups[cur_params].append({"params": item["params"]})
if "param_names" in item:
groups[cur_params][-1]["param_names"] = item["param_names"]
ret = []
for param_keys, param_values in groups.items():
cur = {kv[0]: kv[1] for kv in param_keys}
cur["params"] = list(
itertools.chain.from_iterable([params["params"] for params in param_values])
)
if len(param_values) > 0 and "param_names" in param_values[0]:
cur["param_names"] = list(
itertools.chain.from_iterable([params["param_names"] for params in param_values])
)
ret.append(cur)
return ret
def build_lr_scheduler(cfg: CfgNode, optimizer: torch.optim.Optimizer) -> LRScheduler:
"""
Build a LR scheduler from config.
"""
name = cfg.SOLVER.LR_SCHEDULER_NAME
if name == "WarmupMultiStepLR":
steps = [x for x in cfg.SOLVER.STEPS if x <= cfg.SOLVER.MAX_ITER]
if len(steps) != len(cfg.SOLVER.STEPS):
logger = logging.getLogger(__name__)
logger.warning(
"SOLVER.STEPS contains values larger than SOLVER.MAX_ITER. "
"These values will be ignored."
)
sched = MultiStepParamScheduler(
values=[cfg.SOLVER.GAMMA**k for k in range(len(steps) + 1)],
milestones=steps,
num_updates=cfg.SOLVER.MAX_ITER,
)
elif name == "WarmupCosineLR":
end_value = cfg.SOLVER.BASE_LR_END / cfg.SOLVER.BASE_LR
assert end_value >= 0.0 and end_value <= 1.0, end_value
sched = CosineParamScheduler(1, end_value)
elif name == "WarmupStepWithFixedGammaLR":
sched = StepWithFixedGammaParamScheduler(
base_value=1.0,
gamma=cfg.SOLVER.GAMMA,
num_decays=cfg.SOLVER.NUM_DECAYS,
num_updates=cfg.SOLVER.MAX_ITER,
)
else:
raise ValueError("Unknown LR scheduler: {}".format(name))
sched = WarmupParamScheduler(
sched,
cfg.SOLVER.WARMUP_FACTOR,
min(cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER, 1.0),
cfg.SOLVER.WARMUP_METHOD,
cfg.SOLVER.RESCALE_INTERVAL,
)
return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER)
|