File size: 13,160 Bytes
0b11a42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
import logging
import os
import pickle

import skorch
import torch
from skorch.dataset import Dataset, ValidSplit
from skorch.setter import optimizer_setter
from skorch.utils import is_dataset, to_device

logger = logging.getLogger(__name__)
#from ..tbWriter import writer


class Net(skorch.NeuralNet):
    def __init__(
        self,
        clip=0.25,
        top_k=1,
        correct=0,
        save_embedding=False,
        gene_embedds=[],
        second_input_embedd=[],
        confidence_threshold = 0.95,
        *args,
        **kwargs
    ):
        self.clip = clip
        self.curr_epoch = 0
        super(Net, self).__init__(*args, **kwargs)
        self.correct = correct
        self.save_embedding = save_embedding
        self.gene_embedds = gene_embedds
        self.second_input_embedds = second_input_embedd
        self.main_config = kwargs["module__main_config"]
        self.train_config = self.main_config["train_config"]
        self.top_k =  self.train_config.top_k
        self.num_classes = self.main_config["model_config"].num_classes
        self.labels_mapping_path = self.train_config.labels_mapping_path
        if self.labels_mapping_path:
            with open(self.labels_mapping_path, 'rb') as handle:
                self.labels_mapping_dict = pickle.load(handle)
        self.confidence_threshold = confidence_threshold
        self.max_epochs = kwargs["max_epochs"]
        self.task = '' #is set in utils.instantiate_predictor
        self.log_tb = False

        


    def set_save_epoch(self):
        ''' 
        scale best train epoch by valid size
        '''
        if self.task !='tcga':
            if self.train_split:
                self.save_epoch = self.main_config["train_config"].train_epoch
            else:
                self.save_epoch = round(self.main_config["train_config"].train_epoch*\
                    (1+self.main_config["valid_size"]))

    def save_benchmark_model(self):
        '''
        saves benchmark epochs when train_split is none 
        '''
        try:
            os.mkdir("ckpt")
        except:
            pass
        cwd = os.getcwd()+"/ckpt/"
        self.save_params(f_params= f'{cwd}/model_params_{self.main_config["task"]}.pt')


    def fit(self, X, y=None, valid_ds=None,**fit_params):
        #all sequence lengths should be saved to compute the median based 
        self.all_lengths = [[] for i in range(self.num_classes)]
        self.median_lengths = []

        if not self.warm_start or not self.initialized_:
            self.initialize()

        if valid_ds:
            self.validation_dataset = valid_ds
        else:
            self.validation_dataset = None

        self.partial_fit(X, y, **fit_params)
        return self

    def fit_loop(self, X, y=None, epochs=None, **fit_params):
        #if id then train longer otherwise stop at 0.99
        rounding_digits = 3
        if self.main_config['trained_on'] == 'full':
            rounding_digits = 2
        self.check_data(X, y)
        epochs = epochs if epochs is not None else self.max_epochs

        dataset_train, dataset_valid = self.get_split_datasets(X, y, **fit_params)

        if self.validation_dataset is not None:
            dataset_valid = self.validation_dataset.keywords["valid_ds"]

        on_epoch_kwargs = {
            "dataset_train": dataset_train,
            "dataset_valid": dataset_valid,
        }

        iterator_train = self.get_iterator(dataset_train, training=True)
        iterator_valid = None
        if dataset_valid is not None:
            iterator_valid = self.get_iterator(dataset_valid, training=False)

        self.set_save_epoch()

        for epoch_no in range(epochs):
            #save model if training only on test set
            self.curr_epoch = epoch_no
            #save epoch is scaled by best train epoch
            #save benchmark only when training on boith train and val sets
            if self.task != 'tcga' and epoch_no == self.save_epoch and self.train_split == None:
                self.save_benchmark_model()
                
            self.notify("on_epoch_begin", **on_epoch_kwargs)

            self.run_single_epoch(
                iterator_train,
                training=True,
                prefix="train",
                step_fn=self.train_step,
                **fit_params
            )

            if dataset_valid is not None:
                self.run_single_epoch(
                    iterator_valid,
                    training=False,
                    prefix="valid",
                    step_fn=self.validation_step,
                    **fit_params
                )
        
            
            self.notify("on_epoch_end", **on_epoch_kwargs)
            #manual early stopping for tcga
            if self.task == 'tcga':
                train_acc = round(self.history[:,'train_acc'][-1],rounding_digits)
                if train_acc == 1:
                    break


            
        return self

    def train_step(self, X, y=None):
        y = X[1]
        X = X[0]
        sample_weights = X[:,-1]
        if self.device == 'cuda':
            sample_weights = sample_weights.to(self.train_config.device)
        self.module_.train()
        self.module_.zero_grad()
        gene_embedd, second_input_embedd, activations,_,_ = self.module_(X[:,:-1],train=True)
        #curr_epoch is passed to loss as it is used to switch loss criteria from unsup. -> sup
        loss = self.get_loss([gene_embedd,second_input_embedd,activations,self.curr_epoch], y)
        
        ###sup loss should be X with samples weight and aggregated

        loss = loss*sample_weights
        loss = loss.mean()

        loss.backward()

        # TODO: clip only some parameters
        torch.nn.utils.clip_grad_norm_(self.module_.parameters(), self.clip)
        self.optimizer_.step()

        return {"X":X,"y":y,"loss": loss, "y_pred": [gene_embedd,second_input_embedd,activations]}

    def validation_step(self, X, y=None):
        y = X[1]
        X = X[0]
        sample_weights = X[:,-1]
        if self.device == 'cuda':
            sample_weights = sample_weights.to(self.train_config.device)
        self.module_.eval()
        with torch.no_grad():
            gene_embedd, second_input_embedd, activations,_,_ = self.module_(X[:,:-1])
            loss = self.get_loss([gene_embedd,second_input_embedd,activations,self.curr_epoch], y)

        ###sup loss should be X with samples weight and aggregated

        loss = loss*sample_weights
        loss = loss.mean()

        return {"X":X,"y":y,"loss": loss, "y_pred": [gene_embedd,second_input_embedd,activations]}

    def get_attention_scores(self, X, y=None):
        '''
        returns attention scores for a given input
        '''
        self.module_.eval()
        with torch.no_grad():
            _, _, _,attn_scores_first,attn_scores_second = self.module_(X[:,:-1])

        attn_scores_first = attn_scores_first.detach().cpu().numpy()
        if attn_scores_second is not None:
            attn_scores_second = attn_scores_second.detach().cpu().numpy()
        return attn_scores_first,attn_scores_second

    def predict(self, X):
        self.module_.train(False)
        embedds = self.module_(X[:,:-1])
        sample_weights = X[:,-1]
        if self.device == 'cuda':
            sample_weights = sample_weights.to(self.train_config.device)

        gene_embedd, second_input_embedd, activations,_,_ = embedds
        if self.save_embedding:
            self.gene_embedds.append(gene_embedd.detach().cpu())
            #in case only a single transformer is deployed, then second_input_embedd are None. thus have no detach()
            if second_input_embedd is not None:
                self.second_input_embedds.append(second_input_embedd.detach().cpu())
        
        predictions = torch.cat([activations,sample_weights[:,None]],dim=1)
        return predictions


    def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
        # log gradients and weights
        for _, m in self.module_.named_modules():
            for pn, p in m.named_parameters():
                if pn.endswith("weight") and pn.find("norm") < 0:
                    if p.grad != None:
                        if self.log_tb:
                            from ..callbacks.tbWriter import writer
                            writer.add_histogram("weights/" + pn, p, len(net.history))
                            writer.add_histogram(
                                "gradients/" + pn, p.grad.data, len(net.history)
                            )

        return

    def configure_opt(self, l2_weight_decay):
        no_decay = ["bias", "LayerNorm.weight"]
        params_decay = [
            p
            for n, p in self.module_.named_parameters()
            if not any(nd in n for nd in no_decay)
        ]
        params_nodecay = [
            p
            for n, p in self.module_.named_parameters()
            if any(nd in n for nd in no_decay)
        ]
        optim_groups = [
            {"params": params_decay, "weight_decay": l2_weight_decay},
            {"params": params_nodecay, "weight_decay": 0.0},
        ]
        return optim_groups

    def initialize_optimizer(self, triggered_directly=True):
        """Initialize the model optimizer. If ``self.optimizer__lr``
        is not set, use ``self.lr`` instead.

        Parameters
        ----------
        triggered_directly : bool (default=True)
          Only relevant when optimizer is re-initialized.
          Initialization of the optimizer can be triggered directly
          (e.g. when lr was changed) or indirectly (e.g. when the
          module was re-initialized). If and only if the former
          happens, the user should receive a message informing them
          about the parameters that caused the re-initialization.

        """
        # get learning rate from train config
        optimizer_params = self.main_config["train_config"]
        kwargs = {}
        kwargs["lr"] = optimizer_params.learning_rate
        # get l2 weight decay to init opt params
        args = self.configure_opt(optimizer_params.l2_weight_decay)

        if self.initialized_ and self.verbose:
            msg = self._format_reinit_msg(
                "optimizer", kwargs, triggered_directly=triggered_directly
            )
            logger.info(msg)

        self.optimizer_ = self.optimizer(args, lr=kwargs["lr"])

        self._register_virtual_param(
            ["optimizer__param_groups__*__*", "optimizer__*", "lr"],
            optimizer_setter,
        )

    def initialize_criterion(self):
        """Initializes the criterion."""
        # critereon takes train_config and model_config as an input.
        # we get both from the module parameters
        self.criterion_ = self.criterion(
            self.main_config
        )
        if isinstance(self.criterion_, torch.nn.Module):
            self.criterion_ = to_device(self.criterion_, self.device)
        return self

    def initialize_callbacks(self):
        """Initializes all callbacks and save the result in the
        ``callbacks_`` attribute.

        Both ``default_callbacks`` and ``callbacks`` are used (in that
        order). Callbacks may either be initialized or not, and if
        they don't have a name, the name is inferred from the class
        name. The ``initialize`` method is called on all callbacks.

        The final result will be a list of tuples, where each tuple
        consists of a name and an initialized callback. If names are
        not unique, a ValueError is raised.

        """
        if self.callbacks == "disable":
            self.callbacks_ = []
            return self

        callbacks_ = []

        class Dummy:
            # We cannot use None as dummy value since None is a
            # legitimate value to be set.
            pass

        for name, cb in self._uniquely_named_callbacks():
            # check if callback itself is changed
            param_callback = getattr(self, "callbacks__" + name, Dummy)
            if param_callback is not Dummy:  # callback itself was set
                cb = param_callback

            # below: check for callback params
            # don't set a parameter for non-existing callback

            # if the callback is lrcallback then initializa it with the train config,
            # which is an input to the module
            if name == "lrcallback":
                params["config"] = self.main_config["train_config"]
            else:
                params = self.get_params_for("callbacks__{}".format(name))
            if (cb is None) and params:
                raise ValueError(
                    "Trying to set a parameter for callback {} "
                    "which does not exist.".format(name)
                )
            if cb is None:
                continue

            if isinstance(cb, type):  # uninitialized:
                cb = cb(**params)
            else:
                cb.set_params(**params)
            cb.initialize()
            callbacks_.append((name, cb))

        self.callbacks_ = callbacks_

        return self