File size: 6,453 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
This file implements:
Ghazvininejad, Marjan, et al.
"Constant-time machine translation with conditional masked language models."
arXiv preprint arXiv:1904.09324 (2019).
"""
from fairseq.models import register_model, register_model_architecture
from fairseq.models.nat import NATransformerModel
from fairseq.utils import new_arange
def _skeptical_unmasking(output_scores, output_masks, p):
sorted_index = output_scores.sort(-1)[1]
boundary_len = (
(output_masks.sum(1, keepdim=True).type_as(output_scores) - 2) * p
).long()
skeptical_mask = new_arange(output_masks) < boundary_len
return skeptical_mask.scatter(1, sorted_index, skeptical_mask)
@register_model("cmlm_transformer")
class CMLMNATransformerModel(NATransformerModel):
@staticmethod
def add_args(parser):
NATransformerModel.add_args(parser)
def forward(
self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
):
assert not self.decoder.src_embedding_copy, "do not support embedding copy."
# encoding
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
# length prediction
length_out = self.decoder.forward_length(
normalize=False, encoder_out=encoder_out
)
length_tgt = self.decoder.forward_length_prediction(
length_out, encoder_out, tgt_tokens
)
# decoding
word_ins_out = self.decoder(
normalize=False,
prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out,
)
word_ins_mask = prev_output_tokens.eq(self.unk)
return {
"word_ins": {
"out": word_ins_out,
"tgt": tgt_tokens,
"mask": word_ins_mask,
"ls": self.args.label_smoothing,
"nll_loss": True,
},
"length": {
"out": length_out,
"tgt": length_tgt,
"factor": self.decoder.length_loss_factor,
},
}
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
step = decoder_out.step
max_step = decoder_out.max_step
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
history = decoder_out.history
# execute the decoder
output_masks = output_tokens.eq(self.unk)
_scores, _tokens = self.decoder(
normalize=True,
prev_output_tokens=output_tokens,
encoder_out=encoder_out,
).max(-1)
output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
output_scores.masked_scatter_(output_masks, _scores[output_masks])
if history is not None:
history.append(output_tokens.clone())
# skeptical decoding (depend on the maximum decoding steps.)
if (step + 1) < max_step:
skeptical_mask = _skeptical_unmasking(
output_scores, output_tokens.ne(self.pad), 1 - (step + 1) / max_step
)
output_tokens.masked_fill_(skeptical_mask, self.unk)
output_scores.masked_fill_(skeptical_mask, 0.0)
if history is not None:
history.append(output_tokens.clone())
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
history=history,
)
@register_model_architecture("cmlm_transformer", "cmlm_transformer")
def cmlm_base_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", True)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.apply_bert_init = getattr(args, "apply_bert_init", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
# --- special arguments ---
args.sg_length_pred = getattr(args, "sg_length_pred", False)
args.pred_length_offset = getattr(args, "pred_length_offset", False)
args.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
args.ngram_predictor = getattr(args, "ngram_predictor", 1)
args.src_embedding_copy = getattr(args, "src_embedding_copy", False)
@register_model_architecture("cmlm_transformer", "cmlm_transformer_wmt_en_de")
def cmlm_wmt_en_de(args):
cmlm_base_architecture(args)
|