File size: 8,063 Bytes
0b11a42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import os
import numpy as np
import skorch
import torch
from sklearn.metrics import confusion_matrix, make_scorer
from skorch.callbacks import BatchScoring
from skorch.callbacks.scoring import ScoringBase, _cache_net_forward_iter
from skorch.callbacks.training import Checkpoint
from .LRCallback import LearningRateDecayCallback
writer = None
def accuracy_score(y_true, y_pred: torch.tensor,task:str=None,mirna_flag:bool = False):
#sample
# premirna
if task == "premirna":
y_pred = y_pred[:,:-1]
miRNA_idx = np.where(y_true.squeeze()==mirna_flag)
correct = torch.max(y_pred,1).indices.cpu().numpy()[miRNA_idx] == mirna_flag
return sum(correct)
# sncrna
if task == "sncrna":
y_pred = y_pred[:,:-1]
# correct is of [samples], where each entry is true if it was found in top k
correct = torch.max(y_pred,1).indices.cpu().numpy() == y_true.squeeze()
return sum(correct) / y_pred.shape[0]
def accuracy_score_tcga(y_true, y_pred):
if torch.is_tensor(y_pred):
y_pred = y_pred.clone().detach().cpu().numpy()
if torch.is_tensor(y_true):
y_true = y_true.clone().detach().cpu().numpy()
#y pred contains logits | samples weights
sample_weight = y_pred[:,-1]
y_pred = np.argmax(y_pred[:,:-1],axis=1)
C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight)
with np.errstate(divide='ignore', invalid='ignore'):
per_class = np.diag(C) / C.sum(axis=1)
if np.any(np.isnan(per_class)):
per_class = per_class[~np.isnan(per_class)]
score = np.mean(per_class)
return score
def score_callbacks(cfg):
acc_scorer = make_scorer(accuracy_score,task=cfg["task"])
if cfg['task'] == 'tcga':
acc_scorer = make_scorer(accuracy_score_tcga)
if cfg["task"] == "premirna":
acc_scorer_mirna = make_scorer(accuracy_score,task=cfg["task"],mirna_flag = True)
val_score_callback_mirna = BatchScoringPremirna( mirna_flag=True,
scoring = acc_scorer_mirna, lower_is_better=False, name="val_acc_mirna")
train_score_callback_mirna = BatchScoringPremirna(mirna_flag=True,
scoring = acc_scorer_mirna, on_train=True, lower_is_better=False, name="train_acc_mirna")
val_score_callback = BatchScoringPremirna(mirna_flag=False,
scoring = acc_scorer, lower_is_better=False, name="val_acc")
train_score_callback = BatchScoringPremirna(mirna_flag=False,
scoring = acc_scorer, on_train=True, lower_is_better=False, name="train_acc")
scoring_callbacks = [
train_score_callback,
train_score_callback_mirna
]
if cfg["train_split"]:
scoring_callbacks.extend([val_score_callback_mirna,val_score_callback])
if cfg["task"] in ["sncrna", "tcga"]:
val_score_callback = BatchScoring(acc_scorer, lower_is_better=False, name="val_acc")
train_score_callback = BatchScoring(
acc_scorer, on_train=True, lower_is_better=False, name="train_acc"
)
scoring_callbacks = [train_score_callback]
#tcga dataset has a predifined valid split, so train_split is false, but still valid metric is required
#TODO: remove predifined valid from tcga from prepare_data_tcga
if cfg["train_split"] or cfg['task'] == 'tcga':
scoring_callbacks.append(val_score_callback)
return scoring_callbacks
def get_callbacks(path,cfg):
callback_list = [("lrcallback", LearningRateDecayCallback)]
if cfg['tensorboard'] == True:
from .tbWriter import writer
callback_list.append(MetricsVizualization)
if (cfg["train_split"] or cfg['task'] == 'tcga') and cfg['inference'] == False:
monitor = "val_acc_best"
if cfg['trained_on'] == 'full':
monitor = 'train_acc_best'
ckpt_path = path+"/ckpt/"
try:
os.mkdir(ckpt_path)
except:
pass
model_name = f'model_params_{cfg["task"]}.pt'
callback_list.append(Checkpoint(monitor=monitor, dirname=ckpt_path,f_params=model_name))
scoring_callbacks = score_callbacks(cfg)
#TODO: For some reason scoring callbaks have to be inserted before checpoint and metrics viz callbacks
#otherwise NeuralNet notify function throws an exception
callback_list[1:1] = scoring_callbacks
return callback_list
class MetricsVizualization(skorch.callbacks.Callback):
def __init__(self, batch_idx=0) -> None:
super().__init__()
self.batch_idx = batch_idx
# TODO: Change to display metrics at epoch ends
def on_batch_end(self, net, training, **kwargs):
# validation batch
if not training:
# log val accuracy. accessing net.history:[ epoch ,batches, last batch,column in batch]
writer.add_scalar(
"Accuracy/val_acc",
net.history[-1, "batches", -1, "val_acc"],
self.batch_idx,
)
# log val loss
writer.add_scalar(
"Loss/val_loss",
net.history[-1, "batches", -1, "valid_loss"],
self.batch_idx,
)
# update batch idx after validation on batch is computed
# train batch
else:
# log lr
writer.add_scalar("Metrics/lr", net.lr, self.batch_idx)
# log train accuracy
writer.add_scalar(
"Accuracy/train_acc",
net.history[-1, "batches", -1, "train_acc"],
self.batch_idx,
)
# log train loss
writer.add_scalar(
"Loss/train_loss",
net.history[-1, "batches", -1, "train_loss"],
self.batch_idx,
)
self.batch_idx += 1
class BatchScoringPremirna(ScoringBase):
def __init__(self,mirna_flag:bool = False,*args,**kwargs):
super().__init__(*args,**kwargs)
#self.total_num_samples = total_num_samples
self.total_num_samples = 0
self.mirna_flag = mirna_flag
self.first_batch_flag = True
def on_batch_end(self, net, X, y, training, **kwargs):
if training != self.on_train:
return
y_preds = [kwargs['y_pred']]
#only for the first batch: get no. of samples belonging to same class samples
if self.first_batch_flag:
self.total_num_samples += sum(kwargs["batch"][1] == self.mirna_flag).detach().cpu().numpy()[0]
with _cache_net_forward_iter(net, self.use_caching, y_preds) as cached_net:
# In case of y=None we will not have gathered any samples.
# We expect the scoring function to deal with y=None.
y = None if y is None else self.target_extractor(y)
try:
score = self._scoring(cached_net, X, y)
cached_net.history.record_batch(self.name_, score)
except KeyError:
pass
def get_avg_score(self, history):
if self.on_train:
bs_key = 'train_batch_size'
else:
bs_key = 'valid_batch_size'
weights, scores = list(zip(
*history[-1, 'batches', :, [bs_key, self.name_]]))
#score_avg = np.average(scores, weights=weights)
score_avg = sum(scores)/self.total_num_samples
return score_avg
# pylint: disable=unused-argument
def on_epoch_end(self, net, **kwargs):
self.first_batch_flag = False
history = net.history
try: # don't raise if there is no valid data
history[-1, 'batches', :, self.name_]
except KeyError:
return
score_avg = self.get_avg_score(history)
is_best = self._is_best_score(score_avg)
if is_best:
self.best_score_ = score_avg
history.record(self.name_, score_avg)
if is_best is not None:
history.record(self.name_ + '_best', bool(is_best))
|