""" @date: 2021/7/19 @description: """ import torch import loss from utils.misc import tensor2np def build_criterion(config, logger): criterion = {} device = config.TRAIN.DEVICE for k in config.TRAIN.CRITERION.keys(): sc = config.TRAIN.CRITERION[k] if sc.WEIGHT is None or float(sc.WEIGHT) == 0: continue criterion[sc.NAME] = { 'loss': getattr(loss, sc.LOSS)(), 'weight': float(sc.WEIGHT), 'sub_weights': sc.WEIGHTS, 'need_all': sc.NEED_ALL } criterion[sc.NAME]['loss'] = criterion[sc.NAME]['loss'].to(device) if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device: criterion[sc.NAME]['loss'] = criterion[sc.NAME]['loss'].type(torch.float16) # logger.info(f"Build criterion:{sc.WEIGHT}_{sc.NAME}_{sc.LOSS}_{sc.WEIGHTS}") return criterion def calc_criterion(criterion, gt, dt, epoch_loss_d): loss = None postfix_d = {} for k in criterion.keys(): if criterion[k]['need_all']: single_loss = criterion[k]['loss'](gt, dt) ws_loss = None for i, sub_weight in enumerate(criterion[k]['sub_weights']): if sub_weight == 0: continue if ws_loss is None: ws_loss = single_loss[i] * sub_weight else: ws_loss = ws_loss + single_loss[i] * sub_weight single_loss = ws_loss if ws_loss is not None else single_loss else: assert k in gt.keys(), "ground label is None:" + k assert k in dt.keys(), "detection key is None:" + k if k == 'ratio' and gt[k].shape[-1] != dt[k].shape[-1]: gt[k] = gt[k].repeat(1, dt[k].shape[-1]) single_loss = criterion[k]['loss'](gt[k], dt[k]) postfix_d[k] = tensor2np(single_loss) if k not in epoch_loss_d.keys(): epoch_loss_d[k] = [] epoch_loss_d[k].append(postfix_d[k]) single_loss = single_loss * criterion[k]['weight'] if loss is None: loss = single_loss else: loss = loss + single_loss k = 'loss' postfix_d[k] = tensor2np(loss) if k not in epoch_loss_d.keys(): epoch_loss_d[k] = [] epoch_loss_d[k].append(postfix_d[k]) return loss, postfix_d, epoch_loss_d