TAAS / ner_model.py
zy414775's picture
upload
c08e521
#! python3
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
from typing import Optional
from transformers import AutoModel
from torch.nn.init import xavier_uniform_
def cal_ner_acc(y, y_hat):
if len(y) == 0:
return 0, 1
y,y_hat = y.cpu().numpy(), y_hat.cpu().numpy()
acc_cnt, len_cnt = 0, 0
for i in range(len(y)):
if y[i] <= 7 and y_hat[i] <= 7:
len_cnt += 1
if y[i] == y_hat[i]:
acc_cnt += 1
return acc_cnt, len_cnt
class NER_model(nn.Module):
def __init__(self, vocab_size):
super(NER_model, self).__init__()
while True:
try:
self.g2ptl = AutoModel.from_pretrained('Cainiao-AI/G2PTL', trust_remote_code=True)
break
except:
continue
"""
Ner head
"""
# print('model loaded.')
self.dropout = nn.Dropout(p = 0.1, inplace = False)
self.linear1 = nn.Linear(in_features=768, out_features=128, bias=True)
self.linear2 = nn.Linear(in_features=128, out_features=vocab_size, bias=True)
# self.classifier = nn.Linear(in_features=768, out_features=vocab_size, bias=True)
# self.cls = ErnieForMaskedLM.from_pretrained('nghuyong/ernie-3.0-base-zh').cls
#self._reset_parameters()
def forward(self,
input_ids,
attention_mask,
token_type_ids,
node_position_ids,spatial_pos, in_degree, out_degree, edge_type_matrix, edge_input,
prov_city_mask: Optional[torch.Tensor] = None,
sequence_len=6,
labels: Optional[torch.Tensor] = None
):
output= self.g2ptl(input_ids, attention_mask, token_type_ids, node_position_ids, spatial_pos,
in_degree,
out_degree,
edge_type_matrix,
edge_input )
pooler_output_embedding = output.final_hidden_state
sequence_output = pooler_output_embedding.squeeze()
# Input的是Bert输出的token sequence的embedding,而不是pooler的embedding
sequence_output = self.dropout(sequence_output)
linear_out = self.linear1(sequence_output)
logits = self.linear2(self.dropout(linear_out))
# logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return [logits, loss]
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
xavier_uniform_(p)
def save_weights(self, path):
torch.save(self.state_dict(), path)
def load_weights(self, path):
self.load_state_dict(torch.load(path, map_location=torch.device('cpu')), False)