File size: 881 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch.nn.functional as F
from torch import nn


class LPVLoss(nn.Module):

    def __init__(self, label_smoothing=0.0, **kwargs):
        super(LPVLoss, self).__init__()
        self.label_smoothing = label_smoothing

    def forward(self, preds, batch):
        max_len = batch[2].max()
        tgt = batch[1][:, 1:2 + max_len]

        tgt = tgt.flatten(0, 1)
        loss = 0
        loss_dict = {}
        for i, pred in enumerate(preds):
            pred = pred.flatten(0, 1)
            loss_i = F.cross_entropy(
                pred,
                tgt,
                reduction='mean',
                label_smoothing=self.label_smoothing,
                ignore_index=pred.shape[1] + 1,
            )  # self.loss_func(pred, tgt)
            loss += loss_i
            loss_dict['loss' + str(i)] = loss_i
        loss_dict['loss'] = loss
        return loss_dict