|
from ML_SLRC import * |
|
|
|
import os |
|
import numpy as np |
|
import pandas as pd |
|
|
|
|
|
from torch.utils.data import DataLoader |
|
from torch.optim import Adam |
|
|
|
import gc |
|
from torchmetrics import functional as fn |
|
|
|
import random |
|
|
|
|
|
from tqdm import tqdm |
|
|
|
from sklearn.metrics import confusion_matrix |
|
from sklearn.metrics import roc_curve, auc |
|
import ipywidgets as widgets |
|
from IPython.display import display, clear_output |
|
import matplotlib.pyplot as plt |
|
import warnings |
|
import torch |
|
|
|
import time |
|
from sklearn.manifold import TSNE |
|
from copy import deepcopy |
|
import seaborn as sns |
|
import matplotlib.pylab as plt |
|
import json |
|
from pathlib import Path |
|
|
|
import re |
|
from collections import defaultdict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def random_seed(value): |
|
torch.backends.cudnn.deterministic=True |
|
torch.manual_seed(value) |
|
torch.cuda.manual_seed(value) |
|
np.random.seed(value) |
|
random.seed(value) |
|
|
|
|
|
def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4): |
|
idxs = list(range(0,len(taskset))) |
|
if is_shuffle: |
|
random.shuffle(idxs) |
|
for i in range(0,len(idxs), batch_size): |
|
yield [taskset[idxs[i]] for i in range(i, min(i + batch_size,len(taskset)))] |
|
|
|
|
|
|
|
def prepare_data(data, batch_size, tokenizer,max_seq_length, |
|
input = 'text', output = 'label', |
|
train_size_per_class = 5, global_datasets = False, |
|
treat_text_fun =None): |
|
data = data.reset_index().drop("index", axis=1) |
|
|
|
if global_datasets: |
|
global data_train, data_test |
|
|
|
|
|
data_train = data.groupby('label').sample(train_size_per_class, replace=False) |
|
idex = data.index.isin(data_train.index) |
|
|
|
|
|
data_test = data |
|
|
|
|
|
|
|
|
|
dataset_train = SLR_DataSet( |
|
data = data_train.sample(frac=1), |
|
input = input, |
|
output = output, |
|
tokenizer=tokenizer, |
|
max_seq_length =max_seq_length, |
|
treat_text =treat_text_fun) |
|
|
|
|
|
dataset_test = SLR_DataSet( |
|
data = data_test, |
|
input = input, |
|
output = output, |
|
tokenizer=tokenizer, |
|
max_seq_length =max_seq_length, |
|
treat_text =treat_text_fun) |
|
|
|
|
|
|
|
data_train_loader = DataLoader(dataset_train, |
|
shuffle=True, |
|
batch_size=batch_size['train'] |
|
) |
|
|
|
|
|
if len(dataset_test) % batch_size['test'] == 1 : |
|
data_test_loader = DataLoader(dataset_test, |
|
batch_size=batch_size['test'], |
|
drop_last=True) |
|
else: |
|
data_test_loader = DataLoader(dataset_test, |
|
batch_size=batch_size['test'], |
|
drop_last=False) |
|
|
|
return data_train_loader, data_test_loader, data_train, data_test |
|
|
|
|
|
|
|
def meta_train(data, model, device, Info, |
|
print_epoch =True, |
|
Test_resource =None, |
|
treat_text_fun =None): |
|
|
|
|
|
learner = Learner(model = model, device = device, **Info) |
|
|
|
|
|
if isinstance(Test_resource, pd.DataFrame): |
|
test = MetaTask(Test_resource, num_task = 0, k_support=10, k_query=10, |
|
training=False,treat_text =treat_text_fun, **Info) |
|
|
|
|
|
torch.clear_autocast_cache() |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
for epoch in tqdm(range(Info['meta_epoch']), desc= "Meta epoch ", ncols=80): |
|
|
|
|
|
train = MetaTask(data, |
|
num_task = Info['num_task_train'], |
|
k_support=Info['k_qry'], |
|
k_query=Info['k_spt'], |
|
treat_text =treat_text_fun, **Info) |
|
|
|
|
|
db = create_batch_of_tasks(train, is_shuffle = True, batch_size = Info["outer_batch_size"]) |
|
|
|
if print_epoch: |
|
|
|
for step, task_batch in enumerate(db): |
|
print("\n-----------------Training Mode","Meta_epoch:", epoch ,"-----------------\n") |
|
|
|
|
|
acc = learner(task_batch, valid_train= print_epoch) |
|
print('Step:', step, '\ttraining Acc:', acc) |
|
|
|
if isinstance(Test_resource, pd.DataFrame): |
|
|
|
if ((epoch+1) % 4) + step == 0: |
|
random_seed(123) |
|
print("\n-----------------Testing Mode-----------------\n") |
|
|
|
|
|
db_test = create_batch_of_tasks(test, is_shuffle = False, batch_size = 1) |
|
acc_all_test = [] |
|
|
|
|
|
for test_batch in db_test: |
|
acc = learner(test_batch, training = False) |
|
acc_all_test.append(acc) |
|
|
|
print('Test acc:', np.mean(acc_all_test)) |
|
del acc_all_test, db_test |
|
|
|
|
|
random_seed(int(time.time() % 10)) |
|
|
|
else: |
|
for step, task_batch in enumerate(db): |
|
|
|
acc = learner(task_batch, print_epoch, valid_train= print_epoch) |
|
|
|
torch.clear_autocast_cache() |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr = 1, print_info = True, name = 'name', weight_decay = 1): |
|
|
|
model_meta = deepcopy(model) |
|
optimizer = Adam(model_meta.parameters(), lr=lr, weight_decay = weight_decay) |
|
|
|
model_meta.to(device) |
|
model_meta.train() |
|
|
|
|
|
for i in range(0, epoch): |
|
all_loss = [] |
|
|
|
|
|
for inner_step, batch in enumerate(data_train_loader): |
|
batch = tuple(t.to(device) for t in batch) |
|
input_ids, attention_mask,q_token_type_ids, label_id = batch |
|
|
|
|
|
loss, _, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze()) |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
all_loss.append(loss.item()) |
|
|
|
|
|
if (i % 2 == 0) & print_info: |
|
print("Loss: ", np.mean(all_loss)) |
|
|
|
|
|
|
|
model_meta.eval() |
|
all_loss = [] |
|
all_acc = [] |
|
features = [] |
|
labels = [] |
|
predi_logit = [] |
|
|
|
with torch.no_grad(): |
|
|
|
for inner_step, batch in enumerate(tqdm(data_test_loader, |
|
desc="Test validation | " + name, |
|
ncols=80)) : |
|
batch = tuple(t.to(device) for t in batch) |
|
input_ids, attention_mask,q_token_type_ids, label_id = batch |
|
|
|
|
|
_, feature, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze()) |
|
|
|
|
|
|
|
logit = feature[1].detach().cpu() |
|
|
|
|
|
|
|
|
|
predi_logit.append(logit.numpy()) |
|
|
|
|
|
|
|
|
|
del input_ids, attention_mask, label_id, batch |
|
|
|
if print_info: |
|
print("acc:", np.mean(all_acc)) |
|
|
|
model_meta.to('cpu') |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
del model_meta, optimizer |
|
|
|
logits = np.concatenate(np.array(predi_logit,dtype=object)) |
|
logits = torch.tensor(logits.astype(np.float32)).detach().clone() |
|
|
|
|
|
return logits.detach().clone() |
|
|
|
|
|
def map_feature_tsne(features, labels, predi_logit): |
|
|
|
features = np.concatenate(np.array(features,dtype=object)) |
|
features = torch.tensor(features.astype(np.float32)).detach().clone() |
|
|
|
labels = np.concatenate(np.array(labels,dtype=object)) |
|
labels = torch.tensor(labels.astype(int)).detach().clone() |
|
|
|
logits = np.concatenate(np.array(predi_logit,dtype=object)) |
|
logits = torch.tensor(logits.astype(np.float32)).detach().clone() |
|
|
|
|
|
X_embedded = TSNE(n_components=2, learning_rate='auto', |
|
init='random').fit_transform(features.detach().clone()) |
|
|
|
return logits.detach().clone(), X_embedded, labels.detach().clone(), features.detach().clone() |
|
|
|
def wss_calc(logit, labels, trsh = 0.5): |
|
|
|
|
|
predict_trash = torch.sigmoid(logit).squeeze() >= trsh |
|
|
|
|
|
CM = confusion_matrix(labels, predict_trash.to(int) ) |
|
tn, fp, fne, tp = CM.ravel() |
|
|
|
P = (tp + fne) |
|
N = (tn + fp) |
|
recall = tp/(tp+fne) |
|
|
|
|
|
wss = (tn + fne)/len(labels) -(1- recall) |
|
|
|
|
|
awss = (tn/N - fne/P) |
|
|
|
return { |
|
"wss": round(wss,4), |
|
"awss": round(awss,4), |
|
"R": round(recall,4), |
|
"CM": CM |
|
} |
|
|
|
|
|
|
|
def plot(logits, X_embedded, labels, threshold, show = True, |
|
namefig = "plot", make_plot = True, print_stats = True, save = True): |
|
col = pd.MultiIndex.from_tuples([ |
|
("Predict", "0"), |
|
("Predict", "1") |
|
]) |
|
index = pd.MultiIndex.from_tuples([ |
|
("Real", "0"), |
|
("Real", "1") |
|
]) |
|
|
|
predict = torch.sigmoid(logits).detach().clone() |
|
|
|
|
|
fpr, tpr, thresholds = roc_curve(labels, predict.squeeze()) |
|
|
|
|
|
|
|
|
|
idx_wss95 = sum(tpr < 0.95) |
|
|
|
thresholds95 = thresholds[idx_wss95] |
|
|
|
|
|
wss95_info = wss_calc(logits,labels, thresholds95 ) |
|
acc_wss95 = fn.accuracy(predict, labels, threshold=thresholds95) |
|
f1_wss95 = fn.f1_score(predict, labels, threshold=thresholds95) |
|
|
|
|
|
|
|
|
|
wss_info = wss_calc(logits,labels, threshold ) |
|
acc_wssR = fn.accuracy(predict, labels, threshold=threshold) |
|
f1_wssR = fn.f1_score(predict, labels, threshold=threshold) |
|
|
|
|
|
metrics= { |
|
|
|
"WSS@95": wss95_info['wss'], |
|
"AWSS@95": wss95_info['awss'], |
|
"WSS@R": wss_info['wss'], |
|
"AWSS@R": wss_info['awss'], |
|
|
|
"Recall_WSS@95": wss95_info['R'], |
|
"Recall_WSS@R": wss_info['R'], |
|
|
|
"acc@95": acc_wss95.item(), |
|
"acc@R": acc_wssR.item(), |
|
|
|
"f1@95": f1_wss95.item(), |
|
"f1@R": f1_wssR.item(), |
|
|
|
"threshold@95": thresholds95 |
|
} |
|
|
|
|
|
if print_stats: |
|
wss95= f"WSS@95:{wss95_info['wss']}, R: {wss95_info['R']}" |
|
wss95_adj= f"ASSWSS@95:{wss95_info['awss']}" |
|
print(wss95) |
|
print(wss95_adj) |
|
print('Acc.:', round(acc_wss95.item(), 4)) |
|
print('F1-score:', round(f1_wss95.item(), 4)) |
|
print(f"threshold to wss95: {round(thresholds95, 4)}") |
|
cm = pd.DataFrame(wss95_info['CM'], |
|
index=index, |
|
columns=col) |
|
|
|
print("\nConfusion matrix:") |
|
print(cm) |
|
print("\n---Metrics with threshold:", threshold, "----\n") |
|
wss= f"WSS@R:{wss_info['wss']}, R: {wss_info['R']}" |
|
print(wss) |
|
wss_adj= f"AWSS@R:{wss_info['awss']}" |
|
print(wss_adj) |
|
print('Acc.:', round(acc_wssR.item(), 4)) |
|
print('F1-score:', round(f1_wssR.item(), 4)) |
|
cm = pd.DataFrame(wss_info['CM'], |
|
index=index, |
|
columns=col) |
|
|
|
print("\nConfusion matrix:") |
|
print(cm) |
|
|
|
|
|
|
|
|
|
if make_plot: |
|
|
|
fig, axes = plt.subplots(1, 4, figsize=(25,10)) |
|
alpha = torch.squeeze(predict).numpy() |
|
|
|
|
|
p1 = sns.scatterplot(x=X_embedded[:, 0], |
|
y=X_embedded[:, 1], |
|
hue=labels, |
|
alpha=alpha, ax = axes[0]).set_title('Predictions-TSNE', size=20) |
|
|
|
|
|
|
|
t_wss = predict >= thresholds95 |
|
t_wss = t_wss.squeeze().numpy() |
|
p2 = sns.scatterplot(x=X_embedded[t_wss, 0], |
|
y=X_embedded[t_wss, 1], |
|
hue=labels[t_wss], |
|
alpha=alpha[t_wss], ax = axes[1]).set_title('WSS@95', size=20) |
|
|
|
|
|
t = predict >= threshold |
|
t = t.squeeze().numpy() |
|
p3 = sns.scatterplot(x=X_embedded[t, 0], |
|
y=X_embedded[t, 1], |
|
hue=labels[t], |
|
alpha=alpha[t], ax = axes[2]).set_title(f'Predictions-threshold {threshold}', size=20) |
|
|
|
|
|
roc_auc = auc(fpr, tpr) |
|
lw = 2 |
|
axes[3].plot( |
|
fpr, |
|
tpr, |
|
color="darkorange", |
|
lw=lw, |
|
label="ROC curve (area = %0.2f)" % roc_auc) |
|
axes[3].plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--") |
|
axes[3].axhline(y=0.95, color='r', linestyle='-') |
|
|
|
axes[3].legend(loc="lower right") |
|
axes[3].set_title(label= "ROC", size = 20) |
|
axes[3].set_ylabel("True Positive Rate", fontsize = 15) |
|
axes[3].set_xlabel("False Positive Rate", fontsize = 15) |
|
|
|
|
|
if show: |
|
plt.show() |
|
|
|
if save: |
|
fig.savefig(namefig, dpi=fig.dpi) |
|
|
|
return metrics |
|
|
|
|
|
def auc_plot(logits,labels, color = "darkorange", label = "test"): |
|
predict = torch.sigmoid(logits).detach().clone() |
|
fpr, tpr, thresholds = roc_curve(labels, predict.squeeze()) |
|
roc_auc = auc(fpr, tpr) |
|
lw = 2 |
|
|
|
label = label + str(round(roc_auc,2)) |
|
|
|
|
|
plt.plot( |
|
fpr, |
|
tpr, |
|
color=color, |
|
lw=lw, |
|
label= label |
|
) |
|
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--") |
|
plt.axhline(y=0.95, color='r', linestyle='-') |
|
|
|
|
|
class diagnosis(): |
|
def __init__(self, names, Valid_resource, batch_size_test, |
|
model,Info, device,treat_text_fun=None,start = 0): |
|
self.names=names |
|
self.Valid_resource=Valid_resource |
|
self.batch_size_test=batch_size_test |
|
self.model=model |
|
self.start=start |
|
self.Info = Info |
|
self.device = device |
|
self.treat_text_fun = treat_text_fun |
|
|
|
|
|
|
|
self.value_trash = widgets.FloatText( |
|
value=0.95, |
|
description='threshold', |
|
disabled=False |
|
) |
|
self.valueb = widgets.IntText( |
|
value=10, |
|
description='size', |
|
disabled=False |
|
) |
|
|
|
|
|
self.train_b = widgets.Button(description="Train") |
|
self.next_b = widgets.Button(description="Next") |
|
self.eval_b = widgets.Button(description="Evaluation") |
|
|
|
self.hbox = widgets.HBox([self.train_b, self.valueb]) |
|
|
|
|
|
self.next_b.on_click(self.Next_button) |
|
self.train_b.on_click(self.Train_button) |
|
self.eval_b.on_click(self.Evaluation_button) |
|
|
|
|
|
|
|
def Next_button(self,p): |
|
clear_output() |
|
self.i=self.i+1 |
|
|
|
|
|
self.domain = self.names[self.i] |
|
self.data = self.Valid_resource[self.Valid_resource['domain'] == self.domain] |
|
|
|
print("Name:", self.domain) |
|
print(self.data['label'].value_counts()) |
|
display(self.hbox) |
|
display(self.next_b) |
|
|
|
|
|
|
|
def Train_button(self, y): |
|
clear_output() |
|
print(self.domain) |
|
|
|
|
|
self.data_train_loader, self.data_test_loader, self.data_train, self.data_test = prepare_data(self.data, |
|
train_size_per_class = self.valueb.value, |
|
batch_size = {'train': self.Info['inner_batch_size'], |
|
'test': self.batch_size_test}, |
|
max_seq_length = self.Info['max_seq_length'], |
|
tokenizer = self.Info['tokenizer'], |
|
input = "text", |
|
output = "label", |
|
treat_text_fun=self.treat_text_fun) |
|
|
|
|
|
self.logits, self.X_embedded, self.labels, self.features = train_loop(self.data_train_loader, self.data_test_loader, |
|
self.model, self.device, |
|
epoch = self.Info['inner_update_step'], |
|
lr=self.Info['inner_update_lr'], |
|
print_info=True, |
|
name = self.domain) |
|
|
|
tresh_box = widgets.HBox([self.eval_b, self.value_trash]) |
|
display(self.hbox) |
|
display(tresh_box) |
|
display(self.next_b) |
|
|
|
|
|
|
|
def Evaluation_button(self, te): |
|
clear_output() |
|
tresh_box = widgets.HBox([self.eval_b, self.value_trash]) |
|
|
|
print(self.domain) |
|
|
|
print("-------Train data-------") |
|
print(data_train['label'].value_counts()) |
|
print("-------Test data-------") |
|
print(data_test['label'].value_counts()) |
|
|
|
|
|
display(self.next_b) |
|
display(tresh_box) |
|
display(self.hbox) |
|
|
|
|
|
metrics = plot(self.logits, self.X_embedded, self.labels, |
|
threshold=self.Info['threshold'], show = True, |
|
namefig= 'test', |
|
make_plot = True, |
|
print_stats = True, |
|
save=False) |
|
|
|
def __call__(self): |
|
self.i= self.start-1 |
|
clear_output() |
|
display(self.next_b) |
|
|
|
|
|
|
|
|
|
|
|
def pipeline_simulation(Valid_resource, names_to_valid, path_save, |
|
model, Info, device, initializer_model, |
|
treat_text_fun=None): |
|
n_attempt = 5 |
|
batch_test = 100 |
|
|
|
|
|
for name in names_to_valid: |
|
name = re.sub("\.csv", "",name) |
|
Path(path_save + name + "/img").mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
roc_stats = defaultdict(lambda: defaultdict( |
|
lambda: defaultdict( |
|
list |
|
) |
|
) |
|
) |
|
|
|
|
|
|
|
|
|
all_metrics = [] |
|
|
|
for name in names_to_valid: |
|
|
|
|
|
data = Valid_resource[Valid_resource['domain'] == name].reset_index().drop("index", axis=1) |
|
|
|
|
|
for attempt in range(n_attempt): |
|
print("---"*4,"attempt", attempt, "---"*4) |
|
|
|
|
|
data_train_loader, data_test_loader, _ , _ = prepare_data(data, |
|
train_size_per_class = Info['k_spt'], |
|
batch_size = {'train': Info['inner_batch_size'], |
|
'test': batch_test}, |
|
max_seq_length = Info['max_seq_length'], |
|
tokenizer = Info['tokenizer'], |
|
input = "text", |
|
output = "label", |
|
treat_text_fun=treat_text_fun) |
|
|
|
|
|
logits, X_embedded, labels, features = train_loop(data_train_loader, data_test_loader, |
|
model, device, |
|
epoch = Info['inner_update_step'], |
|
lr=Info['inner_update_lr'], |
|
print_info=False, |
|
name = name) |
|
|
|
|
|
name_domain = re.sub("\.csv", "",name) |
|
|
|
|
|
metrics = plot(logits, X_embedded, labels, |
|
threshold=Info['threshold'], show = False, |
|
namefig= path_save + name_domain + "/img/" + str(attempt) + 'plots', |
|
make_plot = True, print_stats = False, save = True) |
|
|
|
|
|
fpr, tpr, _ = roc_curve(labels, torch.sigmoid(logits).squeeze()) |
|
|
|
|
|
metrics['name'] = name_domain |
|
metrics['layer_size'] = Info['bert_layers'] |
|
metrics['attempt'] = attempt |
|
roc_stats[name_domain][str(Info['bert_layers'])]['fpr'].append(fpr.tolist()) |
|
roc_stats[name_domain][str(Info['bert_layers'])]['tpr'].append(tpr.tolist()) |
|
all_metrics.append(metrics) |
|
|
|
|
|
pd.DataFrame(all_metrics).to_csv(path_save+ "metrics.csv") |
|
roc_path = path_save + "roc_stats.json" |
|
with open(roc_path, 'w') as fp: |
|
json.dump(roc_stats, fp) |
|
|
|
|
|
del fpr, tpr, logits, X_embedded, labels |
|
del features, metrics, _ |
|
|
|
|
|
|
|
save_info = Info.copy() |
|
save_info['model'] = initializer_model.tokenizer.name_or_path |
|
save_info.pop("tokenizer") |
|
save_info.pop("bert_layers") |
|
|
|
info_path = path_save+"info.json" |
|
with open(info_path, 'w') as fp: |
|
json.dump(save_info, fp) |
|
|
|
|
|
|
|
def load_data_statistics(paths, names): |
|
size = [] |
|
pos = [] |
|
neg = [] |
|
for p in paths: |
|
data = pd.read_csv(p) |
|
data = data.dropna() |
|
|
|
size.append(len(data)) |
|
|
|
pos.append(data['labels'].value_counts()[1]) |
|
|
|
neg.append(data['labels'].value_counts()[0]) |
|
del data |
|
|
|
info_load = pd.DataFrame({ |
|
"size":size, |
|
"pos":pos, |
|
"neg":neg, |
|
"names":names, |
|
"paths": paths }) |
|
return info_load |
|
|
|
|
|
def load_data(train_info_load): |
|
|
|
col = ['abstract','title', 'labels', 'domain'] |
|
|
|
data_train = pd.DataFrame(columns=col) |
|
for p in train_info_load['paths']: |
|
data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']] |
|
data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']] |
|
data_temp['domain'] = os.path.basename(p) |
|
data_train = pd.concat([data_train, data_temp]) |
|
|
|
data_train['text'] = data_train['title'] + data_train['abstract'].replace(np.nan, '') |
|
|
|
return( data_train \ |
|
.replace({"labels":{0:"negative", 1:'positive'}})\ |
|
.rename({"labels":"label"} , axis=1)\ |
|
.loc[ :,("text","domain","label")] |
|
) |
|
|
|
|
|
|
|
|