File size: 16,359 Bytes
e7d3e35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import os

import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.utils import ContextManagers

from m4.training.setup_vision_model import vision_model_name_to_model
from m4.training.utils import (
    deepspeed_zero_init_disabled_context_manager,
    is_deepspeed_zero_init_enabled,
    load_state_dict_into_model,
)


# from pathlib import Path


class VLOOMPreTrainedModelBase(PreTrainedModel):
    # The problem we are trying to solve is 2 nested zero.Init thanks to fetching from_pretrained(vision_model_name)
    # and then one more zero.Init to override from_pretrained(vision_model_name) once again as it was done in the original - this breaks deepspeed zero3 w/ zero.Init
    # So one solution is this:
    # a. replace  from_pretrained(vision_model_name) with from_config(vision_model_name) while hacking to disable zero.Init context
    # b. instead of straight replacement of model.vision_model = from_pretrained(vision_model_name) when it gets updated, we first do from_pretrained(vision_model_name) and then update the existing model with weights using the already zero.Init'ed pre-sharded weights
    #
    # there are a few variations to get_vision_model_from_config - all need to bypass zero.Init under zero3
    # 1. one variant is to hack into accelerate's deepspeed_plugin and turn off zero.Init while loading the vision model
    # 2. the other variant is to override _from_config method with our version that doesn't do zero.Init

    @classmethod
    def override_vision_model(cls, model, vision_model_name, vision_model_params, torch_dtype):
        # 1. fetch the pretrained vision model w/o zero.Init
        with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
            vision_model = AutoModel.from_pretrained(vision_model_name, **vision_model_params, torch_dtype=torch_dtype)

        # this extracts the desired submodule if the part we want is nested (e.g. as in clip)
        real_vision_model = vision_model_name_to_model(vision_model_name, vision_model)

        # 2. now override the weights already sharded by zero.Init with the weights from the real_vision_model
        # by gradually gathering sharded weights and replacing with new weights
        if is_deepspeed_zero_init_enabled():
            state_dict = real_vision_model.state_dict()
            load_state_dict_into_model(model.vision_model, state_dict, start_prefix="")
        else:
            model.vision_model = real_vision_model

    @classmethod
    def from_config(cls, config, **kwargs):
        # torch_dtype is crucial for using the minimal amount of memory at load time
        torch_dtype = kwargs.get("torch_dtype", None)

        vision_model_name = config.vision_model_name
        vision_model_params = eval(config.vision_model_params)

        # 1. create an uninitialized vision_model to insert into the main model.
        # It has to be created outside lm's `from_pretrained` and w/o zero.Init so that zero3+zero.Init works
        with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
            vision_model_config = AutoConfig.from_pretrained(vision_model_name, **vision_model_params)
            vision_model_from_config = AutoModel.from_config(vision_model_config, torch_dtype=torch_dtype)
        # this extracts the desired submodule if the part we want is nested (e.g. as in clip)
        kwargs["vision_model"] = vision_model_name_to_model(vision_model_name, vision_model_from_config)

        # 2. create the main class's model, passing the uninitialized vision_model to it
        model = cls(config, **kwargs)

        return model

    @classmethod
    def from_pretrained_models(cls, *args, **kwargs):
        """
        Use this method when creating a new vloom model that hasn't been yet trained and it'll be
        composed of 2 pre-trained models - hence `pretrained_models`.
        """

        return cls.from_pretrained(*args, **kwargs, new_model=True)

    @classmethod
    def from_pretrained(cls, *model_args, is_resume=False, new_model=False, **kwargs):
        """
        Use this method when loading an already pretrained vloom model - either from a checkpoint or from hub.
        For creating an untrained model use `pretrained_models` instead.
        """

        is_untrained_vloom_model = False
        is_pretrained_vloom_model_resumed = False
        is_pretrained_vloom_model_from_hub_or_path = False

        # we have 3 use cases:
        # 1. is_untrained_vloom_model - a totally new vloom model
        # 2. is_pretrained_vloom_model_resumed - a pretrained vloom model being resumed from a
        #    checkpoint (instantiate a random empty model in this case)
        # 3. is_pretrained_vloom_model_from_hub_or_path - a pretrained vloom model loaded from hub or local path
        if new_model:
            is_untrained_vloom_model = True
        elif is_resume:
            is_pretrained_vloom_model_resumed = True
        else:
            is_pretrained_vloom_model_from_hub_or_path = True

        # torch_dtype is crucial for using the minimal amount of memory at load time
        torch_dtype = kwargs.get("torch_dtype", None)

        # config is:
        # 1. either not passed and then we use the model's default config (used by tests)
        # 2. passed and in which case it's one of:
        #   2a. `PretrainedConfig` (a new m4 model)
        #   2b. path to a json config (an already pretrained m4 model, usually resumed training)
        config = kwargs.get("config", None)
        if config is None:
            config = cls.config_class.from_pretrained(*model_args, **kwargs, return_unused_kwargs=False)
        elif not isinstance(config, PretrainedConfig):
            # adapted from https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/modeling_utils.py#L1920
            assert isinstance(config, os.PathLike)
            config_path = str(config)
            config = cls.config_class.from_pretrained(
                config_path,
                return_unused_kwargs=False,
                **kwargs,
            )

        vision_model_name = config.vision_model_name
        vision_model_params = eval(config.vision_model_params)

        # 1. create an uninitialized vision_model to insert into the main model.
        # It has to be created outside lm's `from_pretrained` and w/o zero.Init so that zero3+zero.Init works
        with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
            vision_model_config = AutoConfig.from_pretrained(vision_model_name, **vision_model_params)
            vision_model_from_config = AutoModel.from_config(vision_model_config, torch_dtype=torch_dtype)
        # this extracts the desired submodule if the part we want is nested (e.g. as in clip)
        kwargs["vision_model"] = vision_model_name_to_model(vision_model_name, vision_model_from_config)

        # 2. create the vloom model
        if is_untrained_vloom_model or is_pretrained_vloom_model_from_hub_or_path:
            model = super().from_pretrained(*model_args, **kwargs)
        elif is_pretrained_vloom_model_resumed:
            # in the case of resume under deepspeed we create an empty model, and get deepspeed
            # to load the weights from the checkpoint
            # but not all models have these keys so handle the case they don't have them
            _ = kwargs.pop("config", None)
            model = super().from_pretrained(None, config=config, state_dict={}, **kwargs)

        # 3. if is_untrained_vloom_model, now override the uninitialized vision_model with one with pretrained weights
        if is_untrained_vloom_model:
            cls.override_vision_model_wrapper(model, config, vision_model_name, vision_model_params, torch_dtype)

        return model


