|
import logging |
|
import os |
|
import pickle |
|
|
|
import skorch |
|
import torch |
|
from skorch.dataset import Dataset, ValidSplit |
|
from skorch.setter import optimizer_setter |
|
from skorch.utils import is_dataset, to_device |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
class Net(skorch.NeuralNet): |
|
def __init__( |
|
self, |
|
clip=0.25, |
|
top_k=1, |
|
correct=0, |
|
save_embedding=False, |
|
gene_embedds=[], |
|
second_input_embedd=[], |
|
confidence_threshold = 0.95, |
|
*args, |
|
**kwargs |
|
): |
|
self.clip = clip |
|
self.curr_epoch = 0 |
|
super(Net, self).__init__(*args, **kwargs) |
|
self.correct = correct |
|
self.save_embedding = save_embedding |
|
self.gene_embedds = gene_embedds |
|
self.second_input_embedds = second_input_embedd |
|
self.main_config = kwargs["module__main_config"] |
|
self.train_config = self.main_config["train_config"] |
|
self.top_k = self.train_config.top_k |
|
self.num_classes = self.main_config["model_config"].num_classes |
|
self.labels_mapping_path = self.train_config.labels_mapping_path |
|
if self.labels_mapping_path: |
|
with open(self.labels_mapping_path, 'rb') as handle: |
|
self.labels_mapping_dict = pickle.load(handle) |
|
self.confidence_threshold = confidence_threshold |
|
self.max_epochs = kwargs["max_epochs"] |
|
self.task = '' |
|
self.log_tb = False |
|
|
|
|
|
|
|
|
|
def set_save_epoch(self): |
|
''' |
|
scale best train epoch by valid size |
|
''' |
|
if self.task !='tcga': |
|
if self.train_split: |
|
self.save_epoch = self.main_config["train_config"].train_epoch |
|
else: |
|
self.save_epoch = round(self.main_config["train_config"].train_epoch*\ |
|
(1+self.main_config["valid_size"])) |
|
|
|
def save_benchmark_model(self): |
|
''' |
|
saves benchmark epochs when train_split is none |
|
''' |
|
try: |
|
os.mkdir("ckpt") |
|
except: |
|
pass |
|
cwd = os.getcwd()+"/ckpt/" |
|
self.save_params(f_params= f'{cwd}/model_params_{self.main_config["task"]}.pt') |
|
|
|
|
|
def fit(self, X, y=None, valid_ds=None,**fit_params): |
|
|
|
self.all_lengths = [[] for i in range(self.num_classes)] |
|
self.median_lengths = [] |
|
|
|
if not self.warm_start or not self.initialized_: |
|
self.initialize() |
|
|
|
if valid_ds: |
|
self.validation_dataset = valid_ds |
|
else: |
|
self.validation_dataset = None |
|
|
|
self.partial_fit(X, y, **fit_params) |
|
return self |
|
|
|
def fit_loop(self, X, y=None, epochs=None, **fit_params): |
|
|
|
rounding_digits = 3 |
|
if self.main_config['trained_on'] == 'full': |
|
rounding_digits = 2 |
|
self.check_data(X, y) |
|
epochs = epochs if epochs is not None else self.max_epochs |
|
|
|
dataset_train, dataset_valid = self.get_split_datasets(X, y, **fit_params) |
|
|
|
if self.validation_dataset is not None: |
|
dataset_valid = self.validation_dataset.keywords["valid_ds"] |
|
|
|
on_epoch_kwargs = { |
|
"dataset_train": dataset_train, |
|
"dataset_valid": dataset_valid, |
|
} |
|
|
|
iterator_train = self.get_iterator(dataset_train, training=True) |
|
iterator_valid = None |
|
if dataset_valid is not None: |
|
iterator_valid = self.get_iterator(dataset_valid, training=False) |
|
|
|
self.set_save_epoch() |
|
|
|
for epoch_no in range(epochs): |
|
|
|
self.curr_epoch = epoch_no |
|
|
|
|
|
if self.task != 'tcga' and epoch_no == self.save_epoch and self.train_split == None: |
|
self.save_benchmark_model() |
|
|
|
self.notify("on_epoch_begin", **on_epoch_kwargs) |
|
|
|
self.run_single_epoch( |
|
iterator_train, |
|
training=True, |
|
prefix="train", |
|
step_fn=self.train_step, |
|
**fit_params |
|
) |
|
|
|
if dataset_valid is not None: |
|
self.run_single_epoch( |
|
iterator_valid, |
|
training=False, |
|
prefix="valid", |
|
step_fn=self.validation_step, |
|
**fit_params |
|
) |
|
|
|
|
|
self.notify("on_epoch_end", **on_epoch_kwargs) |
|
|
|
if self.task == 'tcga': |
|
train_acc = round(self.history[:,'train_acc'][-1],rounding_digits) |
|
if train_acc == 1: |
|
break |
|
|
|
|
|
|
|
return self |
|
|
|
def train_step(self, X, y=None): |
|
y = X[1] |
|
X = X[0] |
|
sample_weights = X[:,-1] |
|
if self.device == 'cuda': |
|
sample_weights = sample_weights.to(self.train_config.device) |
|
self.module_.train() |
|
self.module_.zero_grad() |
|
gene_embedd, second_input_embedd, activations,_,_ = self.module_(X[:,:-1],train=True) |
|
|
|
loss = self.get_loss([gene_embedd,second_input_embedd,activations,self.curr_epoch], y) |
|
|
|
|
|
|
|
loss = loss*sample_weights |
|
loss = loss.mean() |
|
|
|
loss.backward() |
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(self.module_.parameters(), self.clip) |
|
self.optimizer_.step() |
|
|
|
return {"X":X,"y":y,"loss": loss, "y_pred": [gene_embedd,second_input_embedd,activations]} |
|
|
|
def validation_step(self, X, y=None): |
|
y = X[1] |
|
X = X[0] |
|
sample_weights = X[:,-1] |
|
if self.device == 'cuda': |
|
sample_weights = sample_weights.to(self.train_config.device) |
|
self.module_.eval() |
|
with torch.no_grad(): |
|
gene_embedd, second_input_embedd, activations,_,_ = self.module_(X[:,:-1]) |
|
loss = self.get_loss([gene_embedd,second_input_embedd,activations,self.curr_epoch], y) |
|
|
|
|
|
|
|
loss = loss*sample_weights |
|
loss = loss.mean() |
|
|
|
return {"X":X,"y":y,"loss": loss, "y_pred": [gene_embedd,second_input_embedd,activations]} |
|
|
|
def get_attention_scores(self, X, y=None): |
|
''' |
|
returns attention scores for a given input |
|
''' |
|
self.module_.eval() |
|
with torch.no_grad(): |
|
_, _, _,attn_scores_first,attn_scores_second = self.module_(X[:,:-1]) |
|
|
|
attn_scores_first = attn_scores_first.detach().cpu().numpy() |
|
if attn_scores_second is not None: |
|
attn_scores_second = attn_scores_second.detach().cpu().numpy() |
|
return attn_scores_first,attn_scores_second |
|
|
|
def predict(self, X): |
|
self.module_.train(False) |
|
embedds = self.module_(X[:,:-1]) |
|
sample_weights = X[:,-1] |
|
if self.device == 'cuda': |
|
sample_weights = sample_weights.to(self.train_config.device) |
|
|
|
gene_embedd, second_input_embedd, activations,_,_ = embedds |
|
if self.save_embedding: |
|
self.gene_embedds.append(gene_embedd.detach().cpu()) |
|
|
|
if second_input_embedd is not None: |
|
self.second_input_embedds.append(second_input_embedd.detach().cpu()) |
|
|
|
predictions = torch.cat([activations,sample_weights[:,None]],dim=1) |
|
return predictions |
|
|
|
|
|
def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs): |
|
|
|
for _, m in self.module_.named_modules(): |
|
for pn, p in m.named_parameters(): |
|
if pn.endswith("weight") and pn.find("norm") < 0: |
|
if p.grad != None: |
|
if self.log_tb: |
|
from ..callbacks.tbWriter import writer |
|
writer.add_histogram("weights/" + pn, p, len(net.history)) |
|
writer.add_histogram( |
|
"gradients/" + pn, p.grad.data, len(net.history) |
|
) |
|
|
|
return |
|
|
|
def configure_opt(self, l2_weight_decay): |
|
no_decay = ["bias", "LayerNorm.weight"] |
|
params_decay = [ |
|
p |
|
for n, p in self.module_.named_parameters() |
|
if not any(nd in n for nd in no_decay) |
|
] |
|
params_nodecay = [ |
|
p |
|
for n, p in self.module_.named_parameters() |
|
if any(nd in n for nd in no_decay) |
|
] |
|
optim_groups = [ |
|
{"params": params_decay, "weight_decay": l2_weight_decay}, |
|
{"params": params_nodecay, "weight_decay": 0.0}, |
|
] |
|
return optim_groups |
|
|
|
def initialize_optimizer(self, triggered_directly=True): |
|
"""Initialize the model optimizer. If ``self.optimizer__lr`` |
|
is not set, use ``self.lr`` instead. |
|
|
|
Parameters |
|
---------- |
|
triggered_directly : bool (default=True) |
|
Only relevant when optimizer is re-initialized. |
|
Initialization of the optimizer can be triggered directly |
|
(e.g. when lr was changed) or indirectly (e.g. when the |
|
module was re-initialized). If and only if the former |
|
happens, the user should receive a message informing them |
|
about the parameters that caused the re-initialization. |
|
|
|
""" |
|
|
|
optimizer_params = self.main_config["train_config"] |
|
kwargs = {} |
|
kwargs["lr"] = optimizer_params.learning_rate |
|
|
|
args = self.configure_opt(optimizer_params.l2_weight_decay) |
|
|
|
if self.initialized_ and self.verbose: |
|
msg = self._format_reinit_msg( |
|
"optimizer", kwargs, triggered_directly=triggered_directly |
|
) |
|
logger.info(msg) |
|
|
|
self.optimizer_ = self.optimizer(args, lr=kwargs["lr"]) |
|
|
|
self._register_virtual_param( |
|
["optimizer__param_groups__*__*", "optimizer__*", "lr"], |
|
optimizer_setter, |
|
) |
|
|
|
def initialize_criterion(self): |
|
"""Initializes the criterion.""" |
|
|
|
|
|
self.criterion_ = self.criterion( |
|
self.main_config |
|
) |
|
if isinstance(self.criterion_, torch.nn.Module): |
|
self.criterion_ = to_device(self.criterion_, self.device) |
|
return self |
|
|
|
def initialize_callbacks(self): |
|
"""Initializes all callbacks and save the result in the |
|
``callbacks_`` attribute. |
|
|
|
Both ``default_callbacks`` and ``callbacks`` are used (in that |
|
order). Callbacks may either be initialized or not, and if |
|
they don't have a name, the name is inferred from the class |
|
name. The ``initialize`` method is called on all callbacks. |
|
|
|
The final result will be a list of tuples, where each tuple |
|
consists of a name and an initialized callback. If names are |
|
not unique, a ValueError is raised. |
|
|
|
""" |
|
if self.callbacks == "disable": |
|
self.callbacks_ = [] |
|
return self |
|
|
|
callbacks_ = [] |
|
|
|
class Dummy: |
|
|
|
|
|
pass |
|
|
|
for name, cb in self._uniquely_named_callbacks(): |
|
|
|
param_callback = getattr(self, "callbacks__" + name, Dummy) |
|
if param_callback is not Dummy: |
|
cb = param_callback |
|
|
|
|
|
|
|
|
|
|
|
|
|
if name == "lrcallback": |
|
params["config"] = self.main_config["train_config"] |
|
else: |
|
params = self.get_params_for("callbacks__{}".format(name)) |
|
if (cb is None) and params: |
|
raise ValueError( |
|
"Trying to set a parameter for callback {} " |
|
"which does not exist.".format(name) |
|
) |
|
if cb is None: |
|
continue |
|
|
|
if isinstance(cb, type): |
|
cb = cb(**params) |
|
else: |
|
cb.set_params(**params) |
|
cb.initialize() |
|
callbacks_.append((name, cb)) |
|
|
|
self.callbacks_ = callbacks_ |
|
|
|
return self |
|
|