from ML_SLRC import * |
import os |
import numpy as np |
import pandas as pd |
from torch.utils.data import DataLoader |
from torch.optim import Adam |
import gc |
from torchmetrics import functional as fn |
import random |
from tqdm import tqdm |
from sklearn.metrics import confusion_matrix |
from sklearn.metrics import roc_curve, auc |
import ipywidgets as widgets |
from IPython.display import display, clear_output |
import matplotlib.pyplot as plt |
import warnings |
import torch |
import time |
from sklearn.manifold import TSNE |
from copy import deepcopy |
import seaborn as sns |
import matplotlib.pylab as plt |
import json |
from pathlib import Path |
import re |
from collections import defaultdict |
def random_seed(value): |
torch.backends.cudnn.deterministic=True |
torch.manual_seed(value) |
torch.cuda.manual_seed(value) |
np.random.seed(value) |
random.seed(value) |
def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4): |
idxs = list(range(0,len(taskset))) |
if is_shuffle: |
random.shuffle(idxs) |
for i in range(0,len(idxs), batch_size): |
yield [taskset[idxs[i]] for i in range(i, min(i + batch_size,len(taskset)))] |
def prepare_data(data, batch_size, tokenizer,max_seq_length, |
input = 'text', output = 'label', |
train_size_per_class = 5, global_datasets = False, |
treat_text_fun =None): |
data = data.reset_index().drop("index", axis=1) |
if global_datasets: |
global data_train, data_test |
data_train = data.groupby('label').sample(train_size_per_class, replace=False) |
idex = data.index.isin(data_train.index) |
data_test = data |
dataset_train = SLR_DataSet( |
data = data_train.sample(frac=1), |
input = input, |
output = output, |
tokenizer=tokenizer, |
max_seq_length =max_seq_length, |
treat_text =treat_text_fun) |
dataset_test = SLR_DataSet( |
data = data_test, |
input = input, |
output = output, |
tokenizer=tokenizer, |
max_seq_length =max_seq_length, |
treat_text =treat_text_fun) |
data_train_loader = DataLoader(dataset_train, |
shuffle=True, |
batch_size=batch_size['train'] |
) |
if len(dataset_test) % batch_size['test'] == 1 : |
data_test_loader = DataLoader(dataset_test, |
batch_size=batch_size['test'], |
drop_last=True) |
else: |
data_test_loader = DataLoader(dataset_test, |
batch_size=batch_size['test'], |
drop_last=False) |
return data_train_loader, data_test_loader, data_train, data_test |
def meta_train(data, model, device, Info, |
print_epoch =True, |
Test_resource =None, |
treat_text_fun =None): |
learner = Learner(model = model, device = device, **Info) |
if isinstance(Test_resource, pd.DataFrame): |
test = MetaTask(Test_resource, num_task = 0, k_support=10, k_query=10, |
training=False,treat_text =treat_text_fun, **Info) |
torch.clear_autocast_cache() |
gc.collect() |
torch.cuda.empty_cache() |
for epoch in tqdm(range(Info['meta_epoch']), desc= "Meta epoch ", ncols=80): |
train = MetaTask(data, |
num_task = Info['num_task_train'], |
k_support=Info['k_qry'], |
k_query=Info['k_spt'], |
treat_text =treat_text_fun, **Info) |
db = create_batch_of_tasks(train, is_shuffle = True, batch_size = Info["outer_batch_size"]) |
if print_epoch: |
for step, task_batch in enumerate(db): |
print("\n-----------------Training Mode","Meta_epoch:", epoch ,"-----------------\n") |
acc = learner(task_batch, valid_train= print_epoch) |
print('Step:', step, '\ttraining Acc:', acc) |
if isinstance(Test_resource, pd.DataFrame): |
if ((epoch+1) % 4) + step == 0: |
random_seed(123) |
print("\n-----------------Testing Mode-----------------\n") |
db_test = create_batch_of_tasks(test, is_shuffle = False, batch_size = 1) |
acc_all_test = [] |
for test_batch in db_test: |
acc = learner(test_batch, training = False) |
acc_all_test.append(acc) |
print('Test acc:', np.mean(acc_all_test)) |
del acc_all_test, db_test |
random_seed(int(time.time() % 10)) |
else: |
for step, task_batch in enumerate(db): |
acc = learner(task_batch, print_epoch, valid_train= print_epoch) |
torch.clear_autocast_cache() |
gc.collect() |
torch.cuda.empty_cache() |
def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr = 1, print_info = True, name = 'name', weight_decay = 1): |
model_meta = deepcopy(model) |
optimizer = Adam(model_meta.parameters(), lr=lr, weight_decay = weight_decay) |
model_meta.to(device) |
model_meta.train() |
for i in range(0, epoch): |
all_loss = [] |
for inner_step, batch in enumerate(data_train_loader): |
batch = tuple(t.to(device) for t in batch) |
input_ids, attention_mask,q_token_type_ids, label_id = batch |
loss, _, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze()) |
loss.backward() |
optimizer.step() |
optimizer.zero_grad() |
all_loss.append(loss.item()) |
if (i % 2 == 0) & print_info: |
print("Loss: ", np.mean(all_loss)) |
model_meta.eval() |
all_loss = [] |
all_acc = [] |
features = [] |
labels = [] |
predi_logit = [] |
with torch.no_grad(): |
for inner_step, batch in enumerate(tqdm(data_test_loader, |
desc="Test validation | " + name, |
ncols=80)) : |
batch = tuple(t.to(device) for t in batch) |
input_ids, attention_mask,q_token_type_ids, label_id = batch |
_, feature, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze()) |
logit = feature[1].detach().cpu() |
predi_logit.append(logit.numpy()) |
del input_ids, attention_mask, label_id, batch |
if print_info: |
print("acc:", np.mean(all_acc)) |
model_meta.to('cpu') |
gc.collect() |
torch.cuda.empty_cache() |
del model_meta, optimizer |
logits = np.concatenate(np.array(predi_logit,dtype=object)) |
logits = torch.tensor(logits.astype(np.float32)).detach().clone() |
return logits.detach().clone() |
def map_feature_tsne(features, labels, predi_logit): |
features = np.concatenate(np.array(features,dtype=object)) |
features = torch.tensor(features.astype(np.float32)).detach().clone() |
labels = np.concatenate(np.array(labels,dtype=object)) |
labels = torch.tensor(labels.astype(int)).detach().clone() |
logits = np.concatenate(np.array(predi_logit,dtype=object)) |
logits = torch.tensor(logits.astype(np.float32)).detach().clone() |
X_embedded = TSNE(n_components=2, learning_rate='auto', |
init='random').fit_transform(features.detach().clone()) |
return logits.detach().clone(), X_embedded, labels.detach().clone(), features.detach().clone() |
def wss_calc(logit, labels, trsh = 0.5): |
predict_trash = torch.sigmoid(logit).squeeze() >= trsh |
CM = confusion_matrix(labels, predict_trash.to(int) ) |
tn, fp, fne, tp = CM.ravel() |
P = (tp + fne) |
N = (tn + fp) |
recall = tp/(tp+fne) |
wss = (tn + fne)/len(labels) -(1- recall) |
awss = (tn/N - fne/P) |
return { |
"wss": round(wss,4), |
"awss": round(awss,4), |
"R": round(recall,4), |
"CM": CM |
} |
def plot(logits, X_embedded, labels, threshold, show = True, |
namefig = "plot", make_plot = True, print_stats = True, save = True): |
col = pd.MultiIndex.from_tuples([ |
("Predict", "0"), |
("Predict", "1") |
]) |
index = pd.MultiIndex.from_tuples([ |
("Real", "0"), |
("Real", "1") |
]) |
predict = torch.sigmoid(logits).detach().clone() |
fpr, tpr, thresholds = roc_curve(labels, predict.squeeze()) |
idx_wss95 = sum(tpr < 0.95) |
thresholds95 = thresholds[idx_wss95] |
wss95_info = wss_calc(logits,labels, thresholds95 ) |
acc_wss95 = fn.accuracy(predict, labels, threshold=thresholds95) |
f1_wss95 = fn.f1_score(predict, labels, threshold=thresholds95) |
wss_info = wss_calc(logits,labels, threshold ) |
acc_wssR = fn.accuracy(predict, labels, threshold=threshold) |
f1_wssR = fn.f1_score(predict, labels, threshold=threshold) |
metrics= { |
"WSS@95": wss95_info['wss'], |
"AWSS@95": wss95_info['awss'], |
"WSS@R": wss_info['wss'], |
"AWSS@R": wss_info['awss'], |
"Recall_WSS@95": wss95_info['R'], |
"Recall_WSS@R": wss_info['R'], |
"acc@95": acc_wss95.item(), |
"acc@R": acc_wssR.item(), |
"f1@95": f1_wss95.item(), |
"f1@R": f1_wssR.item(), |
"threshold@95": thresholds95 |
} |
if print_stats: |
wss95= f"WSS@95:{wss95_info['wss']}, R: {wss95_info['R']}" |
wss95_adj= f"ASSWSS@95:{wss95_info['awss']}" |
print(wss95) |
print(wss95_adj) |
print('Acc.:', round(acc_wss95.item(), 4)) |
print('F1-score:', round(f1_wss95.item(), 4)) |
print(f"threshold to wss95: {round(thresholds95, 4)}") |
cm = pd.DataFrame(wss95_info['CM'], |
index=index, |
columns=col) |
print("\nConfusion matrix:") |
print(cm) |
print("\n---Metrics with threshold:", threshold, "----\n") |
wss= f"WSS@R:{wss_info['wss']}, R: {wss_info['R']}" |
print(wss) |
wss_adj= f"AWSS@R:{wss_info['awss']}" |
print(wss_adj) |
print('Acc.:', round(acc_wssR.item(), 4)) |
print('F1-score:', round(f1_wssR.item(), 4)) |
cm = pd.DataFrame(wss_info['CM'], |
index=index, |
columns=col) |
print("\nConfusion matrix:") |
print(cm) |
if make_plot: |
fig, axes = plt.subplots(1, 4, figsize=(25,10)) |
alpha = torch.squeeze(predict).numpy() |
p1 = sns.scatterplot(x=X_embedded[:, 0], |
y=X_embedded[:, 1], |
hue=labels, |
alpha=alpha, ax = axes[0]).set_title('Predictions-TSNE', size=20) |
t_wss = predict >= thresholds95 |
t_wss = t_wss.squeeze().numpy() |
p2 = sns.scatterplot(x=X_embedded[t_wss, 0], |
y=X_embedded[t_wss, 1], |
hue=labels[t_wss], |
alpha=alpha[t_wss], ax = axes[1]).set_title('WSS@95', size=20) |
t = predict >= threshold |
t = t.squeeze().numpy() |
p3 = sns.scatterplot(x=X_embedded[t, 0], |
y=X_embedded[t, 1], |
hue=labels[t], |
alpha=alpha[t], ax = axes[2]).set_title(f'Predictions-threshold {threshold}', size=20) |
roc_auc = auc(fpr, tpr) |
lw = 2 |
axes[3].plot( |
fpr, |
tpr, |
color="darkorange", |
lw=lw, |
label="ROC curve (area = %0.2f)" % roc_auc) |
axes[3].plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--") |
axes[3].axhline(y=0.95, color='r', linestyle='-') |
axes[3].legend(loc="lower right") |
axes[3].set_title(label= "ROC", size = 20) |
axes[3].set_ylabel("True Positive Rate", fontsize = 15) |
axes[3].set_xlabel("False Positive Rate", fontsize = 15) |
if show: |
plt.show() |
if save: |
fig.savefig(namefig, dpi=fig.dpi) |
return metrics |
def auc_plot(logits,labels, color = "darkorange", label = "test"): |
predict = torch.sigmoid(logits).detach().clone() |
fpr, tpr, thresholds = roc_curve(labels, predict.squeeze()) |
roc_auc = auc(fpr, tpr) |
lw = 2 |
label = label + str(round(roc_auc,2)) |
plt.plot( |
fpr, |
tpr, |
color=color, |
lw=lw, |
label= label |
) |
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--") |
plt.axhline(y=0.95, color='r', linestyle='-') |
class diagnosis(): |
def __init__(self, names, Valid_resource, batch_size_test, |
model,Info, device,treat_text_fun=None,start = 0): |
self.names=names |
self.Valid_resource=Valid_resource |
self.batch_size_test=batch_size_test |
self.model=model |
self.start=start |
self.Info = Info |
self.device = device |
self.treat_text_fun = treat_text_fun |
self.value_trash = widgets.FloatText( |
value=0.95, |
description='threshold', |
disabled=False |
) |
self.valueb = widgets.IntText( |
value=10, |
description='size', |
disabled=False |
) |
self.train_b = widgets.Button(description="Train") |
self.next_b = widgets.Button(description="Next") |
self.eval_b = widgets.Button(description="Evaluation") |
self.hbox = widgets.HBox([self.train_b, self.valueb]) |
self.next_b.on_click(self.Next_button) |
self.train_b.on_click(self.Train_button) |
self.eval_b.on_click(self.Evaluation_button) |
def Next_button(self,p): |
clear_output() |
self.i=self.i+1 |
self.domain = self.names[self.i] |
self.data = self.Valid_resource[self.Valid_resource['domain'] == self.domain] |
print("Name:", self.domain) |
print(self.data['label'].value_counts()) |
display(self.hbox) |
display(self.next_b) |
def Train_button(self, y): |
clear_output() |
print(self.domain) |
self.data_train_loader, self.data_test_loader, self.data_train, self.data_test = prepare_data(self.data, |
train_size_per_class = self.valueb.value, |
batch_size = {'train': self.Info['inner_batch_size'], |
'test': self.batch_size_test}, |
max_seq_length = self.Info['max_seq_length'], |
tokenizer = self.Info['tokenizer'], |
input = "text", |
output = "label", |
treat_text_fun=self.treat_text_fun) |
self.logits, self.X_embedded, self.labels, self.features = train_loop(self.data_train_loader, self.data_test_loader, |
self.model, self.device, |
epoch = self.Info['inner_update_step'], |
lr=self.Info['inner_update_lr'], |
print_info=True, |
name = self.domain) |
tresh_box = widgets.HBox([self.eval_b, self.value_trash]) |
display(self.hbox) |
display(tresh_box) |
display(self.next_b) |
def Evaluation_button(self, te): |
clear_output() |
tresh_box = widgets.HBox([self.eval_b, self.value_trash]) |
print(self.domain) |
print("-------Train data-------") |
print(data_train['label'].value_counts()) |
print("-------Test data-------") |
print(data_test['label'].value_counts()) |
display(self.next_b) |
display(tresh_box) |
display(self.hbox) |
metrics = plot(self.logits, self.X_embedded, self.labels, |
threshold=self.Info['threshold'], show = True, |
namefig= 'test', |
make_plot = True, |
print_stats = True, |
save=False) |
def __call__(self): |
self.i= self.start-1 |
clear_output() |
display(self.next_b) |
def pipeline_simulation(Valid_resource, names_to_valid, path_save, |
model, Info, device, initializer_model, |
treat_text_fun=None): |
n_attempt = 5 |
batch_test = 100 |
for name in names_to_valid: |
name = re.sub("\.csv", "",name) |
Path(path_save + name + "/img").mkdir(parents=True, exist_ok=True) |
roc_stats = defaultdict(lambda: defaultdict( |
lambda: defaultdict( |
list |
) |
) |
) |
all_metrics = [] |
for name in names_to_valid: |
data = Valid_resource[Valid_resource['domain'] == name].reset_index().drop("index", axis=1) |
for attempt in range(n_attempt): |
print("---"*4,"attempt", attempt, "---"*4) |
data_train_loader, data_test_loader, _ , _ = prepare_data(data, |
train_size_per_class = Info['k_spt'], |
batch_size = {'train': Info['inner_batch_size'], |
'test': batch_test}, |
max_seq_length = Info['max_seq_length'], |
tokenizer = Info['tokenizer'], |
input = "text", |
output = "label", |
treat_text_fun=treat_text_fun) |
logits, X_embedded, labels, features = train_loop(data_train_loader, data_test_loader, |
model, device, |
epoch = Info['inner_update_step'], |
lr=Info['inner_update_lr'], |
print_info=False, |
name = name) |
name_domain = re.sub("\.csv", "",name) |
metrics = plot(logits, X_embedded, labels, |
threshold=Info['threshold'], show = False, |
namefig= path_save + name_domain + "/img/" + str(attempt) + 'plots', |
make_plot = True, print_stats = False, save = True) |
fpr, tpr, _ = roc_curve(labels, torch.sigmoid(logits).squeeze()) |
metrics['name'] = name_domain |
metrics['layer_size'] = Info['bert_layers'] |
metrics['attempt'] = attempt |
roc_stats[name_domain][str(Info['bert_layers'])]['fpr'].append(fpr.tolist()) |
roc_stats[name_domain][str(Info['bert_layers'])]['tpr'].append(tpr.tolist()) |
all_metrics.append(metrics) |
pd.DataFrame(all_metrics).to_csv(path_save+ "metrics.csv") |
roc_path = path_save + "roc_stats.json" |
with open(roc_path, 'w') as fp: |
json.dump(roc_stats, fp) |
del fpr, tpr, logits, X_embedded, labels |
del features, metrics, _ |
save_info = Info.copy() |
save_info['model'] = initializer_model.tokenizer.name_or_path |
save_info.pop("tokenizer") |
save_info.pop("bert_layers") |
info_path = path_save+"info.json" |
with open(info_path, 'w') as fp: |
json.dump(save_info, fp) |
def load_data_statistics(paths, names): |
size = [] |
pos = [] |
neg = [] |
for p in paths: |
data = pd.read_csv(p) |
data = data.dropna() |
size.append(len(data)) |
pos.append(data['labels'].value_counts()[1]) |
neg.append(data['labels'].value_counts()[0]) |
del data |
info_load = pd.DataFrame({ |
"size":size, |
"pos":pos, |
"neg":neg, |
"names":names, |
"paths": paths }) |
return info_load |
def load_data(train_info_load): |
col = ['abstract','title', 'labels', 'domain'] |
data_train = pd.DataFrame(columns=col) |
for p in train_info_load['paths']: |
data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']] |
data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']] |
data_temp['domain'] = os.path.basename(p) |
data_train = pd.concat([data_train, data_temp]) |
data_train['text'] = data_train['title'] + data_train['abstract'].replace(np.nan, '') |
return( data_train \ |
.replace({"labels":{0:"negative", 1:'positive'}})\ |
.rename({"labels":"label"} , axis=1)\ |
.loc[ :,("text","domain","label")] |
) |