Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# References: | |
# ELECTRA https://github.com/google-research/electra | |
# BEiT: https://github.com/microsoft/unilm/tree/master/beit | |
# -------------------------------------------------------- | |
import json | |
def param_groups_lrd( | |
model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75 | |
): | |
""" | |
Parameter groups for layer-wise lr decay | |
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 | |
""" | |
param_group_names = {} | |
param_groups = {} | |
num_layers = len(model.blocks) + 1 | |
layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) | |
for n, p in model.named_parameters(): | |
if not p.requires_grad: | |
continue | |
# no decay: all 1D parameters and model specific ones | |
if p.ndim == 1 or n in no_weight_decay_list: | |
g_decay = "no_decay" | |
this_decay = 0.0 | |
else: | |
g_decay = "decay" | |
this_decay = weight_decay | |
layer_id = get_layer_id_for_vit(n, num_layers) | |
group_name = "layer_%d_%s" % (layer_id, g_decay) | |
if group_name not in param_group_names: | |
this_scale = layer_scales[layer_id] | |
param_group_names[group_name] = { | |
"lr_scale": this_scale, | |
"weight_decay": this_decay, | |
"params": [], | |
} | |
param_groups[group_name] = { | |
"lr_scale": this_scale, | |
"weight_decay": this_decay, | |
"params": [], | |
} | |
param_group_names[group_name]["params"].append(n) | |
param_groups[group_name]["params"].append(p) | |
# print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) | |
return list(param_groups.values()) | |
def get_layer_id_for_vit(name, num_layers): | |
""" | |
Assign a parameter with its layer id | |
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 | |
""" | |
if name in ["cls_token", "pos_embed"]: | |
return 0 | |
elif name.startswith("patch_embed"): | |
return 0 | |
elif name.startswith("blocks"): | |
return int(name.split(".")[1]) + 1 | |
else: | |
return num_layers | |