|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from typing import Optional |
|
from transformers import AutoModel |
|
from torch.nn.init import xavier_uniform_ |
|
|
|
def cal_ner_acc(y, y_hat): |
|
if len(y) == 0: |
|
return 0, 1 |
|
y,y_hat = y.cpu().numpy(), y_hat.cpu().numpy() |
|
|
|
acc_cnt, len_cnt = 0, 0 |
|
for i in range(len(y)): |
|
if y[i] <= 7 and y_hat[i] <= 7: |
|
len_cnt += 1 |
|
if y[i] == y_hat[i]: |
|
acc_cnt += 1 |
|
|
|
return acc_cnt, len_cnt |
|
|
|
|
|
class NER_model(nn.Module): |
|
def __init__(self, vocab_size): |
|
super(NER_model, self).__init__() |
|
|
|
while True: |
|
try: |
|
self.g2ptl = AutoModel.from_pretrained('Cainiao-AI/G2PTL', trust_remote_code=True) |
|
break |
|
except: |
|
continue |
|
""" |
|
Ner head |
|
""" |
|
|
|
self.dropout = nn.Dropout(p = 0.1, inplace = False) |
|
self.linear1 = nn.Linear(in_features=768, out_features=128, bias=True) |
|
self.linear2 = nn.Linear(in_features=128, out_features=vocab_size, bias=True) |
|
|
|
|
|
|
|
|
|
def forward(self, |
|
input_ids, |
|
attention_mask, |
|
token_type_ids, |
|
node_position_ids,spatial_pos, in_degree, out_degree, edge_type_matrix, edge_input, |
|
prov_city_mask: Optional[torch.Tensor] = None, |
|
sequence_len=6, |
|
labels: Optional[torch.Tensor] = None |
|
): |
|
output= self.g2ptl(input_ids, attention_mask, token_type_ids, node_position_ids, spatial_pos, |
|
in_degree, |
|
out_degree, |
|
edge_type_matrix, |
|
edge_input ) |
|
|
|
pooler_output_embedding = output.final_hidden_state |
|
sequence_output = pooler_output_embedding.squeeze() |
|
|
|
sequence_output = self.dropout(sequence_output) |
|
linear_out = self.linear1(sequence_output) |
|
logits = self.linear2(self.dropout(linear_out)) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
return [logits, loss] |
|
|
|
def _reset_parameters(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
xavier_uniform_(p) |
|
|
|
def save_weights(self, path): |
|
torch.save(self.state_dict(), path) |
|
|
|
def load_weights(self, path): |
|
self.load_state_dict(torch.load(path, map_location=torch.device('cpu')), False) |
|
|
|
|
|
|