class DecoupledEmbedding(nn.Embedding):
    # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
    """
    Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings.
    In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained.
    If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
    """

    def __init__(
        self,
        num_embeddings,
        num_additional_embeddings,
        embedding_dim,
        partially_freeze=False,
        device=None,
        dtype=None,
        padding_idx=None,
        **kwargs,
    ) -> None:
        """
        num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`.
        partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen.

        Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these.
        """
        if padding_idx is not None and padding_idx > num_embeddings:
            raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}")
        super().__init__(
            num_embeddings=num_embeddings,
            embedding_dim=embedding_dim,
            device=device,
            dtype=dtype,
            padding_idx=padding_idx,
            **kwargs,
        )
        self.num_embeddings = num_embeddings
        self.padding_idx = padding_idx
        self.num_additional_embeddings = num_additional_embeddings
        self.partially_freeze = partially_freeze

        if partially_freeze:
            self.weight.requires_grad_(False)

        if self.num_additional_embeddings > 0:
            self.additional_embedding = nn.Embedding(
                num_embeddings=self.num_additional_embeddings,
                embedding_dim=embedding_dim,
                device=device,
                dtype=dtype,
            )

    def forward(self, input_ids):
        """
        we have 2 embeddings, with different indices - one pretrained self.weight and another
        self.additional_embedding.weight that is being trained.

        in order to make a lookup of the input ids, we:
        1. find out the indices of the entries belonging to the 2nd embedding
        2. extract those values while subtracting the size of the first embedding (num_embeddings),
           since the 2nd embedding starts from 0 and not num_embeddings
        3. perform the 2nd embedding lookup
        4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
        5. perform the 1st embedding lookup
        6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup

        note: for the 1st embedding lookup we could have looked up only the low indices and not do
        the padding, but then we have to create a new tensor and populate it with 2 tensors that are
        spread out across various indices - i.e. not a simple concat - I haven't benchmarked the
        complex case if it's any faster, given that seqlens are usually relatively short it's
        probably not faster or if faster not by much - but might be a good idea to measure.

        """
        if self.num_additional_embeddings == 0:
            return F.embedding(input_ids, self.weight)

        # Clone so that we don't modify the original input_ids later on
        input_ids = input_ids.clone()
        additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
        input_ids_additional_vocab = input_ids[additional_vocab_indices]
        additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings)

        # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
        input_ids[additional_vocab_indices] = 0
        full_vector = F.embedding(input_ids, self.weight)

        # overwrite the records with high indices
        full_vector[additional_vocab_indices] = additional_embeddings

        return full_vector

    def extra_repr(self) -> str:
        return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
            self.num_embeddings,
            self.num_additional_embeddings,
            self.embedding_dim,
            self.partially_freeze,
        )

    @classmethod
    def from_pretrained(cls, embeddings, freeze=True, **kwargs):
        raise NotImplementedError


