|
|
|
import logging |
|
import math |
|
import os |
|
import random |
|
from pathlib import Path |
|
from random import randint |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from hydra._internal.utils import _locate |
|
from hydra.utils import instantiate |
|
from omegaconf import DictConfig, OmegaConf |
|
from scipy.stats import entropy |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.utils.class_weight import (compute_class_weight, |
|
compute_sample_weight) |
|
from skorch.dataset import Dataset |
|
from skorch.helper import predefined_split |
|
|
|
from ..callbacks.metrics import get_callbacks |
|
from ..score.score import infer_from_model |
|
from .energy import * |
|
from .file import load |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def update_config_with_inference_params(config:DictConfig,mc_or_sc:str='sub_class',trained_on:str = 'id',path_to_models:str = 'models/tcga/') -> DictConfig: |
|
inference_config = config.copy() |
|
model = config['model_name'] |
|
model = "-".join([word.capitalize() for word in model.split("-")]) |
|
transforna_folder = "TransfoRNA_ID" |
|
if trained_on == "full": |
|
transforna_folder = "TransfoRNA_FULL" |
|
|
|
inference_config['inference_settings']["model_path"] = f'{path_to_models}{transforna_folder}/{mc_or_sc}/{model}/ckpt/model_params_tcga.pt' |
|
inference_config["inference"] = True |
|
inference_config["log_logits"] = False |
|
|
|
|
|
inference_config = DictConfig(inference_config) |
|
|
|
train_cfg_path = get_hp_setting(inference_config, "train_config") |
|
model_cfg_path = get_hp_setting(inference_config, "model_config") |
|
train_config = instantiate(train_cfg_path) |
|
model_config = instantiate(model_cfg_path) |
|
|
|
train_config = OmegaConf.structured(train_config) |
|
model_config = OmegaConf.structured(model_config) |
|
|
|
model_config["model_input"] = inference_config["model_name"] |
|
inference_config = OmegaConf.merge({"train_config": train_config, "model_config": model_config}, inference_config) |
|
return inference_config |
|
|
|
def update_config_with_dataset_params_benchmark(train_data_df,configs): |
|
''' |
|
After tokenizing the dataset, some features in the config needs to be updated as they will be used |
|
later by sub modules |
|
''' |
|
|
|
|
|
configs["model_config"].second_input_token_len = train_data_df["second_input"].shape[1] |
|
configs["model_config"].tokens_len = train_data_df["tokens_id"].shape[1] |
|
|
|
configs["train_config"].batch_per_epoch = train_data_df["tokens_id"].shape[0]/configs["train_config"].batch_size |
|
return |
|
|
|
def update_config_with_dataset_params_tcga(dataset_class,all_data_df,configs): |
|
configs["model_config"].ff_input_dim = all_data_df['second_input'].shape[1] |
|
configs["model_config"].vocab_size = len(dataset_class.seq_tokens_ids_dict.keys()) |
|
configs["model_config"].second_input_vocab_size = len(dataset_class.second_input_tokens_ids_dict.keys()) |
|
configs["model_config"].tokens_len = dataset_class.tokens_len |
|
configs["model_config"].second_input_token_len = dataset_class.tokens_len |
|
|
|
if configs["model_name"] == "seq-seq": |
|
configs["model_config"].tokens_len = math.ceil(dataset_class.tokens_len/2) |
|
configs["model_config"].second_input_token_len = math.ceil(dataset_class.tokens_len/2) |
|
|
|
|
|
def update_dataclass_inference(cfg,dataset_class): |
|
seq_token_dict,ss_token_dict = get_tokenization_dicts(cfg) |
|
dataset_class.seq_tokens_ids_dict = seq_token_dict |
|
dataset_class.second_input_tokens_ids_dict = ss_token_dict |
|
dataset_class.tokens_len =cfg["model_config"].tokens_len |
|
dataset_class.max_length = get_hp_setting(cfg,'max_length') |
|
dataset_class.min_length = get_hp_setting(cfg,'min_length') |
|
return dataset_class |
|
|
|
def set_seed_and_device(seed:int = 0,device_no:int=0): |
|
|
|
torch.backends.cudnn.deterministic = True |
|
random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
np.random.seed(seed) |
|
torch.cuda.set_device(device_no) |
|
|
|
|
|
def sync_skorch_with_config(skorch_cfg: DictConfig,cfg:DictConfig): |
|
''' |
|
skorch config contains duplicate params to the train and model configs |
|
values of skorch config should be populated by those in the trian and |
|
model config |
|
''' |
|
|
|
|
|
for key in skorch_cfg: |
|
if key in cfg["train_config"]: |
|
skorch_cfg[key] = cfg["train_config"][key] |
|
if key in cfg["model_config"]: |
|
skorch_cfg[key] = cfg["model_config"][key] |
|
|
|
return |
|
|
|
def instantiate_predictor(skorch_cfg: DictConfig,cfg:DictConfig,path: str=None): |
|
|
|
predictor_config = OmegaConf.to_container(skorch_cfg) |
|
|
|
if "device" in predictor_config: |
|
predictor_config["device"] = skorch_cfg["device"] |
|
for key, val in predictor_config.items(): |
|
try: |
|
predictor_config[key] = _locate(val) |
|
except: |
|
continue |
|
|
|
predictor_config["callbacks"] = get_callbacks(path,cfg) |
|
|
|
|
|
|
|
|
|
callbacks_list = predictor_config["callbacks"] |
|
predictor_config["callbacks"] = "disable" |
|
|
|
|
|
|
|
predictor_config["module__main_config"] = \ |
|
{key:cfg[key] for key in cfg if key not in ["model"]} |
|
|
|
if 'dataset' in predictor_config['module__main_config']: |
|
del predictor_config['module__main_config']['dataset'] |
|
|
|
|
|
if not cfg['train_split']: |
|
predictor_config['train_split'] = False |
|
net = instantiate(predictor_config) |
|
|
|
net.callbacks = callbacks_list |
|
net.task = cfg['task'] |
|
net.initialize_callbacks() |
|
|
|
net.initialized_=True |
|
return net |
|
|
|
def get_fused_seqs(seqs,num_sequences:int=1,max_len:int=30): |
|
''' |
|
fuse num_sequences sequences from seqs |
|
''' |
|
fused_seqs = [] |
|
while len(fused_seqs) < num_sequences: |
|
|
|
seq1 = random.choice(seqs)[:max_len] |
|
seq2 = random.choice(seqs)[:max_len] |
|
|
|
|
|
idx = random.randint(math.floor(len(seq1)*0.3),math.floor(len(seq1)*0.7)) |
|
len_to_be_added_from_seq2 = len(seq1)-idx |
|
|
|
seq1 = seq1[:idx] |
|
|
|
seq2 = seq2[:len_to_be_added_from_seq2] |
|
|
|
fused_seq = seq1+seq2 |
|
|
|
if fused_seq not in fused_seqs and fused_seq not in seqs: |
|
fused_seqs.append(fused_seq) |
|
|
|
return fused_seqs |
|
|
|
def revert_seq_tokenization(tokenized_seqs,configs): |
|
window = configs["model_config"].window |
|
if configs["model_config"].tokenizer != "overlap": |
|
logger.error("Sequences are not reverse tokenized") |
|
return tokenized_seqs |
|
|
|
|
|
seqs_concat = [] |
|
for seq in tokenized_seqs.values: |
|
seqs_concat.append(''.join(seq[seq!='pad'])[::window]+seq[seq!='pad'][-1][window-1]) |
|
|
|
return pd.DataFrame(seqs_concat,columns=["Sequences"]) |
|
|
|
def introduce_mismatches(seq, n_mismatches): |
|
seq = list(seq) |
|
for i in range(n_mismatches): |
|
rand_nt = randint(0,len(seq)-1) |
|
seq[rand_nt] = ['A','G','C','T'][randint(0,3)] |
|
return ''.join(seq) |
|
|
|
def prepare_split(split_data_df,configs): |
|
''' |
|
This function returns tokens, token ids and labels for a given dataframes' split. |
|
It also moves tokens and labels to device |
|
''' |
|
|
|
model_input_cols = ['tokens_id','second_input','seqs_length'] |
|
|
|
split_data = torch.tensor( |
|
np.array(split_data_df[model_input_cols].values, dtype=float), |
|
dtype=torch.float, |
|
) |
|
split_weights = torch.tensor(compute_sample_weight('balanced',split_data_df['Labels'])) |
|
split_data = torch.cat([split_data,split_weights[:,None]],dim=1) |
|
|
|
split_rna_seq = revert_seq_tokenization(split_data_df["tokens"],configs) |
|
|
|
|
|
split_labels = torch.tensor( |
|
np.array(split_data_df["Labels"], dtype=int), |
|
dtype=torch.long, |
|
) |
|
return split_data, split_rna_seq, split_labels |
|
|
|
def prepare_model_inference(cfg,path): |
|
|
|
net = instantiate_predictor(cfg["model"]["skorch_model"], cfg,path) |
|
net.initialize() |
|
|
|
logger.info(f"Model loaded from {cfg['inference_settings']['model_path']}") |
|
net.load_params(f_params=f'{cfg["inference_settings"]["model_path"]}') |
|
net.labels_mapping_dict = dict(zip(cfg["model_config"].class_mappings,list(np.arange(cfg["model_config"].num_classes)))) |
|
|
|
if cfg['log_embedds']: |
|
net.save_embedding=True |
|
net.gene_embedds = [] |
|
net.second_input_embedds = [] |
|
return net |
|
|
|
def prepare_data_benchmark(tokenizer,test_ad, configs): |
|
""" |
|
This function recieves anddata and prepares the anndata in a format suitable for training |
|
It also set default parameters in the config that cannot be known until preprocessing step |
|
is done. |
|
all_data_df is heirarchical pandas dataframe, so can be accessed [AA,AT,..,AC ] |
|
""" |
|
|
|
train_data_df = tokenizer.get_tokenized_data() |
|
|
|
|
|
update_config_with_dataset_params_benchmark(train_data_df,configs) |
|
|
|
|
|
test_data_df = tokenize_set(tokenizer,test_ad.var) |
|
|
|
|
|
train_data, train_rna_seq, train_labels = prepare_split(train_data_df,configs) |
|
test_data, test_rna_seq, test_labels = prepare_split(test_data_df,configs) |
|
|
|
class_weights = compute_class_weight(class_weight='balanced',classes=np.unique(train_labels.flatten()),y=train_labels.flatten().numpy()) |
|
|
|
|
|
|
|
|
|
configs['model_config'].class_weights = [str(x) for x in list(class_weights)] |
|
|
|
if configs["train_split"]: |
|
|
|
train_data,valid_data,train_labels,valid_labels = stratify(train_data,train_labels,configs["valid_size"]) |
|
valid_ds = Dataset(valid_data,valid_labels) |
|
valid_ds=predefined_split(valid_ds) |
|
else: |
|
valid_ds = None |
|
|
|
all_data= {"train_data":train_data, |
|
"valid_ds":valid_ds, |
|
"test_data":test_data, |
|
"train_rna_seq":train_rna_seq, |
|
"test_rna_seq":test_rna_seq, |
|
"train_labels_numeric":train_labels, |
|
"test_labels_numeric":test_labels} |
|
|
|
if configs["task"] == "premirna": |
|
generalization_test_set = get_add_test_set(tokenizer,\ |
|
dataset_path=configs["train_config"].datset_path_additional_testset) |
|
|
|
|
|
|
|
configs["model_config"].vocab_size = len(tokenizer.seq_tokens_ids_dict.keys()) |
|
configs["model_config"].second_input_vocab_size = len(tokenizer.second_input_tokens_ids_dict.keys()) |
|
configs["model_config"].tokens_mapping_dict = tokenizer.seq_tokens_ids_dict |
|
|
|
|
|
if configs["task"] == "premirna": |
|
generalization_test_data = [] |
|
for test_df in generalization_test_set: |
|
|
|
test_data_extra, _, _ = prepare_split(test_df,configs) |
|
generalization_test_data.append(test_data_extra) |
|
all_data["additional_testset"] = generalization_test_data |
|
|
|
|
|
|
|
get_inference_data(configs,tokenizer,all_data) |
|
|
|
return all_data |
|
|
|
def prepare_inference_results_benchmarck(net,cfg,predicted_labels,logits,all_data): |
|
iterables = [["Sequences"], np.arange(1, dtype=int)] |
|
index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"]) |
|
rna_seqs_df = pd.DataFrame(columns=index, data=np.vstack(all_data["infere_rna_seq"]["Sequences"].values)) |
|
|
|
iterables = [["Logits"], list(net.labels_mapping_dict.keys())] |
|
index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"]) |
|
logits_df = pd.DataFrame(columns=index, data=np.array(logits)) |
|
|
|
|
|
all_data["infere_rna_seq"]["Labels",'0'] = predicted_labels |
|
all_data["infere_rna_seq"].set_index("Sequences",inplace=True) |
|
|
|
|
|
if cfg["log_logits"]: |
|
seq_logits_df = logits_df.join(rna_seqs_df).set_index(("Sequences",0)) |
|
all_data["infere_rna_seq"] = all_data["infere_rna_seq"].join(seq_logits_df) |
|
else: |
|
all_data["infere_rna_seq"].columns = ['Labels'] |
|
|
|
return |
|
|
|
def prepare_inference_results_tcga(cfg,predicted_labels,logits,all_data,max_len): |
|
|
|
logits_clf = load('/'.join(cfg["inference_settings"]["model_path"].split('/')[:-2])\ |
|
+'/analysis/logits_model_coef.yaml') |
|
threshold = round(logits_clf['Threshold'],2) |
|
|
|
|
|
iterables = [["Sequences"], np.arange(1, dtype=int)] |
|
index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"]) |
|
rna_seqs_df = pd.DataFrame(columns=index, data=np.vstack(all_data["infere_rna_seq"]["Sequences"].values)) |
|
|
|
iterables = [["Logits"], cfg['model_config'].class_mappings] |
|
index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"]) |
|
logits_df = pd.DataFrame(columns=index, data=np.array(logits)) |
|
|
|
|
|
all_data["infere_rna_seq"]["Net-Label"] = predicted_labels |
|
all_data["infere_rna_seq"]["Is Familiar?"] = entropy(logits,axis=1) <= threshold |
|
|
|
all_data["infere_rna_seq"].set_index("Sequences",inplace=True) |
|
|
|
|
|
if cfg["log_logits"]: |
|
seq_logits_df = logits_df.join(rna_seqs_df).set_index(("Sequences",0)) |
|
all_data["infere_rna_seq"] = all_data["infere_rna_seq"].join(seq_logits_df) |
|
|
|
all_data["infere_rna_seq"].index.name = f'Sequences, Max Length={max_len}' |
|
|
|
return |
|
|
|
def prepare_inference_data(cfg,infer_pd,dataset_class): |
|
|
|
infere_data_df = tokenize_set(dataset_class,infer_pd,inference=True) |
|
infere_data,infere_rna_seq,_ = prepare_split(infere_data_df,cfg) |
|
|
|
all_data = {} |
|
all_data["infere_data"] = infere_data |
|
all_data["infere_rna_seq"] = infere_rna_seq |
|
return all_data |
|
|
|
def get_inference_data(configs,dataset_class,all_data): |
|
|
|
if configs["inference"]==True and configs["inference_settings"]["sequences_path"] is not None: |
|
inference_file = configs["inference_settings"]["sequences_path"] |
|
inference_path = Path(__file__).parent.parent.parent.absolute() / f"{inference_file}" |
|
|
|
infer_data = load(inference_path) |
|
|
|
if "Secondary" not in infer_data: |
|
infer_data['Secondary'] = dataset_class.get_secondary_structure(infer_data["Sequences"]) |
|
if "Labels" not in infer_data: |
|
infer_data["Labels"] = [0]*len(infer_data["Sequences"].values) |
|
|
|
dataset_class.seqs_dot_bracket_labels = infer_data |
|
|
|
|
|
dataset_class.min_length = 0 |
|
dataset_class.limit_seqs_to_range(logger) |
|
infere_data_df = dataset_class.get_tokenized_data(inference=True) |
|
infere_data,infere_rna_seq,_ = prepare_split(infere_data_df,configs) |
|
|
|
all_data["infere_data"] = infere_data |
|
all_data["infere_rna_seq"] = infere_rna_seq |
|
|
|
def get_add_test_set(dataset_class,dataset_path): |
|
all_added_test_set = [] |
|
|
|
paths_mirbase = dataset_path+"mirbase/" |
|
files_mirbase = os.listdir(paths_mirbase) |
|
for f_idx,_ in enumerate(files_mirbase): |
|
files_mirbase[f_idx] = paths_mirbase+files_mirbase[f_idx] |
|
|
|
paths_mirgene = dataset_path + "mirgene/" |
|
files_mirgene = os.listdir(paths_mirgene) |
|
for f_idx,_ in enumerate(files_mirgene): |
|
files_mirgene[f_idx] = paths_mirgene+files_mirgene[f_idx] |
|
files = files_mirbase+files_mirgene |
|
for f in files: |
|
|
|
test_pd = load(f) |
|
test_pd = test_pd.drop(columns='Unnamed: 0') |
|
test_pd["Sequences"] = test_pd["Sequences"].astype(object) |
|
test_pd["Secondary"] = test_pd["Secondary"].astype(object) |
|
|
|
test_pd["Labels"] = 1 |
|
|
|
dataset_class.seqs_dot_bracket_labels = test_pd |
|
dataset_class.limit_seqs_to_range() |
|
all_added_test_set.append(dataset_class.get_tokenized_data()) |
|
return all_added_test_set |
|
|
|
def get_tokenization_dicts(cfg): |
|
tokenization_path='/'.join(cfg['inference_settings']['model_path'].split('/')[:-2]) |
|
seq_token_dict = load(tokenization_path+'/seq_tokens_ids_dict') |
|
ss_token_dict = load(tokenization_path+'/second_input_tokens_ids_dict') |
|
return seq_token_dict,ss_token_dict |
|
|
|
def get_hp_setting(cfg,hp_param): |
|
model_parent_path=Path('/'.join(cfg['inference_settings']['model_path'].split('/')[:-2])) |
|
hp_settings = load(model_parent_path/'meta/hp_settings.yaml') |
|
|
|
|
|
hp_val = hp_settings.get(hp_param) |
|
if not hp_val: |
|
for key in hp_settings.keys(): |
|
try: |
|
hp_val = hp_settings[key].get(hp_param) |
|
except: |
|
pass |
|
if hp_val != None: |
|
break |
|
if hp_val == None: |
|
raise ValueError(f"hp_param {hp_param} not found in hp_settings") |
|
|
|
return hp_val |
|
|
|
def get_model(cfg,path): |
|
|
|
cfg["model_config"] = get_hp_setting(cfg,'model_config') |
|
|
|
sync_skorch_with_config(cfg["model"]["skorch_model"],cfg) |
|
cfg['model_config']['model_input'] = cfg['model_name'] |
|
net = prepare_model_inference(cfg,path) |
|
return cfg,net |
|
|
|
def stratify(train_data,train_labels,valid_size): |
|
return train_test_split(train_data, train_labels, |
|
stratify=train_labels, |
|
test_size=valid_size) |
|
|
|
def tokenize_set(dataset_class,test_pd,inference:bool=False): |
|
|
|
dataset_class.seqs_dot_bracket_labels = test_pd |
|
|
|
dataset_class.limit_seqs_to_range() |
|
return dataset_class.get_tokenized_data(inference) |
|
|
|
def add_original_seqs_to_predictions(short_to_long_df,pred_df): |
|
short_to_long_df.set_index('Sequences',inplace=True) |
|
pred_df = pd.merge(pred_df,short_to_long_df[['Trimmed','Original_Sequence']],right_index=True,left_index=True,how='left') |
|
|
|
pred_df = pred_df[~pred_df.index.duplicated(keep='first')] |
|
return pred_df |
|
|
|
def add_ss_and_labels(infer_data): |
|
|
|
if "Secondary" not in infer_data: |
|
infer_data["Secondary"] = fold_sequences(infer_data["Sequences"].tolist())['structure_37'].values |
|
if "Labels" not in infer_data: |
|
infer_data["Labels"] = [0]*len(infer_data["Sequences"].values) |
|
return infer_data |
|
|
|
def chunkstring_overlap(string, window): |
|
return ( |
|
string[0 + i : window + i] for i in range(0, len(string) - window + 1, 1) |
|
) |
|
|
|
def create_short_seqs_from_long(df,max_len): |
|
long_seqs = df['Sequences'][df['Sequences'].str.len()>max_len].values |
|
short_seqs_pd = df[df['Sequences'].str.len()<=max_len] |
|
feature_tokens_gen = list( |
|
chunkstring_overlap(feature, max_len) |
|
for feature in long_seqs |
|
) |
|
original_seqs = [] |
|
shortened_seqs = [] |
|
for i,feature_tokens in enumerate(feature_tokens_gen): |
|
curr_trimmed_seqs = [feature for feature in feature_tokens] |
|
shortened_seqs.extend(curr_trimmed_seqs) |
|
original_seqs.extend([long_seqs[i]]*len(curr_trimmed_seqs)) |
|
short_to_long_dict = dict(zip(shortened_seqs,original_seqs)) |
|
shortened_df = pd.DataFrame(data=shortened_seqs,columns=['Sequences']) |
|
df = shortened_df.append(short_seqs_pd).reset_index(drop=True) |
|
|
|
df['Trimmed'] = False |
|
df.loc[shortened_df.index,'Trimmed'] = True |
|
df['Original_Sequence'] = df['Sequences'] |
|
df.loc[shortened_df.index,'Original_Sequence'] = df.loc[shortened_df.index,'Sequences'].map(short_to_long_dict) |
|
return df |
|
|
|
def infer_from_pd(cfg,net,infer_pd,DataClass,attention_flag:bool=False): |
|
try: |
|
max_len = net.module_.transformer_layers.pos_encoder.pe.shape[1]+1 |
|
except: |
|
max_len = 30 |
|
|
|
if cfg['model_name'] == 'seq-seq': |
|
max_len = max_len*2 - 1 |
|
|
|
if len(infer_pd['Sequences'][infer_pd['Sequences'].str.len()>max_len].values)>0: |
|
infer_pd = create_short_seqs_from_long(infer_pd,max_len) |
|
infer_pd = add_ss_and_labels(infer_pd) |
|
if cfg['model_name'] == 'seq-seq': |
|
cfg['model_config']['tokens_len'] *=2 |
|
cfg['model_config']['second_input_token_len'] *=2 |
|
|
|
|
|
|
|
dataset_class = DataClass(infer_pd,cfg) |
|
|
|
dataset_class = update_dataclass_inference(cfg,dataset_class) |
|
|
|
all_data = prepare_inference_data(cfg,infer_pd,dataset_class) |
|
|
|
|
|
predicted_labels,logits,attn_scores_first_list,attn_scores_second_list = infer_from_model(net,all_data["infere_data"]) |
|
if attention_flag: |
|
|
|
if not attn_scores_second_list: |
|
attn_scores_second_list = attn_scores_first_list |
|
|
|
attn_scores_first = np.array(attn_scores_first_list) |
|
seq_lengths = all_data['infere_rna_seq']['Sequences'].str.len().values |
|
|
|
attn_scores_list = [attn_scores_first[i,:seq_lengths[i],:seq_lengths[i]].flatten().tolist() for i in range(len(seq_lengths))] |
|
attn_scores_first_df = pd.DataFrame(data = {'attention_first':attn_scores_list}) |
|
attn_scores_first_df.index = all_data['infere_rna_seq']['Sequences'].values |
|
|
|
attn_scores_second = np.array(attn_scores_second_list) |
|
attn_scores_list = [attn_scores_second[i,:seq_lengths[i],:seq_lengths[i]].flatten().tolist() for i in range(len(seq_lengths))] |
|
attn_scores_second_df = pd.DataFrame(data = {'attention_second':attn_scores_list}) |
|
attn_scores_second_df.index = all_data['infere_rna_seq']['Sequences'].values |
|
|
|
attn_scores_df = attn_scores_first_df.join(attn_scores_second_df) |
|
attn_scores_df['Secondary'] = infer_pd["Secondary"].values |
|
else: |
|
attn_scores_df = None |
|
|
|
gene_embedds_df = None |
|
|
|
if cfg['log_embedds']: |
|
gene_embedds = np.vstack(net.gene_embedds) |
|
if cfg['model_name'] not in ['baseline']: |
|
second_input_embedds = np.vstack(net.second_input_embedds) |
|
gene_embedds = np.concatenate((gene_embedds,second_input_embedds),axis=1) |
|
gene_embedds_df = pd.DataFrame(data=gene_embedds) |
|
gene_embedds_df.index = all_data['infere_rna_seq']['Sequences'].values |
|
gene_embedds_df.columns = ['gene_embedds_'+str(i) for i in range(gene_embedds_df.shape[1])] |
|
|
|
return predicted_labels,logits,gene_embedds_df,attn_scores_df,all_data,max_len,net,infer_pd |
|
|
|
def log_embedds(cfg,net,seqs_df): |
|
gene_embedds = np.vstack(net.gene_embedds) |
|
if not cfg['model_name'] in ['seq','baseline']: |
|
second_input_embedds = np.vstack(net.second_input_embedds) |
|
gene_embedds = np.concatenate((gene_embedds,second_input_embedds),axis=1) |
|
|
|
return seqs_df.join(pd.DataFrame(data=gene_embedds)) |
|
|