# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import numpy as np import torch import torch.nn.functional as F from fairseq.models import register_model, register_model_architecture from fairseq.models.nat import ( FairseqNATModel, LevenshteinTransformerDecoder, LevenshteinTransformerModel, ensemble_decoder, ) from fairseq.models.transformer import Linear from fairseq.modules.transformer_sentence_encoder import init_bert_params from fairseq.utils import new_arange class NegativeDistanceScore(object): def __init__(self): # pre-compute some values self.scores = {} self.scores[0.5] = self.compute_score_full(50, 0.5) self.scores[1.0] = self.compute_score_full(50, 1.0) self.scores[2.0] = self.compute_score_full(50, 2.0) def __call__(self, i, L, tau): if (tau is None) or (tau > 1000): return 1 / L if tau in self.scores: if L < self.scores[tau].shape[0]: return self.scores[tau][L - 1, i] return self.compute_score(L, tau)[i] def compute_score(self, L, tau): s = np.array([-abs(L / 2 - i) / tau for i in range(L)]) s = np.exp(s - s.max()) return s / s.sum() def compute_score_full(self, L, tau): s = -abs(np.arange(0, L - 1)[:, None] / 2 - np.arange(L)[None, :]) / tau s = np.tril(s, 0) + np.triu(s - float("inf"), 1) s = np.exp(s - s.max(1, keepdims=True)) return s / s.sum(1, keepdims=True) neg_scorer = NegativeDistanceScore() def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx, vocab_size, tau=None): try: from fairseq import libnat except ImportError as e: import sys sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n") raise e B = in_tokens.size(0) T = in_tokens.size(1) V = vocab_size with torch.cuda.device_of(in_tokens): in_tokens_list = [ [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) ] out_tokens_list = [ [t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist()) ] full_labels = libnat.suggested_ed2_path( in_tokens_list, out_tokens_list, padding_idx ) insert_labels = [a[:-1] for a in full_labels] # numericalize1 insert_label_tensors = in_tokens.new_zeros(B * (T - 1) * V).float() insert_index, insert_labels = zip( *[ (w + (j + i * (T - 1)) * V, neg_scorer(k, len(label), tau)) for i, labels in enumerate(insert_labels) for j, label in enumerate(labels[1:-1]) for k, w in enumerate(label) ] ) # HACK 1:-1 insert_index, insert_labels = [ torch.tensor(list(a), device=in_tokens.device) for a in [insert_index, insert_labels] ] insert_label_tensors.scatter_(0, insert_index.long(), insert_labels) insert_label_tensors = insert_label_tensors.view(B, T - 1, V) return insert_label_tensors def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, padding_idx): padding_masks = in_tokens[:, 1:].eq(padding_idx) word_ins_scores.masked_fill_(padding_masks, 0.0) word_ins_pred.masked_fill_(padding_masks, padding_idx) in_coords = new_arange(in_tokens).type_as(in_scores) # shift all padding predictions to infinite out_coords = (in_coords[:, 1:] - 0.5).masked_fill( word_ins_pred.eq(padding_idx), float("inf") ) out_coords = torch.cat([in_coords, out_coords], 1).sort(-1)[1] out_tokens = torch.cat([in_tokens, word_ins_pred], 1).gather(1, out_coords) out_scores = torch.cat([in_scores, word_ins_scores], 1).gather(1, out_coords) return out_tokens, out_scores @register_model("insertion_transformer") class InsertionTransformerModel(LevenshteinTransformerModel): def __init__(self, args, encoder, decoder): super().__init__(args, encoder, decoder) @staticmethod def add_args(parser): FairseqNATModel.add_args(parser) parser.add_argument("--label-tau", default=None, type=float) @classmethod def build_decoder(cls, args, tgt_dict, embed_tokens): decoder = InsertionTransformerDecoder(args, tgt_dict, embed_tokens) if getattr(args, "apply_bert_init", False): decoder.apply(init_bert_params) return decoder def forward( self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs ): assert tgt_tokens is not None, "forward function only supports training." # encoding encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) # generate training labels for insertion word_ins_out = self.decoder.forward_word_ins( normalize=False, prev_output_tokens=prev_output_tokens, encoder_out=encoder_out, ) word_ins_tgt = _get_ins_targets( prev_output_tokens, tgt_tokens, self.pad, self.unk, len(self.tgt_dict), tau=self.decoder.label_tau, ).type_as(word_ins_out) word_ins_masks = prev_output_tokens[:, 1:].ne(self.pad) return { "word_ins": { "out": word_ins_out, "tgt": word_ins_tgt, "mask": word_ins_masks, "ls": self.args.label_smoothing, "nll_loss": True, } } def forward_decoder( self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs ): output_tokens = decoder_out.output_tokens output_scores = decoder_out.output_scores history = decoder_out.history # TODO: decoding for InsertionTransformer word_ins_score = self.decoder.forward_word_ins( normalize=True, prev_output_tokens=output_tokens, encoder_out=encoder_out ) if eos_penalty > 0.0: word_ins_score[:, :, self.pad] -= eos_penalty word_ins_score, word_ins_pred = word_ins_score.max(-1) output_tokens, output_scores = _apply_ins_words( output_tokens, output_scores, word_ins_pred, word_ins_score, self.pad ) # delete some unnecessary paddings cut_off = output_tokens.ne(self.pad).sum(1).max() output_tokens = output_tokens[:, :cut_off] output_scores = output_scores[:, :cut_off] if history is not None: history.append(output_tokens.clone()) return decoder_out._replace( output_tokens=output_tokens, output_scores=output_scores, attn=None, history=history, ) class InsertionTransformerDecoder(LevenshteinTransformerDecoder): def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): # use the TransformerDecoder's __init__ super(LevenshteinTransformerDecoder, self).__init__( args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn ) self.dictionary = dictionary self.bos = dictionary.bos() self.unk = dictionary.unk() self.eos = dictionary.eos() self.pool_out = Linear(self.output_embed_dim * 2, self.output_embed_dim) self.label_tau = getattr(args, "label_tau", None) @ensemble_decoder def forward_word_ins(self, normalize, encoder_out, prev_output_tokens): features = self.extract_features(prev_output_tokens, encoder_out=encoder_out)[0] features = self.pool_out( torch.cat([features[:, :-1, :], features[:, 1:, :]], 2) ) decoder_out = self.output_layer(features) return F.log_softmax(decoder_out, -1) if normalize else decoder_out def forward_mask_ins(self, *args, **kwargs): raise NotImplementedError def forward_word_del(self, *args, **kwargs): raise NotImplementedError @register_model_architecture("insertion_transformer", "insertion_transformer") def insertion_base_architecture(args): args.encoder_embed_path = getattr(args, "encoder_embed_path", None) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_layers = getattr(args, "encoder_layers", 6) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) args.decoder_embed_path = getattr(args, "decoder_embed_path", None) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) args.decoder_ffn_embed_dim = getattr( args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim ) args.decoder_layers = getattr(args, "decoder_layers", 6) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) args.attention_dropout = getattr(args, "attention_dropout", 0.0) args.activation_dropout = getattr(args, "activation_dropout", 0.0) args.activation_fn = getattr(args, "activation_fn", "relu") args.dropout = getattr(args, "dropout", 0.1) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.share_decoder_input_output_embed = getattr( args, "share_decoder_input_output_embed", False ) args.share_all_embeddings = getattr(args, "share_all_embeddings", False) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False ) args.adaptive_input = getattr(args, "adaptive_input", False) args.apply_bert_init = getattr(args, "apply_bert_init", False) args.decoder_output_dim = getattr( args, "decoder_output_dim", args.decoder_embed_dim ) args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) # special for insertion transformer args.label_tau = getattr(args, "label_tau", None)