TomatoCocotree
上传
6a62ffb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.models import register_model, register_model_architecture
from fairseq.models.nat import (
FairseqNATModel,
LevenshteinTransformerDecoder,
LevenshteinTransformerModel,
ensemble_decoder,
)
from fairseq.models.transformer import Linear
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.utils import new_arange
class NegativeDistanceScore(object):
def __init__(self):
# pre-compute some values
self.scores = {}
self.scores[0.5] = self.compute_score_full(50, 0.5)
self.scores[1.0] = self.compute_score_full(50, 1.0)
self.scores[2.0] = self.compute_score_full(50, 2.0)
def __call__(self, i, L, tau):
if (tau is None) or (tau > 1000):
return 1 / L
if tau in self.scores:
if L < self.scores[tau].shape[0]:
return self.scores[tau][L - 1, i]
return self.compute_score(L, tau)[i]
def compute_score(self, L, tau):
s = np.array([-abs(L / 2 - i) / tau for i in range(L)])
s = np.exp(s - s.max())
return s / s.sum()
def compute_score_full(self, L, tau):
s = -abs(np.arange(0, L - 1)[:, None] / 2 - np.arange(L)[None, :]) / tau
s = np.tril(s, 0) + np.triu(s - float("inf"), 1)
s = np.exp(s - s.max(1, keepdims=True))
return s / s.sum(1, keepdims=True)
neg_scorer = NegativeDistanceScore()
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx, vocab_size, tau=None):
try:
from fairseq import libnat
except ImportError as e:
import sys
sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
raise e
B = in_tokens.size(0)
T = in_tokens.size(1)
V = vocab_size
with torch.cuda.device_of(in_tokens):
in_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
]
out_tokens_list = [
[t for t in s if t != padding_idx]
for i, s in enumerate(out_tokens.tolist())
]
full_labels = libnat.suggested_ed2_path(
in_tokens_list, out_tokens_list, padding_idx
)
insert_labels = [a[:-1] for a in full_labels]
# numericalize1
insert_label_tensors = in_tokens.new_zeros(B * (T - 1) * V).float()
insert_index, insert_labels = zip(
*[
(w + (j + i * (T - 1)) * V, neg_scorer(k, len(label), tau))
for i, labels in enumerate(insert_labels)
for j, label in enumerate(labels[1:-1])
for k, w in enumerate(label)
]
) # HACK 1:-1
insert_index, insert_labels = [
torch.tensor(list(a), device=in_tokens.device)
for a in [insert_index, insert_labels]
]
insert_label_tensors.scatter_(0, insert_index.long(), insert_labels)
insert_label_tensors = insert_label_tensors.view(B, T - 1, V)
return insert_label_tensors
def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, padding_idx):
padding_masks = in_tokens[:, 1:].eq(padding_idx)
word_ins_scores.masked_fill_(padding_masks, 0.0)
word_ins_pred.masked_fill_(padding_masks, padding_idx)
in_coords = new_arange(in_tokens).type_as(in_scores)
# shift all padding predictions to infinite
out_coords = (in_coords[:, 1:] - 0.5).masked_fill(
word_ins_pred.eq(padding_idx), float("inf")
)
out_coords = torch.cat([in_coords, out_coords], 1).sort(-1)[1]
out_tokens = torch.cat([in_tokens, word_ins_pred], 1).gather(1, out_coords)
out_scores = torch.cat([in_scores, word_ins_scores], 1).gather(1, out_coords)
return out_tokens, out_scores
@register_model("insertion_transformer")
class InsertionTransformerModel(LevenshteinTransformerModel):
def __init__(self, args, encoder, decoder):
super().__init__(args, encoder, decoder)
@staticmethod
def add_args(parser):
FairseqNATModel.add_args(parser)
parser.add_argument("--label-tau", default=None, type=float)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
decoder = InsertionTransformerDecoder(args, tgt_dict, embed_tokens)
if getattr(args, "apply_bert_init", False):
decoder.apply(init_bert_params)
return decoder
def forward(
self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
):
assert tgt_tokens is not None, "forward function only supports training."
# encoding
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
# generate training labels for insertion
word_ins_out = self.decoder.forward_word_ins(
normalize=False,
prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out,
)
word_ins_tgt = _get_ins_targets(
prev_output_tokens,
tgt_tokens,
self.pad,
self.unk,
len(self.tgt_dict),
tau=self.decoder.label_tau,
).type_as(word_ins_out)
word_ins_masks = prev_output_tokens[:, 1:].ne(self.pad)
return {
"word_ins": {
"out": word_ins_out,
"tgt": word_ins_tgt,
"mask": word_ins_masks,
"ls": self.args.label_smoothing,
"nll_loss": True,
}
}
def forward_decoder(
self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs
):
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
history = decoder_out.history
# TODO: decoding for InsertionTransformer
word_ins_score = self.decoder.forward_word_ins(
normalize=True, prev_output_tokens=output_tokens, encoder_out=encoder_out
)
if eos_penalty > 0.0:
word_ins_score[:, :, self.pad] -= eos_penalty
word_ins_score, word_ins_pred = word_ins_score.max(-1)
output_tokens, output_scores = _apply_ins_words(
output_tokens, output_scores, word_ins_pred, word_ins_score, self.pad
)
# delete some unnecessary paddings
cut_off = output_tokens.ne(self.pad).sum(1).max()
output_tokens = output_tokens[:, :cut_off]
output_scores = output_scores[:, :cut_off]
if history is not None:
history.append(output_tokens.clone())
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
history=history,
)
class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
# use the TransformerDecoder's __init__
super(LevenshteinTransformerDecoder, self).__init__(
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
)
self.dictionary = dictionary
self.bos = dictionary.bos()
self.unk = dictionary.unk()
self.eos = dictionary.eos()
self.pool_out = Linear(self.output_embed_dim * 2, self.output_embed_dim)
self.label_tau = getattr(args, "label_tau", None)
@ensemble_decoder
def forward_word_ins(self, normalize, encoder_out, prev_output_tokens):
features = self.extract_features(prev_output_tokens, encoder_out=encoder_out)[0]
features = self.pool_out(
torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
)
decoder_out = self.output_layer(features)
return F.log_softmax(decoder_out, -1) if normalize else decoder_out
def forward_mask_ins(self, *args, **kwargs):
raise NotImplementedError
def forward_word_del(self, *args, **kwargs):
raise NotImplementedError
@register_model_architecture("insertion_transformer", "insertion_transformer")
def insertion_base_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.apply_bert_init = getattr(args, "apply_bert_init", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
# special for insertion transformer
args.label_tau = getattr(args, "label_tau", None)