|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from utils.general import bbox_iou, bbox_alpha_iou, box_iou, box_giou, box_diou, box_ciou, xywh2xyxy |
|
from utils.torch_utils import is_parallel |
|
|
|
|
|
def smooth_BCE(eps=0.1): |
|
|
|
return 1.0 - 0.5 * eps, 0.5 * eps |
|
|
|
|
|
class BCEBlurWithLogitsLoss(nn.Module): |
|
|
|
def __init__(self, alpha=0.05): |
|
super(BCEBlurWithLogitsLoss, self).__init__() |
|
self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') |
|
self.alpha = alpha |
|
|
|
def forward(self, pred, true): |
|
loss = self.loss_fcn(pred, true) |
|
pred = torch.sigmoid(pred) |
|
dx = pred - true |
|
|
|
alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4)) |
|
loss *= alpha_factor |
|
return loss.mean() |
|
|
|
|
|
class SigmoidBin(nn.Module): |
|
stride = None |
|
export = False |
|
|
|
def __init__(self, bin_count=10, min=0.0, max=1.0, reg_scale = 2.0, use_loss_regression=True, use_fw_regression=True, BCE_weight=1.0, smooth_eps=0.0): |
|
super(SigmoidBin, self).__init__() |
|
|
|
self.bin_count = bin_count |
|
self.length = bin_count + 1 |
|
self.min = min |
|
self.max = max |
|
self.scale = float(max - min) |
|
self.shift = self.scale / 2.0 |
|
|
|
self.use_loss_regression = use_loss_regression |
|
self.use_fw_regression = use_fw_regression |
|
self.reg_scale = reg_scale |
|
self.BCE_weight = BCE_weight |
|
|
|
start = min + (self.scale/2.0) / self.bin_count |
|
end = max - (self.scale/2.0) / self.bin_count |
|
step = self.scale / self.bin_count |
|
self.step = step |
|
|
|
|
|
bins = torch.range(start, end + 0.0001, step).float() |
|
self.register_buffer('bins', bins) |
|
|
|
|
|
self.cp = 1.0 - 0.5 * smooth_eps |
|
self.cn = 0.5 * smooth_eps |
|
|
|
self.BCEbins = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([BCE_weight])) |
|
self.MSELoss = nn.MSELoss() |
|
|
|
def get_length(self): |
|
return self.length |
|
|
|
def forward(self, pred): |
|
assert pred.shape[-1] == self.length, 'pred.shape[-1]=%d is not equal to self.length=%d' % (pred.shape[-1], self.length) |
|
|
|
pred_reg = (pred[..., 0] * self.reg_scale - self.reg_scale/2.0) * self.step |
|
pred_bin = pred[..., 1:(1+self.bin_count)] |
|
|
|
_, bin_idx = torch.max(pred_bin, dim=-1) |
|
bin_bias = self.bins[bin_idx] |
|
|
|
if self.use_fw_regression: |
|
result = pred_reg + bin_bias |
|
else: |
|
result = bin_bias |
|
result = result.clamp(min=self.min, max=self.max) |
|
|
|
return result |
|
|
|
|
|
def training_loss(self, pred, target): |
|
assert pred.shape[-1] == self.length, 'pred.shape[-1]=%d is not equal to self.length=%d' % (pred.shape[-1], self.length) |
|
assert pred.shape[0] == target.shape[0], 'pred.shape=%d is not equal to the target.shape=%d' % (pred.shape[0], target.shape[0]) |
|
device = pred.device |
|
|
|
pred_reg = (pred[..., 0].sigmoid() * self.reg_scale - self.reg_scale/2.0) * self.step |
|
pred_bin = pred[..., 1:(1+self.bin_count)] |
|
|
|
diff_bin_target = torch.abs(target[..., None] - self.bins) |
|
_, bin_idx = torch.min(diff_bin_target, dim=-1) |
|
|
|
bin_bias = self.bins[bin_idx] |
|
bin_bias.requires_grad = False |
|
result = pred_reg + bin_bias |
|
|
|
target_bins = torch.full_like(pred_bin, self.cn, device=device) |
|
n = pred.shape[0] |
|
target_bins[range(n), bin_idx] = self.cp |
|
|
|
loss_bin = self.BCEbins(pred_bin, target_bins) |
|
|
|
if self.use_loss_regression: |
|
loss_regression = self.MSELoss(result, target) |
|
loss = loss_bin + loss_regression |
|
else: |
|
loss = loss_bin |
|
|
|
out_result = result.clamp(min=self.min, max=self.max) |
|
|
|
return loss, out_result |
|
|
|
|
|
class FocalLoss(nn.Module): |
|
|
|
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): |
|
super(FocalLoss, self).__init__() |
|
self.loss_fcn = loss_fcn |
|
self.gamma = gamma |
|
self.alpha = alpha |
|
self.reduction = loss_fcn.reduction |
|
self.loss_fcn.reduction = 'none' |
|
|
|
def forward(self, pred, true): |
|
loss = self.loss_fcn(pred, true) |
|
|
|
|
|
|
|
|
|
pred_prob = torch.sigmoid(pred) |
|
p_t = true * pred_prob + (1 - true) * (1 - pred_prob) |
|
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) |
|
modulating_factor = (1.0 - p_t) ** self.gamma |
|
loss *= alpha_factor * modulating_factor |
|
|
|
if self.reduction == 'mean': |
|
return loss.mean() |
|
elif self.reduction == 'sum': |
|
return loss.sum() |
|
else: |
|
return loss |
|
|
|
|
|
class QFocalLoss(nn.Module): |
|
|
|
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): |
|
super(QFocalLoss, self).__init__() |
|
self.loss_fcn = loss_fcn |
|
self.gamma = gamma |
|
self.alpha = alpha |
|
self.reduction = loss_fcn.reduction |
|
self.loss_fcn.reduction = 'none' |
|
|
|
def forward(self, pred, true): |
|
loss = self.loss_fcn(pred, true) |
|
|
|
pred_prob = torch.sigmoid(pred) |
|
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) |
|
modulating_factor = torch.abs(true - pred_prob) ** self.gamma |
|
loss *= alpha_factor * modulating_factor |
|
|
|
if self.reduction == 'mean': |
|
return loss.mean() |
|
elif self.reduction == 'sum': |
|
return loss.sum() |
|
else: |
|
return loss |
|
|
|
class RankSort(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, logits, targets, delta_RS=0.50, eps=1e-10): |
|
|
|
classification_grads=torch.zeros(logits.shape).cuda() |
|
|
|
|
|
fg_labels = (targets > 0.) |
|
fg_logits = logits[fg_labels] |
|
fg_targets = targets[fg_labels] |
|
fg_num = len(fg_logits) |
|
|
|
|
|
|
|
threshold_logit = torch.min(fg_logits)-delta_RS |
|
relevant_bg_labels=((targets==0) & (logits>=threshold_logit)) |
|
|
|
relevant_bg_logits = logits[relevant_bg_labels] |
|
relevant_bg_grad=torch.zeros(len(relevant_bg_logits)).cuda() |
|
sorting_error=torch.zeros(fg_num).cuda() |
|
ranking_error=torch.zeros(fg_num).cuda() |
|
fg_grad=torch.zeros(fg_num).cuda() |
|
|
|
|
|
order=torch.argsort(fg_logits) |
|
|
|
for ii in order: |
|
|
|
fg_relations=fg_logits-fg_logits[ii] |
|
bg_relations=relevant_bg_logits-fg_logits[ii] |
|
|
|
if delta_RS > 0: |
|
fg_relations=torch.clamp(fg_relations/(2*delta_RS)+0.5,min=0,max=1) |
|
bg_relations=torch.clamp(bg_relations/(2*delta_RS)+0.5,min=0,max=1) |
|
else: |
|
fg_relations = (fg_relations >= 0).float() |
|
bg_relations = (bg_relations >= 0).float() |
|
|
|
|
|
rank_pos=torch.sum(fg_relations) |
|
FP_num=torch.sum(bg_relations) |
|
|
|
|
|
rank=rank_pos+FP_num |
|
|
|
|
|
ranking_error[ii]=FP_num/rank |
|
|
|
|
|
current_sorting_error = torch.sum(fg_relations*(1-fg_targets))/rank_pos |
|
|
|
|
|
iou_relations = (fg_targets >= fg_targets[ii]) |
|
target_sorted_order = iou_relations * fg_relations |
|
|
|
|
|
rank_pos_target = torch.sum(target_sorted_order) |
|
|
|
|
|
|
|
target_sorting_error= torch.sum(target_sorted_order*(1-fg_targets))/rank_pos_target |
|
|
|
|
|
sorting_error[ii] = current_sorting_error - target_sorting_error |
|
|
|
|
|
if FP_num > eps: |
|
|
|
fg_grad[ii] -= ranking_error[ii] |
|
|
|
relevant_bg_grad += (bg_relations*(ranking_error[ii]/FP_num)) |
|
|
|
|
|
|
|
missorted_examples = (~ iou_relations) * fg_relations |
|
|
|
|
|
sorting_pmf_denom = torch.sum(missorted_examples) |
|
|
|
|
|
if sorting_pmf_denom > eps: |
|
|
|
fg_grad[ii] -= sorting_error[ii] |
|
|
|
fg_grad += (missorted_examples*(sorting_error[ii]/sorting_pmf_denom)) |
|
|
|
|
|
classification_grads[fg_labels]= (fg_grad/fg_num) |
|
classification_grads[relevant_bg_labels]= (relevant_bg_grad/fg_num) |
|
|
|
ctx.save_for_backward(classification_grads) |
|
|
|
return ranking_error.mean(), sorting_error.mean() |
|
|
|
@staticmethod |
|
def backward(ctx, out_grad1, out_grad2): |
|
g1, =ctx.saved_tensors |
|
return g1*out_grad1, None, None, None |
|
|
|
class aLRPLoss(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, logits, targets, regression_losses, delta=1., eps=1e-5): |
|
classification_grads=torch.zeros(logits.shape).cuda() |
|
|
|
|
|
fg_labels = (targets == 1) |
|
fg_logits = logits[fg_labels] |
|
fg_num = len(fg_logits) |
|
|
|
|
|
|
|
threshold_logit = torch.min(fg_logits)-delta |
|
|
|
|
|
relevant_bg_labels=((targets==0)&(logits>=threshold_logit)) |
|
relevant_bg_logits=logits[relevant_bg_labels] |
|
relevant_bg_grad=torch.zeros(len(relevant_bg_logits)).cuda() |
|
rank=torch.zeros(fg_num).cuda() |
|
prec=torch.zeros(fg_num).cuda() |
|
fg_grad=torch.zeros(fg_num).cuda() |
|
|
|
max_prec=0 |
|
|
|
order=torch.argsort(fg_logits) |
|
|
|
for ii in order: |
|
|
|
fg_relations=fg_logits-fg_logits[ii] |
|
|
|
fg_relations=torch.clamp(fg_relations/(2*delta)+0.5,min=0,max=1) |
|
|
|
fg_relations[ii]=0 |
|
|
|
|
|
bg_relations=relevant_bg_logits-fg_logits[ii] |
|
|
|
bg_relations=torch.clamp(bg_relations/(2*delta)+0.5,min=0,max=1) |
|
|
|
|
|
rank_pos=1+torch.sum(fg_relations) |
|
FP_num=torch.sum(bg_relations) |
|
|
|
rank[ii]=rank_pos+FP_num |
|
|
|
|
|
prec[ii]=rank_pos/rank[ii] |
|
|
|
if FP_num > eps: |
|
fg_grad[ii] = -(torch.sum(fg_relations*regression_losses)+FP_num)/rank[ii] |
|
relevant_bg_grad += (bg_relations*(-fg_grad[ii]/FP_num)) |
|
|
|
|
|
classification_grads[fg_labels]= fg_grad |
|
|
|
classification_grads[relevant_bg_labels]= relevant_bg_grad |
|
|
|
classification_grads /= (fg_num) |
|
|
|
cls_loss=1-prec.mean() |
|
ctx.save_for_backward(classification_grads) |
|
|
|
return cls_loss, rank, order |
|
|
|
@staticmethod |
|
def backward(ctx, out_grad1, out_grad2, out_grad3): |
|
g1, =ctx.saved_tensors |
|
return g1*out_grad1, None, None, None, None |
|
|
|
|
|
class APLoss(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, logits, targets, delta=1.): |
|
classification_grads=torch.zeros(logits.shape).cuda() |
|
|
|
|
|
fg_labels = (targets == 1) |
|
fg_logits = logits[fg_labels] |
|
fg_num = len(fg_logits) |
|
|
|
|
|
|
|
threshold_logit = torch.min(fg_logits)-delta |
|
|
|
|
|
relevant_bg_labels=((targets==0)&(logits>=threshold_logit)) |
|
relevant_bg_logits=logits[relevant_bg_labels] |
|
relevant_bg_grad=torch.zeros(len(relevant_bg_logits)).cuda() |
|
rank=torch.zeros(fg_num).cuda() |
|
prec=torch.zeros(fg_num).cuda() |
|
fg_grad=torch.zeros(fg_num).cuda() |
|
|
|
max_prec=0 |
|
|
|
order=torch.argsort(fg_logits) |
|
|
|
for ii in order: |
|
|
|
fg_relations=fg_logits-fg_logits[ii] |
|
|
|
fg_relations=torch.clamp(fg_relations/(2*delta)+0.5,min=0,max=1) |
|
|
|
fg_relations[ii]=0 |
|
|
|
|
|
bg_relations=relevant_bg_logits-fg_logits[ii] |
|
|
|
bg_relations=torch.clamp(bg_relations/(2*delta)+0.5,min=0,max=1) |
|
|
|
|
|
rank_pos=1+torch.sum(fg_relations) |
|
FP_num=torch.sum(bg_relations) |
|
|
|
rank[ii]=rank_pos+FP_num |
|
|
|
|
|
current_prec=rank_pos/rank[ii] |
|
|
|
|
|
if (max_prec<=current_prec): |
|
max_prec=current_prec |
|
relevant_bg_grad += (bg_relations/rank[ii]) |
|
else: |
|
relevant_bg_grad += (bg_relations/rank[ii])*(((1-max_prec)/(1-current_prec))) |
|
|
|
|
|
fg_grad[ii]=-(1-max_prec) |
|
prec[ii]=max_prec |
|
|
|
|
|
classification_grads[fg_labels]= fg_grad |
|
|
|
classification_grads[relevant_bg_labels]= relevant_bg_grad |
|
|
|
classification_grads /= fg_num |
|
|
|
cls_loss=1-prec.mean() |
|
ctx.save_for_backward(classification_grads) |
|
|
|
return cls_loss |
|
|
|
@staticmethod |
|
def backward(ctx, out_grad1): |
|
g1, =ctx.saved_tensors |
|
return g1*out_grad1, None, None |
|
|
|
|
|
class ComputeLoss: |
|
|
|
def __init__(self, model, autobalance=False): |
|
super(ComputeLoss, self).__init__() |
|
device = next(model.parameters()).device |
|
h = model.hyp |
|
|
|
|
|
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device)) |
|
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device)) |
|
|
|
|
|
self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) |
|
|
|
|
|
g = h['fl_gamma'] |
|
if g > 0: |
|
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) |
|
|
|
det = model.module.model[-1] if is_parallel(model) else model.model[-1] |
|
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) |
|
|
|
|
|
self.ssi = list(det.stride).index(16) if autobalance else 0 |
|
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance |
|
for k in 'na', 'nc', 'nl', 'anchors': |
|
setattr(self, k, getattr(det, k)) |
|
|
|
def __call__(self, p, targets): |
|
device = targets.device |
|
lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device) |
|
tcls, tbox, indices, anchors = self.build_targets(p, targets) |
|
|
|
|
|
for i, pi in enumerate(p): |
|
b, a, gj, gi = indices[i] |
|
tobj = torch.zeros_like(pi[..., 0], device=device) |
|
|
|
n = b.shape[0] |
|
if n: |
|
ps = pi[b, a, gj, gi] |
|
|
|
|
|
pxy = ps[:, :2].sigmoid() * 2. - 0.5 |
|
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] |
|
pbox = torch.cat((pxy, pwh), 1) |
|
iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) |
|
lbox += (1.0 - iou).mean() |
|
|
|
|
|
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) |
|
|
|
|
|
if self.nc > 1: |
|
t = torch.full_like(ps[:, 5:], self.cn, device=device) |
|
t[range(n), tcls[i]] = self.cp |
|
|
|
lcls += self.BCEcls(ps[:, 5:], t) |
|
|
|
|
|
|
|
|
|
|
|
obji = self.BCEobj(pi[..., 4], tobj) |
|
lobj += obji * self.balance[i] |
|
if self.autobalance: |
|
self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item() |
|
|
|
if self.autobalance: |
|
self.balance = [x / self.balance[self.ssi] for x in self.balance] |
|
lbox *= self.hyp['box'] |
|
lobj *= self.hyp['obj'] |
|
lcls *= self.hyp['cls'] |
|
bs = tobj.shape[0] |
|
|
|
loss = lbox + lobj + lcls |
|
return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach() |
|
|
|
def build_targets(self, p, targets): |
|
|
|
na, nt = self.na, targets.shape[0] |
|
tcls, tbox, indices, anch = [], [], [], [] |
|
gain = torch.ones(7, device=targets.device) |
|
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) |
|
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) |
|
|
|
g = 0.5 |
|
off = torch.tensor([[0, 0], |
|
[1, 0], [0, 1], [-1, 0], [0, -1], |
|
|
|
], device=targets.device).float() * g |
|
|
|
for i in range(self.nl): |
|
anchors = self.anchors[i] |
|
gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] |
|
|
|
|
|
t = targets * gain |
|
if nt: |
|
|
|
r = t[:, :, 4:6] / anchors[:, None] |
|
j = torch.max(r, 1. / r).max(2)[0] < self.hyp['anchor_t'] |
|
|
|
t = t[j] |
|
|
|
|
|
gxy = t[:, 2:4] |
|
gxi = gain[[2, 3]] - gxy |
|
j, k = ((gxy % 1. < g) & (gxy > 1.)).T |
|
l, m = ((gxi % 1. < g) & (gxi > 1.)).T |
|
j = torch.stack((torch.ones_like(j), j, k, l, m)) |
|
t = t.repeat((5, 1, 1))[j] |
|
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] |
|
else: |
|
t = targets[0] |
|
offsets = 0 |
|
|
|
|
|
b, c = t[:, :2].long().T |
|
gxy = t[:, 2:4] |
|
gwh = t[:, 4:6] |
|
gij = (gxy - offsets).long() |
|
gi, gj = gij.T |
|
|
|
|
|
a = t[:, 6].long() |
|
indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) |
|
tbox.append(torch.cat((gxy - gij, gwh), 1)) |
|
anch.append(anchors[a]) |
|
tcls.append(c) |
|
|
|
return tcls, tbox, indices, anch |
|
|
|
|
|
class ComputeLossOTA: |
|
|
|
def __init__(self, model, autobalance=False): |
|
super(ComputeLossOTA, self).__init__() |
|
device = next(model.parameters()).device |
|
h = model.hyp |
|
|
|
|
|
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device)) |
|
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device)) |
|
|
|
|
|
self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) |
|
|
|
|
|
g = h['fl_gamma'] |
|
if g > 0: |
|
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) |
|
|
|
det = model.module.model[-1] if is_parallel(model) else model.model[-1] |
|
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) |
|
self.ssi = list(det.stride).index(16) if autobalance else 0 |
|
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance |
|
for k in 'na', 'nc', 'nl', 'anchors', 'stride': |
|
setattr(self, k, getattr(det, k)) |
|
|
|
def __call__(self, p, targets, imgs): |
|
device = targets.device |
|
lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device) |
|
bs, as_, gjs, gis, targets, anchors = self.build_targets(p, targets, imgs) |
|
pre_gen_gains = [torch.tensor(pp.shape, device=device)[[3, 2, 3, 2]] for pp in p] |
|
|
|
|
|
|
|
for i, pi in enumerate(p): |
|
b, a, gj, gi = bs[i], as_[i], gjs[i], gis[i] |
|
tobj = torch.zeros_like(pi[..., 0], device=device) |
|
|
|
n = b.shape[0] |
|
if n: |
|
ps = pi[b, a, gj, gi] |
|
|
|
|
|
grid = torch.stack([gi, gj], dim=1) |
|
pxy = ps[:, :2].sigmoid() * 2. - 0.5 |
|
|
|
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] |
|
pbox = torch.cat((pxy, pwh), 1) |
|
selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i] |
|
selected_tbox[:, :2] -= grid |
|
iou = bbox_iou(pbox.T, selected_tbox, x1y1x2y2=False, CIoU=True) |
|
lbox += (1.0 - iou).mean() |
|
|
|
|
|
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) |
|
|
|
|
|
selected_tcls = targets[i][:, 1].long() |
|
if self.nc > 1: |
|
t = torch.full_like(ps[:, 5:], self.cn, device=device) |
|
t[range(n), selected_tcls] = self.cp |
|
lcls += self.BCEcls(ps[:, 5:], t) |
|
|
|
|
|
|
|
|
|
|
|
obji = self.BCEobj(pi[..., 4], tobj) |
|
lobj += obji * self.balance[i] |
|
if self.autobalance: |
|
self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item() |
|
|
|
if self.autobalance: |
|
self.balance = [x / self.balance[self.ssi] for x in self.balance] |
|
lbox *= self.hyp['box'] |
|
lobj *= self.hyp['obj'] |
|
lcls *= self.hyp['cls'] |
|
bs = tobj.shape[0] |
|
|
|
loss = lbox + lobj + lcls |
|
return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach() |
|
|
|
def build_targets(self, p, targets, imgs): |
|
|
|
|
|
indices, anch = self.find_3_positive(p, targets) |
|
|
|
|
|
|
|
|
|
matching_bs = [[] for pp in p] |
|
matching_as = [[] for pp in p] |
|
matching_gjs = [[] for pp in p] |
|
matching_gis = [[] for pp in p] |
|
matching_targets = [[] for pp in p] |
|
matching_anchs = [[] for pp in p] |
|
|
|
nl = len(p) |
|
|
|
for batch_idx in range(p[0].shape[0]): |
|
|
|
b_idx = targets[:, 0]==batch_idx |
|
this_target = targets[b_idx] |
|
if this_target.shape[0] == 0: |
|
continue |
|
|
|
txywh = this_target[:, 2:6] * imgs[batch_idx].shape[1] |
|
txyxy = xywh2xyxy(txywh) |
|
|
|
pxyxys = [] |
|
p_cls = [] |
|
p_obj = [] |
|
from_which_layer = [] |
|
all_b = [] |
|
all_a = [] |
|
all_gj = [] |
|
all_gi = [] |
|
all_anch = [] |
|
|
|
for i, pi in enumerate(p): |
|
|
|
b, a, gj, gi = indices[i] |
|
idx = (b == batch_idx) |
|
b, a, gj, gi = b[idx], a[idx], gj[idx], gi[idx] |
|
all_b.append(b) |
|
all_a.append(a) |
|
all_gj.append(gj) |
|
all_gi.append(gi) |
|
all_anch.append(anch[i][idx]) |
|
from_which_layer.append(torch.ones(size=(len(b),)) * i) |
|
|
|
fg_pred = pi[b, a, gj, gi] |
|
p_obj.append(fg_pred[:, 4:5]) |
|
p_cls.append(fg_pred[:, 5:]) |
|
|
|
grid = torch.stack([gi, gj], dim=1) |
|
pxy = (fg_pred[:, :2].sigmoid() * 2. - 0.5 + grid) * self.stride[i] |
|
|
|
pwh = (fg_pred[:, 2:4].sigmoid() * 2) ** 2 * anch[i][idx] * self.stride[i] |
|
pxywh = torch.cat([pxy, pwh], dim=-1) |
|
pxyxy = xywh2xyxy(pxywh) |
|
pxyxys.append(pxyxy) |
|
|
|
pxyxys = torch.cat(pxyxys, dim=0) |
|
if pxyxys.shape[0] == 0: |
|
continue |
|
p_obj = torch.cat(p_obj, dim=0) |
|
p_cls = torch.cat(p_cls, dim=0) |
|
from_which_layer = torch.cat(from_which_layer, dim=0) |
|
all_b = torch.cat(all_b, dim=0) |
|
all_a = torch.cat(all_a, dim=0) |
|
all_gj = torch.cat(all_gj, dim=0) |
|
all_gi = torch.cat(all_gi, dim=0) |
|
all_anch = torch.cat(all_anch, dim=0) |
|
|
|
pair_wise_iou = box_iou(txyxy, pxyxys) |
|
|
|
pair_wise_iou_loss = -torch.log(pair_wise_iou + 1e-8) |
|
|
|
top_k, _ = torch.topk(pair_wise_iou, min(10, pair_wise_iou.shape[1]), dim=1) |
|
dynamic_ks = torch.clamp(top_k.sum(1).int(), min=1) |
|
|
|
gt_cls_per_image = ( |
|
F.one_hot(this_target[:, 1].to(torch.int64), self.nc) |
|
.float() |
|
.unsqueeze(1) |
|
.repeat(1, pxyxys.shape[0], 1) |
|
) |
|
|
|
num_gt = this_target.shape[0] |
|
cls_preds_ = ( |
|
p_cls.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() |
|
* p_obj.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() |
|
) |
|
|
|
y = cls_preds_.sqrt_() |
|
pair_wise_cls_loss = F.binary_cross_entropy_with_logits( |
|
torch.log(y/(1-y)) , gt_cls_per_image, reduction="none" |
|
).sum(-1) |
|
del cls_preds_ |
|
|
|
cost = ( |
|
pair_wise_cls_loss |
|
+ 3.0 * pair_wise_iou_loss |
|
) |
|
|
|
matching_matrix = torch.zeros_like(cost) |
|
|
|
for gt_idx in range(num_gt): |
|
_, pos_idx = torch.topk( |
|
cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False |
|
) |
|
matching_matrix[gt_idx][pos_idx] = 1.0 |
|
|
|
del top_k, dynamic_ks |
|
anchor_matching_gt = matching_matrix.sum(0) |
|
if (anchor_matching_gt > 1).sum() > 0: |
|
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0) |
|
matching_matrix[:, anchor_matching_gt > 1] *= 0.0 |
|
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0 |
|
fg_mask_inboxes = matching_matrix.sum(0) > 0.0 |
|
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) |
|
|
|
from_which_layer = from_which_layer[fg_mask_inboxes] |
|
all_b = all_b[fg_mask_inboxes] |
|
all_a = all_a[fg_mask_inboxes] |
|
all_gj = all_gj[fg_mask_inboxes] |
|
all_gi = all_gi[fg_mask_inboxes] |
|
all_anch = all_anch[fg_mask_inboxes] |
|
|
|
this_target = this_target[matched_gt_inds] |
|
|
|
for i in range(nl): |
|
layer_idx = from_which_layer == i |
|
matching_bs[i].append(all_b[layer_idx]) |
|
matching_as[i].append(all_a[layer_idx]) |
|
matching_gjs[i].append(all_gj[layer_idx]) |
|
matching_gis[i].append(all_gi[layer_idx]) |
|
matching_targets[i].append(this_target[layer_idx]) |
|
matching_anchs[i].append(all_anch[layer_idx]) |
|
|
|
for i in range(nl): |
|
matching_bs[i] = torch.cat(matching_bs[i], dim=0) |
|
matching_as[i] = torch.cat(matching_as[i], dim=0) |
|
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) |
|
matching_gis[i] = torch.cat(matching_gis[i], dim=0) |
|
matching_targets[i] = torch.cat(matching_targets[i], dim=0) |
|
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) |
|
|
|
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs |
|
|
|
def find_3_positive(self, p, targets): |
|
|
|
na, nt = self.na, targets.shape[0] |
|
indices, anch = [], [] |
|
gain = torch.ones(7, device=targets.device) |
|
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) |
|
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) |
|
|
|
g = 0.5 |
|
off = torch.tensor([[0, 0], |
|
[1, 0], [0, 1], [-1, 0], [0, -1], |
|
|
|
], device=targets.device).float() * g |
|
|
|
for i in range(self.nl): |
|
anchors = self.anchors[i] |
|
gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] |
|
|
|
|
|
t = targets * gain |
|
if nt: |
|
|
|
r = t[:, :, 4:6] / anchors[:, None] |
|
j = torch.max(r, 1. / r).max(2)[0] < self.hyp['anchor_t'] |
|
|
|
t = t[j] |
|
|
|
|
|
gxy = t[:, 2:4] |
|
gxi = gain[[2, 3]] - gxy |
|
j, k = ((gxy % 1. < g) & (gxy > 1.)).T |
|
l, m = ((gxi % 1. < g) & (gxi > 1.)).T |
|
j = torch.stack((torch.ones_like(j), j, k, l, m)) |
|
t = t.repeat((5, 1, 1))[j] |
|
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] |
|
else: |
|
t = targets[0] |
|
offsets = 0 |
|
|
|
|
|
b, c = t[:, :2].long().T |
|
gxy = t[:, 2:4] |
|
gwh = t[:, 4:6] |
|
gij = (gxy - offsets).long() |
|
gi, gj = gij.T |
|
|
|
|
|
a = t[:, 6].long() |
|
indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) |
|
anch.append(anchors[a]) |
|
|
|
return indices, anch |
|
|
|
|
|
class ComputeLossBinOTA: |
|
|
|
def __init__(self, model, autobalance=False): |
|
super(ComputeLossBinOTA, self).__init__() |
|
device = next(model.parameters()).device |
|
h = model.hyp |
|
|
|
|
|
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device)) |
|
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device)) |
|
|
|
|
|
|
|
self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) |
|
|
|
|
|
g = h['fl_gamma'] |
|
if g > 0: |
|
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) |
|
|
|
det = model.module.model[-1] if is_parallel(model) else model.model[-1] |
|
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) |
|
self.ssi = list(det.stride).index(16) if autobalance else 0 |
|
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance |
|
for k in 'na', 'nc', 'nl', 'anchors', 'stride', 'bin_count': |
|
setattr(self, k, getattr(det, k)) |
|
|
|
|
|
wh_bin_sigmoid = SigmoidBin(bin_count=self.bin_count, min=0.0, max=4.0, use_loss_regression=False).to(device) |
|
|
|
self.wh_bin_sigmoid = wh_bin_sigmoid |
|
|
|
def __call__(self, p, targets, imgs): |
|
device = targets.device |
|
lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device) |
|
bs, as_, gjs, gis, targets, anchors = self.build_targets(p, targets, imgs) |
|
pre_gen_gains = [torch.tensor(pp.shape, device=device)[[3, 2, 3, 2]] for pp in p] |
|
|
|
|
|
|
|
for i, pi in enumerate(p): |
|
b, a, gj, gi = bs[i], as_[i], gjs[i], gis[i] |
|
tobj = torch.zeros_like(pi[..., 0], device=device) |
|
|
|
obj_idx = self.wh_bin_sigmoid.get_length()*2 + 2 |
|
|
|
n = b.shape[0] |
|
if n: |
|
ps = pi[b, a, gj, gi] |
|
|
|
|
|
grid = torch.stack([gi, gj], dim=1) |
|
selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i] |
|
selected_tbox[:, :2] -= grid |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
w_loss, pw = self.wh_bin_sigmoid.training_loss(ps[..., 2:(3+self.bin_count)], selected_tbox[..., 2] / anchors[i][..., 0]) |
|
h_loss, ph = self.wh_bin_sigmoid.training_loss(ps[..., (3+self.bin_count):obj_idx], selected_tbox[..., 3] / anchors[i][..., 1]) |
|
|
|
pw *= anchors[i][..., 0] |
|
ph *= anchors[i][..., 1] |
|
|
|
px = ps[:, 0].sigmoid() * 2. - 0.5 |
|
py = ps[:, 1].sigmoid() * 2. - 0.5 |
|
|
|
lbox += w_loss + h_loss |
|
|
|
|
|
|
|
pbox = torch.cat((px.unsqueeze(1), py.unsqueeze(1), pw.unsqueeze(1), ph.unsqueeze(1)), 1).to(device) |
|
|
|
|
|
|
|
|
|
iou = bbox_iou(pbox.T, selected_tbox, x1y1x2y2=False, CIoU=True) |
|
lbox += (1.0 - iou).mean() |
|
|
|
|
|
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) |
|
|
|
|
|
selected_tcls = targets[i][:, 1].long() |
|
if self.nc > 1: |
|
t = torch.full_like(ps[:, (1+obj_idx):], self.cn, device=device) |
|
t[range(n), selected_tcls] = self.cp |
|
lcls += self.BCEcls(ps[:, (1+obj_idx):], t) |
|
|
|
|
|
|
|
|
|
|
|
obji = self.BCEobj(pi[..., obj_idx], tobj) |
|
lobj += obji * self.balance[i] |
|
if self.autobalance: |
|
self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item() |
|
|
|
if self.autobalance: |
|
self.balance = [x / self.balance[self.ssi] for x in self.balance] |
|
lbox *= self.hyp['box'] |
|
lobj *= self.hyp['obj'] |
|
lcls *= self.hyp['cls'] |
|
bs = tobj.shape[0] |
|
|
|
loss = lbox + lobj + lcls |
|
return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach() |
|
|
|
def build_targets(self, p, targets, imgs): |
|
|
|
|
|
indices, anch = self.find_3_positive(p, targets) |
|
|
|
|
|
|
|
|
|
matching_bs = [[] for pp in p] |
|
matching_as = [[] for pp in p] |
|
matching_gjs = [[] for pp in p] |
|
matching_gis = [[] for pp in p] |
|
matching_targets = [[] for pp in p] |
|
matching_anchs = [[] for pp in p] |
|
|
|
nl = len(p) |
|
|
|
for batch_idx in range(p[0].shape[0]): |
|
|
|
b_idx = targets[:, 0]==batch_idx |
|
this_target = targets[b_idx] |
|
if this_target.shape[0] == 0: |
|
continue |
|
|
|
txywh = this_target[:, 2:6] * imgs[batch_idx].shape[1] |
|
txyxy = xywh2xyxy(txywh) |
|
|
|
pxyxys = [] |
|
p_cls = [] |
|
p_obj = [] |
|
from_which_layer = [] |
|
all_b = [] |
|
all_a = [] |
|
all_gj = [] |
|
all_gi = [] |
|
all_anch = [] |
|
|
|
for i, pi in enumerate(p): |
|
|
|
obj_idx = self.wh_bin_sigmoid.get_length()*2 + 2 |
|
|
|
b, a, gj, gi = indices[i] |
|
idx = (b == batch_idx) |
|
b, a, gj, gi = b[idx], a[idx], gj[idx], gi[idx] |
|
all_b.append(b) |
|
all_a.append(a) |
|
all_gj.append(gj) |
|
all_gi.append(gi) |
|
all_anch.append(anch[i][idx]) |
|
from_which_layer.append(torch.ones(size=(len(b),)) * i) |
|
|
|
fg_pred = pi[b, a, gj, gi] |
|
p_obj.append(fg_pred[:, obj_idx:(obj_idx+1)]) |
|
p_cls.append(fg_pred[:, (obj_idx+1):]) |
|
|
|
grid = torch.stack([gi, gj], dim=1) |
|
pxy = (fg_pred[:, :2].sigmoid() * 2. - 0.5 + grid) * self.stride[i] |
|
|
|
pw = self.wh_bin_sigmoid.forward(fg_pred[..., 2:(3+self.bin_count)].sigmoid()) * anch[i][idx][:, 0] * self.stride[i] |
|
ph = self.wh_bin_sigmoid.forward(fg_pred[..., (3+self.bin_count):obj_idx].sigmoid()) * anch[i][idx][:, 1] * self.stride[i] |
|
|
|
pxywh = torch.cat([pxy, pw.unsqueeze(1), ph.unsqueeze(1)], dim=-1) |
|
pxyxy = xywh2xyxy(pxywh) |
|
pxyxys.append(pxyxy) |
|
|
|
pxyxys = torch.cat(pxyxys, dim=0) |
|
if pxyxys.shape[0] == 0: |
|
continue |
|
p_obj = torch.cat(p_obj, dim=0) |
|
p_cls = torch.cat(p_cls, dim=0) |
|
from_which_layer = torch.cat(from_which_layer, dim=0) |
|
all_b = torch.cat(all_b, dim=0) |
|
all_a = torch.cat(all_a, dim=0) |
|
all_gj = torch.cat(all_gj, dim=0) |
|
all_gi = torch.cat(all_gi, dim=0) |
|
all_anch = torch.cat(all_anch, dim=0) |
|
|
|
pair_wise_iou = box_iou(txyxy, pxyxys) |
|
|
|
pair_wise_iou_loss = -torch.log(pair_wise_iou + 1e-8) |
|
|
|
top_k, _ = torch.topk(pair_wise_iou, min(10, pair_wise_iou.shape[1]), dim=1) |
|
dynamic_ks = torch.clamp(top_k.sum(1).int(), min=1) |
|
|
|
gt_cls_per_image = ( |
|
F.one_hot(this_target[:, 1].to(torch.int64), self.nc) |
|
.float() |
|
.unsqueeze(1) |
|
.repeat(1, pxyxys.shape[0], 1) |
|
) |
|
|
|
num_gt = this_target.shape[0] |
|
cls_preds_ = ( |
|
p_cls.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() |
|
* p_obj.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() |
|
) |
|
|
|
y = cls_preds_.sqrt_() |
|
pair_wise_cls_loss = F.binary_cross_entropy_with_logits( |
|
torch.log(y/(1-y)) , gt_cls_per_image, reduction="none" |
|
).sum(-1) |
|
del cls_preds_ |
|
|
|
cost = ( |
|
pair_wise_cls_loss |
|
+ 3.0 * pair_wise_iou_loss |
|
) |
|
|
|
matching_matrix = torch.zeros_like(cost) |
|
|
|
for gt_idx in range(num_gt): |
|
_, pos_idx = torch.topk( |
|
cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False |
|
) |
|
matching_matrix[gt_idx][pos_idx] = 1.0 |
|
|
|
del top_k, dynamic_ks |
|
anchor_matching_gt = matching_matrix.sum(0) |
|
if (anchor_matching_gt > 1).sum() > 0: |
|
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0) |
|
matching_matrix[:, anchor_matching_gt > 1] *= 0.0 |
|
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0 |
|
fg_mask_inboxes = matching_matrix.sum(0) > 0.0 |
|
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) |
|
|
|
from_which_layer = from_which_layer[fg_mask_inboxes] |
|
all_b = all_b[fg_mask_inboxes] |
|
all_a = all_a[fg_mask_inboxes] |
|
all_gj = all_gj[fg_mask_inboxes] |
|
all_gi = all_gi[fg_mask_inboxes] |
|
all_anch = all_anch[fg_mask_inboxes] |
|
|
|
this_target = this_target[matched_gt_inds] |
|
|
|
for i in range(nl): |
|
layer_idx = from_which_layer == i |
|
matching_bs[i].append(all_b[layer_idx]) |
|
matching_as[i].append(all_a[layer_idx]) |
|
matching_gjs[i].append(all_gj[layer_idx]) |
|
matching_gis[i].append(all_gi[layer_idx]) |
|
matching_targets[i].append(this_target[layer_idx]) |
|
matching_anchs[i].append(all_anch[layer_idx]) |
|
|
|
for i in range(nl): |
|
matching_bs[i] = torch.cat(matching_bs[i], dim=0) |
|
matching_as[i] = torch.cat(matching_as[i], dim=0) |
|
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) |
|
matching_gis[i] = torch.cat(matching_gis[i], dim=0) |
|
matching_targets[i] = torch.cat(matching_targets[i], dim=0) |
|
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) |
|
|
|
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs |
|
|
|
def find_3_positive(self, p, targets): |
|
|
|
na, nt = self.na, targets.shape[0] |
|
indices, anch = [], [] |
|
gain = torch.ones(7, device=targets.device) |
|
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) |
|
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) |
|
|
|
g = 0.5 |
|
off = torch.tensor([[0, 0], |
|
[1, 0], [0, 1], [-1, 0], [0, -1], |
|
|
|
], device=targets.device).float() * g |
|
|
|
for i in range(self.nl): |
|
anchors = self.anchors[i] |
|
gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] |
|
|
|
|
|
t = targets * gain |
|
if nt: |
|
|
|
r = t[:, :, 4:6] / anchors[:, None] |
|
j = torch.max(r, 1. / r).max(2)[0] < self.hyp['anchor_t'] |
|
|
|
t = t[j] |
|
|
|
|
|
gxy = t[:, 2:4] |
|
gxi = gain[[2, 3]] - gxy |
|
j, k = ((gxy % 1. < g) & (gxy > 1.)).T |
|
l, m = ((gxi % 1. < g) & (gxi > 1.)).T |
|
j = torch.stack((torch.ones_like(j), j, k, l, m)) |
|
t = t.repeat((5, 1, 1))[j] |
|
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] |
|
else: |
|
t = targets[0] |
|
offsets = 0 |
|
|
|
|
|
b, c = t[:, :2].long().T |
|
gxy = t[:, 2:4] |
|
gwh = t[:, 4:6] |
|
gij = (gxy - offsets).long() |
|
gi, gj = gij.T |
|
|
|
|
|
a = t[:, 6].long() |
|
indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) |
|
anch.append(anchors[a]) |
|
|
|
return indices, anch |
|
|