File size: 2,762 Bytes
7900c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
Modify model's embedding and softmax layers according to the vocabulary.
"""
import argparse
import os
import collections
import sys
import numpy as np
import torch

tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(tencentpretrain_dir)

from tencentpretrain.utils.vocab import Vocab


def adapter(old_model, old_vocab, new_vocab):
    new_model = collections.OrderedDict()

    embedding_key = "embedding.word.embedding.weight"
    softmax_key = "target.mlm.linear_2.weight"
    softmax_bias_key = "target.mlm.linear_2.bias"

    # Fit in parameters that would not be modified.
    tensor_name = []
    for k, v in old_model.items():
        tensor_name.append(k)
        if k not in [embedding_key, softmax_key, softmax_bias_key]:
            new_model[k] = v
    bool = softmax_key in tensor_name
    # Get word embedding, mlm, and mlm bias variables.
    old_embedding = old_model.get(embedding_key).data.numpy()
    if bool:
        old_softmax = old_model.get(softmax_key).data.numpy()
        old_softmax_bias = old_model.get(softmax_bias_key).data.numpy()

    # Initialize.
    new_embedding = np.random.normal(0, 0.02, [len(new_vocab), old_embedding.shape[1]])
    if bool:
        new_softmax = np.random.normal(0, 0.02, [len(new_vocab), old_softmax.shape[1]])
        new_softmax_bias = np.random.normal(0, 0.02, [len(new_vocab)])

    # Put corresponding parameters into the new model.
    for i, w in enumerate(new_vocab.i2w):
        if w in old_vocab.w2i:
            old_w_index = old_vocab.w2i[w]
            new_embedding[i] = old_embedding[old_w_index]
            if bool:
                new_softmax[i] = old_softmax[old_w_index]
                new_softmax_bias[i] = old_softmax_bias[old_w_index]

    new_model[embedding_key] = torch.tensor(new_embedding, dtype=torch.float32)
    if bool:
        new_model[softmax_key] = torch.tensor(new_softmax, dtype=torch.float32)
        new_model[softmax_bias_key] = torch.tensor(new_softmax_bias, dtype=torch.float32)
    return new_model


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Input options.
    parser.add_argument("--old_model_path", type=str)
    parser.add_argument("--old_vocab_path", type=str)
    parser.add_argument("--new_vocab_path", type=str)

    # Output options.
    parser.add_argument("--new_model_path", type=str)

    args = parser.parse_args()
    old_vocab = Vocab()
    old_vocab.load(args.old_vocab_path)
    new_vocab = Vocab()
    new_vocab.load(args.new_vocab_path)

    old_model = torch.load(args.old_model_path, map_location="cpu")

    new_model = adapter(old_model, old_vocab, new_vocab)
    print("Output adapted new model.")
    torch.save(new_model, args.new_model_path)