# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math import torch import torch.nn.functional as F from fairseq.models.nat import ( _apply_del_words, _apply_ins_masks, _apply_ins_words, _fill, _skip, _skip_encoder_out, ) class _EnsembleModelEncoder(object): def __init__(self, models): self.models = models def reorder_encoder_out(self, encoder_outs, new_order): encoder_outs = [ model.encoder.reorder_encoder_out(encoder_out, new_order) for model, encoder_out in zip(self.models, encoder_outs) ] return encoder_outs class BasicEnsembleModel(torch.nn.Module): """A wrapper around an ensemble of models.""" def __init__(self, models): super().__init__() self.models = torch.nn.ModuleList(models) self.bos = self.models[0].decoder.dictionary.bos() self.eos = self.models[0].decoder.dictionary.eos() self.pad = self.models[0].decoder.dictionary.pad() self.unk = self.models[0].decoder.dictionary.unk() self.encoder = _EnsembleModelEncoder(self.models) def has_encoder(self): return hasattr(self.models[0], "encoder") def max_decoder_positions(self): return min(m.max_decoder_positions() for m in self.models) @torch.no_grad() def forward_encoder(self, encoder_input): if not self.has_encoder(): return None return [model.forward_encoder(encoder_input) for model in self.models] @torch.no_grad() def forward_decoder(self, *inputs): raise NotImplementedError def initialize_output_tokens(self, *inputs): raise NotImplementedError class EnsembleLevT(BasicEnsembleModel): """A wrapper around an ensemble of models.""" def __init__(self, models): super().__init__(models) @torch.no_grad() def forward_decoder( self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio=None, **kwargs ): # LevT ensembling # A pipeline of three steps: deletion, placeholder, and word insertion. # We need to average scores in each step in a pipeline way because of dependence. # deletion output_tokens = decoder_out.output_tokens output_scores = decoder_out.output_scores attn = decoder_out.attn bsz = output_tokens.size(0) if max_ratio is None: max_lens = output_tokens.new().fill_(255) else: if not encoder_outs[0]["encoder_padding_mask"]: src_lens = ( encoder_outs[0]["encoder_out"][0] .new(bsz) .fill_(encoder_outs[0]["encoder_out"][0].size(1)) ) else: src_lens = (~encoder_outs[0]["encoder_padding_mask"][0]).sum(1) max_lens = (src_lens * max_ratio).clamp(min=10).long() # delete words # do not delete tokens if it is can_del_word = output_tokens.ne(self.pad).sum(1) > 2 if can_del_word.sum() != 0: # we cannot delete, skip output_tokens, output_scores, attn = self.forward_word_del( encoder_outs, output_tokens, output_scores, attn, can_del_word, ) # insert placeholders can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens if can_ins_mask.sum() != 0: output_tokens, output_scores = self.forward_mask_ins( encoder_outs, output_tokens, output_scores, can_ins_mask, eos_penalty, max_lens, ) # insert words can_ins_word = output_tokens.eq(self.unk).sum(1) > 0 if can_ins_word.sum() != 0: output_tokens, output_scores, attn = self.forward_word_ins( encoder_outs, output_tokens, output_scores, attn, can_ins_word, ) # delete some unnecessary paddings cut_off = output_tokens.ne(self.pad).sum(1).max() output_tokens = output_tokens[:, :cut_off] output_scores = output_scores[:, :cut_off] attn = None if attn is None else attn[:, :cut_off, :] return decoder_out._replace( output_tokens=output_tokens, output_scores=output_scores, attn=attn, history=None, ) def forward_word_del( self, encoder_outs, output_tokens, output_scores, attn, can_del_word ): word_del_score_avg = [] word_del_attn_avg = [] for model, encoder_out in zip(self.models, encoder_outs): word_del_out, word_del_attn = model.decoder.forward_word_del( _skip(output_tokens, can_del_word), _skip_encoder_out(model.encoder, encoder_out, can_del_word), ) word_del_score = F.log_softmax(word_del_out, 2) word_del_score_avg.append(word_del_score) word_del_attn_avg.append(word_del_attn) word_del_score_avg = torch.logsumexp( torch.stack(word_del_score_avg, dim=0), dim=0 ) - math.log(len(self.models)) word_del_pred = word_del_score_avg.max(-1)[1].bool() if word_del_attn_avg[0] is not None: word_del_attn_avg = torch.stack(word_del_attn_avg, dim=0) / len(self.models) else: word_del_attn_avg = None _tokens, _scores, _attn = _apply_del_words( output_tokens[can_del_word], output_scores[can_del_word], word_del_attn_avg, word_del_pred, self.pad, self.bos, self.eos, ) output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad) output_scores = _fill(output_scores, can_del_word, _scores, 0) attn = _fill(attn, can_del_word, _attn, 0.0) return output_tokens, output_scores, attn def forward_mask_ins( self, encoder_outs, output_tokens, output_scores, can_ins_mask, eos_penalty, max_lens, ): mask_ins_score_avg = [] for model, encoder_out in zip(self.models, encoder_outs): mask_ins_out, _ = model.decoder.forward_mask_ins( _skip(output_tokens, can_ins_mask), _skip_encoder_out(model.encoder, encoder_out, can_ins_mask), ) mask_ins_score = F.log_softmax(mask_ins_out, 2) if eos_penalty > 0.0: mask_ins_score[:, :, 0] -= eos_penalty mask_ins_score_avg.append(mask_ins_score) mask_ins_score_avg = torch.logsumexp( torch.stack(mask_ins_score_avg, dim=0), dim=0 ) - math.log(len(self.models)) mask_ins_pred = mask_ins_score_avg.max(-1)[1] mask_ins_pred = torch.min( mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred) ) _tokens, _scores = _apply_ins_masks( output_tokens[can_ins_mask], output_scores[can_ins_mask], mask_ins_pred, self.pad, self.unk, self.eos, ) output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad) output_scores = _fill(output_scores, can_ins_mask, _scores, 0) return output_tokens, output_scores def forward_word_ins( self, encoder_outs, output_tokens, output_scores, attn, can_ins_word ): word_ins_score_avg = [] word_ins_attn_avg = [] for model, encoder_out in zip(self.models, encoder_outs): word_ins_out, word_ins_attn = model.decoder.forward_word_ins( _skip(output_tokens, can_ins_word), _skip_encoder_out(model.encoder, encoder_out, can_ins_word), ) word_ins_score = F.log_softmax(word_ins_out, 2) word_ins_score_avg.append(word_ins_score) word_ins_attn_avg.append(word_ins_attn) word_ins_score_avg = torch.logsumexp( torch.stack(word_ins_score_avg, dim=0), dim=0 ) - math.log(len(self.models)) if word_ins_attn_avg[0] is not None: word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0) / len(self.models) else: word_ins_attn_avg = None word_ins_score_max, word_ins_pred = word_ins_score_avg.max(-1) _tokens, _scores = _apply_ins_words( output_tokens[can_ins_word], output_scores[can_ins_word], word_ins_pred, word_ins_score_max, self.unk, ) output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad) output_scores = _fill(output_scores, can_ins_word, _scores, 0) attn = _fill(attn, can_ins_word, word_ins_attn, 0.0) return output_tokens, output_scores, attn def initialize_output_tokens(self, encoder_outs, src_tokens): # LevT doesn't do length prediction. return self.models[0].initialize_output_tokens(encoder_outs[0], src_tokens)