File size: 2,564 Bytes
f3581c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# This code takes the pytorch weights generated using paddle2torch_weights script and then stacks
# Queries, Keys and Values for Attention(self_attn) Layer in Encoder Layers(to make it more like torch.nn.MultiheadAttention).

import torch
full_state_dict = torch.load("./pytorch_model.bin")
full_state_dict = dict((".".join(k.split(".")[1:]), v) \
                       for k, v in full_state_dict.items())

def con_cat(kqv_dict):
    kqv_dict_keys = list(kqv_dict.keys())
    if "weight" in kqv_dict_keys[0]:
        tmp = kqv_dict_keys[0].split(".")[3]
        c_dict_value = torch.cat([kqv_dict[kqv_dict_keys[0].replace(tmp, "q_proj")],
                                  kqv_dict[kqv_dict_keys[0].replace(tmp, "k_proj")],
                                  kqv_dict[kqv_dict_keys[0].replace(tmp, "v_proj")]
                                  ])
        c_dict_key = ".".join(kqv_dict_keys[0].split(".")[:3]+["in_proj_weight"])
        # return {c_dict_key:c_dict_value}
        return {f"encoder.{c_dict_key}":c_dict_value}

    #(k,q,v), (k,v,q), (q, k, v), (q, v, k), (v, k, q), (v, q, k)
    if "bias" in kqv_dict_keys[0]:
        tmp = kqv_dict_keys[0].split(".")[3]
        c_dict_value = torch.cat([kqv_dict[kqv_dict_keys[0].replace(tmp, "q_proj")],
                                  kqv_dict[kqv_dict_keys[0].replace(tmp, "k_proj")],
                                  kqv_dict[kqv_dict_keys[0].replace(tmp, "v_proj")]
                                  ])
        c_dict_key = ".".join(kqv_dict_keys[0].split(".")[:3]+["in_proj_bias"])
        # return {c_dict_key:c_dict_value}
        return {f"encoder.{c_dict_key}":c_dict_value}


mod_dict = {}
#Embedding weights
for k, v in full_state_dict.items():
    if "embedding" in k or "layer_norm" in k:
        mod_dict.update({f"embeddings.{k}": v})

#Encoder weights
for i in range(12):
    sd = dict((k, v) for k, v in full_state_dict.items() if f"layers.{i}" in k)
    kvq_weight = {}
    kvq_bias = {}
    for k, v in sd.items():
        if "self_attn" in k and "out_proj" not in k:
            if "weight" in k:
                kvq_weight[k] = v
            if "bias" in k:
                kvq_bias[k] = v
        else:
            mod_dict[f"encoder.{k}"] = v

    mod_dict.update(con_cat(kvq_weight))
    mod_dict.update(con_cat(kvq_bias))

#Pooler
for k, v in full_state_dict.items():
    if "pooler" in k:
        mod_dict.update({k:v})


for k, v in mod_dict.items():
    print(k, v.size())

model_name = "ernie-m-base_pytorch"
PATH = f"./{model_name}/pytorch_model.bin"
torch.save(mod_dict, PATH)