|
|
|
|
|
|
|
from transforna import load |
|
from transforna import predict_transforna,predict_transforna_all_models |
|
import pandas as pd |
|
import plotly.graph_objects as go |
|
import numpy as np |
|
import plotly.io as pio |
|
import plotly.express as px |
|
mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/subclass_to_annotation.json' |
|
models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/' |
|
|
|
mapping_dict = load(mapping_dict_path) |
|
|
|
|
|
dataset:str = 'LC' |
|
hico_loco_na_flag:str = 'hico' |
|
assert hico_loco_na_flag in ['hico','loco_na'], 'hico_loco_na_flag must be either hico or loco_na' |
|
if dataset == 'TCGA': |
|
dataset_path_train: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv' |
|
else: |
|
dataset_path_train: str = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv' |
|
|
|
prediction_single_pd = predict_transforna(['AAAAAAACCCCCTTTTTTT'],model='Seq',logits_flag = True,trained_on='id',path_to_models=models_path) |
|
sub_classes_used_for_training = prediction_single_pd.columns.tolist() |
|
|
|
var = load(dataset_path_train).set_index('sequence') |
|
|
|
var = var[var.index.str.len() <= 30] |
|
hico_seqs_all = var.index[var['hico']].tolist() |
|
hico_labels_all = var['subclass_name'][var['hico']].values |
|
|
|
hico_seqs_id = var.index[var.hico & var.subclass_name.isin(sub_classes_used_for_training)].tolist() |
|
hico_labels_id = var['subclass_name'][var.hico & var.subclass_name.isin(sub_classes_used_for_training)].values |
|
|
|
non_hico_seqs = var['subclass_name'][var['hico'] == False].index.values |
|
non_hico_labels = var['subclass_name'][var['hico'] == False].values |
|
|
|
|
|
|
|
if hico_loco_na_flag == 'loco_na': |
|
curr_seqs = non_hico_seqs |
|
curr_labels = non_hico_labels |
|
elif hico_loco_na_flag == 'hico': |
|
curr_seqs = hico_seqs_id |
|
curr_labels = hico_labels_id |
|
|
|
full_df = predict_transforna_all_models(sequences=curr_seqs,path_to_models=models_path) |
|
|
|
|
|
|
|
mcs = ['rRNA','tRNA','snoRNA','protein_coding','snRNA','miRNA','miscRNA','lncRNA','piRNA','YRNA','vtRNA'] |
|
|
|
num_hicos_per_mc = {} |
|
if hico_loco_na_flag == 'hico': |
|
curr_labels_id_mc = [mapping_dict[label] for label in curr_labels] |
|
|
|
elif hico_loco_na_flag == 'loco_na': |
|
ensemble_preds = full_df[full_df.Model == 'Ensemble'].set_index('Sequence').loc[curr_seqs].reset_index() |
|
curr_labels_id_mc = [mapping_dict[label] for label in ensemble_preds['Net-Label']] |
|
|
|
for mc in mcs: |
|
|
|
mc_seqs = [seq for seq,label in zip(curr_seqs,curr_labels_id_mc) if label == mc] |
|
if len(mc_seqs) == 0: |
|
num_hicos_per_mc[mc] = {model:0 for model in full_df.Model.unique()} |
|
continue |
|
|
|
mc_full_df = full_df[full_df.Sequence.isin(mc_seqs)] |
|
curr_num_hico_per_model = mc_full_df[mc_full_df['Is Familiar?']].groupby(['Model'])['Is Familiar?'].value_counts().droplevel(1) |
|
|
|
curr_num_hico_per_model = curr_num_hico_per_model.drop('Baseline') |
|
curr_num_hico_per_model -= curr_num_hico_per_model.min() |
|
num_hicos_per_mc[mc] = curr_num_hico_per_model.to_dict() |
|
|
|
to_plot_df = pd.DataFrame(num_hicos_per_mc) |
|
to_plot_mcs = ['rRNA','tRNA','snoRNA'] |
|
fig = go.Figure() |
|
|
|
for model in num_hicos_per_mc['rRNA'].keys(): |
|
fig.add_trace(go.Bar(x=mcs, y=[num_hicos_per_mc[mc][model] for mc in mcs], name=model)) |
|
|
|
fig.update_layout(barmode='group') |
|
fig.update_layout(plot_bgcolor='rgba(0,0,0,0)') |
|
fig.write_image(f"num_hicos_per_model_{dataset}_{hico_loco_na_flag}.svg") |
|
fig.update_yaxes(type="log") |
|
fig.show() |
|
|
|
|
|
|
|
import pandas as pd |
|
import glob |
|
from plotly import graph_objects as go |
|
from transforna import load,predict_transforna |
|
all_df = pd.DataFrame() |
|
files = glob.glob('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_files/LC-novel_lev_dist_df.csv') |
|
for file in files: |
|
df = pd.read_csv(file) |
|
all_df = pd.concat([all_df,df]) |
|
all_df = all_df.drop(columns=['Unnamed: 0']) |
|
all_df.loc[all_df.split.isnull(),'split'] = 'NA' |
|
ensemble_df = all_df[all_df.Model == 'Ensemble'] |
|
|
|
|
|
lc_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv' |
|
lc_df = load(lc_path) |
|
lc_df.set_index('sequence',inplace=True) |
|
|
|
|
|
lc_df = lc_df.loc[ensemble_df.Sequence] |
|
actual_major_classes = lc_df['small_RNA_class_annotation'] |
|
predicted_major_classes = ensemble_df[['Net-Label','Sequence']].set_index('Sequence').loc[lc_df.index]['Net-Label'].map(mapping_dict) |
|
|
|
|
|
from sklearn.metrics import confusion_matrix |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
major_classes = list(set(list(predicted_major_classes.unique())+list(actual_major_classes.unique()))) |
|
conf_matrix = confusion_matrix(actual_major_classes,predicted_major_classes,labels=major_classes) |
|
conf_matrix = conf_matrix/np.sum(conf_matrix,axis=1) |
|
|
|
sns.heatmap(conf_matrix,annot=True,cmap='Blues') |
|
for i in range(conf_matrix.shape[0]): |
|
for j in range(conf_matrix.shape[1]): |
|
conf_matrix[i,j] = round(conf_matrix[i,j],1) |
|
|
|
|
|
plt.xlabel('Predicted Major Class') |
|
plt.ylabel('Actual Major Class') |
|
plt.xticks(np.arange(len(major_classes)),major_classes,rotation=90) |
|
plt.yticks(np.arange(len(major_classes)),major_classes,rotation=0) |
|
plt.show() |
|
|
|
|
|
|