File size: 14,186 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
import os
import torch
import re
from collections import deque
from onmt.utils.logging import logger
from onmt.inputters.inputter import vocabs_to_dict
from onmt.modules.lora import lora_state_dict


def build_model_saver(model_opt, opt, model, vocabs, optim, device_id):
    # _check_save_model_path
    save_model_path = os.path.abspath(opt.save_model)
    os.makedirs(os.path.dirname(save_model_path), exist_ok=True)

    model_saver = ModelSaver(
        opt.save_model,
        model,
        model_opt,
        vocabs,
        optim,
        opt.keep_checkpoint,
        opt.save_format,
        device_id,
    )
    return model_saver


def load_checkpoint(ckpt_path):
    """Load checkpoint from `ckpt_path` if any else return `None`."""
    checkpoint = None
    if ckpt_path:
        logger.info("Loading checkpoint from %s" % ckpt_path)
        checkpoint = torch.load(ckpt_path, map_location=torch.device("cpu"))

        if "model" in checkpoint.keys():
            # This preserves backward-compat for models using customed layernorm
            def fix_key(s):
                s = re.sub(
                    r"(.*)\.layer_norm((_\d+)?)\.b_2", r"\1.layer_norm\2.bias", s
                )
                s = re.sub(
                    r"(.*)\.layer_norm((_\d+)?)\.a_2", r"\1.layer_norm\2.weight", s
                )
                return s

            checkpoint["model"] = {
                fix_key(k): v for k, v in checkpoint["model"].items()
            }
            # Force add_ffnbias to True if bias found in model w_1 keys
            for key in checkpoint["model"].keys():
                if "w_1.bias" in key:
                    checkpoint["opt"].add_ffnbias = True

        if not hasattr(checkpoint["opt"], "num_kv"):
            checkpoint["opt"].num_kv = 0
        if not hasattr(checkpoint["opt"], "add_ffnbias"):
            checkpoint["opt"].add_ffnbias = False
        if not hasattr(checkpoint["opt"], "parallel_residual"):
            checkpoint["opt"].parallel_residual = False
        if not hasattr(checkpoint["opt"], "shared_layer_norm"):
            checkpoint["opt"].shared_layer_norm = False
        if not hasattr(checkpoint["opt"], "use_ckpting"):
            checkpoint["opt"].use_ckpting = []
        if not hasattr(checkpoint["opt"], "relative_positions_buckets"):
            checkpoint["opt"].relative_positions_buckets = 0
        if not hasattr(checkpoint["opt"], "parallel_mode"):
            checkpoint["opt"].parallel_mode = "data_parallel"
        if not hasattr(checkpoint["opt"], "norm_eps"):
            checkpoint["opt"].norm_eps = 1e-6

        # fix v2 compatibility
        if "generator" in checkpoint.keys() and checkpoint["generator"]:
            if "0.weight" in checkpoint["generator"]:
                checkpoint["generator"]["weight"] = checkpoint["generator"].pop(
                    "0.weight"
                )
            if "0.bias" in checkpoint["generator"]:
                checkpoint["generator"]["bias"] = checkpoint["generator"].pop("0.bias")
        # end of patch for backward compatibility

    return checkpoint


