|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import pandas as pd |
|
import sys |
|
import os |
|
|
|
|
|
from transformers.utils.hub import cached_file |
|
|
|
resolved_module_file = cached_file( |
|
'JunhongLou/G2PTL', |
|
'htc_mask_dict.pkl', |
|
) |
|
|
|
htc_weights = [0.067, 0.133, 0.2, 0.267, 0.333] |
|
htc_mask_dict = pd.read_pickle(resolved_module_file) |
|
import numpy as np |
|
import operator |
|
def calculate_multi_htc_acc_batch(predicted_htc, y, sequence_len = 6): |
|
acc_cnt = np.array([0, 0, 0, 0, 0]) |
|
y = y.view(-1, sequence_len, 5).tolist() |
|
predicted = np.array(predicted_htc).reshape(-1, sequence_len, 5).tolist() |
|
batch_size = len(y) |
|
total_cnt = np.array([0, 0, 0, 0, 0]) |
|
for batch_i in range(batch_size): |
|
for index, s2 in enumerate(y[batch_i]): |
|
for c, i in enumerate(range(5)): |
|
y_l10 = y[batch_i][index][:i+1] |
|
p_l10 = predicted[batch_i][index][:i+1] |
|
if -100 in y_l10: |
|
break |
|
if operator.eq(y_l10, p_l10): |
|
acc_cnt[c] += 1 |
|
total_cnt[c] += 1 |
|
return acc_cnt, total_cnt |
|
|
|
|
|
class HTCLoss(torch.nn.Module): |
|
def __init__(self, device, reduction='mean', using_htc = True): |
|
super(HTCLoss, self).__init__() |
|
self.reduction = reduction |
|
self.htc_weights = htc_weights |
|
self.device = device |
|
self.using_htc = using_htc |
|
self.htc_mask_dict = htc_mask_dict |
|
for key, value in self.htc_mask_dict.items(): |
|
self.htc_mask_dict[key] = torch.tensor(value).clone().detach().to(self.device) |
|
|
|
def forward(self, logits, target): |
|
target = target.reshape(-1, 1) |
|
target_mask = target != -100 |
|
target_mask = target_mask.squeeze() |
|
target_mask_idx = torch.where(target == -100) |
|
target_new = target.clone() |
|
target_new[target_mask_idx] = 0 |
|
predict_res = [] |
|
if not self.using_htc: |
|
log_pro = -1.0 * F.log_softmax(logits, dim=1) |
|
else: |
|
logits_reshaped = logits.clone() |
|
logits_reshaped = logits_reshaped.reshape(-1, 5, 100) |
|
_, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1) |
|
aa_predicted += 1 |
|
logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device) |
|
logits_new[:,0,1:32] = logits_reshaped[:,0,1:32] |
|
for sample_idx, aa in enumerate(aa_predicted): |
|
|
|
bb_idx = htc_mask_dict['{:02d}'.format(aa)] |
|
_, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0) |
|
bb = bb_idx[bb_idy] |
|
logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx] |
|
cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)] |
|
_, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0) |
|
logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx] |
|
cc = cc_idx[cc_idy] |
|
d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)] |
|
_, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0) |
|
logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx] |
|
d = d_idx[d_idy] |
|
ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)] |
|
_, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0) |
|
logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx] |
|
ee = ee_idx[ee_idy] |
|
predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()]) |
|
|
|
logits_new = logits_new.reshape(-1, 100) |
|
log_pro = -1.0 * F.log_softmax(logits_new, dim=1) |
|
logits = logits.contiguous().view(-1, 100) |
|
one_hot = torch.zeros(logits.shape[0], logits.shape[1]).to(self.device) |
|
one_hot = one_hot.scatter_(1, target_new, 1) |
|
loss = torch.mul(log_pro, one_hot).sum(dim=1) |
|
loss = loss*target_mask |
|
bs = int(loss.shape[0] / 5) |
|
w_loss = [] |
|
for i in range(bs): |
|
w_loss.extend(self.htc_weights) |
|
w_loss = torch.FloatTensor(w_loss).to(self.device) |
|
loss = loss.mul(w_loss) * 5 |
|
if self.reduction == 'mean': |
|
loss = loss[torch.where(loss>0)].mean() |
|
elif self.reduction == 'sum': |
|
loss = loss[torch.where(loss>0)].sum() |
|
return loss, predict_res |
|
|
|
def get_htc_code(self, logits): |
|
logits_reshaped = logits.clone() |
|
logits_reshaped = logits_reshaped.reshape(-1, 5, 100) |
|
_, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1) |
|
aa_predicted += 1 |
|
logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device) |
|
logits_new[:,0,1:32] = logits_reshaped[:,0,1:32] |
|
predict_res = [] |
|
for sample_idx, aa in enumerate(aa_predicted): |
|
bb_idx = htc_mask_dict['{:02d}'.format(aa)] |
|
_, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0) |
|
bb = bb_idx[bb_idy] |
|
logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx] |
|
cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)] |
|
_, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0) |
|
logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx] |
|
cc = cc_idx[cc_idy] |
|
d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)] |
|
_, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0) |
|
logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx] |
|
d = d_idx[d_idy] |
|
ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)] |
|
_, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0) |
|
logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx] |
|
ee = ee_idx[ee_idy] |
|
predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()]) |
|
return predict_res |
|
|
|
|