File size: 8,063 Bytes
0b11a42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import os

import numpy as np
import skorch
import torch
from sklearn.metrics import confusion_matrix, make_scorer
from skorch.callbacks import BatchScoring
from skorch.callbacks.scoring import ScoringBase, _cache_net_forward_iter
from skorch.callbacks.training import Checkpoint

from .LRCallback import LearningRateDecayCallback

writer = None

def accuracy_score(y_true, y_pred: torch.tensor,task:str=None,mirna_flag:bool = False):
    #sample 
    
    # premirna
    if task == "premirna":
        y_pred = y_pred[:,:-1]
        miRNA_idx = np.where(y_true.squeeze()==mirna_flag)
        correct = torch.max(y_pred,1).indices.cpu().numpy()[miRNA_idx] == mirna_flag
        return sum(correct) 

    # sncrna
    if task == "sncrna":
        y_pred = y_pred[:,:-1]
        # correct is of [samples], where each entry is true if it was found in top k
        correct = torch.max(y_pred,1).indices.cpu().numpy() == y_true.squeeze()

        return sum(correct) / y_pred.shape[0]


def accuracy_score_tcga(y_true, y_pred):
    
    if torch.is_tensor(y_pred):
        y_pred = y_pred.clone().detach().cpu().numpy()
    if torch.is_tensor(y_true):
        y_true = y_true.clone().detach().cpu().numpy()
    
    #y pred contains logits | samples weights
    sample_weight = y_pred[:,-1]
    y_pred = np.argmax(y_pred[:,:-1],axis=1)

    C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight)
    with np.errstate(divide='ignore', invalid='ignore'):
        per_class = np.diag(C) / C.sum(axis=1)
    if np.any(np.isnan(per_class)):
        per_class = per_class[~np.isnan(per_class)]
    score = np.mean(per_class)
    return score

def score_callbacks(cfg):

    acc_scorer = make_scorer(accuracy_score,task=cfg["task"])
    if cfg['task'] == 'tcga':
        acc_scorer = make_scorer(accuracy_score_tcga)


    if cfg["task"] == "premirna":
        acc_scorer_mirna = make_scorer(accuracy_score,task=cfg["task"],mirna_flag = True)
        
        val_score_callback_mirna = BatchScoringPremirna( mirna_flag=True,
            scoring = acc_scorer_mirna, lower_is_better=False, name="val_acc_mirna")

        train_score_callback_mirna = BatchScoringPremirna(mirna_flag=True,
            scoring = acc_scorer_mirna, on_train=True, lower_is_better=False, name="train_acc_mirna")
        
        val_score_callback = BatchScoringPremirna(mirna_flag=False,
            scoring = acc_scorer, lower_is_better=False, name="val_acc")

        train_score_callback = BatchScoringPremirna(mirna_flag=False,
            scoring = acc_scorer, on_train=True, lower_is_better=False, name="train_acc")


        scoring_callbacks = [
                train_score_callback,
                train_score_callback_mirna
                            ]
        if cfg["train_split"]:
            scoring_callbacks.extend([val_score_callback_mirna,val_score_callback])
    
    if cfg["task"] in ["sncrna", "tcga"]:

        val_score_callback = BatchScoring(acc_scorer, lower_is_better=False, name="val_acc")
        train_score_callback = BatchScoring(
            acc_scorer, on_train=True, lower_is_better=False, name="train_acc"
        )
        scoring_callbacks = [train_score_callback]

        #tcga dataset has a predifined valid split, so train_split is false, but still valid metric is required
        #TODO: remove predifined valid from tcga from prepare_data_tcga
        if cfg["train_split"] or cfg['task'] == 'tcga':
            scoring_callbacks.append(val_score_callback)

    return scoring_callbacks

