import json import torch import numpy as np from modeling_gpt1 import GPT1ForCausalLM, GPT1Model from configuration_gpt1 import GPT1Config GPT1Config.register_for_auto_class() GPT1Model.register_for_auto_class('AutoModel') GPT1ForCausalLM.register_for_auto_class('AutoModelForCausalLM') def lists_are_equal(list1, list2): for i, j in zip(list1, list2): if i != j: return False return True # get the original weights from the GPT1 params.npy files def get_weights_from_tf_model(): shapes = json.load(open('original_gpt1_params/params_shapes.json')) offsets = np.cumsum([np.prod(shape) for shape in shapes]) init_params = [np.load('original_gpt1_params/params_{}.npy'.format(n)) for n in range(10)] init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] config = GPT1Config() model = GPT1ForCausalLM(config) # print(shapes[:15]) # print([k for k, v in model.named_parameters()][:10]) # embs layer model.model.embs.weight.data = torch.from_numpy(init_params[1]) # pos enc layer model.model.pos_emb.weight.data = torch.from_numpy(init_params[0]) layers = model.model.layers for i in range(0, 12): idx = 12 * i + 2 # attention q, k, v projections init_params[idx] = np.squeeze(init_params[idx], axis=0) q, k, v = torch.split(torch.tensor(init_params[idx]), 768, dim=-1) layers[i].attention.q_proj.weight.data = q.detach().clone().transpose(-1, -2).contiguous() layers[i].attention.k_proj.weight.data = k.detach().clone().transpose(-1, -2).contiguous() layers[i].attention.v_proj.weight.data = v.detach().clone().transpose(-1, -2).contiguous() # attention q, k, v biases q_bias, k_bias, v_bias = torch.split(torch.tensor(init_params[idx + 1]), 768, dim=-1) layers[i].attention.q_proj.bias.data = q_bias.detach().clone().contiguous() layers[i].attention.k_proj.bias.data = k_bias.detach().clone().contiguous() layers[i].attention.v_proj.bias.data = v_bias.detach().clone().contiguous() # attention output proj + bias init_params[idx + 2] = np.squeeze(init_params[idx + 2], axis=0) layers[i].attention.o_proj.weight.data = torch.from_numpy(init_params[idx + 2]).transpose(-1, -2).contiguous() layers[i].attention.o_proj.bias.data = torch.from_numpy(init_params[idx + 3]) # attention norm + bias layers[i].attention_norm.weight.data = torch.from_numpy(init_params[idx + 4]) layers[i].attention_norm.bias.data = torch.from_numpy(init_params[idx + 5]) # mlp layer init_params[idx + 6] = np.squeeze(init_params[idx + 6], axis=0) layers[i].mlp.fc1.weight.data = torch.from_numpy(init_params[idx + 6]).transpose(-1, -2).contiguous() layers[i].mlp.fc1.bias.data = torch.from_numpy(init_params[idx + 7]) init_params[idx + 8] = np.squeeze(init_params[idx + 8], axis=0) layers[i].mlp.fc2.weight.data = torch.from_numpy(init_params[idx + 8]).transpose(-1, -2).contiguous() layers[i].mlp.fc2.bias.data = torch.from_numpy(init_params[idx + 9]) # mlp norm + bias layers[i].mlp_norm.weight.data = torch.from_numpy(init_params[idx + 10]) layers[i].mlp_norm.bias.data = torch.from_numpy(init_params[idx + 11]) model.save_pretrained('gpt1-converted-weights/') get_weights_from_tf_model()