File size: 6,101 Bytes
0b11a42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
#in this file, the progression of the number of hicos per major class is computed per model
#this is done before ID, after FULL.
#%%
from transforna import load
from transforna import predict_transforna,predict_transforna_all_models
import pandas as pd
import plotly.graph_objects as go
import numpy as np
import plotly.io as pio
import plotly.express as px
mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/subclass_to_annotation.json'
models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/'
mapping_dict = load(mapping_dict_path)
#%%
dataset:str = 'LC'
hico_loco_na_flag:str = 'hico'
assert hico_loco_na_flag in ['hico','loco_na'], 'hico_loco_na_flag must be either hico or loco_na'
if dataset == 'TCGA':
dataset_path_train: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv'
else:
dataset_path_train: str = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv'
prediction_single_pd = predict_transforna(['AAAAAAACCCCCTTTTTTT'],model='Seq',logits_flag = True,trained_on='id',path_to_models=models_path)
sub_classes_used_for_training = prediction_single_pd.columns.tolist()
var = load(dataset_path_train).set_index('sequence')
#remove from var all indexes that are longer than 30
var = var[var.index.str.len() <= 30]
hico_seqs_all = var.index[var['hico']].tolist()
hico_labels_all = var['subclass_name'][var['hico']].values
hico_seqs_id = var.index[var.hico & var.subclass_name.isin(sub_classes_used_for_training)].tolist()
hico_labels_id = var['subclass_name'][var.hico & var.subclass_name.isin(sub_classes_used_for_training)].values
non_hico_seqs = var['subclass_name'][var['hico'] == False].index.values
non_hico_labels = var['subclass_name'][var['hico'] == False].values
#filter hico labels and hico seqs to hico ID
if hico_loco_na_flag == 'loco_na':
curr_seqs = non_hico_seqs
curr_labels = non_hico_labels
elif hico_loco_na_flag == 'hico':
curr_seqs = hico_seqs_id
curr_labels = hico_labels_id
full_df = predict_transforna_all_models(sequences=curr_seqs,path_to_models=models_path)
#%%
mcs = ['rRNA','tRNA','snoRNA','protein_coding','snRNA','miRNA','miscRNA','lncRNA','piRNA','YRNA','vtRNA']
#for each mc, get the sequences of hicos in that mc and compute the number of hicos per model
num_hicos_per_mc = {}
if hico_loco_na_flag == 'hico':#this is where ground truth exists (hico id)
curr_labels_id_mc = [mapping_dict[label] for label in curr_labels]
elif hico_loco_na_flag == 'loco_na': # this is where ground truth does not exist (LOCO/NA)
ensemble_preds = full_df[full_df.Model == 'Ensemble'].set_index('Sequence').loc[curr_seqs].reset_index()
curr_labels_id_mc = [mapping_dict[label] for label in ensemble_preds['Net-Label']]
for mc in mcs:
#select sequences from hico_seqs that are in the major class mc
mc_seqs = [seq for seq,label in zip(curr_seqs,curr_labels_id_mc) if label == mc]
if len(mc_seqs) == 0:
num_hicos_per_mc[mc] = {model:0 for model in full_df.Model.unique()}
continue
#only keep in full_df the sequences that are in mc_seqs
mc_full_df = full_df[full_df.Sequence.isin(mc_seqs)]
curr_num_hico_per_model = mc_full_df[mc_full_df['Is Familiar?']].groupby(['Model'])['Is Familiar?'].value_counts().droplevel(1)
#remove Baseline from index
curr_num_hico_per_model = curr_num_hico_per_model.drop('Baseline')
curr_num_hico_per_model -= curr_num_hico_per_model.min()
num_hicos_per_mc[mc] = curr_num_hico_per_model.to_dict()
#%%
to_plot_df = pd.DataFrame(num_hicos_per_mc)
to_plot_mcs = ['rRNA','tRNA','snoRNA']
fig = go.Figure()
#x axis should be the mcs
for model in num_hicos_per_mc['rRNA'].keys():
fig.add_trace(go.Bar(x=mcs, y=[num_hicos_per_mc[mc][model] for mc in mcs], name=model))
fig.update_layout(barmode='group')
fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')
fig.write_image(f"num_hicos_per_model_{dataset}_{hico_loco_na_flag}.svg")
fig.update_yaxes(type="log")
fig.show()
#%%
import pandas as pd
import glob
from plotly import graph_objects as go
from transforna import load,predict_transforna
all_df = pd.DataFrame()
files = glob.glob('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_files/LC-novel_lev_dist_df.csv')
for file in files:
df = pd.read_csv(file)
all_df = pd.concat([all_df,df])
all_df = all_df.drop(columns=['Unnamed: 0'])
all_df.loc[all_df.split.isnull(),'split'] = 'NA'
ensemble_df = all_df[all_df.Model == 'Ensemble']
# %%
lc_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv'
lc_df = load(lc_path)
lc_df.set_index('sequence',inplace=True)
# %%
#filter lc_df to only include sequences that are in ensemble_df
lc_df = lc_df.loc[ensemble_df.Sequence]
actual_major_classes = lc_df['small_RNA_class_annotation']
predicted_major_classes = ensemble_df[['Net-Label','Sequence']].set_index('Sequence').loc[lc_df.index]['Net-Label'].map(mapping_dict)
# %%
#plot correlation matrix between actual and predicted major classes
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
major_classes = list(set(list(predicted_major_classes.unique())+list(actual_major_classes.unique())))
conf_matrix = confusion_matrix(actual_major_classes,predicted_major_classes,labels=major_classes)
conf_matrix = conf_matrix/np.sum(conf_matrix,axis=1)
sns.heatmap(conf_matrix,annot=True,cmap='Blues')
for i in range(conf_matrix.shape[0]):
for j in range(conf_matrix.shape[1]):
conf_matrix[i,j] = round(conf_matrix[i,j],1)
plt.xlabel('Predicted Major Class')
plt.ylabel('Actual Major Class')
plt.xticks(np.arange(len(major_classes)),major_classes,rotation=90)
plt.yticks(np.arange(len(major_classes)),major_classes,rotation=0)
plt.show()
# %%
|