File size: 5,555 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from fairseq.models.transformer import (
TransformerDecoder,
TransformerEncoder,
TransformerModel,
)
from fairseq.modules.transformer_sentence_encoder import init_bert_params
def ensemble_encoder(func):
def wrapper(self, *args, **kwargs):
if self.ensemble_models is None or len(self.ensemble_models) == 1:
return func(self, *args, **kwargs)
encoder_outs = [
func(model, *args, **kwargs, return_all_hiddens=True)
for model in self.ensemble_models
]
_encoder_out = encoder_outs[0].copy()
def stack(key):
outs = [e[key][0] for e in encoder_outs]
return [torch.stack(outs, -1) if outs[0] is not None else None]
_encoder_out["encoder_out"] = stack("encoder_out")
_encoder_out["encoder_embedding"] = stack("encoder_embedding")
num_layers = len(_encoder_out["encoder_states"])
if num_layers > 0:
_encoder_out["encoder_states"] = [
torch.stack([e["encoder_states"][i] for e in encoder_outs], -1)
for i in range(num_layers)
]
return _encoder_out
return wrapper
def ensemble_decoder(func):
def wrapper(self, normalize=False, encoder_out=None, *args, **kwargs):
if self.ensemble_models is None or len(self.ensemble_models) == 1:
return func(
self, normalize=normalize, encoder_out=encoder_out, *args, **kwargs
)
def _replace(encoder_out, new_val):
new_encoder_out = encoder_out.copy()
new_encoder_out["encoder_out"] = [new_val]
return new_encoder_out
action_outs = [
func(
model,
normalize=normalize,
encoder_out=_replace(
encoder_out, encoder_out["encoder_out"][0][:, :, :, i]
),
*args,
**kwargs
)
for i, model in enumerate(self.ensemble_models)
]
if not isinstance(action_outs[0], tuple): # return multiple values
action_outs = [[a] for a in action_outs]
else:
action_outs = [list(a) for a in action_outs]
ensembled_outs = []
for i in range(len(action_outs[0])):
if i == 0 and normalize:
ensembled_outs += [
torch.logsumexp(
torch.stack([a[i] for a in action_outs], -1), dim=-1
)
- math.log(len(self.ensemble_models))
]
elif action_outs[0][i] is not None:
ensembled_outs += [torch.stack([a[i] for a in action_outs], -1)]
else:
ensembled_outs += [None]
if len(ensembled_outs) == 1:
return ensembled_outs[0]
return tuple(ensembled_outs)
return wrapper
class FairseqNATModel(TransformerModel):
"""
Abstract class for all nonautoregressive-based models
"""
def __init__(self, args, encoder, decoder):
super().__init__(args, encoder, decoder)
self.tgt_dict = decoder.dictionary
self.bos = decoder.dictionary.bos()
self.eos = decoder.dictionary.eos()
self.pad = decoder.dictionary.pad()
self.unk = decoder.dictionary.unk()
self.ensemble_models = None
@property
def allow_length_beam(self):
return False
@property
def allow_ensemble(self):
return True
def enable_ensemble(self, models):
self.encoder.ensemble_models = [m.encoder for m in models]
self.decoder.ensemble_models = [m.decoder for m in models]
@staticmethod
def add_args(parser):
TransformerModel.add_args(parser)
parser.add_argument(
"--apply-bert-init",
action="store_true",
help="use custom param initialization for BERT",
)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
decoder = FairseqNATDecoder(args, tgt_dict, embed_tokens)
if getattr(args, "apply_bert_init", False):
decoder.apply(init_bert_params)
return decoder
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
encoder = FairseqNATEncoder(args, src_dict, embed_tokens)
if getattr(args, "apply_bert_init", False):
encoder.apply(init_bert_params)
return encoder
def forward_encoder(self, encoder_inputs):
return self.encoder(*encoder_inputs)
def forward_decoder(self, *args, **kwargs):
return NotImplementedError
def initialize_output_tokens(self, *args, **kwargs):
return NotImplementedError
def forward(self, *args, **kwargs):
return NotImplementedError
class FairseqNATEncoder(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
self.ensemble_models = None
@ensemble_encoder
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
class FairseqNATDecoder(TransformerDecoder):
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
super().__init__(args, dictionary, embed_tokens, no_encoder_attn)
self.ensemble_models = None
|