G2PTL / htc_loss.py
jinyan218's picture
G2PTL update
42e1290
raw
history blame contribute delete
No virus
6.15 kB
#! python3
# -*- encoding: utf-8 -*-
import torch
import torch.nn.functional as F
import pandas as pd
import sys
import os
from transformers.utils.hub import cached_file
resolved_module_file = cached_file(
'Cainiao-AI/G2PTL',
'htc_mask_dict.pkl',
)
htc_weights = [0.067, 0.133, 0.2, 0.267, 0.333]
htc_mask_dict = pd.read_pickle(resolved_module_file)
import numpy as np
import operator
def calculate_multi_htc_acc_batch(predicted_htc, y, sequence_len = 6):
acc_cnt = np.array([0, 0, 0, 0, 0])
y = y.view(-1, sequence_len, 5).tolist()
predicted = np.array(predicted_htc).reshape(-1, sequence_len, 5).tolist()
batch_size = len(y)
total_cnt = np.array([0, 0, 0, 0, 0])
for batch_i in range(batch_size):
for index, s2 in enumerate(y[batch_i]):
for c, i in enumerate(range(5)):
y_l10 = y[batch_i][index][:i+1]
p_l10 = predicted[batch_i][index][:i+1]
if -100 in y_l10:
break
if operator.eq(y_l10, p_l10):
acc_cnt[c] += 1
total_cnt[c] += 1
return acc_cnt, total_cnt
class HTCLoss(torch.nn.Module):
def __init__(self, device, reduction='mean', using_htc = True):
super(HTCLoss, self).__init__()
self.reduction = reduction
self.htc_weights = htc_weights
self.device = device
self.using_htc = using_htc
self.htc_mask_dict = htc_mask_dict
for key, value in self.htc_mask_dict.items():
self.htc_mask_dict[key] = torch.tensor(value).clone().detach().to(self.device)
def forward(self, logits, target):
target = target.reshape(-1, 1)
target_mask = target != -100
target_mask = target_mask.squeeze()
target_mask_idx = torch.where(target == -100)
target_new = target.clone()
target_new[target_mask_idx] = 0
predict_res = []
if not self.using_htc:
log_pro = -1.0 * F.log_softmax(logits, dim=1)
else:
logits_reshaped = logits.clone()
logits_reshaped = logits_reshaped.reshape(-1, 5, 100)
_, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1)
aa_predicted += 1
logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device)
logits_new[:,0,1:32] = logits_reshaped[:,0,1:32]
for sample_idx, aa in enumerate(aa_predicted):
# Using mask_dict to get candidates for the next hierarchical
bb_idx = htc_mask_dict['{:02d}'.format(aa)]
_, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0)
bb = bb_idx[bb_idy]
logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx]
cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)]
_, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0)
logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx]
cc = cc_idx[cc_idy]
d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)]
_, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0)
logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx]
d = d_idx[d_idy]
ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)]
_, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0)
logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx]
ee = ee_idx[ee_idy]
predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()])
logits_new = logits_new.reshape(-1, 100)
log_pro = -1.0 * F.log_softmax(logits_new, dim=1)
logits = logits.contiguous().view(-1, 100)
one_hot = torch.zeros(logits.shape[0], logits.shape[1]).to(self.device) # .cuda()
one_hot = one_hot.scatter_(1, target_new, 1)
loss = torch.mul(log_pro, one_hot).sum(dim=1)
loss = loss*target_mask
bs = int(loss.shape[0] / 5)
w_loss = []
for i in range(bs):
w_loss.extend(self.htc_weights)
w_loss = torch.FloatTensor(w_loss).to(self.device)
loss = loss.mul(w_loss) * 5
if self.reduction == 'mean':
loss = loss[torch.where(loss>0)].mean()
elif self.reduction == 'sum':
loss = loss[torch.where(loss>0)].sum()
return loss, predict_res
def get_htc_code(self, logits):
logits_reshaped = logits.clone()
logits_reshaped = logits_reshaped.reshape(-1, 5, 100)
_, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1)
aa_predicted += 1
logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device)
logits_new[:,0,1:32] = logits_reshaped[:,0,1:32]
predict_res = []
for sample_idx, aa in enumerate(aa_predicted):
bb_idx = htc_mask_dict['{:02d}'.format(aa)]
_, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0)
bb = bb_idx[bb_idy]
logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx]
cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)]
_, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0)
logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx]
cc = cc_idx[cc_idy]
d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)]
_, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0)
logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx]
d = d_idx[d_idy]
ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)]
_, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0)
logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx]
ee = ee_idx[ee_idy]
predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()])
return predict_res