File size: 4,509 Bytes
5e9bd47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from .tokenizer import PAD_ID, MASK, MASK_ID
class LabelSmoothingLoss(nn.Module):
"""
With label smoothing,
KL-divergence between q_{smoothed ground truth prob.}(w)
and p_{prob. computed by model}(w) is minimized.
"""
def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100):
assert 0.0 < label_smoothing <= 1.0
self.ignore_index = ignore_index
super(LabelSmoothingLoss, self).__init__()
smoothing_value = label_smoothing / (tgt_vocab_size - 2)
one_hot = torch.full((tgt_vocab_size,), smoothing_value)
one_hot[self.ignore_index] = 0
self.register_buffer('one_hot', one_hot.unsqueeze(0))
self.confidence = 1.0 - label_smoothing
def forward(self, output, target):
"""
output (FloatTensor): batch_size x n_classes
target (LongTensor): batch_size
"""
# assuming output is raw logits
# convert to log_probs
log_probs = F.log_softmax(output, dim=-1)
model_prob = self.one_hot.repeat(target.size(0), 1)
model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0)
# reduction mean or sum?
return F.kl_div(log_probs, model_prob, reduction='batchmean')
class SequenceLoss(nn.Module):
def __init__(self, label_smoothing, vocab_size, ignore_index=-100, ignore_indices=[]):
super(SequenceLoss, self).__init__()
if ignore_indices:
ignore_index = ignore_indices[0]
self.ignore_index = ignore_index
self.ignore_indices = ignore_indices
if label_smoothing == 0:
self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='mean')
else:
self.criterion = LabelSmoothingLoss(label_smoothing, vocab_size, ignore_index)
def forward(self, output, target):
"""
:param output: [batch, len, vocab]
:param target: [batch, len]
:return:
"""
batch_size, max_len, vocab_size = output.size()
output = output.reshape(-1, vocab_size)
target = target.reshape(-1)
for idx in self.ignore_indices:
if idx != self.ignore_index:
target.masked_fill_((target == idx), self.ignore_index)
loss = self.criterion(output, target)
return loss
class GraphLoss(nn.Module):
def __init__(self):
super(GraphLoss, self).__init__()
weight = torch.ones(7) * 10
weight[0] = 1
self.criterion = nn.CrossEntropyLoss(weight, ignore_index=-100)
def forward(self, outputs, targets):
results = {}
if 'coords' in outputs:
pred = outputs['coords']
max_len = pred.size(1)
target = targets['coords'][:, :max_len]
mask = target.ge(0)
loss = F.l1_loss(pred, target, reduction='none')
results['coords'] = (loss * mask).sum() / mask.sum()
if 'edges' in outputs:
pred = outputs['edges']
max_len = pred.size(-1)
target = targets['edges'][:, :max_len, :max_len]
results['edges'] = self.criterion(pred, target)
return results
class Criterion(nn.Module):
def __init__(self, args, tokenizer):
super(Criterion, self).__init__()
criterion = {}
for format_ in args.formats:
if format_ == 'edges':
criterion['edges'] = GraphLoss()
else:
if MASK in tokenizer[format_].stoi:
ignore_indices = [PAD_ID, MASK_ID]
else:
ignore_indices = []
criterion[format_] = SequenceLoss(args.label_smoothing, len(tokenizer[format_]),
ignore_index=PAD_ID, ignore_indices=ignore_indices)
self.criterion = nn.ModuleDict(criterion)
def forward(self, results, refs):
losses = {}
for format_ in results:
predictions, targets, *_ = results[format_]
loss_ = self.criterion[format_](predictions, targets)
if type(loss_) is dict:
losses.update(loss_)
else:
if loss_.numel() > 1:
loss_ = loss_.mean()
losses[format_] = loss_
return losses
|