File size: 2,230 Bytes
7900c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
  This script provides an example to wrap TencentPretrain for embedding extraction.
"""
import sys
import os
import argparse
import torch

tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(tencentpretrain_dir)

from tencentpretrain.utils.vocab import Vocab


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--load_model_path", default=None, type=str,
                        help="Path of the input model.")
    parser.add_argument("--vocab_path", default=None, type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--spm_model_path", default=None, type=str,
                        help="Path of the sentence piece model.")
    parser.add_argument("--word_embedding_path", default=None, type=str,
                        help="Path of the output word embedding.")

    args = parser.parse_args()

    if args.spm_model_path:
        try:
            import sentencepiece as spm
        except ImportError:
            raise ImportError("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
                                                    "pip install sentencepiece")
        sp_model = spm.SentencePieceProcessor()
        sp_model.Load(args.spm_model_path)
        vocab = Vocab()
        vocab.i2w = {i: sp_model.IdToPiece(i) for i in range(sp_model.GetPieceSize())}
    else:
        vocab = Vocab()
        vocab.load(args.vocab_path)

    pretrained_model = torch.load(args.load_model_path)
    embedding = pretrained_model["embedding.word.embedding.weight"]

    with open(args.word_embedding_path, mode="w", encoding="utf-8") as f:
        head = str(list(embedding.size())[0]) + " " + str(list(embedding.size())[1]) + "\n"
        f.write(head)

        for i in range(len(vocab.i2w)):
            word = vocab.i2w[i]
            word_embedding = embedding[vocab.get(word), :]
            word_embedding = word_embedding.cpu().numpy().tolist()
            line = str(word)
            for j in range(len(word_embedding)):
                line = line + " " + str(word_embedding[j])
            line += "\n"
            f.write(line)