|
import os |
|
import torch |
|
import re |
|
from collections import deque |
|
from onmt.utils.logging import logger |
|
from onmt.inputters.inputter import vocabs_to_dict |
|
from onmt.modules.lora import lora_state_dict |
|
|
|
|
|
def build_model_saver(model_opt, opt, model, vocabs, optim, device_id): |
|
|
|
save_model_path = os.path.abspath(opt.save_model) |
|
os.makedirs(os.path.dirname(save_model_path), exist_ok=True) |
|
|
|
model_saver = ModelSaver( |
|
opt.save_model, |
|
model, |
|
model_opt, |
|
vocabs, |
|
optim, |
|
opt.keep_checkpoint, |
|
opt.save_format, |
|
device_id, |
|
) |
|
return model_saver |
|
|
|
|
|
def load_checkpoint(ckpt_path): |
|
"""Load checkpoint from `ckpt_path` if any else return `None`.""" |
|
checkpoint = None |
|
if ckpt_path: |
|
logger.info("Loading checkpoint from %s" % ckpt_path) |
|
checkpoint = torch.load(ckpt_path, map_location=torch.device("cpu")) |
|
|
|
if "model" in checkpoint.keys(): |
|
|
|
def fix_key(s): |
|
s = re.sub( |
|
r"(.*)\.layer_norm((_\d+)?)\.b_2", r"\1.layer_norm\2.bias", s |
|
) |
|
s = re.sub( |
|
r"(.*)\.layer_norm((_\d+)?)\.a_2", r"\1.layer_norm\2.weight", s |
|
) |
|
return s |
|
|
|
checkpoint["model"] = { |
|
fix_key(k): v for k, v in checkpoint["model"].items() |
|
} |
|
|
|
for key in checkpoint["model"].keys(): |
|
if "w_1.bias" in key: |
|
checkpoint["opt"].add_ffnbias = True |
|
|
|
if not hasattr(checkpoint["opt"], "num_kv"): |
|
checkpoint["opt"].num_kv = 0 |
|
if not hasattr(checkpoint["opt"], "add_ffnbias"): |
|
checkpoint["opt"].add_ffnbias = False |
|
if not hasattr(checkpoint["opt"], "parallel_residual"): |
|
checkpoint["opt"].parallel_residual = False |
|
if not hasattr(checkpoint["opt"], "shared_layer_norm"): |
|
checkpoint["opt"].shared_layer_norm = False |
|
if not hasattr(checkpoint["opt"], "use_ckpting"): |
|
checkpoint["opt"].use_ckpting = [] |
|
if not hasattr(checkpoint["opt"], "relative_positions_buckets"): |
|
checkpoint["opt"].relative_positions_buckets = 0 |
|
if not hasattr(checkpoint["opt"], "parallel_mode"): |
|
checkpoint["opt"].parallel_mode = "data_parallel" |
|
if not hasattr(checkpoint["opt"], "norm_eps"): |
|
checkpoint["opt"].norm_eps = 1e-6 |
|
|
|
|
|
if "generator" in checkpoint.keys() and checkpoint["generator"]: |
|
if "0.weight" in checkpoint["generator"]: |
|
checkpoint["generator"]["weight"] = checkpoint["generator"].pop( |
|
"0.weight" |
|
) |
|
if "0.bias" in checkpoint["generator"]: |
|
checkpoint["generator"]["bias"] = checkpoint["generator"].pop("0.bias") |
|
|
|
|
|
return checkpoint |
|
|
|
|
|
class ModelSaverBase(object): |
|
"""Base class for model saving operations |
|
|
|
Inherited classes must implement private methods: |
|
* `_save` |
|
* `_rm_checkpoint |
|
""" |
|
|
|
def __init__( |
|
self, |
|
base_path, |
|
model, |
|
model_opt, |
|
vocabs, |
|
optim, |
|
keep_checkpoint=-1, |
|
save_format="pytorch", |
|
device_id=0, |
|
): |
|
self.base_path = base_path |
|
self.model = model |
|
self.model_opt = model_opt |
|
self.vocabs = vocabs |
|
self.optim = optim |
|
self.last_saved_step = None |
|
self.keep_checkpoint = keep_checkpoint |
|
self.save_format = save_format |
|
self.device_id = device_id |
|
|
|
if keep_checkpoint > 0: |
|
self.checkpoint_queue = deque([], maxlen=keep_checkpoint) |
|
if save_format == "safetensors": |
|
self.model_queue = deque([], maxlen=keep_checkpoint) |
|
|
|
def save(self, step, moving_average=None): |
|
"""Main entry point for model saver |
|
|
|
It wraps the `_save` method with checks and apply `keep_checkpoint` |
|
related logic |
|
""" |
|
|
|
if self.keep_checkpoint == 0 or step == self.last_saved_step: |
|
return |
|
|
|
save_model = self.model |
|
if moving_average: |
|
model_params_data = [] |
|
for avg, param in zip(moving_average, save_model.parameters()): |
|
model_params_data.append(param.data) |
|
param.data = avg.data |
|
|
|
if self.save_format == "pytorch": |
|
ckpt_path, _ = self._save(step, save_model) |
|
elif self.save_format == "safetensors": |
|
ckpt_path, model_path = self._st_save(step, save_model) |
|
|
|
self.last_saved_step = step |
|
|
|
if moving_average: |
|
for param_data, param in zip(model_params_data, save_model.parameters()): |
|
param.data = param_data |
|
|
|
if ckpt_path is not None: |
|
if self.keep_checkpoint > 0: |
|
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen: |
|
todel = self.checkpoint_queue.popleft() |
|
self._rm_checkpoint(todel) |
|
if self.save_format == "safetensors": |
|
todel = self.model_queue.popleft() |
|
self._rm_checkpoint(todel) |
|
self.checkpoint_queue.append(ckpt_path) |
|
if self.save_format == "safetensors": |
|
self.model_queue.append(model_path) |
|
|
|
def _save(self, step, model): |
|
"""Save a resumable checkpoint. |
|
|
|
Args: |
|
step (int): step number |
|
model (nn.Module): torch model to save |
|
|
|
Returns: |
|
(str, str): |
|
|
|
* checkpoint_name: name (or path) of the saved checkpoint |
|
* model_name: name (or path) of the saved safetensors weights if applicable |
|
""" |
|
|
|
raise NotImplementedError() |
|
|
|
def _rm_checkpoint(self, name): |
|
"""Remove a checkpoint |
|
|
|
Args: |
|
name(str): name that indentifies the checkpoint |
|
(it may be a filepath) |
|
""" |
|
|
|
raise NotImplementedError() |
|
|
|
|
|
class ModelSaver(ModelSaverBase): |
|
"""Simple model saver to filesystem""" |
|
|
|
def _save(self, step, model): |
|
if ( |
|
hasattr(self.model_opt, "lora_layers") |
|
and len(self.model_opt.lora_layers) > 0 |
|
) or ( |
|
hasattr(self.model_opt, "lora_embedding") and self.model_opt.lora_embedding |
|
): |
|
model_state_dict = lora_state_dict(model, bias="lora_only") |
|
generator_state_dict = None |
|
else: |
|
model_state_dict = model.state_dict() |
|
model_state_dict = { |
|
k: v for k, v in model_state_dict.items() if "generator" not in k |
|
} |
|
generator_state_dict = model.generator.state_dict() |
|
|
|
if torch.distributed.is_initialized(): |
|
ws = torch.distributed.get_world_size() |
|
else: |
|
ws = 1 |
|
if ws > 1: |
|
full_model = [None for _ in range(ws)] |
|
for key, value in model_state_dict.items(): |
|
model_state_dict[key] = value.cpu() |
|
torch.distributed.all_gather_object(full_model, model_state_dict) |
|
fm_sd = {} |
|
for key in full_model[0].keys(): |
|
if key.split(".")[-1] == "lora_A": |
|
if key.split(".")[-2] in [ |
|
"linear_keys", |
|
"linear_values", |
|
"linear_query", |
|
"w_1", |
|
"w_3", |
|
]: |
|
fm_sd[key] = ( |
|
sum([full_model[i][key].cpu() for i in range(ws)]) / ws |
|
) |
|
elif key.split(".")[-2] in ["final_linear", "w_2"]: |
|
fm_sd[key] = torch.cat( |
|
[full_model[i][key].cpu() for i in range(ws)], 1 |
|
) |
|
elif key.split(".")[-1] == "lora_B": |
|
if key.split(".")[-2] in [ |
|
"linear_keys", |
|
"linear_values", |
|
"linear_query", |
|
"w_1", |
|
"w_3", |
|
]: |
|
fm_sd[key] = torch.cat( |
|
[full_model[i][key].cpu() for i in range(ws)], 0 |
|
) |
|
elif key.split(".")[-2] in ["final_linear", "w_2"]: |
|
fm_sd[key] = ( |
|
sum([full_model[i][key].cpu() for i in range(ws)]) / ws |
|
) |
|
elif key.split(".")[-1] in [ |
|
"linear_keys", |
|
"linear_values", |
|
"linear_query", |
|
"w_1", |
|
"w_3", |
|
]: |
|
fm_sd[key] = torch.cat( |
|
[full_model[i][key].cpu() for i in range(ws)], 0 |
|
) |
|
elif key.split(".")[-1] in ["final_linear", "w_2"]: |
|
fm_sd[key] = torch.cat( |
|
[full_model[i][key].cpu() for i in range(ws)], 1 |
|
) |
|
else: |
|
fm_sd[key] = full_model[0][key] |
|
model_state_dict = fm_sd |
|
|
|
checkpoint = { |
|
"model": model_state_dict, |
|
"generator": generator_state_dict, |
|
"vocab": vocabs_to_dict(self.vocabs), |
|
"opt": self.model_opt, |
|
"optim": self.optim.state_dict(), |
|
} |
|
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: |
|
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step)) |
|
ckpt_path = "%s_step_%d.pt" % (self.base_path, step) |
|
torch.save(checkpoint, ckpt_path) |
|
else: |
|
ckpt_path = None |
|
if torch.distributed.is_initialized(): |
|
torch.distributed.barrier() |
|
return ckpt_path, None |
|
|
|
def _st_save(self, step, model): |
|
try: |
|
from safetensors.torch import save_file |
|
except ImportError: |
|
raise ImportError("run: pip install safetensors, to use safetensors") |
|
if ( |
|
hasattr(self.model_opt, "lora_layers") |
|
and len(self.model_opt.lora_layers) > 0 |
|
) or ( |
|
hasattr(self.model_opt, "lora_embedding") and self.model_opt.lora_embedding |
|
): |
|
model_state_dict = lora_state_dict(model, bias="lora_only") |
|
else: |
|
model_state_dict = model.state_dict() |
|
|
|
if torch.distributed.is_initialized(): |
|
ws = torch.distributed.get_world_size() |
|
else: |
|
ws = 1 |
|
if ws > 1: |
|
full_model = [None for _ in range(ws)] |
|
for key, value in model_state_dict.items(): |
|
model_state_dict[key] = value.cpu() |
|
torch.distributed.all_gather_object(full_model, model_state_dict) |
|
fm_sd = {} |
|
for key in full_model[0].keys(): |
|
if key.split(".")[-1] == "lora_A": |
|
if key.split(".")[-2] in [ |
|
"linear_keys", |
|
"linear_values", |
|
"linear_query", |
|
"w_1", |
|
"w_3", |
|
]: |
|
fm_sd[key] = ( |
|
sum([full_model[i][key].cpu() for i in range(ws)]) / ws |
|
) |
|
elif key.split(".")[-2] in ["final_linear", "w_2"]: |
|
fm_sd[key] = torch.cat( |
|
[full_model[i][key].cpu() for i in range(ws)], 1 |
|
) |
|
elif key.split(".")[-1] == "lora_B": |
|
if key.split(".")[-2] in [ |
|
"linear_keys", |
|
"linear_values", |
|
"linear_query", |
|
"w_1", |
|
"w_3", |
|
]: |
|
fm_sd[key] = torch.cat( |
|
[full_model[i][key].cpu() for i in range(ws)], 0 |
|
) |
|
elif key.split(".")[-2] in ["final_linear", "w_2"]: |
|
fm_sd[key] = ( |
|
sum([full_model[i][key].cpu() for i in range(ws)]) / ws |
|
) |
|
elif key.split(".")[-1] in [ |
|
"linear_keys", |
|
"linear_values", |
|
"linear_query", |
|
"w_1", |
|
"w_3", |
|
]: |
|
fm_sd[key] = torch.cat( |
|
[full_model[i][key].cpu() for i in range(ws)], 0 |
|
) |
|
elif key.split(".")[-1] in ["final_linear", "w_2"]: |
|
fm_sd[key] = torch.cat( |
|
[full_model[i][key].cpu() for i in range(ws)], 1 |
|
) |
|
else: |
|
fm_sd[key] = full_model[0][key] |
|
model_state_dict = fm_sd |
|
|
|
checkpoint = { |
|
"vocab": vocabs_to_dict(self.vocabs), |
|
"opt": self.model_opt, |
|
"optim": self.optim.state_dict(), |
|
} |
|
|
|
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: |
|
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step)) |
|
ckpt_path = "%s_step_%d.pt" % (self.base_path, step) |
|
torch.save(checkpoint, ckpt_path) |
|
logger.info("Saving safetensors %s_step_%d.pt" % (self.base_path, step)) |
|
model_path = "%s_step_%d.safetensors" % (self.base_path, step) |
|
save_file(model_state_dict, model_path) |
|
else: |
|
ckpt_path = None |
|
model_path = None |
|
if torch.distributed.is_initialized(): |
|
torch.distributed.barrier() |
|
|
|
return ckpt_path, model_path |
|
|
|
def _rm_checkpoint(self, name): |
|
if os.path.exists(name): |
|
os.remove(name) |
|
|