|
|
|
""" |
|
This module converts a transformers LlamaForCausalLM to a brrr model |
|
|
|
Command: |
|
torchrun --nproc_per_node=1 convert_trfrs_to_brrr.py \ |
|
--model_name mistralai/Mistral-7B-v0.1 \ |
|
--save_path ./pretrained/Mistral-7B-v0.1 |
|
""" |
|
import argparse |
|
import sys |
|
from dataclasses import asdict |
|
from pathlib import Path |
|
from typing import Dict, List |
|
|
|
import torch |
|
|
|
from brrr.trainer import DistributedTrainer |
|
|
|
sys.path.append(Path(__file__).parent.parent.as_posix()) |
|
import os |
|
|
|
from nanotron.parallel.parameters import NanotronParameter, sanity_check |
|
from nanotron.parallel.pipeline_parallel.engine import ( |
|
AllForwardAllBackwardPipelineEngine, |
|
) |
|
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode |
|
from transformers import MistralConfig as MistralConfig_trfs, MistralForCausalLM |
|
|
|
import nanotron.distributed as dist |
|
from nanotron.config import ParallelismArgs, RecomputeGranularity |
|
from nanotron.parallel.context import ParallelContext |
|
from nanotron.models import build_model |
|
from nanotron.trainer import mark_tied_parameters |
|
from nanotron.serialize import save_meta, save_weights, save |
|
|
|
from modeling_mistral import MistralForTraining |
|
from config_mistral_7b import PARALLELISM as PARALLELISM_BRRR, CONFIG as CONFIG_BRRR |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser(description="Convert transformers weights to brrr weights") |
|
parser.add_argument("--model_name", type=str, default="mistralai/Mistral-7B-v0.1") |
|
parser.add_argument("--save_path", type=str, default="pretrained/Mistral-7B-v0.1") |
|
parser.add_argument("--dp", type=int, default=1) |
|
parser.add_argument("--pp", type=int, default=1) |
|
parser.add_argument("--tp", type=int, default=1) |
|
return parser.parse_args() |
|
|
|
|
|
def permute_for_rotary(tensor, num_heads, per_head_hidden_size, hidden_size): |
|
return ( |
|
tensor.view(num_heads, 2, per_head_hidden_size // 2, hidden_size) |
|
.transpose(1, 2) |
|
.contiguous() |
|
.view(num_heads * per_head_hidden_size, hidden_size) |
|
) |
|
|
|
|
|
def get_transformers_weight( |
|
name: str, ref_module_state_dict: Dict[str, torch.Tensor], ref_module: MistralForCausalLM, get_grad: bool = False |
|
) -> torch.Tensor: |
|
"""From our brrr implementation, we get the equivalent tensor in transformers implementation""" |
|
config = ref_module.config |
|
brrr_prefix = "model." |
|
assert name.startswith(brrr_prefix) |
|
name = name[len(brrr_prefix) :] |
|
|
|
path = name.split(".") |
|
path.remove("pp_block") |
|
name = ".".join(path) |
|
|
|
if get_grad is False: |
|
|
|
def get_tensor(path: str): |
|
return ref_module_state_dict[path] |
|
|
|
def get_tensors(path: List[str]): |
|
return [get_tensor(p) for p in path] |
|
|
|
else: |
|
|
|
def get_tensor(path: str): |
|
weight = ref_module.get_parameter(path) |
|
return weight.grad |
|
|
|
def get_tensors(path: List[str]): |
|
return [get_tensor(p) for p in path] |
|
|
|
if name == "token_position_embeddings.token_embedding.weight": |
|
return get_tensor("model.embed_tokens.weight") |
|
|
|
elif name == "lm_head.weight": |
|
|
|
return get_tensor("lm_head.weight") |
|
|
|
elif name == "final_layer_norm.weight": |
|
return get_tensor("model.norm.weight") |
|
|
|
if path[0] == "decoder": |
|
transformer_path = ["model"] + ["layers"] + [path[1]] |
|
|
|
if path[2] == "attn": |
|
path[2] = "self_attn" |
|
|
|
if path[2] == "ff": |
|
path[2] = "mlp" |
|
|
|
if path[3] == "qkv_proj": |
|
proj_names = ["q_proj", "k_proj", "v_proj"] |
|
tensor_list = get_tensors( |
|
[".".join(transformer_path + path[2:3] + [proj_name] + path[4:]) for proj_name in proj_names] |
|
) |
|
|
|
per_head_hidden_size = config.hidden_size // config.num_attention_heads |
|
|
|
print(f"Permuting q {tensor_list[0].shape}") |
|
tensor_list[0] = permute_for_rotary( |
|
tensor=tensor_list[0], |
|
num_heads=config.num_attention_heads, |
|
per_head_hidden_size=per_head_hidden_size, |
|
hidden_size=config.hidden_size, |
|
) |
|
|
|
print(f"Permuting k {tensor_list[1].shape}") |
|
tensor_list[1] = permute_for_rotary( |
|
tensor=tensor_list[1], |
|
num_heads=config.num_key_value_heads, |
|
per_head_hidden_size=per_head_hidden_size, |
|
hidden_size=config.hidden_size, |
|
) |
|
return torch.cat(tensor_list, dim=0) |
|
|
|
if path[3] == "gate_up_proj": |
|
tensor_list = get_tensors( |
|
[ |
|
".".join(transformer_path + path[2:3] + [proj_name] + path[4:]) |
|
for proj_name in ["gate_proj", "up_proj"] |
|
] |
|
) |
|
return torch.cat(tensor_list, dim=0) |
|
|
|
return get_tensor(".".join(transformer_path + path[2:])) |
|
|
|
else: |
|
raise ValueError(f"Couldn't find transformer equivalent of {name}") |
|
|
|
|
|
def convert_trfrs_to_brrr(dp, pp, tp, model_name="huggyllama/llama-7b", save_path="pretrained/llama-7b"): |
|
|
|
save_path = Path(save_path) |
|
|
|
|
|
parallel_config = PARALLELISM_BRRR |
|
|
|
parallel_config.dp = dp |
|
parallel_config.pp = pp |
|
parallel_config.tp = tp |
|
|
|
|
|
parallel_context = ParallelContext( |
|
data_parallel_size=parallel_config.dp, |
|
pipeline_parallel_size=parallel_config.pp, |
|
tensor_parallel_size=parallel_config.tp, |
|
) |
|
|
|
dtype = torch.bfloat16 |
|
|
|
|
|
model_config_brrr = CONFIG_BRRR.model.model_config |
|
|
|
model = build_model( |
|
model_builder=lambda: MistralForTraining( |
|
config=model_config_brrr, |
|
parallel_context=parallel_context, |
|
parallel_config=parallel_config, |
|
random_states=None, |
|
), |
|
dtype=dtype, |
|
parallel_context=parallel_context, |
|
device=torch.device("cpu"), |
|
) |
|
|
|
|
|
device_map = {} |
|
current_pp_rank = dist.get_rank(group=parallel_context.pp_pg) |
|
device_map["model.embed_tokens"] = ( |
|
model.model.token_position_embeddings.rank |
|
if current_pp_rank == model.model.token_position_embeddings.rank |
|
else "meta" |
|
) |
|
for i in range(model_config_brrr.num_hidden_layers): |
|
device_map[f"model.layers.{i}"] = ( |
|
model.model.decoder[i].rank if current_pp_rank == model.model.decoder[i].rank else "meta" |
|
) |
|
device_map["model.norm"] = ( |
|
model.model.final_layer_norm.rank if current_pp_rank == model.model.final_layer_norm.rank else "meta" |
|
) |
|
device_map["lm_head"] = model.model.lm_head.rank if current_pp_rank == model.model.lm_head.rank else "meta" |
|
model_ref = MistralForCausalLM.from_pretrained(model_name, torch_dtype=dtype, device_map=device_map) |
|
|
|
|
|
ref_state_dict = model_ref.state_dict() |
|
for name, param in model.named_parameters(): |
|
print(f"Syncing {name}") |
|
ref_param = get_transformers_weight(name=name, ref_module_state_dict=ref_state_dict, ref_module=model_ref) |
|
|
|
param_is_tp_sharded = ( |
|
isinstance(param, NanotronParameter) |
|
and param.is_sharded |
|
and parallel_context.world_ranks_to_pg[param.get_sharded_info().global_ranks] == parallel_context.tp_pg |
|
) |
|
|
|
if param_is_tp_sharded: |
|
sharded_info = param.get_sharded_info() |
|
|
|
with torch.no_grad(): |
|
for local_global_slices_pair in sharded_info.local_global_slices_pairs: |
|
local_slices = local_global_slices_pair.local_slices |
|
global_slices = local_global_slices_pair.global_slices |
|
param[local_slices].copy_(ref_param[global_slices]) |
|
else: |
|
assert ( |
|
ref_param.shape == param.shape |
|
), f"Parameter shape don't match for {name}\n{ref_param.shape} != {param.shape}" |
|
|
|
with torch.no_grad(): |
|
param.copy_(ref_param) |
|
ref_param = None |
|
|
|
|
|
|
|
|
|
mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) |
|
|
|
sanity_check(root_module=model) |
|
|
|
checkpoint_metadata = { |
|
"last_train_step": 0, |
|
"consumed_train_samples": 0, |
|
} |
|
save(config=CONFIG_BRRR, model=model, optimizer=None, lr_scheduler=None, parallel_context=parallel_context, root_folder=save_path, |
|
should_save_optimizer=False, should_save_lr_scheduler=False, checkpoint_metadata=checkpoint_metadata, |
|
sanity_checks=False) |
|
|
|
|
|
|
|
if dist.get_rank(parallel_context.world_pg) == 0: |
|
print(save_path) |
|
import json |
|
|
|
with open(save_path / "model_config.json", mode="w") as fo: |
|
fo.write(json.dumps(asdict(CONFIG_BRRR.model.model_config), indent=4)) |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
convert_trfrs_to_brrr(**vars(args)) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|