{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "\n", "import pandas as pd\n", "scores = {'major_class':{},'sub_class':{}}\n", "models = ['Baseline','Seq','Seq-Seq','Seq-Struct','Seq-Rev']\n", "models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_ID'\n", "for model1 in models:\n", " summary_pd = pd.read_csv(models_path+'/major_class/'+model1+'/summary_pd.tsv',sep='\\t')\n", " scores['major_class'][model1] = str(summary_pd['B. Acc'].mean()*100)+'+/-'+' ('+str(summary_pd['B. Acc'].std()*100)+')'\n", " summary_pd = pd.read_csv(models_path+'/sub_class/'+model1+'/summary_pd.tsv',sep='\\t')\n", " scores['sub_class'][model1] = str(summary_pd['B. Acc'].mean()*100)+'+/-'+' ('+str(summary_pd['B. Acc'].std()*100) +')'" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'Baseline': '52.83789870060305+/- (1.0961119898709506)',\n", " 'Seq': '97.70018230805728+/- (0.3819207447704567)',\n", " 'Seq-Seq': '95.65091330992355+/- (0.4963151975035616)',\n", " 'Seq-Struct': '97.71071590680333+/- (0.6173598637101496)',\n", " 'Seq-Rev': '97.51224133899979+/- (0.3418133671042992)'}" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scores['sub_class']" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "\n", "import json\n", "import pandas as pd\n", "with open('/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json') as f:\n", " mapping_dict = json.load(f)\n", "\n", "b_acc_sc_to_mc = {}\n", "for model1 in models:\n", " b_acc = []\n", " for idx in range(5):\n", " confusion_matrix = pd.read_csv(models_path+'/sub_class/'+model1+f'/embedds/confusion_matrix_{idx}.csv',sep=',',index_col=0)\n", " confusion_matrix.index = confusion_matrix.index.map(mapping_dict)\n", " confusion_matrix.columns = confusion_matrix.columns.map(mapping_dict)\n", " confusion_matrix = confusion_matrix.groupby(confusion_matrix.index).sum().groupby(confusion_matrix.columns,axis=1).sum()\n", " b_acc.append(confusion_matrix.values.diagonal().sum()/confusion_matrix.values.sum())\n", " b_acc_sc_to_mc[model1] = str(pd.Series(b_acc).mean()*100)+'+/-'+' ('+str(pd.Series(b_acc).std()*100)+')'\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'Baseline': '89.6182558114013+/- (0.6372156071358975)',\n", " 'Seq': '99.66714304286457+/- (0.1404591049684126)',\n", " 'Seq-Seq': '99.40702944026852+/- (0.18268320317601783)',\n", " 'Seq-Struct': '99.77114728744993+/- (0.06976258667467564)',\n", " 'Seq-Rev': '99.70878801385821+/- (0.11954774341354062)'}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b_acc_sc_to_mc" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "\n", "import plotly.express as px\n", "no_annotation_predictions = {}\n", "for model1 in models:\n", " #multiindex\n", " no_annotation_predictions[model1] = pd.read_csv(models_path+'/sub_class/'+model1+'/embedds/no_annotation_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n", " no_annotation_predictions[model1].set_index([('RNA Sequences','0')] ,inplace=True)\n", " no_annotation_predictions[model1].index.name = 'RNA Sequences'\n", " no_annotation_predictions[model1] = no_annotation_predictions[model1]['Logits'].idxmax(axis=1)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transforna.src.utils.tcga_post_analysis_utils import correct_labels\n", "import pandas as pd\n", "correlation = pd.DataFrame(index=models,columns=models)\n", "for model1 in models:\n", " for model2 in models:\n", " model1_predictions = correct_labels(no_annotation_predictions[model1],no_annotation_predictions[model2],mapping_dict)\n", " is_equal = model1_predictions == no_annotation_predictions[model2].values\n", " correlation.loc[model1,model2] = is_equal.sum()/len(is_equal)\n", "font_size = 20\n", "fig = px.imshow(correlation, color_continuous_scale='Blues')\n", "#annotate\n", "for i in range(len(models)):\n", " for j in range(len(models)):\n", " if i != j:\n", " font = dict(color='black', size=font_size)\n", " else:\n", " font = dict(color='white', size=font_size) \n", " \n", " fig.add_annotation(\n", " x=j, y=i,\n", " text=str(round(correlation.iloc[i,j],2)),\n", " showarrow=False,\n", " font=font\n", " )\n", "\n", "#set figure size: width and height\n", "fig.update_layout(width=800, height=800)\n", "\n", "fig.update_layout(title='Correlation between models for each sub_class model')\n", "#set x and y axis to Models\n", "fig.update_xaxes(title_text='Models', tickfont=dict(size=font_size))\n", "fig.update_yaxes(title_text='Models', tickfont=dict(size=font_size))\n", "fig.show()\n", "#save\n", "fig.write_image('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/figures/correlation_id_models_sub_class.png')\n", "fig.write_image('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/figures/correlation_id_models_sub_class.svg')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "#create umap for every model from embedds folder\n", "models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_ID'\n", "\n", "#read\n", "sc_embedds = {}\n", "mc_embedds = {}\n", "sc_to_mc_labels = {}\n", "sc_labels = {}\n", "mc_labels = {}\n", "for model in models:\n", " df = pd.read_csv(models_path+'/sub_class/'+model+'/embedds/train_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n", " sc_embedds[model] = df['RNA Embedds'].values\n", " sc_labels[model] = df['Labels']['0']\n", " sc_to_mc_labels[model] = sc_labels[model].map(mapping_dict).values\n", "\n", " #major class\n", " df = pd.read_csv(models_path+'/major_class/'+model+'/embedds/train_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n", " mc_embedds[model] = df['RNA Embedds'].values\n", " mc_labels[model] = df['Labels']['0']" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import umap\n", "#compute umap coordinates\n", "sc_umap_coords = {}\n", "mc_umap_coords = {}\n", "for model in models:\n", " sc_umap_coords[model] = umap.UMAP(n_neighbors=5, min_dist=0.3, n_components=2, metric='euclidean').fit_transform(sc_embedds[model])\n", " mc_umap_coords[model] = umap.UMAP(n_neighbors=5, min_dist=0.3, n_components=2, metric='euclidean').fit_transform(mc_embedds[model])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#plot umap\n", "import plotly.express as px\n", "import numpy as np\n", "\n", "mcs = np.unique(sc_to_mc_labels[models[0]])\n", "colors = px.colors.qualitative.Plotly\n", "color_mapping = dict(zip(mcs,colors))\n", "for model in models:\n", " fig = px.scatter(x=sc_umap_coords[model][:,0],y=sc_umap_coords[model][:,1],color=sc_to_mc_labels[model],labels={'color':'Major Class'},title=model, width=800, height=800,\\\n", "\n", " hover_data={'Major Class':sc_labels[model],'Sub Class':sc_to_mc_labels[model]},color_discrete_map=color_mapping)\n", "\n", " fig.update_traces(marker=dict(size=1))\n", " #white background\n", " fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n", "\n", " fig.write_image(models_path+'/sub_class/'+model+'/figures/sc_umap.svg')\n", " fig.write_image(models_path+'/sub_class/'+model+'/figures/sc_umap.png')\n", " fig.show()\n", "\n", " #plot umap for major class\n", " fig = px.scatter(x=mc_umap_coords[model][:,0],y=mc_umap_coords[model][:,1],color=mc_labels[model],labels={'color':'Major Class'},title=model, width=800, height=800,\\\n", "\n", " hover_data={'Major Class':mc_labels[model]},color_discrete_map=color_mapping)\n", " fig.update_traces(marker=dict(size=1))\n", " #white background\n", " fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n", "\n", " fig.write_image(models_path+'/major_class/'+model+'/figures/mc_umap.svg')\n", " fig.write_image(models_path+'/major_class/'+model+'/figures/mc_umap.png')\n", " fig.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transforna import fold_sequences\n", "df = pd.read_csv(models_path+'/major_class/Seq-Struct/embedds/train_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n", "sec_struct = fold_sequences(df['RNA Sequences']['0'])['structure_37']\n", "#sec struct ratio is calculated as the number of non '.' characters divided by the length of the sequence\n", "sec_struct_ratio = sec_struct.apply(lambda x: (len(x)-x.count('.'))/len(x))\n", "fig = px.scatter(x=mc_umap_coords['Seq-Struct'][:,0],y=mc_umap_coords['Seq-Struct'][:,1],color=sec_struct_ratio,labels={'color':'Base Pairing'},title='Seq-Struct', width=800, height=800,\\\n", " hover_data={'Major Class':mc_labels['Seq-Struct']}, color_continuous_scale='Viridis',range_color=[0,1])\n", "\n", "fig.update_traces(marker=dict(size=3))\n", "#white background\n", "fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n", "fig.show()\n", "fig.write_image(models_path+'/major_class/Seq-Struct/figures/mc_umap_sec_struct.svg')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "from transforna import fold_sequences\n", "df = pd.read_csv(models_path+'/sub_class/Seq-Struct/embedds/train_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n", "sec_struct = fold_sequences(df['RNA Sequences']['0'])['structure_37']\n", "#sec struct ratio is calculated as the number of non '.' characters divided by the length of the sequence\n", "sec_struct_ratio = sec_struct.apply(lambda x: (len(x)-x.count('.'))/len(x))\n", "fig = px.scatter(x=sc_umap_coords['Seq-Struct'][:,0],y=sc_umap_coords['Seq-Struct'][:,1],color=sec_struct_ratio,labels={'color':'Base Pairing'},title='Seq-Struct', width=800, height=800,\\\n", " hover_data={'Major Class':mc_labels['Seq-Struct']}, color_continuous_scale='Viridis',range_color=[0,1])\n", "\n", "fig.update_traces(marker=dict(size=3))\n", "#white background\n", "fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n", "fig.show()\n", "fig.write_image(models_path+'/sub_class/Seq-Struct/figures/sc_umap_sec_struct.svg')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from transforna import Results_Handler,get_closest_ngbr_per_split\n", "\n", "splits = ['train','valid','test','ood','artificial','no_annotation']\n", "splits_to_plot = ['test','ood','random','recombined','artificial_affix']\n", "renaming_dict= {'test':'ID (test)','ood':'Rare sub-classes','random':'Random','artificial_affix':'Putative 5\\'-adapter prefixes','recombined':'Recombined'}\n", "\n", "lev_dist_df = pd.DataFrame()\n", "for model in models:\n", " results = Results_Handler(models_path+f'/sub_class/{model}/embedds',splits=splits,read_dataset=True)\n", " results.append_loco_variants()\n", " results.get_knn_model()\n", " \n", " #compute levenstein distance per split\n", " for split in splits_to_plot:\n", " split_seqs,split_labels,top_n_seqs,top_n_labels,distances,lev_dist = get_closest_ngbr_per_split(results,split)\n", " #create df from split and levenstein distance\n", " lev_dist_split_df = pd.DataFrame({'split':split,'lev_dist':lev_dist,'seqs':split_seqs,'labels':split_labels,'top_n_seqs':top_n_seqs,'top_n_labels':top_n_labels})\n", " #rename \n", " lev_dist_split_df['split'] = lev_dist_split_df['split'].map(renaming_dict)\n", " lev_dist_split_df['model'] = model\n", " #append \n", " lev_dist_df = pd.concat([lev_dist_df,lev_dist_split_df],axis=0)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#plot the distribution of lev_dist for each split for each model\n", "model_thresholds = {'Baseline':0.267,'Seq':0.246,'Seq-Seq':0.272,'Seq-Struct': 0.242,'Seq-Rev':0.237}\n", "model_aucs = {'Baseline':0.76,'Seq':0.97,'Seq-Seq':0.96,'Seq-Struct': 0.97,'Seq-Rev':0.97}\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "sns.set_theme(style=\"whitegrid\")\n", "sns.set(rc={'figure.figsize':(15,10)})\n", "sns.set(font_scale=1.5)\n", "ax = sns.boxplot(x=\"model\", y=\"lev_dist\", hue=\"split\", data=lev_dist_df, palette=\"Set3\",order=models,showfliers = True)\n", "#add title\n", "ax.set_facecolor('None')\n", "plt.title('Levenshtein Distance Distribution per Model on ID')\n", "ax.set(xlabel='Model', ylabel='Normalized Levenshtein Distance')\n", "#legend background should transparent\n", "ax.legend(loc='upper left', bbox_to_anchor=(1.05, 1), borderaxespad=0.,facecolor=None,framealpha=0.0)\n", "# add horizontal lines for thresholds for each model while making sure the line is within the boxplot\n", "min_val = 0 \n", "for model in models:\n", " thresh = model_thresholds[model]\n", " plt.axhline(y=thresh, color='g', linestyle='--',xmin=min_val,xmax=min_val+0.2)\n", " min_val+=0.2\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "transforna", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 2 }