Yak-hbdx's picture
uploaded TransfoRNA repo
0b11a42 verified
raw
history blame
6.1 kB
#in this file, the progression of the number of hicos per major class is computed per model
#this is done before ID, after FULL.
#%%
from transforna import load
from transforna import predict_transforna,predict_transforna_all_models
import pandas as pd
import plotly.graph_objects as go
import numpy as np
import plotly.io as pio
import plotly.express as px
mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/subclass_to_annotation.json'
models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/'
mapping_dict = load(mapping_dict_path)
#%%
dataset:str = 'LC'
hico_loco_na_flag:str = 'hico'
assert hico_loco_na_flag in ['hico','loco_na'], 'hico_loco_na_flag must be either hico or loco_na'
if dataset == 'TCGA':
dataset_path_train: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv'
else:
dataset_path_train: str = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv'
prediction_single_pd = predict_transforna(['AAAAAAACCCCCTTTTTTT'],model='Seq',logits_flag = True,trained_on='id',path_to_models=models_path)
sub_classes_used_for_training = prediction_single_pd.columns.tolist()
var = load(dataset_path_train).set_index('sequence')
#remove from var all indexes that are longer than 30
var = var[var.index.str.len() <= 30]
hico_seqs_all = var.index[var['hico']].tolist()
hico_labels_all = var['subclass_name'][var['hico']].values
hico_seqs_id = var.index[var.hico & var.subclass_name.isin(sub_classes_used_for_training)].tolist()
hico_labels_id = var['subclass_name'][var.hico & var.subclass_name.isin(sub_classes_used_for_training)].values
non_hico_seqs = var['subclass_name'][var['hico'] == False].index.values
non_hico_labels = var['subclass_name'][var['hico'] == False].values
#filter hico labels and hico seqs to hico ID
if hico_loco_na_flag == 'loco_na':
curr_seqs = non_hico_seqs
curr_labels = non_hico_labels
elif hico_loco_na_flag == 'hico':
curr_seqs = hico_seqs_id
curr_labels = hico_labels_id
full_df = predict_transforna_all_models(sequences=curr_seqs,path_to_models=models_path)
#%%
mcs = ['rRNA','tRNA','snoRNA','protein_coding','snRNA','miRNA','miscRNA','lncRNA','piRNA','YRNA','vtRNA']
#for each mc, get the sequences of hicos in that mc and compute the number of hicos per model
num_hicos_per_mc = {}
if hico_loco_na_flag == 'hico':#this is where ground truth exists (hico id)
curr_labels_id_mc = [mapping_dict[label] for label in curr_labels]
elif hico_loco_na_flag == 'loco_na': # this is where ground truth does not exist (LOCO/NA)
ensemble_preds = full_df[full_df.Model == 'Ensemble'].set_index('Sequence').loc[curr_seqs].reset_index()
curr_labels_id_mc = [mapping_dict[label] for label in ensemble_preds['Net-Label']]
for mc in mcs:
#select sequences from hico_seqs that are in the major class mc
mc_seqs = [seq for seq,label in zip(curr_seqs,curr_labels_id_mc) if label == mc]
if len(mc_seqs) == 0:
num_hicos_per_mc[mc] = {model:0 for model in full_df.Model.unique()}
continue
#only keep in full_df the sequences that are in mc_seqs
mc_full_df = full_df[full_df.Sequence.isin(mc_seqs)]
curr_num_hico_per_model = mc_full_df[mc_full_df['Is Familiar?']].groupby(['Model'])['Is Familiar?'].value_counts().droplevel(1)
#remove Baseline from index
curr_num_hico_per_model = curr_num_hico_per_model.drop('Baseline')
curr_num_hico_per_model -= curr_num_hico_per_model.min()
num_hicos_per_mc[mc] = curr_num_hico_per_model.to_dict()
#%%
to_plot_df = pd.DataFrame(num_hicos_per_mc)
to_plot_mcs = ['rRNA','tRNA','snoRNA']
fig = go.Figure()
#x axis should be the mcs
for model in num_hicos_per_mc['rRNA'].keys():
fig.add_trace(go.Bar(x=mcs, y=[num_hicos_per_mc[mc][model] for mc in mcs], name=model))
fig.update_layout(barmode='group')
fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')
fig.write_image(f"num_hicos_per_model_{dataset}_{hico_loco_na_flag}.svg")
fig.update_yaxes(type="log")
fig.show()
#%%
import pandas as pd
import glob
from plotly import graph_objects as go
from transforna import load,predict_transforna
all_df = pd.DataFrame()
files = glob.glob('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_files/LC-novel_lev_dist_df.csv')
for file in files:
df = pd.read_csv(file)
all_df = pd.concat([all_df,df])
all_df = all_df.drop(columns=['Unnamed: 0'])
all_df.loc[all_df.split.isnull(),'split'] = 'NA'
ensemble_df = all_df[all_df.Model == 'Ensemble']
# %%
lc_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv'
lc_df = load(lc_path)
lc_df.set_index('sequence',inplace=True)
# %%
#filter lc_df to only include sequences that are in ensemble_df
lc_df = lc_df.loc[ensemble_df.Sequence]
actual_major_classes = lc_df['small_RNA_class_annotation']
predicted_major_classes = ensemble_df[['Net-Label','Sequence']].set_index('Sequence').loc[lc_df.index]['Net-Label'].map(mapping_dict)
# %%
#plot correlation matrix between actual and predicted major classes
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
major_classes = list(set(list(predicted_major_classes.unique())+list(actual_major_classes.unique())))
conf_matrix = confusion_matrix(actual_major_classes,predicted_major_classes,labels=major_classes)
conf_matrix = conf_matrix/np.sum(conf_matrix,axis=1)
sns.heatmap(conf_matrix,annot=True,cmap='Blues')
for i in range(conf_matrix.shape[0]):
for j in range(conf_matrix.shape[1]):
conf_matrix[i,j] = round(conf_matrix[i,j],1)
plt.xlabel('Predicted Major Class')
plt.ylabel('Actual Major Class')
plt.xticks(np.arange(len(major_classes)),major_classes,rotation=90)
plt.yticks(np.arange(len(major_classes)),major_classes,rotation=0)
plt.show()
# %%