Awiny's picture
first version submission
c3a1897
raw
history blame
3.37 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob/main/detic/custom_solver.py
import itertools
from typing import Any, Callable, Dict, Iterable, List, Set, Type, Union
import torch
from detectron2.config import CfgNode
from detectron2.solver.build import maybe_add_gradient_clipping
def build_custom_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
optimizer_type = cfg.SOLVER.OPTIMIZER
for key, value in model.named_parameters(recurse=True):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
lr = cfg.SOLVER.BASE_LR
weight_decay = cfg.SOLVER.WEIGHT_DECAY
if cfg.SOLVER.VIT_LAYER_DECAY:
lr = lr * get_vit_lr_decay_rate(key, cfg.SOLVER.VIT_LAYER_DECAY_RATE, cfg.MODEL.VIT_LAYERS)
param = {"params": [value], "lr": lr}
if optimizer_type != 'ADAMW':
param['weight_decay'] = weight_decay
params += [param]
def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
# detectron2 doesn't have full model gradient clipping now
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
enable = (
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
and clip_norm_val > 0.0
)
class FullModelGradientClippingOptimizer(optim):
def step(self, closure=None):
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
super().step(closure=closure)
return FullModelGradientClippingOptimizer if enable else optim
if optimizer_type == 'SGD':
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM,
nesterov=cfg.SOLVER.NESTEROV
)
elif optimizer_type == 'ADAMW':
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
params, cfg.SOLVER.BASE_LR,
weight_decay=cfg.SOLVER.WEIGHT_DECAY
)
else:
raise NotImplementedError(f"no optimizer type {optimizer_type}")
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
optimizer = maybe_add_gradient_clipping(cfg, optimizer)
return optimizer
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
"""
Calculate lr decay rate for different ViT blocks.
Args:
name (string): parameter name.
lr_decay_rate (float): base lr decay rate.
num_layers (int): number of ViT blocks.
Returns:
lr decay rate for the given parameter.
"""
layer_id = num_layers + 1
if name.startswith("backbone"):
if ".pos_embed" in name or ".patch_embed" in name:
layer_id = 0
elif ".blocks." in name and ".residual." not in name:
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
return lr_decay_rate ** (num_layers + 1 - layer_id)