File size: 3,027 Bytes
c08e521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#! python3
# -*- encoding: utf-8 -*-

import torch
import torch.nn as nn
from typing import Optional
from transformers import AutoModel
from torch.nn.init import xavier_uniform_

def cal_ner_acc(y, y_hat):
    if len(y) == 0:
        return 0, 1
    y,y_hat = y.cpu().numpy(), y_hat.cpu().numpy()
    
    acc_cnt, len_cnt = 0, 0
    for i in range(len(y)):
        if y[i] <= 7 and y_hat[i] <= 7:
            len_cnt += 1
            if y[i] == y_hat[i]:
                acc_cnt += 1
    
    return acc_cnt, len_cnt


class NER_model(nn.Module):
    def __init__(self, vocab_size):
        super(NER_model, self).__init__()
        
        while True:
            try:
                self.g2ptl = AutoModel.from_pretrained('Cainiao-AI/G2PTL', trust_remote_code=True)
                break
            except:
                continue
        """
        Ner head
        """
        # print('model loaded.')
        self.dropout = nn.Dropout(p = 0.1, inplace = False)
        self.linear1 = nn.Linear(in_features=768, out_features=128, bias=True)
        self.linear2 = nn.Linear(in_features=128, out_features=vocab_size, bias=True)
        # self.classifier = nn.Linear(in_features=768, out_features=vocab_size, bias=True)
        # self.cls = ErnieForMaskedLM.from_pretrained('nghuyong/ernie-3.0-base-zh').cls
        #self._reset_parameters()

    def forward(self,
                input_ids,
                attention_mask,
                token_type_ids,
                node_position_ids,spatial_pos, in_degree, out_degree, edge_type_matrix, edge_input,
                prov_city_mask: Optional[torch.Tensor] = None,
                sequence_len=6,
                labels: Optional[torch.Tensor] = None
                ):
        output= self.g2ptl(input_ids, attention_mask, token_type_ids, node_position_ids, spatial_pos,
                                    in_degree,
                                    out_degree,
                                    edge_type_matrix,
                                    edge_input )
        
        pooler_output_embedding = output.final_hidden_state
        sequence_output = pooler_output_embedding.squeeze()
        # Input的是Bert输出的token sequence的embedding,而不是pooler的embedding
        sequence_output = self.dropout(sequence_output)
        linear_out = self.linear1(sequence_output)
        logits = self.linear2(self.dropout(linear_out))
        # logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return [logits, loss]

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)

    def save_weights(self, path):
        torch.save(self.state_dict(), path)

    def load_weights(self, path):
        self.load_state_dict(torch.load(path, map_location=torch.device('cpu')), False)