File size: 1,231 Bytes
88b0dcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
"""
@Date: 2021/08/12
@description:
"""
import torch
import torch.nn as nn
from loss.grad_loss import GradLoss
class ObjectLoss(nn.Module):
def __init__(self):
super().__init__()
self.heat_map_loss = HeatmapLoss(reduction='mean') # FocalLoss(reduction='mean')
self.l1_loss = nn.SmoothL1Loss()
def forward(self, gt, dt):
# TODO::
return 0
class HeatmapLoss(nn.Module):
def __init__(self, weight=None, alpha=2, beta=4, reduction='mean'):
super(HeatmapLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.reduction = reduction
def forward(self, targets, inputs):
center_id = (targets == 1.0).float()
other_id = (targets != 1.0).float()
center_loss = -center_id * (1.0 - inputs) ** self.alpha * torch.log(inputs + 1e-14)
other_loss = -other_id * (1 - targets) ** self.beta * inputs ** self.alpha * torch.log(1.0 - inputs + 1e-14)
loss = center_loss + other_loss
batch_size = loss.size(0)
if self.reduction == 'mean':
loss = torch.sum(loss) / batch_size
if self.reduction == 'sum':
loss = torch.sum(loss) / batch_size
return loss
|