#! python3 # -*- encoding: utf-8 -*- import torch import torch.nn.functional as F import pandas as pd import sys import os from transformers.utils.hub import cached_file resolved_module_file = cached_file( 'JunhongLou/G2PTL', 'htc_mask_dict.pkl', ) htc_weights = [0.067, 0.133, 0.2, 0.267, 0.333] htc_mask_dict = pd.read_pickle(resolved_module_file) import numpy as np import operator def calculate_multi_htc_acc_batch(predicted_htc, y, sequence_len = 6): acc_cnt = np.array([0, 0, 0, 0, 0]) y = y.view(-1, sequence_len, 5).tolist() predicted = np.array(predicted_htc).reshape(-1, sequence_len, 5).tolist() batch_size = len(y) total_cnt = np.array([0, 0, 0, 0, 0]) for batch_i in range(batch_size): for index, s2 in enumerate(y[batch_i]): for c, i in enumerate(range(5)): y_l10 = y[batch_i][index][:i+1] p_l10 = predicted[batch_i][index][:i+1] if -100 in y_l10: break if operator.eq(y_l10, p_l10): acc_cnt[c] += 1 total_cnt[c] += 1 return acc_cnt, total_cnt class HTCLoss(torch.nn.Module): def __init__(self, device, reduction='mean', using_htc = True): super(HTCLoss, self).__init__() self.reduction = reduction self.htc_weights = htc_weights self.device = device self.using_htc = using_htc self.htc_mask_dict = htc_mask_dict for key, value in self.htc_mask_dict.items(): self.htc_mask_dict[key] = torch.tensor(value).clone().detach().to(self.device) def forward(self, logits, target): target = target.reshape(-1, 1) target_mask = target != -100 target_mask = target_mask.squeeze() target_mask_idx = torch.where(target == -100) target_new = target.clone() target_new[target_mask_idx] = 0 predict_res = [] if not self.using_htc: log_pro = -1.0 * F.log_softmax(logits, dim=1) else: logits_reshaped = logits.clone() logits_reshaped = logits_reshaped.reshape(-1, 5, 100) _, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1) aa_predicted += 1 logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device) logits_new[:,0,1:32] = logits_reshaped[:,0,1:32] for sample_idx, aa in enumerate(aa_predicted): # Using mask_dict to get candidates for the next hierarchical bb_idx = htc_mask_dict['{:02d}'.format(aa)] _, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0) bb = bb_idx[bb_idy] logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx] cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)] _, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0) logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx] cc = cc_idx[cc_idy] d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)] _, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0) logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx] d = d_idx[d_idy] ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)] _, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0) logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx] ee = ee_idx[ee_idy] predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()]) logits_new = logits_new.reshape(-1, 100) log_pro = -1.0 * F.log_softmax(logits_new, dim=1) logits = logits.contiguous().view(-1, 100) one_hot = torch.zeros(logits.shape[0], logits.shape[1]).to(self.device) # .cuda() one_hot = one_hot.scatter_(1, target_new, 1) loss = torch.mul(log_pro, one_hot).sum(dim=1) loss = loss*target_mask bs = int(loss.shape[0] / 5) w_loss = [] for i in range(bs): w_loss.extend(self.htc_weights) w_loss = torch.FloatTensor(w_loss).to(self.device) loss = loss.mul(w_loss) * 5 if self.reduction == 'mean': loss = loss[torch.where(loss>0)].mean() elif self.reduction == 'sum': loss = loss[torch.where(loss>0)].sum() return loss, predict_res def get_htc_code(self, logits): logits_reshaped = logits.clone() logits_reshaped = logits_reshaped.reshape(-1, 5, 100) _, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1) aa_predicted += 1 logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device) logits_new[:,0,1:32] = logits_reshaped[:,0,1:32] predict_res = [] for sample_idx, aa in enumerate(aa_predicted): bb_idx = htc_mask_dict['{:02d}'.format(aa)] _, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0) bb = bb_idx[bb_idy] logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx] cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)] _, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0) logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx] cc = cc_idx[cc_idy] d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)] _, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0) logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx] d = d_idx[d_idy] ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)] _, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0) logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx] ee = ee_idx[ee_idy] predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()]) return predict_res