import logging import warnings from argparse import ArgumentParser from contextlib import redirect_stdout from datetime import datetime from pathlib import Path from typing import List import numpy as np import pandas as pd from hydra.utils import instantiate from omegaconf import OmegaConf from sklearn.preprocessing import StandardScaler from umap import UMAP from ..novelty_prediction.id_vs_ood_nld_clf import get_closest_ngbr_per_split from ..processing.seq_tokenizer import SeqTokenizer from ..utils.file import load from ..utils.tcga_post_analysis_utils import Results_Handler from ..utils.utils import (get_model, infer_from_pd, prepare_inference_results_tcga, update_config_with_inference_params) logger = logging.getLogger(__name__) warnings.filterwarnings("ignore") def aggregate_ensemble_model(lev_dist_df:pd.DataFrame): ''' This function aggregates the predictions of the ensemble model by choosing the model with the lowest and the highest NLD per query sequence. If the lowest NLD is lower than Novelty Threshold, then the model with the lowest NLD is chosen as the ensemble prediction. Otherwise, the model with the highest NLD is chosen as the ensemble prediction. ''' #for every sequence, if at least one model scores an NLD < Novelty Threshold, then get the one with the least NLD as the ensemble prediction #otherwise, get the highest NLD. #get the minimum NLD per query sequence #remove the baseline model baseline_df = lev_dist_df[lev_dist_df['Model'] == 'Baseline'].reset_index(drop=True) lev_dist_df = lev_dist_df[lev_dist_df['Model'] != 'Baseline'].reset_index(drop=True) min_lev_dist_df = lev_dist_df.iloc[lev_dist_df.groupby('Sequence')['NLD'].idxmin().values] #get the maximum NLD per query sequence max_lev_dist_df = lev_dist_df.iloc[lev_dist_df.groupby('Sequence')['NLD'].idxmax().values] #choose between each row in min_lev_dist_df and max_lev_dist_df based on the value of Novelty Threshold novel_mask_df = min_lev_dist_df['NLD'] > min_lev_dist_df['Novelty Threshold'] #get the rows where NLD is lower than Novelty Threshold min_lev_dist_df = min_lev_dist_df[~novel_mask_df.values] #get the rows where NLD is higher than Novelty Threshold max_lev_dist_df = max_lev_dist_df[novel_mask_df.values] #merge min_lev_dist_df and max_lev_dist_df ensemble_lev_dist_df = pd.concat([min_lev_dist_df,max_lev_dist_df]) #add ensemble model ensemble_lev_dist_df['Model'] = 'Ensemble' #add ensemble_lev_dist_df to lev_dist_df lev_dist_df = pd.concat([lev_dist_df,ensemble_lev_dist_df,baseline_df]) return lev_dist_df.reset_index(drop=True) def read_inference_model_config(model:str,mc_or_sc,trained_on:str,path_to_models:str): transforna_folder = "TransfoRNA_ID" if trained_on == "full": transforna_folder = "TransfoRNA_FULL" model_path = f"{path_to_models}/{transforna_folder}/{mc_or_sc}/{model}/meta/hp_settings.yaml" cfg = OmegaConf.load(model_path) return cfg def predict_transforna(sequences: List[str], model: str = "Seq-Rev", mc_or_sc:str='sub_class',\ logits_flag:bool = False,attention_flag:bool = False,\ similarity_flag:bool=False,n_sim:int=3,embedds_flag:bool = False, \ umap_flag:bool = False,trained_on:str='full',path_to_models:str='') -> pd.DataFrame: ''' This function predicts the major class or sub class of a list of sequences using the TransfoRNA model. Additionaly, it can return logits, attention scores, similarity scores, gene embeddings or umap embeddings. Input: sequences: list of sequences to predict model: model to use for prediction mc_or_sc: models trained on major class or sub class logits_flag: whether to return logits attention_flag: whether to return attention scores (obtained from the self-attention layer) similarity_flag: whether to return explanatory/similar sequences in the training set n_sim: number of similar sequences to return embedds_flag: whether to return embeddings of the sequences umap_flag: whether to return umap embeddings trained_on: whether to use the model trained on the full dataset or the ID dataset Output: pd.DataFrame with the predictions ''' #assers that only one flag is True assert sum([logits_flag,attention_flag,similarity_flag,embedds_flag,umap_flag]) <= 1, 'One option at most can be True' # capitalize the first letter of the model and the first letter after the - model = "-".join([word.capitalize() for word in model.split("-")]) cfg = read_inference_model_config(model,mc_or_sc,trained_on,path_to_models) cfg = update_config_with_inference_params(cfg,mc_or_sc,trained_on,path_to_models) root_dir = Path(__file__).parents[1].absolute() with redirect_stdout(None): cfg, net = get_model(cfg, root_dir) #original_infer_pd might include seqs that are longer than input model. if so, infer_pd contains the trimmed sequences infer_pd = pd.Series(sequences, name="Sequences").to_frame() predicted_labels, logits, gene_embedds_df,attn_scores_pd,all_data, max_len, net,_ = infer_from_pd(cfg, net, infer_pd, SeqTokenizer,attention_flag) if model == 'Seq': gene_embedds_df = gene_embedds_df.iloc[:,:int(gene_embedds_df.shape[1]/2)] if logits_flag: cfg['log_logits'] = True prepare_inference_results_tcga(cfg, predicted_labels, logits, all_data, max_len) infer_pd = all_data["infere_rna_seq"] if logits_flag: logits_df = infer_pd.rename_axis("Sequence").reset_index() logits_cols = [col for col in infer_pd.columns if "Logits" in col] logits_df = infer_pd[logits_cols] logits_df.columns = pd.MultiIndex.from_tuples(logits_df.columns, names=["Logits", "Sub Class"]) logits_df.columns = logits_df.columns.droplevel(0) return logits_df elif attention_flag: return attn_scores_pd elif embedds_flag: return gene_embedds_df else: #return table with predictions, entropy, threshold, is familiar #add aa predictions to infer_pd embedds_path = '/'.join(cfg['inference_settings']["model_path"].split('/')[:-2])+'/embedds' results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train']) results.get_knn_model() lv_threshold = load(results.analysis_path+"/novelty_model_coef")["Threshold"] logger.info(f'computing levenstein distance for the inference set') #prepare infer split gene_embedds_df.columns = results.embedds_cols[:len(gene_embedds_df.columns)] #add index of gene_embedds_df to be a column with name results.seq_col gene_embedds_df[results.seq_col] = gene_embedds_df.index #set gene_embedds_df as the new infer split results.splits_df_dict['infer_df'] = gene_embedds_df _,_,top_n_seqs,top_n_labels,distances,lev_dist = get_closest_ngbr_per_split(results,'infer',num_neighbors=n_sim) if similarity_flag: #create df sim_df = pd.DataFrame() #populate query sequences and duplicate them n times sequences = gene_embedds_df.index.tolist() #duplicate each sequence n_sim times sequences_duplicated = [seq for seq in sequences for _ in range(n_sim)] sim_df['Sequence'] = sequences_duplicated #assign top_5_seqs list to df column sim_df[f'Explanatory Sequence'] = top_n_seqs sim_df['NLD'] = lev_dist sim_df['Explanatory Label'] = top_n_labels sim_df['Novelty Threshold'] = lv_threshold #for every query sequence, order the NLD in a increasing order sim_df = sim_df.sort_values(by=['Sequence','NLD'],ascending=[False,True]) return sim_df logger.info(f'num of hico based on entropy novelty prediction is {sum(infer_pd["Is Familiar?"])}') #for every n_sim elements in the list, get the smallest levenstein distance lv_dist_closest = [min(lev_dist[i:i+n_sim]) for i in range(0,len(lev_dist),n_sim)] top_n_labels_closest = [top_n_labels[i:i+n_sim][np.argmin(lev_dist[i:i+n_sim])] for i in range(0,len(lev_dist),n_sim)] top_n_seqs_closest = [top_n_seqs[i:i+n_sim][np.argmin(lev_dist[i:i+n_sim])] for i in range(0,len(lev_dist),n_sim)] infer_pd['Is Familiar?'] = [True if lv pd.DataFrame: """ Predicts the labels of the sequences using all the models available in the transforna package. If non of the flags are true, it constructs and aggrgates the output of the ensemble model. Input: sequences: list of sequences to predict mc_or_sc: models trained on major class or sub class logits_flag: whether to return logits attention_flag: whether to return attention scores (obtained from the self-attention layer) similarity_flag: whether to return explanatory/similar sequences in the training set n_sim: number of similar sequences to return embedds_flag: whether to return embeddings of the sequences umap_flag: whether to return umap embeddings trained_on: whether to use the model trained on the full dataset or the ID dataset Output: df: dataframe with the predictions """ now = datetime.now() before_time = now.strftime("%H:%M:%S") models = ["Baseline","Seq", "Seq-Seq", "Seq-Struct", "Seq-Rev"] if similarity_flag or embedds_flag: #remove baseline, takes long time models = ["Baseline","Seq", "Seq-Seq", "Seq-Struct", "Seq-Rev"] if attention_flag: #remove single based transformer models models = ["Seq", "Seq-Struct", "Seq-Rev"] df = None for model in models: logger.info(model) df_ = predict_transforna(sequences, model, mc_or_sc,logits_flag,attention_flag,similarity_flag,n_sim,embedds_flag,umap_flag,trained_on=trained_on,path_to_models = path_to_models) df_["Model"] = model df = pd.concat([df, df_], axis=0) #aggregate ensemble model if not of the flags are true if not logits_flag and not attention_flag and not similarity_flag and not embedds_flag and not umap_flag: df = aggregate_ensemble_model(df) now = datetime.now() after_time = now.strftime("%H:%M:%S") delta_time = datetime.strptime(after_time, "%H:%M:%S") - datetime.strptime(before_time, "%H:%M:%S") logger.info(f"Time taken: {delta_time}") return df if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("sequences", nargs="+") parser.add_argument("--logits_flag", nargs="?", const = True,default=False) parser.add_argument("--attention_flag", nargs="?", const = True,default=False) parser.add_argument("--similarity_flag", nargs="?", const = True,default=False) parser.add_argument("--n_sim", nargs="?", const = 3,default=3) parser.add_argument("--embedds_flag", nargs="?", const = True,default=False) parser.add_argument("--trained_on", nargs="?", const = True,default="full") predict_transforna_all_models(**vars(parser.parse_args()))