class ModelSaverBase(object):
    """Base class for model saving operations

    Inherited classes must implement private methods:
    * `_save`
    * `_rm_checkpoint
    """

    def __init__(
        self,
        base_path,
        model,
        model_opt,
        vocabs,
        optim,
        keep_checkpoint=-1,
        save_format="pytorch",
        device_id=0,
    ):
        self.base_path = base_path
        self.model = model
        self.model_opt = model_opt
        self.vocabs = vocabs
        self.optim = optim
        self.last_saved_step = None
        self.keep_checkpoint = keep_checkpoint
        self.save_format = save_format
        self.device_id = device_id

        if keep_checkpoint > 0:
            self.checkpoint_queue = deque([], maxlen=keep_checkpoint)
            if save_format == "safetensors":
                self.model_queue = deque([], maxlen=keep_checkpoint)

    def save(self, step, moving_average=None):
        """Main entry point for model saver

        It wraps the `_save` method with checks and apply `keep_checkpoint`
        related logic
        """

        if self.keep_checkpoint == 0 or step == self.last_saved_step:
            return

        save_model = self.model
        if moving_average:
            model_params_data = []
            for avg, param in zip(moving_average, save_model.parameters()):
                model_params_data.append(param.data)
                param.data = avg.data

        if self.save_format == "pytorch":
            ckpt_path, _ = self._save(step, save_model)
        elif self.save_format == "safetensors":
            ckpt_path, model_path = self._st_save(step, save_model)

        self.last_saved_step = step

        if moving_average:
            for param_data, param in zip(model_params_data, save_model.parameters()):
                param.data = param_data

        if ckpt_path is not None:  # not None when process id 0
            if self.keep_checkpoint > 0:
                if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
                    todel = self.checkpoint_queue.popleft()
                    self._rm_checkpoint(todel)
                    if self.save_format == "safetensors":
                        todel = self.model_queue.popleft()
                        self._rm_checkpoint(todel)
                self.checkpoint_queue.append(ckpt_path)
                if self.save_format == "safetensors":
                    self.model_queue.append(model_path)

    def _save(self, step, model):
        """Save a resumable checkpoint.

        Args:
            step (int): step number
            model (nn.Module): torch model to save

        Returns:
            (str, str):

            * checkpoint_name: name (or path) of the saved checkpoint
            * model_name: name (or path) of the saved safetensors weights if applicable
        """

        raise NotImplementedError()

    def _rm_checkpoint(self, name):
        """Remove a checkpoint

        Args:
            name(str): name that indentifies the checkpoint
                (it may be a filepath)
        """

        raise NotImplementedError()


