Spaces:
Running
on
A10G
Running
on
A10G
import logging | |
log = logging.getLogger() | |
def get_parameter_groups(model, stage_cfg, print_log=False): | |
""" | |
Assign different weight decays and learning rates to different parameters. | |
Returns a parameter group which can be passed to the optimizer. | |
""" | |
weight_decay = stage_cfg.weight_decay | |
embed_weight_decay = stage_cfg.embed_weight_decay | |
backbone_lr_ratio = stage_cfg.backbone_lr_ratio | |
base_lr = stage_cfg.learning_rate | |
backbone_params = [] | |
embed_params = [] | |
other_params = [] | |
embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe'] | |
embedding_names = [e + '.weight' for e in embedding_names] | |
# inspired by detectron2 | |
memo = set() | |
for name, param in model.named_parameters(): | |
if not param.requires_grad: | |
continue | |
# Avoid duplicating parameters | |
if param in memo: | |
continue | |
memo.add(param) | |
if name.startswith('module'): | |
name = name[7:] | |
inserted = False | |
if name.startswith('pixel_encoder.'): | |
backbone_params.append(param) | |
inserted = True | |
if print_log: | |
log.info(f'{name} counted as a backbone parameter.') | |
else: | |
for e in embedding_names: | |
if name.endswith(e): | |
embed_params.append(param) | |
inserted = True | |
if print_log: | |
log.info(f'{name} counted as an embedding parameter.') | |
break | |
if not inserted: | |
other_params.append(param) | |
parameter_groups = [ | |
{ | |
'params': backbone_params, | |
'lr': base_lr * backbone_lr_ratio, | |
'weight_decay': weight_decay | |
}, | |
{ | |
'params': embed_params, | |
'lr': base_lr, | |
'weight_decay': embed_weight_decay | |
}, | |
{ | |
'params': other_params, | |
'lr': base_lr, | |
'weight_decay': weight_decay | |
}, | |
] | |
return parameter_groups |