|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
|
|
from fairseq import utils |
|
from fairseq.models import ( |
|
FairseqMultiModel, |
|
register_model, |
|
register_model_architecture, |
|
) |
|
from fairseq.models.transformer import ( |
|
Embedding, |
|
TransformerDecoder, |
|
TransformerEncoder, |
|
TransformerModel, |
|
base_architecture, |
|
) |
|
from fairseq.utils import safe_hasattr |
|
|
|
|
|
@register_model("multilingual_transformer") |
|
class MultilingualTransformerModel(FairseqMultiModel): |
|
"""Train Transformer models for multiple language pairs simultaneously. |
|
|
|
Requires `--task multilingual_translation`. |
|
|
|
We inherit all arguments from TransformerModel and assume that all language |
|
pairs use a single Transformer architecture. In addition, we provide several |
|
options that are specific to the multilingual setting. |
|
|
|
Args: |
|
--share-encoder-embeddings: share encoder embeddings across all source languages |
|
--share-decoder-embeddings: share decoder embeddings across all target languages |
|
--share-encoders: share all encoder params (incl. embeddings) across all source languages |
|
--share-decoders: share all decoder params (incl. embeddings) across all target languages |
|
""" |
|
|
|
def __init__(self, encoders, decoders): |
|
super().__init__(encoders, decoders) |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
"""Add model-specific arguments to the parser.""" |
|
TransformerModel.add_args(parser) |
|
parser.add_argument( |
|
"--share-encoder-embeddings", |
|
action="store_true", |
|
help="share encoder embeddings across languages", |
|
) |
|
parser.add_argument( |
|
"--share-decoder-embeddings", |
|
action="store_true", |
|
help="share decoder embeddings across languages", |
|
) |
|
parser.add_argument( |
|
"--share-encoders", |
|
action="store_true", |
|
help="share encoders across languages", |
|
) |
|
parser.add_argument( |
|
"--share-decoders", |
|
action="store_true", |
|
help="share decoders across languages", |
|
) |
|
|
|
@classmethod |
|
def build_model(cls, args, task): |
|
"""Build a new model instance.""" |
|
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask |
|
|
|
assert isinstance(task, MultilingualTranslationTask) |
|
|
|
|
|
base_multilingual_architecture(args) |
|
|
|
if not safe_hasattr(args, "max_source_positions"): |
|
args.max_source_positions = 1024 |
|
if not safe_hasattr(args, "max_target_positions"): |
|
args.max_target_positions = 1024 |
|
|
|
src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs] |
|
tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs] |
|
|
|
if args.share_encoders: |
|
args.share_encoder_embeddings = True |
|
if args.share_decoders: |
|
args.share_decoder_embeddings = True |
|
|
|
def build_embedding(dictionary, embed_dim, path=None): |
|
num_embeddings = len(dictionary) |
|
padding_idx = dictionary.pad() |
|
emb = Embedding(num_embeddings, embed_dim, padding_idx) |
|
|
|
if path: |
|
embed_dict = utils.parse_embedding(path) |
|
utils.load_embedding(embed_dict, dictionary, emb) |
|
return emb |
|
|
|
|
|
shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None |
|
if args.share_all_embeddings: |
|
if args.encoder_embed_dim != args.decoder_embed_dim: |
|
raise ValueError( |
|
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" |
|
) |
|
if args.decoder_embed_path and ( |
|
args.decoder_embed_path != args.encoder_embed_path |
|
): |
|
raise ValueError( |
|
"--share-all-embeddings not compatible with --decoder-embed-path" |
|
) |
|
shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( |
|
dicts=task.dicts, |
|
langs=task.langs, |
|
embed_dim=args.encoder_embed_dim, |
|
build_embedding=build_embedding, |
|
pretrained_embed_path=args.encoder_embed_path, |
|
) |
|
shared_decoder_embed_tokens = shared_encoder_embed_tokens |
|
args.share_decoder_input_output_embed = True |
|
else: |
|
if args.share_encoder_embeddings: |
|
shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( |
|
dicts=task.dicts, |
|
langs=src_langs, |
|
embed_dim=args.encoder_embed_dim, |
|
build_embedding=build_embedding, |
|
pretrained_embed_path=args.encoder_embed_path, |
|
) |
|
if args.share_decoder_embeddings: |
|
shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( |
|
dicts=task.dicts, |
|
langs=tgt_langs, |
|
embed_dim=args.decoder_embed_dim, |
|
build_embedding=build_embedding, |
|
pretrained_embed_path=args.decoder_embed_path, |
|
) |
|
|
|
|
|
lang_encoders, lang_decoders = {}, {} |
|
|
|
def get_encoder(lang): |
|
if lang not in lang_encoders: |
|
if shared_encoder_embed_tokens is not None: |
|
encoder_embed_tokens = shared_encoder_embed_tokens |
|
else: |
|
encoder_embed_tokens = build_embedding( |
|
task.dicts[lang], |
|
args.encoder_embed_dim, |
|
args.encoder_embed_path, |
|
) |
|
lang_encoders[lang] = cls._get_module_class( |
|
True, args, task.dicts[lang], encoder_embed_tokens, src_langs |
|
) |
|
return lang_encoders[lang] |
|
|
|
def get_decoder(lang): |
|
if lang not in lang_decoders: |
|
if shared_decoder_embed_tokens is not None: |
|
decoder_embed_tokens = shared_decoder_embed_tokens |
|
else: |
|
decoder_embed_tokens = build_embedding( |
|
task.dicts[lang], |
|
args.decoder_embed_dim, |
|
args.decoder_embed_path, |
|
) |
|
lang_decoders[lang] = cls._get_module_class( |
|
False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs |
|
) |
|
return lang_decoders[lang] |
|
|
|
|
|
shared_encoder, shared_decoder = None, None |
|
if args.share_encoders: |
|
shared_encoder = get_encoder(src_langs[0]) |
|
if args.share_decoders: |
|
shared_decoder = get_decoder(tgt_langs[0]) |
|
|
|
encoders, decoders = OrderedDict(), OrderedDict() |
|
for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs): |
|
encoders[lang_pair] = ( |
|
shared_encoder if shared_encoder is not None else get_encoder(src) |
|
) |
|
decoders[lang_pair] = ( |
|
shared_decoder if shared_decoder is not None else get_decoder(tgt) |
|
) |
|
|
|
return MultilingualTransformerModel(encoders, decoders) |
|
|
|
@classmethod |
|
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): |
|
module_class = TransformerEncoder if is_encoder else TransformerDecoder |
|
return module_class(args, lang_dict, embed_tokens) |
|
|
|
def load_state_dict(self, state_dict, strict=True, model_cfg=None): |
|
state_dict_subset = state_dict.copy() |
|
for k, _ in state_dict.items(): |
|
assert k.startswith("models.") |
|
lang_pair = k.split(".")[1] |
|
if lang_pair not in self.models: |
|
del state_dict_subset[k] |
|
super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg) |
|
|
|
|
|
@register_model_architecture("multilingual_transformer", "multilingual_transformer") |
|
def base_multilingual_architecture(args): |
|
base_architecture(args) |
|
args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", False) |
|
args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", False) |
|
args.share_encoders = getattr(args, "share_encoders", False) |
|
args.share_decoders = getattr(args, "share_decoders", False) |
|
|
|
|
|
@register_model_architecture( |
|
"multilingual_transformer", "multilingual_transformer_iwslt_de_en" |
|
) |
|
def multilingual_transformer_iwslt_de_en(args): |
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) |
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) |
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) |
|
args.encoder_layers = getattr(args, "encoder_layers", 6) |
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) |
|
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) |
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) |
|
args.decoder_layers = getattr(args, "decoder_layers", 6) |
|
base_multilingual_architecture(args) |
|
|