"""Ensemble decoding. Decodes using multiple models simultaneously, combining their prediction distributions by averaging. All models in the ensemble must share a target vocabulary. """ import torch import torch.nn as nn from onmt.encoders.encoder import EncoderBase from onmt.decoders.decoder import DecoderBase from onmt.models import NMTModel import onmt.model_builder class EnsembleDecoderOutput(object): """Wrapper around multiple decoder final hidden states.""" def __init__(self, model_dec_outs): self.model_dec_outs = tuple(model_dec_outs) def squeeze(self, dim=None): """Delegate squeeze to avoid modifying :func:`onmt.translate.translator.Translator.translate_batch()` """ return EnsembleDecoderOutput([x.squeeze(dim) for x in self.model_dec_outs]) def __getitem__(self, index): return self.model_dec_outs[index] class EnsembleEncoder(EncoderBase): """Dummy Encoder that delegates to individual real Encoders.""" def __init__(self, model_encoders): super(EnsembleEncoder, self).__init__() self.model_encoders = nn.ModuleList(model_encoders) def forward(self, src, src_len=None): enc_out, enc_final_hs, _ = zip( *[model_encoder(src, src_len) for model_encoder in self.model_encoders] ) return enc_out, enc_final_hs, src_len class EnsembleDecoder(DecoderBase): """Dummy Decoder that delegates to individual real Decoders.""" def __init__(self, model_decoders): model_decoders = nn.ModuleList(model_decoders) attentional = any([dec.attentional for dec in model_decoders]) super(EnsembleDecoder, self).__init__(attentional) self.model_decoders = model_decoders def forward(self, tgt, enc_out, src_len=None, step=None, **kwargs): """See :func:`onmt.decoders.decoder.DecoderBase.forward()`.""" # src_len is a single tensor shared between all models. # This assumption will not hold if Translator is modified # to calculate src_len as something other than the length # of the input. dec_outs, attns = zip( *[ model_decoder(tgt, enc_out[i], src_len=src_len, step=step, **kwargs) for i, model_decoder in enumerate(self.model_decoders) ] ) mean_attns = self.combine_attns(attns) return EnsembleDecoderOutput(dec_outs), mean_attns def combine_attns(self, attns): result = {} for key in attns[0].keys(): result[key] = torch.stack( [attn[key] for attn in attns if attn[key] is not None] ).mean(0) return result def init_state(self, src, enc_out, enc_hidden): """See :obj:`RNNDecoderBase.init_state()`""" for i, model_decoder in enumerate(self.model_decoders): model_decoder.init_state(src, enc_out[i], enc_hidden[i]) def map_state(self, fn): for model_decoder in self.model_decoders: model_decoder.map_state(fn) class EnsembleGenerator(nn.Module): """ Dummy Generator that delegates to individual real Generators, and then averages the resulting target distributions. """ def __init__(self, model_generators, raw_probs=False): super(EnsembleGenerator, self).__init__() self.model_generators = nn.ModuleList(model_generators) self._raw_probs = raw_probs def forward(self, hidden, attn=None, src_map=None): """ Compute a distribution over the target dictionary by averaging distributions from models in the ensemble. All models in the ensemble must share a target vocabulary. """ distributions = torch.stack( [ mg(h) if attn is None else mg(h, attn, src_map) for h, mg in zip(hidden, self.model_generators) ] ) if self._raw_probs: return torch.log(torch.exp(distributions).mean(0)) else: return distributions.mean(0) class EnsembleModel(NMTModel): """Dummy NMTModel wrapping individual real NMTModels.""" def __init__(self, models, raw_probs=False): encoder = EnsembleEncoder(model.encoder for model in models) decoder = EnsembleDecoder(model.decoder for model in models) super(EnsembleModel, self).__init__(encoder, decoder) self.generator = EnsembleGenerator( [model.generator for model in models], raw_probs ) self.models = nn.ModuleList(models) def load_test_model(opt, device_id=0): """Read in multiple models for ensemble.""" shared_vocabs = None shared_model_opt = None models = [] for model_path in opt.models: vocabs, model, model_opt = onmt.model_builder.load_test_model( opt, device_id, model_path=model_path ) if shared_vocabs is None: shared_vocabs = vocabs else: assert ( shared_vocabs["src"].tokens_to_ids == vocabs["src"].tokens_to_ids ), "Ensemble models must use the same vocabs " models.append(model) if shared_model_opt is None: shared_model_opt = model_opt ensemble_model = EnsembleModel(models, opt.avg_raw_probs) return shared_vocabs, ensemble_model, shared_model_opt