#! python3 # -*- encoding: utf-8 -*- import torch import torch.nn as nn from typing import Optional from transformers import AutoModel from torch.nn.init import xavier_uniform_ def cal_ner_acc(y, y_hat): if len(y) == 0: return 0, 1 y,y_hat = y.cpu().numpy(), y_hat.cpu().numpy() acc_cnt, len_cnt = 0, 0 for i in range(len(y)): if y[i] <= 7 and y_hat[i] <= 7: len_cnt += 1 if y[i] == y_hat[i]: acc_cnt += 1 return acc_cnt, len_cnt class NER_model(nn.Module): def __init__(self, vocab_size): super(NER_model, self).__init__() while True: try: self.g2ptl = AutoModel.from_pretrained('Cainiao-AI/G2PTL', trust_remote_code=True) break except: continue """ Ner head """ # print('model loaded.') self.dropout = nn.Dropout(p = 0.1, inplace = False) self.linear1 = nn.Linear(in_features=768, out_features=128, bias=True) self.linear2 = nn.Linear(in_features=128, out_features=vocab_size, bias=True) # self.classifier = nn.Linear(in_features=768, out_features=vocab_size, bias=True) # self.cls = ErnieForMaskedLM.from_pretrained('nghuyong/ernie-3.0-base-zh').cls #self._reset_parameters() def forward(self, input_ids, attention_mask, token_type_ids, node_position_ids,spatial_pos, in_degree, out_degree, edge_type_matrix, edge_input, prov_city_mask: Optional[torch.Tensor] = None, sequence_len=6, labels: Optional[torch.Tensor] = None ): output= self.g2ptl(input_ids, attention_mask, token_type_ids, node_position_ids, spatial_pos, in_degree, out_degree, edge_type_matrix, edge_input ) pooler_output_embedding = output.final_hidden_state sequence_output = pooler_output_embedding.squeeze() # Input的是Bert输出的token sequence的embedding,而不是pooler的embedding sequence_output = self.dropout(sequence_output) linear_out = self.linear1(sequence_output) logits = self.linear2(self.dropout(linear_out)) # logits = self.classifier(sequence_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return [logits, loss] def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: xavier_uniform_(p) def save_weights(self, path): torch.save(self.state_dict(), path) def load_weights(self, path): self.load_state_dict(torch.load(path, map_location=torch.device('cpu')), False)