File size: 1,152 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from torch import nn


class MGPLoss(nn.Module):

    def __init__(self, only_char=False, **kwargs):
        super(MGPLoss, self).__init__()
        self.ce = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
        self.only_char = only_char

    def forward(self, pred, batch):
        if self.only_char:
            char_feats = pred
            char_tgt = batch[1].flatten(0, 1)
            char_loss = self.ce(char_feats.flatten(0, 1), char_tgt)
            return {'loss': char_loss}
        else:
            return self.forward_all(pred, batch)

    def forward_all(self, pred, batch):
        char_feats, dpe_feats, wp_feats = pred
        char_tgt = batch[1].flatten(0, 1)
        dpe_tgt = batch[2].flatten(0, 1)
        wp_tgt = batch[3].flatten(0, 1)
        char_loss = self.ce(char_feats.flatten(0, 1), char_tgt)
        dpe_loss = self.ce(dpe_feats.flatten(0, 1), dpe_tgt)
        wp_loss = self.ce(wp_feats.flatten(0, 1), wp_tgt)
        loss = char_loss + dpe_loss + wp_loss
        return {
            'loss': loss,
            'char_loss': char_loss,
            'dpe_loss': dpe_loss,
            'wp_loss': wp_loss
        }