mistral-nanotron / convert_trfrs_to_brrr.py
thomwolf's picture
thomwolf HF staff
add pretrained model
0c6f487
raw
history blame
9.72 kB
# ruff: noqa: E402
"""
This module converts a transformers LlamaForCausalLM to a brrr model
Command:
torchrun --nproc_per_node=1 convert_trfrs_to_brrr.py \
--model_name mistralai/Mistral-7B-v0.1 \
--save_path ./pretrained/Mistral-7B-v0.1
"""
import argparse
import sys
from dataclasses import asdict
from pathlib import Path
from typing import Dict, List
import torch
from brrr.trainer import DistributedTrainer
sys.path.append(Path(__file__).parent.parent.as_posix())
import os
from nanotron.parallel.parameters import NanotronParameter, sanity_check
from nanotron.parallel.pipeline_parallel.engine import (
AllForwardAllBackwardPipelineEngine,
)
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
from transformers import MistralConfig as MistralConfig_trfs, MistralForCausalLM
import nanotron.distributed as dist
from nanotron.config import ParallelismArgs, RecomputeGranularity
from nanotron.parallel.context import ParallelContext
from nanotron.models import build_model
from nanotron.trainer import mark_tied_parameters
from nanotron.serialize import save_meta, save_weights, save
from modeling_mistral import MistralForTraining
from config_mistral_7b import PARALLELISM as PARALLELISM_BRRR, CONFIG as CONFIG_BRRR
def get_args():
parser = argparse.ArgumentParser(description="Convert transformers weights to brrr weights")
parser.add_argument("--model_name", type=str, default="mistralai/Mistral-7B-v0.1")
parser.add_argument("--save_path", type=str, default="pretrained/Mistral-7B-v0.1")
parser.add_argument("--dp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--tp", type=int, default=1)
return parser.parse_args()
def permute_for_rotary(tensor, num_heads, per_head_hidden_size, hidden_size):
return (
tensor.view(num_heads, 2, per_head_hidden_size // 2, hidden_size)
.transpose(1, 2)
.contiguous()
.view(num_heads * per_head_hidden_size, hidden_size)
)
def get_transformers_weight(
name: str, ref_module_state_dict: Dict[str, torch.Tensor], ref_module: MistralForCausalLM, get_grad: bool = False
) -> torch.Tensor:
"""From our brrr implementation, we get the equivalent tensor in transformers implementation"""
config = ref_module.config
brrr_prefix = "model."
assert name.startswith(brrr_prefix)
name = name[len(brrr_prefix) :]
path = name.split(".")
path.remove("pp_block")
name = ".".join(path)
if get_grad is False:
def get_tensor(path: str):
return ref_module_state_dict[path]
def get_tensors(path: List[str]):
return [get_tensor(p) for p in path]
else:
def get_tensor(path: str):
weight = ref_module.get_parameter(path)
return weight.grad
def get_tensors(path: List[str]):
return [get_tensor(p) for p in path]
if name == "token_position_embeddings.token_embedding.weight":
return get_tensor("model.embed_tokens.weight")
elif name == "lm_head.weight":
# This only used when weights are not shared
return get_tensor("lm_head.weight")
elif name == "final_layer_norm.weight":
return get_tensor("model.norm.weight")
if path[0] == "decoder":
transformer_path = ["model"] + ["layers"] + [path[1]]
if path[2] == "attn":
path[2] = "self_attn"
if path[2] == "ff":
path[2] = "mlp"
if path[3] == "qkv_proj":
proj_names = ["q_proj", "k_proj", "v_proj"]
tensor_list = get_tensors(
[".".join(transformer_path + path[2:3] + [proj_name] + path[4:]) for proj_name in proj_names]
)
# Permute q/k
per_head_hidden_size = config.hidden_size // config.num_attention_heads
# Permute q
print(f"Permuting q {tensor_list[0].shape}")
tensor_list[0] = permute_for_rotary(
tensor=tensor_list[0],
num_heads=config.num_attention_heads,
per_head_hidden_size=per_head_hidden_size,
hidden_size=config.hidden_size,
)
# Permute k
print(f"Permuting k {tensor_list[1].shape}")
tensor_list[1] = permute_for_rotary(
tensor=tensor_list[1],
num_heads=config.num_key_value_heads,
per_head_hidden_size=per_head_hidden_size,
hidden_size=config.hidden_size,
)
return torch.cat(tensor_list, dim=0)
if path[3] == "gate_up_proj":
tensor_list = get_tensors(
[
".".join(transformer_path + path[2:3] + [proj_name] + path[4:])
for proj_name in ["gate_proj", "up_proj"]
]
)
return torch.cat(tensor_list, dim=0)
return get_tensor(".".join(transformer_path + path[2:]))
else:
raise ValueError(f"Couldn't find transformer equivalent of {name}")
def convert_trfrs_to_brrr(dp, pp, tp, model_name="huggyllama/llama-7b", save_path="pretrained/llama-7b"):
# check save_path doesnt exist or is empty
save_path = Path(save_path)
# assert not save_path.exists() or len(list(save_path.iterdir())) == 0, f"save_path {save_path} is not empty"
parallel_config = PARALLELISM_BRRR
parallel_config.dp = dp
parallel_config.pp = pp
parallel_config.tp = tp
# Initialise all process groups
parallel_context = ParallelContext(
data_parallel_size=parallel_config.dp,
pipeline_parallel_size=parallel_config.pp,
tensor_parallel_size=parallel_config.tp,
)
# params
dtype = torch.bfloat16 # Flash attention doesn't support fp32
# Initialise brrr model
model_config_brrr = CONFIG_BRRR.model.model_config
model = build_model(
model_builder=lambda: MistralForTraining(
config=model_config_brrr,
parallel_context=parallel_context,
parallel_config=parallel_config,
random_states=None,
),
dtype=dtype,
parallel_context=parallel_context,
device=torch.device("cpu"),
)
# Initialise transformers model
device_map = {}
current_pp_rank = dist.get_rank(group=parallel_context.pp_pg)
device_map["model.embed_tokens"] = (
model.model.token_position_embeddings.rank
if current_pp_rank == model.model.token_position_embeddings.rank
else "meta"
)
for i in range(model_config_brrr.num_hidden_layers):
device_map[f"model.layers.{i}"] = (
model.model.decoder[i].rank if current_pp_rank == model.model.decoder[i].rank else "meta"
)
device_map["model.norm"] = (
model.model.final_layer_norm.rank if current_pp_rank == model.model.final_layer_norm.rank else "meta"
)
device_map["lm_head"] = model.model.lm_head.rank if current_pp_rank == model.model.lm_head.rank else "meta"
model_ref = MistralForCausalLM.from_pretrained(model_name, torch_dtype=dtype, device_map=device_map)
# Copy weights from trfrs to brrr
ref_state_dict = model_ref.state_dict()
for name, param in model.named_parameters():
print(f"Syncing {name}")
ref_param = get_transformers_weight(name=name, ref_module_state_dict=ref_state_dict, ref_module=model_ref)
param_is_tp_sharded = (
isinstance(param, NanotronParameter)
and param.is_sharded
and parallel_context.world_ranks_to_pg[param.get_sharded_info().global_ranks] == parallel_context.tp_pg
)
if param_is_tp_sharded:
sharded_info = param.get_sharded_info()
# copy param data (not just the reference)
with torch.no_grad():
for local_global_slices_pair in sharded_info.local_global_slices_pairs:
local_slices = local_global_slices_pair.local_slices
global_slices = local_global_slices_pair.global_slices
param[local_slices].copy_(ref_param[global_slices])
else:
assert (
ref_param.shape == param.shape
), f"Parameter shape don't match for {name}\n{ref_param.shape} != {param.shape}"
# copy param data (not just the reference)
with torch.no_grad():
param.copy_(ref_param)
ref_param = None
# torch.cuda.empty_cache()
# TODO @nouamanetazi: assert weights are the same
# Marks parameters as NanotronParameters
mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config)
sanity_check(root_module=model)
checkpoint_metadata = {
"last_train_step": 0,
"consumed_train_samples": 0,
}
save(config=CONFIG_BRRR, model=model, optimizer=None, lr_scheduler=None, parallel_context=parallel_context, root_folder=save_path,
should_save_optimizer=False, should_save_lr_scheduler=False, checkpoint_metadata=checkpoint_metadata,
sanity_checks=False)
# save_weights(model=model, parallel_context=parallel_context, root_folder=save_path)
# save_meta(root_folder=save_path, parallel_context=parallel_context, checkpoint_metadata=checkpoint_metadata)
if dist.get_rank(parallel_context.world_pg) == 0:
print(save_path)
import json
with open(save_path / "model_config.json", mode="w") as fo:
fo.write(json.dumps(asdict(CONFIG_BRRR.model.model_config), indent=4))
def main():
args = get_args()
convert_trfrs_to_brrr(**vars(args))
if __name__ == "__main__":
main()