def get_callbacks(path,cfg):

    callback_list = [("lrcallback", LearningRateDecayCallback)]
    if cfg['tensorboard'] == True:
        from .tbWriter import writer
        callback_list.append(MetricsVizualization)

    if (cfg["train_split"] or cfg['task'] == 'tcga') and cfg['inference'] == False:
        monitor = "val_acc_best"
        if cfg['trained_on'] == 'full':
            monitor = 'train_acc_best'
        ckpt_path = path+"/ckpt/"
        try:
            os.mkdir(ckpt_path)
        except:
            pass
        model_name = f'model_params_{cfg["task"]}.pt'
        callback_list.append(Checkpoint(monitor=monitor, dirname=ckpt_path,f_params=model_name))

    scoring_callbacks = score_callbacks(cfg)
    #TODO: For some reason scoring callbaks have to be inserted before checpoint and metrics viz callbacks
    #otherwise NeuralNet notify function throws an exception
    callback_list[1:1] = scoring_callbacks

    return callback_list


class MetricsVizualization(skorch.callbacks.Callback):
    def __init__(self, batch_idx=0) -> None:
        super().__init__()
        self.batch_idx = batch_idx

    # TODO: Change to display metrics at epoch ends
    def on_batch_end(self, net, training, **kwargs):
        # validation batch
        if not training:
            # log val accuracy. accessing net.history:[ epoch ,batches, last batch,column in batch]
            writer.add_scalar(
                "Accuracy/val_acc",
                net.history[-1, "batches", -1, "val_acc"],
                self.batch_idx,
            )
            # log val loss
            writer.add_scalar(
                "Loss/val_loss",
                net.history[-1, "batches", -1, "valid_loss"],
                self.batch_idx,
            )
            # update batch idx after validation on batch is computed
        # train batch
        else:
            # log lr
            writer.add_scalar("Metrics/lr", net.lr, self.batch_idx)
            # log train accuracy
            writer.add_scalar(
                "Accuracy/train_acc",
                net.history[-1, "batches", -1, "train_acc"],
                self.batch_idx,
            )
            # log train loss
            writer.add_scalar(
                "Loss/train_loss",
                net.history[-1, "batches", -1, "train_loss"],
                self.batch_idx,
            )
            self.batch_idx += 1

class BatchScoringPremirna(ScoringBase):
    def __init__(self,mirna_flag:bool = False,*args,**kwargs):
        super().__init__(*args,**kwargs)
        #self.total_num_samples = total_num_samples
        self.total_num_samples = 0
        self.mirna_flag = mirna_flag
        self.first_batch_flag = True
    def on_batch_end(self, net, X, y, training, **kwargs):
        if training != self.on_train:
            return

        y_preds = [kwargs['y_pred']]
        #only for the first batch: get no. of samples belonging to same class samples
        if self.first_batch_flag:
            self.total_num_samples += sum(kwargs["batch"][1] == self.mirna_flag).detach().cpu().numpy()[0]

        with _cache_net_forward_iter(net, self.use_caching, y_preds) as cached_net:
            # In case of y=None we will not have gathered any samples.
            # We expect the scoring function to deal with y=None.
            y = None if y is None else self.target_extractor(y)
            try:
                score = self._scoring(cached_net, X, y)
                cached_net.history.record_batch(self.name_, score)
            except KeyError:
                pass
    def get_avg_score(self, history):
        if self.on_train:
            bs_key = 'train_batch_size'
        else:
            bs_key = 'valid_batch_size'

        weights, scores = list(zip(
            *history[-1, 'batches', :, [bs_key, self.name_]]))
        #score_avg = np.average(scores, weights=weights)
        score_avg = sum(scores)/self.total_num_samples
        return score_avg

    # pylint: disable=unused-argument
    def on_epoch_end(self, net, **kwargs):
        self.first_batch_flag = False
        history = net.history
        try:  # don't raise if there is no valid data
            history[-1, 'batches', :, self.name_]
        except KeyError:
            return

        score_avg = self.get_avg_score(history)
        is_best = self._is_best_score(score_avg)
        if is_best:
            self.best_score_ = score_avg

        history.record(self.name_, score_avg)
        if is_best is not None:
            history.record(self.name_ + '_best', bool(is_best))