|
|
|
from pathlib import Path |
|
from typing import Dict |
|
|
|
from ..callbacks.metrics import accuracy_score |
|
from ..processing.seq_tokenizer import SeqTokenizer |
|
from ..score.score import infer_from_model, infer_testset |
|
from ..utils.file import load, save |
|
from ..utils.utils import * |
|
|
|
|
|
def infer_benchmark(cfg:Dict= None,path:str = None): |
|
if cfg['tensorboard']: |
|
from ..callbacks.tbWriter import writer |
|
|
|
model = cfg["model_name"]+'_'+cfg['task'] |
|
|
|
|
|
set_seed_and_device(cfg["seed"],cfg["device_number"]) |
|
|
|
ad = load(cfg["train_config"].dataset_path_train) |
|
|
|
|
|
dataset_class = SeqTokenizer(ad.var,cfg) |
|
test_data = load(cfg["train_config"].dataset_path_test) |
|
|
|
all_data = prepare_data_benchmark(dataset_class,test_data,cfg) |
|
|
|
|
|
|
|
|
|
sync_skorch_with_config(cfg["model"]["skorch_model"],cfg) |
|
|
|
|
|
net = instantiate_predictor(cfg["model"]["skorch_model"], cfg,path) |
|
net.initialize() |
|
net.load_params(f_params=f'{cfg["inference_settings"]["model_path"]}') |
|
|
|
|
|
if cfg["inference_settings"]["infere_original_testset"]: |
|
infer_testset(net,cfg,all_data,accuracy_score) |
|
|
|
|
|
predicted_labels,logits,_,_ = infer_from_model(net,all_data["infere_data"]) |
|
prepare_inference_results_benchmarck(net,cfg,predicted_labels,logits,all_data) |
|
save(path=Path(__file__).parent.parent.absolute() / f'inference_results_{model}',data=all_data["infere_rna_seq"]) |
|
if cfg['tensorboard']: |
|
writer.close() |