TomatoCocotree
上传
6a62ffb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq.models.nat import (
_apply_del_words,
_apply_ins_masks,
_apply_ins_words,
_fill,
_skip,
_skip_encoder_out,
)
class _EnsembleModelEncoder(object):
def __init__(self, models):
self.models = models
def reorder_encoder_out(self, encoder_outs, new_order):
encoder_outs = [
model.encoder.reorder_encoder_out(encoder_out, new_order)
for model, encoder_out in zip(self.models, encoder_outs)
]
return encoder_outs
class BasicEnsembleModel(torch.nn.Module):
"""A wrapper around an ensemble of models."""
def __init__(self, models):
super().__init__()
self.models = torch.nn.ModuleList(models)
self.bos = self.models[0].decoder.dictionary.bos()
self.eos = self.models[0].decoder.dictionary.eos()
self.pad = self.models[0].decoder.dictionary.pad()
self.unk = self.models[0].decoder.dictionary.unk()
self.encoder = _EnsembleModelEncoder(self.models)
def has_encoder(self):
return hasattr(self.models[0], "encoder")
def max_decoder_positions(self):
return min(m.max_decoder_positions() for m in self.models)
@torch.no_grad()
def forward_encoder(self, encoder_input):
if not self.has_encoder():
return None
return [model.forward_encoder(encoder_input) for model in self.models]
@torch.no_grad()
def forward_decoder(self, *inputs):
raise NotImplementedError
def initialize_output_tokens(self, *inputs):
raise NotImplementedError
class EnsembleLevT(BasicEnsembleModel):
"""A wrapper around an ensemble of models."""
def __init__(self, models):
super().__init__(models)
@torch.no_grad()
def forward_decoder(
self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio=None, **kwargs
):
# LevT ensembling
# A pipeline of three steps: deletion, placeholder, and word insertion.
# We need to average scores in each step in a pipeline way because of dependence.
# deletion
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
attn = decoder_out.attn
bsz = output_tokens.size(0)
if max_ratio is None:
max_lens = output_tokens.new().fill_(255)
else:
if not encoder_outs[0]["encoder_padding_mask"]:
src_lens = (
encoder_outs[0]["encoder_out"][0]
.new(bsz)
.fill_(encoder_outs[0]["encoder_out"][0].size(1))
)
else:
src_lens = (~encoder_outs[0]["encoder_padding_mask"][0]).sum(1)
max_lens = (src_lens * max_ratio).clamp(min=10).long()
# delete words
# do not delete tokens if it is <s> </s>
can_del_word = output_tokens.ne(self.pad).sum(1) > 2
if can_del_word.sum() != 0: # we cannot delete, skip
output_tokens, output_scores, attn = self.forward_word_del(
encoder_outs,
output_tokens,
output_scores,
attn,
can_del_word,
)
# insert placeholders
can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens
if can_ins_mask.sum() != 0:
output_tokens, output_scores = self.forward_mask_ins(
encoder_outs,
output_tokens,
output_scores,
can_ins_mask,
eos_penalty,
max_lens,
)
# insert words
can_ins_word = output_tokens.eq(self.unk).sum(1) > 0
if can_ins_word.sum() != 0:
output_tokens, output_scores, attn = self.forward_word_ins(
encoder_outs,
output_tokens,
output_scores,
attn,
can_ins_word,
)
# delete some unnecessary paddings
cut_off = output_tokens.ne(self.pad).sum(1).max()
output_tokens = output_tokens[:, :cut_off]
output_scores = output_scores[:, :cut_off]
attn = None if attn is None else attn[:, :cut_off, :]
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=attn,
history=None,
)
def forward_word_del(
self, encoder_outs, output_tokens, output_scores, attn, can_del_word
):
word_del_score_avg = []
word_del_attn_avg = []
for model, encoder_out in zip(self.models, encoder_outs):
word_del_out, word_del_attn = model.decoder.forward_word_del(
_skip(output_tokens, can_del_word),
_skip_encoder_out(model.encoder, encoder_out, can_del_word),
)
word_del_score = F.log_softmax(word_del_out, 2)
word_del_score_avg.append(word_del_score)
word_del_attn_avg.append(word_del_attn)
word_del_score_avg = torch.logsumexp(
torch.stack(word_del_score_avg, dim=0), dim=0
) - math.log(len(self.models))
word_del_pred = word_del_score_avg.max(-1)[1].bool()
if word_del_attn_avg[0] is not None:
word_del_attn_avg = torch.stack(word_del_attn_avg, dim=0) / len(self.models)
else:
word_del_attn_avg = None
_tokens, _scores, _attn = _apply_del_words(
output_tokens[can_del_word],
output_scores[can_del_word],
word_del_attn_avg,
word_del_pred,
self.pad,
self.bos,
self.eos,
)
output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad)
output_scores = _fill(output_scores, can_del_word, _scores, 0)
attn = _fill(attn, can_del_word, _attn, 0.0)
return output_tokens, output_scores, attn
def forward_mask_ins(
self,
encoder_outs,
output_tokens,
output_scores,
can_ins_mask,
eos_penalty,
max_lens,
):
mask_ins_score_avg = []
for model, encoder_out in zip(self.models, encoder_outs):
mask_ins_out, _ = model.decoder.forward_mask_ins(
_skip(output_tokens, can_ins_mask),
_skip_encoder_out(model.encoder, encoder_out, can_ins_mask),
)
mask_ins_score = F.log_softmax(mask_ins_out, 2)
if eos_penalty > 0.0:
mask_ins_score[:, :, 0] -= eos_penalty
mask_ins_score_avg.append(mask_ins_score)
mask_ins_score_avg = torch.logsumexp(
torch.stack(mask_ins_score_avg, dim=0), dim=0
) - math.log(len(self.models))
mask_ins_pred = mask_ins_score_avg.max(-1)[1]
mask_ins_pred = torch.min(
mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred)
)
_tokens, _scores = _apply_ins_masks(
output_tokens[can_ins_mask],
output_scores[can_ins_mask],
mask_ins_pred,
self.pad,
self.unk,
self.eos,
)
output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad)
output_scores = _fill(output_scores, can_ins_mask, _scores, 0)
return output_tokens, output_scores
def forward_word_ins(
self, encoder_outs, output_tokens, output_scores, attn, can_ins_word
):
word_ins_score_avg = []
word_ins_attn_avg = []
for model, encoder_out in zip(self.models, encoder_outs):
word_ins_out, word_ins_attn = model.decoder.forward_word_ins(
_skip(output_tokens, can_ins_word),
_skip_encoder_out(model.encoder, encoder_out, can_ins_word),
)
word_ins_score = F.log_softmax(word_ins_out, 2)
word_ins_score_avg.append(word_ins_score)
word_ins_attn_avg.append(word_ins_attn)
word_ins_score_avg = torch.logsumexp(
torch.stack(word_ins_score_avg, dim=0), dim=0
) - math.log(len(self.models))
if word_ins_attn_avg[0] is not None:
word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0) / len(self.models)
else:
word_ins_attn_avg = None
word_ins_score_max, word_ins_pred = word_ins_score_avg.max(-1)
_tokens, _scores = _apply_ins_words(
output_tokens[can_ins_word],
output_scores[can_ins_word],
word_ins_pred,
word_ins_score_max,
self.unk,
)
output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad)
output_scores = _fill(output_scores, can_ins_word, _scores, 0)
attn = _fill(attn, can_ins_word, word_ins_attn, 0.0)
return output_tokens, output_scores, attn
def initialize_output_tokens(self, encoder_outs, src_tokens):
# LevT doesn't do length prediction.
return self.models[0].initialize_output_tokens(encoder_outs[0], src_tokens)