|
import os |
|
|
|
import numpy as np |
|
import skorch |
|
import torch |
|
from sklearn.metrics import confusion_matrix, make_scorer |
|
from skorch.callbacks import BatchScoring |
|
from skorch.callbacks.scoring import ScoringBase, _cache_net_forward_iter |
|
from skorch.callbacks.training import Checkpoint |
|
|
|
from .LRCallback import LearningRateDecayCallback |
|
|
|
writer = None |
|
|
|
def accuracy_score(y_true, y_pred: torch.tensor,task:str=None,mirna_flag:bool = False): |
|
|
|
|
|
|
|
if task == "premirna": |
|
y_pred = y_pred[:,:-1] |
|
miRNA_idx = np.where(y_true.squeeze()==mirna_flag) |
|
correct = torch.max(y_pred,1).indices.cpu().numpy()[miRNA_idx] == mirna_flag |
|
return sum(correct) |
|
|
|
|
|
if task == "sncrna": |
|
y_pred = y_pred[:,:-1] |
|
|
|
correct = torch.max(y_pred,1).indices.cpu().numpy() == y_true.squeeze() |
|
|
|
return sum(correct) / y_pred.shape[0] |
|
|
|
|
|
def accuracy_score_tcga(y_true, y_pred): |
|
|
|
if torch.is_tensor(y_pred): |
|
y_pred = y_pred.clone().detach().cpu().numpy() |
|
if torch.is_tensor(y_true): |
|
y_true = y_true.clone().detach().cpu().numpy() |
|
|
|
|
|
sample_weight = y_pred[:,-1] |
|
y_pred = np.argmax(y_pred[:,:-1],axis=1) |
|
|
|
C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight) |
|
with np.errstate(divide='ignore', invalid='ignore'): |
|
per_class = np.diag(C) / C.sum(axis=1) |
|
if np.any(np.isnan(per_class)): |
|
per_class = per_class[~np.isnan(per_class)] |
|
score = np.mean(per_class) |
|
return score |
|
|
|
def score_callbacks(cfg): |
|
|
|
acc_scorer = make_scorer(accuracy_score,task=cfg["task"]) |
|
if cfg['task'] == 'tcga': |
|
acc_scorer = make_scorer(accuracy_score_tcga) |
|
|
|
|
|
if cfg["task"] == "premirna": |
|
acc_scorer_mirna = make_scorer(accuracy_score,task=cfg["task"],mirna_flag = True) |
|
|
|
val_score_callback_mirna = BatchScoringPremirna( mirna_flag=True, |
|
scoring = acc_scorer_mirna, lower_is_better=False, name="val_acc_mirna") |
|
|
|
train_score_callback_mirna = BatchScoringPremirna(mirna_flag=True, |
|
scoring = acc_scorer_mirna, on_train=True, lower_is_better=False, name="train_acc_mirna") |
|
|
|
val_score_callback = BatchScoringPremirna(mirna_flag=False, |
|
scoring = acc_scorer, lower_is_better=False, name="val_acc") |
|
|
|
train_score_callback = BatchScoringPremirna(mirna_flag=False, |
|
scoring = acc_scorer, on_train=True, lower_is_better=False, name="train_acc") |
|
|
|
|
|
scoring_callbacks = [ |
|
train_score_callback, |
|
train_score_callback_mirna |
|
] |
|
if cfg["train_split"]: |
|
scoring_callbacks.extend([val_score_callback_mirna,val_score_callback]) |
|
|
|
if cfg["task"] in ["sncrna", "tcga"]: |
|
|
|
val_score_callback = BatchScoring(acc_scorer, lower_is_better=False, name="val_acc") |
|
train_score_callback = BatchScoring( |
|
acc_scorer, on_train=True, lower_is_better=False, name="train_acc" |
|
) |
|
scoring_callbacks = [train_score_callback] |
|
|
|
|
|
|
|
if cfg["train_split"] or cfg['task'] == 'tcga': |
|
scoring_callbacks.append(val_score_callback) |
|
|
|
return scoring_callbacks |
|
|
|
def get_callbacks(path,cfg): |
|
|
|
callback_list = [("lrcallback", LearningRateDecayCallback)] |
|
if cfg['tensorboard'] == True: |
|
from .tbWriter import writer |
|
callback_list.append(MetricsVizualization) |
|
|
|
if (cfg["train_split"] or cfg['task'] == 'tcga') and cfg['inference'] == False: |
|
monitor = "val_acc_best" |
|
if cfg['trained_on'] == 'full': |
|
monitor = 'train_acc_best' |
|
ckpt_path = path+"/ckpt/" |
|
try: |
|
os.mkdir(ckpt_path) |
|
except: |
|
pass |
|
model_name = f'model_params_{cfg["task"]}.pt' |
|
callback_list.append(Checkpoint(monitor=monitor, dirname=ckpt_path,f_params=model_name)) |
|
|
|
scoring_callbacks = score_callbacks(cfg) |
|
|
|
|
|
callback_list[1:1] = scoring_callbacks |
|
|
|
return callback_list |
|
|
|
|
|
class MetricsVizualization(skorch.callbacks.Callback): |
|
def __init__(self, batch_idx=0) -> None: |
|
super().__init__() |
|
self.batch_idx = batch_idx |
|
|
|
|
|
def on_batch_end(self, net, training, **kwargs): |
|
|
|
if not training: |
|
|
|
writer.add_scalar( |
|
"Accuracy/val_acc", |
|
net.history[-1, "batches", -1, "val_acc"], |
|
self.batch_idx, |
|
) |
|
|
|
writer.add_scalar( |
|
"Loss/val_loss", |
|
net.history[-1, "batches", -1, "valid_loss"], |
|
self.batch_idx, |
|
) |
|
|
|
|
|
else: |
|
|
|
writer.add_scalar("Metrics/lr", net.lr, self.batch_idx) |
|
|
|
writer.add_scalar( |
|
"Accuracy/train_acc", |
|
net.history[-1, "batches", -1, "train_acc"], |
|
self.batch_idx, |
|
) |
|
|
|
writer.add_scalar( |
|
"Loss/train_loss", |
|
net.history[-1, "batches", -1, "train_loss"], |
|
self.batch_idx, |
|
) |
|
self.batch_idx += 1 |
|
|
|
class BatchScoringPremirna(ScoringBase): |
|
def __init__(self,mirna_flag:bool = False,*args,**kwargs): |
|
super().__init__(*args,**kwargs) |
|
|
|
self.total_num_samples = 0 |
|
self.mirna_flag = mirna_flag |
|
self.first_batch_flag = True |
|
def on_batch_end(self, net, X, y, training, **kwargs): |
|
if training != self.on_train: |
|
return |
|
|
|
y_preds = [kwargs['y_pred']] |
|
|
|
if self.first_batch_flag: |
|
self.total_num_samples += sum(kwargs["batch"][1] == self.mirna_flag).detach().cpu().numpy()[0] |
|
|
|
with _cache_net_forward_iter(net, self.use_caching, y_preds) as cached_net: |
|
|
|
|
|
y = None if y is None else self.target_extractor(y) |
|
try: |
|
score = self._scoring(cached_net, X, y) |
|
cached_net.history.record_batch(self.name_, score) |
|
except KeyError: |
|
pass |
|
def get_avg_score(self, history): |
|
if self.on_train: |
|
bs_key = 'train_batch_size' |
|
else: |
|
bs_key = 'valid_batch_size' |
|
|
|
weights, scores = list(zip( |
|
*history[-1, 'batches', :, [bs_key, self.name_]])) |
|
|
|
score_avg = sum(scores)/self.total_num_samples |
|
return score_avg |
|
|
|
|
|
def on_epoch_end(self, net, **kwargs): |
|
self.first_batch_flag = False |
|
history = net.history |
|
try: |
|
history[-1, 'batches', :, self.name_] |
|
except KeyError: |
|
return |
|
|
|
score_avg = self.get_avg_score(history) |
|
is_best = self._is_best_score(score_avg) |
|
if is_best: |
|
self.best_score_ = score_avg |
|
|
|
history.record(self.name_, score_avg) |
|
if is_best is not None: |
|
history.record(self.name_ + '_best', bool(is_best)) |
|
|