class DecoupledLinear(nn.Linear):
    # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
    """
    Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters.
    In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, then it will create `out_additional_features * in_features` additional parameters that are always trained.
    If `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        out_additional_features: int = 0,
        bias: bool = True,
        partially_freeze: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        """
        out_additional_features: int. Number of additional trainable dimensions. Only makes sense when `partially_freeze=True`.
        partially_freeze: bool. If True, the regular `weight` will be frozen and extra parameters (if any) will be trainable. If False, default to the regular behavior of nn.Linear.
        """
        super().__init__(in_features, out_features, bias, device, dtype)
        self.out_additional_features = out_additional_features
        self.partially_freeze = partially_freeze

        self.in_features = in_features
        self.out_features = out_features

        if partially_freeze:
            self.weight.requires_grad_(False)
            if bias:
                self.bias.requires_grad_(False)

        if out_additional_features > 0:
            self.additional_fc = nn.Linear(
                in_features=in_features,
                out_features=out_additional_features,
                bias=bias,
                device=device,
                dtype=dtype,
            )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = F.linear(input, self.weight, self.bias)

        if self.out_additional_features > 0:
            additional_features = F.linear(input, self.additional_fc.weight, self.additional_fc.bias)
            output = torch.cat((output, additional_features), -1)

        return output

    def extra_repr(self) -> str:
        """Overwriting `nn.Linear.extra_repr` to include new parameters."""
        return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format(
            self.in_features,
            self.out_features,
            self.out_additional_features,
            self.bias is not None,
            self.partially_freeze,
        )


if __name__ == "__main__":
    emb = DecoupledEmbedding(num_embeddings=10, num_additional_embeddings=3, embedding_dim=5, partially_freeze=True)
    for n, p in emb.named_parameters():
        print(n, p.requires_grad)
    idx = torch.tensor([[11, 1, 3]])
    y = emb(idx)
    loss = y.sum()
    loss.backward()
    print(emb.weight, emb.weight.grad)
    print(emb.additional_embedding, emb.additional_embedding.grad)

    lin = DecoupledLinear(in_features=3, out_features=4, out_additional_features=2, bias=True, partially_freeze=True)
    for n, p in lin.named_parameters():
        print(n, p.requires_grad)
    x = torch.randn(12, 3)
    y = lin(x)
    loss = y.sum()
    loss.backward()
    print("Weight w and grad:", lin.weight, lin.weight.grad)
    print("bias w and grad:", lin.bias, lin.bias.grad)
    print("additional_fc.weight w and grad:", lin.additional_fc.weight, lin.additional_fc.weight.grad)
    print("additional_bias w and grad:", lin.additional_fc.bias, lin.additional_fc.bias.grad)