VISOR-GPT / train /scripts /dynamic_vocab_adapter.py
szukevin's picture
upload
7900c16
raw
history blame
2.76 kB
"""
Modify model's embedding and softmax layers according to the vocabulary.
"""
import argparse
import os
import collections
import sys
import numpy as np
import torch
tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(tencentpretrain_dir)
from tencentpretrain.utils.vocab import Vocab
def adapter(old_model, old_vocab, new_vocab):
new_model = collections.OrderedDict()
embedding_key = "embedding.word.embedding.weight"
softmax_key = "target.mlm.linear_2.weight"
softmax_bias_key = "target.mlm.linear_2.bias"
# Fit in parameters that would not be modified.
tensor_name = []
for k, v in old_model.items():
tensor_name.append(k)
if k not in [embedding_key, softmax_key, softmax_bias_key]:
new_model[k] = v
bool = softmax_key in tensor_name
# Get word embedding, mlm, and mlm bias variables.
old_embedding = old_model.get(embedding_key).data.numpy()
if bool:
old_softmax = old_model.get(softmax_key).data.numpy()
old_softmax_bias = old_model.get(softmax_bias_key).data.numpy()
# Initialize.
new_embedding = np.random.normal(0, 0.02, [len(new_vocab), old_embedding.shape[1]])
if bool:
new_softmax = np.random.normal(0, 0.02, [len(new_vocab), old_softmax.shape[1]])
new_softmax_bias = np.random.normal(0, 0.02, [len(new_vocab)])
# Put corresponding parameters into the new model.
for i, w in enumerate(new_vocab.i2w):
if w in old_vocab.w2i:
old_w_index = old_vocab.w2i[w]
new_embedding[i] = old_embedding[old_w_index]
if bool:
new_softmax[i] = old_softmax[old_w_index]
new_softmax_bias[i] = old_softmax_bias[old_w_index]
new_model[embedding_key] = torch.tensor(new_embedding, dtype=torch.float32)
if bool:
new_model[softmax_key] = torch.tensor(new_softmax, dtype=torch.float32)
new_model[softmax_bias_key] = torch.tensor(new_softmax_bias, dtype=torch.float32)
return new_model
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Input options.
parser.add_argument("--old_model_path", type=str)
parser.add_argument("--old_vocab_path", type=str)
parser.add_argument("--new_vocab_path", type=str)
# Output options.
parser.add_argument("--new_model_path", type=str)
args = parser.parse_args()
old_vocab = Vocab()
old_vocab.load(args.old_vocab_path)
new_vocab = Vocab()
new_vocab.load(args.new_vocab_path)
old_model = torch.load(args.old_model_path, map_location="cpu")
new_model = adapter(old_model, old_vocab, new_vocab)
print("Output adapted new model.")
torch.save(new_model, args.new_model_path)