File size: 6,842 Bytes
c08e521 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
#! python3
# -*- encoding: utf-8 -*-
import torch
import torch.nn.functional as F
import pandas as pd
import sys
import os
from transformers.utils.hub import cached_file
resolved_module_file = cached_file(
'Cainiao-AI/TAAS',
'htc_mask_dict_old.pkl'
)
htc_weights = [0.067, 0.133, 0.2, 0.267, 0.333]
htc_mask_dict = pd.read_pickle(resolved_module_file)
import numpy as np
import operator
def calculate_multi_htc_acc_batch(predicted_htc, y, sequence_len = 6):
acc_cnt = np.array([0, 0, 0, 0, 0])
y = y.view(-1, sequence_len, 5).tolist()
predicted = np.array(predicted_htc).reshape(-1, sequence_len, 5).tolist()
batch_size = len(y)
total_cnt = np.array([0, 0, 0, 0, 0])
for batch_i in range(batch_size):
for index, s2 in enumerate(y[batch_i]):
for c, i in enumerate(range(5)):
y_l10 = y[batch_i][index][:i+1]
p_l10 = predicted[batch_i][index][:i+1]
if -100 in y_l10:
break
if operator.eq(y_l10, p_l10):
acc_cnt[c] += 1
total_cnt[c] += 1
return acc_cnt, total_cnt
class HTCLoss(torch.nn.Module):
def __init__(self, device, reduction='mean', using_htc = True):
super(HTCLoss, self).__init__()
self.reduction = reduction
self.htc_weights = htc_weights
self.device = device
self.using_htc = using_htc
self.htc_mask_dict = htc_mask_dict
for key, value in self.htc_mask_dict.items():
# self.htc_mask_dict[key] = torch.tensor(value).to(self.device)
self.htc_mask_dict[key] = torch.tensor(value).clone().detach().to(self.device)
def forward(self, logits, target): # [bs,num_class] CE=q*-log(p), q*log(1-p),p=softmax(logits)
# target相关变量都在cuda上
target = target.reshape(-1, 1)
target_mask = target != -100
target_mask = target_mask.squeeze()
target_mask_idx = torch.where(target == -100)
target_new = target.clone()
target_new[target_mask_idx] = 0
predict_res = []
if not self.using_htc:
log_pro = -1.0 * F.log_softmax(logits, dim=1)
# one_hot = torch.zeros(logits.shape[0], logits.shape[1]).to(self.device) # .cuda()
# one_hot = one_hot.scatter_(1, target_new, 1)
# loss = torch.mul(log_pro, one_hot).sum(dim=1)
# loss = loss*target_mask
else:
# _, predicted = torch.max(logits[:, :32], 1)
logits_reshaped = logits.clone()
logits_reshaped = logits_reshaped.reshape(-1, 5, 100)
_, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1)
aa_predicted += 1
logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device)
logits_new[:,0,1:32] = logits_reshaped[:,0,1:32]
for sample_idx, aa in enumerate(aa_predicted):
bb_idx = htc_mask_dict['{:02d}'.format(aa)]
_, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0)
bb = bb_idx[bb_idy]
logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx]
cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)]
_, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0)
logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx]
cc = cc_idx[cc_idy]
d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)]
_, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0)
logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx]
d = d_idx[d_idy]
ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)]
_, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0)
logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx]
ee = ee_idx[ee_idy]
predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()])
# predicted = predicted.reshape(-1, 5)
# aa = predicted[:, 0]
# aa = ['{:02d}'.format(i) for i in aa]
# bb_activate = [htc_mask_dict[i] for i in aa]
logits_new = logits_new.reshape(-1, 100)
log_pro = -1.0 * F.log_softmax(logits_new, dim=1)
logits = logits.contiguous().view(-1, 100)
one_hot = torch.zeros(logits.shape[0], logits.shape[1]).to(self.device) # .cuda()
one_hot = one_hot.scatter_(1, target_new, 1)
loss = torch.mul(log_pro, one_hot).sum(dim=1)
loss = loss*target_mask
bs = int(loss.shape[0] / 5)
w_loss = []
for i in range(bs):
w_loss.extend(self.htc_weights)
w_loss = torch.FloatTensor(w_loss).to(self.device)
loss = loss.mul(w_loss) * 5
if self.reduction == 'mean':
loss = loss[torch.where(loss>0)].mean()
elif self.reduction == 'sum':
loss = loss[torch.where(loss>0)].sum()
return loss, predict_res
def get_htc_code(self, logits): # [bs,num_class] CE=q*-log(p), q*log(1-p),p=softmax(logits)
logits_reshaped = logits.clone()
logits_reshaped = logits_reshaped.reshape(-1, 5, 100)
_, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1)
aa_predicted += 1
logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device)
logits_new[:,0,1:32] = logits_reshaped[:,0,1:32]
predict_res = []
for sample_idx, aa in enumerate(aa_predicted):
bb_idx = htc_mask_dict['{:02d}'.format(aa)]
_, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0)
bb = bb_idx[bb_idy]
logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx]
cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)]
_, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0)
logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx]
cc = cc_idx[cc_idy]
d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)]
_, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0)
logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx]
d = d_idx[d_idy]
ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)]
_, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0)
logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx]
ee = ee_idx[ee_idy]
predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()])
return predict_res
|