|
import logging |
|
from typing import Dict |
|
|
|
from anndata import AnnData |
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
from ..callbacks.metrics import accuracy_score |
|
from ..novelty_prediction.id_vs_ood_entropy_clf import compute_entropies |
|
from ..novelty_prediction.id_vs_ood_nld_clf import compute_nlds |
|
from ..processing.augmentation import DataAugmenter |
|
from ..processing.seq_tokenizer import SeqTokenizer |
|
from ..processing.splitter import * |
|
from ..processing.splitter import DataSplitter |
|
from ..score.score import (compute_score_benchmark, compute_score_tcga, |
|
infere_additional_test_data) |
|
from ..utils.file import load, save |
|
from ..utils.utils import (instantiate_predictor, prepare_data_benchmark, |
|
set_seed_and_device, sync_skorch_with_config) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def compute_cv(cfg:DictConfig,path:str,output_dir:str): |
|
|
|
summary_pd = pd.DataFrame(index=np.arange(cfg["num_replicates"]),columns = ['B. Acc','Dur']) |
|
for seed_no in range(cfg["num_replicates"]): |
|
logger.info(f"Currently training replicate {seed_no}") |
|
cfg["seed"] = seed_no |
|
test_score,net = train(cfg,path=path,output_dir=output_dir) |
|
convrg_epoch = np.where(net.history[:,'val_acc_best'])[0][-1] |
|
convrg_dur = sum(net.history[:,'dur'][:convrg_epoch+1]) |
|
summary_pd.iloc[seed_no] = [test_score,convrg_dur] |
|
|
|
save(path=path+'/summary_pd',data=summary_pd) |
|
|
|
return |
|
|
|
def train(cfg:Dict= None,path:str = None,output_dir:str = None): |
|
if cfg['tensorboard']: |
|
from ..callbacks.tbWriter import writer |
|
|
|
set_seed_and_device(cfg["seed"],cfg["device_number"]) |
|
|
|
dataset = load(cfg["train_config"].dataset_path_train) |
|
|
|
if isinstance(dataset,AnnData): |
|
dataset = dataset.var |
|
else: |
|
dataset.set_index('sequence',inplace=True) |
|
|
|
|
|
|
|
if cfg["task"] in ["premirna","sncrna"]: |
|
tokenizer = SeqTokenizer(dataset,cfg) |
|
test_ad = load(cfg["train_config"].dataset_path_test) |
|
|
|
all_data = prepare_data_benchmark(tokenizer,test_ad,cfg) |
|
else: |
|
df = DataAugmenter(dataset,cfg).get_augmented_df() |
|
tokenizer = SeqTokenizer(df,cfg) |
|
all_data = DataSplitter(tokenizer,cfg).prepare_data_tcga() |
|
|
|
|
|
sync_skorch_with_config(cfg["model"]["skorch_model"],cfg) |
|
|
|
|
|
net = instantiate_predictor(cfg["model"]["skorch_model"], cfg,path) |
|
|
|
|
|
|
|
net.fit(all_data["train_data"],all_data["train_labels_numeric"],all_data["valid_ds"]) |
|
|
|
|
|
save(data=OmegaConf.to_container(cfg, resolve=True),path=path+'/meta/hp_settings.yaml') |
|
|
|
|
|
if cfg['task'] == 'tcga': |
|
test_score = compute_score_tcga(net, all_data,path,cfg) |
|
compute_nlds(output_dir) |
|
compute_entropies(output_dir) |
|
else: |
|
test_score = compute_score_benchmark(net, path,all_data,accuracy_score,cfg) |
|
|
|
if "additional_testset" in all_data: |
|
infere_additional_test_data(net,all_data["additional_testset"]) |
|
|
|
|
|
|
|
if cfg['tensorboard']: |
|
writer.close() |
|
return test_score,net |
|
|