""" |
This includes: LossComputeBase and the standard NMTLossCompute, and |
sharded loss compute stuff. |
""" |
import torch |
import torch.nn as nn |
import torch.nn.functional as F |
import onmt |
from onmt.modules.sparse_losses import SparsemaxLoss |
from onmt.modules.sparse_activations import LogSparsemax |
from onmt.constants import ModelTask, DefaultTokens |
from onmt.modules.copy_generator import collapse_copy_scores |
from onmt.model_builder import load_test_model |
try: |
import ctranslate2 |
except ImportError: |
pass |
class LossCompute(nn.Module): |
""" |
Class for managing efficient loss computation. Handles |
accumulating multiple loss computations. |
Args: |
criterion (:obj:`nn. loss function`) : NLLoss or customed loss |
generator (:obj:`nn.Module`) : |
copy_attn (bool): whether copy attention mechanism is on/off |
lambda_coverage: Hyper-param to apply coverage attention if any |
lambda_align: Hyper-param for alignment loss |
tgt_shift_index (int): 1 for NMT, 0 for LM |
vocab: target vocab (for copy attention score calculation) |
module that maps the output of the decoder to a |
distribution over the target vocabulary. |
lm_generator (:obj:`ctranslate2.Generator`): LM Generator |
lm_prior_lambda (float): weight of LM model in loss |
lm_prior_tau (float): scaler for LM loss |
""" |
def __init__( |
self, |
criterion, |
generator, |
copy_attn=False, |
lambda_coverage=0.0, |
lambda_align=0.0, |
tgt_shift_index=1, |
vocab=None, |
lm_generator=None, |
lm_prior_lambda=None, |
lm_prior_tau=None, |
lm_prior_model=None, |
): |
super(LossCompute, self).__init__() |
self.criterion = criterion |
self.generator = generator |
self.lambda_coverage = lambda_coverage |
self.lambda_align = lambda_align |
self.tgt_shift_index = tgt_shift_index |
self.copy_attn = copy_attn |
self.vocab = vocab |
self.lm_generator = lm_generator |
self.lm_prior_lambda = lm_prior_lambda |
self.lm_prior_tau = lm_prior_tau |
self.lm_prior_model = lm_prior_model |
@classmethod |
def from_opts(cls, opt, model, vocab, train=True): |
""" |
Returns a subclass which wraps around an nn.Module subclass |
(such as nn.NLLLoss) which defines the loss criterion. The LossCompute |
object passes relevant data to a Statistics object which handles |
training/validation logging. |
The Criterion and LossCompute options are triggered by opt settings. |
""" |
device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") |
padding_idx = vocab[DefaultTokens.PAD] |
unk_idx = vocab[DefaultTokens.UNK] |
if opt.lambda_coverage != 0: |
assert opt.coverage_attn, ( |
"--coverage_attn needs to be set in " |
"order to use --lambda_coverage != 0" |
) |
tgt_shift_idx = 1 if opt.model_task == ModelTask.SEQ2SEQ else 0 |
if opt.copy_attn: |
criterion = onmt.modules.CopyGeneratorLoss( |
len(vocab), |
opt.copy_attn_force, |
unk_index=unk_idx, |
ignore_index=padding_idx, |
) |
else: |
if opt.generator_function == "sparsemax": |
criterion = SparsemaxLoss(ignore_index=padding_idx, reduction="sum") |
else: |
criterion = nn.CrossEntropyLoss( |
ignore_index=padding_idx, |
reduction="sum", |
label_smoothing=opt.label_smoothing, |
) |
lm_prior_lambda = opt.lm_prior_lambda |
lm_prior_tau = opt.lm_prior_tau |
if opt.lm_prior_model: |
if opt.lm_prior_model[-3:] == ".pt": |
opt.gpu = 0 |
opt.fp32 = False |
opt.int8 = False |
_, lm_prior_model, lm_model_opt = load_test_model( |
opt, model_path=opt.lm_prior_model |
) |
lm_prior_model.to(torch.device("cuda", opt.gpu)) |
lm_prior_model.eval() |
lm_generator = None |
else: |
lm_prior_model = None |
try: |
import ctranslate2 |
lm_generator = ctranslate2.Generator( |
opt.lm_prior_model, device="cuda", compute_type="float16" |
) |
except ImportError: |
raise ImportError("Could not import ctranslate2") |
else: |
lm_generator = None |
lm_prior_model = None |
compute = cls( |
criterion, |
model.generator, |
copy_attn=opt.copy_attn, |
lambda_coverage=opt.lambda_coverage, |
lambda_align=opt.lambda_align, |
tgt_shift_index=tgt_shift_idx, |
vocab=vocab, |
lm_generator=lm_generator, |
lm_prior_lambda=lm_prior_lambda, |
lm_prior_tau=lm_prior_tau, |
lm_prior_model=lm_prior_model, |
) |
compute.to(device) |
return compute |
@property |
def padding_idx(self): |
return self.criterion.ignore_index |
def _compute_coverage_loss(self, std_attn, cov_attn, tgt): |
"""compute coverage loss""" |
zero_attn = torch.zeros(cov_attn.size()[1:], device=cov_attn.device) |
cov_attn = torch.cat((zero_attn.unsqueeze(0), cov_attn[:-1]), 0) |
covloss = torch.min(std_attn, cov_attn).sum(dim=-1).view(-1) |
covloss[tgt == self.padding_idx] = 0 |
return covloss.sum() |
def _compute_alignement_loss(self, align_head, ref_align): |
"""Compute loss between 2 partial alignment matrix.""" |
align_loss = -align_head.clamp(min=1e-18).log().mul(ref_align).sum() |
align_loss *= self.lambda_align |
return align_loss |
def _compute_copy_loss(self, batch, output, target, align, attns): |
"""Compute the copy attention loss. |
Args: |
batch: the current batch. |
output: the predict output from the model. |
target: the validate target to compare output with. |
align: |
attns: dictionary of attention distributions |
`(tgt_len, batch, src_len)` |
Returns: |
A tuple with the loss and raw scores. |
""" |
scores = self.generator( |
self._bottle(output), self._bottle(attns["copy"]), batch["src_map"] |
) |
loss = self.criterion(scores, align, target).sum() |
return loss, scores |
def _compute_lm_loss_ct2(self, output, target): |
""" |
Compute the loss between MT output and LM output |
https://github.com/cbaziotis/lm-prior-for-nmt/blob/master |
/fairseq_extension/user/lm_prior/lm_prior.py#L131-L133 |
""" |
scores = self.generator(self._bottle(output)) / self.lm_prior_tau |
scores = F.log_softmax(scores.to(torch.float32), dim=-1) |
src = target.detach().clone() |
src[src == self.vocab[DefaultTokens.EOS]] = self.padding_idx |
src = src[:, :-1, :] |
src_len = src[:, :, 0].ne(self.padding_idx).sum(1) |
lm_scores = self.lm_generator.forward_batch( |
ctranslate2.StorageView.from_array(src[:, :, 0].to(torch.int32)), |
ctranslate2.StorageView.from_array(src_len.to(torch.int32)), |
return_log_probs=False, |
) |
lm_scores = torch.as_tensor(lm_scores, device=scores.device) |
lm_scores = self._bottle(lm_scores) / self.lm_prior_tau |
lm_scores = F.log_softmax(lm_scores.to(torch.float32), dim=-1) |
lm_scores[:, self.vocab[DefaultTokens.UNK]] = -50 |
lm_scores[:, self.vocab[DefaultTokens.EOS]] -= 20 |
lm_loss = F.kl_div(scores, lm_scores, reduction="none", log_target=True).sum(-1) |
non_padding = self._bottle(output).ne(self.padding_idx)[:, 0] |
lm_loss = lm_loss.masked_select(non_padding).sum() |
lm_loss = lm_loss * (self.lm_prior_tau**2) |
return lm_loss |
def _compute_lm_loss(self, output, target): |
""" |
Compute the loss between MT output and LM output |
https://github.com/cbaziotis/lm-prior-for-nmt/blob/master |
/fairseq_extension/user/lm_prior/lm_prior.py#L131-L133 |
""" |
scores = self.generator(self._bottle(output)) / self.lm_prior_tau |
scores = F.log_softmax(scores.to(torch.float32), dim=-1) |
src = target.detach().clone() |
src[src == self.vocab[DefaultTokens.EOS]] = self.padding_idx |
src = src[:, :-1, :] |
src_len = src[:, :, 0].ne(self.padding_idx).sum(1) |
lm_outs, _ = self.lm_prior_model(src, None, src_len, with_align=False) |
lm_scores = ( |
self.lm_prior_model.generator(self._bottle(lm_outs)).detach().clone() |
/ self.lm_prior_tau |
) |
lm_scores = F.log_softmax(lm_scores.to(torch.float32), dim=-1) |
lm_scores[:, self.vocab[DefaultTokens.UNK]] = -50 |
lm_scores[:, self.vocab[DefaultTokens.EOS]] -= 20 |
lm_loss = F.kl_div(scores, lm_scores, reduction="none", log_target=True).sum(-1) |
non_padding = self._bottle(output).ne(self.padding_idx)[:, 0] |
lm_loss = lm_loss.masked_select(non_padding).sum() |
lm_loss = lm_loss * (self.lm_prior_tau**2) |
return lm_loss |
def _bottle(self, _v): |
return _v.view(-1, _v.size(2)) |
def _unbottle(self, _v, batch_size): |
return _v.view(-1, batch_size, _v.size(1)) |
def ignore_prompt(self, batch): |
""" |
Mask the prompt in the target side of the batch examples in order |
to set the loss of the prompt to zero. |
For finetuning on specific tasks. |
The end of the prompt must be indicated by `the DefaultTokens.MASK_BEFORE` |
placeholder. |
The masks are supposed to be properly handled by the loss criterion |
(e.g. nn.CrossEntropyLoss ). |
Args: |
batch: The current batch. |
""" |
mask = batch["src"].squeeze(dim=2) == self.padding_idx |
mask = torch.cumsum(mask.int(), 1) |
mask = mask.unsqueeze(-1) |
batch["tgt"] *= mask.int() |
batch["tgt"] += self.padding_idx * (1 - mask.int()) |
return batch |
def forward(self, batch, output, attns, trunc_start=0, trunc_size=None): |
"""Compute the forward loss, supports truncated BPTT for long |
sequences by taking a range in the decoder output sequence to |
back propagate in. |
Range is from `(trunc_start, trunc_start + trunc_size)`. |
Truncation is an approximate efficiency trick to relieve the |
memory required in the RNN buffers. |
Args: |
batch (batch) : batch of labeled examples |
output (:obj:`FloatTensor`) : |
output of decoder model ``(batch, tgt_len, hidden)`` |
attns (dict) : dictionary of attention weights |
``(batch, tgt_len, src_len)`` |
trunc_start (int) : starting position of truncation window |
trunc_size (int) : length of truncation window |
Returns: |
A tuple with the loss and a :obj:`onmt.utils.Statistics` instance. |
""" |
if trunc_size is None: |
trunc_size = batch["tgt"].size(1) - trunc_start |
trunc_range = (trunc_start + self.tgt_shift_index, trunc_start + trunc_size) |
target = batch["tgt"][:, trunc_range[0] : trunc_range[1], :] |
output = output[:, trunc_start : trunc_range[1], :].contiguous() |
flat_tgt = target[:, :, 0].contiguous().view(-1) |
if self.copy_attn: |
align = ( |
batch["alignment"][:, trunc_range[0] : trunc_range[1]] |
.contiguous() |
.view(-1) |
) |
loss, scores = self._compute_copy_loss( |
batch, output, flat_tgt, align, attns |
) |
scores_data = collapse_copy_scores( |
self._unbottle(scores.clone(), len(batch["srclen"])), |
batch, |
self.vocab, |
None, |
) |
scores_data = self._bottle(scores_data) |
target_data = flat_tgt.clone() |
unk = self.criterion.unk_index |
correct_mask = (target_data == unk) & (align != unk) |
offset_align = align[correct_mask] + len(self.vocab) |
target_data[correct_mask] += offset_align |
scores = scores_data |
flat_tgt = target_data |
else: |
scores = self.generator(self._bottle(output)) |
if isinstance(self.criterion, SparsemaxLoss): |
scores = LogSparsemax(scores.to(torch.float32), dim=-1) |
loss = self.criterion(scores.to(torch.float32), flat_tgt) |
if self.lambda_align != 0.0: |
align_head = attns["align"] |
if align_head.dtype != loss.dtype: |
align_head = align_head.to(loss.dtype) |
align_idx = batch["align"] |
batch_size, pad_tgt_size, _ = batch["tgt"].size() |
_, pad_src_size, _ = batch["src"].size() |
align_matrix_size = [batch_size, pad_tgt_size, pad_src_size] |
ref_align = onmt.utils.make_batch_align_matrix( |
align_idx, align_matrix_size, normalize=True |
) |
ref_align = ref_align[:, trunc_range[0] : trunc_range[1], :] |
if ref_align.dtype != loss.dtype: |
ref_align = ref_align.to(loss.dtype) |
align_loss = self._compute_alignement_loss( |
align_head=align_head, ref_align=ref_align |
) |
loss += align_loss |
if self.lambda_coverage != 0.0: |
coverage_loss = self._compute_coverage_loss( |
attns["std"], attns["coverage"], flat_tgt |
) |
loss += coverage_loss |
if self.lm_generator is not None: |
lm_loss = self._compute_lm_loss_ct2(output, batch["tgt"]) |
loss = loss + lm_loss * self.lm_prior_lambda |
if self.lm_prior_model is not None: |
lm_loss = self._compute_lm_loss(output, batch["tgt"]) |
loss = loss + lm_loss * self.lm_prior_lambda |
n_sents = len(batch["srclen"]) if trunc_start == 0 else 0 |
stats = self._stats(n_sents, loss.sum().item(), scores, flat_tgt) |
return loss, stats |
def _stats(self, bsz, loss, scores, target): |
""" |
Args: |
loss (int): the loss computed by the loss criterion. |
scores (:obj:`FloatTensor`): a score for each possible output |
target (:obj:`FloatTensor`): true targets |
Returns: |
:obj:`onmt.utils.Statistics` : statistics for this batch. |
""" |
pred = scores.max(1)[1] |
non_padding = target.ne(self.padding_idx) |
num_correct = pred.eq(target).masked_select(non_padding).sum().item() |
num_non_padding = non_padding.sum().item() |
n_batchs = 1 if bsz else 0 |
return onmt.utils.Statistics( |
loss=loss, |
n_batchs=n_batchs, |
n_sents=bsz, |
n_words=num_non_padding, |
n_correct=num_correct, |
) |