File size: 1,231 Bytes
88b0dcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
""" 
@Date: 2021/08/12
@description:
"""
import torch
import torch.nn as nn
from loss.grad_loss import GradLoss


class ObjectLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.heat_map_loss = HeatmapLoss(reduction='mean')  # FocalLoss(reduction='mean')
        self.l1_loss = nn.SmoothL1Loss()

    def forward(self, gt, dt):
        # TODO::
        return 0


class HeatmapLoss(nn.Module):
    def __init__(self, weight=None, alpha=2, beta=4, reduction='mean'):
        super(HeatmapLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.reduction = reduction

    def forward(self, targets, inputs):
        center_id = (targets == 1.0).float()
        other_id = (targets != 1.0).float()
        center_loss = -center_id * (1.0 - inputs) ** self.alpha * torch.log(inputs + 1e-14)
        other_loss = -other_id * (1 - targets) ** self.beta * inputs ** self.alpha * torch.log(1.0 - inputs + 1e-14)
        loss = center_loss + other_loss

        batch_size = loss.size(0)
        if self.reduction == 'mean':
            loss = torch.sum(loss) / batch_size

        if self.reduction == 'sum':
            loss = torch.sum(loss) / batch_size

        return loss