Yak-hbdx's picture
uploaded TransfoRNA repo
0b11a42 verified
raw
history blame
13.2 kB
import logging
import os
import pickle
import skorch
import torch
from skorch.dataset import Dataset, ValidSplit
from skorch.setter import optimizer_setter
from skorch.utils import is_dataset, to_device
logger = logging.getLogger(__name__)
#from ..tbWriter import writer
class Net(skorch.NeuralNet):
def __init__(
self,
clip=0.25,
top_k=1,
correct=0,
save_embedding=False,
gene_embedds=[],
second_input_embedd=[],
confidence_threshold = 0.95,
*args,
**kwargs
):
self.clip = clip
self.curr_epoch = 0
super(Net, self).__init__(*args, **kwargs)
self.correct = correct
self.save_embedding = save_embedding
self.gene_embedds = gene_embedds
self.second_input_embedds = second_input_embedd
self.main_config = kwargs["module__main_config"]
self.train_config = self.main_config["train_config"]
self.top_k = self.train_config.top_k
self.num_classes = self.main_config["model_config"].num_classes
self.labels_mapping_path = self.train_config.labels_mapping_path
if self.labels_mapping_path:
with open(self.labels_mapping_path, 'rb') as handle:
self.labels_mapping_dict = pickle.load(handle)
self.confidence_threshold = confidence_threshold
self.max_epochs = kwargs["max_epochs"]
self.task = '' #is set in utils.instantiate_predictor
self.log_tb = False
def set_save_epoch(self):
'''
scale best train epoch by valid size
'''
if self.task !='tcga':
if self.train_split:
self.save_epoch = self.main_config["train_config"].train_epoch
else:
self.save_epoch = round(self.main_config["train_config"].train_epoch*\
(1+self.main_config["valid_size"]))
def save_benchmark_model(self):
'''
saves benchmark epochs when train_split is none
'''
try:
os.mkdir("ckpt")
except:
pass
cwd = os.getcwd()+"/ckpt/"
self.save_params(f_params= f'{cwd}/model_params_{self.main_config["task"]}.pt')
def fit(self, X, y=None, valid_ds=None,**fit_params):
#all sequence lengths should be saved to compute the median based
self.all_lengths = [[] for i in range(self.num_classes)]
self.median_lengths = []
if not self.warm_start or not self.initialized_:
self.initialize()
if valid_ds:
self.validation_dataset = valid_ds
else:
self.validation_dataset = None
self.partial_fit(X, y, **fit_params)
return self
def fit_loop(self, X, y=None, epochs=None, **fit_params):
#if id then train longer otherwise stop at 0.99
rounding_digits = 3
if self.main_config['trained_on'] == 'full':
rounding_digits = 2
self.check_data(X, y)
epochs = epochs if epochs is not None else self.max_epochs
dataset_train, dataset_valid = self.get_split_datasets(X, y, **fit_params)
if self.validation_dataset is not None:
dataset_valid = self.validation_dataset.keywords["valid_ds"]
on_epoch_kwargs = {
"dataset_train": dataset_train,
"dataset_valid": dataset_valid,
}
iterator_train = self.get_iterator(dataset_train, training=True)
iterator_valid = None
if dataset_valid is not None:
iterator_valid = self.get_iterator(dataset_valid, training=False)
self.set_save_epoch()
for epoch_no in range(epochs):
#save model if training only on test set
self.curr_epoch = epoch_no
#save epoch is scaled by best train epoch
#save benchmark only when training on boith train and val sets
if self.task != 'tcga' and epoch_no == self.save_epoch and self.train_split == None:
self.save_benchmark_model()
self.notify("on_epoch_begin", **on_epoch_kwargs)
self.run_single_epoch(
iterator_train,
training=True,
prefix="train",
step_fn=self.train_step,
**fit_params
)
if dataset_valid is not None:
self.run_single_epoch(
iterator_valid,
training=False,
prefix="valid",
step_fn=self.validation_step,
**fit_params
)
self.notify("on_epoch_end", **on_epoch_kwargs)
#manual early stopping for tcga
if self.task == 'tcga':
train_acc = round(self.history[:,'train_acc'][-1],rounding_digits)
if train_acc == 1:
break
return self
def train_step(self, X, y=None):
y = X[1]
X = X[0]
sample_weights = X[:,-1]
if self.device == 'cuda':
sample_weights = sample_weights.to(self.train_config.device)
self.module_.train()
self.module_.zero_grad()
gene_embedd, second_input_embedd, activations,_,_ = self.module_(X[:,:-1],train=True)
#curr_epoch is passed to loss as it is used to switch loss criteria from unsup. -> sup
loss = self.get_loss([gene_embedd,second_input_embedd,activations,self.curr_epoch], y)
###sup loss should be X with samples weight and aggregated
loss = loss*sample_weights
loss = loss.mean()
loss.backward()
# TODO: clip only some parameters
torch.nn.utils.clip_grad_norm_(self.module_.parameters(), self.clip)
self.optimizer_.step()
return {"X":X,"y":y,"loss": loss, "y_pred": [gene_embedd,second_input_embedd,activations]}
def validation_step(self, X, y=None):
y = X[1]
X = X[0]
sample_weights = X[:,-1]
if self.device == 'cuda':
sample_weights = sample_weights.to(self.train_config.device)
self.module_.eval()
with torch.no_grad():
gene_embedd, second_input_embedd, activations,_,_ = self.module_(X[:,:-1])
loss = self.get_loss([gene_embedd,second_input_embedd,activations,self.curr_epoch], y)
###sup loss should be X with samples weight and aggregated
loss = loss*sample_weights
loss = loss.mean()
return {"X":X,"y":y,"loss": loss, "y_pred": [gene_embedd,second_input_embedd,activations]}
def get_attention_scores(self, X, y=None):
'''
returns attention scores for a given input
'''
self.module_.eval()
with torch.no_grad():
_, _, _,attn_scores_first,attn_scores_second = self.module_(X[:,:-1])
attn_scores_first = attn_scores_first.detach().cpu().numpy()
if attn_scores_second is not None:
attn_scores_second = attn_scores_second.detach().cpu().numpy()
return attn_scores_first,attn_scores_second
def predict(self, X):
self.module_.train(False)
embedds = self.module_(X[:,:-1])
sample_weights = X[:,-1]
if self.device == 'cuda':
sample_weights = sample_weights.to(self.train_config.device)
gene_embedd, second_input_embedd, activations,_,_ = embedds
if self.save_embedding:
self.gene_embedds.append(gene_embedd.detach().cpu())
#in case only a single transformer is deployed, then second_input_embedd are None. thus have no detach()
if second_input_embedd is not None:
self.second_input_embedds.append(second_input_embedd.detach().cpu())
predictions = torch.cat([activations,sample_weights[:,None]],dim=1)
return predictions
def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
# log gradients and weights
for _, m in self.module_.named_modules():
for pn, p in m.named_parameters():
if pn.endswith("weight") and pn.find("norm") < 0:
if p.grad != None:
if self.log_tb:
from ..callbacks.tbWriter import writer
writer.add_histogram("weights/" + pn, p, len(net.history))
writer.add_histogram(
"gradients/" + pn, p.grad.data, len(net.history)
)
return
def configure_opt(self, l2_weight_decay):
no_decay = ["bias", "LayerNorm.weight"]
params_decay = [
p
for n, p in self.module_.named_parameters()
if not any(nd in n for nd in no_decay)
]
params_nodecay = [
p
for n, p in self.module_.named_parameters()
if any(nd in n for nd in no_decay)
]
optim_groups = [
{"params": params_decay, "weight_decay": l2_weight_decay},
{"params": params_nodecay, "weight_decay": 0.0},
]
return optim_groups
def initialize_optimizer(self, triggered_directly=True):
"""Initialize the model optimizer. If ``self.optimizer__lr``
is not set, use ``self.lr`` instead.
Parameters
----------
triggered_directly : bool (default=True)
Only relevant when optimizer is re-initialized.
Initialization of the optimizer can be triggered directly
(e.g. when lr was changed) or indirectly (e.g. when the
module was re-initialized). If and only if the former
happens, the user should receive a message informing them
about the parameters that caused the re-initialization.
"""
# get learning rate from train config
optimizer_params = self.main_config["train_config"]
kwargs = {}
kwargs["lr"] = optimizer_params.learning_rate
# get l2 weight decay to init opt params
args = self.configure_opt(optimizer_params.l2_weight_decay)
if self.initialized_ and self.verbose:
msg = self._format_reinit_msg(
"optimizer", kwargs, triggered_directly=triggered_directly
)
logger.info(msg)
self.optimizer_ = self.optimizer(args, lr=kwargs["lr"])
self._register_virtual_param(
["optimizer__param_groups__*__*", "optimizer__*", "lr"],
optimizer_setter,
)
def initialize_criterion(self):
"""Initializes the criterion."""
# critereon takes train_config and model_config as an input.
# we get both from the module parameters
self.criterion_ = self.criterion(
self.main_config
)
if isinstance(self.criterion_, torch.nn.Module):
self.criterion_ = to_device(self.criterion_, self.device)
return self
def initialize_callbacks(self):
"""Initializes all callbacks and save the result in the
``callbacks_`` attribute.
Both ``default_callbacks`` and ``callbacks`` are used (in that
order). Callbacks may either be initialized or not, and if
they don't have a name, the name is inferred from the class
name. The ``initialize`` method is called on all callbacks.
The final result will be a list of tuples, where each tuple
consists of a name and an initialized callback. If names are
not unique, a ValueError is raised.
"""
if self.callbacks == "disable":
self.callbacks_ = []
return self
callbacks_ = []
class Dummy:
# We cannot use None as dummy value since None is a
# legitimate value to be set.
pass
for name, cb in self._uniquely_named_callbacks():
# check if callback itself is changed
param_callback = getattr(self, "callbacks__" + name, Dummy)
if param_callback is not Dummy: # callback itself was set
cb = param_callback
# below: check for callback params
# don't set a parameter for non-existing callback
# if the callback is lrcallback then initializa it with the train config,
# which is an input to the module
if name == "lrcallback":
params["config"] = self.main_config["train_config"]
else:
params = self.get_params_for("callbacks__{}".format(name))
if (cb is None) and params:
raise ValueError(
"Trying to set a parameter for callback {} "
"which does not exist.".format(name)
)
if cb is None:
continue
if isinstance(cb, type): # uninitialized:
cb = cb(**params)
else:
cb.set_params(**params)
cb.initialize()
callbacks_.append((name, cb))
self.callbacks_ = callbacks_
return self