class ModelSaver(ModelSaverBase):
    """Simple model saver to filesystem"""

    def _save(self, step, model):
        if (
            hasattr(self.model_opt, "lora_layers")
            and len(self.model_opt.lora_layers) > 0
        ) or (
            hasattr(self.model_opt, "lora_embedding") and self.model_opt.lora_embedding
        ):
            model_state_dict = lora_state_dict(model, bias="lora_only")
            generator_state_dict = None
        else:
            model_state_dict = model.state_dict()
            model_state_dict = {
                k: v for k, v in model_state_dict.items() if "generator" not in k
            }
            generator_state_dict = model.generator.state_dict()

        if torch.distributed.is_initialized():
            ws = torch.distributed.get_world_size()
        else:
            ws = 1
        if ws > 1:
            full_model = [None for _ in range(ws)]
            for key, value in model_state_dict.items():
                model_state_dict[key] = value.cpu()
            torch.distributed.all_gather_object(full_model, model_state_dict)
            fm_sd = {}
            for key in full_model[0].keys():
                if key.split(".")[-1] == "lora_A":
                    if key.split(".")[-2] in [
                        "linear_keys",
                        "linear_values",
                        "linear_query",
                        "w_1",
                        "w_3",
                    ]:
                        fm_sd[key] = (
                            sum([full_model[i][key].cpu() for i in range(ws)]) / ws
                        )
                    elif key.split(".")[-2] in ["final_linear", "w_2"]:
                        fm_sd[key] = torch.cat(
                            [full_model[i][key].cpu() for i in range(ws)], 1
                        )
                elif key.split(".")[-1] == "lora_B":
                    if key.split(".")[-2] in [
                        "linear_keys",
                        "linear_values",
                        "linear_query",
                        "w_1",
                        "w_3",
                    ]:
                        fm_sd[key] = torch.cat(
                            [full_model[i][key].cpu() for i in range(ws)], 0
                        )
                    elif key.split(".")[-2] in ["final_linear", "w_2"]:
                        fm_sd[key] = (
                            sum([full_model[i][key].cpu() for i in range(ws)]) / ws
                        )
                elif key.split(".")[-1] in [
                    "linear_keys",
                    "linear_values",
                    "linear_query",
                    "w_1",
                    "w_3",
                ]:
                    fm_sd[key] = torch.cat(
                        [full_model[i][key].cpu() for i in range(ws)], 0
                    )
                elif key.split(".")[-1] in ["final_linear", "w_2"]:
                    fm_sd[key] = torch.cat(
                        [full_model[i][key].cpu() for i in range(ws)], 1
                    )
                else:
                    fm_sd[key] = full_model[0][key]
            model_state_dict = fm_sd

        checkpoint = {
            "model": model_state_dict,
            "generator": generator_state_dict,
            "vocab": vocabs_to_dict(self.vocabs),
            "opt": self.model_opt,
            "optim": self.optim.state_dict(),
        }
        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
            logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
            ckpt_path = "%s_step_%d.pt" % (self.base_path, step)
            torch.save(checkpoint, ckpt_path)
        else:
            ckpt_path = None
        if torch.distributed.is_initialized():
            torch.distributed.barrier()
        return ckpt_path, None

    def _st_save(self, step, model):
        try:
            from safetensors.torch import save_file
        except ImportError:
            raise ImportError("run: pip install safetensors, to use safetensors")
        if (
            hasattr(self.model_opt, "lora_layers")
            and len(self.model_opt.lora_layers) > 0
        ) or (
            hasattr(self.model_opt, "lora_embedding") and self.model_opt.lora_embedding
        ):
            model_state_dict = lora_state_dict(model, bias="lora_only")
        else:
            model_state_dict = model.state_dict()

        if torch.distributed.is_initialized():
            ws = torch.distributed.get_world_size()
        else:
            ws = 1
        if ws > 1:
            full_model = [None for _ in range(ws)]
            for key, value in model_state_dict.items():
                model_state_dict[key] = value.cpu()
            torch.distributed.all_gather_object(full_model, model_state_dict)
            fm_sd = {}
            for key in full_model[0].keys():
                if key.split(".")[-1] == "lora_A":
                    if key.split(".")[-2] in [
                        "linear_keys",
                        "linear_values",
                        "linear_query",
                        "w_1",
                        "w_3",
                    ]:
                        fm_sd[key] = (
                            sum([full_model[i][key].cpu() for i in range(ws)]) / ws
                        )
                    elif key.split(".")[-2] in ["final_linear", "w_2"]:
                        fm_sd[key] = torch.cat(
                            [full_model[i][key].cpu() for i in range(ws)], 1
                        )
                elif key.split(".")[-1] == "lora_B":
                    if key.split(".")[-2] in [
                        "linear_keys",
                        "linear_values",
                        "linear_query",
                        "w_1",
                        "w_3",
                    ]:
                        fm_sd[key] = torch.cat(
                            [full_model[i][key].cpu() for i in range(ws)], 0
                        )
                    elif key.split(".")[-2] in ["final_linear", "w_2"]:
                        fm_sd[key] = (
                            sum([full_model[i][key].cpu() for i in range(ws)]) / ws
                        )
                elif key.split(".")[-1] in [
                    "linear_keys",
                    "linear_values",
                    "linear_query",
                    "w_1",
                    "w_3",
                ]:
                    fm_sd[key] = torch.cat(
                        [full_model[i][key].cpu() for i in range(ws)], 0
                    )
                elif key.split(".")[-1] in ["final_linear", "w_2"]:
                    fm_sd[key] = torch.cat(
                        [full_model[i][key].cpu() for i in range(ws)], 1
                    )
                else:
                    fm_sd[key] = full_model[0][key]
            model_state_dict = fm_sd

        checkpoint = {
            "vocab": vocabs_to_dict(self.vocabs),
            "opt": self.model_opt,
            "optim": self.optim.state_dict(),
        }

        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
            logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
            ckpt_path = "%s_step_%d.pt" % (self.base_path, step)
            torch.save(checkpoint, ckpt_path)
            logger.info("Saving safetensors %s_step_%d.pt" % (self.base_path, step))
            model_path = "%s_step_%d.safetensors" % (self.base_path, step)
            save_file(model_state_dict, model_path)
        else:
            ckpt_path = None
            model_path = None
        if torch.distributed.is_initialized():
            torch.distributed.barrier()

        return ckpt_path, model_path

    def _rm_checkpoint(self, name):
        if os.path.exists(name):
            os.remove(name)