|
|
|
|
|
|
|
|
|
|
|
""" |
|
This file implements: |
|
Ghazvininejad, Marjan, et al. |
|
"Constant-time machine translation with conditional masked language models." |
|
arXiv preprint arXiv:1904.09324 (2019). |
|
""" |
|
|
|
from fairseq.models import register_model, register_model_architecture |
|
from fairseq.models.nat import NATransformerModel |
|
from fairseq.utils import new_arange |
|
|
|
|
|
def _skeptical_unmasking(output_scores, output_masks, p): |
|
sorted_index = output_scores.sort(-1)[1] |
|
boundary_len = ( |
|
(output_masks.sum(1, keepdim=True).type_as(output_scores) - 2) * p |
|
).long() |
|
skeptical_mask = new_arange(output_masks) < boundary_len |
|
return skeptical_mask.scatter(1, sorted_index, skeptical_mask) |
|
|
|
|
|
@register_model("cmlm_transformer") |
|
class CMLMNATransformerModel(NATransformerModel): |
|
@staticmethod |
|
def add_args(parser): |
|
NATransformerModel.add_args(parser) |
|
|
|
def forward( |
|
self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs |
|
): |
|
assert not self.decoder.src_embedding_copy, "do not support embedding copy." |
|
|
|
|
|
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) |
|
|
|
length_out = self.decoder.forward_length( |
|
normalize=False, encoder_out=encoder_out |
|
) |
|
length_tgt = self.decoder.forward_length_prediction( |
|
length_out, encoder_out, tgt_tokens |
|
) |
|
|
|
|
|
word_ins_out = self.decoder( |
|
normalize=False, |
|
prev_output_tokens=prev_output_tokens, |
|
encoder_out=encoder_out, |
|
) |
|
word_ins_mask = prev_output_tokens.eq(self.unk) |
|
|
|
return { |
|
"word_ins": { |
|
"out": word_ins_out, |
|
"tgt": tgt_tokens, |
|
"mask": word_ins_mask, |
|
"ls": self.args.label_smoothing, |
|
"nll_loss": True, |
|
}, |
|
"length": { |
|
"out": length_out, |
|
"tgt": length_tgt, |
|
"factor": self.decoder.length_loss_factor, |
|
}, |
|
} |
|
|
|
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): |
|
|
|
step = decoder_out.step |
|
max_step = decoder_out.max_step |
|
|
|
output_tokens = decoder_out.output_tokens |
|
output_scores = decoder_out.output_scores |
|
history = decoder_out.history |
|
|
|
|
|
output_masks = output_tokens.eq(self.unk) |
|
_scores, _tokens = self.decoder( |
|
normalize=True, |
|
prev_output_tokens=output_tokens, |
|
encoder_out=encoder_out, |
|
).max(-1) |
|
output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) |
|
output_scores.masked_scatter_(output_masks, _scores[output_masks]) |
|
|
|
if history is not None: |
|
history.append(output_tokens.clone()) |
|
|
|
|
|
if (step + 1) < max_step: |
|
skeptical_mask = _skeptical_unmasking( |
|
output_scores, output_tokens.ne(self.pad), 1 - (step + 1) / max_step |
|
) |
|
|
|
output_tokens.masked_fill_(skeptical_mask, self.unk) |
|
output_scores.masked_fill_(skeptical_mask, 0.0) |
|
|
|
if history is not None: |
|
history.append(output_tokens.clone()) |
|
|
|
return decoder_out._replace( |
|
output_tokens=output_tokens, |
|
output_scores=output_scores, |
|
attn=None, |
|
history=history, |
|
) |
|
|
|
|
|
@register_model_architecture("cmlm_transformer", "cmlm_transformer") |
|
def cmlm_base_architecture(args): |
|
args.encoder_embed_path = getattr(args, "encoder_embed_path", None) |
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) |
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) |
|
args.encoder_layers = getattr(args, "encoder_layers", 6) |
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) |
|
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) |
|
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) |
|
args.decoder_embed_path = getattr(args, "decoder_embed_path", None) |
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) |
|
args.decoder_ffn_embed_dim = getattr( |
|
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim |
|
) |
|
args.decoder_layers = getattr(args, "decoder_layers", 6) |
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) |
|
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) |
|
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) |
|
args.attention_dropout = getattr(args, "attention_dropout", 0.0) |
|
args.activation_dropout = getattr(args, "activation_dropout", 0.0) |
|
args.activation_fn = getattr(args, "activation_fn", "relu") |
|
args.dropout = getattr(args, "dropout", 0.1) |
|
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) |
|
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) |
|
args.share_decoder_input_output_embed = getattr( |
|
args, "share_decoder_input_output_embed", False |
|
) |
|
args.share_all_embeddings = getattr(args, "share_all_embeddings", True) |
|
args.no_token_positional_embeddings = getattr( |
|
args, "no_token_positional_embeddings", False |
|
) |
|
args.adaptive_input = getattr(args, "adaptive_input", False) |
|
args.apply_bert_init = getattr(args, "apply_bert_init", False) |
|
|
|
args.decoder_output_dim = getattr( |
|
args, "decoder_output_dim", args.decoder_embed_dim |
|
) |
|
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) |
|
|
|
|
|
args.sg_length_pred = getattr(args, "sg_length_pred", False) |
|
args.pred_length_offset = getattr(args, "pred_length_offset", False) |
|
args.length_loss_factor = getattr(args, "length_loss_factor", 0.1) |
|
args.ngram_predictor = getattr(args, "ngram_predictor", 1) |
|
args.src_embedding_copy = getattr(args, "src_embedding_copy", False) |
|
|
|
|
|
@register_model_architecture("cmlm_transformer", "cmlm_transformer_wmt_en_de") |
|
def cmlm_wmt_en_de(args): |
|
cmlm_base_architecture(args) |
|
|