File size: 1,626 Bytes
0b11a42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import logging
import os
import sys
import warnings

import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from omegaconf import DictConfig

from transforna import compute_cv, infer_benchmark, infer_tcga, train

warnings.filterwarnings("ignore")


logger = logging.getLogger(__name__)

def add_config_to_sys_path():
    cfg = HydraConfig.get()
    config_path = [path["path"] for path in cfg.runtime.config_sources if path["schema"] == "file"][0]
    sys.path.append(config_path)

#transforna could called from anywhere:
#python -m transforna --config-dir = /path/to/configs 
@hydra.main(config_path='../conf', config_name="main_config")
def main(cfg: DictConfig) -> None:
    add_config_to_sys_path()
    #get path of hydra outputs folder
    output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir

    path = os.getcwd()
    #init train and model config
    cfg['train_config'] = instantiate(cfg['train_config']).__dict__
    cfg['model_config'] = instantiate(cfg['model_config']).__dict__

    #update model config with the name of the model 
    cfg['model_config']["model_input"] = cfg["model_name"]

    #inference or train
    if cfg["inference"]:
        logger.info(f"Started inference on {cfg['task']}")
        if cfg['task'] == 'tcga':
            return infer_tcga(cfg,path=path)
        else:
            return infer_benchmark(cfg,path=path)
    else:
        if cfg["cross_val"]:
            compute_cv(cfg,path,output_dir=output_dir)

        else:
            train(cfg,path=path,output_dir=output_dir)
    
if __name__ == "__main__":
    main()