File size: 8,076 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import argparse
import logging

import torch.nn as nn
import fairseq.checkpoint_utils
from fairseq.models import (
    FairseqEncoderDecoderModel,
    register_model,
    register_model_architecture,
)
from fairseq.models.transformer import TransformerDecoder
from fairseq.models.roberta import model as roberta

logger = logging.getLogger(__name__)


@register_model("roberta_enc_dec")
class RobertaEncDecModel(FairseqEncoderDecoderModel):
    @staticmethod
    def add_args(parser):
        parser.add_argument(
            "--pretrained-mlm-checkpoint",
            default=None,
            type=str,
            metavar="PRETRAINED",
            help="path to pretrained mlm checkpoint",
        )
        parser.add_argument(
            "--pretrained-decoder", action="store_true", help="reload decoder"
        )
        parser.add_argument(
            "--hack-layernorm-embedding",
            action="store_true",
            help="hack to reload old models trained with encoder-normalize-before=False (no equivalent to encoder-normalize-before=False and layernorm_embedding=False",
        )
        parser.add_argument(
            "--share-decoder-input-output-embed",
            action="store_true",
            help="share decoder input and output embeddings",
        )
        parser.add_argument(
            "--share-all-embeddings",
            action="store_true",
            help="share encoder, decoder and output embeddings"
            " (requires shared dictionary and embed dim)",
        )

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present
        base_enc_dec_architecture(args)
        if args.pretrained_mlm_checkpoint:
            arg_overrides = None
            if args.hack_layernorm_embedding:
                arg_overrides = {"layernorm_embedding": False}
            loaded = fairseq.checkpoint_utils.load_model_ensemble_and_task(
                [args.pretrained_mlm_checkpoint], arg_overrides=arg_overrides
            )
            ([roberta_enc], _cfg, _task) = loaded
        else:
            # Do we need to edit untie_weights here ?
            share_in_out = (
                args.share_decoder_input_output_embed or args.share_all_embeddings
            )
            args.untie_weights_roberta = not share_in_out
            if args.hack_layernorm_embedding:
                args.layernorm_embedding = False
                args.encoder_normalize_before = False
            roberta_enc = roberta.RobertaModel.build_model(args, task)

        return cls.from_roberta(roberta_enc, args, task.source_dictionary)

    @staticmethod
    def from_roberta(roberta_enc: roberta.RobertaModel, args, dictionary):
        encoder = roberta_enc.encoder.sentence_encoder
        vocab_size, embed_dim = encoder.embed_tokens.weight.shape

        if args.share_all_embeddings:
            lm_head = roberta_enc.encoder.lm_head
            assert encoder.embed_tokens.weight is lm_head.weight, (
                "Can't use --share-all-embeddings with a model "
                "that was pretraiend with --untie-weights-roberta_enc"
            )
        else:
            lm_head = roberta.RobertaLMHead(
                embed_dim, vocab_size, roberta_enc.args.activation_fn
            )

        dec_embs = nn.Embedding(vocab_size, embed_dim, dictionary.pad())
        if args.share_all_embeddings or args.share_decoder_input_output_embed:
            # Note: I wasn't able to use Embedding _weight parameter to achive this sharing.
            dec_embs.weight = lm_head.weight

        decoder = TransformerDecoder(
            RobertaEncDecModel.read_args_from_roberta(roberta_enc.args),
            dictionary,
            dec_embs,
            no_encoder_attn=False,
            output_projection=lm_head,
        )
        if getattr(args, "pretrained_decoder", False):
            decoder_dict = encoder.state_dict()

            # TODO: hide setting "encoder_attn" layers behind a flag.
            for k, w in list(decoder_dict.items()):
                if ".self_attn" in k:
                    k_enc_attn = k.replace(".self_attn", ".encoder_attn")
                    decoder_dict[k_enc_attn] = w.detach().clone()

            for k, w in lm_head.state_dict().items():
                decoder_dict["output_projection." + k] = w

            missing_keys, unexpected_keys = decoder.load_state_dict(
                decoder_dict, strict=False
            )
            # missing_keys = [m for m in missing_keys if ".encoder_attn" not in m]
            assert not missing_keys and not unexpected_keys, (
                "Failed to load state dict. "
                f"Missing keys: {missing_keys}. "
                f"Unexpected keys: {unexpected_keys}."
            )

        if args.share_all_embeddings:
            assert decoder.output_projection.weight is decoder.embed_tokens.weight
            assert encoder.embed_tokens.weight is decoder.embed_tokens.weight
        elif args.share_decoder_input_output_embed:
            assert decoder.output_projection.weight is decoder.embed_tokens.weight
            assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight
        else:
            assert decoder.output_projection.weight is not decoder.embed_tokens.weight
            assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight

        return RobertaEncDecModel(encoder, decoder)

    @staticmethod
    def read_args_from_roberta(roberta_args: argparse.Namespace):
        # TODO: this would become easier if encoder/decoder where using a similar
        # TransformerConfig object
        args = argparse.Namespace(**vars(roberta_args))
        attr_map = [
            ("encoder_attention_heads", "decoder_attention_heads"),
            ("encoder_embed_dim", "decoder_embed_dim"),
            ("encoder_embed_dim", "decoder_output_dim"),
            ("encoder_normalize_before", "decoder_normalize_before"),
            ("encoder_layers_to_keep", "decoder_layers_to_keep"),
            ("encoder_ffn_embed_dim", "decoder_ffn_embed_dim"),
            ("encoder_layerdrop", "decoder_layerdrop"),
            ("encoder_layers", "decoder_layers"),
            ("encoder_learned_pos", "decoder_learned_pos"),
            # should this be set from here ?
            ("max_positions", "max_target_positions"),
        ]
        for k1, k2 in attr_map:
            setattr(args, k2, getattr(roberta_args, k1))

        args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
        args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
        args.share_decoder_input_output_embed = not roberta_args.untie_weights_roberta
        return args

    def upgrade_state_dict_named(self, state_dict, name):
        prefix = name + "." if name != "" else ""
        super().upgrade_state_dict_named(state_dict, name)
        old_keys = list(state_dict.keys())

        # rename decoder -> encoder before upgrading children modules
        for k in old_keys:
            if k.startswith(prefix + "encoder.lm_head"):
                state_dict.pop(k)
                continue
            new_k = k
            new_k = new_k.replace(".sentence_encoder.", ".")
            new_k = new_k.replace("decoder.lm_head.", "decoder.output_projection.")
            if k == new_k:
                continue
            # print(k, "->", new_k)
            state_dict[new_k] = state_dict.pop(k)


@register_model_architecture("roberta_enc_dec", "roberta_enc_dec")
def base_enc_dec_architecture(args):
    args.hack_layernorm_embedding = getattr(args, "hack_layernorm_embedding", False)
    args.pretrained_mlm_checkpoint = getattr(args, "pretrained_mlm_checkpoint", None)
    args.pretrained_decoder = getattr(args, "pretrained_decoder", None)
    args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
    args.share_decoder_input_output_embed = getattr(
        args, "share_decoder_input_output_embed", False
    )

    roberta.base_architecture(args)