|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
import torch |
|
import yaml |
|
import shutil |
|
from tqdm import auto as tqdm_lib |
|
|
|
|
|
VOCAB_SIZE = 50432 |
|
IGNORED_MODEL_STATE_KEYS = [ |
|
"optimizer", |
|
"random_rng_state", |
|
"np_rng_state", |
|
"torch_rng_state", |
|
"cuda_rng_state", |
|
"rng_tracker_states", |
|
] |
|
|
|
|
|
def modify_config(input_config_path, output_config_path, output_dir): |
|
with open(input_config_path) as f: |
|
loaded_config = yaml.full_load(f) |
|
|
|
|
|
loaded_config["model_parallel_size"] = 1 |
|
loaded_config["pipe_parallel_size"] = 1 |
|
|
|
|
|
loaded_config["load"] = output_dir |
|
loaded_config["save"] = output_dir |
|
|
|
|
|
loaded_config["vocab_file"] = os.path.join(output_dir, "20B_tokenizer.json") |
|
loaded_config["log_dir"] = "./logs" |
|
|
|
|
|
|
|
|
|
loaded_config["make_vocab_size_divisible_by"] = VOCAB_SIZE |
|
|
|
|
|
loaded_config["zero_optimization"]["stage"] = 0 |
|
|
|
with open(output_config_path, "w") as f: |
|
yaml.dump(loaded_config, f) |
|
|
|
|
|
def modify_model_states(input_model_state_path, output_model_state_path): |
|
model_state = torch.load(input_model_state_path) |
|
for key in IGNORED_MODEL_STATE_KEYS: |
|
del model_state[key] |
|
model_state["mp_world_size"] = 1 |
|
model_state["dp_world_size"] = 1 |
|
model_state["args"]["model_parallel_size"] = 1 |
|
model_state["args"]["make_vocab_size_divisible_by"] = VOCAB_SIZE |
|
torch.save(model_state, output_model_state_path) |
|
|
|
|
|
def merge_model_weights(input_checkpoint_path, output_checkpoint_path): |
|
pbar = tqdm_lib.tqdm(total=47) |
|
|
|
|
|
for layer_i in range(44): |
|
pbar.set_description(f"Merging layer {layer_i}") |
|
filename_tp1 = f"layer_{layer_i + 2:02d}-model_00-model_states.pt" |
|
filename_tp2 = f"layer_{layer_i + 2:02d}-model_01-model_states.pt" |
|
loaded_tp1 = torch.load(os.path.join(input_checkpoint_path, filename_tp1)) |
|
loaded_tp2 = torch.load(os.path.join(input_checkpoint_path, filename_tp2)) |
|
|
|
merged = {} |
|
|
|
|
|
merged["mlp.dense_4h_to_h.weight"] = torch.cat( |
|
[ |
|
loaded_tp1["mlp.dense_4h_to_h.weight"], |
|
loaded_tp2["mlp.dense_4h_to_h.weight"], |
|
], |
|
dim=1, |
|
) |
|
merged["attention.dense.weight"] = torch.cat( |
|
[ |
|
loaded_tp1["attention.dense.weight"], |
|
loaded_tp2["attention.dense.weight"], |
|
], |
|
dim=1, |
|
) |
|
merged["mlp.dense_4h_to_h.bias"] = ( |
|
loaded_tp1["mlp.dense_4h_to_h.bias"] + loaded_tp2["mlp.dense_4h_to_h.bias"] |
|
) |
|
merged["attention.dense.bias"] = ( |
|
loaded_tp1["attention.dense.bias"] + loaded_tp2["attention.dense.bias"] |
|
) |
|
|
|
|
|
merged["input_layernorm.weight"] = ( |
|
loaded_tp1["input_layernorm.weight"] + loaded_tp2["input_layernorm.weight"] |
|
) / 2 |
|
merged["input_layernorm.bias"] = ( |
|
loaded_tp1["input_layernorm.bias"] + loaded_tp2["input_layernorm.bias"] |
|
) / 2 |
|
merged["post_attention_layernorm.weight"] = ( |
|
loaded_tp1["post_attention_layernorm.weight"] |
|
+ loaded_tp2["post_attention_layernorm.weight"] |
|
) / 2 |
|
merged["post_attention_layernorm.bias"] = ( |
|
loaded_tp1["post_attention_layernorm.bias"] |
|
+ loaded_tp2["post_attention_layernorm.bias"] |
|
) / 2 |
|
|
|
|
|
merged["mlp.dense_h_to_4h.weight"] = torch.cat( |
|
[ |
|
loaded_tp1["mlp.dense_h_to_4h.weight"], |
|
loaded_tp2["mlp.dense_h_to_4h.weight"], |
|
], |
|
dim=0, |
|
) |
|
merged["mlp.dense_h_to_4h.bias"] = torch.cat( |
|
[ |
|
loaded_tp1["mlp.dense_h_to_4h.bias"], |
|
loaded_tp2["mlp.dense_h_to_4h.bias"], |
|
], |
|
dim=0, |
|
) |
|
merged["attention.query_key_value.weight"] = torch.cat( |
|
[ |
|
loaded_tp1["attention.query_key_value.weight"], |
|
loaded_tp2["attention.query_key_value.weight"], |
|
], |
|
dim=0, |
|
) |
|
merged["attention.query_key_value.bias"] = torch.cat( |
|
[ |
|
loaded_tp1["attention.query_key_value.bias"], |
|
loaded_tp2["attention.query_key_value.bias"], |
|
], |
|
dim=0, |
|
) |
|
|
|
|
|
merged["attention.rotary_emb.inv_freq"] = loaded_tp1[ |
|
"attention.rotary_emb.inv_freq" |
|
] |
|
|
|
torch.save(merged, os.path.join(output_checkpoint_path, filename_tp1)) |
|
del loaded_tp1 |
|
del loaded_tp2 |
|
pbar.update(1) |
|
|
|
|
|
pbar.set_description(f"Merging input embedding") |
|
loaded_tp1 = torch.load( |
|
os.path.join(input_checkpoint_path, "layer_00-model_00-model_states.pt") |
|
) |
|
loaded_tp2 = torch.load( |
|
os.path.join(input_checkpoint_path, "layer_00-model_01-model_states.pt") |
|
) |
|
merged = { |
|
"word_embeddings.weight": torch.cat( |
|
[ |
|
loaded_tp1["word_embeddings.weight"], |
|
loaded_tp2["word_embeddings.weight"], |
|
], |
|
dim=0, |
|
) |
|
} |
|
torch.save( |
|
merged, |
|
os.path.join(output_checkpoint_path, "layer_00-model_00-model_states.pt"), |
|
) |
|
del loaded_tp1 |
|
del loaded_tp2 |
|
pbar.update(1) |
|
|
|
|
|
pbar.set_description(f"Merging final layer norm") |
|
loaded_tp1 = torch.load( |
|
os.path.join(input_checkpoint_path, "layer_47-model_00-model_states.pt") |
|
) |
|
loaded_tp2 = torch.load( |
|
os.path.join(input_checkpoint_path, "layer_47-model_01-model_states.pt") |
|
) |
|
merged = { |
|
"norm.weight": (loaded_tp1["norm.weight"] + loaded_tp2["norm.weight"]) / 2, |
|
"norm.bias": (loaded_tp1["norm.bias"] + loaded_tp2["norm.bias"]) / 2, |
|
} |
|
torch.save( |
|
merged, |
|
os.path.join(output_checkpoint_path, "layer_47-model_00-model_states.pt"), |
|
) |
|
del loaded_tp1 |
|
del loaded_tp2 |
|
pbar.update(1) |
|
|
|
|
|
pbar.set_description(f"Merging output embedding") |
|
loaded_tp1 = torch.load( |
|
os.path.join(input_checkpoint_path, "layer_48-model_00-model_states.pt") |
|
) |
|
loaded_tp2 = torch.load( |
|
os.path.join(input_checkpoint_path, "layer_48-model_01-model_states.pt") |
|
) |
|
merged = { |
|
"final_linear.weight": torch.cat( |
|
[ |
|
loaded_tp1["final_linear.weight"], |
|
loaded_tp2["final_linear.weight"], |
|
], |
|
dim=0, |
|
), |
|
} |
|
torch.save( |
|
merged, |
|
os.path.join(output_checkpoint_path, "layer_48-model_00-model_states.pt"), |
|
) |
|
del loaded_tp1 |
|
del loaded_tp2 |
|
pbar.update(1) |
|
pbar.set_description("Done.") |
|
|
|
|
|
def merge(input_dir, output_dir): |
|
input_checkpoint_path = os.path.join(input_dir, "global_step150000") |
|
output_checkpoint_path = os.path.join(output_dir, "global_step150000") |
|
os.makedirs(output_checkpoint_path, exist_ok=True) |
|
os.makedirs(os.path.join(output_dir, "configs"), exist_ok=True) |
|
for i in range(8): |
|
modify_model_states( |
|
input_model_state_path=os.path.join( |
|
input_checkpoint_path, f"mp_rank_{i:02d}_model_states.pt" |
|
), |
|
output_model_state_path=os.path.join( |
|
output_checkpoint_path, f"mp_rank_{i:02d}_model_states.pt" |
|
), |
|
) |
|
modify_config( |
|
input_config_path=os.path.join(input_dir, "configs", "20B.yml"), |
|
output_config_path=os.path.join(output_dir, "configs", "20B.yml"), |
|
output_dir=output_dir, |
|
) |
|
merge_model_weights( |
|
input_checkpoint_path=input_checkpoint_path, |
|
output_checkpoint_path=output_checkpoint_path, |
|
) |
|
shutil.copyfile( |
|
os.path.join(input_dir, "20B_tokenizer.json"), |
|
os.path.join(output_dir, "20B_tokenizer.json"), |
|
) |
|
with open(os.path.join(output_dir, "latest"), "w") as f: |
|
f.write("global_step150000") |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Merge 20B checkpoint.") |
|
parser.add_argument( |
|
"--input_dir", |
|
type=str, |
|
help='Checkpoint dir, which should contain (e.g. a folder named "global_step150000")', |
|
) |
|
parser.add_argument( |
|
"--output_dir", type=str, help="Output dir, to save the 1-GPU weights configs" |
|
) |
|
args = parser.parse_args() |
|
merge(args.input_dir, args.output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|