Spaces:
Runtime error
Runtime error
File size: 2,762 Bytes
7900c16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
"""
Modify model's embedding and softmax layers according to the vocabulary.
"""
import argparse
import os
import collections
import sys
import numpy as np
import torch
tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(tencentpretrain_dir)
from tencentpretrain.utils.vocab import Vocab
def adapter(old_model, old_vocab, new_vocab):
new_model = collections.OrderedDict()
embedding_key = "embedding.word.embedding.weight"
softmax_key = "target.mlm.linear_2.weight"
softmax_bias_key = "target.mlm.linear_2.bias"
# Fit in parameters that would not be modified.
tensor_name = []
for k, v in old_model.items():
tensor_name.append(k)
if k not in [embedding_key, softmax_key, softmax_bias_key]:
new_model[k] = v
bool = softmax_key in tensor_name
# Get word embedding, mlm, and mlm bias variables.
old_embedding = old_model.get(embedding_key).data.numpy()
if bool:
old_softmax = old_model.get(softmax_key).data.numpy()
old_softmax_bias = old_model.get(softmax_bias_key).data.numpy()
# Initialize.
new_embedding = np.random.normal(0, 0.02, [len(new_vocab), old_embedding.shape[1]])
if bool:
new_softmax = np.random.normal(0, 0.02, [len(new_vocab), old_softmax.shape[1]])
new_softmax_bias = np.random.normal(0, 0.02, [len(new_vocab)])
# Put corresponding parameters into the new model.
for i, w in enumerate(new_vocab.i2w):
if w in old_vocab.w2i:
old_w_index = old_vocab.w2i[w]
new_embedding[i] = old_embedding[old_w_index]
if bool:
new_softmax[i] = old_softmax[old_w_index]
new_softmax_bias[i] = old_softmax_bias[old_w_index]
new_model[embedding_key] = torch.tensor(new_embedding, dtype=torch.float32)
if bool:
new_model[softmax_key] = torch.tensor(new_softmax, dtype=torch.float32)
new_model[softmax_bias_key] = torch.tensor(new_softmax_bias, dtype=torch.float32)
return new_model
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Input options.
parser.add_argument("--old_model_path", type=str)
parser.add_argument("--old_vocab_path", type=str)
parser.add_argument("--new_vocab_path", type=str)
# Output options.
parser.add_argument("--new_model_path", type=str)
args = parser.parse_args()
old_vocab = Vocab()
old_vocab.load(args.old_vocab_path)
new_vocab = Vocab()
new_vocab.load(args.new_vocab_path)
old_model = torch.load(args.old_model_path, map_location="cpu")
new_model = adapter(old_model, old_vocab, new_vocab)
print("Output adapted new model.")
torch.save(new_model, args.new_model_path)
|